230 lines
6.2 KiB
Go
230 lines
6.2 KiB
Go
package cas
|
||
|
||
import (
|
||
"context"
|
||
"crypto/tls"
|
||
"encoding/xml"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"net/url"
|
||
"strings"
|
||
|
||
"github.com/chaitin/panda-wiki/log"
|
||
)
|
||
|
||
type Client struct {
|
||
logger *log.Logger
|
||
ctx context.Context
|
||
config *Config
|
||
httpClient *http.Client
|
||
}
|
||
|
||
type Config struct {
|
||
ServerURL string `json:"server_url"` // CAS服务器URL,如 https://cas.example.com/cas
|
||
ServiceURL string `json:"service_url"` // 服务回调URL
|
||
LoginPath string `json:"login_path"` // 登录路径,默认为 /login
|
||
ValidatePath string `json:"validate_path"` // 验证路径,默认根据版本自动选择
|
||
Version string `json:"version"` // CAS协议版本: "2" 或 "3"
|
||
CASUrl string `json:"cas_url"`
|
||
}
|
||
|
||
type UserInfo struct {
|
||
Username string `json:"username"`
|
||
Attributes map[string]string `json:"attributes"`
|
||
}
|
||
|
||
// CAS2ServiceResponse CAS2服务验证响应结构
|
||
type CAS2ServiceResponse struct {
|
||
XMLName xml.Name `xml:"serviceResponse"`
|
||
Success *CAS2AuthenticationSuccess `xml:"authenticationSuccess"`
|
||
Failure *AuthenticationFailure `xml:"authenticationFailure"`
|
||
}
|
||
|
||
type CAS2AuthenticationSuccess struct {
|
||
User string `xml:"user"`
|
||
}
|
||
|
||
// CAS3ServiceResponse CAS3服务验证响应结构
|
||
type CAS3ServiceResponse struct {
|
||
XMLName xml.Name `xml:"serviceResponse"`
|
||
Success *CAS3AuthenticationSuccess `xml:"authenticationSuccess"`
|
||
Failure *AuthenticationFailure `xml:"authenticationFailure"`
|
||
}
|
||
|
||
type CAS3AuthenticationSuccess struct {
|
||
User string `xml:"user"`
|
||
Attributes CAS3Attributes `xml:"attributes"`
|
||
}
|
||
|
||
type AuthenticationFailure struct {
|
||
Code string `xml:"code,attr"`
|
||
Message string `xml:",chardata"`
|
||
}
|
||
|
||
type CAS3Attributes struct {
|
||
Email string `xml:"email"`
|
||
Name string `xml:"name"`
|
||
AvatarURL string `xml:"avatar_url"`
|
||
}
|
||
|
||
const (
|
||
defaultLoginPath = "/login"
|
||
defaultValidatePathCAS2 = "/serviceValidate"
|
||
defaultValidatePathCAS3 = "/p3/serviceValidate"
|
||
callbackPath = "/share/pro/v1/openapi/cas/callback"
|
||
)
|
||
|
||
// NewClient 创建CAS客户端
|
||
func NewClient(ctx context.Context, logger *log.Logger, config Config) (*Client, error) {
|
||
// 设置默认登录路径
|
||
if config.LoginPath == "" {
|
||
config.LoginPath = defaultLoginPath
|
||
}
|
||
|
||
// 如果版本为空,默认使用CAS3
|
||
if config.Version == "" {
|
||
config.Version = "3"
|
||
}
|
||
|
||
// 根据版本设置默认验证路径
|
||
if config.ValidatePath == "" {
|
||
switch config.Version {
|
||
case "3":
|
||
config.ValidatePath = defaultValidatePathCAS3
|
||
case "2", "":
|
||
config.ValidatePath = defaultValidatePathCAS2
|
||
default:
|
||
return nil, fmt.Errorf("unsupported CAS version: %s, supported versions are '2' and '3'", config.Version)
|
||
}
|
||
}
|
||
|
||
// 构建服务回调URL
|
||
if config.ServiceURL != "" {
|
||
serviceURL, err := url.Parse(config.ServiceURL)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("invalid service URL: %w", err)
|
||
}
|
||
serviceURL.Path = callbackPath
|
||
config.ServiceURL = serviceURL.String()
|
||
}
|
||
|
||
return &Client{
|
||
ctx: ctx,
|
||
logger: logger.WithModule("pkg.cas"),
|
||
config: &config,
|
||
httpClient: &http.Client{
|
||
Transport: &http.Transport{
|
||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||
},
|
||
},
|
||
}, nil
|
||
}
|
||
|
||
// GetLoginURL 获取CAS登录URL
|
||
func (c *Client) GetLoginURL(state string) string {
|
||
loginURL := strings.TrimSuffix(c.config.ServerURL, "/") + c.config.LoginPath
|
||
|
||
params := url.Values{}
|
||
params.Set("service", c.config.ServiceURL+"?state="+state)
|
||
|
||
return loginURL + "?" + params.Encode()
|
||
}
|
||
|
||
// ValidateTicket 验证CAS票据并获取用户信息
|
||
func (c *Client) ValidateTicket(ticket, state string) (*UserInfo, error) {
|
||
validateURL := strings.TrimSuffix(c.config.ServerURL, "/") + c.config.ValidatePath
|
||
|
||
params := url.Values{}
|
||
params.Set("service", c.config.ServiceURL+"?state="+state)
|
||
params.Set("ticket", ticket)
|
||
|
||
fullURL := validateURL + "?" + params.Encode()
|
||
|
||
c.logger.Info("validating CAS ticket",
|
||
log.String("url", fullURL),
|
||
log.String("version", c.config.Version))
|
||
|
||
resp, err := c.httpClient.Get(fullURL)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to validate ticket: %w", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
body, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||
}
|
||
|
||
c.logger.Info("CAS validation response", log.String("response", string(body)))
|
||
|
||
// 根据CAS版本解析不同的响应格式
|
||
switch c.config.Version {
|
||
case "2":
|
||
return c.parseCAS2Response(body)
|
||
case "3":
|
||
return c.parseCAS3Response(body)
|
||
default:
|
||
return nil, fmt.Errorf("unsupported CAS version: %s", c.config.Version)
|
||
}
|
||
}
|
||
|
||
// parseCAS2Response 解析CAS2响应
|
||
func (c *Client) parseCAS2Response(body []byte) (*UserInfo, error) {
|
||
var serviceResp CAS2ServiceResponse
|
||
if err := xml.Unmarshal(body, &serviceResp); err != nil {
|
||
return nil, fmt.Errorf("failed to parse CAS2 response: %w", err)
|
||
}
|
||
|
||
if serviceResp.Failure != nil {
|
||
return nil, fmt.Errorf("CAS validation failed: %s - %s",
|
||
serviceResp.Failure.Code, strings.TrimSpace(serviceResp.Failure.Message))
|
||
}
|
||
|
||
if serviceResp.Success == nil {
|
||
return nil, fmt.Errorf("invalid CAS2 response: no success or failure element")
|
||
}
|
||
|
||
userInfo := &UserInfo{
|
||
Username: serviceResp.Success.User,
|
||
Attributes: map[string]string{
|
||
"name": serviceResp.Success.User, // CAS2通常只返回用户名
|
||
},
|
||
}
|
||
|
||
return userInfo, nil
|
||
}
|
||
|
||
// parseCAS3Response 解析CAS3响应
|
||
func (c *Client) parseCAS3Response(body []byte) (*UserInfo, error) {
|
||
var serviceResp CAS3ServiceResponse
|
||
if err := xml.Unmarshal(body, &serviceResp); err != nil {
|
||
return nil, fmt.Errorf("failed to parse CAS3 response: %w", err)
|
||
}
|
||
|
||
if serviceResp.Failure != nil {
|
||
return nil, fmt.Errorf("CAS validation failed: %s - %s",
|
||
serviceResp.Failure.Code, strings.TrimSpace(serviceResp.Failure.Message))
|
||
}
|
||
|
||
if serviceResp.Success == nil {
|
||
return nil, fmt.Errorf("invalid CAS3 response: no success or failure element")
|
||
}
|
||
|
||
userInfo := &UserInfo{
|
||
Username: serviceResp.Success.User,
|
||
Attributes: map[string]string{
|
||
"email": serviceResp.Success.Attributes.Email,
|
||
"name": serviceResp.Success.Attributes.Name,
|
||
"avatar_url": serviceResp.Success.Attributes.AvatarURL,
|
||
},
|
||
}
|
||
|
||
// 如果没有显示名称,使用用户名
|
||
if userInfo.Attributes["name"] == "" {
|
||
userInfo.Attributes["name"] = userInfo.Username
|
||
}
|
||
|
||
return userInfo, nil
|
||
}
|