init push

This commit is contained in:
2026-05-21 19:52:45 +08:00
commit e3f75311ab
1280 changed files with 179173 additions and 0 deletions

View File

@@ -0,0 +1,341 @@
package anydoc
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"sync"
"time"
"github.com/chaitin/panda-wiki/domain"
"github.com/chaitin/panda-wiki/log"
"github.com/chaitin/panda-wiki/mq"
"github.com/chaitin/panda-wiki/mq/types"
)
type Client struct {
httpClient *http.Client
logger *log.Logger
mqConsumer mq.MQConsumer
taskWaiters map[string]chan *domain.AnydocTaskExportEvent
mutex sync.RWMutex
subscribed bool
subscribeMu sync.Mutex
}
const (
apiUploaderUrl = "http://panda-wiki-api:8000/api/v1/file/upload/anydoc"
uploaderDir = "/image"
crawlerServiceHost = "http://panda-wiki-crawler:8080"
SpaceIdCloud = "cloud_disk"
getUrlPath = "/api/docs/url/list"
UrlExportPath = "/api/docs/url/export"
TaskListPath = "/api/tasks/list"
)
type Status string
const (
StatusPending Status = "pending"
StatusInProgress Status = "in_process"
StatusCompleted Status = "completed"
StatusFailed Status = "failed"
)
type UploaderType uint
const (
uploaderTypeDefault UploaderType = iota
uploaderTypeHTTP
)
func NewClient(logger *log.Logger, mqConsumer mq.MQConsumer) (*Client, error) {
client := &Client{
logger: logger.WithModule("anydoc.client"),
httpClient: &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
},
},
taskWaiters: make(map[string]chan *domain.AnydocTaskExportEvent),
mqConsumer: mqConsumer,
}
return client, nil
}
func (c *Client) GetUrlList(ctx context.Context, targetURL, id string) (*ListDocResponse, error) {
u, err := url.Parse(crawlerServiceHost)
if err != nil {
return nil, err
}
u.Path = getUrlPath
q := u.Query()
q.Set("url", targetURL)
q.Set("uuid", id)
u.RawQuery = q.Encode()
requestURL := u.String()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
if err != nil {
return nil, err
}
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
c.logger.Info("scrape url", "requestURL:", requestURL, "resp", string(respBody))
var scrapeResp ListDocResponse
err = json.Unmarshal(respBody, &scrapeResp)
if err != nil {
return nil, err
}
if !scrapeResp.Success {
return nil, errors.New(scrapeResp.Msg)
}
return &scrapeResp, nil
}
func (c *Client) UrlExport(ctx context.Context, id, docID, kbId string) (*UrlExportRes, error) {
u, err := url.Parse(crawlerServiceHost)
if err != nil {
return nil, err
}
u.Path = UrlExportPath
requestURL := u.String()
bodyMap := map[string]interface{}{
"uuid": id,
"doc_id": docID,
"uploader": map[string]interface{}{
"type": uploaderTypeHTTP,
"http": map[string]interface{}{
"url": apiUploaderUrl,
},
"dir": fmt.Sprintf("/%s", kbId),
},
}
jsonData, err := json.Marshal(bodyMap)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
c.logger.Info("UrlExport", "requestURL:", requestURL, "resp", string(respBody))
var res UrlExportRes
err = json.Unmarshal(respBody, &res)
if err != nil {
return nil, err
}
if !res.Success {
return nil, errors.New(res.Msg)
}
return &res, nil
}
// ensureSubscribed 确保已订阅消息队列,只订阅一次
func (c *Client) ensureSubscribed() error {
c.subscribeMu.Lock()
defer c.subscribeMu.Unlock()
if c.subscribed {
return nil
}
if c.mqConsumer == nil {
return fmt.Errorf("MQ consumer not initialized")
}
err := c.mqConsumer.RegisterHandler(domain.AnydocTaskExportTopic, c.handleTaskExportEvent)
if err != nil {
return fmt.Errorf("failed to register task export handler: %w", err)
}
c.subscribed = true
c.logger.Info("successfully subscribed to anydoc task export topic")
return nil
}
// TaskWaitForCompletion 通过 NATS 消息队列等待任务完成(推荐方式)
func (c *Client) TaskWaitForCompletion(ctx context.Context, taskID string) (*domain.AnydocTaskExportEvent, error) {
if c.mqConsumer == nil {
return nil, fmt.Errorf("MQ consumer not initialized, use NewClientWithMQ instead")
}
// 延迟订阅:只有在需要时才订阅
if err := c.ensureSubscribed(); err != nil {
return nil, err
}
// Create a channel to wait for the specific task
taskChan := make(chan *domain.AnydocTaskExportEvent, 1)
c.mutex.Lock()
c.taskWaiters[taskID] = taskChan
c.mutex.Unlock()
// Cleanup when done
defer func() {
c.mutex.Lock()
delete(c.taskWaiters, taskID)
c.mutex.Unlock()
close(taskChan)
}()
// Wait for task completion or context cancellation
select {
case event := <-taskChan:
return event, nil
case <-ctx.Done():
return nil, ctx.Err()
}
}
// TaskListPoll 轮询方式(保留兼容性)
func (c *Client) TaskListPoll(ctx context.Context, ids []string) (*TaskRes, error) {
depth := 0
const maxDepth = 10
for depth < maxDepth {
time.Sleep(1000 * time.Millisecond)
resp, err := c.TaskList(ctx, ids)
if err != nil {
return nil, err
}
if resp.Data[0].Status == StatusCompleted {
return resp, nil
}
depth++
}
return nil, fmt.Errorf("task list poll timeout")
}
// handleTaskExportEvent 处理任务导出完成事件
func (c *Client) handleTaskExportEvent(ctx context.Context, msg types.Message) error {
var event domain.AnydocTaskExportEvent
if err := json.Unmarshal(msg.GetData(), &event); err != nil {
c.logger.Error("failed to unmarshal task export event", "error", err)
return err
}
c.logger.Info("received task export event",
"task_id", event.TaskID,
"status", event.Status,
"doc_id", event.DocID)
// Notify waiting goroutines
c.mutex.RLock()
if taskChan, exists := c.taskWaiters[event.TaskID]; exists {
select {
case taskChan <- &event:
default:
// Channel is full or closed, ignore
}
}
c.mutex.RUnlock()
return nil
}
func (c *Client) TaskList(ctx context.Context, ids []string) (*TaskRes, error) {
u, err := url.Parse(crawlerServiceHost)
if err != nil {
return nil, err
}
u.Path = TaskListPath
requestURL := u.String()
bodyMap := map[string]interface{}{
"ids": ids,
}
jsonData, err := json.Marshal(bodyMap)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
c.logger.Info("TaskList url", "requestURL", requestURL, "resp", string(respBody))
var res TaskRes
err = json.Unmarshal(respBody, &res)
if err != nil {
return nil, err
}
if !res.Success {
return nil, errors.New(res.Msg)
}
if len(res.Data) == 0 {
return nil, errors.New("data list is empty")
}
return &res, nil
}
func (c *Client) DownloadDoc(ctx context.Context, filepath string) ([]byte, error) {
u, err := url.Parse(crawlerServiceHost)
if err != nil {
return nil, err
}
u.Path = "/api/tasks/download" + filepath
requestURL := u.String()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
if err != nil {
return nil, err
}
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
c.logger.Info("DownloadDoc", "requestURL:", requestURL, "resp length", len(respBody))
return respBody, nil
}

View File

@@ -0,0 +1,154 @@
package anydoc
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
)
const (
ConfluenceListPath = "/api/docs/confluence/list"
ConfluenceExportPath = "/api/docs/confluence/export"
)
// ConfluenceListDocsRequest Confluence 获取文档列表请求
type ConfluenceListDocsRequest struct {
URL string `json:"url"` // Confluence 配置文件
Filename string `json:"filename"` // 文件名,需要带扩展名
UUID string `json:"uuid"` // 必填的唯一标识符
}
// ConfluenceExportDocRequest Confluence 导出文档请求
type ConfluenceExportDocRequest struct {
UUID string `json:"uuid"` // 必须与 list 接口使用的 uuid 相同
DocID string `json:"doc_id"` // confluence-doc-id
}
// ConfluenceExportDocResponse Confluence 导出文档响应
type ConfluenceExportDocResponse struct {
Success bool `json:"success"`
Msg string `json:"msg"`
Data string `json:"data"`
}
// ConfluenceExportDocData Confluence 导出文档数据
type ConfluenceExportDocData struct {
TaskID string `json:"task_id"`
Status string `json:"status"`
FilePath string `json:"file_path"`
}
// ConfluenceListDocs 获取 Confluence 文档列表
func (c *Client) ConfluenceListDocs(ctx context.Context, confluenceURL, filename, uuid string) (*ListDocResponse, error) {
u, err := url.Parse(crawlerServiceHost)
if err != nil {
return nil, err
}
u.Path = ConfluenceListPath
requestURL := u.String()
bodyMap := map[string]interface{}{
"url": confluenceURL,
"filename": filename,
"uuid": uuid,
}
jsonData, err := json.Marshal(bodyMap)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
c.logger.Info("ConfluenceListDocs", "requestURL:", requestURL, "resp", string(respBody))
var confluenceResp ListDocResponse
err = json.Unmarshal(respBody, &confluenceResp)
if err != nil {
return nil, err
}
if !confluenceResp.Success {
return nil, errors.New(confluenceResp.Msg)
}
return &confluenceResp, nil
}
// ConfluenceExportDoc 导出 Confluence 文档
func (c *Client) ConfluenceExportDoc(ctx context.Context, uuid, docID, kbId string) (*ConfluenceExportDocResponse, error) {
u, err := url.Parse(crawlerServiceHost)
if err != nil {
return nil, err
}
u.Path = ConfluenceExportPath
requestURL := u.String()
bodyMap := map[string]interface{}{
"uuid": uuid,
"doc_id": docID,
"uploader": map[string]interface{}{
"type": uploaderTypeHTTP,
"http": map[string]interface{}{
"url": apiUploaderUrl,
},
"dir": fmt.Sprintf("/%s", kbId),
},
}
jsonData, err := json.Marshal(bodyMap)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
c.logger.Info("ConfluenceExportDoc", "requestURL:", requestURL, "resp", string(respBody))
var exportResp ConfluenceExportDocResponse
err = json.Unmarshal(respBody, &exportResp)
if err != nil {
return nil, err
}
if !exportResp.Success {
return nil, errors.New(exportResp.Msg)
}
return &exportResp, nil
}

View File

@@ -0,0 +1,70 @@
package anydoc
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
)
const (
dingtalkListPath = "/api/docs/dingtalk/list"
dingtalkExportPath = "/api/docs/dingtalk/export"
)
// DingtalkListDocs 获取 dingtalk 文档列表
func (c *Client) DingtalkListDocs(ctx context.Context, uuid string, dingtalkSetting DingtalkSetting) (*ListDocResponse, error) {
u, err := url.Parse(crawlerServiceHost)
if err != nil {
return nil, err
}
u.Path = dingtalkListPath
requestURL := u.String()
bodyMap := map[string]interface{}{
"uuid": uuid,
"app_id": dingtalkSetting.AppID,
"app_secret": dingtalkSetting.AppSecret,
"unionid": dingtalkSetting.UnionID,
"space_id": dingtalkSetting.SpaceID,
"phone": dingtalkSetting.Phone,
}
jsonData, err := json.Marshal(bodyMap)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
c.logger.Info("dingtalkListDocs", "requestURL:", requestURL, "resp", string(respBody))
var dingtalkResp ListDocResponse
err = json.Unmarshal(respBody, &dingtalkResp)
if err != nil {
return nil, err
}
if !dingtalkResp.Success {
return nil, errors.New(dingtalkResp.Msg)
}
return &dingtalkResp, nil
}

173
backend/pkg/anydoc/epub.go Normal file
View File

@@ -0,0 +1,173 @@
package anydoc
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
)
const (
epubpListPath = "/api/docs/epubp/list"
epubpExportPath = "/api/docs/epubp/export"
)
// EpubpListDocsRequest Epubp 获取文档列表请求
type EpubpListDocsRequest struct {
URL string `json:"url"` // Epubp 配置文件
Filename string `json:"filename"` // 文件名,需要带扩展名
UUID string `json:"uuid"` // 必填的唯一标识符
}
// EpubpListDocsResponse Epubp 获取文档列表响应
type EpubpListDocsResponse struct {
Success bool `json:"success"`
Msg string `json:"msg"`
Data EpubpListDocsData `json:"data"`
}
// EpubpListDocsData Epubp 文档列表数据
type EpubpListDocsData struct {
Docs []EpubpDoc `json:"docs"`
}
// EpubpDoc Epubp 文档信息
type EpubpDoc struct {
ID string `json:"id"`
Title string `json:"title"`
URL string `json:"url"`
}
// EpubpExportDocRequest Epubp 导出文档请求
type EpubpExportDocRequest struct {
UUID string `json:"uuid"` // 必须与 list 接口使用的 uuid 相同
DocID string `json:"doc_id"` // epubp-doc-id
}
// EpubpExportDocResponse Epubp 导出文档响应
type EpubpExportDocResponse struct {
Success bool `json:"success"`
Msg string `json:"msg"`
Data string `json:"data"`
}
// EpubpExportDocData Epubp 导出文档数据
type EpubpExportDocData struct {
TaskID string `json:"task_id"`
Status string `json:"status"`
FilePath string `json:"file_path"`
}
// EpubpListDocs 获取 Epubp 文档列表
func (c *Client) EpubpListDocs(ctx context.Context, epubpURL, filename, uuid string) (*ListDocResponse, error) {
u, err := url.Parse(crawlerServiceHost)
if err != nil {
return nil, err
}
u.Path = epubpListPath
requestURL := u.String()
bodyMap := map[string]interface{}{
"url": epubpURL,
"filename": filename,
"uuid": uuid,
}
jsonData, err := json.Marshal(bodyMap)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
c.logger.Info("EpubpListDocs", "requestURL:", requestURL, "resp", string(respBody))
var epubpResp ListDocResponse
err = json.Unmarshal(respBody, &epubpResp)
if err != nil {
return nil, err
}
if !epubpResp.Success {
return nil, errors.New(epubpResp.Msg)
}
return &epubpResp, nil
}
// EpubpExportDoc 导出 Epubp 文档
func (c *Client) EpubpExportDoc(ctx context.Context, uuid, docID, kbId string) (*EpubpExportDocResponse, error) {
u, err := url.Parse(crawlerServiceHost)
if err != nil {
return nil, err
}
u.Path = epubpExportPath
requestURL := u.String()
bodyMap := map[string]interface{}{
"uuid": uuid,
"doc_id": docID,
"uploader": map[string]interface{}{
"type": uploaderTypeHTTP,
"http": map[string]interface{}{
"url": apiUploaderUrl,
},
"dir": fmt.Sprintf("/%s", kbId),
},
}
jsonData, err := json.Marshal(bodyMap)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
c.logger.Info("EpubpExportDoc", "requestURL:", requestURL, "resp", string(respBody))
var exportResp EpubpExportDocResponse
err = json.Unmarshal(respBody, &exportResp)
if err != nil {
return nil, err
}
if !exportResp.Success {
return nil, errors.New(exportResp.Msg)
}
return &exportResp, nil
}

View File

@@ -0,0 +1,175 @@
package anydoc
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
)
const (
feishuListPath = "/api/docs/feishu/list"
feishuExportPath = "/api/docs/feishu/export"
)
// FeishuListDocsRequest Feishu 获取文档列表请求
type FeishuListDocsRequest struct {
URL string `json:"url"` // Feishu 配置文件
Filename string `json:"filename"` // 文件名,需要带扩展名
UUID string `json:"uuid"` // 必填的唯一标识符
}
// FeishuListDocsResponse Feishu 获取文档列表响应
type FeishuListDocsResponse struct {
Success bool `json:"success"`
Msg string `json:"msg"`
Data FeishuListDocsData `json:"data"`
}
// FeishuListDocsData Feishu 文档列表数据
type FeishuListDocsData struct {
Docs []FeishuDoc `json:"docs"`
}
// FeishuDoc Feishu 文档信息
type FeishuDoc struct {
ID string `json:"id"`
FileType string `json:"file_type"`
Title string `json:"title"`
Summary string `json:"summary"`
}
// FeishuExportDocRequest Feishu 导出文档请求
type FeishuExportDocRequest struct {
UUID string `json:"uuid"` // 必须与 list 接口使用的 uuid 相同
DocID string `json:"doc_id"` // feishu-doc-id
}
// FeishuExportDocResponse Feishu 导出文档响应
type FeishuExportDocResponse struct {
Success bool `json:"success"`
Msg string `json:"msg"`
Data string `json:"data"`
}
// FeishuExportDocData Feishu 导出文档数据
type FeishuExportDocData struct {
TaskID string `json:"task_id"`
Status string `json:"status"`
FilePath string `json:"file_path"`
}
// FeishuListDocs 获取 Feishu 文档列表
func (c *Client) FeishuListDocs(ctx context.Context, uuid, appId, appSecret, accessToken, spaceId string) (*ListDocResponse, error) {
u, err := url.Parse(crawlerServiceHost)
if err != nil {
return nil, err
}
u.Path = feishuListPath
q := u.Query()
q.Set("uuid", uuid)
q.Set("app_id", appId)
q.Set("app_secret", appSecret)
q.Set("access_token", accessToken)
if spaceId != "" {
q.Set("space_id", spaceId)
}
u.RawQuery = q.Encode()
requestURL := u.String()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
c.logger.Info("FeishuListDocs", "requestURL:", requestURL, "resp", string(respBody))
var feishuResp ListDocResponse
err = json.Unmarshal(respBody, &feishuResp)
if err != nil {
return nil, err
}
if !feishuResp.Success {
return nil, errors.New(feishuResp.Msg)
}
return &feishuResp, nil
}
// FeishuExportDoc 导出 Feishu 文档
func (c *Client) FeishuExportDoc(ctx context.Context, uuid, docID, fileType, spaceId, kbId string) (*UrlExportRes, error) {
u, err := url.Parse(crawlerServiceHost)
if err != nil {
return nil, err
}
u.Path = feishuExportPath
requestURL := u.String()
bodyMap := map[string]interface{}{
"uuid": uuid,
"doc_id": docID,
"file_type": fileType,
"space_id": spaceId,
"uploader": map[string]interface{}{
"type": uploaderTypeHTTP,
"http": map[string]interface{}{
"url": apiUploaderUrl,
},
"dir": fmt.Sprintf("/%s", kbId),
},
}
jsonData, err := json.Marshal(bodyMap)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
c.logger.Info("FeishuDoc", "requestURL:", requestURL, "body", string(jsonData), "resp", string(respBody))
var exportResp UrlExportRes
err = json.Unmarshal(respBody, &exportResp)
if err != nil {
return nil, err
}
if !exportResp.Success {
return nil, errors.New(exportResp.Msg)
}
return &exportResp, nil
}

View File

@@ -0,0 +1,173 @@
package anydoc
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
)
const (
mindocListPath = "/api/docs/mindoc/list"
mindocExportPath = "/api/docs/mindoc/export"
)
// MindocListDocsRequest Mindoc 获取文档列表请求
type MindocListDocsRequest struct {
URL string `json:"url"` // Mindoc 配置文件
Filename string `json:"filename"` // 文件名,需要带扩展名
UUID string `json:"uuid"` // 必填的唯一标识符
}
// MindocListDocsResponse Mindoc 获取文档列表响应
type MindocListDocsResponse struct {
Success bool `json:"success"`
Msg string `json:"msg"`
Data MindocListDocsData `json:"data"`
}
// MindocListDocsData Mindoc 文档列表数据
type MindocListDocsData struct {
Docs []MindocDoc `json:"docs"`
}
// MindocDoc Mindoc 文档信息
type MindocDoc struct {
ID string `json:"id"`
Title string `json:"title"`
URL string `json:"url"`
}
// MindocExportDocRequest Mindoc 导出文档请求
type MindocExportDocRequest struct {
UUID string `json:"uuid"` // 必须与 list 接口使用的 uuid 相同
DocID string `json:"doc_id"` // mindoc-doc-id
}
// MindocExportDocResponse Mindoc 导出文档响应
type MindocExportDocResponse struct {
Success bool `json:"success"`
Msg string `json:"msg"`
Data string `json:"data"`
}
// MindocExportDocData Mindoc 导出文档数据
type MindocExportDocData struct {
TaskID string `json:"task_id"`
Status string `json:"status"`
FilePath string `json:"file_path"`
}
// MindocListDocs 获取 Mindoc 文档列表
func (c *Client) MindocListDocs(ctx context.Context, mindocURL, filename, uuid string) (*ListDocResponse, error) {
u, err := url.Parse(crawlerServiceHost)
if err != nil {
return nil, err
}
u.Path = mindocListPath
requestURL := u.String()
bodyMap := map[string]interface{}{
"url": mindocURL,
"filename": filename,
"uuid": uuid,
}
jsonData, err := json.Marshal(bodyMap)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
c.logger.Info("MindocListDocs", "requestURL:", requestURL, "resp", string(respBody))
var mindocResp ListDocResponse
err = json.Unmarshal(respBody, &mindocResp)
if err != nil {
return nil, err
}
if !mindocResp.Success {
return nil, errors.New(mindocResp.Msg)
}
return &mindocResp, nil
}
// MindocExportDoc 导出 Mindoc 文档
func (c *Client) MindocExportDoc(ctx context.Context, uuid, docID, kbId string) (*MindocExportDocResponse, error) {
u, err := url.Parse(crawlerServiceHost)
if err != nil {
return nil, err
}
u.Path = mindocExportPath
requestURL := u.String()
bodyMap := map[string]interface{}{
"uuid": uuid,
"doc_id": docID,
"uploader": map[string]interface{}{
"type": uploaderTypeHTTP,
"http": map[string]interface{}{
"url": apiUploaderUrl,
},
"dir": fmt.Sprintf("/%s", kbId),
},
}
jsonData, err := json.Marshal(bodyMap)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
c.logger.Info("MindocExportDoc", "requestURL:", requestURL, "resp", string(respBody))
var exportResp MindocExportDocResponse
err = json.Unmarshal(respBody, &exportResp)
if err != nil {
return nil, err
}
if !exportResp.Success {
return nil, errors.New(exportResp.Msg)
}
return &exportResp, nil
}

View File

@@ -0,0 +1,148 @@
package anydoc
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
)
const (
notionListPath = "/api/docs/notion/list"
notionExportPath = "/api/docs/notion/export"
)
// NotionListDocsResponse Notion 获取文档列表响应
type NotionListDocsResponse struct {
Success bool `json:"success"`
Msg string `json:"msg"`
Data NotionListDocsData `json:"data"`
}
// NotionListDocsData Notion 文档列表数据
type NotionListDocsData struct {
Docs []NotionDoc `json:"docs"`
}
// NotionDoc Notion 文档信息
type NotionDoc struct {
ID string `json:"id"`
Title string `json:"title"`
URL string `json:"url"`
}
// NotionExportDocResponse Notion 导出文档响应
type NotionExportDocResponse struct {
Success bool `json:"success"`
Msg string `json:"msg"`
Data string `json:"data"`
}
// NotionListDocs 获取 Notion 文档列表
func (c *Client) NotionListDocs(ctx context.Context, secret, uuid string) (*ListDocResponse, error) {
u, err := url.Parse(crawlerServiceHost)
if err != nil {
return nil, err
}
u.Path = notionListPath
q := u.Query()
q.Set("uuid", uuid)
q.Set("secret", secret)
u.RawQuery = q.Encode()
requestURL := u.String()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
if err != nil {
return nil, err
}
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
c.logger.Info("NotionListDocs", "requestURL:", requestURL, "resp", string(respBody))
var notionResp ListDocResponse
err = json.Unmarshal(respBody, &notionResp)
if err != nil {
return nil, err
}
if !notionResp.Success {
return nil, errors.New(notionResp.Msg)
}
return &notionResp, nil
}
// NotionExportDoc 导出 Notion 文档
func (c *Client) NotionExportDoc(ctx context.Context, uuid, docID, kbId string) (*NotionExportDocResponse, error) {
u, err := url.Parse(crawlerServiceHost)
if err != nil {
return nil, err
}
u.Path = notionExportPath
requestURL := u.String()
bodyMap := map[string]interface{}{
"uuid": uuid,
"doc_id": docID,
"uploader": map[string]interface{}{
"type": uploaderTypeHTTP,
"http": map[string]interface{}{
"url": apiUploaderUrl,
},
"dir": fmt.Sprintf("/%s", kbId),
},
}
jsonData, err := json.Marshal(bodyMap)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
c.logger.Info("NotionExportDoc", "requestURL:", requestURL, "resp", string(respBody))
var exportResp NotionExportDocResponse
err = json.Unmarshal(respBody, &exportResp)
if err != nil {
return nil, err
}
if !exportResp.Success {
return nil, errors.New(exportResp.Msg)
}
return &exportResp, nil
}

16
backend/pkg/anydoc/req.go Normal file
View File

@@ -0,0 +1,16 @@
package anydoc
type FeishuSetting struct {
UserAccessToken string `json:"user_access_token"`
AppID string `json:"app_id"`
AppSecret string `json:"app_secret"`
SpaceId string `json:"space_id"`
}
type DingtalkSetting struct {
AppID string `json:"app_id"`
AppSecret string `json:"app_secret"`
SpaceID string `json:"space_id"`
UnionID string `json:"unionid"`
Phone string `json:"phone"`
}

63
backend/pkg/anydoc/res.go Normal file
View File

@@ -0,0 +1,63 @@
package anydoc
type GetUrlListResponse struct {
Success bool `json:"success"`
Data GetUrlListData `json:"data"`
Msg string `json:"msg"`
Err string `json:"err"`
TraceId interface{} `json:"trace_id"`
}
type GetUrlListData struct {
Docs []struct {
Id string `json:"id"`
FileType string `json:"file_type"`
Title string `json:"title"`
Summary string `json:"summary"`
} `json:"docs"`
}
type UrlExportRes struct {
Success bool `json:"success"`
Data string `json:"data"`
Msg string `json:"msg"`
Err string `json:"err"`
TraceId interface{} `json:"trace_id"`
}
type TaskRes struct {
Success bool `json:"success"`
Data []struct {
TaskId string `json:"task_id"`
PlatformId string `json:"platform_id"`
DocId string `json:"doc_id"`
Status Status `json:"status"`
Err string `json:"err"`
Markdown string `json:"markdown"`
Json string `json:"json"`
} `json:"data"`
Msg string `json:"msg"`
}
type ListDocResponse struct {
Success bool `json:"success"`
Data ListDocsData `json:"data"`
Msg string `json:"msg"`
Err string `json:"err"`
TraceID string `json:"trace_id"`
}
type ListDocsData struct {
Docs Child `json:"docs"`
}
type Value struct {
ID string `json:"id"`
File bool `json:"file"`
FileType string `json:"file_type"`
Title string `json:"title"`
Summary string `json:"summary"`
}
type Child struct {
Value Value `json:"value"`
Children []Child `json:"children"`
}

161
backend/pkg/anydoc/rss.go Normal file
View File

@@ -0,0 +1,161 @@
package anydoc
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
)
const (
rssListPath = "/api/docs/rss/list"
rssExportPath = "/api/docs/rss/export"
)
// RssListDocsResponse Rss 获取文档列表响应
type RssListDocsResponse struct {
Success bool `json:"success"`
Msg string `json:"msg"`
Data RssListDocsData `json:"data"`
}
// RssListDocsData Rss 文档列表数据
type RssListDocsData struct {
Docs []RssDoc `json:"docs"`
}
// RssDoc Rss 文档信息
type RssDoc struct {
Id string `json:"id"`
FileType string `json:"file_type"`
Title string `json:"title"`
Summary string `json:"summary"`
}
// RssExportDocRequest Rss 导出文档请求
type RssExportDocRequest struct {
UUID string `json:"uuid"` // 必须与 list 接口使用的 uuid 相同
DocID string `json:"doc_id"` // rss-doc-id
}
// RssExportDocResponse Rss 导出文档响应
type RssExportDocResponse struct {
Success bool `json:"success"`
Msg string `json:"msg"`
Data string `json:"data"`
}
// RssExportDocData Rss 导出文档数据
type RssExportDocData struct {
TaskID string `json:"task_id"`
Status string `json:"status"`
FilePath string `json:"file_path"`
}
// RssListDocs 获取 Rss 文档列表
func (c *Client) RssListDocs(ctx context.Context, xmlUrl, uuid string) (*ListDocResponse, error) {
u, err := url.Parse(crawlerServiceHost)
if err != nil {
return nil, err
}
u.Path = rssListPath
q := u.Query()
q.Set("uuid", uuid)
q.Set("url", xmlUrl)
u.RawQuery = q.Encode()
requestURL := u.String()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
c.logger.Info("RssListDocs", "requestURL:", requestURL, "resp", string(respBody))
var rssResp ListDocResponse
err = json.Unmarshal(respBody, &rssResp)
if err != nil {
return nil, err
}
if !rssResp.Success {
return nil, errors.New(rssResp.Msg)
}
return &rssResp, nil
}
// RssExportDoc 导出 Rss 文档
func (c *Client) RssExportDoc(ctx context.Context, uuid, docID, kbId string) (*RssExportDocResponse, error) {
u, err := url.Parse(crawlerServiceHost)
if err != nil {
return nil, err
}
u.Path = rssExportPath
requestURL := u.String()
bodyMap := map[string]interface{}{
"uuid": uuid,
"doc_id": docID,
"uploader": map[string]interface{}{
"type": uploaderTypeHTTP,
"http": map[string]interface{}{
"url": apiUploaderUrl,
},
"dir": fmt.Sprintf("/%s", kbId),
},
}
jsonData, err := json.Marshal(bodyMap)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
c.logger.Info("RssExportDoc", "requestURL:", requestURL, "resp", string(respBody))
var exportResp RssExportDocResponse
err = json.Unmarshal(respBody, &exportResp)
if err != nil {
return nil, err
}
if !exportResp.Success {
return nil, errors.New(exportResp.Msg)
}
return &exportResp, nil
}

View File

@@ -0,0 +1,161 @@
package anydoc
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
)
const (
sitemapListPath = "/api/docs/sitemap/list"
sitemapExportPath = "/api/docs/sitemap/export"
)
// SitemapListDocsResponse Sitemap 获取文档列表响应
type SitemapListDocsResponse struct {
Success bool `json:"success"`
Msg string `json:"msg"`
Data SitemapListDocsData `json:"data"`
}
// SitemapListDocsData Sitemap 文档列表数据
type SitemapListDocsData struct {
Docs []SitemapDoc `json:"docs"`
}
// SitemapDoc Sitemap 文档信息
type SitemapDoc struct {
Id string `json:"id"`
FileType string `json:"file_type"`
Title string `json:"title"`
Summary string `json:"summary"`
}
// SitemapExportDocRequest Sitemap 导出文档请求
type SitemapExportDocRequest struct {
UUID string `json:"uuid"` // 必须与 list 接口使用的 uuid 相同
DocID string `json:"doc_id"` // sitemap-doc-id
}
// SitemapExportDocResponse Sitemap 导出文档响应
type SitemapExportDocResponse struct {
Success bool `json:"success"`
Msg string `json:"msg"`
Data string `json:"data"`
}
// SitemapExportDocData Sitemap 导出文档数据
type SitemapExportDocData struct {
TaskID string `json:"task_id"`
Status string `json:"status"`
FilePath string `json:"file_path"`
}
// SitemapListDocs 获取 Sitemap 文档列表
func (c *Client) SitemapListDocs(ctx context.Context, xmlUrl, uuid string) (*ListDocResponse, error) {
u, err := url.Parse(crawlerServiceHost)
if err != nil {
return nil, err
}
u.Path = sitemapListPath
q := u.Query()
q.Set("uuid", uuid)
q.Set("url", xmlUrl)
u.RawQuery = q.Encode()
requestURL := u.String()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
c.logger.Info("SitemapListDocs", "requestURL:", requestURL, "resp", string(respBody))
var sitemapResp ListDocResponse
err = json.Unmarshal(respBody, &sitemapResp)
if err != nil {
return nil, err
}
if !sitemapResp.Success {
return nil, errors.New(sitemapResp.Msg)
}
return &sitemapResp, nil
}
// SitemapExportDoc 导出 Sitemap 文档
func (c *Client) SitemapExportDoc(ctx context.Context, uuid, docID, kbId string) (*SitemapExportDocResponse, error) {
u, err := url.Parse(crawlerServiceHost)
if err != nil {
return nil, err
}
u.Path = sitemapExportPath
requestURL := u.String()
bodyMap := map[string]interface{}{
"uuid": uuid,
"doc_id": docID,
"uploader": map[string]interface{}{
"type": uploaderTypeHTTP,
"http": map[string]interface{}{
"url": apiUploaderUrl,
},
"dir": fmt.Sprintf("/%s", kbId),
},
}
jsonData, err := json.Marshal(bodyMap)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
c.logger.Info("SitemapExportDoc", "requestURL:", requestURL, "resp", string(respBody))
var exportResp SitemapExportDocResponse
err = json.Unmarshal(respBody, &exportResp)
if err != nil {
return nil, err
}
if !exportResp.Success {
return nil, errors.New(exportResp.Msg)
}
return &exportResp, nil
}

View File

@@ -0,0 +1,173 @@
package anydoc
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
)
const (
siyuanListPath = "/api/docs/siyuan/list"
siyuanExportPath = "/api/docs/siyuan/export"
)
// SiyuanListDocsRequest Siyuan 获取文档列表请求
type SiyuanListDocsRequest struct {
URL string `json:"url"` // Siyuan 配置文件
Filename string `json:"filename"` // 文件名,需要带扩展名
UUID string `json:"uuid"` // 必填的唯一标识符
}
// SiyuanListDocsResponse Siyuan 获取文档列表响应
type SiyuanListDocsResponse struct {
Success bool `json:"success"`
Msg string `json:"msg"`
Data SiyuanListDocsData `json:"data"`
}
// SiyuanListDocsData Siyuan 文档列表数据
type SiyuanListDocsData struct {
Docs []SiyuanDoc `json:"docs"`
}
// SiyuanDoc Siyuan 文档信息
type SiyuanDoc struct {
ID string `json:"id"`
Title string `json:"title"`
URL string `json:"url"`
}
// SiyuanExportDocRequest Siyuan 导出文档请求
type SiyuanExportDocRequest struct {
UUID string `json:"uuid"` // 必须与 list 接口使用的 uuid 相同
DocID string `json:"doc_id"` // siyuan-doc-id
}
// SiyuanExportDocResponse Siyuan 导出文档响应
type SiyuanExportDocResponse struct {
Success bool `json:"success"`
Msg string `json:"msg"`
Data string `json:"data"`
}
// SiyuanExportDocData Siyuan 导出文档数据
type SiyuanExportDocData struct {
TaskID string `json:"task_id"`
Status string `json:"status"`
FilePath string `json:"file_path"`
}
// SiyuanListDocs 获取 Siyuan 文档列表
func (c *Client) SiyuanListDocs(ctx context.Context, siyuanURL, filename, uuid string) (*ListDocResponse, error) {
u, err := url.Parse(crawlerServiceHost)
if err != nil {
return nil, err
}
u.Path = siyuanListPath
requestURL := u.String()
bodyMap := map[string]interface{}{
"url": siyuanURL,
"filename": filename,
"uuid": uuid,
}
jsonData, err := json.Marshal(bodyMap)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
c.logger.Info("SiyuanListDocs", "requestURL:", requestURL, "resp", string(respBody))
var siyuanResp ListDocResponse
err = json.Unmarshal(respBody, &siyuanResp)
if err != nil {
return nil, err
}
if !siyuanResp.Success {
return nil, errors.New(siyuanResp.Msg)
}
return &siyuanResp, nil
}
// SiyuanExportDoc 导出 Siyuan 文档
func (c *Client) SiyuanExportDoc(ctx context.Context, uuid, docID, kbId string) (*SiyuanExportDocResponse, error) {
u, err := url.Parse(crawlerServiceHost)
if err != nil {
return nil, err
}
u.Path = siyuanExportPath
requestURL := u.String()
bodyMap := map[string]interface{}{
"uuid": uuid,
"doc_id": docID,
"uploader": map[string]interface{}{
"type": uploaderTypeHTTP,
"http": map[string]interface{}{
"url": apiUploaderUrl,
},
"dir": fmt.Sprintf("/%s", kbId),
},
}
jsonData, err := json.Marshal(bodyMap)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
c.logger.Info("SiyuanExportDoc", "requestURL:", requestURL, "resp", string(respBody))
var exportResp SiyuanExportDocResponse
err = json.Unmarshal(respBody, &exportResp)
if err != nil {
return nil, err
}
if !exportResp.Success {
return nil, errors.New(exportResp.Msg)
}
return &exportResp, nil
}

View File

@@ -0,0 +1,154 @@
package anydoc
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
)
const (
wikijsListPath = "/api/docs/wikijs/list"
wikijsExportPath = "/api/docs/wikijs/export"
)
// WikijsListDocsRequest Wikijs 获取文档列表请求
type WikijsListDocsRequest struct {
URL string `json:"url"` // Wikijs 配置文件
Filename string `json:"filename"` // 文件名,需要带扩展名
UUID string `json:"uuid"` // 必填的唯一标识符
}
// WikijsExportDocRequest Wikijs 导出文档请求
type WikijsExportDocRequest struct {
UUID string `json:"uuid"` // 必须与 list 接口使用的 uuid 相同
DocID string `json:"doc_id"` // wikijs-doc-id
}
// WikijsExportDocResponse Wikijs 导出文档响应
type WikijsExportDocResponse struct {
Success bool `json:"success"`
Msg string `json:"msg"`
Data string `json:"data"`
}
// WikijsExportDocData Wikijs 导出文档数据
type WikijsExportDocData struct {
TaskID string `json:"task_id"`
Status string `json:"status"`
FilePath string `json:"file_path"`
}
// WikijsListDocs 获取 Wikijs 文档列表
func (c *Client) WikijsListDocs(ctx context.Context, wikijsURL, filename, uuid string) (*ListDocResponse, error) {
u, err := url.Parse(crawlerServiceHost)
if err != nil {
return nil, err
}
u.Path = wikijsListPath
requestURL := u.String()
bodyMap := map[string]interface{}{
"url": wikijsURL,
"filename": filename,
"uuid": uuid,
}
jsonData, err := json.Marshal(bodyMap)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
c.logger.Info("WikijsListDocs", "requestURL:", requestURL, "resp", string(respBody))
var wikijsResp ListDocResponse
err = json.Unmarshal(respBody, &wikijsResp)
if err != nil {
return nil, err
}
if !wikijsResp.Success {
return nil, errors.New(wikijsResp.Msg)
}
return &wikijsResp, nil
}
// WikijsExportDoc 导出 Wikijs 文档
func (c *Client) WikijsExportDoc(ctx context.Context, uuid, docID, kbId string) (*WikijsExportDocResponse, error) {
u, err := url.Parse(crawlerServiceHost)
if err != nil {
return nil, err
}
u.Path = wikijsExportPath
requestURL := u.String()
bodyMap := map[string]interface{}{
"uuid": uuid,
"doc_id": docID,
"uploader": map[string]interface{}{
"type": uploaderTypeHTTP,
"http": map[string]interface{}{
"url": apiUploaderUrl,
},
"dir": fmt.Sprintf("/%s", kbId),
},
}
jsonData, err := json.Marshal(bodyMap)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
c.logger.Info("WikijsExportDoc", "requestURL:", requestURL, "resp", string(respBody))
var exportResp WikijsExportDocResponse
err = json.Unmarshal(respBody, &exportResp)
if err != nil {
return nil, err
}
if !exportResp.Success {
return nil, errors.New(exportResp.Msg)
}
return &exportResp, nil
}

165
backend/pkg/anydoc/yuque.go Normal file
View File

@@ -0,0 +1,165 @@
package anydoc
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
)
const (
yuqueListPath = "/api/docs/yuque/list"
yuqueExportPath = "/api/docs/yuque/export"
)
// YuqueListDocsRequest Yuque 获取文档列表请求
type YuqueListDocsRequest struct {
URL string `json:"url"` // Yuque 配置文件
Filename string `json:"filename"` // 文件名,需要带扩展名
UUID string `json:"uuid"` // 必填的唯一标识符
}
// YuqueListDocsResponse Yuque 获取文档列表响应
type YuqueListDocsResponse struct {
Success bool `json:"success"`
Msg string `json:"msg"`
Data YuqueListDocsData `json:"data"`
}
// YuqueListDocsData Yuque 文档列表数据
type YuqueListDocsData struct {
Docs []YuqueDoc `json:"docs"`
}
// YuqueDoc Yuque 文档信息
type YuqueDoc struct {
ID string `json:"id"`
Title string `json:"title"`
URL string `json:"url"`
}
// YuqueExportDocRequest Yuque 导出文档请求
type YuqueExportDocRequest struct {
UUID string `json:"uuid"` // 必须与 list 接口使用的 uuid 相同
DocID string `json:"doc_id"` // yuque-doc-id
}
// YuqueExportDocResponse Yuque 导出文档响应
type YuqueExportDocResponse struct {
Success bool `json:"success"`
Msg string `json:"msg"`
Data string `json:"data"`
}
// YuqueListDocs 获取 Yuque 文档列表
func (c *Client) YuqueListDocs(ctx context.Context, yuqueURL, filename, uuid string) (*ListDocResponse, error) {
u, err := url.Parse(crawlerServiceHost)
if err != nil {
return nil, err
}
u.Path = yuqueListPath
requestURL := u.String()
bodyMap := map[string]interface{}{
"url": yuqueURL,
"filename": filename,
"uuid": uuid,
}
jsonData, err := json.Marshal(bodyMap)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
c.logger.Info("YuqueListDocs", "requestURL:", requestURL, "resp", string(respBody))
var yuqueResp ListDocResponse
err = json.Unmarshal(respBody, &yuqueResp)
if err != nil {
return nil, err
}
if !yuqueResp.Success {
return nil, fmt.Errorf("yuque list docs API failed - URL: %s, UUID: %s, Error: %s", yuqueURL, uuid, yuqueResp.Msg)
}
return &yuqueResp, nil
}
// YuqueExportDoc 导出 Yuque 文档
func (c *Client) YuqueExportDoc(ctx context.Context, uuid, docID, kbId string) (*YuqueExportDocResponse, error) {
u, err := url.Parse(crawlerServiceHost)
if err != nil {
return nil, err
}
u.Path = yuqueExportPath
requestURL := u.String()
bodyMap := map[string]interface{}{
"uuid": uuid,
"doc_id": docID,
"uploader": map[string]interface{}{
"type": uploaderTypeHTTP,
"http": map[string]interface{}{
"url": apiUploaderUrl,
},
"dir": fmt.Sprintf("/%s", kbId),
},
}
jsonData, err := json.Marshal(bodyMap)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
c.logger.Info("YuqueExportDoc", "requestURL:", requestURL, "resp", string(respBody))
var exportResp YuqueExportDocResponse
err = json.Unmarshal(respBody, &exportResp)
if err != nil {
return nil, err
}
if !exportResp.Success {
return nil, fmt.Errorf("yuque export doc API failed - UUID: %s, DocID: %s, Error: %s", uuid, docID, exportResp.Msg)
}
return &exportResp, nil
}

View File

@@ -0,0 +1,9 @@
package bot
import (
"context"
"github.com/chaitin/panda-wiki/domain"
)
type GetQAFun func(ctx context.Context, msg string, info domain.ConversationInfo, ConversationID string) (chan string, error)

View File

@@ -0,0 +1,502 @@
package dingtalk
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"sync"
"time"
openapi "github.com/alibabacloud-go/darabonba-openapi/v2/client"
dingtalkcard_1_0 "github.com/alibabacloud-go/dingtalk/card_1_0"
dingtalkoauth2_1_0 "github.com/alibabacloud-go/dingtalk/oauth2_1_0"
util "github.com/alibabacloud-go/tea-utils/v2/service"
"github.com/alibabacloud-go/tea/tea"
"github.com/google/uuid"
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
"github.com/open-dingtalk/dingtalk-stream-sdk-go/client"
"github.com/chaitin/panda-wiki/domain"
"github.com/chaitin/panda-wiki/log"
"github.com/chaitin/panda-wiki/pkg/bot"
)
type DingTalkClient struct {
ctx context.Context
cancel context.CancelFunc
clientID string
clientSecret string
templateID string // 4d18414c-aabc-4ec8-9e67-4ceefeada72a.schema
oauthClient *dingtalkoauth2_1_0.Client
cardClient *dingtalkcard_1_0.Client
getQA bot.GetQAFun
logger *log.Logger
tokenCache struct {
accessToken string
expireAt time.Time
}
tokenMutex sync.RWMutex
messageMu sync.Mutex
messageSeenAt map[string]messageMark
messageTTL time.Duration
nowFunc func() time.Time
processMessageFn func(ctx context.Context, data *chatbot.BotCallbackDataModel) error
}
type messageMark struct {
seenAt time.Time
inFlight bool
}
func NewDingTalkClient(ctx context.Context, cancel context.CancelFunc, clientId, clientSecret, templateID string, logger *log.Logger, getQA bot.GetQAFun) (*DingTalkClient, error) {
config := &openapi.Config{}
config.Protocol = tea.String("https")
config.RegionId = tea.String("central")
oauthClient, err := dingtalkoauth2_1_0.NewClient(config)
if err != nil {
return nil, fmt.Errorf("failed to create oauth client: %w", err)
}
cardClient, err := dingtalkcard_1_0.NewClient(config)
if err != nil {
return nil, fmt.Errorf("failed to create card client: %w", err)
}
client := &DingTalkClient{
ctx: ctx,
cancel: cancel,
clientID: clientId,
clientSecret: clientSecret,
templateID: templateID,
oauthClient: oauthClient,
cardClient: cardClient,
getQA: getQA,
logger: logger,
messageSeenAt: make(map[string]messageMark),
messageTTL: 5 * time.Minute,
nowFunc: time.Now,
}
client.startMessageCleanup()
return client, nil
}
func (c *DingTalkClient) GetAccessToken() (string, error) {
c.tokenMutex.RLock()
// TODO: use redis cache
if c.tokenCache.accessToken != "" && time.Now().Before(c.tokenCache.expireAt) {
token := c.tokenCache.accessToken
c.tokenMutex.RUnlock()
return token, nil
}
c.tokenMutex.RUnlock()
c.tokenMutex.Lock()
defer c.tokenMutex.Unlock()
if c.tokenCache.accessToken != "" && time.Now().Before(c.tokenCache.expireAt) {
return c.tokenCache.accessToken, nil
}
request := &dingtalkoauth2_1_0.GetAccessTokenRequest{
AppKey: tea.String(c.clientID),
AppSecret: tea.String(c.clientSecret),
}
response, tryErr := func() (_resp *dingtalkoauth2_1_0.GetAccessTokenResponse, _e error) {
defer func() {
if r := tea.Recover(recover()); r != nil {
_e = r
}
}()
_resp, _err := c.oauthClient.GetAccessToken(request)
if _err != nil {
return nil, _err
}
return _resp, nil
}()
if tryErr != nil {
return "", tryErr
}
accessToken := *response.Body.AccessToken
c.logger.Info("get access token", log.String("access_token", accessToken), log.Int("expire_in", int(*response.Body.ExpireIn)))
c.tokenCache.accessToken = accessToken
c.tokenCache.expireAt = time.Now().Add(time.Duration(*response.Body.ExpireIn-300) * time.Second)
return c.tokenCache.accessToken, nil
}
func (c *DingTalkClient) UpdateAIStreamCard(trackID, content string, isFinalize bool) error {
accessToken, err := c.GetAccessToken()
if err != nil {
return fmt.Errorf("failed to get access token while updating interactive card: %w", err)
}
headers := &dingtalkcard_1_0.StreamingUpdateHeaders{
XAcsDingtalkAccessToken: tea.String(accessToken),
}
request := &dingtalkcard_1_0.StreamingUpdateRequest{
OutTrackId: tea.String(trackID),
Guid: tea.String(uuid.New().String()),
Key: tea.String("content"),
Content: tea.String(content),
IsFull: tea.Bool(true),
IsFinalize: tea.Bool(isFinalize),
IsError: tea.Bool(false),
}
_, err = c.cardClient.StreamingUpdateWithOptions(request, headers, &util.RuntimeOptions{})
if err != nil {
return fmt.Errorf("failed to update card: %w", err)
}
return nil
}
func (c *DingTalkClient) CreateAndDeliverCard(ctx context.Context, trackID string, data *chatbot.BotCallbackDataModel) error {
accessToken, err := c.GetAccessToken()
if err != nil {
return fmt.Errorf("failed to get access token while creating and delivering card: %w", err)
}
createAndDeliverHeaders := &dingtalkcard_1_0.CreateAndDeliverHeaders{}
createAndDeliverHeaders.XAcsDingtalkAccessToken = tea.String(accessToken)
cardDataCardParamMap := map[string]*string{
"content": tea.String(""),
}
cardData := &dingtalkcard_1_0.CreateAndDeliverRequestCardData{
CardParamMap: cardDataCardParamMap,
}
createAndDeliverRequest := &dingtalkcard_1_0.CreateAndDeliverRequest{
CardTemplateId: tea.String(c.templateID),
OutTrackId: tea.String(trackID),
CardData: cardData,
CallbackType: tea.String("STREAM"),
ImGroupOpenSpaceModel: &dingtalkcard_1_0.CreateAndDeliverRequestImGroupOpenSpaceModel{
SupportForward: tea.Bool(true),
},
ImRobotOpenSpaceModel: &dingtalkcard_1_0.CreateAndDeliverRequestImRobotOpenSpaceModel{
SupportForward: tea.Bool(true),
},
UserIdType: tea.Int32(1),
}
switch data.ConversationType {
case "2": // 群聊
openSpaceId := fmt.Sprintf("dtv1.card//%s.%s", "IM_GROUP", data.ConversationId)
createAndDeliverRequest.SetOpenSpaceId(openSpaceId)
createAndDeliverRequest.SetImGroupOpenDeliverModel(
&dingtalkcard_1_0.CreateAndDeliverRequestImGroupOpenDeliverModel{
RobotCode: tea.String(c.clientID),
})
case "1": // Im机器人单聊
openSpaceId := fmt.Sprintf("dtv1.card//%s.%s", "IM_ROBOT", data.SenderStaffId)
createAndDeliverRequest.SetOpenSpaceId(openSpaceId)
createAndDeliverRequest.SetImRobotOpenDeliverModel(&dingtalkcard_1_0.CreateAndDeliverRequestImRobotOpenDeliverModel{
SpaceType: tea.String("IM_GROUP"),
})
default:
return fmt.Errorf("invalid conversation type: %s", data.ConversationType)
}
_, err = c.cardClient.CreateAndDeliverWithOptions(createAndDeliverRequest, createAndDeliverHeaders, &util.RuntimeOptions{})
if err != nil {
return fmt.Errorf("failed to create and deliver card: %w", err)
}
return nil
}
func (c *DingTalkClient) startMessageCleanup() {
go func() {
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for {
select {
case <-c.ctx.Done():
return
case <-ticker.C:
c.cleanupExpiredMessages()
}
}
}()
}
func (c *DingTalkClient) cleanupExpiredMessages() {
now := c.nowFunc()
c.messageMu.Lock()
defer c.messageMu.Unlock()
for msgID, mark := range c.messageSeenAt {
if mark.inFlight {
continue
}
if now.Sub(mark.seenAt) > c.messageTTL {
delete(c.messageSeenAt, msgID)
}
}
}
func (c *DingTalkClient) tryMarkMessage(msgID string) bool {
if strings.TrimSpace(msgID) == "" {
return true
}
now := c.nowFunc()
c.messageMu.Lock()
defer c.messageMu.Unlock()
if mark, ok := c.messageSeenAt[msgID]; ok {
if mark.inFlight || now.Sub(mark.seenAt) <= c.messageTTL {
return false
}
}
c.messageSeenAt[msgID] = messageMark{
seenAt: now,
inFlight: true,
}
return true
}
func (c *DingTalkClient) markMessageCompleted(msgID string) {
if strings.TrimSpace(msgID) == "" {
return
}
c.messageMu.Lock()
defer c.messageMu.Unlock()
c.messageSeenAt[msgID] = messageMark{
seenAt: c.nowFunc(),
inFlight: false,
}
}
func (c *DingTalkClient) clearMessageMark(msgID string) {
if strings.TrimSpace(msgID) == "" {
return
}
c.messageMu.Lock()
defer c.messageMu.Unlock()
delete(c.messageSeenAt, msgID)
}
func (c *DingTalkClient) OnChatBotMessageReceived(ctx context.Context, data *chatbot.BotCallbackDataModel) ([]byte, error) {
select {
case <-c.ctx.Done():
c.logger.Info("dingtalk bot is disabled, ignoring message", log.String("client_id", c.clientID))
return nil, nil
default:
}
if !c.tryMarkMessage(data.MsgId) {
c.logger.Info("ignore duplicate dingtalk message", log.String("msg_id", data.MsgId))
return []byte(""), nil
}
processor := c.processMessageFn
if processor == nil {
processor = c.processMessage
}
payload := *data
go c.processMessageAsync(c.ctx, &payload, processor)
return []byte(""), nil
}
func (c *DingTalkClient) processMessageAsync(ctx context.Context, data *chatbot.BotCallbackDataModel, processor func(context.Context, *chatbot.BotCallbackDataModel) error) {
defer func() {
if r := recover(); r != nil {
c.clearMessageMark(data.MsgId)
c.logger.Error("process dingtalk message panicked", log.String("msg_id", data.MsgId), log.Any("panic", r))
}
}()
if err := processor(ctx, data); err != nil {
c.clearMessageMark(data.MsgId)
c.logger.Error("process dingtalk message failed", log.String("msg_id", data.MsgId), log.Error(err))
return
}
c.markMessageCompleted(data.MsgId)
}
func (c *DingTalkClient) processMessage(ctx context.Context, data *chatbot.BotCallbackDataModel) error {
question := data.Text.Content
question = strings.TrimSpace(question)
trackID := uuid.New().String()
// conversation_type == 1 表示机器人单聊,==2 表示群聊中@机器人
c.logger.Info("dingtalk client received message", log.String("question", question), log.String("track_id", trackID), log.String("conversation_type", data.ConversationType))
// create and deliver card
if err := c.CreateAndDeliverCard(ctx, trackID, data); err != nil {
c.logger.Error("CreateAndDeliverCard", log.Error(err))
return err
}
initialContent := fmt.Sprintf("**%s**\n\n%s", question, "稍等,让我想一想……")
if err := c.UpdateAIStreamCard(trackID, initialContent, false); err != nil {
c.logger.Error("UpdateInteractiveCard", log.Error(err))
return err
}
// 初始化 默认为空
convInfo := &domain.ConversationInfo{
UserInfo: domain.UserInfo{
From: domain.MessageFromPrivate, // 默认是私聊
},
}
// 之前创建并且发送卡片消息,获取用户基本信息
userinfo, err := c.GetUserInfo(data.SenderStaffId)
if err != nil {
c.logger.Error("GetUserInfo failed", log.Error(err))
} else {
c.logger.Info("GetUserInfo success", log.Any("userinfo", userinfo))
convInfo.UserInfo.UserID = userinfo.Result.Userid
convInfo.UserInfo.NickName = userinfo.Result.Name
convInfo.UserInfo.Avatar = userinfo.Result.Avatar
convInfo.UserInfo.Email = userinfo.Result.Email
}
if data.ConversationType == "2" { // 群聊
convInfo.UserInfo.From = domain.MessageFromGroup
} else { // 单聊
convInfo.UserInfo.From = domain.MessageFromPrivate
}
contentCh, err := c.getQA(ctx, question, *convInfo, "")
if err != nil {
c.logger.Error("dingtalk client failed to get answer", log.Error(err))
if updateErr := c.UpdateAIStreamCard(trackID, "出错了,请稍后再试", true); updateErr != nil {
c.logger.Error("UpdateInteractiveCard in contentCh failed", log.Error(updateErr))
return fmt.Errorf("get answer failed: %w; update error card failed: %w", err, updateErr)
}
return nil
}
updateTicker := time.NewTicker(1500 * time.Millisecond)
defer updateTicker.Stop()
ans := fmt.Sprintf("**%s**\n\n", question)
fullContent := fmt.Sprintf("**%s**\n\n", question)
for {
select {
case content, ok := <-contentCh:
if !ok {
if err := c.UpdateAIStreamCard(trackID, fullContent, true); err != nil {
c.logger.Error("UpdateInteractiveCard in contentCh", log.Error(err))
if updateErr := c.UpdateAIStreamCard(trackID, "出错了,请稍后再试", true); updateErr != nil {
c.logger.Error("UpdateInteractiveCard in contentCh failed", log.Error(updateErr))
return fmt.Errorf("final update card failed: %w; fallback update failed: %w", err, updateErr)
}
}
return nil
}
fullContent += content
case <-updateTicker.C:
if fullContent == ans {
continue
}
if err := c.UpdateAIStreamCard(trackID, fullContent, false); err != nil {
c.logger.Error("UpdateInteractiveCard in ticker", log.Error(err))
if updateErr := c.UpdateAIStreamCard(trackID, "出错了,请稍后再试", true); updateErr != nil {
c.logger.Error("UpdateInteractiveCard in ticker failed", log.Error(updateErr))
return fmt.Errorf("stream update card failed: %w; fallback update failed: %w", err, updateErr)
}
return nil
}
}
}
}
func (c *DingTalkClient) Start() error {
cli := client.NewStreamClient(client.WithAppCredential(client.NewAppCredentialConfig(
c.clientID,
c.clientSecret,
)))
cli.RegisterChatBotCallbackRouter(c.OnChatBotMessageReceived)
if err := cli.Start(c.ctx); err != nil {
return err
}
<-c.ctx.Done()
return nil
}
func (c *DingTalkClient) Stop() {
c.cancel()
}
// 钉钉的用户信息
type UserDetailResponse struct {
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
Result UserDetails `json:"result"`
}
type UserDetails struct {
Unionid string `json:"unionid"`
Userid string `json:"userid"`
Name string `json:"name"`
Avatar string `json:"avatar"`
Mobile string `json:"mobile"`
Email string `json:"email"`
Title string `json:"title"`
Active bool `json:"active"`
Admin bool `json:"admin"`
Boss bool `json:"boss"`
DeptIDList []int64 `json:"dept_id_list"`
JobNumber string `json:"job_number"`
HiredDate int64 `json:"hired_date"`
ManagerUserid string `json:"manager_userid"`
}
// 使用原始的http请求来获取用户的信息 - > 需要设置获取用户的权限功能:企业员工手机号信息和邮箱等个人信息、成员信息读权限
func (c *DingTalkClient) GetUserInfo(userID string) (*UserDetailResponse, error) {
accessToken, err := c.GetAccessToken()
if err != nil {
return nil, fmt.Errorf("failed to get access token while creating and delivering card: %w", err)
}
// 1. 构建URL和请求体
url := "https://oapi.dingtalk.com/topapi/v2/user/get"
payload := map[string]string{"userid": userID, "language": "zh_CN"} // 默认是中文
jsonPayload, _ := json.Marshal(payload)
req, _ := http.NewRequest("POST", url, bytes.NewBuffer(jsonPayload))
req.Header.Set("Content-Type", "application/json")
query := req.URL.Query()
query.Add("access_token", accessToken)
req.URL.RawQuery = query.Encode()
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
c.logger.Error("Failed to get user info from dingtalk: %v", log.Error(err))
return nil, err
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
// 获取到用户信息
c.logger.Info("Get user info from dingtalk success", log.Any("resp 原始的消息:", resp))
var result UserDetailResponse
if err := json.Unmarshal(body, &result); err != nil {
c.logger.Error("Failed to unmarshal user info response: %v", log.Error(err))
return nil, err
}
if result.ErrCode != 0 {
c.logger.Error("Failed to get result info", log.Any("ErrCode", result.ErrCode), log.String("ErrMsg", result.ErrMsg))
return nil, fmt.Errorf("result.ErrCode:%d", result.ErrCode)
}
// success
c.logger.Info("Get user info from dingtalk success", log.Any("userinfo:", result))
return &result, nil
}

View File

@@ -0,0 +1,263 @@
package dingtalk
import (
"context"
"io"
"log/slog"
"testing"
"time"
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
pwlog "github.com/chaitin/panda-wiki/log"
)
func newTestLogger() *pwlog.Logger {
return &pwlog.Logger{
Logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
}
}
func newTestDingTalkClient(t *testing.T) *DingTalkClient {
t.Helper()
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
client, err := NewDingTalkClient(
ctx,
cancel,
"client-id",
"client-secret",
"template-id",
newTestLogger(),
nil,
)
require.NoError(t, err)
client.messageTTL = time.Minute
return client
}
func TestTryMarkMessageDeduplicatesWithinTTL(t *testing.T) {
client := newTestDingTalkClient(t)
now := time.Now()
client.nowFunc = func() time.Time {
return now
}
require.True(t, client.tryMarkMessage("msg-1"))
require.False(t, client.tryMarkMessage("msg-1"))
client.markMessageCompleted("msg-1")
require.False(t, client.tryMarkMessage("msg-1"))
now = now.Add(client.messageTTL + time.Second)
require.True(t, client.tryMarkMessage("msg-1"))
}
func TestOnChatBotMessageReceivedIgnoresDuplicateMsgID(t *testing.T) {
client := newTestDingTalkClient(t)
processed := make(chan struct{}, 2)
client.processMessageFn = func(context.Context, *chatbot.BotCallbackDataModel) error {
processed <- struct{}{}
return nil
}
data := &chatbot.BotCallbackDataModel{
MsgId: "msg-1",
Text: chatbot.BotCallbackDataTextModel{
Content: "hello",
},
}
resp, err := client.OnChatBotMessageReceived(context.Background(), data)
require.NoError(t, err)
assert.Equal(t, []byte(""), resp)
resp, err = client.OnChatBotMessageReceived(context.Background(), data)
require.NoError(t, err)
assert.Equal(t, []byte(""), resp)
select {
case <-processed:
case <-time.After(time.Second):
t.Fatal("expected first message to be processed")
}
select {
case <-processed:
t.Fatal("expected duplicate message to be ignored")
case <-time.After(300 * time.Millisecond):
}
}
func TestOnChatBotMessageReceivedReturnsBeforeProcessingCompletes(t *testing.T) {
client := newTestDingTalkClient(t)
started := make(chan struct{})
unblock := make(chan struct{})
client.processMessageFn = func(context.Context, *chatbot.BotCallbackDataModel) error {
close(started)
<-unblock
return nil
}
done := make(chan struct{})
go func() {
_, _ = client.OnChatBotMessageReceived(context.Background(), &chatbot.BotCallbackDataModel{
MsgId: "msg-2",
Text: chatbot.BotCallbackDataTextModel{
Content: "slow question",
},
})
close(done)
}()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("expected callback to return before background processing finishes")
}
select {
case <-started:
case <-time.After(time.Second):
t.Fatal("expected background processing to start")
}
close(unblock)
}
func TestOnChatBotMessageReceivedAllowsRetryAfterProcessingError(t *testing.T) {
client := newTestDingTalkClient(t)
attempts := make(chan struct{}, 4)
client.processMessageFn = func(context.Context, *chatbot.BotCallbackDataModel) error {
attempts <- struct{}{}
return assert.AnError
}
data := &chatbot.BotCallbackDataModel{
MsgId: "msg-retry",
Text: chatbot.BotCallbackDataTextModel{
Content: "retry please",
},
}
resp, err := client.OnChatBotMessageReceived(context.Background(), data)
require.NoError(t, err)
assert.Equal(t, []byte(""), resp)
select {
case <-attempts:
case <-time.After(time.Second):
t.Fatal("expected first message to be processed")
}
require.Eventually(t, func() bool {
_, callErr := client.OnChatBotMessageReceived(context.Background(), data)
require.NoError(t, callErr)
select {
case <-attempts:
return true
default:
return false
}
}, time.Second, 20*time.Millisecond)
}
func TestOnChatBotMessageReceivedRecoversBackgroundPanic(t *testing.T) {
client := newTestDingTalkClient(t)
attempts := make(chan struct{}, 4)
client.processMessageFn = func(context.Context, *chatbot.BotCallbackDataModel) error {
attempts <- struct{}{}
panic("boom")
}
data := &chatbot.BotCallbackDataModel{
MsgId: "msg-panic",
Text: chatbot.BotCallbackDataTextModel{
Content: "panic please",
},
}
resp, err := client.OnChatBotMessageReceived(context.Background(), data)
require.NoError(t, err)
assert.Equal(t, []byte(""), resp)
select {
case <-attempts:
case <-time.After(time.Second):
t.Fatal("expected background processing to start")
}
require.Eventually(t, func() bool {
_, callErr := client.OnChatBotMessageReceived(context.Background(), data)
require.NoError(t, callErr)
select {
case <-attempts:
return true
default:
return false
}
}, time.Second, 20*time.Millisecond)
}
func TestOnChatBotMessageReceivedKeepsInFlightMessageMarkedPastTTL(t *testing.T) {
client := newTestDingTalkClient(t)
now := time.Now()
client.nowFunc = func() time.Time {
return now
}
processed := make(chan struct{}, 2)
unblock := make(chan struct{})
client.processMessageFn = func(context.Context, *chatbot.BotCallbackDataModel) error {
processed <- struct{}{}
<-unblock
return nil
}
data := &chatbot.BotCallbackDataModel{
MsgId: "msg-inflight",
Text: chatbot.BotCallbackDataTextModel{
Content: "long running question",
},
}
resp, err := client.OnChatBotMessageReceived(context.Background(), data)
require.NoError(t, err)
assert.Equal(t, []byte(""), resp)
select {
case <-processed:
case <-time.After(time.Second):
t.Fatal("expected first message to be processed")
}
now = now.Add(client.messageTTL + time.Second)
client.cleanupExpiredMessages()
resp, err = client.OnChatBotMessageReceived(context.Background(), data)
require.NoError(t, err)
assert.Equal(t, []byte(""), resp)
select {
case <-processed:
t.Fatal("expected in-flight duplicate message to be ignored after ttl cleanup")
case <-time.After(300 * time.Millisecond):
}
close(unblock)
}

View File

@@ -0,0 +1,30 @@
package discord
import (
"context"
"testing"
"github.com/chaitin/panda-wiki/config"
"github.com/chaitin/panda-wiki/domain"
"github.com/chaitin/panda-wiki/log"
)
func TestDiscord(t *testing.T) {
cfg, _ := config.NewConfig()
log := log.NewLogger(cfg)
token := "token"
getQA := func(ctx context.Context, msg string, info domain.ConversationInfo, ConversationID string) (chan string, error) {
contentCh := make(chan string, 10)
go func() {
defer close(contentCh)
contentCh <- "hello " + msg
}()
return contentCh, nil
}
c, _ := NewDiscordClient(log, token, getQA)
if err := c.Start(); err != nil {
t.Errorf("Failed to start Discord client: %v", err)
}
select {}
}

View File

@@ -0,0 +1,98 @@
package discord
import (
"context"
"fmt"
"strings"
"github.com/chaitin/panda-wiki/domain"
"github.com/chaitin/panda-wiki/log"
"github.com/chaitin/panda-wiki/pkg/bot"
"github.com/bwmarrin/discordgo"
)
type DiscordClient struct {
logger *log.Logger
BotToken string
dg *discordgo.Session
getQA bot.GetQAFun
}
func NewDiscordClient(logger *log.Logger, BotToken string, getQA bot.GetQAFun) (*DiscordClient, error) {
dg, err := discordgo.New("Bot " + BotToken)
if err != nil {
return nil, fmt.Errorf("failed to create Discord session: %v", err)
}
return &DiscordClient{
logger: logger.WithModule("bot.discord"),
BotToken: BotToken,
dg: dg,
getQA: getQA,
}, nil
}
func (d *DiscordClient) Start() error {
err := d.dg.Open()
if err != nil {
return fmt.Errorf("failed to open Discord connection: %v", err)
}
d.dg.AddHandler(d.handleMessage)
return nil
}
func (d *DiscordClient) Stop() error {
return d.dg.Close()
}
func (d *DiscordClient) handleMessage(s *discordgo.Session, m *discordgo.MessageCreate) {
if m.Author.ID == s.State.User.ID {
return
}
// 判断群聊单聊
d.logger.Debug("接收到消息", log.String("消息内容", m.Content))
d.logger.Debug("接收到消息", log.String("ChannelID", m.ChannelID))
d.logger.Debug("接收到消息", log.String("GuildID", m.GuildID))
// 只接收@ bot 的消息
preFix := fmt.Sprintf("<@%s>", s.State.User.ID)
if !strings.HasPrefix(m.Content, preFix) {
return
}
content := strings.TrimPrefix(m.Content, preFix)
info := domain.ConversationInfo{
UserInfo: domain.UserInfo{
NickName: m.Author.Username,
Email: m.Author.Email,
UserID: m.Author.ID,
},
}
if m.GuildID != "" {
info.UserInfo.From = domain.MessageFromGroup
} else {
info.UserInfo.From = domain.MessageFromPrivate
}
d.logger.Debug("消息来自", log.String("用户名", m.Author.Username), log.String("ID", m.Author.ID), log.String("内容", content))
d.logger.Debug("消息来自频道", log.String("名称", m.ChannelID))
qaChan, err := d.getQA(context.Background(), content, info, "")
if err != nil {
d.logger.Error("failed to get QA", log.String("error", err.Error()))
return
}
message, err := s.ChannelMessageSend(m.ChannelID, "正在获取答案...")
if err != nil {
d.logger.Error("failed to send message to discord", log.String("error", err.Error()))
return
}
go func() {
buf := strings.Builder{}
for qa := range qaChan {
buf.WriteString(qa)
}
_, err := s.ChannelMessageEdit(message.ChannelID, message.ID, buf.String())
if err != nil {
d.logger.Error("failed to edit message to discord", log.String("error", err.Error()))
}
}()
}

View File

@@ -0,0 +1,299 @@
package feishu
import (
"context"
"encoding/json"
"fmt"
"strings"
"sync"
"time"
"github.com/google/uuid"
lark "github.com/larksuite/oapi-sdk-go/v3"
"github.com/larksuite/oapi-sdk-go/v3/event/dispatcher"
larkcardkit "github.com/larksuite/oapi-sdk-go/v3/service/cardkit/v1"
larkcontact "github.com/larksuite/oapi-sdk-go/v3/service/contact/v3"
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
larkws "github.com/larksuite/oapi-sdk-go/v3/ws"
"github.com/chaitin/panda-wiki/domain"
"github.com/chaitin/panda-wiki/log"
"github.com/chaitin/panda-wiki/pkg/bot"
)
type FeishuBotLogger struct {
logger *log.Logger
}
func (l *FeishuBotLogger) Info(ctx context.Context, args ...interface{}) {
l.logger.Info("feishu bot", log.Any("args", args))
}
func (l *FeishuBotLogger) Error(ctx context.Context, args ...interface{}) {
l.logger.Error("feishu bot", log.Any("args", args))
}
func (l *FeishuBotLogger) Debug(ctx context.Context, args ...interface{}) {
l.logger.Debug("feishu bot", log.Any("args", args))
}
func (l *FeishuBotLogger) Warn(ctx context.Context, args ...interface{}) {
l.logger.Warn("feishu bot", log.Any("args", args))
}
type FeishuClient struct {
ctx context.Context
cancel context.CancelFunc
clientID string
clientSecret string
logger *log.Logger
client *lark.Client
msgMap sync.Map
getQA bot.GetQAFun
}
func NewFeishuClient(ctx context.Context, cancel context.CancelFunc, clientID, clientSecret string, logger *log.Logger, getQA bot.GetQAFun) *FeishuClient {
client := lark.NewClient(clientID, clientSecret, lark.WithLogger(&FeishuBotLogger{logger: logger}))
c := &FeishuClient{
ctx: ctx,
cancel: cancel,
clientID: clientID,
clientSecret: clientSecret,
client: client,
logger: logger,
getQA: getQA,
}
go func() {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-c.ctx.Done():
return
case <-ticker.C:
c.msgMap.Range(func(key, value any) bool {
// remove messageId if it is older than 5 minutes
if time.Now().Unix()-value.(int64) > 5*60 {
c.msgMap.Delete(key)
}
return true
})
}
}
}()
return c
}
var cardDataTemplate = `{"schema":"2.0","header":{"title":{"content":"%s","tag":"plain_text"}},"config":{"streaming_mode":true,"summary":{"content":""}},"body":{"elements":[{"tag":"markdown","content":"%s","element_id":"markdown_1"}]}}`
func (c *FeishuClient) sendQACard(ctx context.Context, receiveIdType string, receiveId string, question string, additionalInfo string) {
// create card
cardData := fmt.Sprintf(cardDataTemplate, question, "稍等,让我想一想...")
req := larkcardkit.NewCreateCardReqBuilder().
Body(larkcardkit.NewCreateCardReqBodyBuilder().
Type(`card_json`).
Data(cardData).
Build()).
Build()
resp, err := c.client.Cardkit.V1.Card.Create(ctx, req)
if err != nil {
c.logger.Error("failed to create card", log.Error(err))
return
}
if !resp.Success() {
c.logger.Error("failed to create card", log.String("request_id", resp.RequestId()), log.Any("code_error", resp.CodeError))
return
}
content, err := json.Marshal(map[string]any{
"type": "card",
"data": map[string]string{
"card_id": *resp.Data.CardId,
},
})
if err != nil {
c.logger.Error("failed to marshal alarm card", log.Error(err))
return
}
// send card to user or group
res, err := c.client.Im.Message.Create(ctx, larkim.NewCreateMessageReqBuilder().
ReceiveIdType(receiveIdType).
Body(larkim.NewCreateMessageReqBodyBuilder().
MsgType("interactive").
ReceiveId(receiveId).
Content(string(content)).
Build()).
Build())
if err != nil {
c.logger.Error("failed to create message", log.Error(err))
return
}
if !res.Success() {
c.logger.Error("failed to create message", log.Int("code", res.Code), log.String("msg", res.Msg), log.String("request_id", res.RequestId()))
return
}
// 打印日志
c.logger.Info("send QA card to user or group", log.String("receive_id_type", receiveIdType), log.String("receive_id", receiveId), log.String("question", question), log.String("additional_info(chat:user_openid/p2p:chat_id)", additionalInfo))
// start processing QA
convInfo := domain.ConversationInfo{
UserInfo: domain.UserInfo{
From: domain.MessageFromPrivate, // 默认是私聊
},
}
if receiveIdType == "open_id" {
// 获取用户的信息只需要获取p2p的对话的类型的用户信息 - p2p对话
userinfo, err := c.GetUserInfo(receiveId)
if err != nil {
c.logger.Error("get user info failed", log.Error(err))
} else {
if userinfo.UserId != nil {
convInfo.UserInfo.UserID = *userinfo.UserId
}
if userinfo.Name != nil {
convInfo.UserInfo.NickName = *userinfo.Name
}
if userinfo.Avatar != nil && userinfo.Avatar.AvatarOrigin != nil {
convInfo.UserInfo.Avatar = *userinfo.Avatar.AvatarOrigin
}
c.logger.Info("get user info success", log.Any("user_info", userinfo))
}
convInfo.UserInfo.From = domain.MessageFromPrivate // 私聊
} else { // chat_id 中的userid
// 获取群聊的消息,用户如果是在群聊中@机器人,那么就获取的是群聊的消息
userinfo, err := c.GetUserInfo(additionalInfo)
if err != nil {
c.logger.Error("get chat info failed", log.Error(err))
} else {
if userinfo.UserId != nil {
convInfo.UserInfo.UserID = *userinfo.UserId
}
if userinfo.Name != nil {
convInfo.UserInfo.NickName = *userinfo.Name
}
if userinfo.Avatar != nil && userinfo.Avatar.AvatarOrigin != nil {
convInfo.UserInfo.Avatar = *userinfo.Avatar.AvatarOrigin
}
c.logger.Info("get chat user info success", log.Any("user_info", userinfo))
}
convInfo.UserInfo.From = domain.MessageFromGroup // 群聊
}
answerCh, err := c.getQA(ctx, question, convInfo, "")
if err != nil {
c.logger.Error("get QA failed", log.Error(err))
return
}
answer := ""
seq := 1
for chunk := range answerCh {
seq += 1
answer += chunk
// 部分模型存在输出为空的情况导致飞书报错
if strings.TrimSpace(chunk) == "" {
continue
}
// update card content streaming
updateReq := larkcardkit.NewContentCardElementReqBuilder().
CardId(*resp.Data.CardId).
ElementId(`markdown_1`).
Body(larkcardkit.NewContentCardElementReqBodyBuilder().
Uuid(uuid.New().String()).
Content(answer).
Sequence(seq).
Build()).
Build()
updateResp, err := c.client.Cardkit.V1.CardElement.Content(ctx, updateReq)
if err != nil {
c.logger.Error("failed to update card", log.Error(err))
return
}
if !updateResp.Success() {
c.logger.Error("failed to update card", log.String("request_id", updateResp.RequestId()), log.Any("code_error", updateResp.CodeError))
return
}
}
c.logger.Info("start processing QA", log.String("message_id", *res.Data.MessageId))
}
type Message struct {
Text string `json:"text"`
}
func (c *FeishuClient) Start() error {
eventHandler := dispatcher.NewEventDispatcher("", "").
OnP2MessageReceiveV1(func(ctx context.Context, event *larkim.P2MessageReceiveV1) error {
// ignore duplicate message
if *event.Event.Message.MessageId == "" {
return nil
}
messageId := *event.Event.Message.MessageId
if _, ok := c.msgMap.Load(messageId); ok {
return nil
}
c.msgMap.Store(messageId, time.Now().Unix())
c.logger.Info("received message from feishu bot", log.String("message_id", messageId))
// only handle text type
if *event.Event.Message.MessageType != "text" {
return nil
}
switch *event.Event.Message.ChatType {
case "group":
var message Message
if err := json.Unmarshal([]byte(*event.Event.Message.Content), &message); err != nil {
c.logger.Error("failed to unmarshal message", log.Error(err))
return nil
}
c.sendQACard(ctx, "chat_id", *event.Event.Message.ChatId, message.Text, *event.Event.Sender.SenderId.OpenId)
case "p2p":
var message Message
if err := json.Unmarshal([]byte(*event.Event.Message.Content), &message); err != nil {
c.logger.Error("failed to unmarshal message", log.Error(err))
return nil
}
c.sendQACard(ctx, "open_id", *event.Event.Sender.SenderId.OpenId, message.Text, *event.Event.Message.ChatId)
default:
c.logger.Warn("unsupported chat type", log.String("chat_type", *event.Event.Message.ChatType))
}
return nil
})
cli := larkws.NewClient(c.clientID, c.clientSecret,
larkws.WithEventHandler(eventHandler),
larkws.WithLogger(&FeishuBotLogger{logger: c.logger}),
)
// FIXME: goroutine leak in larkws.Start
err := cli.Start(c.ctx)
if err != nil {
return fmt.Errorf("failed to start feishu client: %w", err)
}
return nil
}
// 下面功能都是需要开启飞书对应的权限才可以获取到用户信息 -- 应用权限(否则获取不到对话用户的信息)
// 飞书机器人获取用户信息,只是适用于单个用户
func (c *FeishuClient) GetUserInfo(UserOpenId string) (*larkcontact.User, error) {
// 获取用户信息根据用户的id
req := larkcontact.NewGetUserReqBuilder().UserId(UserOpenId).
UserIdType(`open_id`).DepartmentIdType(`open_department_id`).Build()
// 发起请求,获取用户消息
resp, err := c.client.Contact.User.Get(context.Background(), req)
if err != nil {
c.logger.Error("failed to get user info", log.Error(err))
return nil, err
}
// 失败
if !resp.Success() {
c.logger.Error("failed to get user info, response status not success", log.Any("errcode:", resp.Code))
return nil, fmt.Errorf("failed to get user info, response data not success")
}
return resp.Data.User, nil
}
func (c *FeishuClient) Stop() {
c.cancel()
}

View File

@@ -0,0 +1,346 @@
package lark
import (
"context"
"encoding/json"
"fmt"
"regexp"
"strings"
"sync"
"time"
"github.com/google/uuid"
lark "github.com/larksuite/oapi-sdk-go/v3"
"github.com/larksuite/oapi-sdk-go/v3/event/dispatcher"
larkcardkit "github.com/larksuite/oapi-sdk-go/v3/service/cardkit/v1"
larkcontact "github.com/larksuite/oapi-sdk-go/v3/service/contact/v3"
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
"github.com/chaitin/panda-wiki/domain"
"github.com/chaitin/panda-wiki/log"
"github.com/chaitin/panda-wiki/pkg/bot"
)
// LarkBotLogger implements Lark SDK logger interface
type LarkBotLogger struct {
logger *log.Logger
}
func (l *LarkBotLogger) Info(ctx context.Context, args ...interface{}) {
l.logger.Info("lark bot", log.Any("args", args))
}
func (l *LarkBotLogger) Error(ctx context.Context, args ...interface{}) {
l.logger.Error("lark bot", log.Any("args", args))
}
func (l *LarkBotLogger) Debug(ctx context.Context, args ...interface{}) {
l.logger.Debug("lark bot", log.Any("args", args))
}
func (l *LarkBotLogger) Warn(ctx context.Context, args ...interface{}) {
l.logger.Warn("lark bot", log.Any("args", args))
}
// LarkClient is a Lark bot client using larksuite SDK (configured for Lark international endpoints)
// Note: Lark uses HTTP callbacks instead of WebSocket for event handling
type LarkClient struct {
ctx context.Context
cancel context.CancelFunc
clientID string
clientSecret string
logger *log.Logger
client *lark.Client
msgMap sync.Map
getQA bot.GetQAFun
eventHandler *dispatcher.EventDispatcher
verifyToken string
encryptKey string
}
// NewLarkClient creates a new Lark bot client
// Lark is the international version of Feishu, using different API endpoints
// Unlike Feishu (China), Lark (International) uses HTTP callbacks instead of WebSocket
func NewLarkClient(ctx context.Context, cancel context.CancelFunc, clientID, clientSecret, verifyToken, encryptKey string, logger *log.Logger, getQA bot.GetQAFun) (*LarkClient, error) {
// Create client with Lark (international) domain
client := lark.NewClient(clientID, clientSecret,
lark.WithLogger(&LarkBotLogger{logger: logger}),
lark.WithOpenBaseUrl("https://open.larksuite.com"), // Lark international endpoint
)
c := &LarkClient{
ctx: ctx,
cancel: cancel,
clientID: clientID,
clientSecret: clientSecret,
client: client,
logger: logger,
getQA: getQA,
verifyToken: verifyToken,
encryptKey: encryptKey,
}
// Setup event handler for HTTP callbacks
c.setupEventHandler()
go func() {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-c.ctx.Done():
return
case <-ticker.C:
c.msgMap.Range(func(key, value any) bool {
// remove messageId if it is older than 5 minutes
if time.Now().Unix()-value.(int64) > 5*60 {
c.msgMap.Delete(key)
}
return true
})
}
}
}()
return c, nil
}
// setupEventHandler configures the event dispatcher for handling HTTP callbacks
func (c *LarkClient) setupEventHandler() {
c.eventHandler = dispatcher.NewEventDispatcher(c.verifyToken, c.encryptKey).
OnP2MessageReceiveV1(func(ctx context.Context, event *larkim.P2MessageReceiveV1) error {
if *event.Event.Message.MessageId == "" {
return nil
}
messageId := *event.Event.Message.MessageId
if _, ok := c.msgMap.Load(messageId); ok {
return nil
}
c.msgMap.Store(messageId, time.Now().Unix())
c.logger.Info("received message from lark bot", log.String("message_id", messageId))
if *event.Event.Message.MessageType != "text" {
return nil
}
switch *event.Event.Message.ChatType {
case "group":
var message Message
if err := json.Unmarshal([]byte(*event.Event.Message.Content), &message); err != nil {
c.logger.Error("failed to unmarshal message", log.Error(err))
return nil
}
// Replace mention placeholders with actual user names
questionText := c.replaceMentions(message.Text, event.Event.Message.Mentions)
go c.sendQACard(c.ctx, "chat_id", *event.Event.Message.ChatId, questionText, *event.Event.Sender.SenderId.OpenId)
case "p2p":
var message Message
if err := json.Unmarshal([]byte(*event.Event.Message.Content), &message); err != nil {
c.logger.Error("failed to unmarshal message", log.Error(err))
return nil
}
go c.sendQACard(c.ctx, "open_id", *event.Event.Sender.SenderId.OpenId, message.Text, *event.Event.Message.ChatId)
default:
c.logger.Warn("unsupported chat type", log.String("chat_type", *event.Event.Message.ChatType))
}
return nil
})
}
// GetEventHandler returns the event dispatcher for HTTP callback handling
// This should be registered with the HTTP server to handle Lark callbacks
func (c *LarkClient) GetEventHandler() *dispatcher.EventDispatcher {
return c.eventHandler
}
var cardDataTemplate = `{"schema":"2.0","header":{"title":{"content":"%s","tag":"plain_text"}},"config":{"streaming_mode":true,"summary":{"content":""}},"body":{"elements":[{"tag":"markdown","content":"%s","element_id":"markdown_1"}]}}`
func (c *LarkClient) sendQACard(ctx context.Context, receiveIdType string, receiveId string, question string, additionalInfo string) {
// create card
cardData := fmt.Sprintf(cardDataTemplate, question, "稍等,让我想一想...")
req := larkcardkit.NewCreateCardReqBuilder().
Body(larkcardkit.NewCreateCardReqBodyBuilder().
Type(`card_json`).
Data(cardData).
Build()).
Build()
resp, err := c.client.Cardkit.V1.Card.Create(ctx, req)
if err != nil {
c.logger.Error("failed to create card", log.Error(err))
return
}
if !resp.Success() {
c.logger.Error("failed to create card", log.String("request_id", resp.RequestId()), log.Any("code_error", resp.CodeError))
return
}
content, err := json.Marshal(map[string]any{
"type": "card",
"data": map[string]string{
"card_id": *resp.Data.CardId,
},
})
if err != nil {
c.logger.Error("failed to marshal alarm card", log.Error(err))
return
}
// send card to user or group
res, err := c.client.Im.Message.Create(ctx, larkim.NewCreateMessageReqBuilder().
ReceiveIdType(receiveIdType).
Body(larkim.NewCreateMessageReqBodyBuilder().
MsgType("interactive").
ReceiveId(receiveId).
Content(string(content)).
Build()).
Build())
if err != nil {
c.logger.Error("failed to create message", log.Error(err))
return
}
if !res.Success() {
c.logger.Error("failed to create message", log.Int("code", res.Code), log.String("msg", res.Msg), log.String("request_id", res.RequestId()))
return
}
c.logger.Info("send QA card to user or group", log.String("receive_id_type", receiveIdType), log.String("receive_id", receiveId), log.String("question", question), log.String("additional_info", additionalInfo))
// start processing QA
convInfo := domain.ConversationInfo{
UserInfo: domain.UserInfo{
From: domain.MessageFromPrivate,
},
}
if receiveIdType == "open_id" {
userinfo, err := c.GetUserInfo(receiveId)
if err != nil {
c.logger.Error("get user info failed", log.Error(err))
} else {
if userinfo.UserId != nil {
convInfo.UserInfo.UserID = *userinfo.UserId
}
if userinfo.Name != nil {
convInfo.UserInfo.NickName = *userinfo.Name
}
if userinfo.Avatar != nil && userinfo.Avatar.AvatarOrigin != nil {
convInfo.UserInfo.Avatar = *userinfo.Avatar.AvatarOrigin
}
c.logger.Info("get user info success", log.Any("user_info", userinfo))
}
convInfo.UserInfo.From = domain.MessageFromPrivate
} else {
userinfo, err := c.GetUserInfo(additionalInfo)
if err != nil {
c.logger.Error("get chat info failed", log.Error(err))
} else {
if userinfo.UserId != nil {
convInfo.UserInfo.UserID = *userinfo.UserId
}
if userinfo.Name != nil {
convInfo.UserInfo.NickName = *userinfo.Name
}
if userinfo.Avatar != nil && userinfo.Avatar.AvatarOrigin != nil {
convInfo.UserInfo.Avatar = *userinfo.Avatar.AvatarOrigin
}
c.logger.Info("get chat user info success", log.Any("user_info", userinfo))
}
convInfo.UserInfo.From = domain.MessageFromGroup
}
answerCh, err := c.getQA(ctx, question, convInfo, "")
if err != nil {
c.logger.Error("lark client failed to get answer", log.Error(err))
return
}
var buf strings.Builder
seq := 0
imageRegex := regexp.MustCompile(`!\[[^\]]*\]\([^)]+\)`)
sendUpdate := func() error {
seq++
answer := imageRegex.ReplaceAllString(buf.String(), "")
updateReq := larkcardkit.NewContentCardElementReqBuilder().
CardId(*resp.Data.CardId).
ElementId(`markdown_1`).
Body(larkcardkit.NewContentCardElementReqBodyBuilder().
Uuid(uuid.New().String()).
Content(answer).
Sequence(seq).
Build()).
Build()
updateResp, err := c.client.Cardkit.V1.CardElement.Content(ctx, updateReq)
if err != nil {
c.logger.Error("failed to update card", log.Error(err))
return err
}
if !updateResp.Success() {
c.logger.Error("failed to update card", log.String("request_id", updateResp.RequestId()), log.Any("code_error", updateResp.CodeError))
return fmt.Errorf("update card failed: %v", updateResp.CodeError)
}
return nil
}
for chunk := range answerCh {
buf.WriteString(chunk)
// drain all currently available chunks
for len(answerCh) > 0 {
buf.WriteString(<-answerCh)
}
if err := sendUpdate(); err != nil {
c.logger.Error("lark client failed to send QA update", log.Error(err), log.Int("sequence", seq))
return
}
}
c.logger.Info("start processing QA", log.String("message_id", *res.Data.MessageId))
}
type Message struct {
Text string `json:"text"`
}
// replaceMentions replaces mention placeholders like @_user_1 with actual user names
func (c *LarkClient) replaceMentions(text string, mentions []*larkim.MentionEvent) string {
if len(mentions) == 0 {
return text
}
result := text
for _, mention := range mentions {
if mention.Key != nil && mention.Name != nil {
// Replace @_user_1, @_user_2, etc. with @ActualUserName
result = strings.ReplaceAll(result, *mention.Key, "@"+*mention.Name)
}
}
return result
}
// Start initializes the Lark bot client
// Note: Unlike Feishu, Lark doesn't use WebSocket. Events are handled via HTTP callbacks.
// The actual HTTP endpoint needs to be registered separately in the HTTP router.
func (c *LarkClient) Start() error {
c.logger.Info("lark bot client initialized (HTTP callback mode)",
log.String("app_id", c.clientID),
log.String("note", "Register HTTP callback endpoint to receive events"))
// For Lark, we don't start a WebSocket connection
// Events will be received via HTTP callbacks handled by GetEventHandler()
// Just keep the context alive
<-c.ctx.Done()
c.logger.Info("lark bot client stopped")
return nil
}
func (c *LarkClient) GetUserInfo(UserOpenId string) (*larkcontact.User, error) {
req := larkcontact.NewGetUserReqBuilder().UserId(UserOpenId).
UserIdType(`open_id`).DepartmentIdType(`open_department_id`).Build()
resp, err := c.client.Contact.User.Get(context.Background(), req)
if err != nil {
c.logger.Error("failed to get user info", log.Error(err))
return nil, err
}
if !resp.Success() {
c.logger.Error("failed to get user info, response status not success", log.Any("errcode:", resp.Code))
return nil, fmt.Errorf("failed to get user info, response data not success")
}
return resp.Data.User, nil
}
func (c *LarkClient) Stop() {
c.cancel()
}

View File

@@ -0,0 +1,9 @@
package utils
import (
"github.com/russross/blackfriday/v2"
)
func Markdown2HTML(md string) string {
return string(blackfriday.Run([]byte(md), blackfriday.WithRenderer(blackfriday.NewHTMLRenderer(blackfriday.HTMLRendererParameters{Flags: blackfriday.UseXHTML | blackfriday.CompletePage}))))
}

View File

@@ -0,0 +1,106 @@
package wechat
import (
"context"
"encoding/xml"
"sync"
"time"
"github.com/chaitin/panda-wiki/log"
"github.com/chaitin/panda-wiki/repo/pg"
)
type WechatConfig struct {
Ctx context.Context
logger *log.Logger
CorpID string
Token string
EncodingAESKey string
kbID string
Secret string
AccessToken string
TokenExpire time.Time
AgentID string
// db
WeRepo *pg.WechatRepository
}
type ReceivedMessage struct {
ToUserName string `xml:"ToUserName"`
FromUserName string `xml:"FromUserName"`
CreateTime int64 `xml:"CreateTime"`
MsgType string `xml:"MsgType"`
Content string `xml:"Content"`
MsgID string `xml:"MsgId"`
}
type ResponseMessage struct {
XMLName xml.Name `xml:"xml"`
ToUserName CDATA `xml:"ToUserName"`
FromUserName CDATA `xml:"FromUserName"`
CreateTime int64 `xml:"CreateTime"`
MsgType CDATA `xml:"MsgType"`
Content CDATA `xml:"Content"`
}
type CDATA struct {
Value string `xml:",cdata"`
}
type BackendRequest struct {
Question string `json:"question"`
UserID string `json:"user_id"`
}
type BackendResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data struct {
TextResponse string `json:"test_response"`
} `json:"data"`
}
// UserInfo 用于存储获取到的用户信息
type UserInfo struct {
Errcode int `json:"errcode"`
Errmsg string `json:"errmsg"`
UserID string `json:"userid"`
Name string `json:"name"`
Department []int `json:"department"`
Mobile string `json:"mobile"`
Email string `json:"email"`
Status int `json:"status"`
}
// 获取token的回应的消息
type AccessToken struct {
Errcode int `json:"errcode"`
Errmsg string `json:"errmsg"`
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
}
type TokenCache struct {
AccessToken string
TokenExpire time.Time
Mutex sync.Mutex
}
// Map-based token cache keyed by kb & agentID
var tokenCacheMap = make(map[string]*TokenCache)
var tokenCacheMapMutex = sync.Mutex{}
// Generate a key for the token cache based on kb & agentID
func getTokenCacheKey(kbID, agentID string) string {
return kbID + ":" + agentID
}
// media
// Upload file response
type MediaUploadResponse struct {
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
MediaType string `json:"type"`
MediaID string `json:"media_id"`
CreatedAt string `json:"created_at"`
}

View File

@@ -0,0 +1,393 @@
package wechat
import (
"bytes"
"context"
"encoding/json"
"encoding/xml"
"errors"
"fmt"
"io"
"net/http"
"time"
"github.com/google/uuid"
"github.com/sbzhu/weworkapi_golang/wxbizmsgcrypt"
"github.com/chaitin/panda-wiki/consts"
"github.com/chaitin/panda-wiki/domain"
"github.com/chaitin/panda-wiki/log"
"github.com/chaitin/panda-wiki/pkg/bot"
)
const wechatMessageMaxBytes = 2000
func NewWechatAppConfig(ctx context.Context, logger *log.Logger, kbId, CorpID, Token, EncodingAESKey, secret, agentID string) (*WechatConfig, error) {
return &WechatConfig{
Ctx: ctx,
logger: logger,
kbID: kbId,
CorpID: CorpID,
Token: Token,
EncodingAESKey: EncodingAESKey,
Secret: secret,
AgentID: agentID,
}, nil
}
func (cfg *WechatConfig) VerifyUrlWechatAPP(signature, timestamp, nonce, echostr string) ([]byte, error) {
wxcpt := wxbizmsgcrypt.NewWXBizMsgCrypt(
cfg.Token,
cfg.EncodingAESKey,
cfg.CorpID,
wxbizmsgcrypt.XmlType,
)
// 验证URL并解密echostr
decryptEchoStr, errCode := wxcpt.VerifyURL(signature, timestamp, nonce, echostr)
if errCode != nil {
return nil, errors.New("server serve fail wechat")
}
// success
return decryptEchoStr, nil
}
func (cfg *WechatConfig) Wechat(msg ReceivedMessage, getQA bot.GetQAFun, userinfo *UserInfo, useTextResponse bool, weChatAppAdvancedSetting *domain.WeChatAppAdvancedSetting) error {
token, err := cfg.GetAccessToken()
if err != nil {
return err
}
if useTextResponse {
err = cfg.ProcessTextMessage(msg, getQA, token, userinfo, weChatAppAdvancedSetting.DisclaimerContent)
if err != nil {
cfg.logger.Error("send to ai failed!", log.Error(err))
return err
}
} else {
if err := cfg.ProcessUrlMessage(msg, getQA, token, userinfo); err != nil {
cfg.logger.Error("send to ai failed!", log.Error(err))
return err
}
}
return nil
}
func (cfg *WechatConfig) ProcessUrlMessage(msg ReceivedMessage, GetQA bot.GetQAFun, token string, userinfo *UserInfo) error {
// 1. get ai channel
id, err := uuid.NewV7()
if err != nil {
cfg.logger.Error("failed to generate conversation uuid", log.Error(err))
id = uuid.New()
}
conversationID := id.String()
contentChan, err := GetQA(cfg.Ctx, msg.Content, domain.ConversationInfo{
UserInfo: domain.UserInfo{
UserID: userinfo.UserID,
NickName: userinfo.Name,
From: domain.MessageFromPrivate,
}}, conversationID)
if err != nil {
return err
}
//2. go send to ai and store in map--> get conversation-id
if _, ok := domain.ConversationManager.Load(conversationID); !ok {
state := &domain.ConversationState{
Question: msg.Content,
NotificationChan: make(chan string), // notification channel
IsVisited: false,
}
domain.ConversationManager.Store(conversationID, state)
go cfg.SendQuestionToAI(conversationID, contentChan)
}
baseUrl, err := cfg.WeRepo.GetWechatBaseURL(cfg.Ctx, cfg.kbID)
if err != nil {
return err
}
//3.send url to user
Errcode, Errmsg, err := cfg.SendURLToUser(msg.FromUserName, msg.Content, token, conversationID, baseUrl)
if err != nil {
return err
}
if Errcode != 0 {
return fmt.Errorf("wechat Api failed : %s (code: %d)", Errmsg, Errcode)
}
return nil
}
func (cfg *WechatConfig) ProcessTextMessage(msg ReceivedMessage, GetQA bot.GetQAFun, token string, userinfo *UserInfo, disclaimerContent string) error {
// 1. get ai channel
id, err := uuid.NewV7()
if err != nil {
cfg.logger.Error("failed to generate conversation uuid", log.Error(err))
id = uuid.New()
}
conversationID := id.String()
contentChan, err := GetQA(cfg.Ctx, msg.Content, domain.ConversationInfo{
UserInfo: domain.UserInfo{
UserID: userinfo.UserID,
NickName: userinfo.Name,
From: domain.MessageFromPrivate,
}}, conversationID)
if err != nil {
return err
}
var fullResponse string
for content := range contentChan {
fullResponse += content
if len([]byte(fullResponse)) > wechatMessageMaxBytes { // wechat limit 2048 byte
if _, _, err := cfg.SendResponseToUser(fullResponse, msg.FromUserName, token); err != nil {
return err
}
fullResponse = ""
}
}
if len([]byte(fullResponse+disclaimerContent)) > wechatMessageMaxBytes {
if _, _, err := cfg.SendResponseToUser(fullResponse, msg.FromUserName, token); err != nil {
return err
}
if _, _, err := cfg.SendResponseToUser(disclaimerContent, msg.FromUserName, token); err != nil {
return err
}
} else {
if disclaimerContent != "" {
fullResponse += fmt.Sprintf("\n%s", disclaimerContent)
}
if _, _, err := cfg.SendResponseToUser(fullResponse, msg.FromUserName, token); err != nil {
return err
}
}
return nil
}
// SendResponseToUser
func (cfg *WechatConfig) SendURLToUser(touser, question, token, conversationID, baseUrl string) (int, string, error) {
msgData := map[string]interface{}{
"touser": touser,
"msgtype": "textcard",
"agentid": cfg.AgentID,
"textcard": map[string]interface{}{
"title": question,
"description": "<div class = \"highlight\">本回答由 PandaWiki 基于 AI 生成,仅供参考。</div>",
"url": fmt.Sprintf("%s/h5-chat?id=%s&source_type=%s", baseUrl, conversationID, consts.SourceTypeWechatBot),
},
}
jsonData, err := json.Marshal(msgData)
if err != nil {
return 0, "", err
}
url := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token=%s", token)
resp, err := http.Post(url, "application/json", bytes.NewBuffer(jsonData))
if err != nil {
return 0, "", err
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
var result struct {
Errcode int `json:"errcode"`
Errmsg string `json:"errmsg"`
}
if err := json.Unmarshal(body, &result); err != nil {
return 0, "", err
}
return result.Errcode, result.Errmsg, nil
}
func (cfg *WechatConfig) SendResponseToUser(response string, touser string, token string) (int, string, error) {
msgData := map[string]interface{}{
"touser": touser,
"msgtype": "markdown",
"agentid": cfg.AgentID,
"markdown": map[string]string{
"content": response,
},
}
jsonData, err := json.Marshal(msgData)
if err != nil {
return 0, "", err
}
url := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token=%s", token)
resp, err := http.Post(url, "application/json", bytes.NewBuffer(jsonData))
if err != nil {
return 0, "", err
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
var result struct {
Errcode int `json:"errcode"`
Errmsg string `json:"errmsg"`
}
if err := json.Unmarshal(body, &result); err != nil {
return 0, "", err
}
if result.Errcode != 0 {
return result.Errcode, result.Errmsg, fmt.Errorf("wechat Api failed : %s (code: %d)", result.Errmsg, result.Errcode)
}
return result.Errcode, result.Errmsg, nil
}
// SendResponse
func (cfg *WechatConfig) SendResponse(msg ReceivedMessage, content string) ([]byte, error) {
responseMsg := ResponseMessage{
ToUserName: CDATA{msg.FromUserName},
FromUserName: CDATA{msg.ToUserName},
CreateTime: msg.CreateTime,
MsgType: CDATA{"text"},
Content: CDATA{content},
}
// XML
responseXML, err := xml.Marshal(responseMsg)
if err != nil {
cfg.logger.Error("marshal response failed", log.Error(err))
return nil, err
}
wxcpt := wxbizmsgcrypt.NewWXBizMsgCrypt(cfg.Token, cfg.EncodingAESKey, cfg.CorpID, wxbizmsgcrypt.XmlType)
// response
var encryptMsg []byte
encryptMsg, errCode := wxcpt.EncryptMsg(string(responseXML), "", "")
if errCode != nil {
return nil, errors.New("encryotMsg err")
}
return encryptMsg, nil
}
func (cfg *WechatConfig) GetAccessToken() (string, error) {
// Generate cache key based on app credentials
cacheKey := getTokenCacheKey(cfg.kbID, cfg.AgentID)
// Get or create token cache for this app
tokenCacheMapMutex.Lock()
tokenCache, exists := tokenCacheMap[cacheKey]
if !exists {
tokenCache = &TokenCache{}
tokenCacheMap[cacheKey] = tokenCache
}
tokenCacheMapMutex.Unlock()
// Lock the specific token cache for this app
tokenCache.Mutex.Lock()
defer tokenCache.Mutex.Unlock()
if tokenCache.AccessToken != "" && time.Now().Before(tokenCache.TokenExpire) {
cfg.logger.Debug("access token has existed and is valid")
return tokenCache.AccessToken, nil
}
if cfg.Secret == "" || cfg.CorpID == "" {
return "", errors.New("secret or corpid is not right")
}
// get AccessToken--请求微信客服token
url := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid=%s&corpsecret=%s", cfg.CorpID, cfg.Secret)
resp, err := http.Get(url)
if err != nil {
return "", errors.New("get wechatapp accesstoken failed")
}
defer resp.Body.Close()
var tokenResp AccessToken // 获取到token消息
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
return "", errors.New("json decode wechat resp failed")
}
if tokenResp.Errcode != 0 {
return "", errors.New("get wechat access token failed")
}
// success
cfg.logger.Info("wechatapp get accesstoken success", log.Any("info", tokenResp.AccessToken))
tokenCache.AccessToken = tokenResp.AccessToken
tokenCache.TokenExpire = time.Now().Add(time.Duration(tokenResp.ExpiresIn-300) * time.Second)
return tokenCache.AccessToken, nil
}
func (cfg *WechatConfig) GetUserInfo(username string) (*UserInfo, error) {
accessToken, err := cfg.GetAccessToken()
if err != nil {
return nil, err
}
// 请求获取用户的内容
resp, err := http.Get(fmt.Sprintf(
"https://qyapi.weixin.qq.com/cgi-bin/user/get?access_token=%s&userid=%s",
accessToken, username))
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
// cfg.logger.Info("获取用户信息成功", log.Any("body", body))
var userInfo UserInfo
if err := json.Unmarshal(body, &userInfo); err != nil {
return nil, err
}
if userInfo.Errcode != 0 {
return nil, fmt.Errorf("获取用户信息失败: %d, %s", userInfo.Errcode, userInfo.Errmsg)
}
return &userInfo, nil
}
func (cfg *WechatConfig) UnmarshalMsg(decryptMsg []byte) (*ReceivedMessage, error) {
var msg ReceivedMessage
err := xml.Unmarshal([]byte(decryptMsg), &msg)
return &msg, err
}
// answer set into conversation state buffer
func (cfg *WechatConfig) SendQuestionToAI(conversationID string, wccontent chan string) {
// send message
val, _ := domain.ConversationManager.Load(conversationID)
state := val.(*domain.ConversationState)
for content := range wccontent {
state.Mutex.Lock()
if state.IsVisited {
state.NotificationChan <- content // notify has new data
}
state.Buffer.WriteString(content)
state.Mutex.Unlock()
}
// end sent notification
defer func() {
close(state.NotificationChan)
domain.ConversationManager.Delete(conversationID)
}()
}

View File

@@ -0,0 +1,33 @@
package wechat_official_account
import (
"context"
"github.com/silenceper/wechat/v2/officialaccount/user"
"github.com/chaitin/panda-wiki/pkg/bot"
"github.com/chaitin/panda-wiki/pkg/bot/wechat_service"
"github.com/chaitin/panda-wiki/domain"
)
func Wechat(ctx context.Context, GetQA bot.GetQAFun, userinfo *user.Info, content string) (string, error) {
wccontent, err := GetQA(ctx, content, domain.ConversationInfo{UserInfo: domain.UserInfo{
UserID: userinfo.OpenID, // 用户对话的id
NickName: userinfo.Nickname, //用户微信的昵称
Avatar: userinfo.Headimgurl, // 用户微信的头像
From: domain.MessageFromPrivate,
}}, "")
if err != nil {
return "", err
}
var response string
for v := range wccontent {
response += v
}
response = wechat_service.MarkdowntoText(response)
return response, nil
}

View File

@@ -0,0 +1,188 @@
package wechat_service
import (
"context"
"sync"
"time"
"github.com/chaitin/panda-wiki/log"
"github.com/chaitin/panda-wiki/repo/pg"
)
type WechatServiceConfig struct {
Ctx context.Context
CorpID string
Token string
EncodingAESKey string
kbID string
Secret string
logger *log.Logger
containKeywords []string
equalKeywords []string
logoUrl string
// db
WeRepo *pg.WechatRepository
}
// 存储ai知识库获取的cursor值以客服为标准方便拉取用户的消息
var KfCursors = &sync.Map{}
// 微信客服发送的消息
type WeixinUserAskMsg struct {
ToUserName string `xml:"ToUserName"`
CreateTime int64 `xml:"CreateTime"`
MsgType string `xml:"MsgType"`
Event string `xml:"Event"`
Token string `xml:"Token"`
OpenKfId string `xml:"OpenKfId"`
}
type AccessToken struct {
Errcode int `json:"errcode"`
Errmsg string `json:"errmsg"`
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
}
type MsgRequest struct {
Cursor string `json:"cursor"`
Token string `json:"token"`
Limit int `json:"limit"`
VoiceFormat int `json:"voice_format"`
OpenKfid string `json:"open_kfid"`
}
type MsgRet struct {
Errcode int `json:"errcode"`
Errmsg string `json:"errmsg"`
NextCursor string `json:"next_cursor"` // 游标
MsgList []Msg `json:"msg_list"`
HasMore int `json:"has_more"`
}
type Msg struct {
Msgid string `json:"msgid"`
SendTime int64 `json:"send_time"`
Origin int `json:"origin"`
Msgtype string `json:"msgtype"`
Event struct {
EventType string `json:"event_type"`
Scene string `json:"scene"`
OpenKfid string `json:"open_kfid"`
ExternalUserid string `json:"external_userid"`
WelcomeCode string `json:"welcome_code"`
} `json:"event"`
Text struct {
Content string `json:"content"`
} `json:"text"`
OpenKfid string `json:"open_kfid"`
ExternalUserid string `json:"external_userid"`
}
// send msg to user with message
type ReplyMsg struct {
Touser string `json:"touser,omitempty"`
OpenKfid string `json:"open_kfid,omitempty"`
Msgid string `json:"msgid,omitempty"`
Msgtype string `json:"msgtype,omitempty"`
Text struct {
Content string `json:"content,omitempty"`
} `json:"text,omitempty"`
}
// send msg to user with url
type ReplyMsgUrl struct {
Touser string `json:"touser,omitempty"`
OpenKfid string `json:"open_kfid,omitempty"`
Msgid string `json:"msgid,omitempty"`
Msgtype string `json:"msgtype,omitempty"`
Link Link `json:"link,omitempty"`
}
type Link struct {
Title string `json:"title,omitempty"`
Desc string `json:"desc,omitempty"`
Url string `json:"url,omitempty"`
ThumbMediaID string `json:"thumb_media_id,omitempty"`
}
// Upload file response
type MediaUploadResponse struct {
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
MediaType string `json:"type"`
MediaID string `json:"media_id"`
CreatedAt string `json:"created_at"`
}
// 获取用户消息应该得到的响应
type WechatCustomerResponse struct {
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
CustomerList []Customer `json:"customer_list"`
InvalidExternalUserIDs []string `json:"invalid_external_userid"`
}
type Customer struct {
ExternalUserID string `json:"external_userid"`
Nickname string `json:"nickname"`
Avatar string `json:"avatar"`
Gender int `json:"gender"`
UnionID string `json:"unionid"`
}
type UerInfoRequest struct {
UserID []string `json:"external_userid_list"`
SessionContext int `json:"need_enter_session_context"`
}
// chat status
type Status struct {
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
ServiceState int `json:"service_state"`
ServiceUserId string `json:"servicer_userid"`
}
type HumanList struct {
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
ServicerList []ServicerList `json:"servicer_list"`
}
type ServicerList struct {
UserID string `json:"userid"`
Status int `json:"status"`
}
type TokenCache struct {
AccessToken string
TokenExpire time.Time
Mutex sync.Mutex
}
// Map-based token cache keyed by app credentials
var tokenCacheMap = make(map[string]*TokenCache)
var tokenCacheMapMutex = sync.Mutex{}
// Generate a key for the token cache based on app credentials
func getTokenCacheKey(kbID, secret string) string {
return kbID + ":" + secret
}
type UserImageCache struct {
ImageID string
ImagePath string
ImageExpire time.Time
Mutex sync.Mutex
}
var UImageCache = &UserImageCache{}
type DefaultImageCache struct {
ImageID string
ImageExpire time.Time
Mutex sync.Mutex
}
var DImageCache = &DefaultImageCache{}

View File

@@ -0,0 +1,329 @@
package wechat_service
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/url"
"path"
"regexp"
"strings"
"time"
)
// 读取 cursor以客服账号的消息作为key返回对应的cursor值
func getCursor(openKfId string) string {
cursorValue, _ := KfCursors.Load(openKfId)
cursor, _ := cursorValue.(string)
return cursor
}
// 存储 cursor
func setCursor(openKfId, cursor string) {
KfCursors.Store(openKfId, cursor)
}
func CheckSessionState(token, extrenaluserid, kfId string) (int, error) {
var statusrequest struct {
OpenKfId string `json:"open_kfid"`
ExternalUserid string `json:"external_userid"`
}
statusrequest.OpenKfId = kfId
statusrequest.ExternalUserid = extrenaluserid
// 将请求体转换为JSON
jsonBody, err := json.Marshal(statusrequest)
if err != nil {
return 0, err
}
// 获取状态信息
url := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/kf/service_state/get?access_token=%s", token)
resp, err := http.Post(url, "application/json", bytes.NewBuffer(jsonBody))
if err != nil {
return 0, fmt.Errorf("发送请求失败: %v", err)
}
defer resp.Body.Close()
// 读取响应体
body, err := io.ReadAll(resp.Body)
if err != nil {
return 0, fmt.Errorf("读取响应失败: %v", err)
}
var response Status
if err := json.Unmarshal(body, &response); err != nil {
return 0, fmt.Errorf("解析响应失败: %v", err)
}
// 得到用户的状态
if response.ErrCode != 0 {
return 0, fmt.Errorf("获取会话状态失败: %s", response.ErrMsg)
}
return response.ServiceState, nil
}
func ChangeState(token, extrenaluserId, kfId string, state int, serviceId string) error {
var changestate struct {
OpenKfId string `json:"open_kfid"`
ExternalUserid string `json:"external_userid"`
ServiceState int `json:"service_state"`
ServicerUserId string `json:"servicer_userid"`
}
changestate.OpenKfId = kfId
changestate.ExternalUserid = extrenaluserId
changestate.ServiceState = state
changestate.ServicerUserId = serviceId
jsonBody, err := json.Marshal(changestate)
if err != nil {
return err
}
// 发送请求
url := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/kf/service_state/trans?access_token=%s", token)
resp, err := http.Post(url, "application/json", bytes.NewBuffer(jsonBody))
if err != nil {
return fmt.Errorf("发送请求失败: %v", err)
}
defer resp.Body.Close()
// 读取响应体
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("读取响应失败: %v", err)
}
// 解析响应
var response struct {
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
MsgCode string `json:"msg_code"`
}
if err := json.Unmarshal(body, &response); err != nil {
return fmt.Errorf("解析响应失败: %v", err)
}
// 得到用户的状态
if response.ErrCode != 0 {
return fmt.Errorf("改变用户状态失败: %s", response.ErrMsg)
}
return nil
}
func GetUserInfo(userid string, accessToken string) (*Customer, error) {
userInfoRequest := UerInfoRequest{
UserID: []string{userid},
SessionContext: 0,
}
// 请求获取用户信息的url
url := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/kf/customer/batchget?access_token=%s", accessToken)
jsonBody, err := json.Marshal(userInfoRequest)
if err != nil {
return nil, err
}
// post获取用户的消息信息
resp, err := http.Post(url, "application/json", bytes.NewBuffer(jsonBody))
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
var userInfo WechatCustomerResponse
if err := json.Unmarshal(body, &userInfo); err != nil {
return nil, err
}
if userInfo.ErrCode != 0 {
return nil, fmt.Errorf("获取用户信息失败: %d, %s", userInfo.ErrCode, userInfo.ErrMsg)
}
return &userInfo.CustomerList[0], nil
}
// get image id
func GetUserImageID(accessToken, filePath string) (string, error) {
UImageCache.Mutex.Lock()
defer UImageCache.Mutex.Unlock()
if UImageCache.ImageID != "" && (UImageCache.ImagePath == filePath) && time.Now().Before(UImageCache.ImageExpire.Add(-5*time.Minute)) {
return UImageCache.ImageID, nil
}
// URL
mediaID, err := UploadMediaFromURL(accessToken, filePath)
if err != nil {
return "", err
}
UImageCache.ImagePath = filePath
UImageCache.ImageID = mediaID
UImageCache.ImageExpire = time.Now().Add(72 * time.Hour) // 3 days
return UImageCache.ImageID, nil
}
// get image id
func GetDefaultImageID(accessToken, ImageBase64 string) (string, error) {
DImageCache.Mutex.Lock()
defer DImageCache.Mutex.Unlock()
if DImageCache.ImageID != "" && time.Now().Before(DImageCache.ImageExpire.Add(-5*time.Minute)) {
return DImageCache.ImageID, nil
}
// Base64编码
mediaID, err := UploadMediaFromBase64(accessToken, ImageBase64)
if err != nil {
return "", err
}
DImageCache.ImageID = mediaID
DImageCache.ImageExpire = time.Now().Add(72 * time.Hour) // 3 days
return DImageCache.ImageID, nil
}
// upload media to wechat server from URL
func UploadMediaFromURL(accessToken, fileURL string) (string, error) {
// 处理URL
resp, err := http.Get(fileURL)
if err != nil {
return "", fmt.Errorf("下载图片失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("下载图片失败,状态码: %d", resp.StatusCode)
}
reader := resp.Body
fileName := "image.png" // 默认文件名
// 从URL中提取文件名
if u, err := url.Parse(fileURL); err == nil && u.Path != "" {
if path.Base(u.Path) != "/" && path.Base(u.Path) != "." {
fileName = path.Base(u.Path)
}
}
return uploadMediaToWechat(accessToken, reader, fileName)
}
// upload media to wechat server from Base64
func UploadMediaFromBase64(accessToken, base64Data string) (string, error) {
// 处理Base64编码的图片
parts := strings.SplitN(base64Data, ",", 2)
if len(parts) != 2 {
return "", fmt.Errorf("无效的Base64图片数据")
}
// 解码Base64数据
decodedData, err := base64.StdEncoding.DecodeString(parts[1])
if err != nil {
return "", fmt.Errorf("解码Base64图片数据失败: %w", err)
}
reader := bytes.NewReader(decodedData)
fileName := "image.png" // const
return uploadMediaToWechat(accessToken, reader, fileName)
}
// upload media to wechat server - common function
func uploadMediaToWechat(accessToken string, reader io.Reader, fileName string) (string, error) {
// 上传文件 req
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
part, err := writer.CreateFormFile("media", fileName)
if err != nil {
return "", err
}
// 将图片数据复制到表单中
_, err = io.Copy(part, reader)
if err != nil {
return "", fmt.Errorf("复制图片数据失败: %w", err)
}
if err := writer.Close(); err != nil {
return "", err
}
url := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/media/upload?access_token=%s&type=image", accessToken)
req, err := http.NewRequest("POST", url, body)
if err != nil {
return "", err
}
req.Header.Set("Content-Type", writer.FormDataContentType())
client := &http.Client{}
httpResp, err := client.Do(req)
if err != nil {
return "", err
}
defer httpResp.Body.Close()
var result MediaUploadResponse
if err := json.NewDecoder(httpResp.Body).Decode(&result); err != nil {
return "", err
}
if result.ErrCode != 0 {
return "", fmt.Errorf("上传失败: [%d] %s", result.ErrCode, result.ErrMsg)
}
return result.MediaID, nil
}
func getMsgs(accessToken string, msg *WeixinUserAskMsg) (*MsgRet, error) {
var msgRet MsgRet
// 拉取消息的路由
url := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/kf/sync_msg?access_token=%s", accessToken)
cursor := getCursor(msg.OpenKfId)
msgBody := MsgRequest{
OpenKfid: msg.OpenKfId,
Token: msg.Token,
Limit: 1000,
VoiceFormat: 0,
Cursor: cursor,
}
jsonBody, _ := json.Marshal(msgBody)
resp, err := http.Post(url, "application/json", bytes.NewBuffer(jsonBody)) // 得到对应的回复
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
// 反序列化之后
if err := json.Unmarshal([]byte(string(body)), &msgRet); err != nil {
return nil, err
}
return &msgRet, nil
}
// markdowntotext
func MarkdowntoText(md string) string {
md = regexp.MustCompile(`(?m)^#+\s*(.*)$`).ReplaceAllString(md, "$1")
md = regexp.MustCompile(`\*\*([^*]+)\*\*`).ReplaceAllString(md, "$1")
md = regexp.MustCompile(`(?m)^>\s*(.*)$`).ReplaceAllString(md, "【引用】$1")
md = regexp.MustCompile(`(?m)^-{3,}$`).ReplaceAllString(md, "─────────")
md = regexp.MustCompile(`\n{3,}`).ReplaceAllString(md, "\n\n")
md = regexp.MustCompile(`\[\[(\d+)\]\([^)]+\)\]`).ReplaceAllString(md, "[$1]")
md = regexp.MustCompile(`\[(\d+)\]\.\s*\[([^\]]+)\]\([^)]+\)`).ReplaceAllString(md, "[$1]. $2")
md = regexp.MustCompile(`(?m)^【引用】\[(\d+)\].\s*([^\n(]+)\s*\([^)]+\)`).ReplaceAllString(md, "【引用】[$1]. $2")
return strings.TrimSpace(md)
}

View File

@@ -0,0 +1,403 @@
package wechat_service
import (
"bytes"
"context"
"encoding/json"
"encoding/xml"
"errors"
"fmt"
"io"
"net/http"
"slices"
"strings"
"time"
"unicode/utf8"
"github.com/google/uuid"
"github.com/samber/lo"
"github.com/sbzhu/weworkapi_golang/wxbizmsgcrypt"
"github.com/chaitin/panda-wiki/domain"
"github.com/chaitin/panda-wiki/log"
"github.com/chaitin/panda-wiki/pkg/bot"
)
func NewWechatServiceConfig(ctx context.Context, logger *log.Logger, KbId, CorpID, Token, EncodingAESKey, secret, logo string, containKeywords, equalKeywords []string) (*WechatServiceConfig, error) {
return &WechatServiceConfig{
Ctx: ctx,
kbID: KbId,
CorpID: CorpID,
Token: Token,
EncodingAESKey: EncodingAESKey,
Secret: secret,
logger: logger,
containKeywords: containKeywords,
equalKeywords: equalKeywords,
logoUrl: logo,
}, nil
}
func (cfg *WechatServiceConfig) VerifyUrlWechatService(signature, timestamp, nonce, echostr string) ([]byte, error) {
wxcpt := wxbizmsgcrypt.NewWXBizMsgCrypt(
cfg.Token,
cfg.EncodingAESKey,
cfg.CorpID,
wxbizmsgcrypt.XmlType,
)
// 验证URL并解密echostr
decryptEchoStr, errCode := wxcpt.VerifyURL(signature, timestamp, nonce, echostr)
if errCode != nil {
return nil, errors.New("server serve fail wechat")
}
// success
return decryptEchoStr, nil
}
func (cfg *WechatServiceConfig) Wechat(msg *WeixinUserAskMsg, getQA bot.GetQAFun) error {
// 获取accesstoken 方便给用户发送消息
token, err := cfg.GetAccessToken()
if err != nil {
return err
}
// 主动拉去用户发送的消息
msgRet, err := getMsgs(token, msg)
if err != nil {
return err
}
if msgRet.NextCursor != "" {
setCursor(msg.OpenKfId, msgRet.NextCursor)
}
err = cfg.Processmessage(msgRet, msg, getQA)
if err != nil {
cfg.logger.Error("send to ai failed!")
return err
}
return nil
}
// forwardToBackend
func (cfg *WechatServiceConfig) Processmessage(msgRet *MsgRet, Kfmsg *WeixinUserAskMsg, GetQA bot.GetQAFun) error {
// err message
cfg.logger.Info("get user message", log.Int("msgRet.Errcode", msgRet.Errcode), log.String("msg.Errmsg", msgRet.Errmsg))
size := len(msgRet.MsgList)
if size < 1 {
return fmt.Errorf("no message received")
}
// 如果是用户刚刚进入会话的事件,那么不需要发送消息给用户
if msgRet.MsgList[size-1].Msgtype == "event" && msgRet.MsgList[size-1].Event.EventType == "enter_session" {
return nil
}
// 每次只是拿去最新的数据
current := msgRet.MsgList[size-1]
userId := current.ExternalUserid
openkfId := current.OpenKfid
content := current.Text.Content
token, _ := cfg.GetAccessToken()
state, err := CheckSessionState(token, userId, openkfId)
if err != nil {
cfg.logger.Error("check session state failed", log.Error(err))
return err
}
if state == 3 { // 人工状态 ---已经是人工,那么就不要需要发消息给用户
cfg.logger.Info("the customer has already in human service")
return nil
}
if len(cfg.equalKeywords) > 0 || len(cfg.containKeywords) > 0 {
if slices.Contains(cfg.equalKeywords, content) || lo.SomeBy(cfg.containKeywords, func(sub string) bool {
return strings.Contains(content, sub)
}) {
// 改变状态为人工接待
// 非人工 ->转人工
humanList, err := cfg.GetKfHumanList(token, openkfId)
if err != nil {
cfg.logger.Error("get human list failed", log.Error(err))
return err
}
// 遍历找到可以接待的员工
for _, servicer := range humanList.ServicerList {
if servicer.Status == 0 { // 可以接待
err := ChangeState(token, userId, openkfId, 3, servicer.UserID)
if err != nil {
cfg.logger.Error("change state to human failed", log.Error(err))
return err
}
cfg.logger.Info("change state to human successful") // 转人工成功
return nil
}
}
// 失败
cfg.logger.Info("no human available")
return cfg.SendResponseToKfTxt(userId, openkfId, "当前没有可用的人工客服", token)
}
}
// 1. first response to user
if err := cfg.SendResponseToKfTxt(userId, openkfId, "正在思考您的问题,请稍等...", token); err != nil {
return err
}
// 获取用户的详细信息
customer, err := GetUserInfo(userId, token)
if err != nil {
cfg.logger.Error("get user info failed", log.Error(err))
}
cfg.logger.Info("customer info", log.Any("customer", customer))
id, err := uuid.NewV7()
if err != nil {
cfg.logger.Error("failed to generate conversation uuid", log.Error(err))
id = uuid.New()
}
conversationID := id.String()
wccontent, err := GetQA(cfg.Ctx, content, domain.ConversationInfo{UserInfo: domain.UserInfo{
UserID: customer.ExternalUserID, // 用户对话的id
NickName: customer.Nickname, //用户微信的昵称
Avatar: customer.Avatar, // 用户微信的头像
From: domain.MessageFromPrivate,
}}, conversationID)
if err != nil {
return err
}
//2. get baseurl and image path
info, err := cfg.WeRepo.GetWechatStatic(cfg.Ctx, cfg.kbID, domain.AppTypeWeb)
if err != nil {
return err
}
//2. go send to ai and store in map--> get conversation-id
if _, ok := domain.ConversationManager.Load(conversationID); !ok {
state := &domain.ConversationState{
Question: content,
NotificationChan: make(chan string), // notification channel
IsVisited: false,
}
domain.ConversationManager.Store(conversationID, state)
go cfg.SendQuestionToAI(conversationID, wccontent)
}
// 3. second send url to user
return cfg.SendResponseToKfUrl(userId, openkfId, conversationID, token, content, info.BaseUrl, info.ImagePath)
}
func (cfg *WechatServiceConfig) getImageID(token, image string) (string, error) {
const minioPrefix = "http://panda-wiki-minio:9000"
// 优先使用配置的logoUrl
if cfg.logoUrl != "" {
image = cfg.logoUrl
}
var imageId string
var err error
switch {
case image == "":
case strings.HasPrefix(image, "data:image/"):
imageId, err = GetDefaultImageID(token, image)
default:
imageId, err = GetUserImageID(token, fmt.Sprintf("%s%s", minioPrefix, image))
}
if imageId != "" && err == nil {
return imageId, nil
}
if err != nil {
cfg.logger.Error("failed to get image ID, using default", log.Error(err))
}
return GetDefaultImageID(token, domain.DefaultPandaWikiIconB64)
}
func (cfg *WechatServiceConfig) SendResponseToKfUrl(userId, openkfId, conversationID, token, question, baseUrl, image string) error {
imageId, err := cfg.getImageID(token, image)
if err != nil {
return err
}
if utf8.RuneCountInString(question) > 35 {
question = string([]rune(question)[:35]) + "......"
}
reply := ReplyMsgUrl{
Touser: userId,
OpenKfid: openkfId,
Msgtype: "link",
Link: Link{
Url: fmt.Sprintf("%s/h5-chat?id=%s", baseUrl, conversationID),
Desc: "本回答由 PandaWiki 基于 AI 生成,仅供参考。",
Title: question,
ThumbMediaID: imageId,
},
}
jsonData, err := json.Marshal(reply)
if err != nil {
return fmt.Errorf("json Marshal failed: %w", err)
}
return cfg.SendMessage(jsonData, token)
}
func (cfg *WechatServiceConfig) SendResponseToKfTxt(userId string, openkfId string, response string, token string) error {
// send text data to user
reply := ReplyMsg{
Touser: userId,
OpenKfid: openkfId,
Msgtype: "text",
Text: struct {
Content string `json:"content,omitempty"`
}{Content: response},
}
jsonData, err := json.Marshal(reply)
if err != nil {
return fmt.Errorf("json Marshal failed: %w", err)
}
return cfg.SendMessage(jsonData, token)
}
func (cfg *WechatServiceConfig) SendMessage(jsonData []byte, token string) error {
// 发送消息给客服
url := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/kf/send_msg?access_token=%s", token)
resp, err := http.Post(url, "application/json", bytes.NewBuffer(jsonData))
if err != nil {
return fmt.Errorf("post to wechatservice failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("read response body failed: %w", err)
}
var res struct {
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
MsgID string `json:"msgid"`
}
if err := json.Unmarshal(body, &res); err != nil {
cfg.logger.Error("解析响应失败", log.Error(err))
return err
}
if res.ErrCode != 0 {
cfg.logger.Error("发送给微信客服消息失败", log.Any("errcode", res.ErrCode), log.Any("errmsg", res.ErrMsg), log.Any("jsonData", string(jsonData)))
return err
}
// 发送消息给微信客服成功
s := string(body)
cfg.logger.Info("response from wechatservice success", log.Any("body", s))
return nil
}
func (cfg *WechatServiceConfig) GetAccessToken() (string, error) {
// Generate cache key based on app credentials
cacheKey := getTokenCacheKey(cfg.kbID, cfg.Secret)
// Get or create token cache for this app
tokenCacheMapMutex.Lock()
tokenCache, exists := tokenCacheMap[cacheKey]
if !exists {
tokenCache = &TokenCache{}
tokenCacheMap[cacheKey] = tokenCache
}
tokenCacheMapMutex.Unlock()
// Lock the specific token cache for this app
tokenCache.Mutex.Lock()
defer tokenCache.Mutex.Unlock()
if tokenCache.AccessToken != "" && time.Now().Before(tokenCache.TokenExpire) {
cfg.logger.Debug("access token has existed and is valid")
return tokenCache.AccessToken, nil
}
if cfg.Secret == "" || cfg.CorpID == "" {
return "", errors.New("secret or corpid is not right")
}
// get AccessToken--请求微信客服token
url := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid=%s&corpsecret=%s", cfg.CorpID, cfg.Secret)
resp, err := http.Get(url)
if err != nil {
return "", errors.New("get wechatservice accesstoken failed")
}
defer resp.Body.Close()
var tokenResp AccessToken // 获取到token消息
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
return "", errors.New("json decode wechat resp failed")
}
if tokenResp.Errcode != 0 {
return "", errors.New("get wechat access token failed")
}
// success
cfg.logger.Info("wechatservice get accesstoken success", log.Any("info", tokenResp.AccessToken))
tokenCache.AccessToken = tokenResp.AccessToken
tokenCache.TokenExpire = time.Now().Add(time.Duration(tokenResp.ExpiresIn-300) * time.Second)
return tokenCache.AccessToken, nil
}
// 解析微信客服消息
func (cfg *WechatServiceConfig) UnmarshalMsg(decryptMsg []byte) (*WeixinUserAskMsg, error) {
var msg WeixinUserAskMsg
err := xml.Unmarshal([]byte(decryptMsg), &msg)
return &msg, err
}
func (cfg *WechatServiceConfig) GetKfHumanList(token string, KfId string) (*HumanList, error) {
url := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/kf/servicer/list?access_token=%s&open_kfid=%s", token, KfId)
resp, err := http.Get(url)
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
var servicerResp HumanList
if err := json.Unmarshal(body, &servicerResp); err != nil {
return nil, err
}
if servicerResp.ErrCode != 0 {
return nil, fmt.Errorf("获取客服列表失败: %d, %s", servicerResp.ErrCode, servicerResp.ErrMsg)
}
return &servicerResp, nil
}
// answer set into redis queue and set useful time
func (cfg *WechatServiceConfig) SendQuestionToAI(conversationID string, wccontent chan string) {
// send message
val, _ := domain.ConversationManager.Load(conversationID)
state := val.(*domain.ConversationState)
for content := range wccontent {
state.Mutex.Lock()
if state.IsVisited {
state.NotificationChan <- content // notify has new data
}
state.Buffer.WriteString(content)
state.Mutex.Unlock()
}
// end sent notification
defer func() {
close(state.NotificationChan)
domain.ConversationManager.Delete(conversationID)
}()
}

View File

@@ -0,0 +1,144 @@
package wecom
import (
"context"
"encoding/json"
"github.com/chaitin/panda-wiki/log"
)
// AIBotClient 微信智能机器人
// https://developer.work.weixin.qq.com/document/path/100719
type AIBotClient struct {
ctx context.Context
logger *log.Logger
Token string
EncodingAESKey string
}
type UserReq struct {
Msgid string `json:"msgid"`
Aibotid string `json:"aibotid"`
Chattype string `json:"chattype"`
From struct {
Userid string `json:"userid"`
} `json:"from"`
Msgtype string `json:"msgtype"`
Text struct {
Content string `json:"content"`
} `json:"text"`
Stream struct {
Id string `json:"id"`
} `json:"stream"`
}
type UserResp struct {
Msgtype string `json:"msgtype"`
Stream Stream `json:"stream"`
}
type Stream struct {
Id string `json:"id"`
Finish bool `json:"finish"`
Content string `json:"content"`
MsgItem []struct {
Msgtype string `json:"msgtype"`
Image struct {
Base64 string `json:"base64"`
Md5 string `json:"md5"`
} `json:"image"`
} `json:"msg_item"`
}
func NewAIBotClient(
ctx context.Context,
logger *log.Logger,
Token string,
EncodingAESKey string,
) (*AIBotClient, error) {
return &AIBotClient{
ctx: ctx,
logger: logger,
Token: Token,
EncodingAESKey: EncodingAESKey,
}, nil
}
func (c *AIBotClient) VerifyUrlWecomService(signature, timestamp, nonce, echostr string) (string, error) {
wx, _, err := NewWXBizJsonMsgCrypt(
c.Token,
c.EncodingAESKey,
"",
)
if err != nil {
return "", err
}
code, sReplyEchoStr := wx.VerifyURL(signature, timestamp, nonce, echostr)
if code != 0 {
c.logger.Error("VerifyUrlWecomService failed:", log.Any("code", code))
return "", c.getErrorMessage(code)
}
return sReplyEchoStr, nil
}
func (c *AIBotClient) DecryptUserReq(signature, timestamp, nonce, msg string) (*UserReq, error) {
wx, _, err := NewWXBizJsonMsgCrypt(
c.Token,
c.EncodingAESKey,
"",
)
if err != nil {
return nil, err
}
code, reqMsg := wx.DecryptMsg(msg, signature, timestamp, nonce)
if code != 0 {
return nil, c.getErrorMessage(code)
}
var data UserReq
c.logger.Info("decrypt user req:", log.Any("reqMsg", reqMsg))
err = json.Unmarshal([]byte(reqMsg), &data)
if err != nil {
return nil, err
}
return &data, nil
}
func (c *AIBotClient) MakeStreamResp(nonce, id, content string, isFinish bool) (string, error) {
c.logger.Debug("MakeStreamResp:", log.String("content", content), log.Any("isFinish", isFinish))
wx, _, err := NewWXBizJsonMsgCrypt(
c.Token,
c.EncodingAESKey,
"",
)
if err != nil {
return "", err
}
resp := UserResp{
Msgtype: "stream",
Stream: Stream{
Id: id,
Finish: isFinish,
Content: content,
MsgItem: nil,
},
}
b, err := json.Marshal(resp)
if err != nil {
return "", err
}
code, msg := wx.EncryptMsg(string(b), nonce)
if code != 0 {
c.logger.Error("MakeStreamResp failed:", log.Any("code", code))
return "", c.getErrorMessage(code)
}
return msg, nil
}

View File

@@ -0,0 +1,374 @@
// Package wecom provides cryptographic utilities for WeChat Work (WeCom) message encryption and decryption.
// It implements the WXBizMsgCrypt algorithm for secure message handling with WeChat Work APIs.
package wecom
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha1"
"encoding/base64"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"math/big"
"sort"
"strings"
"time"
)
const (
WXBizMsgCrypt_OK = 0
WXBizMsgCrypt_ValidateSignature_Error = 40001
WXBizMsgCrypt_ParseJson_Error = 40002
WXBizMsgCrypt_ComputeSignature_Error = 40003
WXBizMsgCrypt_IllegalAesKey = 40004
WXBizMsgCrypt_EncryptAES_Error = 40005
WXBizMsgCrypt_DecryptAES_Error = 40006
WXBizMsgCrypt_IllegalBuffer = 40007
WXBizMsgCrypt_ValidateCorpid_Error = 40008
WXBizMsgCrypt_ValidateCorpid_Receive_Id = 40009
WXBizMsgCrypt_ValidateCorpid_Mismatch = 40010
)
var wecomErrorMessages = map[int]string{
WXBizMsgCrypt_OK: "success",
WXBizMsgCrypt_ValidateSignature_Error: "signature validation failed",
WXBizMsgCrypt_ParseJson_Error: "invalid JSON format",
WXBizMsgCrypt_ComputeSignature_Error: "signature computation failed",
WXBizMsgCrypt_IllegalAesKey: "illegal AES key",
WXBizMsgCrypt_EncryptAES_Error: "AES encryption failed",
WXBizMsgCrypt_DecryptAES_Error: "AES decryption failed",
WXBizMsgCrypt_IllegalBuffer: "illegal buffer format",
WXBizMsgCrypt_ValidateCorpid_Error: "corp ID validation failed",
WXBizMsgCrypt_ValidateCorpid_Receive_Id: "receive ID validation failed",
WXBizMsgCrypt_ValidateCorpid_Mismatch: "corp ID mismatch",
}
func (c *AIBotClient) getErrorMessage(code int) error {
if msg, ok := wecomErrorMessages[code]; ok {
return fmt.Errorf("wecom error (code %d): %s", code, msg)
}
return fmt.Errorf("unknown wecom error: %d", code)
}
var ErrFormat = errors.New("format error")
// SHA1 负责生成安全签名sha1
type SHA1 struct{}
// GetSHA1 : 对 token, timestamp, nonce, encrypt 排序后 sha1
// 返回 (code, signature)
func (s *SHA1) GetSHA1(token, timestamp, nonce string, encrypt interface{}) (int, string) {
defer func() {
// no panic propagation in this helper; but keep signature simple
}()
encStr := ""
switch v := encrypt.(type) {
case string:
encStr = v
case []byte:
encStr = string(v)
case nil:
encStr = ""
default:
encStr = fmt.Sprint(v)
}
list := []string{token, timestamp, nonce, encStr}
sort.Strings(list)
joined := strings.Join(list, "")
h := sha1.New()
_, err := h.Write([]byte(joined))
if err != nil {
return WXBizMsgCrypt_ComputeSignature_Error, ""
}
return WXBizMsgCrypt_OK, fmt.Sprintf("%x", h.Sum(nil))
}
// JsonParse 提取/生成 json 消息
type JsonParse struct{}
type aesTextResponse struct {
Encrypt string `json:"encrypt"`
MsgSignature string `json:"msgsignature"`
Timestamp string `json:"timestamp"`
Nonce string `json:"nonce"`
}
// Extract 从 json 字符串中提取 encrypt 字段
// 返回 (code, encrypt)
func (jp *JsonParse) Extract(jsonText string) (int, string) {
var m map[string]interface{}
if err := json.Unmarshal([]byte(jsonText), &m); err != nil {
return WXBizMsgCrypt_ParseJson_Error, ""
}
if v, ok := m["encrypt"].(string); ok {
return WXBizMsgCrypt_OK, v
}
return WXBizMsgCrypt_ParseJson_Error, ""
}
// Generate 根据参数生成 json 字符串
func (jp *JsonParse) Generate(encrypt, signature, timestamp, nonce string) string {
resp := aesTextResponse{
Encrypt: encrypt,
MsgSignature: signature,
Timestamp: timestamp,
Nonce: nonce,
}
bs, _ := json.Marshal(resp)
return string(bs)
}
// PKCS7Encoder 提供基于 PKCS7 的填充/去填充
type PKCS7Encoder struct {
BlockSize int // 使用 32 与 Python 示例一致
}
func NewPKCS7Encoder() *PKCS7Encoder {
return &PKCS7Encoder{BlockSize: 32}
}
func (p *PKCS7Encoder) Encode(src []byte) []byte {
if src == nil {
src = []byte{}
}
n := len(src)
amountToPad := p.BlockSize - (n % p.BlockSize)
if amountToPad == 0 {
amountToPad = p.BlockSize
}
pad := byte(amountToPad)
padtext := bytes.Repeat([]byte{pad}, amountToPad)
return append(src, padtext...)
}
func (p *PKCS7Encoder) Decode(decrypted []byte) ([]byte, error) {
if len(decrypted) == 0 {
return nil, nil
}
pad := int(decrypted[len(decrypted)-1])
if pad < 1 || pad > p.BlockSize {
// 同 Python 逻辑:当 pad 值不合理时,视为 0或 error
return decrypted, fmt.Errorf("invalid padding")
}
return decrypted[:len(decrypted)-pad], nil
}
// Prpcrypt 提供 AES 加解密功能
type Prpcrypt struct {
Key []byte
Mode string // not used but kept for parity
}
func NewPrpcrypt(key []byte) *Prpcrypt {
return &Prpcrypt{Key: key, Mode: "CBC"}
}
// Encrypt 对明文加密,返回 (code, base64Ciphertext)
func (pc *Prpcrypt) Encrypt(plainText string, receiveID string) (int, string) {
// 将明文转换为 bytes
txt := []byte(plainText)
// 随机 16 字节数字字符串
rand16, err := getRandom16BytesAsDigits()
if err != nil {
return WXBizMsgCrypt_EncryptAES_Error, ""
}
// 包装: 16 bytes random + 4 bytes network-order(len) + txt + receiveid
buf := bytes.NewBuffer(nil)
buf.Write(rand16)
// len(txt) 网络字节序
lenBuf := make([]byte, 4)
// Python 示例使用 socket.htonl(len(text)),即 network order (big endian)
binary.BigEndian.PutUint32(lenBuf, uint32(len(txt)))
buf.Write(lenBuf)
buf.Write(txt)
buf.Write([]byte(receiveID))
raw := buf.Bytes()
// PKCS7 pad 到 blocksize=32
encoder := NewPKCS7Encoder()
padded := encoder.Encode(raw)
// AES-CBC
block, err := aes.NewCipher(pc.Key)
if err != nil {
return WXBizMsgCrypt_EncryptAES_Error, ""
}
iv := pc.Key[:16]
if len(iv) < 16 {
return WXBizMsgCrypt_IllegalAesKey, ""
}
mode := cipher.NewCBCEncrypter(block, iv)
if len(padded)%block.BlockSize() != 0 {
// 应该已经经过 pad
return WXBizMsgCrypt_EncryptAES_Error, ""
}
ciphertext := make([]byte, len(padded))
mode.CryptBlocks(ciphertext, padded)
enc := base64.StdEncoding.EncodeToString(ciphertext)
return WXBizMsgCrypt_OK, enc
}
// Decrypt 解密 base64 文本,返回 (code, jsonContent)
func (pc *Prpcrypt) Decrypt(base64Cipher string, receiveID string) (int, string) {
cipherData, err := base64.StdEncoding.DecodeString(base64Cipher)
if err != nil {
return WXBizMsgCrypt_DecryptAES_Error, ""
}
block, err := aes.NewCipher(pc.Key)
if err != nil {
return WXBizMsgCrypt_DecryptAES_Error, ""
}
if len(cipherData)%block.BlockSize() != 0 {
return WXBizMsgCrypt_DecryptAES_Error, ""
}
iv := pc.Key[:16]
mode := cipher.NewCBCDecrypter(block, iv)
plain := make([]byte, len(cipherData))
mode.CryptBlocks(plain, cipherData)
// 去 PKCS7 填充 (blocksize=32)
encoder := NewPKCS7Encoder()
unpadded, err := encoder.Decode(plain)
if err != nil {
// Python 里如果 pad 错误会继续尝试并最后返回 IllegalBuffer
// 这里直接返回 IllegalBuffer
return WXBizMsgCrypt_IllegalBuffer, ""
}
// 去掉前 16 字节随机字符串
if len(unpadded) < 16 {
return WXBizMsgCrypt_IllegalBuffer, ""
}
content := unpadded[16:]
if len(content) < 4 {
return WXBizMsgCrypt_IllegalBuffer, ""
}
// 前 4 字节为 network order 的 json length
jsonLen := binary.BigEndian.Uint32(content[:4])
if int(jsonLen) > len(content)-4 {
return WXBizMsgCrypt_IllegalBuffer, ""
}
jsonContent := string(content[4 : 4+jsonLen])
fromReceiveID := string(content[4+jsonLen:])
if fromReceiveID != receiveID {
// receiveid 不匹配
return WXBizMsgCrypt_ValidateCorpid_Error, ""
}
return WXBizMsgCrypt_OK, jsonContent
}
// getRandom16BytesAsDigits 产生一个 16 字节的 ASCII 数字字符串(与 Python 版本行为一致)
func getRandom16BytesAsDigits() ([]byte, error) {
const digits = "0123456789"
out := make([]byte, 16)
for i := 0; i < 16; i++ {
nBig, err := rand.Int(rand.Reader, big.NewInt(int64(len(digits))))
if err != nil {
return nil, err
}
out[i] = digits[nBig.Int64()]
}
return out, nil
}
// WXBizJsonMsgCrypt 将整个流程封装:初始化时传入 token, encodingAESKey, receiveID
type WXBizJsonMsgCrypt struct {
Token string
EncodingKey []byte
ReceiveID string
encodingAES string // 原始 sEncodingAESKey
}
// NewWXBizJsonMsgCrypt 构造sToken, sEncodingAESKey, sReceiveID
func NewWXBizJsonMsgCrypt(sToken, sEncodingAESKey, sReceiveID string) (*WXBizJsonMsgCrypt, int, error) {
// Python 里是 base64.b64decode(sEncodingAESKey + "=")
dec, err := base64.StdEncoding.DecodeString(sEncodingAESKey + "=")
if err != nil {
return nil, WXBizMsgCrypt_IllegalAesKey, fmt.Errorf("EncodingAESKey base64 decode fail: %w", err)
}
if len(dec) != 32 {
return nil, WXBizMsgCrypt_IllegalAesKey, fmt.Errorf("EncodingAESKey decoded length must be 32 (got %d)", len(dec))
}
return &WXBizJsonMsgCrypt{
Token: sToken,
EncodingKey: dec,
ReceiveID: sReceiveID,
encodingAES: sEncodingAESKey,
}, WXBizMsgCrypt_OK, nil
}
// VerifyURL 校验并解密 sEchoStr用于首次验证 URL
// 返回 (code, sReplyEchoStr)
func (w *WXBizJsonMsgCrypt) VerifyURL(sMsgSignature, sTimeStamp, sNonce, sEchoStr string) (int, string) {
sha1 := &SHA1{}
ret, signature := sha1.GetSHA1(w.Token, sTimeStamp, sNonce, sEchoStr)
if ret != WXBizMsgCrypt_OK {
return ret, ""
}
if signature != sMsgSignature {
return WXBizMsgCrypt_ValidateSignature_Error, ""
}
pc := NewPrpcrypt(w.EncodingKey)
ret, reply := pc.Decrypt(sEchoStr, w.ReceiveID)
return ret, reply
}
// EncryptMsg 对要回复的消息 sReplyMsgjson 字符串)进行加密并生成外层 JSON 包装
// 返回 (code, generatedJson)
func (w *WXBizJsonMsgCrypt) EncryptMsg(sReplyMsg, sNonce string, timestamp ...string) (int, string) {
pc := NewPrpcrypt(w.EncodingKey)
ret, encrypt := pc.Encrypt(sReplyMsg, w.ReceiveID)
if ret != WXBizMsgCrypt_OK {
return ret, ""
}
// encrypt 是 base64 字符串(已经),确保是字符串
encryptStr := encrypt
ts := ""
if len(timestamp) > 0 && timestamp[0] != "" {
ts = timestamp[0]
} else {
ts = fmt.Sprintf("%d", time.Now().Unix())
}
sha1 := &SHA1{}
ret, signature := sha1.GetSHA1(w.Token, ts, sNonce, encryptStr)
if ret != WXBizMsgCrypt_OK {
return ret, ""
}
jp := &JsonParse{}
jsonStr := jp.Generate(encryptStr, signature, ts, sNonce)
return WXBizMsgCrypt_OK, jsonStr
}
// DecryptMsg 验证签名并解密 POST 的 json 数据包
// sPostData: POST 的 json 数据字符串(包含 encrypt 字段)
// sMsgSignature: URL param msg_signature
// sTimeStamp: timestamp
// sNonce: nonce
// 返回 (code, jsonContent)
func (w *WXBizJsonMsgCrypt) DecryptMsg(sPostData, sMsgSignature, sTimeStamp, sNonce string) (int, string) {
jp := &JsonParse{}
ret, encrypt := jp.Extract(sPostData)
if ret != WXBizMsgCrypt_OK {
return ret, ""
}
sha1 := &SHA1{}
ret, signature := sha1.GetSHA1(w.Token, sTimeStamp, sNonce, encrypt)
if ret != WXBizMsgCrypt_OK {
return ret, ""
}
if signature != sMsgSignature {
return WXBizMsgCrypt_ValidateSignature_Error, ""
}
pc := NewPrpcrypt(w.EncodingKey)
return pc.Decrypt(encrypt, w.ReceiveID)
}

View File

@@ -0,0 +1,17 @@
package captcha
import gocap "github.com/ackcoder/go-cap"
type Captcha struct {
*gocap.Cap
}
func NewCaptcha() *Captcha {
return &Captcha{
Cap: gocap.New(
gocap.WithChallenge(50, 32, 3),
gocap.WithChallengeExpires(60*2),
gocap.WithTokenExpires(60*5),
),
}
}

229
backend/pkg/cas/cas.go Normal file
View File

@@ -0,0 +1,229 @@
package cas
import (
"context"
"crypto/tls"
"encoding/xml"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"github.com/chaitin/panda-wiki/log"
)
type Client struct {
logger *log.Logger
ctx context.Context
config *Config
httpClient *http.Client
}
type Config struct {
ServerURL string `json:"server_url"` // CAS服务器URL如 https://cas.example.com/cas
ServiceURL string `json:"service_url"` // 服务回调URL
LoginPath string `json:"login_path"` // 登录路径,默认为 /login
ValidatePath string `json:"validate_path"` // 验证路径,默认根据版本自动选择
Version string `json:"version"` // CAS协议版本: "2" 或 "3"
CASUrl string `json:"cas_url"`
}
type UserInfo struct {
Username string `json:"username"`
Attributes map[string]string `json:"attributes"`
}
// CAS2ServiceResponse CAS2服务验证响应结构
type CAS2ServiceResponse struct {
XMLName xml.Name `xml:"serviceResponse"`
Success *CAS2AuthenticationSuccess `xml:"authenticationSuccess"`
Failure *AuthenticationFailure `xml:"authenticationFailure"`
}
type CAS2AuthenticationSuccess struct {
User string `xml:"user"`
}
// CAS3ServiceResponse CAS3服务验证响应结构
type CAS3ServiceResponse struct {
XMLName xml.Name `xml:"serviceResponse"`
Success *CAS3AuthenticationSuccess `xml:"authenticationSuccess"`
Failure *AuthenticationFailure `xml:"authenticationFailure"`
}
type CAS3AuthenticationSuccess struct {
User string `xml:"user"`
Attributes CAS3Attributes `xml:"attributes"`
}
type AuthenticationFailure struct {
Code string `xml:"code,attr"`
Message string `xml:",chardata"`
}
type CAS3Attributes struct {
Email string `xml:"email"`
Name string `xml:"name"`
AvatarURL string `xml:"avatar_url"`
}
const (
defaultLoginPath = "/login"
defaultValidatePathCAS2 = "/serviceValidate"
defaultValidatePathCAS3 = "/p3/serviceValidate"
callbackPath = "/share/pro/v1/openapi/cas/callback"
)
// NewClient 创建CAS客户端
func NewClient(ctx context.Context, logger *log.Logger, config Config) (*Client, error) {
// 设置默认登录路径
if config.LoginPath == "" {
config.LoginPath = defaultLoginPath
}
// 如果版本为空默认使用CAS3
if config.Version == "" {
config.Version = "3"
}
// 根据版本设置默认验证路径
if config.ValidatePath == "" {
switch config.Version {
case "3":
config.ValidatePath = defaultValidatePathCAS3
case "2", "":
config.ValidatePath = defaultValidatePathCAS2
default:
return nil, fmt.Errorf("unsupported CAS version: %s, supported versions are '2' and '3'", config.Version)
}
}
// 构建服务回调URL
if config.ServiceURL != "" {
serviceURL, err := url.Parse(config.ServiceURL)
if err != nil {
return nil, fmt.Errorf("invalid service URL: %w", err)
}
serviceURL.Path = callbackPath
config.ServiceURL = serviceURL.String()
}
return &Client{
ctx: ctx,
logger: logger.WithModule("pkg.cas"),
config: &config,
httpClient: &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
},
}, nil
}
// GetLoginURL 获取CAS登录URL
func (c *Client) GetLoginURL(state string) string {
loginURL := strings.TrimSuffix(c.config.ServerURL, "/") + c.config.LoginPath
params := url.Values{}
params.Set("service", c.config.ServiceURL+"?state="+state)
return loginURL + "?" + params.Encode()
}
// ValidateTicket 验证CAS票据并获取用户信息
func (c *Client) ValidateTicket(ticket, state string) (*UserInfo, error) {
validateURL := strings.TrimSuffix(c.config.ServerURL, "/") + c.config.ValidatePath
params := url.Values{}
params.Set("service", c.config.ServiceURL+"?state="+state)
params.Set("ticket", ticket)
fullURL := validateURL + "?" + params.Encode()
c.logger.Info("validating CAS ticket",
log.String("url", fullURL),
log.String("version", c.config.Version))
resp, err := c.httpClient.Get(fullURL)
if err != nil {
return nil, fmt.Errorf("failed to validate ticket: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
c.logger.Info("CAS validation response", log.String("response", string(body)))
// 根据CAS版本解析不同的响应格式
switch c.config.Version {
case "2":
return c.parseCAS2Response(body)
case "3":
return c.parseCAS3Response(body)
default:
return nil, fmt.Errorf("unsupported CAS version: %s", c.config.Version)
}
}
// parseCAS2Response 解析CAS2响应
func (c *Client) parseCAS2Response(body []byte) (*UserInfo, error) {
var serviceResp CAS2ServiceResponse
if err := xml.Unmarshal(body, &serviceResp); err != nil {
return nil, fmt.Errorf("failed to parse CAS2 response: %w", err)
}
if serviceResp.Failure != nil {
return nil, fmt.Errorf("CAS validation failed: %s - %s",
serviceResp.Failure.Code, strings.TrimSpace(serviceResp.Failure.Message))
}
if serviceResp.Success == nil {
return nil, fmt.Errorf("invalid CAS2 response: no success or failure element")
}
userInfo := &UserInfo{
Username: serviceResp.Success.User,
Attributes: map[string]string{
"name": serviceResp.Success.User, // CAS2通常只返回用户名
},
}
return userInfo, nil
}
// parseCAS3Response 解析CAS3响应
func (c *Client) parseCAS3Response(body []byte) (*UserInfo, error) {
var serviceResp CAS3ServiceResponse
if err := xml.Unmarshal(body, &serviceResp); err != nil {
return nil, fmt.Errorf("failed to parse CAS3 response: %w", err)
}
if serviceResp.Failure != nil {
return nil, fmt.Errorf("CAS validation failed: %s - %s",
serviceResp.Failure.Code, strings.TrimSpace(serviceResp.Failure.Message))
}
if serviceResp.Success == nil {
return nil, fmt.Errorf("invalid CAS3 response: no success or failure element")
}
userInfo := &UserInfo{
Username: serviceResp.Success.User,
Attributes: map[string]string{
"email": serviceResp.Success.Attributes.Email,
"name": serviceResp.Success.Attributes.Name,
"avatar_url": serviceResp.Success.Attributes.AvatarURL,
},
}
// 如果没有显示名称,使用用户名
if userInfo.Attributes["name"] == "" {
userInfo.Attributes["name"] = userInfo.Username
}
return userInfo, nil
}

View File

@@ -0,0 +1,351 @@
package dingtalk
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"time"
openapi "github.com/alibabacloud-go/darabonba-openapi/v2/client"
dingtalkcard_1_0 "github.com/alibabacloud-go/dingtalk/card_1_0"
dingtalkoauth2_1_0 "github.com/alibabacloud-go/dingtalk/v2/oauth2_1_0"
"github.com/alibabacloud-go/tea/tea"
"github.com/chaitin/panda-wiki/log"
"github.com/chaitin/panda-wiki/store/cache"
)
const (
callbackPath = "/share/pro/v1/openapi/dingtalk/callback"
userInfoUrl = "https://api.dingtalk.com/v1.0/contact/users/me"
DepartmentListUrl = "https://oapi.dingtalk.com/department/list"
// https://open.dingtalk.com/document/isvapp/queries-the-complete-information-of-a-department-user
UserListUrl = "https://oapi.dingtalk.com/topapi/v2/user/list"
)
type Client struct {
ctx context.Context
logger *log.Logger
httpClient *http.Client
clientID string
clientSecret string
oauthClient *dingtalkoauth2_1_0.Client
cardClient *dingtalkcard_1_0.Client
dingTalkAuthURL string
cache *cache.Cache
}
// UserInfo 用于解析获取用户信息的接口返回
type UserInfo struct {
Nick string `json:"nick"`
UnionID string `json:"unionId"`
OpenID string `json:"openId"`
AvatarURL string `json:"avatarUrl"`
StateCode string `json:"stateCode"`
}
// DepartmentListRsp 用于解析组织信息接口返回
type DepartmentListRsp struct {
Errcode int `json:"errcode"`
Department []struct {
CreateDeptGroup bool `json:"createDeptGroup"`
Name string `json:"name"`
Id int `json:"id"`
AutoAddUser bool `json:"autoAddUser"`
Parentid int `json:"parentid,omitempty"`
} `json:"department"`
Errmsg string `json:"errmsg"`
}
type GetUserListResp struct {
Errcode int `json:"errcode"`
Result struct {
HasMore bool `json:"has_more"`
List []UserDetail `json:"list"`
} `json:"result"`
Errmsg string `json:"errmsg"`
}
type UserDetail struct {
Active bool `json:"active"`
Admin bool `json:"admin"`
Avatar string `json:"avatar"`
Boss bool `json:"boss"`
DeptIdList []int `json:"dept_id_list"`
DeptOrder int64 `json:"dept_order"`
Email string `json:"email"`
ExclusiveAccount bool `json:"exclusive_account"`
HideMobile bool `json:"hide_mobile"`
JobNumber string `json:"job_number"`
Leader bool `json:"leader"`
Mobile string `json:"mobile"`
Name string `json:"name"`
Remark string `json:"remark"`
StateCode string `json:"state_code"`
Telephone string `json:"telephone"`
Title string `json:"title"`
Unionid string `json:"unionid"`
Userid string `json:"userid"`
WorkPlace string `json:"work_place"`
}
func NewDingTalkClient(ctx context.Context, logger *log.Logger, clientId, clientSecret string, cache *cache.Cache) (*Client, error) {
config := &openapi.Config{}
config.Protocol = tea.String("https")
config.RegionId = tea.String("central")
oauthClient, err := dingtalkoauth2_1_0.NewClient(config)
if err != nil {
return nil, fmt.Errorf("failed to create oauth client: %w", err)
}
cardClient, err := dingtalkcard_1_0.NewClient(config)
if err != nil {
return nil, fmt.Errorf("failed to create card client: %w", err)
}
return &Client{
ctx: ctx,
logger: logger.WithModule("pkg.dingtalk"),
httpClient: &http.Client{},
clientID: clientId,
clientSecret: clientSecret,
oauthClient: oauthClient,
cardClient: cardClient,
dingTalkAuthURL: "https://login.dingtalk.com/oauth2/auth",
cache: cache,
}, nil
}
// GenerateAuthURL 生成钉钉授权URL
func (c *Client) GenerateAuthURL(baseUrl string, state string) string {
redirectURI, err := url.JoinPath(baseUrl, callbackPath)
if err != nil {
c.logger.Error("failed to join path", log.Error(err))
return ""
}
params := url.Values{}
params.Add("response_type", "code")
params.Add("client_id", c.clientID)
params.Add("redirect_uri", redirectURI)
params.Add("scope", "openid")
params.Add("state", state)
params.Add("prompt", "consent")
return fmt.Sprintf("%s?%s", c.dingTalkAuthURL, params.Encode())
}
func (c *Client) GetAccessTokenByCode(code string) (string, error) {
request := &dingtalkoauth2_1_0.GetUserTokenRequest{
ClientId: tea.String(c.clientID),
ClientSecret: tea.String(c.clientSecret),
Code: tea.String(code),
GrantType: tea.String("authorization_code"),
}
response, err := c.oauthClient.GetUserToken(request)
if err != nil {
return "", fmt.Errorf("failed to get user access token: %w", err)
}
accessToken := tea.StringValue(response.Body.AccessToken)
return accessToken, nil
}
func (c *Client) GetAccessToken() (string, error) {
ctx := context.Background()
cacheKey := fmt.Sprintf("dingtalk-access-token:%s", c.clientID)
cachedData, err := c.cache.Get(ctx, cacheKey).Result()
if err == nil && cachedData != "" {
return cachedData, nil
}
request := &dingtalkoauth2_1_0.GetAccessTokenRequest{
AppKey: tea.String(c.clientID),
AppSecret: tea.String(c.clientSecret),
}
response, tryErr := func() (_resp *dingtalkoauth2_1_0.GetAccessTokenResponse, _e error) {
defer func() {
if r := tea.Recover(recover()); r != nil {
_e = r
}
}()
_resp, _err := c.oauthClient.GetAccessToken(request)
if _err != nil {
return nil, _err
}
return _resp, nil
}()
if tryErr != nil {
return "", tryErr
}
accessToken := *response.Body.AccessToken
c.logger.Debug("get access token", log.String("access_token", accessToken), log.Int("expire_in", int(*response.Body.ExpireIn)))
if err := c.cache.Set(ctx, cacheKey, accessToken, time.Duration(*response.Body.ExpireIn-300)*time.Second).Err(); err != nil {
c.logger.Warn("failed to set cache", log.Error(err))
}
return accessToken, nil
}
func (c *Client) GetUserInfoByCode(code string) (*UserInfo, error) {
req, err := http.NewRequest("GET", userInfoUrl, nil)
if err != nil {
return nil, fmt.Errorf("failed to create GET request: %w", err)
}
accessToken, err := c.GetAccessTokenByCode(code)
if err != nil {
return nil, err
}
// Set request headers
req.Header.Set("x-acs-dingtalk-access-token", accessToken)
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send GET request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("DingTalk API returned non-200 status: %s, response: %s", resp.Status, string(body))
}
var userInfo UserInfo
if err := json.Unmarshal(body, &userInfo); err != nil {
return nil, fmt.Errorf("failed to unmarshal JSON response: %w", err)
}
return &userInfo, nil
}
func (c *Client) GetDepartmentList() (*DepartmentListRsp, error) {
accessToken, err := c.GetAccessToken()
if err != nil {
return nil, err
}
params := url.Values{}
params.Add("access_token", accessToken)
requestURL := fmt.Sprintf("%s?%s", DepartmentListUrl, params.Encode())
req, err := http.NewRequest(http.MethodGet, requestURL, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("DingTalk API returned non-200 status: %s, response: %s", resp.Status, string(body))
}
c.logger.Debug("DepartmentListUrl:", log.String("body", string(body)))
var departmentListRsp DepartmentListRsp
if err := json.Unmarshal(body, &departmentListRsp); err != nil {
return nil, fmt.Errorf("failed to unmarshal JSON response: %w", err)
}
if departmentListRsp.Errcode != 0 {
return nil, fmt.Errorf("DingTalk API error: errcode=%d errmsg=%s", departmentListRsp.Errcode, departmentListRsp.Errmsg)
}
return &departmentListRsp, nil
}
func (c *Client) GetAllUserList(deptID int) ([]UserDetail, error) {
depth := 0
const maxDepth = 10
userList := make([]UserDetail, 0)
for depth < maxDepth {
resp, err := c.GetUserList(deptID)
if err != nil {
return nil, err
}
if len(resp.Result.List) > 0 {
userList = append(userList, resp.Result.List...)
}
if !resp.Result.HasMore {
break
}
depth++
}
return userList, nil
}
func (c *Client) GetUserList(deptID int) (*GetUserListResp, error) {
accessToken, err := c.GetAccessToken()
if err != nil {
return nil, err
}
params := url.Values{}
params.Add("access_token", accessToken)
requestURL := fmt.Sprintf("%s?%s", UserListUrl, params.Encode())
bodyMap := map[string]interface{}{
"dept_id": deptID,
"size": 100,
"cursor": 0,
}
jsonData, err := json.Marshal(bodyMap)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
req, err := http.NewRequest(http.MethodPost, requestURL, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("DingTalk API returned non-200 status: %s, response: %s", resp.Status, string(body))
}
c.logger.Debug("GetUserList:", log.String("body", string(body)))
var getUserListResp GetUserListResp
if err := json.Unmarshal(body, &getUserListResp); err != nil {
return nil, fmt.Errorf("failed to unmarshal JSON response: %w", err)
}
if getUserListResp.Errcode != 0 {
return nil, fmt.Errorf("DingTalk GetUserList error: errcode=%d errcode=%s", getUserListResp.Errcode, getUserListResp.Errmsg)
}
return &getUserListResp, nil
}

View File

@@ -0,0 +1,123 @@
package feishu
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"golang.org/x/oauth2"
"github.com/chaitin/panda-wiki/log"
)
const (
AuthURL = "https://accounts.feishu.cn/open-apis/authen/v1/authorize"
TokenURL = "https://open.feishu.cn/open-apis/authen/v2/oauth/token"
UserInfoURL = "https://open.feishu.cn/open-apis/authen/v1/user_info"
callbackPath = "/share/pro/v1/openapi/feishu/callback"
)
var oauthEndpoint = oauth2.Endpoint{
AuthURL: AuthURL,
TokenURL: TokenURL,
}
// Client 飞书客户端
type Client struct {
context context.Context
oauthConfig *oauth2.Config
logger *log.Logger
}
type Response struct {
Code int `json:"code"`
Msg string `json:"msg"`
Data UserInfo `json:"data"`
}
type UserInfo struct {
Name string `json:"name"`
EnName string `json:"en_name"`
AvatarUrl string `json:"avatar_url"`
AvatarThumb string `json:"avatar_thumb"`
AvatarMiddle string `json:"avatar_middle"`
AvatarBig string `json:"avatar_big"`
OpenId string `json:"open_id"`
UnionId string `json:"union_id"`
Email string `json:"email"`
EnterpriseEmail string `json:"enterprise_email"`
UserId string `json:"user_id"`
Mobile string `json:"mobile"`
TenantKey string `json:"tenant_key"`
EmployeeNo string `json:"employee_no"`
}
func NewClient(ctx context.Context, logger *log.Logger, appID, appSecret, baseUrl string) (*Client, error) {
redirectURI, err := url.JoinPath(baseUrl, callbackPath)
if err != nil {
return nil, err
}
oauthConfig := &oauth2.Config{
ClientID: appID,
ClientSecret: appSecret,
RedirectURL: redirectURI,
Endpoint: oauthEndpoint,
Scopes: []string{},
}
return &Client{
context: ctx,
logger: logger.WithModule("feishu.client"),
oauthConfig: oauthConfig,
}, nil
}
// GenerateAuthURL 生成授权 URL
func (c *Client) GenerateAuthURL(state string, verifier string) string {
return c.oauthConfig.AuthCodeURL(state, oauth2.S256ChallengeOption(verifier))
}
// GetAccessToken 通过授权码获取访问令牌
func (c *Client) GetAccessToken(ctx context.Context, code string, codeVerifier string) (*oauth2.Token, error) {
token, err := c.oauthConfig.Exchange(ctx, code, oauth2.VerifierOption(codeVerifier))
if err != nil {
return nil, fmt.Errorf("oauthConfig.Exchange() failed: %w", err)
}
return token, nil
}
// GetUserInfoByCode 获取用户信息
func (c *Client) GetUserInfoByCode(ctx context.Context, code string, codeVerifier string) (*UserInfo, error) {
token, err := c.oauthConfig.Exchange(ctx, code, oauth2.VerifierOption(codeVerifier))
if err != nil {
return nil, fmt.Errorf("oauthConfig.Exchange() failed: %w", err)
}
client := c.oauthConfig.Client(ctx, token)
req, err := http.NewRequest("GET", UserInfoURL, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to get user info: %w", err)
}
defer resp.Body.Close()
var r Response
if err := json.NewDecoder(resp.Body).Decode(&r); err != nil {
return nil, fmt.Errorf("failed to decode user info: %w", err)
}
c.logger.Info("GetUserInfoByCode", log.Any("resp", r))
if r.Code != 0 {
return nil, fmt.Errorf("failed to get user info: %s", r.Msg)
}
return &r.Data, nil
}

207
backend/pkg/ldap/ldap.go Normal file
View File

@@ -0,0 +1,207 @@
package ldap
import (
"context"
"fmt"
"strings"
"github.com/go-ldap/ldap/v3"
"github.com/chaitin/panda-wiki/log"
)
type Client struct {
logger *log.Logger
ctx context.Context
config *Config
}
type Config struct {
ServerURL string `json:"server_url"` // LDAP服务器URL如 ldap://openldap.company.com:389
BindDN string `json:"bind_dn"` // 绑定DN如 cn=admin,dc=company,dc=com
BindPassword string `json:"bind_password"` // 绑定密码
UserBaseDN string `json:"user_base_dn"` // 用户基础DN如 ou=People,dc=company,dc=com
UserFilter string `json:"user_filter"` // 用户查询过滤器,如 (&(objectClass=person)(uid=%s))
UserIDAttr string `json:"user_id_attr"` // 用户ID属性默认 uid
UserNameAttr string `json:"user_name_attr"` // 用户名属性,默认 cn
UserEmailAttr string `json:"user_email_attr"` // 用户邮箱属性,默认 mail
}
type UserInfo struct {
ID string `json:"id"`
Username string `json:"username"`
Email string `json:"email"`
DN string `json:"dn"` // Distinguished Name
}
const (
defaultUserIDAttr = "uid"
defaultUserNameAttr = "cn"
defaultUserEmailAttr = "mail"
defaultUserFilter = "(&(objectClass=person)(uid=%s))"
)
// NewClient 创建LDAP客户端
func NewClient(ctx context.Context, logger *log.Logger, config Config) (*Client, error) {
// 设置默认值
if config.UserIDAttr == "" {
config.UserIDAttr = defaultUserIDAttr
}
if config.UserNameAttr == "" {
config.UserNameAttr = defaultUserNameAttr
}
if config.UserEmailAttr == "" {
config.UserEmailAttr = defaultUserEmailAttr
}
if config.UserFilter == "" {
config.UserFilter = defaultUserFilter
}
// 验证必需的配置
if config.ServerURL == "" {
return nil, fmt.Errorf("LDAP server URL is required")
}
if config.BindDN == "" {
return nil, fmt.Errorf("bind DN is required")
}
if config.UserBaseDN == "" {
return nil, fmt.Errorf("user base DN is required")
}
return &Client{
ctx: ctx,
logger: logger.WithModule("pkg.ldap"),
config: &config,
}, nil
}
// Authenticate 验证用户凭据并获取用户信息
func (c *Client) Authenticate(username, password string) (*UserInfo, error) {
// 连接到LDAP服务器
conn, err := ldap.DialURL(c.config.ServerURL)
if err != nil {
c.logger.Error("failed to connect to LDAP server", log.Error(err))
return nil, fmt.Errorf("failed to connect to LDAP server: %w", err)
}
defer conn.Close()
// 使用管理员账户绑定
err = conn.Bind(c.config.BindDN, c.config.BindPassword)
if err != nil {
c.logger.Error("failed to bind with admin credentials", log.Error(err))
return nil, fmt.Errorf("failed to bind with admin credentials: %w", err)
}
// 搜索用户
userInfo, err := c.searchUser(conn, username)
if err != nil {
return nil, err
}
// 验证用户密码
err = conn.Bind(userInfo.DN, password)
if err != nil {
c.logger.Error("user authentication failed",
log.String("username", username),
log.String("dn", userInfo.DN),
log.Error(err))
return nil, fmt.Errorf("authentication failed: invalid credentials")
}
c.logger.Info("user authenticated successfully",
log.String("username", username),
log.String("dn", userInfo.DN))
return userInfo, nil
}
// searchUser 搜索用户信息
func (c *Client) searchUser(conn *ldap.Conn, username string) (*UserInfo, error) {
// 构建搜索过滤器
filter := fmt.Sprintf(c.config.UserFilter, username)
// 构建搜索请求
searchRequest := ldap.NewSearchRequest(
c.config.UserBaseDN,
ldap.ScopeWholeSubtree,
ldap.NeverDerefAliases,
0, // 不限制结果数量
0, // 不限制搜索时间
false,
filter,
[]string{c.config.UserIDAttr, c.config.UserNameAttr, c.config.UserEmailAttr},
nil,
)
c.logger.Info("searching for user",
log.String("filter", filter),
log.String("base_dn", c.config.UserBaseDN))
// 执行搜索
searchResult, err := conn.Search(searchRequest)
if err != nil {
c.logger.Error("user search failed", log.Error(err))
return nil, fmt.Errorf("user search failed: %w", err)
}
// 检查搜索结果
if len(searchResult.Entries) == 0 {
c.logger.Warn("user not found", log.String("username", username))
return nil, fmt.Errorf("user not found: %s", username)
}
if len(searchResult.Entries) > 1 {
c.logger.Warn("multiple users found",
log.String("username", username),
log.Int("count", len(searchResult.Entries)))
return nil, fmt.Errorf("multiple users found for username: %s", username)
}
// 解析用户信息
entry := searchResult.Entries[0]
userInfo := &UserInfo{
DN: entry.DN,
ID: c.getAttributeValue(entry, c.config.UserIDAttr),
Username: c.getAttributeValue(entry, c.config.UserNameAttr),
Email: c.getAttributeValue(entry, c.config.UserEmailAttr),
}
// 如果没有获取到用户名使用ID作为用户名
if userInfo.Username == "" {
userInfo.Username = userInfo.ID
}
c.logger.Info("user found",
log.String("dn", userInfo.DN),
log.String("id", userInfo.ID),
log.String("username", userInfo.Username),
log.String("email", userInfo.Email))
return userInfo, nil
}
// getAttributeValue 获取LDAP属性值
func (c *Client) getAttributeValue(entry *ldap.Entry, attrName string) string {
values := entry.GetAttributeValues(attrName)
if len(values) > 0 {
return strings.TrimSpace(values[0])
}
return ""
}
// TestConnection 测试LDAP连接
func (c *Client) TestConnection() error {
conn, err := ldap.DialURL(c.config.ServerURL)
if err != nil {
return fmt.Errorf("failed to connect to LDAP server: %w", err)
}
defer conn.Close()
err = conn.Bind(c.config.BindDN, c.config.BindPassword)
if err != nil {
return fmt.Errorf("failed to bind with admin credentials: %w", err)
}
c.logger.Info("LDAP connection test successful")
return nil
}

137
backend/pkg/oauth/github.go Normal file
View File

@@ -0,0 +1,137 @@
package oauth
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"golang.org/x/oauth2"
"github.com/chaitin/panda-wiki/consts"
"github.com/chaitin/panda-wiki/log"
)
const (
githubAuthorizeURL = "https://github.com/login/oauth/authorize"
githubTokenURL = "https://github.com/login/oauth/access_token"
githubUserInfoURL = "https://api.github.com/user"
githubUserEmailURL = "https://api.github.com/user/emails"
githubCallbackPathPro = "/share/pro/v1/openapi/github/callback"
githubCallbackPath = "/share/v1/openapi/github/callback"
)
func NewGithubClient(ctx context.Context, logger *log.Logger, clientID, clientSecret, redirectURI, proxyURL string) (*Client, error) {
licenseEdition, ok := ctx.Value(consts.ContextKeyEdition).(consts.LicenseEdition)
if !ok {
return nil, fmt.Errorf("failed to retrieve license edition from context")
}
redirectURL, _ := url.Parse(redirectURI)
redirectURL.Path = githubCallbackPath
if licenseEdition > consts.LicenseEditionFree {
redirectURL.Path = githubCallbackPathPro
}
redirectURI = redirectURL.String()
var httpClient *http.Client
if proxyURL != "" {
proxyURLParsed, err := url.Parse(proxyURL)
if err != nil {
return nil, fmt.Errorf("invalid proxy URL: %w", err)
}
httpClient = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxyURLParsed),
},
}
logger.Info("GitHub OAuth client configured with proxy", log.String("proxy", proxyURL))
} else {
httpClient = http.DefaultClient
}
config := Config{
ClientID: clientID,
ClientSecret: clientSecret,
Scopes: []string{"user:email"},
AuthorizeURL: githubAuthorizeURL,
TokenURL: githubTokenURL,
UserInfoURL: githubUserInfoURL,
IDField: "id",
NameField: "login",
AvatarField: "avatar_url",
EmailField: "email",
RedirectURI: redirectURI,
}
oauthConfig := &oauth2.Config{
ClientID: config.ClientID,
ClientSecret: config.ClientSecret,
Endpoint: oauth2.Endpoint{
AuthURL: config.AuthorizeURL,
TokenURL: config.TokenURL,
},
RedirectURL: redirectURI,
Scopes: config.Scopes,
}
if proxyURL != "" {
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
}
return &Client{
ctx: ctx,
logger: logger.WithModule("pkg.oauth"),
oauth: oauthConfig,
httpClient: httpClient,
config: &config,
}, nil
}
func (c *Client) GetGithubPrimaryEmail(token *oauth2.Token) (string, error) {
var client *http.Client
if c.httpClient != nil {
ctx := context.WithValue(c.ctx, oauth2.HTTPClient, c.httpClient)
client = c.oauth.Client(ctx, token)
} else {
client = c.oauth.Client(c.ctx, token)
}
type Email struct {
Email string `json:"email"`
Primary bool `json:"primary"`
Verified bool `json:"verified"`
}
resp, err := client.Get(githubUserEmailURL)
if err != nil {
return "", err
}
defer resp.Body.Close()
buf, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
c.logger.Info("GetGithubPrimaryEmail:", log.Any("buf", string(buf)))
var emails []Email
if err := json.Unmarshal(buf, &emails); err != nil {
return "", err
}
for _, email := range emails {
if email.Primary && email.Verified {
return email.Email, nil
}
}
return "", errors.New("no primary verified email found")
}

110
backend/pkg/oauth/oauth.go Normal file
View File

@@ -0,0 +1,110 @@
package oauth
import (
"context"
"io"
"net/http"
"net/url"
"github.com/tidwall/gjson"
"golang.org/x/oauth2"
"github.com/chaitin/panda-wiki/log"
)
type Client struct {
logger *log.Logger
ctx context.Context
config *Config
oauth *oauth2.Config
httpClient *http.Client
}
const (
callbackPath = "/share/pro/v1/openapi/oauth/callback"
)
type Config struct {
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret"`
RedirectURI string `json:"redirect_uri,omitempty"`
Scopes []string `json:"scopes,omitempty"`
AuthorizeURL string `json:"authorize_url,omitempty"`
TokenURL string `json:"token_url,omitempty"`
UserInfoURL string `json:"user_info_url,omitempty"`
IDField string `json:"id_field,omitempty"`
NameField string `json:"name_field,omitempty"`
AvatarField string `json:"avatar_field,omitempty"`
EmailField string `json:"email_field,omitempty"`
}
type UserInfo struct {
ID string `json:"id"`
Name string `json:"name"`
Email string `json:"email"`
AvatarUrl string `json:"avatar_url"`
}
// NewClient 创建OAuth客户端
func NewClient(ctx context.Context, logger *log.Logger, baseUrl string, config Config) (*Client, error) {
redirectURI, err := url.JoinPath(baseUrl, callbackPath)
if err != nil {
return nil, err
}
return &Client{
ctx: ctx,
logger: logger.WithModule("pkg.oauth"),
oauth: &oauth2.Config{
ClientID: config.ClientID,
ClientSecret: config.ClientSecret,
Endpoint: oauth2.Endpoint{
AuthURL: config.AuthorizeURL,
TokenURL: config.TokenURL,
},
RedirectURL: redirectURI,
Scopes: config.Scopes,
},
config: &config,
}, nil
}
func (c *Client) GetAuthorizeURL(state string) string {
return c.oauth.AuthCodeURL(state)
}
func (c *Client) GetUserInfo(code string) (*UserInfo, error) {
token, err := c.oauth.Exchange(c.ctx, code)
if err != nil {
return nil, err
}
client := c.oauth.Client(c.ctx, token)
res, err := client.Get(c.config.UserInfoURL)
if err != nil {
return nil, err
}
defer res.Body.Close()
buf, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}
c.logger.Info("oauth GetUserInfo:", log.Any("resp", string(buf)))
jsonString := string(buf)
email := gjson.Get(jsonString, c.config.EmailField).String()
if email == "" && c.config.UserInfoURL == githubUserInfoURL {
email, err = c.GetGithubPrimaryEmail(token)
if err != nil {
c.logger.Warn("GetGithubPrimaryEmail failed", log.Error(err))
}
}
return &UserInfo{
ID: gjson.Get(jsonString, c.config.IDField).String(),
AvatarUrl: gjson.Get(jsonString, c.config.AvatarField).String(),
Name: gjson.Get(jsonString, c.config.NameField).String(),
Email: email,
}, nil
}

View File

@@ -0,0 +1,102 @@
package ratelimit
import (
"context"
"errors"
"fmt"
"time"
"github.com/redis/go-redis/v9"
"github.com/chaitin/panda-wiki/log"
"github.com/chaitin/panda-wiki/store/cache"
)
type RateLimiter struct {
logger *log.Logger
cache *cache.Cache
}
func NewRateLimiter(logger *log.Logger, cache *cache.Cache) *RateLimiter {
return &RateLimiter{
logger: logger,
cache: cache,
}
}
const (
LockThreshold1 = 5 // 第一次锁定阈值
LockThreshold2 = 10 // 第二次锁定阈值
LockThreshold3 = 15 // 第三次锁定阈值
AttemptsKeyExpiry = 24 * time.Hour
)
// CheckIPLocked checks if the IP is currently locked
// Returns:
// - bool: whether the IP is locked
// - time.Duration: remaining lockout duration
func (r *RateLimiter) CheckIPLocked(ctx context.Context, ip string) (bool, time.Duration) {
lockKey := fmt.Sprintf("login_lock:%s", ip)
ttl, err := r.cache.TTL(ctx, lockKey).Result()
if err != nil {
r.logger.Error("failed to check lock status", "error", err, "ip", ip)
return false, 0
}
if ttl > 0 {
return true, ttl
}
return false, 0
}
func (r *RateLimiter) LockAttempt(ctx context.Context, ip string) {
attemptsKey := fmt.Sprintf("login_attempts:%s", ip)
lockKey := fmt.Sprintf("login_lock:%s", ip)
attempts, err := r.cache.Incr(ctx, attemptsKey).Result()
if err != nil {
r.logger.Error("failed to increment attempts", "error", err, "ip", ip)
return
}
if err := r.cache.Expire(ctx, attemptsKey, AttemptsKeyExpiry).Err(); err != nil {
r.logger.Error("failed to set expiry on attempts key", "error", err, "ip", ip)
}
var lockDuration time.Duration
if attempts%5 == 0 {
switch {
case attempts == LockThreshold1:
lockDuration = time.Minute
case attempts == LockThreshold2:
lockDuration = 15 * time.Minute
case attempts >= LockThreshold3:
lockDuration = time.Hour
}
if lockDuration > 0 {
if err := r.cache.Set(ctx, lockKey, 1, lockDuration).Err(); err != nil {
r.logger.Error("failed to set lock key", "error", err, "ip", ip)
return
}
r.logger.Info("IP has been locked", "ip", ip, "lockDuration", lockDuration)
}
}
}
// ResetLoginAttempts resets the login attempt counter and lock for an IP
func (r *RateLimiter) ResetLoginAttempts(ctx context.Context, ip string) error {
attemptsKey := fmt.Sprintf("login_attempts:%s", ip)
lockKey := fmt.Sprintf("login_lock:%s", ip)
pipe := r.cache.Pipeline()
pipe.Del(ctx, attemptsKey)
pipe.Del(ctx, lockKey)
_, err := pipe.Exec(ctx)
if err != nil && !errors.Is(err, redis.Nil) {
return fmt.Errorf("failed to reset login attempts: %w", err)
}
return nil
}

347
backend/pkg/wecom/wecom.go Normal file
View File

@@ -0,0 +1,347 @@
package wecom
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"time"
"golang.org/x/oauth2"
"github.com/chaitin/panda-wiki/log"
"github.com/chaitin/panda-wiki/store/cache"
)
const (
// AuthURL api doc https://developer.work.weixin.qq.com/document/path/98152
AuthWebURL = "https://login.work.weixin.qq.com/wwlogin/sso/login"
AuthAPPURL = "https://open.weixin.qq.com/connect/oauth2/authorize"
TokenURL = "https://qyapi.weixin.qq.com/cgi-bin/gettoken"
UserInfoURL = "https://qyapi.weixin.qq.com/cgi-bin/auth/getuserinfo"
UserDetailURL = "https://qyapi.weixin.qq.com/cgi-bin/user/get"
// DepartmentListURL https://developer.work.weixin.qq.com/document/path/90344
DepartmentListURL = "https://qyapi.weixin.qq.com/cgi-bin/department/list"
// UserListUrl https://developer.work.weixin.qq.com/document/path/90337
UserListUrl = "https://qyapi.weixin.qq.com/cgi-bin/user/list"
callbackPath = "/share/pro/v1/openapi/wecom/callback"
)
// Client 企业微信客户端
type Client struct {
context context.Context
cache *cache.Cache
httpClient *http.Client
oauthConfig *oauth2.Config
logger *log.Logger
corpID string
agentID string
}
type TokenResponse struct {
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
}
type UserInfoResponse struct {
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
UserID string `json:"userid"`
UserTicket string `json:"user_ticket"`
OpenID string `json:"openid"`
ExternalUserid string `json:"external_userid"`
}
type UserDetailResponse struct {
Errcode int `json:"errcode"`
Errmsg string `json:"errmsg"`
Userid string `json:"userid"`
Name string `json:"name"`
Mobile string `json:"mobile"`
Gender string `json:"gender"`
Email string `json:"email"`
Avatar string `json:"avatar"`
OpenUserid string `json:"open_userid"`
}
type DepartmentListResponse struct {
Errcode int `json:"errcode"`
Errmsg string `json:"errmsg"`
Department []struct {
Id int `json:"id"`
Name string `json:"name"`
NameEn string `json:"name_en"`
DepartmentLeader []string `json:"department_leader"`
Parentid int `json:"parentid"`
Order int `json:"order"`
} `json:"department"`
}
type UserListResponse struct {
Errcode int `json:"errcode"`
Errmsg string `json:"errmsg"`
Userlist []struct {
Name string `json:"name"`
Department []int `json:"department"`
Position string `json:"position"`
Status int `json:"status"`
Email string `json:"email"`
Avatar string `json:"avatar"`
Enable int `json:"enable"`
Isleader int `json:"isleader"`
Extattr struct {
Attrs []interface{} `json:"attrs"`
} `json:"extattr"`
HideMobile int `json:"hide_mobile"`
Telephone string `json:"telephone"`
Order []int `json:"order"`
ExternalProfile struct {
ExternalAttr []interface{} `json:"external_attr"`
ExternalCorpName string `json:"external_corp_name"`
} `json:"external_profile"`
MainDepartment int `json:"main_department"`
Alias string `json:"alias"`
IsLeaderInDept []int `json:"is_leader_in_dept"`
Userid string `json:"userid"`
DirectLeader []interface{} `json:"direct_leader"`
} `json:"userlist"`
}
func NewClient(ctx context.Context, logger *log.Logger, corpID, corpSecret, agentID, baseUrl string, cache *cache.Cache, isApp bool) (*Client, error) {
redirectURI, err := url.JoinPath(baseUrl, callbackPath)
if err != nil {
return nil, err
}
authUrl := AuthWebURL
if isApp {
authUrl = AuthAPPURL
}
oauthConfig := &oauth2.Config{
ClientID: fmt.Sprintf("%s-%s", corpID, agentID),
ClientSecret: corpSecret,
RedirectURL: redirectURI,
Endpoint: oauth2.Endpoint{
AuthURL: authUrl,
TokenURL: TokenURL,
},
Scopes: []string{"snsapi_privateinfo"},
}
return &Client{
context: ctx,
httpClient: &http.Client{},
cache: cache,
logger: logger.WithModule("wecom.client"),
oauthConfig: oauthConfig,
corpID: corpID,
agentID: agentID,
}, nil
}
// GenerateAuthURL 生成授权 URL
func (c *Client) GenerateAuthURL(state string) string {
params := url.Values{}
params.Set("appid", c.corpID)
params.Set("redirect_uri", c.oauthConfig.RedirectURL)
params.Set("response_type", "code")
params.Set("scope", "snsapi_privateinfo")
params.Set("login_type", "CorpApp")
params.Set("agentid", c.agentID)
params.Set("state", state)
authUrl := fmt.Sprintf("%s?%s", c.oauthConfig.Endpoint.AuthURL, params.Encode())
if c.oauthConfig.Endpoint.AuthURL == AuthAPPURL {
authUrl += "#wechat_redirect"
}
return authUrl
}
// GetAccessToken 获取企业微信访问令牌
func (c *Client) GetAccessToken(ctx context.Context) (string, error) {
cacheKey := fmt.Sprintf("wecom-access-token:%s", c.oauthConfig.ClientID)
cachedData, err := c.cache.Get(ctx, cacheKey).Result()
if err == nil && cachedData != "" {
return cachedData, nil
}
params := url.Values{}
params.Set("corpid", c.corpID)
params.Set("corpsecret", c.oauthConfig.ClientSecret)
resp, err := c.httpClient.Get(fmt.Sprintf("%s?%s", TokenURL, params.Encode()))
if err != nil {
return "", fmt.Errorf("failed to get access token: %w", err)
}
defer resp.Body.Close()
var tokenResp TokenResponse
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
return "", fmt.Errorf("failed to decode token response: %w", err)
}
if tokenResp.ErrCode != 0 {
return "", fmt.Errorf("failed to get access token: %s", tokenResp.ErrMsg)
}
if err := c.cache.Set(ctx, cacheKey, tokenResp.AccessToken, time.Duration(tokenResp.ExpiresIn-300)*time.Second).Err(); err != nil {
c.logger.Warn("failed to set cache", log.Error(err))
}
return tokenResp.AccessToken, nil
}
// GetUserInfoByCode 通过授权码获取用户信息
func (c *Client) GetUserInfoByCode(ctx context.Context, code string) (*UserDetailResponse, error) {
accessToken, err := c.GetAccessToken(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get access token: %w", err)
}
params := url.Values{}
params.Set("access_token", accessToken)
params.Set("code", code)
userInfoURL := fmt.Sprintf("%s?%s", UserInfoURL, params.Encode())
c.logger.Debug("GetUserInfoByCode", log.Any("userInfoURL", userInfoURL))
resp, err := c.httpClient.Get(userInfoURL)
if err != nil {
return nil, fmt.Errorf("failed to get user info: %w", err)
}
rawBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read body: %w", err)
}
defer resp.Body.Close()
c.logger.Debug("GetUserInfoByCode raw resp:", log.Any("raw", string(rawBody)))
resp.Body = io.NopCloser(bytes.NewReader(rawBody))
var userInfoResp UserInfoResponse
if err := json.NewDecoder(resp.Body).Decode(&userInfoResp); err != nil {
return nil, fmt.Errorf("failed to decode user info response: %w", err)
}
c.logger.Debug("GetUserInfoByCode resp:", log.Any("resp", userInfoResp))
if userInfoResp.ErrCode != 0 {
return nil, fmt.Errorf("failed to get user info: %s", userInfoResp.ErrMsg)
}
detailParams := url.Values{}
detailParams.Set("access_token", accessToken)
detailParams.Set("userid", userInfoResp.UserID)
userDetailURL := fmt.Sprintf("%s?%s", UserDetailURL, detailParams.Encode())
detailResp, err := c.httpClient.Get(userDetailURL)
if err != nil {
return nil, fmt.Errorf("failed to get user detail: %w", err)
}
defer detailResp.Body.Close()
var UserDetailResp UserDetailResponse
if err := json.NewDecoder(detailResp.Body).Decode(&UserDetailResp); err != nil {
return nil, fmt.Errorf("failed to decode user detail response: %w", err)
}
c.logger.Debug("GetUserInfoByCode detail info", log.Any("resp", UserDetailResp))
if UserDetailResp.Errcode != 0 {
return nil, fmt.Errorf("failed to get user detail: %s", UserDetailResp.Errmsg)
}
return &UserDetailResp, nil
}
func (c *Client) GetDepartmentList(ctx context.Context) (*DepartmentListResponse, error) {
accessToken, err := c.GetAccessToken(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get access token: %w", err)
}
params := url.Values{}
params.Set("access_token", accessToken)
departmentListURL := fmt.Sprintf("%s?%s", DepartmentListURL, params.Encode())
resp, err := c.httpClient.Get(departmentListURL)
if err != nil {
return nil, fmt.Errorf("failed to get department list: %w", err)
}
rawBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read body: %w", err)
}
c.logger.Debug("GetDepartmentList raw resp:", log.Any("raw", string(rawBody)))
defer resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(rawBody))
var departmentListResponse DepartmentListResponse
if err := json.NewDecoder(resp.Body).Decode(&departmentListResponse); err != nil {
return nil, fmt.Errorf("failed to decode department list response: %w", err)
}
c.logger.Debug("GetDepartmentList resp:", log.Any("resp", departmentListResponse))
if departmentListResponse.Errcode != 0 {
return nil, fmt.Errorf("failed to get user info: %s", departmentListResponse.Errmsg)
}
return &departmentListResponse, nil
}
func (c *Client) GetUserList(ctx context.Context, deptID string) (*UserListResponse, error) {
accessToken, err := c.GetAccessToken(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get access token: %w", err)
}
params := url.Values{}
params.Set("access_token", accessToken)
params.Set("department_id", deptID)
userListUrl := fmt.Sprintf("%s?%s", UserListUrl, params.Encode())
resp, err := c.httpClient.Get(userListUrl)
if err != nil {
return nil, fmt.Errorf("failed to get user list: %w", err)
}
rawBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read body: %w", err)
}
c.logger.Debug("GetUserList raw resp:", log.Any("raw", string(rawBody)))
resp.Body = io.NopCloser(bytes.NewReader(rawBody))
var userListResponse UserListResponse
if err := json.NewDecoder(resp.Body).Decode(&userListResponse); err != nil {
return nil, fmt.Errorf("failed to decode user list response: %w", err)
}
c.logger.Debug("GetUserList resp:", log.Any("resp", userListResponse))
if userListResponse.Errcode != 0 {
return nil, fmt.Errorf("failed to get user info: %s", userListResponse.Errmsg)
}
return &userListResponse, nil
}