Files
YouduWiki/backend/pkg/cas/cas.go
2026-05-21 19:52:45 +08:00

230 lines
6.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package cas
import (
"context"
"crypto/tls"
"encoding/xml"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"github.com/chaitin/panda-wiki/log"
)
type Client struct {
logger *log.Logger
ctx context.Context
config *Config
httpClient *http.Client
}
type Config struct {
ServerURL string `json:"server_url"` // CAS服务器URL如 https://cas.example.com/cas
ServiceURL string `json:"service_url"` // 服务回调URL
LoginPath string `json:"login_path"` // 登录路径,默认为 /login
ValidatePath string `json:"validate_path"` // 验证路径,默认根据版本自动选择
Version string `json:"version"` // CAS协议版本: "2" 或 "3"
CASUrl string `json:"cas_url"`
}
type UserInfo struct {
Username string `json:"username"`
Attributes map[string]string `json:"attributes"`
}
// CAS2ServiceResponse CAS2服务验证响应结构
type CAS2ServiceResponse struct {
XMLName xml.Name `xml:"serviceResponse"`
Success *CAS2AuthenticationSuccess `xml:"authenticationSuccess"`
Failure *AuthenticationFailure `xml:"authenticationFailure"`
}
type CAS2AuthenticationSuccess struct {
User string `xml:"user"`
}
// CAS3ServiceResponse CAS3服务验证响应结构
type CAS3ServiceResponse struct {
XMLName xml.Name `xml:"serviceResponse"`
Success *CAS3AuthenticationSuccess `xml:"authenticationSuccess"`
Failure *AuthenticationFailure `xml:"authenticationFailure"`
}
type CAS3AuthenticationSuccess struct {
User string `xml:"user"`
Attributes CAS3Attributes `xml:"attributes"`
}
type AuthenticationFailure struct {
Code string `xml:"code,attr"`
Message string `xml:",chardata"`
}
type CAS3Attributes struct {
Email string `xml:"email"`
Name string `xml:"name"`
AvatarURL string `xml:"avatar_url"`
}
const (
defaultLoginPath = "/login"
defaultValidatePathCAS2 = "/serviceValidate"
defaultValidatePathCAS3 = "/p3/serviceValidate"
callbackPath = "/share/pro/v1/openapi/cas/callback"
)
// NewClient 创建CAS客户端
func NewClient(ctx context.Context, logger *log.Logger, config Config) (*Client, error) {
// 设置默认登录路径
if config.LoginPath == "" {
config.LoginPath = defaultLoginPath
}
// 如果版本为空默认使用CAS3
if config.Version == "" {
config.Version = "3"
}
// 根据版本设置默认验证路径
if config.ValidatePath == "" {
switch config.Version {
case "3":
config.ValidatePath = defaultValidatePathCAS3
case "2", "":
config.ValidatePath = defaultValidatePathCAS2
default:
return nil, fmt.Errorf("unsupported CAS version: %s, supported versions are '2' and '3'", config.Version)
}
}
// 构建服务回调URL
if config.ServiceURL != "" {
serviceURL, err := url.Parse(config.ServiceURL)
if err != nil {
return nil, fmt.Errorf("invalid service URL: %w", err)
}
serviceURL.Path = callbackPath
config.ServiceURL = serviceURL.String()
}
return &Client{
ctx: ctx,
logger: logger.WithModule("pkg.cas"),
config: &config,
httpClient: &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
},
}, nil
}
// GetLoginURL 获取CAS登录URL
func (c *Client) GetLoginURL(state string) string {
loginURL := strings.TrimSuffix(c.config.ServerURL, "/") + c.config.LoginPath
params := url.Values{}
params.Set("service", c.config.ServiceURL+"?state="+state)
return loginURL + "?" + params.Encode()
}
// ValidateTicket 验证CAS票据并获取用户信息
func (c *Client) ValidateTicket(ticket, state string) (*UserInfo, error) {
validateURL := strings.TrimSuffix(c.config.ServerURL, "/") + c.config.ValidatePath
params := url.Values{}
params.Set("service", c.config.ServiceURL+"?state="+state)
params.Set("ticket", ticket)
fullURL := validateURL + "?" + params.Encode()
c.logger.Info("validating CAS ticket",
log.String("url", fullURL),
log.String("version", c.config.Version))
resp, err := c.httpClient.Get(fullURL)
if err != nil {
return nil, fmt.Errorf("failed to validate ticket: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
c.logger.Info("CAS validation response", log.String("response", string(body)))
// 根据CAS版本解析不同的响应格式
switch c.config.Version {
case "2":
return c.parseCAS2Response(body)
case "3":
return c.parseCAS3Response(body)
default:
return nil, fmt.Errorf("unsupported CAS version: %s", c.config.Version)
}
}
// parseCAS2Response 解析CAS2响应
func (c *Client) parseCAS2Response(body []byte) (*UserInfo, error) {
var serviceResp CAS2ServiceResponse
if err := xml.Unmarshal(body, &serviceResp); err != nil {
return nil, fmt.Errorf("failed to parse CAS2 response: %w", err)
}
if serviceResp.Failure != nil {
return nil, fmt.Errorf("CAS validation failed: %s - %s",
serviceResp.Failure.Code, strings.TrimSpace(serviceResp.Failure.Message))
}
if serviceResp.Success == nil {
return nil, fmt.Errorf("invalid CAS2 response: no success or failure element")
}
userInfo := &UserInfo{
Username: serviceResp.Success.User,
Attributes: map[string]string{
"name": serviceResp.Success.User, // CAS2通常只返回用户名
},
}
return userInfo, nil
}
// parseCAS3Response 解析CAS3响应
func (c *Client) parseCAS3Response(body []byte) (*UserInfo, error) {
var serviceResp CAS3ServiceResponse
if err := xml.Unmarshal(body, &serviceResp); err != nil {
return nil, fmt.Errorf("failed to parse CAS3 response: %w", err)
}
if serviceResp.Failure != nil {
return nil, fmt.Errorf("CAS validation failed: %s - %s",
serviceResp.Failure.Code, strings.TrimSpace(serviceResp.Failure.Message))
}
if serviceResp.Success == nil {
return nil, fmt.Errorf("invalid CAS3 response: no success or failure element")
}
userInfo := &UserInfo{
Username: serviceResp.Success.User,
Attributes: map[string]string{
"email": serviceResp.Success.Attributes.Email,
"name": serviceResp.Success.Attributes.Name,
"avatar_url": serviceResp.Success.Attributes.AvatarURL,
},
}
// 如果没有显示名称,使用用户名
if userInfo.Attributes["name"] == "" {
userInfo.Attributes["name"] = userInfo.Username
}
return userInfo, nil
}