init push
This commit is contained in:
341
backend/pkg/anydoc/anydoc.go
Normal file
341
backend/pkg/anydoc/anydoc.go
Normal file
@@ -0,0 +1,341 @@
|
||||
package anydoc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/chaitin/panda-wiki/domain"
|
||||
"github.com/chaitin/panda-wiki/log"
|
||||
"github.com/chaitin/panda-wiki/mq"
|
||||
"github.com/chaitin/panda-wiki/mq/types"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
httpClient *http.Client
|
||||
logger *log.Logger
|
||||
mqConsumer mq.MQConsumer
|
||||
taskWaiters map[string]chan *domain.AnydocTaskExportEvent
|
||||
mutex sync.RWMutex
|
||||
subscribed bool
|
||||
subscribeMu sync.Mutex
|
||||
}
|
||||
|
||||
const (
|
||||
apiUploaderUrl = "http://panda-wiki-api:8000/api/v1/file/upload/anydoc"
|
||||
uploaderDir = "/image"
|
||||
crawlerServiceHost = "http://panda-wiki-crawler:8080"
|
||||
SpaceIdCloud = "cloud_disk"
|
||||
getUrlPath = "/api/docs/url/list"
|
||||
UrlExportPath = "/api/docs/url/export"
|
||||
TaskListPath = "/api/tasks/list"
|
||||
)
|
||||
|
||||
type Status string
|
||||
|
||||
const (
|
||||
StatusPending Status = "pending"
|
||||
StatusInProgress Status = "in_process"
|
||||
StatusCompleted Status = "completed"
|
||||
StatusFailed Status = "failed"
|
||||
)
|
||||
|
||||
type UploaderType uint
|
||||
|
||||
const (
|
||||
uploaderTypeDefault UploaderType = iota
|
||||
uploaderTypeHTTP
|
||||
)
|
||||
|
||||
func NewClient(logger *log.Logger, mqConsumer mq.MQConsumer) (*Client, error) {
|
||||
client := &Client{
|
||||
logger: logger.WithModule("anydoc.client"),
|
||||
httpClient: &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
taskWaiters: make(map[string]chan *domain.AnydocTaskExportEvent),
|
||||
mqConsumer: mqConsumer,
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *Client) GetUrlList(ctx context.Context, targetURL, id string) (*ListDocResponse, error) {
|
||||
|
||||
u, err := url.Parse(crawlerServiceHost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.Path = getUrlPath
|
||||
q := u.Query()
|
||||
q.Set("url", targetURL)
|
||||
q.Set("uuid", id)
|
||||
u.RawQuery = q.Encode()
|
||||
requestURL := u.String()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.logger.Info("scrape url", "requestURL:", requestURL, "resp", string(respBody))
|
||||
var scrapeResp ListDocResponse
|
||||
err = json.Unmarshal(respBody, &scrapeResp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !scrapeResp.Success {
|
||||
return nil, errors.New(scrapeResp.Msg)
|
||||
}
|
||||
|
||||
return &scrapeResp, nil
|
||||
}
|
||||
|
||||
func (c *Client) UrlExport(ctx context.Context, id, docID, kbId string) (*UrlExportRes, error) {
|
||||
|
||||
u, err := url.Parse(crawlerServiceHost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.Path = UrlExportPath
|
||||
requestURL := u.String()
|
||||
|
||||
bodyMap := map[string]interface{}{
|
||||
"uuid": id,
|
||||
"doc_id": docID,
|
||||
"uploader": map[string]interface{}{
|
||||
"type": uploaderTypeHTTP,
|
||||
"http": map[string]interface{}{
|
||||
"url": apiUploaderUrl,
|
||||
},
|
||||
"dir": fmt.Sprintf("/%s", kbId),
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(bodyMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.logger.Info("UrlExport", "requestURL:", requestURL, "resp", string(respBody))
|
||||
var res UrlExportRes
|
||||
err = json.Unmarshal(respBody, &res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !res.Success {
|
||||
return nil, errors.New(res.Msg)
|
||||
}
|
||||
return &res, nil
|
||||
}
|
||||
|
||||
// ensureSubscribed 确保已订阅消息队列,只订阅一次
|
||||
func (c *Client) ensureSubscribed() error {
|
||||
c.subscribeMu.Lock()
|
||||
defer c.subscribeMu.Unlock()
|
||||
|
||||
if c.subscribed {
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.mqConsumer == nil {
|
||||
return fmt.Errorf("MQ consumer not initialized")
|
||||
}
|
||||
|
||||
err := c.mqConsumer.RegisterHandler(domain.AnydocTaskExportTopic, c.handleTaskExportEvent)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to register task export handler: %w", err)
|
||||
}
|
||||
|
||||
c.subscribed = true
|
||||
c.logger.Info("successfully subscribed to anydoc task export topic")
|
||||
return nil
|
||||
}
|
||||
|
||||
// TaskWaitForCompletion 通过 NATS 消息队列等待任务完成(推荐方式)
|
||||
func (c *Client) TaskWaitForCompletion(ctx context.Context, taskID string) (*domain.AnydocTaskExportEvent, error) {
|
||||
if c.mqConsumer == nil {
|
||||
return nil, fmt.Errorf("MQ consumer not initialized, use NewClientWithMQ instead")
|
||||
}
|
||||
|
||||
// 延迟订阅:只有在需要时才订阅
|
||||
if err := c.ensureSubscribed(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create a channel to wait for the specific task
|
||||
taskChan := make(chan *domain.AnydocTaskExportEvent, 1)
|
||||
|
||||
c.mutex.Lock()
|
||||
c.taskWaiters[taskID] = taskChan
|
||||
c.mutex.Unlock()
|
||||
|
||||
// Cleanup when done
|
||||
defer func() {
|
||||
c.mutex.Lock()
|
||||
delete(c.taskWaiters, taskID)
|
||||
c.mutex.Unlock()
|
||||
close(taskChan)
|
||||
}()
|
||||
|
||||
// Wait for task completion or context cancellation
|
||||
select {
|
||||
case event := <-taskChan:
|
||||
return event, nil
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// TaskListPoll 轮询方式(保留兼容性)
|
||||
func (c *Client) TaskListPoll(ctx context.Context, ids []string) (*TaskRes, error) {
|
||||
depth := 0
|
||||
const maxDepth = 10
|
||||
|
||||
for depth < maxDepth {
|
||||
time.Sleep(1000 * time.Millisecond)
|
||||
resp, err := c.TaskList(ctx, ids)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Data[0].Status == StatusCompleted {
|
||||
return resp, nil
|
||||
}
|
||||
depth++
|
||||
}
|
||||
return nil, fmt.Errorf("task list poll timeout")
|
||||
}
|
||||
|
||||
// handleTaskExportEvent 处理任务导出完成事件
|
||||
func (c *Client) handleTaskExportEvent(ctx context.Context, msg types.Message) error {
|
||||
var event domain.AnydocTaskExportEvent
|
||||
if err := json.Unmarshal(msg.GetData(), &event); err != nil {
|
||||
c.logger.Error("failed to unmarshal task export event", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
c.logger.Info("received task export event",
|
||||
"task_id", event.TaskID,
|
||||
"status", event.Status,
|
||||
"doc_id", event.DocID)
|
||||
|
||||
// Notify waiting goroutines
|
||||
c.mutex.RLock()
|
||||
if taskChan, exists := c.taskWaiters[event.TaskID]; exists {
|
||||
select {
|
||||
case taskChan <- &event:
|
||||
default:
|
||||
// Channel is full or closed, ignore
|
||||
}
|
||||
}
|
||||
c.mutex.RUnlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) TaskList(ctx context.Context, ids []string) (*TaskRes, error) {
|
||||
u, err := url.Parse(crawlerServiceHost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.Path = TaskListPath
|
||||
requestURL := u.String()
|
||||
|
||||
bodyMap := map[string]interface{}{
|
||||
"ids": ids,
|
||||
}
|
||||
jsonData, err := json.Marshal(bodyMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.logger.Info("TaskList url", "requestURL", requestURL, "resp", string(respBody))
|
||||
var res TaskRes
|
||||
err = json.Unmarshal(respBody, &res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !res.Success {
|
||||
return nil, errors.New(res.Msg)
|
||||
}
|
||||
if len(res.Data) == 0 {
|
||||
return nil, errors.New("data list is empty")
|
||||
}
|
||||
return &res, nil
|
||||
}
|
||||
|
||||
func (c *Client) DownloadDoc(ctx context.Context, filepath string) ([]byte, error) {
|
||||
u, err := url.Parse(crawlerServiceHost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.Path = "/api/tasks/download" + filepath
|
||||
requestURL := u.String()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.logger.Info("DownloadDoc", "requestURL:", requestURL, "resp length", len(respBody))
|
||||
return respBody, nil
|
||||
}
|
||||
154
backend/pkg/anydoc/confluence.go
Normal file
154
backend/pkg/anydoc/confluence.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package anydoc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
const (
|
||||
ConfluenceListPath = "/api/docs/confluence/list"
|
||||
ConfluenceExportPath = "/api/docs/confluence/export"
|
||||
)
|
||||
|
||||
// ConfluenceListDocsRequest Confluence 获取文档列表请求
|
||||
type ConfluenceListDocsRequest struct {
|
||||
URL string `json:"url"` // Confluence 配置文件
|
||||
Filename string `json:"filename"` // 文件名,需要带扩展名
|
||||
UUID string `json:"uuid"` // 必填的唯一标识符
|
||||
}
|
||||
|
||||
// ConfluenceExportDocRequest Confluence 导出文档请求
|
||||
type ConfluenceExportDocRequest struct {
|
||||
UUID string `json:"uuid"` // 必须与 list 接口使用的 uuid 相同
|
||||
DocID string `json:"doc_id"` // confluence-doc-id
|
||||
}
|
||||
|
||||
// ConfluenceExportDocResponse Confluence 导出文档响应
|
||||
type ConfluenceExportDocResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Msg string `json:"msg"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// ConfluenceExportDocData Confluence 导出文档数据
|
||||
type ConfluenceExportDocData struct {
|
||||
TaskID string `json:"task_id"`
|
||||
Status string `json:"status"`
|
||||
FilePath string `json:"file_path"`
|
||||
}
|
||||
|
||||
// ConfluenceListDocs 获取 Confluence 文档列表
|
||||
func (c *Client) ConfluenceListDocs(ctx context.Context, confluenceURL, filename, uuid string) (*ListDocResponse, error) {
|
||||
u, err := url.Parse(crawlerServiceHost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.Path = ConfluenceListPath
|
||||
requestURL := u.String()
|
||||
|
||||
bodyMap := map[string]interface{}{
|
||||
"url": confluenceURL,
|
||||
"filename": filename,
|
||||
"uuid": uuid,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(bodyMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.logger.Info("ConfluenceListDocs", "requestURL:", requestURL, "resp", string(respBody))
|
||||
|
||||
var confluenceResp ListDocResponse
|
||||
err = json.Unmarshal(respBody, &confluenceResp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !confluenceResp.Success {
|
||||
return nil, errors.New(confluenceResp.Msg)
|
||||
}
|
||||
|
||||
return &confluenceResp, nil
|
||||
}
|
||||
|
||||
// ConfluenceExportDoc 导出 Confluence 文档
|
||||
func (c *Client) ConfluenceExportDoc(ctx context.Context, uuid, docID, kbId string) (*ConfluenceExportDocResponse, error) {
|
||||
u, err := url.Parse(crawlerServiceHost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.Path = ConfluenceExportPath
|
||||
requestURL := u.String()
|
||||
|
||||
bodyMap := map[string]interface{}{
|
||||
"uuid": uuid,
|
||||
"doc_id": docID,
|
||||
"uploader": map[string]interface{}{
|
||||
"type": uploaderTypeHTTP,
|
||||
"http": map[string]interface{}{
|
||||
"url": apiUploaderUrl,
|
||||
},
|
||||
"dir": fmt.Sprintf("/%s", kbId),
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(bodyMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.logger.Info("ConfluenceExportDoc", "requestURL:", requestURL, "resp", string(respBody))
|
||||
|
||||
var exportResp ConfluenceExportDocResponse
|
||||
err = json.Unmarshal(respBody, &exportResp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !exportResp.Success {
|
||||
return nil, errors.New(exportResp.Msg)
|
||||
}
|
||||
|
||||
return &exportResp, nil
|
||||
}
|
||||
70
backend/pkg/anydoc/dingtalk.go
Normal file
70
backend/pkg/anydoc/dingtalk.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package anydoc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
const (
|
||||
dingtalkListPath = "/api/docs/dingtalk/list"
|
||||
dingtalkExportPath = "/api/docs/dingtalk/export"
|
||||
)
|
||||
|
||||
// DingtalkListDocs 获取 dingtalk 文档列表
|
||||
func (c *Client) DingtalkListDocs(ctx context.Context, uuid string, dingtalkSetting DingtalkSetting) (*ListDocResponse, error) {
|
||||
u, err := url.Parse(crawlerServiceHost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.Path = dingtalkListPath
|
||||
requestURL := u.String()
|
||||
|
||||
bodyMap := map[string]interface{}{
|
||||
"uuid": uuid,
|
||||
"app_id": dingtalkSetting.AppID,
|
||||
"app_secret": dingtalkSetting.AppSecret,
|
||||
"unionid": dingtalkSetting.UnionID,
|
||||
"space_id": dingtalkSetting.SpaceID,
|
||||
"phone": dingtalkSetting.Phone,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(bodyMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.logger.Info("dingtalkListDocs", "requestURL:", requestURL, "resp", string(respBody))
|
||||
|
||||
var dingtalkResp ListDocResponse
|
||||
err = json.Unmarshal(respBody, &dingtalkResp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !dingtalkResp.Success {
|
||||
return nil, errors.New(dingtalkResp.Msg)
|
||||
}
|
||||
|
||||
return &dingtalkResp, nil
|
||||
}
|
||||
173
backend/pkg/anydoc/epub.go
Normal file
173
backend/pkg/anydoc/epub.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package anydoc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
const (
|
||||
epubpListPath = "/api/docs/epubp/list"
|
||||
epubpExportPath = "/api/docs/epubp/export"
|
||||
)
|
||||
|
||||
// EpubpListDocsRequest Epubp 获取文档列表请求
|
||||
type EpubpListDocsRequest struct {
|
||||
URL string `json:"url"` // Epubp 配置文件
|
||||
Filename string `json:"filename"` // 文件名,需要带扩展名
|
||||
UUID string `json:"uuid"` // 必填的唯一标识符
|
||||
}
|
||||
|
||||
// EpubpListDocsResponse Epubp 获取文档列表响应
|
||||
type EpubpListDocsResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Msg string `json:"msg"`
|
||||
Data EpubpListDocsData `json:"data"`
|
||||
}
|
||||
|
||||
// EpubpListDocsData Epubp 文档列表数据
|
||||
type EpubpListDocsData struct {
|
||||
Docs []EpubpDoc `json:"docs"`
|
||||
}
|
||||
|
||||
// EpubpDoc Epubp 文档信息
|
||||
type EpubpDoc struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
// EpubpExportDocRequest Epubp 导出文档请求
|
||||
type EpubpExportDocRequest struct {
|
||||
UUID string `json:"uuid"` // 必须与 list 接口使用的 uuid 相同
|
||||
DocID string `json:"doc_id"` // epubp-doc-id
|
||||
}
|
||||
|
||||
// EpubpExportDocResponse Epubp 导出文档响应
|
||||
type EpubpExportDocResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Msg string `json:"msg"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// EpubpExportDocData Epubp 导出文档数据
|
||||
type EpubpExportDocData struct {
|
||||
TaskID string `json:"task_id"`
|
||||
Status string `json:"status"`
|
||||
FilePath string `json:"file_path"`
|
||||
}
|
||||
|
||||
// EpubpListDocs 获取 Epubp 文档列表
|
||||
func (c *Client) EpubpListDocs(ctx context.Context, epubpURL, filename, uuid string) (*ListDocResponse, error) {
|
||||
u, err := url.Parse(crawlerServiceHost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.Path = epubpListPath
|
||||
requestURL := u.String()
|
||||
|
||||
bodyMap := map[string]interface{}{
|
||||
"url": epubpURL,
|
||||
"filename": filename,
|
||||
"uuid": uuid,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(bodyMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.logger.Info("EpubpListDocs", "requestURL:", requestURL, "resp", string(respBody))
|
||||
|
||||
var epubpResp ListDocResponse
|
||||
err = json.Unmarshal(respBody, &epubpResp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !epubpResp.Success {
|
||||
return nil, errors.New(epubpResp.Msg)
|
||||
}
|
||||
|
||||
return &epubpResp, nil
|
||||
}
|
||||
|
||||
// EpubpExportDoc 导出 Epubp 文档
|
||||
func (c *Client) EpubpExportDoc(ctx context.Context, uuid, docID, kbId string) (*EpubpExportDocResponse, error) {
|
||||
u, err := url.Parse(crawlerServiceHost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.Path = epubpExportPath
|
||||
requestURL := u.String()
|
||||
|
||||
bodyMap := map[string]interface{}{
|
||||
"uuid": uuid,
|
||||
"doc_id": docID,
|
||||
"uploader": map[string]interface{}{
|
||||
"type": uploaderTypeHTTP,
|
||||
"http": map[string]interface{}{
|
||||
"url": apiUploaderUrl,
|
||||
},
|
||||
"dir": fmt.Sprintf("/%s", kbId),
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(bodyMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.logger.Info("EpubpExportDoc", "requestURL:", requestURL, "resp", string(respBody))
|
||||
|
||||
var exportResp EpubpExportDocResponse
|
||||
err = json.Unmarshal(respBody, &exportResp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !exportResp.Success {
|
||||
return nil, errors.New(exportResp.Msg)
|
||||
}
|
||||
|
||||
return &exportResp, nil
|
||||
}
|
||||
175
backend/pkg/anydoc/feishu.go
Normal file
175
backend/pkg/anydoc/feishu.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package anydoc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
const (
|
||||
feishuListPath = "/api/docs/feishu/list"
|
||||
feishuExportPath = "/api/docs/feishu/export"
|
||||
)
|
||||
|
||||
// FeishuListDocsRequest Feishu 获取文档列表请求
|
||||
type FeishuListDocsRequest struct {
|
||||
URL string `json:"url"` // Feishu 配置文件
|
||||
Filename string `json:"filename"` // 文件名,需要带扩展名
|
||||
UUID string `json:"uuid"` // 必填的唯一标识符
|
||||
}
|
||||
|
||||
// FeishuListDocsResponse Feishu 获取文档列表响应
|
||||
type FeishuListDocsResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Msg string `json:"msg"`
|
||||
Data FeishuListDocsData `json:"data"`
|
||||
}
|
||||
|
||||
// FeishuListDocsData Feishu 文档列表数据
|
||||
type FeishuListDocsData struct {
|
||||
Docs []FeishuDoc `json:"docs"`
|
||||
}
|
||||
|
||||
// FeishuDoc Feishu 文档信息
|
||||
type FeishuDoc struct {
|
||||
ID string `json:"id"`
|
||||
FileType string `json:"file_type"`
|
||||
Title string `json:"title"`
|
||||
Summary string `json:"summary"`
|
||||
}
|
||||
|
||||
// FeishuExportDocRequest Feishu 导出文档请求
|
||||
type FeishuExportDocRequest struct {
|
||||
UUID string `json:"uuid"` // 必须与 list 接口使用的 uuid 相同
|
||||
DocID string `json:"doc_id"` // feishu-doc-id
|
||||
}
|
||||
|
||||
// FeishuExportDocResponse Feishu 导出文档响应
|
||||
type FeishuExportDocResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Msg string `json:"msg"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// FeishuExportDocData Feishu 导出文档数据
|
||||
type FeishuExportDocData struct {
|
||||
TaskID string `json:"task_id"`
|
||||
Status string `json:"status"`
|
||||
FilePath string `json:"file_path"`
|
||||
}
|
||||
|
||||
// FeishuListDocs 获取 Feishu 文档列表
|
||||
func (c *Client) FeishuListDocs(ctx context.Context, uuid, appId, appSecret, accessToken, spaceId string) (*ListDocResponse, error) {
|
||||
u, err := url.Parse(crawlerServiceHost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.Path = feishuListPath
|
||||
|
||||
q := u.Query()
|
||||
q.Set("uuid", uuid)
|
||||
q.Set("app_id", appId)
|
||||
q.Set("app_secret", appSecret)
|
||||
q.Set("access_token", accessToken)
|
||||
if spaceId != "" {
|
||||
q.Set("space_id", spaceId)
|
||||
}
|
||||
u.RawQuery = q.Encode()
|
||||
requestURL := u.String()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.logger.Info("FeishuListDocs", "requestURL:", requestURL, "resp", string(respBody))
|
||||
|
||||
var feishuResp ListDocResponse
|
||||
err = json.Unmarshal(respBody, &feishuResp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !feishuResp.Success {
|
||||
return nil, errors.New(feishuResp.Msg)
|
||||
}
|
||||
|
||||
return &feishuResp, nil
|
||||
}
|
||||
|
||||
// FeishuExportDoc 导出 Feishu 文档
|
||||
func (c *Client) FeishuExportDoc(ctx context.Context, uuid, docID, fileType, spaceId, kbId string) (*UrlExportRes, error) {
|
||||
u, err := url.Parse(crawlerServiceHost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.Path = feishuExportPath
|
||||
requestURL := u.String()
|
||||
|
||||
bodyMap := map[string]interface{}{
|
||||
"uuid": uuid,
|
||||
"doc_id": docID,
|
||||
"file_type": fileType,
|
||||
"space_id": spaceId,
|
||||
"uploader": map[string]interface{}{
|
||||
"type": uploaderTypeHTTP,
|
||||
"http": map[string]interface{}{
|
||||
"url": apiUploaderUrl,
|
||||
},
|
||||
"dir": fmt.Sprintf("/%s", kbId),
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(bodyMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.logger.Info("FeishuDoc", "requestURL:", requestURL, "body", string(jsonData), "resp", string(respBody))
|
||||
|
||||
var exportResp UrlExportRes
|
||||
err = json.Unmarshal(respBody, &exportResp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !exportResp.Success {
|
||||
return nil, errors.New(exportResp.Msg)
|
||||
}
|
||||
|
||||
return &exportResp, nil
|
||||
}
|
||||
173
backend/pkg/anydoc/mindoc.go
Normal file
173
backend/pkg/anydoc/mindoc.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package anydoc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
const (
|
||||
mindocListPath = "/api/docs/mindoc/list"
|
||||
mindocExportPath = "/api/docs/mindoc/export"
|
||||
)
|
||||
|
||||
// MindocListDocsRequest Mindoc 获取文档列表请求
|
||||
type MindocListDocsRequest struct {
|
||||
URL string `json:"url"` // Mindoc 配置文件
|
||||
Filename string `json:"filename"` // 文件名,需要带扩展名
|
||||
UUID string `json:"uuid"` // 必填的唯一标识符
|
||||
}
|
||||
|
||||
// MindocListDocsResponse Mindoc 获取文档列表响应
|
||||
type MindocListDocsResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Msg string `json:"msg"`
|
||||
Data MindocListDocsData `json:"data"`
|
||||
}
|
||||
|
||||
// MindocListDocsData Mindoc 文档列表数据
|
||||
type MindocListDocsData struct {
|
||||
Docs []MindocDoc `json:"docs"`
|
||||
}
|
||||
|
||||
// MindocDoc Mindoc 文档信息
|
||||
type MindocDoc struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
// MindocExportDocRequest Mindoc 导出文档请求
|
||||
type MindocExportDocRequest struct {
|
||||
UUID string `json:"uuid"` // 必须与 list 接口使用的 uuid 相同
|
||||
DocID string `json:"doc_id"` // mindoc-doc-id
|
||||
}
|
||||
|
||||
// MindocExportDocResponse Mindoc 导出文档响应
|
||||
type MindocExportDocResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Msg string `json:"msg"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// MindocExportDocData Mindoc 导出文档数据
|
||||
type MindocExportDocData struct {
|
||||
TaskID string `json:"task_id"`
|
||||
Status string `json:"status"`
|
||||
FilePath string `json:"file_path"`
|
||||
}
|
||||
|
||||
// MindocListDocs 获取 Mindoc 文档列表
|
||||
func (c *Client) MindocListDocs(ctx context.Context, mindocURL, filename, uuid string) (*ListDocResponse, error) {
|
||||
u, err := url.Parse(crawlerServiceHost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.Path = mindocListPath
|
||||
requestURL := u.String()
|
||||
|
||||
bodyMap := map[string]interface{}{
|
||||
"url": mindocURL,
|
||||
"filename": filename,
|
||||
"uuid": uuid,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(bodyMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.logger.Info("MindocListDocs", "requestURL:", requestURL, "resp", string(respBody))
|
||||
|
||||
var mindocResp ListDocResponse
|
||||
err = json.Unmarshal(respBody, &mindocResp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !mindocResp.Success {
|
||||
return nil, errors.New(mindocResp.Msg)
|
||||
}
|
||||
|
||||
return &mindocResp, nil
|
||||
}
|
||||
|
||||
// MindocExportDoc 导出 Mindoc 文档
|
||||
func (c *Client) MindocExportDoc(ctx context.Context, uuid, docID, kbId string) (*MindocExportDocResponse, error) {
|
||||
u, err := url.Parse(crawlerServiceHost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.Path = mindocExportPath
|
||||
requestURL := u.String()
|
||||
|
||||
bodyMap := map[string]interface{}{
|
||||
"uuid": uuid,
|
||||
"doc_id": docID,
|
||||
"uploader": map[string]interface{}{
|
||||
"type": uploaderTypeHTTP,
|
||||
"http": map[string]interface{}{
|
||||
"url": apiUploaderUrl,
|
||||
},
|
||||
"dir": fmt.Sprintf("/%s", kbId),
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(bodyMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.logger.Info("MindocExportDoc", "requestURL:", requestURL, "resp", string(respBody))
|
||||
|
||||
var exportResp MindocExportDocResponse
|
||||
err = json.Unmarshal(respBody, &exportResp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !exportResp.Success {
|
||||
return nil, errors.New(exportResp.Msg)
|
||||
}
|
||||
|
||||
return &exportResp, nil
|
||||
}
|
||||
148
backend/pkg/anydoc/notion.go
Normal file
148
backend/pkg/anydoc/notion.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package anydoc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
const (
|
||||
notionListPath = "/api/docs/notion/list"
|
||||
notionExportPath = "/api/docs/notion/export"
|
||||
)
|
||||
|
||||
// NotionListDocsResponse Notion 获取文档列表响应
|
||||
type NotionListDocsResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Msg string `json:"msg"`
|
||||
Data NotionListDocsData `json:"data"`
|
||||
}
|
||||
|
||||
// NotionListDocsData Notion 文档列表数据
|
||||
type NotionListDocsData struct {
|
||||
Docs []NotionDoc `json:"docs"`
|
||||
}
|
||||
|
||||
// NotionDoc Notion 文档信息
|
||||
type NotionDoc struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
// NotionExportDocResponse Notion 导出文档响应
|
||||
type NotionExportDocResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Msg string `json:"msg"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// NotionListDocs 获取 Notion 文档列表
|
||||
func (c *Client) NotionListDocs(ctx context.Context, secret, uuid string) (*ListDocResponse, error) {
|
||||
u, err := url.Parse(crawlerServiceHost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.Path = notionListPath
|
||||
|
||||
q := u.Query()
|
||||
q.Set("uuid", uuid)
|
||||
q.Set("secret", secret)
|
||||
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
requestURL := u.String()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.logger.Info("NotionListDocs", "requestURL:", requestURL, "resp", string(respBody))
|
||||
|
||||
var notionResp ListDocResponse
|
||||
err = json.Unmarshal(respBody, ¬ionResp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !notionResp.Success {
|
||||
return nil, errors.New(notionResp.Msg)
|
||||
}
|
||||
|
||||
return ¬ionResp, nil
|
||||
}
|
||||
|
||||
// NotionExportDoc 导出 Notion 文档
|
||||
func (c *Client) NotionExportDoc(ctx context.Context, uuid, docID, kbId string) (*NotionExportDocResponse, error) {
|
||||
u, err := url.Parse(crawlerServiceHost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.Path = notionExportPath
|
||||
requestURL := u.String()
|
||||
|
||||
bodyMap := map[string]interface{}{
|
||||
"uuid": uuid,
|
||||
"doc_id": docID,
|
||||
"uploader": map[string]interface{}{
|
||||
"type": uploaderTypeHTTP,
|
||||
"http": map[string]interface{}{
|
||||
"url": apiUploaderUrl,
|
||||
},
|
||||
"dir": fmt.Sprintf("/%s", kbId),
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(bodyMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.logger.Info("NotionExportDoc", "requestURL:", requestURL, "resp", string(respBody))
|
||||
|
||||
var exportResp NotionExportDocResponse
|
||||
err = json.Unmarshal(respBody, &exportResp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !exportResp.Success {
|
||||
return nil, errors.New(exportResp.Msg)
|
||||
}
|
||||
|
||||
return &exportResp, nil
|
||||
}
|
||||
16
backend/pkg/anydoc/req.go
Normal file
16
backend/pkg/anydoc/req.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package anydoc
|
||||
|
||||
type FeishuSetting struct {
|
||||
UserAccessToken string `json:"user_access_token"`
|
||||
AppID string `json:"app_id"`
|
||||
AppSecret string `json:"app_secret"`
|
||||
SpaceId string `json:"space_id"`
|
||||
}
|
||||
|
||||
type DingtalkSetting struct {
|
||||
AppID string `json:"app_id"`
|
||||
AppSecret string `json:"app_secret"`
|
||||
SpaceID string `json:"space_id"`
|
||||
UnionID string `json:"unionid"`
|
||||
Phone string `json:"phone"`
|
||||
}
|
||||
63
backend/pkg/anydoc/res.go
Normal file
63
backend/pkg/anydoc/res.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package anydoc
|
||||
|
||||
type GetUrlListResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Data GetUrlListData `json:"data"`
|
||||
Msg string `json:"msg"`
|
||||
Err string `json:"err"`
|
||||
TraceId interface{} `json:"trace_id"`
|
||||
}
|
||||
type GetUrlListData struct {
|
||||
Docs []struct {
|
||||
Id string `json:"id"`
|
||||
FileType string `json:"file_type"`
|
||||
Title string `json:"title"`
|
||||
Summary string `json:"summary"`
|
||||
} `json:"docs"`
|
||||
}
|
||||
|
||||
type UrlExportRes struct {
|
||||
Success bool `json:"success"`
|
||||
Data string `json:"data"`
|
||||
Msg string `json:"msg"`
|
||||
Err string `json:"err"`
|
||||
TraceId interface{} `json:"trace_id"`
|
||||
}
|
||||
type TaskRes struct {
|
||||
Success bool `json:"success"`
|
||||
Data []struct {
|
||||
TaskId string `json:"task_id"`
|
||||
PlatformId string `json:"platform_id"`
|
||||
DocId string `json:"doc_id"`
|
||||
Status Status `json:"status"`
|
||||
Err string `json:"err"`
|
||||
Markdown string `json:"markdown"`
|
||||
Json string `json:"json"`
|
||||
} `json:"data"`
|
||||
Msg string `json:"msg"`
|
||||
}
|
||||
|
||||
type ListDocResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Data ListDocsData `json:"data"`
|
||||
Msg string `json:"msg"`
|
||||
Err string `json:"err"`
|
||||
TraceID string `json:"trace_id"`
|
||||
}
|
||||
|
||||
type ListDocsData struct {
|
||||
Docs Child `json:"docs"`
|
||||
}
|
||||
|
||||
type Value struct {
|
||||
ID string `json:"id"`
|
||||
File bool `json:"file"`
|
||||
FileType string `json:"file_type"`
|
||||
Title string `json:"title"`
|
||||
Summary string `json:"summary"`
|
||||
}
|
||||
|
||||
type Child struct {
|
||||
Value Value `json:"value"`
|
||||
Children []Child `json:"children"`
|
||||
}
|
||||
161
backend/pkg/anydoc/rss.go
Normal file
161
backend/pkg/anydoc/rss.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package anydoc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
const (
|
||||
rssListPath = "/api/docs/rss/list"
|
||||
rssExportPath = "/api/docs/rss/export"
|
||||
)
|
||||
|
||||
// RssListDocsResponse Rss 获取文档列表响应
|
||||
type RssListDocsResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Msg string `json:"msg"`
|
||||
Data RssListDocsData `json:"data"`
|
||||
}
|
||||
|
||||
// RssListDocsData Rss 文档列表数据
|
||||
type RssListDocsData struct {
|
||||
Docs []RssDoc `json:"docs"`
|
||||
}
|
||||
|
||||
// RssDoc Rss 文档信息
|
||||
type RssDoc struct {
|
||||
Id string `json:"id"`
|
||||
FileType string `json:"file_type"`
|
||||
Title string `json:"title"`
|
||||
Summary string `json:"summary"`
|
||||
}
|
||||
|
||||
// RssExportDocRequest Rss 导出文档请求
|
||||
type RssExportDocRequest struct {
|
||||
UUID string `json:"uuid"` // 必须与 list 接口使用的 uuid 相同
|
||||
DocID string `json:"doc_id"` // rss-doc-id
|
||||
}
|
||||
|
||||
// RssExportDocResponse Rss 导出文档响应
|
||||
type RssExportDocResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Msg string `json:"msg"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// RssExportDocData Rss 导出文档数据
|
||||
type RssExportDocData struct {
|
||||
TaskID string `json:"task_id"`
|
||||
Status string `json:"status"`
|
||||
FilePath string `json:"file_path"`
|
||||
}
|
||||
|
||||
// RssListDocs 获取 Rss 文档列表
|
||||
func (c *Client) RssListDocs(ctx context.Context, xmlUrl, uuid string) (*ListDocResponse, error) {
|
||||
u, err := url.Parse(crawlerServiceHost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.Path = rssListPath
|
||||
|
||||
q := u.Query()
|
||||
q.Set("uuid", uuid)
|
||||
q.Set("url", xmlUrl)
|
||||
u.RawQuery = q.Encode()
|
||||
requestURL := u.String()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.logger.Info("RssListDocs", "requestURL:", requestURL, "resp", string(respBody))
|
||||
|
||||
var rssResp ListDocResponse
|
||||
err = json.Unmarshal(respBody, &rssResp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !rssResp.Success {
|
||||
return nil, errors.New(rssResp.Msg)
|
||||
}
|
||||
|
||||
return &rssResp, nil
|
||||
}
|
||||
|
||||
// RssExportDoc 导出 Rss 文档
|
||||
func (c *Client) RssExportDoc(ctx context.Context, uuid, docID, kbId string) (*RssExportDocResponse, error) {
|
||||
u, err := url.Parse(crawlerServiceHost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.Path = rssExportPath
|
||||
requestURL := u.String()
|
||||
|
||||
bodyMap := map[string]interface{}{
|
||||
"uuid": uuid,
|
||||
"doc_id": docID,
|
||||
"uploader": map[string]interface{}{
|
||||
"type": uploaderTypeHTTP,
|
||||
"http": map[string]interface{}{
|
||||
"url": apiUploaderUrl,
|
||||
},
|
||||
"dir": fmt.Sprintf("/%s", kbId),
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(bodyMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.logger.Info("RssExportDoc", "requestURL:", requestURL, "resp", string(respBody))
|
||||
|
||||
var exportResp RssExportDocResponse
|
||||
err = json.Unmarshal(respBody, &exportResp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !exportResp.Success {
|
||||
return nil, errors.New(exportResp.Msg)
|
||||
}
|
||||
|
||||
return &exportResp, nil
|
||||
}
|
||||
161
backend/pkg/anydoc/sitemap.go
Normal file
161
backend/pkg/anydoc/sitemap.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package anydoc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
const (
|
||||
sitemapListPath = "/api/docs/sitemap/list"
|
||||
sitemapExportPath = "/api/docs/sitemap/export"
|
||||
)
|
||||
|
||||
// SitemapListDocsResponse Sitemap 获取文档列表响应
|
||||
type SitemapListDocsResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Msg string `json:"msg"`
|
||||
Data SitemapListDocsData `json:"data"`
|
||||
}
|
||||
|
||||
// SitemapListDocsData Sitemap 文档列表数据
|
||||
type SitemapListDocsData struct {
|
||||
Docs []SitemapDoc `json:"docs"`
|
||||
}
|
||||
|
||||
// SitemapDoc Sitemap 文档信息
|
||||
type SitemapDoc struct {
|
||||
Id string `json:"id"`
|
||||
FileType string `json:"file_type"`
|
||||
Title string `json:"title"`
|
||||
Summary string `json:"summary"`
|
||||
}
|
||||
|
||||
// SitemapExportDocRequest Sitemap 导出文档请求
|
||||
type SitemapExportDocRequest struct {
|
||||
UUID string `json:"uuid"` // 必须与 list 接口使用的 uuid 相同
|
||||
DocID string `json:"doc_id"` // sitemap-doc-id
|
||||
}
|
||||
|
||||
// SitemapExportDocResponse Sitemap 导出文档响应
|
||||
type SitemapExportDocResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Msg string `json:"msg"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// SitemapExportDocData Sitemap 导出文档数据
|
||||
type SitemapExportDocData struct {
|
||||
TaskID string `json:"task_id"`
|
||||
Status string `json:"status"`
|
||||
FilePath string `json:"file_path"`
|
||||
}
|
||||
|
||||
// SitemapListDocs 获取 Sitemap 文档列表
|
||||
func (c *Client) SitemapListDocs(ctx context.Context, xmlUrl, uuid string) (*ListDocResponse, error) {
|
||||
u, err := url.Parse(crawlerServiceHost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.Path = sitemapListPath
|
||||
|
||||
q := u.Query()
|
||||
q.Set("uuid", uuid)
|
||||
q.Set("url", xmlUrl)
|
||||
u.RawQuery = q.Encode()
|
||||
requestURL := u.String()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.logger.Info("SitemapListDocs", "requestURL:", requestURL, "resp", string(respBody))
|
||||
|
||||
var sitemapResp ListDocResponse
|
||||
err = json.Unmarshal(respBody, &sitemapResp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !sitemapResp.Success {
|
||||
return nil, errors.New(sitemapResp.Msg)
|
||||
}
|
||||
|
||||
return &sitemapResp, nil
|
||||
}
|
||||
|
||||
// SitemapExportDoc 导出 Sitemap 文档
|
||||
func (c *Client) SitemapExportDoc(ctx context.Context, uuid, docID, kbId string) (*SitemapExportDocResponse, error) {
|
||||
u, err := url.Parse(crawlerServiceHost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.Path = sitemapExportPath
|
||||
requestURL := u.String()
|
||||
|
||||
bodyMap := map[string]interface{}{
|
||||
"uuid": uuid,
|
||||
"doc_id": docID,
|
||||
"uploader": map[string]interface{}{
|
||||
"type": uploaderTypeHTTP,
|
||||
"http": map[string]interface{}{
|
||||
"url": apiUploaderUrl,
|
||||
},
|
||||
"dir": fmt.Sprintf("/%s", kbId),
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(bodyMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.logger.Info("SitemapExportDoc", "requestURL:", requestURL, "resp", string(respBody))
|
||||
|
||||
var exportResp SitemapExportDocResponse
|
||||
err = json.Unmarshal(respBody, &exportResp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !exportResp.Success {
|
||||
return nil, errors.New(exportResp.Msg)
|
||||
}
|
||||
|
||||
return &exportResp, nil
|
||||
}
|
||||
173
backend/pkg/anydoc/siyuan.go
Normal file
173
backend/pkg/anydoc/siyuan.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package anydoc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
const (
|
||||
siyuanListPath = "/api/docs/siyuan/list"
|
||||
siyuanExportPath = "/api/docs/siyuan/export"
|
||||
)
|
||||
|
||||
// SiyuanListDocsRequest Siyuan 获取文档列表请求
|
||||
type SiyuanListDocsRequest struct {
|
||||
URL string `json:"url"` // Siyuan 配置文件
|
||||
Filename string `json:"filename"` // 文件名,需要带扩展名
|
||||
UUID string `json:"uuid"` // 必填的唯一标识符
|
||||
}
|
||||
|
||||
// SiyuanListDocsResponse Siyuan 获取文档列表响应
|
||||
type SiyuanListDocsResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Msg string `json:"msg"`
|
||||
Data SiyuanListDocsData `json:"data"`
|
||||
}
|
||||
|
||||
// SiyuanListDocsData Siyuan 文档列表数据
|
||||
type SiyuanListDocsData struct {
|
||||
Docs []SiyuanDoc `json:"docs"`
|
||||
}
|
||||
|
||||
// SiyuanDoc Siyuan 文档信息
|
||||
type SiyuanDoc struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
// SiyuanExportDocRequest Siyuan 导出文档请求
|
||||
type SiyuanExportDocRequest struct {
|
||||
UUID string `json:"uuid"` // 必须与 list 接口使用的 uuid 相同
|
||||
DocID string `json:"doc_id"` // siyuan-doc-id
|
||||
}
|
||||
|
||||
// SiyuanExportDocResponse Siyuan 导出文档响应
|
||||
type SiyuanExportDocResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Msg string `json:"msg"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// SiyuanExportDocData Siyuan 导出文档数据
|
||||
type SiyuanExportDocData struct {
|
||||
TaskID string `json:"task_id"`
|
||||
Status string `json:"status"`
|
||||
FilePath string `json:"file_path"`
|
||||
}
|
||||
|
||||
// SiyuanListDocs 获取 Siyuan 文档列表
|
||||
func (c *Client) SiyuanListDocs(ctx context.Context, siyuanURL, filename, uuid string) (*ListDocResponse, error) {
|
||||
u, err := url.Parse(crawlerServiceHost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.Path = siyuanListPath
|
||||
requestURL := u.String()
|
||||
|
||||
bodyMap := map[string]interface{}{
|
||||
"url": siyuanURL,
|
||||
"filename": filename,
|
||||
"uuid": uuid,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(bodyMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.logger.Info("SiyuanListDocs", "requestURL:", requestURL, "resp", string(respBody))
|
||||
|
||||
var siyuanResp ListDocResponse
|
||||
err = json.Unmarshal(respBody, &siyuanResp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !siyuanResp.Success {
|
||||
return nil, errors.New(siyuanResp.Msg)
|
||||
}
|
||||
|
||||
return &siyuanResp, nil
|
||||
}
|
||||
|
||||
// SiyuanExportDoc 导出 Siyuan 文档
|
||||
func (c *Client) SiyuanExportDoc(ctx context.Context, uuid, docID, kbId string) (*SiyuanExportDocResponse, error) {
|
||||
u, err := url.Parse(crawlerServiceHost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.Path = siyuanExportPath
|
||||
requestURL := u.String()
|
||||
|
||||
bodyMap := map[string]interface{}{
|
||||
"uuid": uuid,
|
||||
"doc_id": docID,
|
||||
"uploader": map[string]interface{}{
|
||||
"type": uploaderTypeHTTP,
|
||||
"http": map[string]interface{}{
|
||||
"url": apiUploaderUrl,
|
||||
},
|
||||
"dir": fmt.Sprintf("/%s", kbId),
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(bodyMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.logger.Info("SiyuanExportDoc", "requestURL:", requestURL, "resp", string(respBody))
|
||||
|
||||
var exportResp SiyuanExportDocResponse
|
||||
err = json.Unmarshal(respBody, &exportResp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !exportResp.Success {
|
||||
return nil, errors.New(exportResp.Msg)
|
||||
}
|
||||
|
||||
return &exportResp, nil
|
||||
}
|
||||
154
backend/pkg/anydoc/wikijs.go
Normal file
154
backend/pkg/anydoc/wikijs.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package anydoc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
const (
|
||||
wikijsListPath = "/api/docs/wikijs/list"
|
||||
wikijsExportPath = "/api/docs/wikijs/export"
|
||||
)
|
||||
|
||||
// WikijsListDocsRequest Wikijs 获取文档列表请求
|
||||
type WikijsListDocsRequest struct {
|
||||
URL string `json:"url"` // Wikijs 配置文件
|
||||
Filename string `json:"filename"` // 文件名,需要带扩展名
|
||||
UUID string `json:"uuid"` // 必填的唯一标识符
|
||||
}
|
||||
|
||||
// WikijsExportDocRequest Wikijs 导出文档请求
|
||||
type WikijsExportDocRequest struct {
|
||||
UUID string `json:"uuid"` // 必须与 list 接口使用的 uuid 相同
|
||||
DocID string `json:"doc_id"` // wikijs-doc-id
|
||||
}
|
||||
|
||||
// WikijsExportDocResponse Wikijs 导出文档响应
|
||||
type WikijsExportDocResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Msg string `json:"msg"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// WikijsExportDocData Wikijs 导出文档数据
|
||||
type WikijsExportDocData struct {
|
||||
TaskID string `json:"task_id"`
|
||||
Status string `json:"status"`
|
||||
FilePath string `json:"file_path"`
|
||||
}
|
||||
|
||||
// WikijsListDocs 获取 Wikijs 文档列表
|
||||
func (c *Client) WikijsListDocs(ctx context.Context, wikijsURL, filename, uuid string) (*ListDocResponse, error) {
|
||||
u, err := url.Parse(crawlerServiceHost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.Path = wikijsListPath
|
||||
requestURL := u.String()
|
||||
|
||||
bodyMap := map[string]interface{}{
|
||||
"url": wikijsURL,
|
||||
"filename": filename,
|
||||
"uuid": uuid,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(bodyMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.logger.Info("WikijsListDocs", "requestURL:", requestURL, "resp", string(respBody))
|
||||
|
||||
var wikijsResp ListDocResponse
|
||||
err = json.Unmarshal(respBody, &wikijsResp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !wikijsResp.Success {
|
||||
return nil, errors.New(wikijsResp.Msg)
|
||||
}
|
||||
|
||||
return &wikijsResp, nil
|
||||
}
|
||||
|
||||
// WikijsExportDoc 导出 Wikijs 文档
|
||||
func (c *Client) WikijsExportDoc(ctx context.Context, uuid, docID, kbId string) (*WikijsExportDocResponse, error) {
|
||||
u, err := url.Parse(crawlerServiceHost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.Path = wikijsExportPath
|
||||
requestURL := u.String()
|
||||
|
||||
bodyMap := map[string]interface{}{
|
||||
"uuid": uuid,
|
||||
"doc_id": docID,
|
||||
"uploader": map[string]interface{}{
|
||||
"type": uploaderTypeHTTP,
|
||||
"http": map[string]interface{}{
|
||||
"url": apiUploaderUrl,
|
||||
},
|
||||
"dir": fmt.Sprintf("/%s", kbId),
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(bodyMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.logger.Info("WikijsExportDoc", "requestURL:", requestURL, "resp", string(respBody))
|
||||
|
||||
var exportResp WikijsExportDocResponse
|
||||
err = json.Unmarshal(respBody, &exportResp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !exportResp.Success {
|
||||
return nil, errors.New(exportResp.Msg)
|
||||
}
|
||||
|
||||
return &exportResp, nil
|
||||
}
|
||||
165
backend/pkg/anydoc/yuque.go
Normal file
165
backend/pkg/anydoc/yuque.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package anydoc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
const (
|
||||
yuqueListPath = "/api/docs/yuque/list"
|
||||
yuqueExportPath = "/api/docs/yuque/export"
|
||||
)
|
||||
|
||||
// YuqueListDocsRequest Yuque 获取文档列表请求
|
||||
type YuqueListDocsRequest struct {
|
||||
URL string `json:"url"` // Yuque 配置文件
|
||||
Filename string `json:"filename"` // 文件名,需要带扩展名
|
||||
UUID string `json:"uuid"` // 必填的唯一标识符
|
||||
}
|
||||
|
||||
// YuqueListDocsResponse Yuque 获取文档列表响应
|
||||
type YuqueListDocsResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Msg string `json:"msg"`
|
||||
Data YuqueListDocsData `json:"data"`
|
||||
}
|
||||
|
||||
// YuqueListDocsData Yuque 文档列表数据
|
||||
type YuqueListDocsData struct {
|
||||
Docs []YuqueDoc `json:"docs"`
|
||||
}
|
||||
|
||||
// YuqueDoc Yuque 文档信息
|
||||
type YuqueDoc struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
// YuqueExportDocRequest Yuque 导出文档请求
|
||||
type YuqueExportDocRequest struct {
|
||||
UUID string `json:"uuid"` // 必须与 list 接口使用的 uuid 相同
|
||||
DocID string `json:"doc_id"` // yuque-doc-id
|
||||
}
|
||||
|
||||
// YuqueExportDocResponse Yuque 导出文档响应
|
||||
type YuqueExportDocResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Msg string `json:"msg"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// YuqueListDocs 获取 Yuque 文档列表
|
||||
func (c *Client) YuqueListDocs(ctx context.Context, yuqueURL, filename, uuid string) (*ListDocResponse, error) {
|
||||
u, err := url.Parse(crawlerServiceHost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.Path = yuqueListPath
|
||||
requestURL := u.String()
|
||||
|
||||
bodyMap := map[string]interface{}{
|
||||
"url": yuqueURL,
|
||||
"filename": filename,
|
||||
"uuid": uuid,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(bodyMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.logger.Info("YuqueListDocs", "requestURL:", requestURL, "resp", string(respBody))
|
||||
|
||||
var yuqueResp ListDocResponse
|
||||
err = json.Unmarshal(respBody, &yuqueResp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !yuqueResp.Success {
|
||||
return nil, fmt.Errorf("yuque list docs API failed - URL: %s, UUID: %s, Error: %s", yuqueURL, uuid, yuqueResp.Msg)
|
||||
}
|
||||
|
||||
return &yuqueResp, nil
|
||||
}
|
||||
|
||||
// YuqueExportDoc 导出 Yuque 文档
|
||||
func (c *Client) YuqueExportDoc(ctx context.Context, uuid, docID, kbId string) (*YuqueExportDocResponse, error) {
|
||||
u, err := url.Parse(crawlerServiceHost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.Path = yuqueExportPath
|
||||
requestURL := u.String()
|
||||
|
||||
bodyMap := map[string]interface{}{
|
||||
"uuid": uuid,
|
||||
"doc_id": docID,
|
||||
"uploader": map[string]interface{}{
|
||||
"type": uploaderTypeHTTP,
|
||||
"http": map[string]interface{}{
|
||||
"url": apiUploaderUrl,
|
||||
},
|
||||
"dir": fmt.Sprintf("/%s", kbId),
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(bodyMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.logger.Info("YuqueExportDoc", "requestURL:", requestURL, "resp", string(respBody))
|
||||
|
||||
var exportResp YuqueExportDocResponse
|
||||
err = json.Unmarshal(respBody, &exportResp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !exportResp.Success {
|
||||
return nil, fmt.Errorf("yuque export doc API failed - UUID: %s, DocID: %s, Error: %s", uuid, docID, exportResp.Msg)
|
||||
}
|
||||
|
||||
return &exportResp, nil
|
||||
}
|
||||
9
backend/pkg/bot/common.go
Normal file
9
backend/pkg/bot/common.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package bot
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/chaitin/panda-wiki/domain"
|
||||
)
|
||||
|
||||
type GetQAFun func(ctx context.Context, msg string, info domain.ConversationInfo, ConversationID string) (chan string, error)
|
||||
502
backend/pkg/bot/dingtalk/stream.go
Normal file
502
backend/pkg/bot/dingtalk/stream.go
Normal file
@@ -0,0 +1,502 @@
|
||||
package dingtalk
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
openapi "github.com/alibabacloud-go/darabonba-openapi/v2/client"
|
||||
dingtalkcard_1_0 "github.com/alibabacloud-go/dingtalk/card_1_0"
|
||||
dingtalkoauth2_1_0 "github.com/alibabacloud-go/dingtalk/oauth2_1_0"
|
||||
util "github.com/alibabacloud-go/tea-utils/v2/service"
|
||||
"github.com/alibabacloud-go/tea/tea"
|
||||
"github.com/google/uuid"
|
||||
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
|
||||
"github.com/open-dingtalk/dingtalk-stream-sdk-go/client"
|
||||
|
||||
"github.com/chaitin/panda-wiki/domain"
|
||||
"github.com/chaitin/panda-wiki/log"
|
||||
"github.com/chaitin/panda-wiki/pkg/bot"
|
||||
)
|
||||
|
||||
type DingTalkClient struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
clientID string
|
||||
clientSecret string
|
||||
templateID string // 4d18414c-aabc-4ec8-9e67-4ceefeada72a.schema
|
||||
oauthClient *dingtalkoauth2_1_0.Client
|
||||
cardClient *dingtalkcard_1_0.Client
|
||||
getQA bot.GetQAFun
|
||||
logger *log.Logger
|
||||
tokenCache struct {
|
||||
accessToken string
|
||||
expireAt time.Time
|
||||
}
|
||||
tokenMutex sync.RWMutex
|
||||
messageMu sync.Mutex
|
||||
messageSeenAt map[string]messageMark
|
||||
messageTTL time.Duration
|
||||
nowFunc func() time.Time
|
||||
processMessageFn func(ctx context.Context, data *chatbot.BotCallbackDataModel) error
|
||||
}
|
||||
|
||||
type messageMark struct {
|
||||
seenAt time.Time
|
||||
inFlight bool
|
||||
}
|
||||
|
||||
func NewDingTalkClient(ctx context.Context, cancel context.CancelFunc, clientId, clientSecret, templateID string, logger *log.Logger, getQA bot.GetQAFun) (*DingTalkClient, error) {
|
||||
config := &openapi.Config{}
|
||||
config.Protocol = tea.String("https")
|
||||
config.RegionId = tea.String("central")
|
||||
oauthClient, err := dingtalkoauth2_1_0.NewClient(config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create oauth client: %w", err)
|
||||
}
|
||||
cardClient, err := dingtalkcard_1_0.NewClient(config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create card client: %w", err)
|
||||
}
|
||||
client := &DingTalkClient{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
clientID: clientId,
|
||||
clientSecret: clientSecret,
|
||||
templateID: templateID,
|
||||
oauthClient: oauthClient,
|
||||
cardClient: cardClient,
|
||||
getQA: getQA,
|
||||
logger: logger,
|
||||
messageSeenAt: make(map[string]messageMark),
|
||||
messageTTL: 5 * time.Minute,
|
||||
nowFunc: time.Now,
|
||||
}
|
||||
client.startMessageCleanup()
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *DingTalkClient) GetAccessToken() (string, error) {
|
||||
c.tokenMutex.RLock()
|
||||
// TODO: use redis cache
|
||||
if c.tokenCache.accessToken != "" && time.Now().Before(c.tokenCache.expireAt) {
|
||||
token := c.tokenCache.accessToken
|
||||
c.tokenMutex.RUnlock()
|
||||
return token, nil
|
||||
}
|
||||
c.tokenMutex.RUnlock()
|
||||
|
||||
c.tokenMutex.Lock()
|
||||
defer c.tokenMutex.Unlock()
|
||||
|
||||
if c.tokenCache.accessToken != "" && time.Now().Before(c.tokenCache.expireAt) {
|
||||
return c.tokenCache.accessToken, nil
|
||||
}
|
||||
|
||||
request := &dingtalkoauth2_1_0.GetAccessTokenRequest{
|
||||
AppKey: tea.String(c.clientID),
|
||||
AppSecret: tea.String(c.clientSecret),
|
||||
}
|
||||
response, tryErr := func() (_resp *dingtalkoauth2_1_0.GetAccessTokenResponse, _e error) {
|
||||
defer func() {
|
||||
if r := tea.Recover(recover()); r != nil {
|
||||
_e = r
|
||||
}
|
||||
}()
|
||||
_resp, _err := c.oauthClient.GetAccessToken(request)
|
||||
if _err != nil {
|
||||
return nil, _err
|
||||
}
|
||||
|
||||
return _resp, nil
|
||||
}()
|
||||
if tryErr != nil {
|
||||
return "", tryErr
|
||||
}
|
||||
accessToken := *response.Body.AccessToken
|
||||
c.logger.Info("get access token", log.String("access_token", accessToken), log.Int("expire_in", int(*response.Body.ExpireIn)))
|
||||
c.tokenCache.accessToken = accessToken
|
||||
c.tokenCache.expireAt = time.Now().Add(time.Duration(*response.Body.ExpireIn-300) * time.Second)
|
||||
|
||||
return c.tokenCache.accessToken, nil
|
||||
}
|
||||
|
||||
func (c *DingTalkClient) UpdateAIStreamCard(trackID, content string, isFinalize bool) error {
|
||||
accessToken, err := c.GetAccessToken()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get access token while updating interactive card: %w", err)
|
||||
}
|
||||
|
||||
headers := &dingtalkcard_1_0.StreamingUpdateHeaders{
|
||||
XAcsDingtalkAccessToken: tea.String(accessToken),
|
||||
}
|
||||
request := &dingtalkcard_1_0.StreamingUpdateRequest{
|
||||
OutTrackId: tea.String(trackID),
|
||||
Guid: tea.String(uuid.New().String()),
|
||||
Key: tea.String("content"),
|
||||
Content: tea.String(content),
|
||||
IsFull: tea.Bool(true),
|
||||
IsFinalize: tea.Bool(isFinalize),
|
||||
IsError: tea.Bool(false),
|
||||
}
|
||||
_, err = c.cardClient.StreamingUpdateWithOptions(request, headers, &util.RuntimeOptions{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update card: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *DingTalkClient) CreateAndDeliverCard(ctx context.Context, trackID string, data *chatbot.BotCallbackDataModel) error {
|
||||
accessToken, err := c.GetAccessToken()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get access token while creating and delivering card: %w", err)
|
||||
}
|
||||
|
||||
createAndDeliverHeaders := &dingtalkcard_1_0.CreateAndDeliverHeaders{}
|
||||
createAndDeliverHeaders.XAcsDingtalkAccessToken = tea.String(accessToken)
|
||||
|
||||
cardDataCardParamMap := map[string]*string{
|
||||
"content": tea.String(""),
|
||||
}
|
||||
cardData := &dingtalkcard_1_0.CreateAndDeliverRequestCardData{
|
||||
CardParamMap: cardDataCardParamMap,
|
||||
}
|
||||
|
||||
createAndDeliverRequest := &dingtalkcard_1_0.CreateAndDeliverRequest{
|
||||
CardTemplateId: tea.String(c.templateID),
|
||||
OutTrackId: tea.String(trackID),
|
||||
CardData: cardData,
|
||||
CallbackType: tea.String("STREAM"),
|
||||
ImGroupOpenSpaceModel: &dingtalkcard_1_0.CreateAndDeliverRequestImGroupOpenSpaceModel{
|
||||
SupportForward: tea.Bool(true),
|
||||
},
|
||||
ImRobotOpenSpaceModel: &dingtalkcard_1_0.CreateAndDeliverRequestImRobotOpenSpaceModel{
|
||||
SupportForward: tea.Bool(true),
|
||||
},
|
||||
UserIdType: tea.Int32(1),
|
||||
}
|
||||
switch data.ConversationType {
|
||||
case "2": // 群聊
|
||||
openSpaceId := fmt.Sprintf("dtv1.card//%s.%s", "IM_GROUP", data.ConversationId)
|
||||
createAndDeliverRequest.SetOpenSpaceId(openSpaceId)
|
||||
createAndDeliverRequest.SetImGroupOpenDeliverModel(
|
||||
&dingtalkcard_1_0.CreateAndDeliverRequestImGroupOpenDeliverModel{
|
||||
RobotCode: tea.String(c.clientID),
|
||||
})
|
||||
case "1": // Im机器人单聊
|
||||
openSpaceId := fmt.Sprintf("dtv1.card//%s.%s", "IM_ROBOT", data.SenderStaffId)
|
||||
createAndDeliverRequest.SetOpenSpaceId(openSpaceId)
|
||||
createAndDeliverRequest.SetImRobotOpenDeliverModel(&dingtalkcard_1_0.CreateAndDeliverRequestImRobotOpenDeliverModel{
|
||||
SpaceType: tea.String("IM_GROUP"),
|
||||
})
|
||||
default:
|
||||
return fmt.Errorf("invalid conversation type: %s", data.ConversationType)
|
||||
}
|
||||
|
||||
_, err = c.cardClient.CreateAndDeliverWithOptions(createAndDeliverRequest, createAndDeliverHeaders, &util.RuntimeOptions{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create and deliver card: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *DingTalkClient) startMessageCleanup() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
c.cleanupExpiredMessages()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (c *DingTalkClient) cleanupExpiredMessages() {
|
||||
now := c.nowFunc()
|
||||
|
||||
c.messageMu.Lock()
|
||||
defer c.messageMu.Unlock()
|
||||
|
||||
for msgID, mark := range c.messageSeenAt {
|
||||
if mark.inFlight {
|
||||
continue
|
||||
}
|
||||
if now.Sub(mark.seenAt) > c.messageTTL {
|
||||
delete(c.messageSeenAt, msgID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *DingTalkClient) tryMarkMessage(msgID string) bool {
|
||||
if strings.TrimSpace(msgID) == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
now := c.nowFunc()
|
||||
|
||||
c.messageMu.Lock()
|
||||
defer c.messageMu.Unlock()
|
||||
|
||||
if mark, ok := c.messageSeenAt[msgID]; ok {
|
||||
if mark.inFlight || now.Sub(mark.seenAt) <= c.messageTTL {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
c.messageSeenAt[msgID] = messageMark{
|
||||
seenAt: now,
|
||||
inFlight: true,
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *DingTalkClient) markMessageCompleted(msgID string) {
|
||||
if strings.TrimSpace(msgID) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
c.messageMu.Lock()
|
||||
defer c.messageMu.Unlock()
|
||||
|
||||
c.messageSeenAt[msgID] = messageMark{
|
||||
seenAt: c.nowFunc(),
|
||||
inFlight: false,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *DingTalkClient) clearMessageMark(msgID string) {
|
||||
if strings.TrimSpace(msgID) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
c.messageMu.Lock()
|
||||
defer c.messageMu.Unlock()
|
||||
|
||||
delete(c.messageSeenAt, msgID)
|
||||
}
|
||||
|
||||
func (c *DingTalkClient) OnChatBotMessageReceived(ctx context.Context, data *chatbot.BotCallbackDataModel) ([]byte, error) {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
c.logger.Info("dingtalk bot is disabled, ignoring message", log.String("client_id", c.clientID))
|
||||
return nil, nil
|
||||
default:
|
||||
}
|
||||
|
||||
if !c.tryMarkMessage(data.MsgId) {
|
||||
c.logger.Info("ignore duplicate dingtalk message", log.String("msg_id", data.MsgId))
|
||||
return []byte(""), nil
|
||||
}
|
||||
|
||||
processor := c.processMessageFn
|
||||
if processor == nil {
|
||||
processor = c.processMessage
|
||||
}
|
||||
|
||||
payload := *data
|
||||
go c.processMessageAsync(c.ctx, &payload, processor)
|
||||
|
||||
return []byte(""), nil
|
||||
}
|
||||
|
||||
func (c *DingTalkClient) processMessageAsync(ctx context.Context, data *chatbot.BotCallbackDataModel, processor func(context.Context, *chatbot.BotCallbackDataModel) error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
c.clearMessageMark(data.MsgId)
|
||||
c.logger.Error("process dingtalk message panicked", log.String("msg_id", data.MsgId), log.Any("panic", r))
|
||||
}
|
||||
}()
|
||||
|
||||
if err := processor(ctx, data); err != nil {
|
||||
c.clearMessageMark(data.MsgId)
|
||||
c.logger.Error("process dingtalk message failed", log.String("msg_id", data.MsgId), log.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
c.markMessageCompleted(data.MsgId)
|
||||
}
|
||||
|
||||
func (c *DingTalkClient) processMessage(ctx context.Context, data *chatbot.BotCallbackDataModel) error {
|
||||
question := data.Text.Content
|
||||
question = strings.TrimSpace(question)
|
||||
trackID := uuid.New().String()
|
||||
// conversation_type == 1 表示机器人单聊,==2 表示群聊中@机器人
|
||||
c.logger.Info("dingtalk client received message", log.String("question", question), log.String("track_id", trackID), log.String("conversation_type", data.ConversationType))
|
||||
// create and deliver card
|
||||
if err := c.CreateAndDeliverCard(ctx, trackID, data); err != nil {
|
||||
c.logger.Error("CreateAndDeliverCard", log.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
initialContent := fmt.Sprintf("**%s**\n\n%s", question, "稍等,让我想一想……")
|
||||
|
||||
if err := c.UpdateAIStreamCard(trackID, initialContent, false); err != nil {
|
||||
c.logger.Error("UpdateInteractiveCard", log.Error(err))
|
||||
return err
|
||||
}
|
||||
// 初始化 默认为空
|
||||
convInfo := &domain.ConversationInfo{
|
||||
UserInfo: domain.UserInfo{
|
||||
From: domain.MessageFromPrivate, // 默认是私聊
|
||||
},
|
||||
}
|
||||
// 之前创建并且发送卡片消息,获取用户基本信息
|
||||
userinfo, err := c.GetUserInfo(data.SenderStaffId)
|
||||
if err != nil {
|
||||
c.logger.Error("GetUserInfo failed", log.Error(err))
|
||||
} else {
|
||||
c.logger.Info("GetUserInfo success", log.Any("userinfo", userinfo))
|
||||
convInfo.UserInfo.UserID = userinfo.Result.Userid
|
||||
convInfo.UserInfo.NickName = userinfo.Result.Name
|
||||
convInfo.UserInfo.Avatar = userinfo.Result.Avatar
|
||||
convInfo.UserInfo.Email = userinfo.Result.Email
|
||||
}
|
||||
if data.ConversationType == "2" { // 群聊
|
||||
convInfo.UserInfo.From = domain.MessageFromGroup
|
||||
} else { // 单聊
|
||||
convInfo.UserInfo.From = domain.MessageFromPrivate
|
||||
}
|
||||
|
||||
contentCh, err := c.getQA(ctx, question, *convInfo, "")
|
||||
if err != nil {
|
||||
c.logger.Error("dingtalk client failed to get answer", log.Error(err))
|
||||
if updateErr := c.UpdateAIStreamCard(trackID, "出错了,请稍后再试", true); updateErr != nil {
|
||||
c.logger.Error("UpdateInteractiveCard in contentCh failed", log.Error(updateErr))
|
||||
return fmt.Errorf("get answer failed: %w; update error card failed: %w", err, updateErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
updateTicker := time.NewTicker(1500 * time.Millisecond)
|
||||
defer updateTicker.Stop()
|
||||
|
||||
ans := fmt.Sprintf("**%s**\n\n", question)
|
||||
fullContent := fmt.Sprintf("**%s**\n\n", question)
|
||||
for {
|
||||
select {
|
||||
case content, ok := <-contentCh:
|
||||
if !ok {
|
||||
if err := c.UpdateAIStreamCard(trackID, fullContent, true); err != nil {
|
||||
c.logger.Error("UpdateInteractiveCard in contentCh", log.Error(err))
|
||||
if updateErr := c.UpdateAIStreamCard(trackID, "出错了,请稍后再试", true); updateErr != nil {
|
||||
c.logger.Error("UpdateInteractiveCard in contentCh failed", log.Error(updateErr))
|
||||
return fmt.Errorf("final update card failed: %w; fallback update failed: %w", err, updateErr)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
fullContent += content
|
||||
case <-updateTicker.C:
|
||||
if fullContent == ans {
|
||||
continue
|
||||
}
|
||||
if err := c.UpdateAIStreamCard(trackID, fullContent, false); err != nil {
|
||||
c.logger.Error("UpdateInteractiveCard in ticker", log.Error(err))
|
||||
if updateErr := c.UpdateAIStreamCard(trackID, "出错了,请稍后再试", true); updateErr != nil {
|
||||
c.logger.Error("UpdateInteractiveCard in ticker failed", log.Error(updateErr))
|
||||
return fmt.Errorf("stream update card failed: %w; fallback update failed: %w", err, updateErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *DingTalkClient) Start() error {
|
||||
cli := client.NewStreamClient(client.WithAppCredential(client.NewAppCredentialConfig(
|
||||
c.clientID,
|
||||
c.clientSecret,
|
||||
)))
|
||||
cli.RegisterChatBotCallbackRouter(c.OnChatBotMessageReceived)
|
||||
if err := cli.Start(c.ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
<-c.ctx.Done()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *DingTalkClient) Stop() {
|
||||
c.cancel()
|
||||
}
|
||||
|
||||
// 钉钉的用户信息
|
||||
type UserDetailResponse struct {
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
Result UserDetails `json:"result"`
|
||||
}
|
||||
|
||||
type UserDetails struct {
|
||||
Unionid string `json:"unionid"`
|
||||
Userid string `json:"userid"`
|
||||
Name string `json:"name"`
|
||||
Avatar string `json:"avatar"`
|
||||
Mobile string `json:"mobile"`
|
||||
Email string `json:"email"`
|
||||
Title string `json:"title"`
|
||||
Active bool `json:"active"`
|
||||
Admin bool `json:"admin"`
|
||||
Boss bool `json:"boss"`
|
||||
DeptIDList []int64 `json:"dept_id_list"`
|
||||
JobNumber string `json:"job_number"`
|
||||
HiredDate int64 `json:"hired_date"`
|
||||
ManagerUserid string `json:"manager_userid"`
|
||||
}
|
||||
|
||||
// 使用原始的http请求来获取用户的信息 - > 需要设置获取用户的权限功能:企业员工手机号信息和邮箱等个人信息、成员信息读权限
|
||||
func (c *DingTalkClient) GetUserInfo(userID string) (*UserDetailResponse, error) {
|
||||
accessToken, err := c.GetAccessToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get access token while creating and delivering card: %w", err)
|
||||
}
|
||||
// 1. 构建URL和请求体
|
||||
url := "https://oapi.dingtalk.com/topapi/v2/user/get"
|
||||
payload := map[string]string{"userid": userID, "language": "zh_CN"} // 默认是中文
|
||||
jsonPayload, _ := json.Marshal(payload)
|
||||
|
||||
req, _ := http.NewRequest("POST", url, bytes.NewBuffer(jsonPayload))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
query := req.URL.Query()
|
||||
query.Add("access_token", accessToken)
|
||||
req.URL.RawQuery = query.Encode()
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
c.logger.Error("Failed to get user info from dingtalk: %v", log.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
// 获取到用户信息
|
||||
c.logger.Info("Get user info from dingtalk success", log.Any("resp 原始的消息:", resp))
|
||||
|
||||
var result UserDetailResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
c.logger.Error("Failed to unmarshal user info response: %v", log.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result.ErrCode != 0 {
|
||||
c.logger.Error("Failed to get result info", log.Any("ErrCode", result.ErrCode), log.String("ErrMsg", result.ErrMsg))
|
||||
return nil, fmt.Errorf("result.ErrCode:%d", result.ErrCode)
|
||||
}
|
||||
// success
|
||||
c.logger.Info("Get user info from dingtalk success", log.Any("userinfo:", result))
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
263
backend/pkg/bot/dingtalk/stream_test.go
Normal file
263
backend/pkg/bot/dingtalk/stream_test.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package dingtalk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"log/slog"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
pwlog "github.com/chaitin/panda-wiki/log"
|
||||
)
|
||||
|
||||
func newTestLogger() *pwlog.Logger {
|
||||
return &pwlog.Logger{
|
||||
Logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
|
||||
}
|
||||
}
|
||||
|
||||
func newTestDingTalkClient(t *testing.T) *DingTalkClient {
|
||||
t.Helper()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
|
||||
client, err := NewDingTalkClient(
|
||||
ctx,
|
||||
cancel,
|
||||
"client-id",
|
||||
"client-secret",
|
||||
"template-id",
|
||||
newTestLogger(),
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
client.messageTTL = time.Minute
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
func TestTryMarkMessageDeduplicatesWithinTTL(t *testing.T) {
|
||||
client := newTestDingTalkClient(t)
|
||||
|
||||
now := time.Now()
|
||||
client.nowFunc = func() time.Time {
|
||||
return now
|
||||
}
|
||||
|
||||
require.True(t, client.tryMarkMessage("msg-1"))
|
||||
require.False(t, client.tryMarkMessage("msg-1"))
|
||||
|
||||
client.markMessageCompleted("msg-1")
|
||||
require.False(t, client.tryMarkMessage("msg-1"))
|
||||
|
||||
now = now.Add(client.messageTTL + time.Second)
|
||||
|
||||
require.True(t, client.tryMarkMessage("msg-1"))
|
||||
}
|
||||
|
||||
func TestOnChatBotMessageReceivedIgnoresDuplicateMsgID(t *testing.T) {
|
||||
client := newTestDingTalkClient(t)
|
||||
|
||||
processed := make(chan struct{}, 2)
|
||||
client.processMessageFn = func(context.Context, *chatbot.BotCallbackDataModel) error {
|
||||
processed <- struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
data := &chatbot.BotCallbackDataModel{
|
||||
MsgId: "msg-1",
|
||||
Text: chatbot.BotCallbackDataTextModel{
|
||||
Content: "hello",
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := client.OnChatBotMessageReceived(context.Background(), data)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte(""), resp)
|
||||
|
||||
resp, err = client.OnChatBotMessageReceived(context.Background(), data)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte(""), resp)
|
||||
|
||||
select {
|
||||
case <-processed:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected first message to be processed")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-processed:
|
||||
t.Fatal("expected duplicate message to be ignored")
|
||||
case <-time.After(300 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func TestOnChatBotMessageReceivedReturnsBeforeProcessingCompletes(t *testing.T) {
|
||||
client := newTestDingTalkClient(t)
|
||||
|
||||
started := make(chan struct{})
|
||||
unblock := make(chan struct{})
|
||||
client.processMessageFn = func(context.Context, *chatbot.BotCallbackDataModel) error {
|
||||
close(started)
|
||||
<-unblock
|
||||
return nil
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
_, _ = client.OnChatBotMessageReceived(context.Background(), &chatbot.BotCallbackDataModel{
|
||||
MsgId: "msg-2",
|
||||
Text: chatbot.BotCallbackDataTextModel{
|
||||
Content: "slow question",
|
||||
},
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected callback to return before background processing finishes")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-started:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected background processing to start")
|
||||
}
|
||||
|
||||
close(unblock)
|
||||
}
|
||||
|
||||
func TestOnChatBotMessageReceivedAllowsRetryAfterProcessingError(t *testing.T) {
|
||||
client := newTestDingTalkClient(t)
|
||||
|
||||
attempts := make(chan struct{}, 4)
|
||||
client.processMessageFn = func(context.Context, *chatbot.BotCallbackDataModel) error {
|
||||
attempts <- struct{}{}
|
||||
return assert.AnError
|
||||
}
|
||||
|
||||
data := &chatbot.BotCallbackDataModel{
|
||||
MsgId: "msg-retry",
|
||||
Text: chatbot.BotCallbackDataTextModel{
|
||||
Content: "retry please",
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := client.OnChatBotMessageReceived(context.Background(), data)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte(""), resp)
|
||||
|
||||
select {
|
||||
case <-attempts:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected first message to be processed")
|
||||
}
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
_, callErr := client.OnChatBotMessageReceived(context.Background(), data)
|
||||
require.NoError(t, callErr)
|
||||
|
||||
select {
|
||||
case <-attempts:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, time.Second, 20*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestOnChatBotMessageReceivedRecoversBackgroundPanic(t *testing.T) {
|
||||
client := newTestDingTalkClient(t)
|
||||
|
||||
attempts := make(chan struct{}, 4)
|
||||
client.processMessageFn = func(context.Context, *chatbot.BotCallbackDataModel) error {
|
||||
attempts <- struct{}{}
|
||||
panic("boom")
|
||||
}
|
||||
|
||||
data := &chatbot.BotCallbackDataModel{
|
||||
MsgId: "msg-panic",
|
||||
Text: chatbot.BotCallbackDataTextModel{
|
||||
Content: "panic please",
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := client.OnChatBotMessageReceived(context.Background(), data)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte(""), resp)
|
||||
|
||||
select {
|
||||
case <-attempts:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected background processing to start")
|
||||
}
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
_, callErr := client.OnChatBotMessageReceived(context.Background(), data)
|
||||
require.NoError(t, callErr)
|
||||
|
||||
select {
|
||||
case <-attempts:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, time.Second, 20*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestOnChatBotMessageReceivedKeepsInFlightMessageMarkedPastTTL(t *testing.T) {
|
||||
client := newTestDingTalkClient(t)
|
||||
|
||||
now := time.Now()
|
||||
client.nowFunc = func() time.Time {
|
||||
return now
|
||||
}
|
||||
|
||||
processed := make(chan struct{}, 2)
|
||||
unblock := make(chan struct{})
|
||||
client.processMessageFn = func(context.Context, *chatbot.BotCallbackDataModel) error {
|
||||
processed <- struct{}{}
|
||||
<-unblock
|
||||
return nil
|
||||
}
|
||||
|
||||
data := &chatbot.BotCallbackDataModel{
|
||||
MsgId: "msg-inflight",
|
||||
Text: chatbot.BotCallbackDataTextModel{
|
||||
Content: "long running question",
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := client.OnChatBotMessageReceived(context.Background(), data)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte(""), resp)
|
||||
|
||||
select {
|
||||
case <-processed:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected first message to be processed")
|
||||
}
|
||||
|
||||
now = now.Add(client.messageTTL + time.Second)
|
||||
client.cleanupExpiredMessages()
|
||||
|
||||
resp, err = client.OnChatBotMessageReceived(context.Background(), data)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte(""), resp)
|
||||
|
||||
select {
|
||||
case <-processed:
|
||||
t.Fatal("expected in-flight duplicate message to be ignored after ttl cleanup")
|
||||
case <-time.After(300 * time.Millisecond):
|
||||
}
|
||||
|
||||
close(unblock)
|
||||
}
|
||||
30
backend/pkg/bot/discord/discord_test.go
Normal file
30
backend/pkg/bot/discord/discord_test.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package discord
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/chaitin/panda-wiki/config"
|
||||
"github.com/chaitin/panda-wiki/domain"
|
||||
"github.com/chaitin/panda-wiki/log"
|
||||
)
|
||||
|
||||
func TestDiscord(t *testing.T) {
|
||||
cfg, _ := config.NewConfig()
|
||||
log := log.NewLogger(cfg)
|
||||
token := "token"
|
||||
getQA := func(ctx context.Context, msg string, info domain.ConversationInfo, ConversationID string) (chan string, error) {
|
||||
contentCh := make(chan string, 10)
|
||||
go func() {
|
||||
defer close(contentCh)
|
||||
contentCh <- "hello " + msg
|
||||
}()
|
||||
return contentCh, nil
|
||||
}
|
||||
c, _ := NewDiscordClient(log, token, getQA)
|
||||
if err := c.Start(); err != nil {
|
||||
t.Errorf("Failed to start Discord client: %v", err)
|
||||
}
|
||||
|
||||
select {}
|
||||
}
|
||||
98
backend/pkg/bot/discord/stream.go
Normal file
98
backend/pkg/bot/discord/stream.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package discord
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/chaitin/panda-wiki/domain"
|
||||
"github.com/chaitin/panda-wiki/log"
|
||||
"github.com/chaitin/panda-wiki/pkg/bot"
|
||||
|
||||
"github.com/bwmarrin/discordgo"
|
||||
)
|
||||
|
||||
type DiscordClient struct {
|
||||
logger *log.Logger
|
||||
BotToken string
|
||||
dg *discordgo.Session
|
||||
getQA bot.GetQAFun
|
||||
}
|
||||
|
||||
func NewDiscordClient(logger *log.Logger, BotToken string, getQA bot.GetQAFun) (*DiscordClient, error) {
|
||||
dg, err := discordgo.New("Bot " + BotToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Discord session: %v", err)
|
||||
}
|
||||
return &DiscordClient{
|
||||
logger: logger.WithModule("bot.discord"),
|
||||
BotToken: BotToken,
|
||||
dg: dg,
|
||||
getQA: getQA,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *DiscordClient) Start() error {
|
||||
err := d.dg.Open()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open Discord connection: %v", err)
|
||||
}
|
||||
d.dg.AddHandler(d.handleMessage)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DiscordClient) Stop() error {
|
||||
return d.dg.Close()
|
||||
}
|
||||
|
||||
func (d *DiscordClient) handleMessage(s *discordgo.Session, m *discordgo.MessageCreate) {
|
||||
if m.Author.ID == s.State.User.ID {
|
||||
return
|
||||
}
|
||||
// 判断群聊单聊
|
||||
d.logger.Debug("接收到消息", log.String("消息内容", m.Content))
|
||||
d.logger.Debug("接收到消息", log.String("ChannelID", m.ChannelID))
|
||||
d.logger.Debug("接收到消息", log.String("GuildID", m.GuildID))
|
||||
// 只接收@ bot 的消息
|
||||
preFix := fmt.Sprintf("<@%s>", s.State.User.ID)
|
||||
if !strings.HasPrefix(m.Content, preFix) {
|
||||
return
|
||||
}
|
||||
content := strings.TrimPrefix(m.Content, preFix)
|
||||
info := domain.ConversationInfo{
|
||||
UserInfo: domain.UserInfo{
|
||||
NickName: m.Author.Username,
|
||||
Email: m.Author.Email,
|
||||
UserID: m.Author.ID,
|
||||
},
|
||||
}
|
||||
if m.GuildID != "" {
|
||||
info.UserInfo.From = domain.MessageFromGroup
|
||||
} else {
|
||||
info.UserInfo.From = domain.MessageFromPrivate
|
||||
}
|
||||
|
||||
d.logger.Debug("消息来自", log.String("用户名", m.Author.Username), log.String("ID", m.Author.ID), log.String("内容", content))
|
||||
d.logger.Debug("消息来自频道", log.String("名称", m.ChannelID))
|
||||
qaChan, err := d.getQA(context.Background(), content, info, "")
|
||||
if err != nil {
|
||||
d.logger.Error("failed to get QA", log.String("error", err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
message, err := s.ChannelMessageSend(m.ChannelID, "正在获取答案...")
|
||||
if err != nil {
|
||||
d.logger.Error("failed to send message to discord", log.String("error", err.Error()))
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
buf := strings.Builder{}
|
||||
for qa := range qaChan {
|
||||
buf.WriteString(qa)
|
||||
}
|
||||
_, err := s.ChannelMessageEdit(message.ChannelID, message.ID, buf.String())
|
||||
if err != nil {
|
||||
d.logger.Error("failed to edit message to discord", log.String("error", err.Error()))
|
||||
}
|
||||
}()
|
||||
}
|
||||
299
backend/pkg/bot/feishu/stream.go
Normal file
299
backend/pkg/bot/feishu/stream.go
Normal file
@@ -0,0 +1,299 @@
|
||||
package feishu
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
lark "github.com/larksuite/oapi-sdk-go/v3"
|
||||
"github.com/larksuite/oapi-sdk-go/v3/event/dispatcher"
|
||||
larkcardkit "github.com/larksuite/oapi-sdk-go/v3/service/cardkit/v1"
|
||||
larkcontact "github.com/larksuite/oapi-sdk-go/v3/service/contact/v3"
|
||||
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
|
||||
larkws "github.com/larksuite/oapi-sdk-go/v3/ws"
|
||||
|
||||
"github.com/chaitin/panda-wiki/domain"
|
||||
"github.com/chaitin/panda-wiki/log"
|
||||
"github.com/chaitin/panda-wiki/pkg/bot"
|
||||
)
|
||||
|
||||
type FeishuBotLogger struct {
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
func (l *FeishuBotLogger) Info(ctx context.Context, args ...interface{}) {
|
||||
l.logger.Info("feishu bot", log.Any("args", args))
|
||||
}
|
||||
|
||||
func (l *FeishuBotLogger) Error(ctx context.Context, args ...interface{}) {
|
||||
l.logger.Error("feishu bot", log.Any("args", args))
|
||||
}
|
||||
|
||||
func (l *FeishuBotLogger) Debug(ctx context.Context, args ...interface{}) {
|
||||
l.logger.Debug("feishu bot", log.Any("args", args))
|
||||
}
|
||||
|
||||
func (l *FeishuBotLogger) Warn(ctx context.Context, args ...interface{}) {
|
||||
l.logger.Warn("feishu bot", log.Any("args", args))
|
||||
}
|
||||
|
||||
type FeishuClient struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
clientID string
|
||||
clientSecret string
|
||||
logger *log.Logger
|
||||
client *lark.Client
|
||||
msgMap sync.Map
|
||||
getQA bot.GetQAFun
|
||||
}
|
||||
|
||||
func NewFeishuClient(ctx context.Context, cancel context.CancelFunc, clientID, clientSecret string, logger *log.Logger, getQA bot.GetQAFun) *FeishuClient {
|
||||
client := lark.NewClient(clientID, clientSecret, lark.WithLogger(&FeishuBotLogger{logger: logger}))
|
||||
|
||||
c := &FeishuClient{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
clientID: clientID,
|
||||
clientSecret: clientSecret,
|
||||
client: client,
|
||||
logger: logger,
|
||||
getQA: getQA,
|
||||
}
|
||||
go func() {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
c.msgMap.Range(func(key, value any) bool {
|
||||
// remove messageId if it is older than 5 minutes
|
||||
if time.Now().Unix()-value.(int64) > 5*60 {
|
||||
c.msgMap.Delete(key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
}()
|
||||
return c
|
||||
}
|
||||
|
||||
var cardDataTemplate = `{"schema":"2.0","header":{"title":{"content":"%s","tag":"plain_text"}},"config":{"streaming_mode":true,"summary":{"content":""}},"body":{"elements":[{"tag":"markdown","content":"%s","element_id":"markdown_1"}]}}`
|
||||
|
||||
func (c *FeishuClient) sendQACard(ctx context.Context, receiveIdType string, receiveId string, question string, additionalInfo string) {
|
||||
// create card
|
||||
cardData := fmt.Sprintf(cardDataTemplate, question, "稍等,让我想一想...")
|
||||
req := larkcardkit.NewCreateCardReqBuilder().
|
||||
Body(larkcardkit.NewCreateCardReqBodyBuilder().
|
||||
Type(`card_json`).
|
||||
Data(cardData).
|
||||
Build()).
|
||||
Build()
|
||||
resp, err := c.client.Cardkit.V1.Card.Create(ctx, req)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to create card", log.Error(err))
|
||||
return
|
||||
}
|
||||
if !resp.Success() {
|
||||
c.logger.Error("failed to create card", log.String("request_id", resp.RequestId()), log.Any("code_error", resp.CodeError))
|
||||
return
|
||||
}
|
||||
content, err := json.Marshal(map[string]any{
|
||||
"type": "card",
|
||||
"data": map[string]string{
|
||||
"card_id": *resp.Data.CardId,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
c.logger.Error("failed to marshal alarm card", log.Error(err))
|
||||
return
|
||||
}
|
||||
// send card to user or group
|
||||
res, err := c.client.Im.Message.Create(ctx, larkim.NewCreateMessageReqBuilder().
|
||||
ReceiveIdType(receiveIdType).
|
||||
Body(larkim.NewCreateMessageReqBodyBuilder().
|
||||
MsgType("interactive").
|
||||
ReceiveId(receiveId).
|
||||
Content(string(content)).
|
||||
Build()).
|
||||
Build())
|
||||
if err != nil {
|
||||
c.logger.Error("failed to create message", log.Error(err))
|
||||
return
|
||||
}
|
||||
if !res.Success() {
|
||||
c.logger.Error("failed to create message", log.Int("code", res.Code), log.String("msg", res.Msg), log.String("request_id", res.RequestId()))
|
||||
return
|
||||
}
|
||||
// 打印日志
|
||||
c.logger.Info("send QA card to user or group", log.String("receive_id_type", receiveIdType), log.String("receive_id", receiveId), log.String("question", question), log.String("additional_info(chat:user_openid/p2p:chat_id)", additionalInfo))
|
||||
|
||||
// start processing QA
|
||||
convInfo := domain.ConversationInfo{
|
||||
UserInfo: domain.UserInfo{
|
||||
From: domain.MessageFromPrivate, // 默认是私聊
|
||||
},
|
||||
}
|
||||
if receiveIdType == "open_id" {
|
||||
// 获取用户的信息,只需要获取p2p的对话的类型的用户信息 - p2p对话
|
||||
userinfo, err := c.GetUserInfo(receiveId)
|
||||
if err != nil {
|
||||
c.logger.Error("get user info failed", log.Error(err))
|
||||
} else {
|
||||
if userinfo.UserId != nil {
|
||||
convInfo.UserInfo.UserID = *userinfo.UserId
|
||||
}
|
||||
if userinfo.Name != nil {
|
||||
convInfo.UserInfo.NickName = *userinfo.Name
|
||||
}
|
||||
if userinfo.Avatar != nil && userinfo.Avatar.AvatarOrigin != nil {
|
||||
convInfo.UserInfo.Avatar = *userinfo.Avatar.AvatarOrigin
|
||||
}
|
||||
c.logger.Info("get user info success", log.Any("user_info", userinfo))
|
||||
}
|
||||
convInfo.UserInfo.From = domain.MessageFromPrivate // 私聊
|
||||
} else { // chat_id 中的userid
|
||||
// 获取群聊的消息,用户如果是在群聊中@机器人,那么就获取的是群聊的消息
|
||||
userinfo, err := c.GetUserInfo(additionalInfo)
|
||||
if err != nil {
|
||||
c.logger.Error("get chat info failed", log.Error(err))
|
||||
} else {
|
||||
if userinfo.UserId != nil {
|
||||
convInfo.UserInfo.UserID = *userinfo.UserId
|
||||
}
|
||||
if userinfo.Name != nil {
|
||||
convInfo.UserInfo.NickName = *userinfo.Name
|
||||
}
|
||||
if userinfo.Avatar != nil && userinfo.Avatar.AvatarOrigin != nil {
|
||||
convInfo.UserInfo.Avatar = *userinfo.Avatar.AvatarOrigin
|
||||
}
|
||||
c.logger.Info("get chat user info success", log.Any("user_info", userinfo))
|
||||
}
|
||||
convInfo.UserInfo.From = domain.MessageFromGroup // 群聊
|
||||
}
|
||||
|
||||
answerCh, err := c.getQA(ctx, question, convInfo, "")
|
||||
if err != nil {
|
||||
c.logger.Error("get QA failed", log.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
answer := ""
|
||||
seq := 1
|
||||
for chunk := range answerCh {
|
||||
seq += 1
|
||||
answer += chunk
|
||||
// 部分模型存在输出为空的情况导致飞书报错
|
||||
if strings.TrimSpace(chunk) == "" {
|
||||
continue
|
||||
}
|
||||
// update card content streaming
|
||||
updateReq := larkcardkit.NewContentCardElementReqBuilder().
|
||||
CardId(*resp.Data.CardId).
|
||||
ElementId(`markdown_1`).
|
||||
Body(larkcardkit.NewContentCardElementReqBodyBuilder().
|
||||
Uuid(uuid.New().String()).
|
||||
Content(answer).
|
||||
Sequence(seq).
|
||||
Build()).
|
||||
Build()
|
||||
updateResp, err := c.client.Cardkit.V1.CardElement.Content(ctx, updateReq)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to update card", log.Error(err))
|
||||
return
|
||||
}
|
||||
if !updateResp.Success() {
|
||||
c.logger.Error("failed to update card", log.String("request_id", updateResp.RequestId()), log.Any("code_error", updateResp.CodeError))
|
||||
return
|
||||
}
|
||||
}
|
||||
c.logger.Info("start processing QA", log.String("message_id", *res.Data.MessageId))
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
func (c *FeishuClient) Start() error {
|
||||
eventHandler := dispatcher.NewEventDispatcher("", "").
|
||||
OnP2MessageReceiveV1(func(ctx context.Context, event *larkim.P2MessageReceiveV1) error {
|
||||
// ignore duplicate message
|
||||
if *event.Event.Message.MessageId == "" {
|
||||
return nil
|
||||
}
|
||||
messageId := *event.Event.Message.MessageId
|
||||
if _, ok := c.msgMap.Load(messageId); ok {
|
||||
return nil
|
||||
}
|
||||
c.msgMap.Store(messageId, time.Now().Unix())
|
||||
c.logger.Info("received message from feishu bot", log.String("message_id", messageId))
|
||||
// only handle text type
|
||||
if *event.Event.Message.MessageType != "text" {
|
||||
return nil
|
||||
}
|
||||
switch *event.Event.Message.ChatType {
|
||||
case "group":
|
||||
var message Message
|
||||
if err := json.Unmarshal([]byte(*event.Event.Message.Content), &message); err != nil {
|
||||
c.logger.Error("failed to unmarshal message", log.Error(err))
|
||||
return nil
|
||||
}
|
||||
c.sendQACard(ctx, "chat_id", *event.Event.Message.ChatId, message.Text, *event.Event.Sender.SenderId.OpenId)
|
||||
case "p2p":
|
||||
var message Message
|
||||
if err := json.Unmarshal([]byte(*event.Event.Message.Content), &message); err != nil {
|
||||
c.logger.Error("failed to unmarshal message", log.Error(err))
|
||||
return nil
|
||||
}
|
||||
c.sendQACard(ctx, "open_id", *event.Event.Sender.SenderId.OpenId, message.Text, *event.Event.Message.ChatId)
|
||||
default:
|
||||
c.logger.Warn("unsupported chat type", log.String("chat_type", *event.Event.Message.ChatType))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
cli := larkws.NewClient(c.clientID, c.clientSecret,
|
||||
larkws.WithEventHandler(eventHandler),
|
||||
larkws.WithLogger(&FeishuBotLogger{logger: c.logger}),
|
||||
)
|
||||
// FIXME: goroutine leak in larkws.Start
|
||||
err := cli.Start(c.ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start feishu client: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 下面功能都是需要开启飞书对应的权限才可以获取到用户信息 -- 应用权限(否则获取不到对话用户的信息)
|
||||
|
||||
// 飞书机器人获取用户信息,只是适用于单个用户
|
||||
func (c *FeishuClient) GetUserInfo(UserOpenId string) (*larkcontact.User, error) {
|
||||
// 获取用户信息,根据用户的id
|
||||
req := larkcontact.NewGetUserReqBuilder().UserId(UserOpenId).
|
||||
UserIdType(`open_id`).DepartmentIdType(`open_department_id`).Build()
|
||||
// 发起请求,获取用户消息
|
||||
resp, err := c.client.Contact.User.Get(context.Background(), req)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to get user info", log.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 失败
|
||||
if !resp.Success() {
|
||||
c.logger.Error("failed to get user info, response status not success", log.Any("errcode:", resp.Code))
|
||||
return nil, fmt.Errorf("failed to get user info, response data not success")
|
||||
}
|
||||
|
||||
return resp.Data.User, nil
|
||||
}
|
||||
|
||||
func (c *FeishuClient) Stop() {
|
||||
c.cancel()
|
||||
}
|
||||
346
backend/pkg/bot/lark/client.go
Normal file
346
backend/pkg/bot/lark/client.go
Normal file
@@ -0,0 +1,346 @@
|
||||
package lark
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
lark "github.com/larksuite/oapi-sdk-go/v3"
|
||||
"github.com/larksuite/oapi-sdk-go/v3/event/dispatcher"
|
||||
larkcardkit "github.com/larksuite/oapi-sdk-go/v3/service/cardkit/v1"
|
||||
larkcontact "github.com/larksuite/oapi-sdk-go/v3/service/contact/v3"
|
||||
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
|
||||
|
||||
"github.com/chaitin/panda-wiki/domain"
|
||||
"github.com/chaitin/panda-wiki/log"
|
||||
"github.com/chaitin/panda-wiki/pkg/bot"
|
||||
)
|
||||
|
||||
// LarkBotLogger implements Lark SDK logger interface
|
||||
type LarkBotLogger struct {
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
func (l *LarkBotLogger) Info(ctx context.Context, args ...interface{}) {
|
||||
l.logger.Info("lark bot", log.Any("args", args))
|
||||
}
|
||||
|
||||
func (l *LarkBotLogger) Error(ctx context.Context, args ...interface{}) {
|
||||
l.logger.Error("lark bot", log.Any("args", args))
|
||||
}
|
||||
|
||||
func (l *LarkBotLogger) Debug(ctx context.Context, args ...interface{}) {
|
||||
l.logger.Debug("lark bot", log.Any("args", args))
|
||||
}
|
||||
|
||||
func (l *LarkBotLogger) Warn(ctx context.Context, args ...interface{}) {
|
||||
l.logger.Warn("lark bot", log.Any("args", args))
|
||||
}
|
||||
|
||||
// LarkClient is a Lark bot client using larksuite SDK (configured for Lark international endpoints)
|
||||
// Note: Lark uses HTTP callbacks instead of WebSocket for event handling
|
||||
type LarkClient struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
clientID string
|
||||
clientSecret string
|
||||
logger *log.Logger
|
||||
client *lark.Client
|
||||
msgMap sync.Map
|
||||
getQA bot.GetQAFun
|
||||
eventHandler *dispatcher.EventDispatcher
|
||||
verifyToken string
|
||||
encryptKey string
|
||||
}
|
||||
|
||||
// NewLarkClient creates a new Lark bot client
|
||||
// Lark is the international version of Feishu, using different API endpoints
|
||||
// Unlike Feishu (China), Lark (International) uses HTTP callbacks instead of WebSocket
|
||||
func NewLarkClient(ctx context.Context, cancel context.CancelFunc, clientID, clientSecret, verifyToken, encryptKey string, logger *log.Logger, getQA bot.GetQAFun) (*LarkClient, error) {
|
||||
// Create client with Lark (international) domain
|
||||
client := lark.NewClient(clientID, clientSecret,
|
||||
lark.WithLogger(&LarkBotLogger{logger: logger}),
|
||||
lark.WithOpenBaseUrl("https://open.larksuite.com"), // Lark international endpoint
|
||||
)
|
||||
|
||||
c := &LarkClient{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
clientID: clientID,
|
||||
clientSecret: clientSecret,
|
||||
client: client,
|
||||
logger: logger,
|
||||
getQA: getQA,
|
||||
verifyToken: verifyToken,
|
||||
encryptKey: encryptKey,
|
||||
}
|
||||
|
||||
// Setup event handler for HTTP callbacks
|
||||
c.setupEventHandler()
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
c.msgMap.Range(func(key, value any) bool {
|
||||
// remove messageId if it is older than 5 minutes
|
||||
if time.Now().Unix()-value.(int64) > 5*60 {
|
||||
c.msgMap.Delete(key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
}()
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// setupEventHandler configures the event dispatcher for handling HTTP callbacks
|
||||
func (c *LarkClient) setupEventHandler() {
|
||||
c.eventHandler = dispatcher.NewEventDispatcher(c.verifyToken, c.encryptKey).
|
||||
OnP2MessageReceiveV1(func(ctx context.Context, event *larkim.P2MessageReceiveV1) error {
|
||||
if *event.Event.Message.MessageId == "" {
|
||||
return nil
|
||||
}
|
||||
messageId := *event.Event.Message.MessageId
|
||||
if _, ok := c.msgMap.Load(messageId); ok {
|
||||
return nil
|
||||
}
|
||||
c.msgMap.Store(messageId, time.Now().Unix())
|
||||
c.logger.Info("received message from lark bot", log.String("message_id", messageId))
|
||||
if *event.Event.Message.MessageType != "text" {
|
||||
return nil
|
||||
}
|
||||
switch *event.Event.Message.ChatType {
|
||||
case "group":
|
||||
var message Message
|
||||
if err := json.Unmarshal([]byte(*event.Event.Message.Content), &message); err != nil {
|
||||
c.logger.Error("failed to unmarshal message", log.Error(err))
|
||||
return nil
|
||||
}
|
||||
// Replace mention placeholders with actual user names
|
||||
questionText := c.replaceMentions(message.Text, event.Event.Message.Mentions)
|
||||
go c.sendQACard(c.ctx, "chat_id", *event.Event.Message.ChatId, questionText, *event.Event.Sender.SenderId.OpenId)
|
||||
case "p2p":
|
||||
var message Message
|
||||
if err := json.Unmarshal([]byte(*event.Event.Message.Content), &message); err != nil {
|
||||
c.logger.Error("failed to unmarshal message", log.Error(err))
|
||||
return nil
|
||||
}
|
||||
go c.sendQACard(c.ctx, "open_id", *event.Event.Sender.SenderId.OpenId, message.Text, *event.Event.Message.ChatId)
|
||||
default:
|
||||
c.logger.Warn("unsupported chat type", log.String("chat_type", *event.Event.Message.ChatType))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// GetEventHandler returns the event dispatcher for HTTP callback handling
|
||||
// This should be registered with the HTTP server to handle Lark callbacks
|
||||
func (c *LarkClient) GetEventHandler() *dispatcher.EventDispatcher {
|
||||
return c.eventHandler
|
||||
}
|
||||
|
||||
var cardDataTemplate = `{"schema":"2.0","header":{"title":{"content":"%s","tag":"plain_text"}},"config":{"streaming_mode":true,"summary":{"content":""}},"body":{"elements":[{"tag":"markdown","content":"%s","element_id":"markdown_1"}]}}`
|
||||
|
||||
func (c *LarkClient) sendQACard(ctx context.Context, receiveIdType string, receiveId string, question string, additionalInfo string) {
|
||||
// create card
|
||||
cardData := fmt.Sprintf(cardDataTemplate, question, "稍等,让我想一想...")
|
||||
req := larkcardkit.NewCreateCardReqBuilder().
|
||||
Body(larkcardkit.NewCreateCardReqBodyBuilder().
|
||||
Type(`card_json`).
|
||||
Data(cardData).
|
||||
Build()).
|
||||
Build()
|
||||
resp, err := c.client.Cardkit.V1.Card.Create(ctx, req)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to create card", log.Error(err))
|
||||
return
|
||||
}
|
||||
if !resp.Success() {
|
||||
c.logger.Error("failed to create card", log.String("request_id", resp.RequestId()), log.Any("code_error", resp.CodeError))
|
||||
return
|
||||
}
|
||||
content, err := json.Marshal(map[string]any{
|
||||
"type": "card",
|
||||
"data": map[string]string{
|
||||
"card_id": *resp.Data.CardId,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
c.logger.Error("failed to marshal alarm card", log.Error(err))
|
||||
return
|
||||
}
|
||||
// send card to user or group
|
||||
res, err := c.client.Im.Message.Create(ctx, larkim.NewCreateMessageReqBuilder().
|
||||
ReceiveIdType(receiveIdType).
|
||||
Body(larkim.NewCreateMessageReqBodyBuilder().
|
||||
MsgType("interactive").
|
||||
ReceiveId(receiveId).
|
||||
Content(string(content)).
|
||||
Build()).
|
||||
Build())
|
||||
if err != nil {
|
||||
c.logger.Error("failed to create message", log.Error(err))
|
||||
return
|
||||
}
|
||||
if !res.Success() {
|
||||
c.logger.Error("failed to create message", log.Int("code", res.Code), log.String("msg", res.Msg), log.String("request_id", res.RequestId()))
|
||||
return
|
||||
}
|
||||
c.logger.Info("send QA card to user or group", log.String("receive_id_type", receiveIdType), log.String("receive_id", receiveId), log.String("question", question), log.String("additional_info", additionalInfo))
|
||||
|
||||
// start processing QA
|
||||
convInfo := domain.ConversationInfo{
|
||||
UserInfo: domain.UserInfo{
|
||||
From: domain.MessageFromPrivate,
|
||||
},
|
||||
}
|
||||
if receiveIdType == "open_id" {
|
||||
userinfo, err := c.GetUserInfo(receiveId)
|
||||
if err != nil {
|
||||
c.logger.Error("get user info failed", log.Error(err))
|
||||
} else {
|
||||
if userinfo.UserId != nil {
|
||||
convInfo.UserInfo.UserID = *userinfo.UserId
|
||||
}
|
||||
if userinfo.Name != nil {
|
||||
convInfo.UserInfo.NickName = *userinfo.Name
|
||||
}
|
||||
if userinfo.Avatar != nil && userinfo.Avatar.AvatarOrigin != nil {
|
||||
convInfo.UserInfo.Avatar = *userinfo.Avatar.AvatarOrigin
|
||||
}
|
||||
c.logger.Info("get user info success", log.Any("user_info", userinfo))
|
||||
}
|
||||
convInfo.UserInfo.From = domain.MessageFromPrivate
|
||||
} else {
|
||||
userinfo, err := c.GetUserInfo(additionalInfo)
|
||||
if err != nil {
|
||||
c.logger.Error("get chat info failed", log.Error(err))
|
||||
} else {
|
||||
if userinfo.UserId != nil {
|
||||
convInfo.UserInfo.UserID = *userinfo.UserId
|
||||
}
|
||||
if userinfo.Name != nil {
|
||||
convInfo.UserInfo.NickName = *userinfo.Name
|
||||
}
|
||||
if userinfo.Avatar != nil && userinfo.Avatar.AvatarOrigin != nil {
|
||||
convInfo.UserInfo.Avatar = *userinfo.Avatar.AvatarOrigin
|
||||
}
|
||||
c.logger.Info("get chat user info success", log.Any("user_info", userinfo))
|
||||
}
|
||||
convInfo.UserInfo.From = domain.MessageFromGroup
|
||||
}
|
||||
|
||||
answerCh, err := c.getQA(ctx, question, convInfo, "")
|
||||
if err != nil {
|
||||
c.logger.Error("lark client failed to get answer", log.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
var buf strings.Builder
|
||||
seq := 0
|
||||
imageRegex := regexp.MustCompile(`!\[[^\]]*\]\([^)]+\)`)
|
||||
sendUpdate := func() error {
|
||||
seq++
|
||||
answer := imageRegex.ReplaceAllString(buf.String(), "")
|
||||
updateReq := larkcardkit.NewContentCardElementReqBuilder().
|
||||
CardId(*resp.Data.CardId).
|
||||
ElementId(`markdown_1`).
|
||||
Body(larkcardkit.NewContentCardElementReqBodyBuilder().
|
||||
Uuid(uuid.New().String()).
|
||||
Content(answer).
|
||||
Sequence(seq).
|
||||
Build()).
|
||||
Build()
|
||||
updateResp, err := c.client.Cardkit.V1.CardElement.Content(ctx, updateReq)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to update card", log.Error(err))
|
||||
return err
|
||||
}
|
||||
if !updateResp.Success() {
|
||||
c.logger.Error("failed to update card", log.String("request_id", updateResp.RequestId()), log.Any("code_error", updateResp.CodeError))
|
||||
return fmt.Errorf("update card failed: %v", updateResp.CodeError)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
for chunk := range answerCh {
|
||||
buf.WriteString(chunk)
|
||||
// drain all currently available chunks
|
||||
for len(answerCh) > 0 {
|
||||
buf.WriteString(<-answerCh)
|
||||
}
|
||||
if err := sendUpdate(); err != nil {
|
||||
c.logger.Error("lark client failed to send QA update", log.Error(err), log.Int("sequence", seq))
|
||||
return
|
||||
}
|
||||
}
|
||||
c.logger.Info("start processing QA", log.String("message_id", *res.Data.MessageId))
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// replaceMentions replaces mention placeholders like @_user_1 with actual user names
|
||||
func (c *LarkClient) replaceMentions(text string, mentions []*larkim.MentionEvent) string {
|
||||
if len(mentions) == 0 {
|
||||
return text
|
||||
}
|
||||
|
||||
result := text
|
||||
for _, mention := range mentions {
|
||||
if mention.Key != nil && mention.Name != nil {
|
||||
// Replace @_user_1, @_user_2, etc. with @ActualUserName
|
||||
result = strings.ReplaceAll(result, *mention.Key, "@"+*mention.Name)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Start initializes the Lark bot client
|
||||
// Note: Unlike Feishu, Lark doesn't use WebSocket. Events are handled via HTTP callbacks.
|
||||
// The actual HTTP endpoint needs to be registered separately in the HTTP router.
|
||||
func (c *LarkClient) Start() error {
|
||||
c.logger.Info("lark bot client initialized (HTTP callback mode)",
|
||||
log.String("app_id", c.clientID),
|
||||
log.String("note", "Register HTTP callback endpoint to receive events"))
|
||||
|
||||
// For Lark, we don't start a WebSocket connection
|
||||
// Events will be received via HTTP callbacks handled by GetEventHandler()
|
||||
// Just keep the context alive
|
||||
<-c.ctx.Done()
|
||||
c.logger.Info("lark bot client stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *LarkClient) GetUserInfo(UserOpenId string) (*larkcontact.User, error) {
|
||||
req := larkcontact.NewGetUserReqBuilder().UserId(UserOpenId).
|
||||
UserIdType(`open_id`).DepartmentIdType(`open_department_id`).Build()
|
||||
resp, err := c.client.Contact.User.Get(context.Background(), req)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to get user info", log.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !resp.Success() {
|
||||
c.logger.Error("failed to get user info, response status not success", log.Any("errcode:", resp.Code))
|
||||
return nil, fmt.Errorf("failed to get user info, response data not success")
|
||||
}
|
||||
|
||||
return resp.Data.User, nil
|
||||
}
|
||||
|
||||
func (c *LarkClient) Stop() {
|
||||
c.cancel()
|
||||
}
|
||||
9
backend/pkg/bot/utils/utils.go
Normal file
9
backend/pkg/bot/utils/utils.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"github.com/russross/blackfriday/v2"
|
||||
)
|
||||
|
||||
func Markdown2HTML(md string) string {
|
||||
return string(blackfriday.Run([]byte(md), blackfriday.WithRenderer(blackfriday.NewHTMLRenderer(blackfriday.HTMLRendererParameters{Flags: blackfriday.UseXHTML | blackfriday.CompletePage}))))
|
||||
}
|
||||
106
backend/pkg/bot/wechat/domain.go
Normal file
106
backend/pkg/bot/wechat/domain.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package wechat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/xml"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/chaitin/panda-wiki/log"
|
||||
"github.com/chaitin/panda-wiki/repo/pg"
|
||||
)
|
||||
|
||||
type WechatConfig struct {
|
||||
Ctx context.Context
|
||||
logger *log.Logger
|
||||
CorpID string
|
||||
Token string
|
||||
EncodingAESKey string
|
||||
kbID string
|
||||
Secret string
|
||||
AccessToken string
|
||||
TokenExpire time.Time
|
||||
AgentID string
|
||||
// db
|
||||
WeRepo *pg.WechatRepository
|
||||
}
|
||||
|
||||
type ReceivedMessage struct {
|
||||
ToUserName string `xml:"ToUserName"`
|
||||
FromUserName string `xml:"FromUserName"`
|
||||
CreateTime int64 `xml:"CreateTime"`
|
||||
MsgType string `xml:"MsgType"`
|
||||
Content string `xml:"Content"`
|
||||
MsgID string `xml:"MsgId"`
|
||||
}
|
||||
|
||||
type ResponseMessage struct {
|
||||
XMLName xml.Name `xml:"xml"`
|
||||
ToUserName CDATA `xml:"ToUserName"`
|
||||
FromUserName CDATA `xml:"FromUserName"`
|
||||
CreateTime int64 `xml:"CreateTime"`
|
||||
MsgType CDATA `xml:"MsgType"`
|
||||
Content CDATA `xml:"Content"`
|
||||
}
|
||||
|
||||
type CDATA struct {
|
||||
Value string `xml:",cdata"`
|
||||
}
|
||||
|
||||
type BackendRequest struct {
|
||||
Question string `json:"question"`
|
||||
UserID string `json:"user_id"`
|
||||
}
|
||||
|
||||
type BackendResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data struct {
|
||||
TextResponse string `json:"test_response"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
// UserInfo 用于存储获取到的用户信息
|
||||
type UserInfo struct {
|
||||
Errcode int `json:"errcode"`
|
||||
Errmsg string `json:"errmsg"`
|
||||
UserID string `json:"userid"`
|
||||
Name string `json:"name"`
|
||||
Department []int `json:"department"`
|
||||
Mobile string `json:"mobile"`
|
||||
Email string `json:"email"`
|
||||
Status int `json:"status"`
|
||||
}
|
||||
|
||||
// 获取token的回应的消息
|
||||
type AccessToken struct {
|
||||
Errcode int `json:"errcode"`
|
||||
Errmsg string `json:"errmsg"`
|
||||
AccessToken string `json:"access_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
|
||||
type TokenCache struct {
|
||||
AccessToken string
|
||||
TokenExpire time.Time
|
||||
Mutex sync.Mutex
|
||||
}
|
||||
|
||||
// Map-based token cache keyed by kb & agentID
|
||||
var tokenCacheMap = make(map[string]*TokenCache)
|
||||
var tokenCacheMapMutex = sync.Mutex{}
|
||||
|
||||
// Generate a key for the token cache based on kb & agentID
|
||||
func getTokenCacheKey(kbID, agentID string) string {
|
||||
return kbID + ":" + agentID
|
||||
}
|
||||
|
||||
// media
|
||||
// Upload file response
|
||||
type MediaUploadResponse struct {
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
MediaType string `json:"type"`
|
||||
MediaID string `json:"media_id"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
393
backend/pkg/bot/wechat/wechat.go
Normal file
393
backend/pkg/bot/wechat/wechat.go
Normal file
@@ -0,0 +1,393 @@
|
||||
package wechat
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sbzhu/weworkapi_golang/wxbizmsgcrypt"
|
||||
|
||||
"github.com/chaitin/panda-wiki/consts"
|
||||
"github.com/chaitin/panda-wiki/domain"
|
||||
"github.com/chaitin/panda-wiki/log"
|
||||
"github.com/chaitin/panda-wiki/pkg/bot"
|
||||
)
|
||||
|
||||
const wechatMessageMaxBytes = 2000
|
||||
|
||||
func NewWechatAppConfig(ctx context.Context, logger *log.Logger, kbId, CorpID, Token, EncodingAESKey, secret, agentID string) (*WechatConfig, error) {
|
||||
return &WechatConfig{
|
||||
Ctx: ctx,
|
||||
logger: logger,
|
||||
kbID: kbId,
|
||||
CorpID: CorpID,
|
||||
Token: Token,
|
||||
EncodingAESKey: EncodingAESKey,
|
||||
Secret: secret,
|
||||
AgentID: agentID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (cfg *WechatConfig) VerifyUrlWechatAPP(signature, timestamp, nonce, echostr string) ([]byte, error) {
|
||||
wxcpt := wxbizmsgcrypt.NewWXBizMsgCrypt(
|
||||
cfg.Token,
|
||||
cfg.EncodingAESKey,
|
||||
cfg.CorpID,
|
||||
wxbizmsgcrypt.XmlType,
|
||||
)
|
||||
|
||||
// 验证URL并解密echostr
|
||||
decryptEchoStr, errCode := wxcpt.VerifyURL(signature, timestamp, nonce, echostr)
|
||||
if errCode != nil {
|
||||
return nil, errors.New("server serve fail wechat")
|
||||
}
|
||||
// success
|
||||
return decryptEchoStr, nil
|
||||
}
|
||||
|
||||
func (cfg *WechatConfig) Wechat(msg ReceivedMessage, getQA bot.GetQAFun, userinfo *UserInfo, useTextResponse bool, weChatAppAdvancedSetting *domain.WeChatAppAdvancedSetting) error {
|
||||
|
||||
token, err := cfg.GetAccessToken()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if useTextResponse {
|
||||
err = cfg.ProcessTextMessage(msg, getQA, token, userinfo, weChatAppAdvancedSetting.DisclaimerContent)
|
||||
if err != nil {
|
||||
cfg.logger.Error("send to ai failed!", log.Error(err))
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := cfg.ProcessUrlMessage(msg, getQA, token, userinfo); err != nil {
|
||||
cfg.logger.Error("send to ai failed!", log.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cfg *WechatConfig) ProcessUrlMessage(msg ReceivedMessage, GetQA bot.GetQAFun, token string, userinfo *UserInfo) error {
|
||||
// 1. get ai channel
|
||||
id, err := uuid.NewV7()
|
||||
if err != nil {
|
||||
cfg.logger.Error("failed to generate conversation uuid", log.Error(err))
|
||||
id = uuid.New()
|
||||
}
|
||||
conversationID := id.String()
|
||||
|
||||
contentChan, err := GetQA(cfg.Ctx, msg.Content, domain.ConversationInfo{
|
||||
UserInfo: domain.UserInfo{
|
||||
UserID: userinfo.UserID,
|
||||
NickName: userinfo.Name,
|
||||
From: domain.MessageFromPrivate,
|
||||
}}, conversationID)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
//2. go send to ai and store in map--> get conversation-id
|
||||
if _, ok := domain.ConversationManager.Load(conversationID); !ok {
|
||||
state := &domain.ConversationState{
|
||||
Question: msg.Content,
|
||||
NotificationChan: make(chan string), // notification channel
|
||||
IsVisited: false,
|
||||
}
|
||||
domain.ConversationManager.Store(conversationID, state)
|
||||
|
||||
go cfg.SendQuestionToAI(conversationID, contentChan)
|
||||
}
|
||||
|
||||
baseUrl, err := cfg.WeRepo.GetWechatBaseURL(cfg.Ctx, cfg.kbID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
//3.send url to user
|
||||
Errcode, Errmsg, err := cfg.SendURLToUser(msg.FromUserName, msg.Content, token, conversationID, baseUrl)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if Errcode != 0 {
|
||||
return fmt.Errorf("wechat Api failed : %s (code: %d)", Errmsg, Errcode)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cfg *WechatConfig) ProcessTextMessage(msg ReceivedMessage, GetQA bot.GetQAFun, token string, userinfo *UserInfo, disclaimerContent string) error {
|
||||
// 1. get ai channel
|
||||
id, err := uuid.NewV7()
|
||||
if err != nil {
|
||||
cfg.logger.Error("failed to generate conversation uuid", log.Error(err))
|
||||
id = uuid.New()
|
||||
}
|
||||
conversationID := id.String()
|
||||
|
||||
contentChan, err := GetQA(cfg.Ctx, msg.Content, domain.ConversationInfo{
|
||||
UserInfo: domain.UserInfo{
|
||||
UserID: userinfo.UserID,
|
||||
NickName: userinfo.Name,
|
||||
From: domain.MessageFromPrivate,
|
||||
}}, conversationID)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var fullResponse string
|
||||
for content := range contentChan {
|
||||
fullResponse += content
|
||||
if len([]byte(fullResponse)) > wechatMessageMaxBytes { // wechat limit 2048 byte
|
||||
if _, _, err := cfg.SendResponseToUser(fullResponse, msg.FromUserName, token); err != nil {
|
||||
return err
|
||||
}
|
||||
fullResponse = ""
|
||||
}
|
||||
}
|
||||
if len([]byte(fullResponse+disclaimerContent)) > wechatMessageMaxBytes {
|
||||
if _, _, err := cfg.SendResponseToUser(fullResponse, msg.FromUserName, token); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, _, err := cfg.SendResponseToUser(disclaimerContent, msg.FromUserName, token); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if disclaimerContent != "" {
|
||||
fullResponse += fmt.Sprintf("\n%s", disclaimerContent)
|
||||
}
|
||||
if _, _, err := cfg.SendResponseToUser(fullResponse, msg.FromUserName, token); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendResponseToUser
|
||||
func (cfg *WechatConfig) SendURLToUser(touser, question, token, conversationID, baseUrl string) (int, string, error) {
|
||||
msgData := map[string]interface{}{
|
||||
"touser": touser,
|
||||
"msgtype": "textcard",
|
||||
"agentid": cfg.AgentID,
|
||||
"textcard": map[string]interface{}{
|
||||
"title": question,
|
||||
"description": "<div class = \"highlight\">本回答由 PandaWiki 基于 AI 生成,仅供参考。</div>",
|
||||
"url": fmt.Sprintf("%s/h5-chat?id=%s&source_type=%s", baseUrl, conversationID, consts.SourceTypeWechatBot),
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msgData)
|
||||
if err != nil {
|
||||
return 0, "", err
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token=%s", token)
|
||||
resp, err := http.Post(url, "application/json", bytes.NewBuffer(jsonData))
|
||||
|
||||
if err != nil {
|
||||
return 0, "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
var result struct {
|
||||
Errcode int `json:"errcode"`
|
||||
Errmsg string `json:"errmsg"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return 0, "", err
|
||||
}
|
||||
return result.Errcode, result.Errmsg, nil
|
||||
}
|
||||
|
||||
func (cfg *WechatConfig) SendResponseToUser(response string, touser string, token string) (int, string, error) {
|
||||
|
||||
msgData := map[string]interface{}{
|
||||
"touser": touser,
|
||||
"msgtype": "markdown",
|
||||
"agentid": cfg.AgentID,
|
||||
"markdown": map[string]string{
|
||||
"content": response,
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msgData)
|
||||
if err != nil {
|
||||
return 0, "", err
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token=%s", token)
|
||||
resp, err := http.Post(url, "application/json", bytes.NewBuffer(jsonData))
|
||||
|
||||
if err != nil {
|
||||
return 0, "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
var result struct {
|
||||
Errcode int `json:"errcode"`
|
||||
Errmsg string `json:"errmsg"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return 0, "", err
|
||||
}
|
||||
if result.Errcode != 0 {
|
||||
return result.Errcode, result.Errmsg, fmt.Errorf("wechat Api failed : %s (code: %d)", result.Errmsg, result.Errcode)
|
||||
}
|
||||
return result.Errcode, result.Errmsg, nil
|
||||
}
|
||||
|
||||
// SendResponse
|
||||
func (cfg *WechatConfig) SendResponse(msg ReceivedMessage, content string) ([]byte, error) {
|
||||
|
||||
responseMsg := ResponseMessage{
|
||||
ToUserName: CDATA{msg.FromUserName},
|
||||
FromUserName: CDATA{msg.ToUserName},
|
||||
CreateTime: msg.CreateTime,
|
||||
MsgType: CDATA{"text"},
|
||||
Content: CDATA{content},
|
||||
}
|
||||
|
||||
// XML
|
||||
responseXML, err := xml.Marshal(responseMsg)
|
||||
if err != nil {
|
||||
cfg.logger.Error("marshal response failed", log.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wxcpt := wxbizmsgcrypt.NewWXBizMsgCrypt(cfg.Token, cfg.EncodingAESKey, cfg.CorpID, wxbizmsgcrypt.XmlType)
|
||||
|
||||
// response
|
||||
var encryptMsg []byte
|
||||
encryptMsg, errCode := wxcpt.EncryptMsg(string(responseXML), "", "")
|
||||
if errCode != nil {
|
||||
return nil, errors.New("encryotMsg err")
|
||||
}
|
||||
|
||||
return encryptMsg, nil
|
||||
}
|
||||
|
||||
func (cfg *WechatConfig) GetAccessToken() (string, error) {
|
||||
// Generate cache key based on app credentials
|
||||
cacheKey := getTokenCacheKey(cfg.kbID, cfg.AgentID)
|
||||
|
||||
// Get or create token cache for this app
|
||||
tokenCacheMapMutex.Lock()
|
||||
tokenCache, exists := tokenCacheMap[cacheKey]
|
||||
if !exists {
|
||||
tokenCache = &TokenCache{}
|
||||
tokenCacheMap[cacheKey] = tokenCache
|
||||
}
|
||||
tokenCacheMapMutex.Unlock()
|
||||
|
||||
// Lock the specific token cache for this app
|
||||
tokenCache.Mutex.Lock()
|
||||
defer tokenCache.Mutex.Unlock()
|
||||
|
||||
if tokenCache.AccessToken != "" && time.Now().Before(tokenCache.TokenExpire) {
|
||||
cfg.logger.Debug("access token has existed and is valid")
|
||||
return tokenCache.AccessToken, nil
|
||||
}
|
||||
|
||||
if cfg.Secret == "" || cfg.CorpID == "" {
|
||||
return "", errors.New("secret or corpid is not right")
|
||||
}
|
||||
|
||||
// get AccessToken--请求微信客服token
|
||||
url := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid=%s&corpsecret=%s", cfg.CorpID, cfg.Secret)
|
||||
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return "", errors.New("get wechatapp accesstoken failed")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var tokenResp AccessToken // 获取到token消息
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||
return "", errors.New("json decode wechat resp failed")
|
||||
}
|
||||
|
||||
if tokenResp.Errcode != 0 {
|
||||
return "", errors.New("get wechat access token failed")
|
||||
}
|
||||
|
||||
// success
|
||||
cfg.logger.Info("wechatapp get accesstoken success", log.Any("info", tokenResp.AccessToken))
|
||||
|
||||
tokenCache.AccessToken = tokenResp.AccessToken
|
||||
tokenCache.TokenExpire = time.Now().Add(time.Duration(tokenResp.ExpiresIn-300) * time.Second)
|
||||
|
||||
return tokenCache.AccessToken, nil
|
||||
}
|
||||
|
||||
func (cfg *WechatConfig) GetUserInfo(username string) (*UserInfo, error) {
|
||||
|
||||
accessToken, err := cfg.GetAccessToken()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 请求获取用户的内容
|
||||
resp, err := http.Get(fmt.Sprintf(
|
||||
"https://qyapi.weixin.qq.com/cgi-bin/user/get?access_token=%s&userid=%s",
|
||||
accessToken, username))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// cfg.logger.Info("获取用户信息成功", log.Any("body", body))
|
||||
|
||||
var userInfo UserInfo
|
||||
if err := json.Unmarshal(body, &userInfo); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if userInfo.Errcode != 0 {
|
||||
return nil, fmt.Errorf("获取用户信息失败: %d, %s", userInfo.Errcode, userInfo.Errmsg)
|
||||
}
|
||||
|
||||
return &userInfo, nil
|
||||
}
|
||||
|
||||
func (cfg *WechatConfig) UnmarshalMsg(decryptMsg []byte) (*ReceivedMessage, error) {
|
||||
var msg ReceivedMessage
|
||||
err := xml.Unmarshal([]byte(decryptMsg), &msg)
|
||||
return &msg, err
|
||||
}
|
||||
|
||||
// answer set into conversation state buffer
|
||||
func (cfg *WechatConfig) SendQuestionToAI(conversationID string, wccontent chan string) {
|
||||
// send message
|
||||
val, _ := domain.ConversationManager.Load(conversationID)
|
||||
state := val.(*domain.ConversationState)
|
||||
for content := range wccontent {
|
||||
state.Mutex.Lock()
|
||||
if state.IsVisited {
|
||||
state.NotificationChan <- content // notify has new data
|
||||
}
|
||||
state.Buffer.WriteString(content)
|
||||
state.Mutex.Unlock()
|
||||
}
|
||||
// end sent notification
|
||||
defer func() {
|
||||
close(state.NotificationChan)
|
||||
domain.ConversationManager.Delete(conversationID)
|
||||
}()
|
||||
}
|
||||
33
backend/pkg/bot/wechat_official_account/official_account.go
Normal file
33
backend/pkg/bot/wechat_official_account/official_account.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package wechat_official_account
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/silenceper/wechat/v2/officialaccount/user"
|
||||
|
||||
"github.com/chaitin/panda-wiki/pkg/bot"
|
||||
"github.com/chaitin/panda-wiki/pkg/bot/wechat_service"
|
||||
|
||||
"github.com/chaitin/panda-wiki/domain"
|
||||
)
|
||||
|
||||
func Wechat(ctx context.Context, GetQA bot.GetQAFun, userinfo *user.Info, content string) (string, error) {
|
||||
|
||||
wccontent, err := GetQA(ctx, content, domain.ConversationInfo{UserInfo: domain.UserInfo{
|
||||
UserID: userinfo.OpenID, // 用户对话的id
|
||||
NickName: userinfo.Nickname, //用户微信的昵称
|
||||
Avatar: userinfo.Headimgurl, // 用户微信的头像
|
||||
From: domain.MessageFromPrivate,
|
||||
}}, "")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var response string
|
||||
for v := range wccontent {
|
||||
response += v
|
||||
}
|
||||
response = wechat_service.MarkdowntoText(response)
|
||||
|
||||
return response, nil
|
||||
}
|
||||
188
backend/pkg/bot/wechat_service/domain.go
Normal file
188
backend/pkg/bot/wechat_service/domain.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package wechat_service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/chaitin/panda-wiki/log"
|
||||
"github.com/chaitin/panda-wiki/repo/pg"
|
||||
)
|
||||
|
||||
type WechatServiceConfig struct {
|
||||
Ctx context.Context
|
||||
CorpID string
|
||||
Token string
|
||||
EncodingAESKey string
|
||||
kbID string
|
||||
Secret string
|
||||
logger *log.Logger
|
||||
containKeywords []string
|
||||
equalKeywords []string
|
||||
logoUrl string
|
||||
// db
|
||||
WeRepo *pg.WechatRepository
|
||||
}
|
||||
|
||||
// 存储ai知识库获取的cursor值以客服为标准,方便拉取用户的消息
|
||||
var KfCursors = &sync.Map{}
|
||||
|
||||
// 微信客服发送的消息
|
||||
type WeixinUserAskMsg struct {
|
||||
ToUserName string `xml:"ToUserName"`
|
||||
CreateTime int64 `xml:"CreateTime"`
|
||||
MsgType string `xml:"MsgType"`
|
||||
Event string `xml:"Event"`
|
||||
Token string `xml:"Token"`
|
||||
OpenKfId string `xml:"OpenKfId"`
|
||||
}
|
||||
|
||||
type AccessToken struct {
|
||||
Errcode int `json:"errcode"`
|
||||
Errmsg string `json:"errmsg"`
|
||||
AccessToken string `json:"access_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
|
||||
type MsgRequest struct {
|
||||
Cursor string `json:"cursor"`
|
||||
Token string `json:"token"`
|
||||
Limit int `json:"limit"`
|
||||
VoiceFormat int `json:"voice_format"`
|
||||
OpenKfid string `json:"open_kfid"`
|
||||
}
|
||||
|
||||
type MsgRet struct {
|
||||
Errcode int `json:"errcode"`
|
||||
Errmsg string `json:"errmsg"`
|
||||
NextCursor string `json:"next_cursor"` // 游标
|
||||
MsgList []Msg `json:"msg_list"`
|
||||
HasMore int `json:"has_more"`
|
||||
}
|
||||
|
||||
type Msg struct {
|
||||
Msgid string `json:"msgid"`
|
||||
SendTime int64 `json:"send_time"`
|
||||
Origin int `json:"origin"`
|
||||
Msgtype string `json:"msgtype"`
|
||||
Event struct {
|
||||
EventType string `json:"event_type"`
|
||||
Scene string `json:"scene"`
|
||||
OpenKfid string `json:"open_kfid"`
|
||||
ExternalUserid string `json:"external_userid"`
|
||||
WelcomeCode string `json:"welcome_code"`
|
||||
} `json:"event"`
|
||||
Text struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"text"`
|
||||
OpenKfid string `json:"open_kfid"`
|
||||
ExternalUserid string `json:"external_userid"`
|
||||
}
|
||||
|
||||
// send msg to user with message
|
||||
type ReplyMsg struct {
|
||||
Touser string `json:"touser,omitempty"`
|
||||
OpenKfid string `json:"open_kfid,omitempty"`
|
||||
Msgid string `json:"msgid,omitempty"`
|
||||
Msgtype string `json:"msgtype,omitempty"`
|
||||
Text struct {
|
||||
Content string `json:"content,omitempty"`
|
||||
} `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
// send msg to user with url
|
||||
type ReplyMsgUrl struct {
|
||||
Touser string `json:"touser,omitempty"`
|
||||
OpenKfid string `json:"open_kfid,omitempty"`
|
||||
Msgid string `json:"msgid,omitempty"`
|
||||
Msgtype string `json:"msgtype,omitempty"`
|
||||
Link Link `json:"link,omitempty"`
|
||||
}
|
||||
|
||||
type Link struct {
|
||||
Title string `json:"title,omitempty"`
|
||||
Desc string `json:"desc,omitempty"`
|
||||
Url string `json:"url,omitempty"`
|
||||
ThumbMediaID string `json:"thumb_media_id,omitempty"`
|
||||
}
|
||||
|
||||
// Upload file response
|
||||
type MediaUploadResponse struct {
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
MediaType string `json:"type"`
|
||||
MediaID string `json:"media_id"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
// 获取用户消息应该得到的响应
|
||||
type WechatCustomerResponse struct {
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
CustomerList []Customer `json:"customer_list"`
|
||||
InvalidExternalUserIDs []string `json:"invalid_external_userid"`
|
||||
}
|
||||
|
||||
type Customer struct {
|
||||
ExternalUserID string `json:"external_userid"`
|
||||
Nickname string `json:"nickname"`
|
||||
Avatar string `json:"avatar"`
|
||||
Gender int `json:"gender"`
|
||||
UnionID string `json:"unionid"`
|
||||
}
|
||||
|
||||
type UerInfoRequest struct {
|
||||
UserID []string `json:"external_userid_list"`
|
||||
SessionContext int `json:"need_enter_session_context"`
|
||||
}
|
||||
|
||||
// chat status
|
||||
type Status struct {
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
ServiceState int `json:"service_state"`
|
||||
ServiceUserId string `json:"servicer_userid"`
|
||||
}
|
||||
|
||||
type HumanList struct {
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
ServicerList []ServicerList `json:"servicer_list"`
|
||||
}
|
||||
|
||||
type ServicerList struct {
|
||||
UserID string `json:"userid"`
|
||||
Status int `json:"status"`
|
||||
}
|
||||
|
||||
type TokenCache struct {
|
||||
AccessToken string
|
||||
TokenExpire time.Time
|
||||
Mutex sync.Mutex
|
||||
}
|
||||
|
||||
// Map-based token cache keyed by app credentials
|
||||
var tokenCacheMap = make(map[string]*TokenCache)
|
||||
var tokenCacheMapMutex = sync.Mutex{}
|
||||
|
||||
// Generate a key for the token cache based on app credentials
|
||||
func getTokenCacheKey(kbID, secret string) string {
|
||||
return kbID + ":" + secret
|
||||
}
|
||||
|
||||
type UserImageCache struct {
|
||||
ImageID string
|
||||
ImagePath string
|
||||
ImageExpire time.Time
|
||||
Mutex sync.Mutex
|
||||
}
|
||||
|
||||
var UImageCache = &UserImageCache{}
|
||||
|
||||
type DefaultImageCache struct {
|
||||
ImageID string
|
||||
ImageExpire time.Time
|
||||
Mutex sync.Mutex
|
||||
}
|
||||
|
||||
var DImageCache = &DefaultImageCache{}
|
||||
329
backend/pkg/bot/wechat_service/tools.go
Normal file
329
backend/pkg/bot/wechat_service/tools.go
Normal file
@@ -0,0 +1,329 @@
|
||||
package wechat_service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 读取 cursor,以客服账号的消息作为key,返回对应的cursor值
|
||||
func getCursor(openKfId string) string {
|
||||
cursorValue, _ := KfCursors.Load(openKfId)
|
||||
cursor, _ := cursorValue.(string)
|
||||
return cursor
|
||||
}
|
||||
|
||||
// 存储 cursor
|
||||
func setCursor(openKfId, cursor string) {
|
||||
KfCursors.Store(openKfId, cursor)
|
||||
}
|
||||
|
||||
func CheckSessionState(token, extrenaluserid, kfId string) (int, error) {
|
||||
var statusrequest struct {
|
||||
OpenKfId string `json:"open_kfid"`
|
||||
ExternalUserid string `json:"external_userid"`
|
||||
}
|
||||
statusrequest.OpenKfId = kfId
|
||||
statusrequest.ExternalUserid = extrenaluserid
|
||||
// 将请求体转换为JSON
|
||||
jsonBody, err := json.Marshal(statusrequest)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
// 获取状态信息
|
||||
url := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/kf/service_state/get?access_token=%s", token)
|
||||
resp, err := http.Post(url, "application/json", bytes.NewBuffer(jsonBody))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("发送请求失败: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 读取响应体
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("读取响应失败: %v", err)
|
||||
}
|
||||
|
||||
var response Status
|
||||
|
||||
if err := json.Unmarshal(body, &response); err != nil {
|
||||
return 0, fmt.Errorf("解析响应失败: %v", err)
|
||||
}
|
||||
// 得到用户的状态
|
||||
if response.ErrCode != 0 {
|
||||
return 0, fmt.Errorf("获取会话状态失败: %s", response.ErrMsg)
|
||||
}
|
||||
return response.ServiceState, nil
|
||||
}
|
||||
|
||||
func ChangeState(token, extrenaluserId, kfId string, state int, serviceId string) error {
|
||||
var changestate struct {
|
||||
OpenKfId string `json:"open_kfid"`
|
||||
ExternalUserid string `json:"external_userid"`
|
||||
ServiceState int `json:"service_state"`
|
||||
ServicerUserId string `json:"servicer_userid"`
|
||||
}
|
||||
changestate.OpenKfId = kfId
|
||||
changestate.ExternalUserid = extrenaluserId
|
||||
changestate.ServiceState = state
|
||||
changestate.ServicerUserId = serviceId
|
||||
jsonBody, err := json.Marshal(changestate)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 发送请求
|
||||
url := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/kf/service_state/trans?access_token=%s", token)
|
||||
resp, err := http.Post(url, "application/json", bytes.NewBuffer(jsonBody))
|
||||
if err != nil {
|
||||
return fmt.Errorf("发送请求失败: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 读取响应体
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("读取响应失败: %v", err)
|
||||
}
|
||||
// 解析响应
|
||||
var response struct {
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
MsgCode string `json:"msg_code"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &response); err != nil {
|
||||
return fmt.Errorf("解析响应失败: %v", err)
|
||||
}
|
||||
// 得到用户的状态
|
||||
if response.ErrCode != 0 {
|
||||
return fmt.Errorf("改变用户状态失败: %s", response.ErrMsg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetUserInfo(userid string, accessToken string) (*Customer, error) {
|
||||
userInfoRequest := UerInfoRequest{
|
||||
UserID: []string{userid},
|
||||
SessionContext: 0,
|
||||
}
|
||||
// 请求获取用户信息的url
|
||||
url := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/kf/customer/batchget?access_token=%s", accessToken)
|
||||
|
||||
jsonBody, err := json.Marshal(userInfoRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// post获取用户的消息信息
|
||||
resp, err := http.Post(url, "application/json", bytes.NewBuffer(jsonBody))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var userInfo WechatCustomerResponse
|
||||
if err := json.Unmarshal(body, &userInfo); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if userInfo.ErrCode != 0 {
|
||||
return nil, fmt.Errorf("获取用户信息失败: %d, %s", userInfo.ErrCode, userInfo.ErrMsg)
|
||||
}
|
||||
|
||||
return &userInfo.CustomerList[0], nil
|
||||
}
|
||||
|
||||
// get image id
|
||||
func GetUserImageID(accessToken, filePath string) (string, error) {
|
||||
UImageCache.Mutex.Lock()
|
||||
defer UImageCache.Mutex.Unlock()
|
||||
|
||||
if UImageCache.ImageID != "" && (UImageCache.ImagePath == filePath) && time.Now().Before(UImageCache.ImageExpire.Add(-5*time.Minute)) {
|
||||
return UImageCache.ImageID, nil
|
||||
}
|
||||
|
||||
// URL
|
||||
mediaID, err := UploadMediaFromURL(accessToken, filePath)
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
UImageCache.ImagePath = filePath
|
||||
UImageCache.ImageID = mediaID
|
||||
UImageCache.ImageExpire = time.Now().Add(72 * time.Hour) // 3 days
|
||||
return UImageCache.ImageID, nil
|
||||
}
|
||||
|
||||
// get image id
|
||||
func GetDefaultImageID(accessToken, ImageBase64 string) (string, error) {
|
||||
DImageCache.Mutex.Lock()
|
||||
defer DImageCache.Mutex.Unlock()
|
||||
|
||||
if DImageCache.ImageID != "" && time.Now().Before(DImageCache.ImageExpire.Add(-5*time.Minute)) {
|
||||
return DImageCache.ImageID, nil
|
||||
}
|
||||
|
||||
// Base64编码
|
||||
mediaID, err := UploadMediaFromBase64(accessToken, ImageBase64)
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
DImageCache.ImageID = mediaID
|
||||
DImageCache.ImageExpire = time.Now().Add(72 * time.Hour) // 3 days
|
||||
return DImageCache.ImageID, nil
|
||||
}
|
||||
|
||||
// upload media to wechat server from URL
|
||||
func UploadMediaFromURL(accessToken, fileURL string) (string, error) {
|
||||
// 处理URL
|
||||
resp, err := http.Get(fileURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("下载图片失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("下载图片失败,状态码: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
reader := resp.Body
|
||||
fileName := "image.png" // 默认文件名
|
||||
|
||||
// 从URL中提取文件名
|
||||
if u, err := url.Parse(fileURL); err == nil && u.Path != "" {
|
||||
if path.Base(u.Path) != "/" && path.Base(u.Path) != "." {
|
||||
fileName = path.Base(u.Path)
|
||||
}
|
||||
}
|
||||
|
||||
return uploadMediaToWechat(accessToken, reader, fileName)
|
||||
}
|
||||
|
||||
// upload media to wechat server from Base64
|
||||
func UploadMediaFromBase64(accessToken, base64Data string) (string, error) {
|
||||
// 处理Base64编码的图片
|
||||
parts := strings.SplitN(base64Data, ",", 2)
|
||||
if len(parts) != 2 {
|
||||
return "", fmt.Errorf("无效的Base64图片数据")
|
||||
}
|
||||
|
||||
// 解码Base64数据
|
||||
decodedData, err := base64.StdEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("解码Base64图片数据失败: %w", err)
|
||||
}
|
||||
|
||||
reader := bytes.NewReader(decodedData)
|
||||
fileName := "image.png" // const
|
||||
|
||||
return uploadMediaToWechat(accessToken, reader, fileName)
|
||||
}
|
||||
|
||||
// upload media to wechat server - common function
|
||||
func uploadMediaToWechat(accessToken string, reader io.Reader, fileName string) (string, error) {
|
||||
// 上传文件 req
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
|
||||
part, err := writer.CreateFormFile("media", fileName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 将图片数据复制到表单中
|
||||
_, err = io.Copy(part, reader)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("复制图片数据失败: %w", err)
|
||||
}
|
||||
|
||||
if err := writer.Close(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/media/upload?access_token=%s&type=image", accessToken)
|
||||
req, err := http.NewRequest("POST", url, body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
|
||||
client := &http.Client{}
|
||||
httpResp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer httpResp.Body.Close()
|
||||
|
||||
var result MediaUploadResponse
|
||||
if err := json.NewDecoder(httpResp.Body).Decode(&result); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if result.ErrCode != 0 {
|
||||
return "", fmt.Errorf("上传失败: [%d] %s", result.ErrCode, result.ErrMsg)
|
||||
}
|
||||
|
||||
return result.MediaID, nil
|
||||
}
|
||||
|
||||
func getMsgs(accessToken string, msg *WeixinUserAskMsg) (*MsgRet, error) {
|
||||
var msgRet MsgRet
|
||||
// 拉取消息的路由
|
||||
url := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/kf/sync_msg?access_token=%s", accessToken)
|
||||
cursor := getCursor(msg.OpenKfId)
|
||||
|
||||
msgBody := MsgRequest{
|
||||
OpenKfid: msg.OpenKfId,
|
||||
Token: msg.Token,
|
||||
Limit: 1000,
|
||||
VoiceFormat: 0,
|
||||
Cursor: cursor,
|
||||
}
|
||||
|
||||
jsonBody, _ := json.Marshal(msgBody)
|
||||
|
||||
resp, err := http.Post(url, "application/json", bytes.NewBuffer(jsonBody)) // 得到对应的回复
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 反序列化之后
|
||||
if err := json.Unmarshal([]byte(string(body)), &msgRet); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &msgRet, nil
|
||||
}
|
||||
|
||||
// markdowntotext
|
||||
func MarkdowntoText(md string) string {
|
||||
md = regexp.MustCompile(`(?m)^#+\s*(.*)$`).ReplaceAllString(md, "$1")
|
||||
md = regexp.MustCompile(`\*\*([^*]+)\*\*`).ReplaceAllString(md, "$1")
|
||||
md = regexp.MustCompile(`(?m)^>\s*(.*)$`).ReplaceAllString(md, "【引用】$1")
|
||||
md = regexp.MustCompile(`(?m)^-{3,}$`).ReplaceAllString(md, "─────────")
|
||||
md = regexp.MustCompile(`\n{3,}`).ReplaceAllString(md, "\n\n")
|
||||
md = regexp.MustCompile(`\[\[(\d+)\]\([^)]+\)\]`).ReplaceAllString(md, "[$1]")
|
||||
md = regexp.MustCompile(`\[(\d+)\]\.\s*\[([^\]]+)\]\([^)]+\)`).ReplaceAllString(md, "[$1]. $2")
|
||||
md = regexp.MustCompile(`(?m)^【引用】\[(\d+)\].\s*([^\n(]+)\s*\([^)]+\)`).ReplaceAllString(md, "【引用】[$1]. $2")
|
||||
return strings.TrimSpace(md)
|
||||
}
|
||||
403
backend/pkg/bot/wechat_service/wechat.go
Normal file
403
backend/pkg/bot/wechat_service/wechat.go
Normal file
@@ -0,0 +1,403 @@
|
||||
package wechat_service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/samber/lo"
|
||||
"github.com/sbzhu/weworkapi_golang/wxbizmsgcrypt"
|
||||
|
||||
"github.com/chaitin/panda-wiki/domain"
|
||||
"github.com/chaitin/panda-wiki/log"
|
||||
"github.com/chaitin/panda-wiki/pkg/bot"
|
||||
)
|
||||
|
||||
func NewWechatServiceConfig(ctx context.Context, logger *log.Logger, KbId, CorpID, Token, EncodingAESKey, secret, logo string, containKeywords, equalKeywords []string) (*WechatServiceConfig, error) {
|
||||
return &WechatServiceConfig{
|
||||
Ctx: ctx,
|
||||
kbID: KbId,
|
||||
CorpID: CorpID,
|
||||
Token: Token,
|
||||
EncodingAESKey: EncodingAESKey,
|
||||
Secret: secret,
|
||||
logger: logger,
|
||||
containKeywords: containKeywords,
|
||||
equalKeywords: equalKeywords,
|
||||
logoUrl: logo,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (cfg *WechatServiceConfig) VerifyUrlWechatService(signature, timestamp, nonce, echostr string) ([]byte, error) {
|
||||
wxcpt := wxbizmsgcrypt.NewWXBizMsgCrypt(
|
||||
cfg.Token,
|
||||
cfg.EncodingAESKey,
|
||||
cfg.CorpID,
|
||||
wxbizmsgcrypt.XmlType,
|
||||
)
|
||||
|
||||
// 验证URL并解密echostr
|
||||
decryptEchoStr, errCode := wxcpt.VerifyURL(signature, timestamp, nonce, echostr)
|
||||
if errCode != nil {
|
||||
return nil, errors.New("server serve fail wechat")
|
||||
}
|
||||
// success
|
||||
return decryptEchoStr, nil
|
||||
}
|
||||
|
||||
func (cfg *WechatServiceConfig) Wechat(msg *WeixinUserAskMsg, getQA bot.GetQAFun) error {
|
||||
// 获取accesstoken 方便给用户发送消息
|
||||
token, err := cfg.GetAccessToken()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 主动拉去用户发送的消息
|
||||
msgRet, err := getMsgs(token, msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if msgRet.NextCursor != "" {
|
||||
setCursor(msg.OpenKfId, msgRet.NextCursor)
|
||||
}
|
||||
|
||||
err = cfg.Processmessage(msgRet, msg, getQA)
|
||||
if err != nil {
|
||||
cfg.logger.Error("send to ai failed!")
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// forwardToBackend
|
||||
func (cfg *WechatServiceConfig) Processmessage(msgRet *MsgRet, Kfmsg *WeixinUserAskMsg, GetQA bot.GetQAFun) error {
|
||||
// err message
|
||||
cfg.logger.Info("get user message", log.Int("msgRet.Errcode", msgRet.Errcode), log.String("msg.Errmsg", msgRet.Errmsg))
|
||||
|
||||
size := len(msgRet.MsgList)
|
||||
if size < 1 {
|
||||
return fmt.Errorf("no message received")
|
||||
}
|
||||
// 如果是用户刚刚进入会话的事件,那么不需要发送消息给用户
|
||||
if msgRet.MsgList[size-1].Msgtype == "event" && msgRet.MsgList[size-1].Event.EventType == "enter_session" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 每次只是拿去最新的数据
|
||||
current := msgRet.MsgList[size-1]
|
||||
userId := current.ExternalUserid
|
||||
openkfId := current.OpenKfid
|
||||
content := current.Text.Content
|
||||
|
||||
token, _ := cfg.GetAccessToken()
|
||||
|
||||
state, err := CheckSessionState(token, userId, openkfId)
|
||||
if err != nil {
|
||||
cfg.logger.Error("check session state failed", log.Error(err))
|
||||
return err
|
||||
}
|
||||
if state == 3 { // 人工状态 ---已经是人工,那么就不要需要发消息给用户
|
||||
cfg.logger.Info("the customer has already in human service")
|
||||
return nil
|
||||
}
|
||||
if len(cfg.equalKeywords) > 0 || len(cfg.containKeywords) > 0 {
|
||||
if slices.Contains(cfg.equalKeywords, content) || lo.SomeBy(cfg.containKeywords, func(sub string) bool {
|
||||
return strings.Contains(content, sub)
|
||||
}) {
|
||||
// 改变状态为人工接待
|
||||
// 非人工 ->转人工
|
||||
humanList, err := cfg.GetKfHumanList(token, openkfId)
|
||||
if err != nil {
|
||||
cfg.logger.Error("get human list failed", log.Error(err))
|
||||
return err
|
||||
}
|
||||
// 遍历找到可以接待的员工
|
||||
for _, servicer := range humanList.ServicerList {
|
||||
if servicer.Status == 0 { // 可以接待
|
||||
err := ChangeState(token, userId, openkfId, 3, servicer.UserID)
|
||||
if err != nil {
|
||||
cfg.logger.Error("change state to human failed", log.Error(err))
|
||||
return err
|
||||
}
|
||||
cfg.logger.Info("change state to human successful") // 转人工成功
|
||||
return nil
|
||||
}
|
||||
}
|
||||
// 失败
|
||||
cfg.logger.Info("no human available")
|
||||
return cfg.SendResponseToKfTxt(userId, openkfId, "当前没有可用的人工客服", token)
|
||||
}
|
||||
}
|
||||
|
||||
// 1. first response to user
|
||||
if err := cfg.SendResponseToKfTxt(userId, openkfId, "正在思考您的问题,请稍等...", token); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 获取用户的详细信息
|
||||
customer, err := GetUserInfo(userId, token)
|
||||
if err != nil {
|
||||
cfg.logger.Error("get user info failed", log.Error(err))
|
||||
}
|
||||
cfg.logger.Info("customer info", log.Any("customer", customer))
|
||||
|
||||
id, err := uuid.NewV7()
|
||||
if err != nil {
|
||||
cfg.logger.Error("failed to generate conversation uuid", log.Error(err))
|
||||
id = uuid.New()
|
||||
}
|
||||
conversationID := id.String()
|
||||
wccontent, err := GetQA(cfg.Ctx, content, domain.ConversationInfo{UserInfo: domain.UserInfo{
|
||||
UserID: customer.ExternalUserID, // 用户对话的id
|
||||
NickName: customer.Nickname, //用户微信的昵称
|
||||
Avatar: customer.Avatar, // 用户微信的头像
|
||||
From: domain.MessageFromPrivate,
|
||||
}}, conversationID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
//2. get baseurl and image path
|
||||
info, err := cfg.WeRepo.GetWechatStatic(cfg.Ctx, cfg.kbID, domain.AppTypeWeb)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
//2. go send to ai and store in map--> get conversation-id
|
||||
if _, ok := domain.ConversationManager.Load(conversationID); !ok {
|
||||
state := &domain.ConversationState{
|
||||
Question: content,
|
||||
NotificationChan: make(chan string), // notification channel
|
||||
IsVisited: false,
|
||||
}
|
||||
domain.ConversationManager.Store(conversationID, state)
|
||||
|
||||
go cfg.SendQuestionToAI(conversationID, wccontent)
|
||||
}
|
||||
// 3. second send url to user
|
||||
return cfg.SendResponseToKfUrl(userId, openkfId, conversationID, token, content, info.BaseUrl, info.ImagePath)
|
||||
}
|
||||
|
||||
func (cfg *WechatServiceConfig) getImageID(token, image string) (string, error) {
|
||||
const minioPrefix = "http://panda-wiki-minio:9000"
|
||||
|
||||
// 优先使用配置的logoUrl
|
||||
if cfg.logoUrl != "" {
|
||||
image = cfg.logoUrl
|
||||
}
|
||||
|
||||
var imageId string
|
||||
var err error
|
||||
|
||||
switch {
|
||||
case image == "":
|
||||
case strings.HasPrefix(image, "data:image/"):
|
||||
imageId, err = GetDefaultImageID(token, image)
|
||||
default:
|
||||
imageId, err = GetUserImageID(token, fmt.Sprintf("%s%s", minioPrefix, image))
|
||||
}
|
||||
|
||||
if imageId != "" && err == nil {
|
||||
return imageId, nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
cfg.logger.Error("failed to get image ID, using default", log.Error(err))
|
||||
}
|
||||
|
||||
return GetDefaultImageID(token, domain.DefaultPandaWikiIconB64)
|
||||
}
|
||||
|
||||
func (cfg *WechatServiceConfig) SendResponseToKfUrl(userId, openkfId, conversationID, token, question, baseUrl, image string) error {
|
||||
imageId, err := cfg.getImageID(token, image)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if utf8.RuneCountInString(question) > 35 {
|
||||
question = string([]rune(question)[:35]) + "......"
|
||||
}
|
||||
|
||||
reply := ReplyMsgUrl{
|
||||
Touser: userId,
|
||||
OpenKfid: openkfId,
|
||||
Msgtype: "link",
|
||||
Link: Link{
|
||||
Url: fmt.Sprintf("%s/h5-chat?id=%s", baseUrl, conversationID),
|
||||
Desc: "本回答由 PandaWiki 基于 AI 生成,仅供参考。",
|
||||
Title: question,
|
||||
ThumbMediaID: imageId,
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reply)
|
||||
if err != nil {
|
||||
return fmt.Errorf("json Marshal failed: %w", err)
|
||||
}
|
||||
return cfg.SendMessage(jsonData, token)
|
||||
}
|
||||
|
||||
func (cfg *WechatServiceConfig) SendResponseToKfTxt(userId string, openkfId string, response string, token string) error {
|
||||
// send text data to user
|
||||
reply := ReplyMsg{
|
||||
Touser: userId,
|
||||
OpenKfid: openkfId,
|
||||
Msgtype: "text",
|
||||
Text: struct {
|
||||
Content string `json:"content,omitempty"`
|
||||
}{Content: response},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reply)
|
||||
if err != nil {
|
||||
return fmt.Errorf("json Marshal failed: %w", err)
|
||||
}
|
||||
return cfg.SendMessage(jsonData, token)
|
||||
}
|
||||
|
||||
func (cfg *WechatServiceConfig) SendMessage(jsonData []byte, token string) error {
|
||||
// 发送消息给客服
|
||||
url := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/kf/send_msg?access_token=%s", token)
|
||||
resp, err := http.Post(url, "application/json", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return fmt.Errorf("post to wechatservice failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read response body failed: %w", err)
|
||||
}
|
||||
|
||||
var res struct {
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
MsgID string `json:"msgid"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &res); err != nil {
|
||||
cfg.logger.Error("解析响应失败", log.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
if res.ErrCode != 0 {
|
||||
cfg.logger.Error("发送给微信客服消息失败", log.Any("errcode", res.ErrCode), log.Any("errmsg", res.ErrMsg), log.Any("jsonData", string(jsonData)))
|
||||
return err
|
||||
}
|
||||
// 发送消息给微信客服成功
|
||||
s := string(body)
|
||||
cfg.logger.Info("response from wechatservice success", log.Any("body", s))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cfg *WechatServiceConfig) GetAccessToken() (string, error) {
|
||||
// Generate cache key based on app credentials
|
||||
cacheKey := getTokenCacheKey(cfg.kbID, cfg.Secret)
|
||||
|
||||
// Get or create token cache for this app
|
||||
tokenCacheMapMutex.Lock()
|
||||
tokenCache, exists := tokenCacheMap[cacheKey]
|
||||
if !exists {
|
||||
tokenCache = &TokenCache{}
|
||||
tokenCacheMap[cacheKey] = tokenCache
|
||||
}
|
||||
tokenCacheMapMutex.Unlock()
|
||||
|
||||
// Lock the specific token cache for this app
|
||||
tokenCache.Mutex.Lock()
|
||||
defer tokenCache.Mutex.Unlock()
|
||||
|
||||
if tokenCache.AccessToken != "" && time.Now().Before(tokenCache.TokenExpire) {
|
||||
cfg.logger.Debug("access token has existed and is valid")
|
||||
return tokenCache.AccessToken, nil
|
||||
}
|
||||
|
||||
if cfg.Secret == "" || cfg.CorpID == "" {
|
||||
return "", errors.New("secret or corpid is not right")
|
||||
}
|
||||
|
||||
// get AccessToken--请求微信客服token
|
||||
url := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid=%s&corpsecret=%s", cfg.CorpID, cfg.Secret)
|
||||
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return "", errors.New("get wechatservice accesstoken failed")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var tokenResp AccessToken // 获取到token消息
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||
return "", errors.New("json decode wechat resp failed")
|
||||
}
|
||||
|
||||
if tokenResp.Errcode != 0 {
|
||||
return "", errors.New("get wechat access token failed")
|
||||
}
|
||||
|
||||
// success
|
||||
cfg.logger.Info("wechatservice get accesstoken success", log.Any("info", tokenResp.AccessToken))
|
||||
|
||||
tokenCache.AccessToken = tokenResp.AccessToken
|
||||
tokenCache.TokenExpire = time.Now().Add(time.Duration(tokenResp.ExpiresIn-300) * time.Second)
|
||||
|
||||
return tokenCache.AccessToken, nil
|
||||
}
|
||||
|
||||
// 解析微信客服消息
|
||||
func (cfg *WechatServiceConfig) UnmarshalMsg(decryptMsg []byte) (*WeixinUserAskMsg, error) {
|
||||
var msg WeixinUserAskMsg
|
||||
err := xml.Unmarshal([]byte(decryptMsg), &msg)
|
||||
return &msg, err
|
||||
}
|
||||
|
||||
func (cfg *WechatServiceConfig) GetKfHumanList(token string, KfId string) (*HumanList, error) {
|
||||
url := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/kf/servicer/list?access_token=%s&open_kfid=%s", token, KfId)
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var servicerResp HumanList
|
||||
if err := json.Unmarshal(body, &servicerResp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if servicerResp.ErrCode != 0 {
|
||||
return nil, fmt.Errorf("获取客服列表失败: %d, %s", servicerResp.ErrCode, servicerResp.ErrMsg)
|
||||
}
|
||||
|
||||
return &servicerResp, nil
|
||||
}
|
||||
|
||||
// answer set into redis queue and set useful time
|
||||
func (cfg *WechatServiceConfig) SendQuestionToAI(conversationID string, wccontent chan string) {
|
||||
// send message
|
||||
val, _ := domain.ConversationManager.Load(conversationID)
|
||||
state := val.(*domain.ConversationState)
|
||||
for content := range wccontent {
|
||||
state.Mutex.Lock()
|
||||
if state.IsVisited {
|
||||
state.NotificationChan <- content // notify has new data
|
||||
}
|
||||
state.Buffer.WriteString(content)
|
||||
state.Mutex.Unlock()
|
||||
}
|
||||
// end sent notification
|
||||
defer func() {
|
||||
close(state.NotificationChan)
|
||||
domain.ConversationManager.Delete(conversationID)
|
||||
}()
|
||||
}
|
||||
144
backend/pkg/bot/wecom/ai_bot.go
Normal file
144
backend/pkg/bot/wecom/ai_bot.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/chaitin/panda-wiki/log"
|
||||
)
|
||||
|
||||
// AIBotClient 微信智能机器人
|
||||
// https://developer.work.weixin.qq.com/document/path/100719
|
||||
type AIBotClient struct {
|
||||
ctx context.Context
|
||||
logger *log.Logger
|
||||
Token string
|
||||
EncodingAESKey string
|
||||
}
|
||||
|
||||
type UserReq struct {
|
||||
Msgid string `json:"msgid"`
|
||||
Aibotid string `json:"aibotid"`
|
||||
Chattype string `json:"chattype"`
|
||||
From struct {
|
||||
Userid string `json:"userid"`
|
||||
} `json:"from"`
|
||||
Msgtype string `json:"msgtype"`
|
||||
Text struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"text"`
|
||||
Stream struct {
|
||||
Id string `json:"id"`
|
||||
} `json:"stream"`
|
||||
}
|
||||
type UserResp struct {
|
||||
Msgtype string `json:"msgtype"`
|
||||
Stream Stream `json:"stream"`
|
||||
}
|
||||
|
||||
type Stream struct {
|
||||
Id string `json:"id"`
|
||||
Finish bool `json:"finish"`
|
||||
Content string `json:"content"`
|
||||
MsgItem []struct {
|
||||
Msgtype string `json:"msgtype"`
|
||||
Image struct {
|
||||
Base64 string `json:"base64"`
|
||||
Md5 string `json:"md5"`
|
||||
} `json:"image"`
|
||||
} `json:"msg_item"`
|
||||
}
|
||||
|
||||
func NewAIBotClient(
|
||||
ctx context.Context,
|
||||
logger *log.Logger,
|
||||
Token string,
|
||||
EncodingAESKey string,
|
||||
) (*AIBotClient, error) {
|
||||
return &AIBotClient{
|
||||
ctx: ctx,
|
||||
logger: logger,
|
||||
Token: Token,
|
||||
EncodingAESKey: EncodingAESKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *AIBotClient) VerifyUrlWecomService(signature, timestamp, nonce, echostr string) (string, error) {
|
||||
wx, _, err := NewWXBizJsonMsgCrypt(
|
||||
c.Token,
|
||||
c.EncodingAESKey,
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
code, sReplyEchoStr := wx.VerifyURL(signature, timestamp, nonce, echostr)
|
||||
if code != 0 {
|
||||
c.logger.Error("VerifyUrlWecomService failed:", log.Any("code", code))
|
||||
return "", c.getErrorMessage(code)
|
||||
}
|
||||
|
||||
return sReplyEchoStr, nil
|
||||
}
|
||||
|
||||
func (c *AIBotClient) DecryptUserReq(signature, timestamp, nonce, msg string) (*UserReq, error) {
|
||||
|
||||
wx, _, err := NewWXBizJsonMsgCrypt(
|
||||
c.Token,
|
||||
c.EncodingAESKey,
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
code, reqMsg := wx.DecryptMsg(msg, signature, timestamp, nonce)
|
||||
if code != 0 {
|
||||
return nil, c.getErrorMessage(code)
|
||||
}
|
||||
var data UserReq
|
||||
|
||||
c.logger.Info("decrypt user req:", log.Any("reqMsg", reqMsg))
|
||||
err = json.Unmarshal([]byte(reqMsg), &data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
func (c *AIBotClient) MakeStreamResp(nonce, id, content string, isFinish bool) (string, error) {
|
||||
c.logger.Debug("MakeStreamResp:", log.String("content", content), log.Any("isFinish", isFinish))
|
||||
wx, _, err := NewWXBizJsonMsgCrypt(
|
||||
c.Token,
|
||||
c.EncodingAESKey,
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
resp := UserResp{
|
||||
Msgtype: "stream",
|
||||
Stream: Stream{
|
||||
Id: id,
|
||||
Finish: isFinish,
|
||||
Content: content,
|
||||
MsgItem: nil,
|
||||
},
|
||||
}
|
||||
|
||||
b, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
code, msg := wx.EncryptMsg(string(b), nonce)
|
||||
if code != 0 {
|
||||
c.logger.Error("MakeStreamResp failed:", log.Any("code", code))
|
||||
return "", c.getErrorMessage(code)
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
374
backend/pkg/bot/wecom/crypt.go
Normal file
374
backend/pkg/bot/wecom/crypt.go
Normal file
@@ -0,0 +1,374 @@
|
||||
// Package wecom provides cryptographic utilities for WeChat Work (WeCom) message encryption and decryption.
|
||||
// It implements the WXBizMsgCrypt algorithm for secure message handling with WeChat Work APIs.
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
WXBizMsgCrypt_OK = 0
|
||||
WXBizMsgCrypt_ValidateSignature_Error = 40001
|
||||
WXBizMsgCrypt_ParseJson_Error = 40002
|
||||
WXBizMsgCrypt_ComputeSignature_Error = 40003
|
||||
WXBizMsgCrypt_IllegalAesKey = 40004
|
||||
WXBizMsgCrypt_EncryptAES_Error = 40005
|
||||
WXBizMsgCrypt_DecryptAES_Error = 40006
|
||||
WXBizMsgCrypt_IllegalBuffer = 40007
|
||||
WXBizMsgCrypt_ValidateCorpid_Error = 40008
|
||||
WXBizMsgCrypt_ValidateCorpid_Receive_Id = 40009
|
||||
WXBizMsgCrypt_ValidateCorpid_Mismatch = 40010
|
||||
)
|
||||
|
||||
var wecomErrorMessages = map[int]string{
|
||||
WXBizMsgCrypt_OK: "success",
|
||||
WXBizMsgCrypt_ValidateSignature_Error: "signature validation failed",
|
||||
WXBizMsgCrypt_ParseJson_Error: "invalid JSON format",
|
||||
WXBizMsgCrypt_ComputeSignature_Error: "signature computation failed",
|
||||
WXBizMsgCrypt_IllegalAesKey: "illegal AES key",
|
||||
WXBizMsgCrypt_EncryptAES_Error: "AES encryption failed",
|
||||
WXBizMsgCrypt_DecryptAES_Error: "AES decryption failed",
|
||||
WXBizMsgCrypt_IllegalBuffer: "illegal buffer format",
|
||||
WXBizMsgCrypt_ValidateCorpid_Error: "corp ID validation failed",
|
||||
WXBizMsgCrypt_ValidateCorpid_Receive_Id: "receive ID validation failed",
|
||||
WXBizMsgCrypt_ValidateCorpid_Mismatch: "corp ID mismatch",
|
||||
}
|
||||
|
||||
func (c *AIBotClient) getErrorMessage(code int) error {
|
||||
if msg, ok := wecomErrorMessages[code]; ok {
|
||||
return fmt.Errorf("wecom error (code %d): %s", code, msg)
|
||||
}
|
||||
return fmt.Errorf("unknown wecom error: %d", code)
|
||||
}
|
||||
|
||||
var ErrFormat = errors.New("format error")
|
||||
|
||||
// SHA1 负责生成安全签名(sha1)
|
||||
type SHA1 struct{}
|
||||
|
||||
// GetSHA1 : 对 token, timestamp, nonce, encrypt 排序后 sha1
|
||||
// 返回 (code, signature)
|
||||
func (s *SHA1) GetSHA1(token, timestamp, nonce string, encrypt interface{}) (int, string) {
|
||||
defer func() {
|
||||
// no panic propagation in this helper; but keep signature simple
|
||||
}()
|
||||
encStr := ""
|
||||
switch v := encrypt.(type) {
|
||||
case string:
|
||||
encStr = v
|
||||
case []byte:
|
||||
encStr = string(v)
|
||||
case nil:
|
||||
encStr = ""
|
||||
default:
|
||||
encStr = fmt.Sprint(v)
|
||||
}
|
||||
list := []string{token, timestamp, nonce, encStr}
|
||||
sort.Strings(list)
|
||||
joined := strings.Join(list, "")
|
||||
h := sha1.New()
|
||||
_, err := h.Write([]byte(joined))
|
||||
if err != nil {
|
||||
return WXBizMsgCrypt_ComputeSignature_Error, ""
|
||||
}
|
||||
return WXBizMsgCrypt_OK, fmt.Sprintf("%x", h.Sum(nil))
|
||||
}
|
||||
|
||||
// JsonParse 提取/生成 json 消息
|
||||
type JsonParse struct{}
|
||||
|
||||
type aesTextResponse struct {
|
||||
Encrypt string `json:"encrypt"`
|
||||
MsgSignature string `json:"msgsignature"`
|
||||
Timestamp string `json:"timestamp"`
|
||||
Nonce string `json:"nonce"`
|
||||
}
|
||||
|
||||
// Extract 从 json 字符串中提取 encrypt 字段
|
||||
// 返回 (code, encrypt)
|
||||
func (jp *JsonParse) Extract(jsonText string) (int, string) {
|
||||
var m map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(jsonText), &m); err != nil {
|
||||
return WXBizMsgCrypt_ParseJson_Error, ""
|
||||
}
|
||||
if v, ok := m["encrypt"].(string); ok {
|
||||
return WXBizMsgCrypt_OK, v
|
||||
}
|
||||
return WXBizMsgCrypt_ParseJson_Error, ""
|
||||
}
|
||||
|
||||
// Generate 根据参数生成 json 字符串
|
||||
func (jp *JsonParse) Generate(encrypt, signature, timestamp, nonce string) string {
|
||||
resp := aesTextResponse{
|
||||
Encrypt: encrypt,
|
||||
MsgSignature: signature,
|
||||
Timestamp: timestamp,
|
||||
Nonce: nonce,
|
||||
}
|
||||
bs, _ := json.Marshal(resp)
|
||||
return string(bs)
|
||||
}
|
||||
|
||||
// PKCS7Encoder 提供基于 PKCS7 的填充/去填充
|
||||
type PKCS7Encoder struct {
|
||||
BlockSize int // 使用 32 与 Python 示例一致
|
||||
}
|
||||
|
||||
func NewPKCS7Encoder() *PKCS7Encoder {
|
||||
return &PKCS7Encoder{BlockSize: 32}
|
||||
}
|
||||
|
||||
func (p *PKCS7Encoder) Encode(src []byte) []byte {
|
||||
if src == nil {
|
||||
src = []byte{}
|
||||
}
|
||||
n := len(src)
|
||||
amountToPad := p.BlockSize - (n % p.BlockSize)
|
||||
if amountToPad == 0 {
|
||||
amountToPad = p.BlockSize
|
||||
}
|
||||
pad := byte(amountToPad)
|
||||
padtext := bytes.Repeat([]byte{pad}, amountToPad)
|
||||
return append(src, padtext...)
|
||||
}
|
||||
|
||||
func (p *PKCS7Encoder) Decode(decrypted []byte) ([]byte, error) {
|
||||
if len(decrypted) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
pad := int(decrypted[len(decrypted)-1])
|
||||
if pad < 1 || pad > p.BlockSize {
|
||||
// 同 Python 逻辑:当 pad 值不合理时,视为 0(或 error)
|
||||
return decrypted, fmt.Errorf("invalid padding")
|
||||
}
|
||||
return decrypted[:len(decrypted)-pad], nil
|
||||
}
|
||||
|
||||
// Prpcrypt 提供 AES 加解密功能
|
||||
type Prpcrypt struct {
|
||||
Key []byte
|
||||
Mode string // not used but kept for parity
|
||||
}
|
||||
|
||||
func NewPrpcrypt(key []byte) *Prpcrypt {
|
||||
return &Prpcrypt{Key: key, Mode: "CBC"}
|
||||
}
|
||||
|
||||
// Encrypt 对明文加密,返回 (code, base64Ciphertext)
|
||||
func (pc *Prpcrypt) Encrypt(plainText string, receiveID string) (int, string) {
|
||||
// 将明文转换为 bytes
|
||||
txt := []byte(plainText)
|
||||
|
||||
// 随机 16 字节数字字符串
|
||||
rand16, err := getRandom16BytesAsDigits()
|
||||
if err != nil {
|
||||
return WXBizMsgCrypt_EncryptAES_Error, ""
|
||||
}
|
||||
|
||||
// 包装: 16 bytes random + 4 bytes network-order(len) + txt + receiveid
|
||||
buf := bytes.NewBuffer(nil)
|
||||
buf.Write(rand16)
|
||||
|
||||
// len(txt) 网络字节序
|
||||
lenBuf := make([]byte, 4)
|
||||
// Python 示例使用 socket.htonl(len(text)),即 network order (big endian)
|
||||
binary.BigEndian.PutUint32(lenBuf, uint32(len(txt)))
|
||||
buf.Write(lenBuf)
|
||||
buf.Write(txt)
|
||||
buf.Write([]byte(receiveID))
|
||||
raw := buf.Bytes()
|
||||
|
||||
// PKCS7 pad 到 blocksize=32
|
||||
encoder := NewPKCS7Encoder()
|
||||
padded := encoder.Encode(raw)
|
||||
|
||||
// AES-CBC
|
||||
block, err := aes.NewCipher(pc.Key)
|
||||
if err != nil {
|
||||
return WXBizMsgCrypt_EncryptAES_Error, ""
|
||||
}
|
||||
iv := pc.Key[:16]
|
||||
if len(iv) < 16 {
|
||||
return WXBizMsgCrypt_IllegalAesKey, ""
|
||||
}
|
||||
mode := cipher.NewCBCEncrypter(block, iv)
|
||||
if len(padded)%block.BlockSize() != 0 {
|
||||
// 应该已经经过 pad
|
||||
return WXBizMsgCrypt_EncryptAES_Error, ""
|
||||
}
|
||||
ciphertext := make([]byte, len(padded))
|
||||
mode.CryptBlocks(ciphertext, padded)
|
||||
|
||||
enc := base64.StdEncoding.EncodeToString(ciphertext)
|
||||
return WXBizMsgCrypt_OK, enc
|
||||
}
|
||||
|
||||
// Decrypt 解密 base64 文本,返回 (code, jsonContent)
|
||||
func (pc *Prpcrypt) Decrypt(base64Cipher string, receiveID string) (int, string) {
|
||||
cipherData, err := base64.StdEncoding.DecodeString(base64Cipher)
|
||||
if err != nil {
|
||||
return WXBizMsgCrypt_DecryptAES_Error, ""
|
||||
}
|
||||
block, err := aes.NewCipher(pc.Key)
|
||||
if err != nil {
|
||||
return WXBizMsgCrypt_DecryptAES_Error, ""
|
||||
}
|
||||
if len(cipherData)%block.BlockSize() != 0 {
|
||||
return WXBizMsgCrypt_DecryptAES_Error, ""
|
||||
}
|
||||
iv := pc.Key[:16]
|
||||
mode := cipher.NewCBCDecrypter(block, iv)
|
||||
plain := make([]byte, len(cipherData))
|
||||
mode.CryptBlocks(plain, cipherData)
|
||||
|
||||
// 去 PKCS7 填充 (blocksize=32)
|
||||
encoder := NewPKCS7Encoder()
|
||||
unpadded, err := encoder.Decode(plain)
|
||||
if err != nil {
|
||||
// Python 里如果 pad 错误会继续尝试并最后返回 IllegalBuffer
|
||||
// 这里直接返回 IllegalBuffer
|
||||
return WXBizMsgCrypt_IllegalBuffer, ""
|
||||
}
|
||||
|
||||
// 去掉前 16 字节随机字符串
|
||||
if len(unpadded) < 16 {
|
||||
return WXBizMsgCrypt_IllegalBuffer, ""
|
||||
}
|
||||
content := unpadded[16:]
|
||||
|
||||
if len(content) < 4 {
|
||||
return WXBizMsgCrypt_IllegalBuffer, ""
|
||||
}
|
||||
// 前 4 字节为 network order 的 json length
|
||||
jsonLen := binary.BigEndian.Uint32(content[:4])
|
||||
if int(jsonLen) > len(content)-4 {
|
||||
return WXBizMsgCrypt_IllegalBuffer, ""
|
||||
}
|
||||
jsonContent := string(content[4 : 4+jsonLen])
|
||||
fromReceiveID := string(content[4+jsonLen:])
|
||||
if fromReceiveID != receiveID {
|
||||
// receiveid 不匹配
|
||||
return WXBizMsgCrypt_ValidateCorpid_Error, ""
|
||||
}
|
||||
return WXBizMsgCrypt_OK, jsonContent
|
||||
}
|
||||
|
||||
// getRandom16BytesAsDigits 产生一个 16 字节的 ASCII 数字字符串(与 Python 版本行为一致)
|
||||
func getRandom16BytesAsDigits() ([]byte, error) {
|
||||
const digits = "0123456789"
|
||||
out := make([]byte, 16)
|
||||
for i := 0; i < 16; i++ {
|
||||
nBig, err := rand.Int(rand.Reader, big.NewInt(int64(len(digits))))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out[i] = digits[nBig.Int64()]
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// WXBizJsonMsgCrypt 将整个流程封装:初始化时传入 token, encodingAESKey, receiveID
|
||||
type WXBizJsonMsgCrypt struct {
|
||||
Token string
|
||||
EncodingKey []byte
|
||||
ReceiveID string
|
||||
encodingAES string // 原始 sEncodingAESKey
|
||||
}
|
||||
|
||||
// NewWXBizJsonMsgCrypt 构造:sToken, sEncodingAESKey, sReceiveID
|
||||
func NewWXBizJsonMsgCrypt(sToken, sEncodingAESKey, sReceiveID string) (*WXBizJsonMsgCrypt, int, error) {
|
||||
// Python 里是 base64.b64decode(sEncodingAESKey + "=")
|
||||
dec, err := base64.StdEncoding.DecodeString(sEncodingAESKey + "=")
|
||||
if err != nil {
|
||||
return nil, WXBizMsgCrypt_IllegalAesKey, fmt.Errorf("EncodingAESKey base64 decode fail: %w", err)
|
||||
}
|
||||
if len(dec) != 32 {
|
||||
return nil, WXBizMsgCrypt_IllegalAesKey, fmt.Errorf("EncodingAESKey decoded length must be 32 (got %d)", len(dec))
|
||||
}
|
||||
return &WXBizJsonMsgCrypt{
|
||||
Token: sToken,
|
||||
EncodingKey: dec,
|
||||
ReceiveID: sReceiveID,
|
||||
encodingAES: sEncodingAESKey,
|
||||
}, WXBizMsgCrypt_OK, nil
|
||||
}
|
||||
|
||||
// VerifyURL 校验并解密 sEchoStr(用于首次验证 URL)
|
||||
// 返回 (code, sReplyEchoStr)
|
||||
func (w *WXBizJsonMsgCrypt) VerifyURL(sMsgSignature, sTimeStamp, sNonce, sEchoStr string) (int, string) {
|
||||
sha1 := &SHA1{}
|
||||
ret, signature := sha1.GetSHA1(w.Token, sTimeStamp, sNonce, sEchoStr)
|
||||
if ret != WXBizMsgCrypt_OK {
|
||||
return ret, ""
|
||||
}
|
||||
if signature != sMsgSignature {
|
||||
return WXBizMsgCrypt_ValidateSignature_Error, ""
|
||||
}
|
||||
pc := NewPrpcrypt(w.EncodingKey)
|
||||
ret, reply := pc.Decrypt(sEchoStr, w.ReceiveID)
|
||||
return ret, reply
|
||||
}
|
||||
|
||||
// EncryptMsg 对要回复的消息 sReplyMsg(json 字符串)进行加密并生成外层 JSON 包装
|
||||
// 返回 (code, generatedJson)
|
||||
func (w *WXBizJsonMsgCrypt) EncryptMsg(sReplyMsg, sNonce string, timestamp ...string) (int, string) {
|
||||
pc := NewPrpcrypt(w.EncodingKey)
|
||||
ret, encrypt := pc.Encrypt(sReplyMsg, w.ReceiveID)
|
||||
if ret != WXBizMsgCrypt_OK {
|
||||
return ret, ""
|
||||
}
|
||||
// encrypt 是 base64 字符串(已经),确保是字符串
|
||||
encryptStr := encrypt
|
||||
|
||||
ts := ""
|
||||
if len(timestamp) > 0 && timestamp[0] != "" {
|
||||
ts = timestamp[0]
|
||||
} else {
|
||||
ts = fmt.Sprintf("%d", time.Now().Unix())
|
||||
}
|
||||
sha1 := &SHA1{}
|
||||
ret, signature := sha1.GetSHA1(w.Token, ts, sNonce, encryptStr)
|
||||
if ret != WXBizMsgCrypt_OK {
|
||||
return ret, ""
|
||||
}
|
||||
jp := &JsonParse{}
|
||||
jsonStr := jp.Generate(encryptStr, signature, ts, sNonce)
|
||||
return WXBizMsgCrypt_OK, jsonStr
|
||||
}
|
||||
|
||||
// DecryptMsg 验证签名并解密 POST 的 json 数据包
|
||||
// sPostData: POST 的 json 数据字符串(包含 encrypt 字段)
|
||||
// sMsgSignature: URL param msg_signature
|
||||
// sTimeStamp: timestamp
|
||||
// sNonce: nonce
|
||||
// 返回 (code, jsonContent)
|
||||
func (w *WXBizJsonMsgCrypt) DecryptMsg(sPostData, sMsgSignature, sTimeStamp, sNonce string) (int, string) {
|
||||
jp := &JsonParse{}
|
||||
ret, encrypt := jp.Extract(sPostData)
|
||||
if ret != WXBizMsgCrypt_OK {
|
||||
return ret, ""
|
||||
}
|
||||
sha1 := &SHA1{}
|
||||
ret, signature := sha1.GetSHA1(w.Token, sTimeStamp, sNonce, encrypt)
|
||||
if ret != WXBizMsgCrypt_OK {
|
||||
return ret, ""
|
||||
}
|
||||
if signature != sMsgSignature {
|
||||
return WXBizMsgCrypt_ValidateSignature_Error, ""
|
||||
}
|
||||
pc := NewPrpcrypt(w.EncodingKey)
|
||||
return pc.Decrypt(encrypt, w.ReceiveID)
|
||||
}
|
||||
17
backend/pkg/captcha/captcha.go
Normal file
17
backend/pkg/captcha/captcha.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package captcha
|
||||
|
||||
import gocap "github.com/ackcoder/go-cap"
|
||||
|
||||
type Captcha struct {
|
||||
*gocap.Cap
|
||||
}
|
||||
|
||||
func NewCaptcha() *Captcha {
|
||||
return &Captcha{
|
||||
Cap: gocap.New(
|
||||
gocap.WithChallenge(50, 32, 3),
|
||||
gocap.WithChallengeExpires(60*2),
|
||||
gocap.WithTokenExpires(60*5),
|
||||
),
|
||||
}
|
||||
}
|
||||
229
backend/pkg/cas/cas.go
Normal file
229
backend/pkg/cas/cas.go
Normal file
@@ -0,0 +1,229 @@
|
||||
package cas
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/chaitin/panda-wiki/log"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
logger *log.Logger
|
||||
ctx context.Context
|
||||
config *Config
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
ServerURL string `json:"server_url"` // CAS服务器URL,如 https://cas.example.com/cas
|
||||
ServiceURL string `json:"service_url"` // 服务回调URL
|
||||
LoginPath string `json:"login_path"` // 登录路径,默认为 /login
|
||||
ValidatePath string `json:"validate_path"` // 验证路径,默认根据版本自动选择
|
||||
Version string `json:"version"` // CAS协议版本: "2" 或 "3"
|
||||
CASUrl string `json:"cas_url"`
|
||||
}
|
||||
|
||||
type UserInfo struct {
|
||||
Username string `json:"username"`
|
||||
Attributes map[string]string `json:"attributes"`
|
||||
}
|
||||
|
||||
// CAS2ServiceResponse CAS2服务验证响应结构
|
||||
type CAS2ServiceResponse struct {
|
||||
XMLName xml.Name `xml:"serviceResponse"`
|
||||
Success *CAS2AuthenticationSuccess `xml:"authenticationSuccess"`
|
||||
Failure *AuthenticationFailure `xml:"authenticationFailure"`
|
||||
}
|
||||
|
||||
type CAS2AuthenticationSuccess struct {
|
||||
User string `xml:"user"`
|
||||
}
|
||||
|
||||
// CAS3ServiceResponse CAS3服务验证响应结构
|
||||
type CAS3ServiceResponse struct {
|
||||
XMLName xml.Name `xml:"serviceResponse"`
|
||||
Success *CAS3AuthenticationSuccess `xml:"authenticationSuccess"`
|
||||
Failure *AuthenticationFailure `xml:"authenticationFailure"`
|
||||
}
|
||||
|
||||
type CAS3AuthenticationSuccess struct {
|
||||
User string `xml:"user"`
|
||||
Attributes CAS3Attributes `xml:"attributes"`
|
||||
}
|
||||
|
||||
type AuthenticationFailure struct {
|
||||
Code string `xml:"code,attr"`
|
||||
Message string `xml:",chardata"`
|
||||
}
|
||||
|
||||
type CAS3Attributes struct {
|
||||
Email string `xml:"email"`
|
||||
Name string `xml:"name"`
|
||||
AvatarURL string `xml:"avatar_url"`
|
||||
}
|
||||
|
||||
const (
|
||||
defaultLoginPath = "/login"
|
||||
defaultValidatePathCAS2 = "/serviceValidate"
|
||||
defaultValidatePathCAS3 = "/p3/serviceValidate"
|
||||
callbackPath = "/share/pro/v1/openapi/cas/callback"
|
||||
)
|
||||
|
||||
// NewClient 创建CAS客户端
|
||||
func NewClient(ctx context.Context, logger *log.Logger, config Config) (*Client, error) {
|
||||
// 设置默认登录路径
|
||||
if config.LoginPath == "" {
|
||||
config.LoginPath = defaultLoginPath
|
||||
}
|
||||
|
||||
// 如果版本为空,默认使用CAS3
|
||||
if config.Version == "" {
|
||||
config.Version = "3"
|
||||
}
|
||||
|
||||
// 根据版本设置默认验证路径
|
||||
if config.ValidatePath == "" {
|
||||
switch config.Version {
|
||||
case "3":
|
||||
config.ValidatePath = defaultValidatePathCAS3
|
||||
case "2", "":
|
||||
config.ValidatePath = defaultValidatePathCAS2
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported CAS version: %s, supported versions are '2' and '3'", config.Version)
|
||||
}
|
||||
}
|
||||
|
||||
// 构建服务回调URL
|
||||
if config.ServiceURL != "" {
|
||||
serviceURL, err := url.Parse(config.ServiceURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid service URL: %w", err)
|
||||
}
|
||||
serviceURL.Path = callbackPath
|
||||
config.ServiceURL = serviceURL.String()
|
||||
}
|
||||
|
||||
return &Client{
|
||||
ctx: ctx,
|
||||
logger: logger.WithModule("pkg.cas"),
|
||||
config: &config,
|
||||
httpClient: &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetLoginURL 获取CAS登录URL
|
||||
func (c *Client) GetLoginURL(state string) string {
|
||||
loginURL := strings.TrimSuffix(c.config.ServerURL, "/") + c.config.LoginPath
|
||||
|
||||
params := url.Values{}
|
||||
params.Set("service", c.config.ServiceURL+"?state="+state)
|
||||
|
||||
return loginURL + "?" + params.Encode()
|
||||
}
|
||||
|
||||
// ValidateTicket 验证CAS票据并获取用户信息
|
||||
func (c *Client) ValidateTicket(ticket, state string) (*UserInfo, error) {
|
||||
validateURL := strings.TrimSuffix(c.config.ServerURL, "/") + c.config.ValidatePath
|
||||
|
||||
params := url.Values{}
|
||||
params.Set("service", c.config.ServiceURL+"?state="+state)
|
||||
params.Set("ticket", ticket)
|
||||
|
||||
fullURL := validateURL + "?" + params.Encode()
|
||||
|
||||
c.logger.Info("validating CAS ticket",
|
||||
log.String("url", fullURL),
|
||||
log.String("version", c.config.Version))
|
||||
|
||||
resp, err := c.httpClient.Get(fullURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate ticket: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
c.logger.Info("CAS validation response", log.String("response", string(body)))
|
||||
|
||||
// 根据CAS版本解析不同的响应格式
|
||||
switch c.config.Version {
|
||||
case "2":
|
||||
return c.parseCAS2Response(body)
|
||||
case "3":
|
||||
return c.parseCAS3Response(body)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported CAS version: %s", c.config.Version)
|
||||
}
|
||||
}
|
||||
|
||||
// parseCAS2Response 解析CAS2响应
|
||||
func (c *Client) parseCAS2Response(body []byte) (*UserInfo, error) {
|
||||
var serviceResp CAS2ServiceResponse
|
||||
if err := xml.Unmarshal(body, &serviceResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse CAS2 response: %w", err)
|
||||
}
|
||||
|
||||
if serviceResp.Failure != nil {
|
||||
return nil, fmt.Errorf("CAS validation failed: %s - %s",
|
||||
serviceResp.Failure.Code, strings.TrimSpace(serviceResp.Failure.Message))
|
||||
}
|
||||
|
||||
if serviceResp.Success == nil {
|
||||
return nil, fmt.Errorf("invalid CAS2 response: no success or failure element")
|
||||
}
|
||||
|
||||
userInfo := &UserInfo{
|
||||
Username: serviceResp.Success.User,
|
||||
Attributes: map[string]string{
|
||||
"name": serviceResp.Success.User, // CAS2通常只返回用户名
|
||||
},
|
||||
}
|
||||
|
||||
return userInfo, nil
|
||||
}
|
||||
|
||||
// parseCAS3Response 解析CAS3响应
|
||||
func (c *Client) parseCAS3Response(body []byte) (*UserInfo, error) {
|
||||
var serviceResp CAS3ServiceResponse
|
||||
if err := xml.Unmarshal(body, &serviceResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse CAS3 response: %w", err)
|
||||
}
|
||||
|
||||
if serviceResp.Failure != nil {
|
||||
return nil, fmt.Errorf("CAS validation failed: %s - %s",
|
||||
serviceResp.Failure.Code, strings.TrimSpace(serviceResp.Failure.Message))
|
||||
}
|
||||
|
||||
if serviceResp.Success == nil {
|
||||
return nil, fmt.Errorf("invalid CAS3 response: no success or failure element")
|
||||
}
|
||||
|
||||
userInfo := &UserInfo{
|
||||
Username: serviceResp.Success.User,
|
||||
Attributes: map[string]string{
|
||||
"email": serviceResp.Success.Attributes.Email,
|
||||
"name": serviceResp.Success.Attributes.Name,
|
||||
"avatar_url": serviceResp.Success.Attributes.AvatarURL,
|
||||
},
|
||||
}
|
||||
|
||||
// 如果没有显示名称,使用用户名
|
||||
if userInfo.Attributes["name"] == "" {
|
||||
userInfo.Attributes["name"] = userInfo.Username
|
||||
}
|
||||
|
||||
return userInfo, nil
|
||||
}
|
||||
351
backend/pkg/dingtalk/dingtalk.go
Normal file
351
backend/pkg/dingtalk/dingtalk.go
Normal file
@@ -0,0 +1,351 @@
|
||||
package dingtalk
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
openapi "github.com/alibabacloud-go/darabonba-openapi/v2/client"
|
||||
dingtalkcard_1_0 "github.com/alibabacloud-go/dingtalk/card_1_0"
|
||||
dingtalkoauth2_1_0 "github.com/alibabacloud-go/dingtalk/v2/oauth2_1_0"
|
||||
"github.com/alibabacloud-go/tea/tea"
|
||||
|
||||
"github.com/chaitin/panda-wiki/log"
|
||||
"github.com/chaitin/panda-wiki/store/cache"
|
||||
)
|
||||
|
||||
const (
|
||||
callbackPath = "/share/pro/v1/openapi/dingtalk/callback"
|
||||
userInfoUrl = "https://api.dingtalk.com/v1.0/contact/users/me"
|
||||
DepartmentListUrl = "https://oapi.dingtalk.com/department/list"
|
||||
// https://open.dingtalk.com/document/isvapp/queries-the-complete-information-of-a-department-user
|
||||
UserListUrl = "https://oapi.dingtalk.com/topapi/v2/user/list"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
ctx context.Context
|
||||
logger *log.Logger
|
||||
httpClient *http.Client
|
||||
clientID string
|
||||
clientSecret string
|
||||
oauthClient *dingtalkoauth2_1_0.Client
|
||||
cardClient *dingtalkcard_1_0.Client
|
||||
dingTalkAuthURL string
|
||||
cache *cache.Cache
|
||||
}
|
||||
|
||||
// UserInfo 用于解析获取用户信息的接口返回
|
||||
type UserInfo struct {
|
||||
Nick string `json:"nick"`
|
||||
UnionID string `json:"unionId"`
|
||||
OpenID string `json:"openId"`
|
||||
AvatarURL string `json:"avatarUrl"`
|
||||
StateCode string `json:"stateCode"`
|
||||
}
|
||||
|
||||
// DepartmentListRsp 用于解析组织信息接口返回
|
||||
type DepartmentListRsp struct {
|
||||
Errcode int `json:"errcode"`
|
||||
Department []struct {
|
||||
CreateDeptGroup bool `json:"createDeptGroup"`
|
||||
Name string `json:"name"`
|
||||
Id int `json:"id"`
|
||||
AutoAddUser bool `json:"autoAddUser"`
|
||||
Parentid int `json:"parentid,omitempty"`
|
||||
} `json:"department"`
|
||||
Errmsg string `json:"errmsg"`
|
||||
}
|
||||
|
||||
type GetUserListResp struct {
|
||||
Errcode int `json:"errcode"`
|
||||
Result struct {
|
||||
HasMore bool `json:"has_more"`
|
||||
List []UserDetail `json:"list"`
|
||||
} `json:"result"`
|
||||
Errmsg string `json:"errmsg"`
|
||||
}
|
||||
|
||||
type UserDetail struct {
|
||||
Active bool `json:"active"`
|
||||
Admin bool `json:"admin"`
|
||||
Avatar string `json:"avatar"`
|
||||
Boss bool `json:"boss"`
|
||||
DeptIdList []int `json:"dept_id_list"`
|
||||
DeptOrder int64 `json:"dept_order"`
|
||||
Email string `json:"email"`
|
||||
ExclusiveAccount bool `json:"exclusive_account"`
|
||||
HideMobile bool `json:"hide_mobile"`
|
||||
JobNumber string `json:"job_number"`
|
||||
Leader bool `json:"leader"`
|
||||
Mobile string `json:"mobile"`
|
||||
Name string `json:"name"`
|
||||
Remark string `json:"remark"`
|
||||
StateCode string `json:"state_code"`
|
||||
Telephone string `json:"telephone"`
|
||||
Title string `json:"title"`
|
||||
Unionid string `json:"unionid"`
|
||||
Userid string `json:"userid"`
|
||||
WorkPlace string `json:"work_place"`
|
||||
}
|
||||
|
||||
func NewDingTalkClient(ctx context.Context, logger *log.Logger, clientId, clientSecret string, cache *cache.Cache) (*Client, error) {
|
||||
config := &openapi.Config{}
|
||||
config.Protocol = tea.String("https")
|
||||
config.RegionId = tea.String("central")
|
||||
oauthClient, err := dingtalkoauth2_1_0.NewClient(config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create oauth client: %w", err)
|
||||
}
|
||||
cardClient, err := dingtalkcard_1_0.NewClient(config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create card client: %w", err)
|
||||
}
|
||||
return &Client{
|
||||
ctx: ctx,
|
||||
logger: logger.WithModule("pkg.dingtalk"),
|
||||
httpClient: &http.Client{},
|
||||
clientID: clientId,
|
||||
clientSecret: clientSecret,
|
||||
oauthClient: oauthClient,
|
||||
cardClient: cardClient,
|
||||
dingTalkAuthURL: "https://login.dingtalk.com/oauth2/auth",
|
||||
cache: cache,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GenerateAuthURL 生成钉钉授权URL
|
||||
func (c *Client) GenerateAuthURL(baseUrl string, state string) string {
|
||||
redirectURI, err := url.JoinPath(baseUrl, callbackPath)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to join path", log.Error(err))
|
||||
return ""
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
params.Add("response_type", "code")
|
||||
params.Add("client_id", c.clientID)
|
||||
params.Add("redirect_uri", redirectURI)
|
||||
params.Add("scope", "openid")
|
||||
params.Add("state", state)
|
||||
params.Add("prompt", "consent")
|
||||
|
||||
return fmt.Sprintf("%s?%s", c.dingTalkAuthURL, params.Encode())
|
||||
}
|
||||
|
||||
func (c *Client) GetAccessTokenByCode(code string) (string, error) {
|
||||
request := &dingtalkoauth2_1_0.GetUserTokenRequest{
|
||||
ClientId: tea.String(c.clientID),
|
||||
ClientSecret: tea.String(c.clientSecret),
|
||||
Code: tea.String(code),
|
||||
GrantType: tea.String("authorization_code"),
|
||||
}
|
||||
response, err := c.oauthClient.GetUserToken(request)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get user access token: %w", err)
|
||||
}
|
||||
accessToken := tea.StringValue(response.Body.AccessToken)
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func (c *Client) GetAccessToken() (string, error) {
|
||||
ctx := context.Background()
|
||||
cacheKey := fmt.Sprintf("dingtalk-access-token:%s", c.clientID)
|
||||
cachedData, err := c.cache.Get(ctx, cacheKey).Result()
|
||||
if err == nil && cachedData != "" {
|
||||
return cachedData, nil
|
||||
}
|
||||
|
||||
request := &dingtalkoauth2_1_0.GetAccessTokenRequest{
|
||||
AppKey: tea.String(c.clientID),
|
||||
AppSecret: tea.String(c.clientSecret),
|
||||
}
|
||||
response, tryErr := func() (_resp *dingtalkoauth2_1_0.GetAccessTokenResponse, _e error) {
|
||||
defer func() {
|
||||
if r := tea.Recover(recover()); r != nil {
|
||||
_e = r
|
||||
}
|
||||
}()
|
||||
_resp, _err := c.oauthClient.GetAccessToken(request)
|
||||
if _err != nil {
|
||||
return nil, _err
|
||||
}
|
||||
|
||||
return _resp, nil
|
||||
}()
|
||||
if tryErr != nil {
|
||||
return "", tryErr
|
||||
}
|
||||
accessToken := *response.Body.AccessToken
|
||||
c.logger.Debug("get access token", log.String("access_token", accessToken), log.Int("expire_in", int(*response.Body.ExpireIn)))
|
||||
|
||||
if err := c.cache.Set(ctx, cacheKey, accessToken, time.Duration(*response.Body.ExpireIn-300)*time.Second).Err(); err != nil {
|
||||
c.logger.Warn("failed to set cache", log.Error(err))
|
||||
}
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func (c *Client) GetUserInfoByCode(code string) (*UserInfo, error) {
|
||||
req, err := http.NewRequest("GET", userInfoUrl, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create GET request: %w", err)
|
||||
}
|
||||
|
||||
accessToken, err := c.GetAccessTokenByCode(code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set request headers
|
||||
req.Header.Set("x-acs-dingtalk-access-token", accessToken)
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send GET request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("DingTalk API returned non-200 status: %s, response: %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
var userInfo UserInfo
|
||||
if err := json.Unmarshal(body, &userInfo); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal JSON response: %w", err)
|
||||
}
|
||||
|
||||
return &userInfo, nil
|
||||
}
|
||||
|
||||
func (c *Client) GetDepartmentList() (*DepartmentListRsp, error) {
|
||||
accessToken, err := c.GetAccessToken()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
params.Add("access_token", accessToken)
|
||||
requestURL := fmt.Sprintf("%s?%s", DepartmentListUrl, params.Encode())
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, requestURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("DingTalk API returned non-200 status: %s, response: %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
c.logger.Debug("DepartmentListUrl:", log.String("body", string(body)))
|
||||
var departmentListRsp DepartmentListRsp
|
||||
if err := json.Unmarshal(body, &departmentListRsp); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal JSON response: %w", err)
|
||||
}
|
||||
|
||||
if departmentListRsp.Errcode != 0 {
|
||||
return nil, fmt.Errorf("DingTalk API error: errcode=%d errmsg=%s", departmentListRsp.Errcode, departmentListRsp.Errmsg)
|
||||
}
|
||||
|
||||
return &departmentListRsp, nil
|
||||
}
|
||||
|
||||
func (c *Client) GetAllUserList(deptID int) ([]UserDetail, error) {
|
||||
depth := 0
|
||||
const maxDepth = 10
|
||||
|
||||
userList := make([]UserDetail, 0)
|
||||
for depth < maxDepth {
|
||||
resp, err := c.GetUserList(deptID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(resp.Result.List) > 0 {
|
||||
userList = append(userList, resp.Result.List...)
|
||||
}
|
||||
if !resp.Result.HasMore {
|
||||
break
|
||||
}
|
||||
depth++
|
||||
}
|
||||
return userList, nil
|
||||
}
|
||||
|
||||
func (c *Client) GetUserList(deptID int) (*GetUserListResp, error) {
|
||||
accessToken, err := c.GetAccessToken()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
params.Add("access_token", accessToken)
|
||||
requestURL := fmt.Sprintf("%s?%s", UserListUrl, params.Encode())
|
||||
|
||||
bodyMap := map[string]interface{}{
|
||||
"dept_id": deptID,
|
||||
"size": 100,
|
||||
"cursor": 0,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(bodyMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, requestURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("DingTalk API returned non-200 status: %s, response: %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
c.logger.Debug("GetUserList:", log.String("body", string(body)))
|
||||
var getUserListResp GetUserListResp
|
||||
if err := json.Unmarshal(body, &getUserListResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal JSON response: %w", err)
|
||||
}
|
||||
|
||||
if getUserListResp.Errcode != 0 {
|
||||
return nil, fmt.Errorf("DingTalk GetUserList error: errcode=%d errcode=%s", getUserListResp.Errcode, getUserListResp.Errmsg)
|
||||
}
|
||||
|
||||
return &getUserListResp, nil
|
||||
}
|
||||
123
backend/pkg/feishu/feishu.go
Normal file
123
backend/pkg/feishu/feishu.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package feishu
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/chaitin/panda-wiki/log"
|
||||
)
|
||||
|
||||
const (
|
||||
AuthURL = "https://accounts.feishu.cn/open-apis/authen/v1/authorize"
|
||||
TokenURL = "https://open.feishu.cn/open-apis/authen/v2/oauth/token"
|
||||
UserInfoURL = "https://open.feishu.cn/open-apis/authen/v1/user_info"
|
||||
callbackPath = "/share/pro/v1/openapi/feishu/callback"
|
||||
)
|
||||
|
||||
var oauthEndpoint = oauth2.Endpoint{
|
||||
AuthURL: AuthURL,
|
||||
TokenURL: TokenURL,
|
||||
}
|
||||
|
||||
// Client 飞书客户端
|
||||
type Client struct {
|
||||
context context.Context
|
||||
oauthConfig *oauth2.Config
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
Data UserInfo `json:"data"`
|
||||
}
|
||||
type UserInfo struct {
|
||||
Name string `json:"name"`
|
||||
EnName string `json:"en_name"`
|
||||
AvatarUrl string `json:"avatar_url"`
|
||||
AvatarThumb string `json:"avatar_thumb"`
|
||||
AvatarMiddle string `json:"avatar_middle"`
|
||||
AvatarBig string `json:"avatar_big"`
|
||||
OpenId string `json:"open_id"`
|
||||
UnionId string `json:"union_id"`
|
||||
Email string `json:"email"`
|
||||
EnterpriseEmail string `json:"enterprise_email"`
|
||||
UserId string `json:"user_id"`
|
||||
Mobile string `json:"mobile"`
|
||||
TenantKey string `json:"tenant_key"`
|
||||
EmployeeNo string `json:"employee_no"`
|
||||
}
|
||||
|
||||
func NewClient(ctx context.Context, logger *log.Logger, appID, appSecret, baseUrl string) (*Client, error) {
|
||||
redirectURI, err := url.JoinPath(baseUrl, callbackPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
oauthConfig := &oauth2.Config{
|
||||
ClientID: appID,
|
||||
ClientSecret: appSecret,
|
||||
RedirectURL: redirectURI,
|
||||
Endpoint: oauthEndpoint,
|
||||
Scopes: []string{},
|
||||
}
|
||||
|
||||
return &Client{
|
||||
context: ctx,
|
||||
logger: logger.WithModule("feishu.client"),
|
||||
oauthConfig: oauthConfig,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GenerateAuthURL 生成授权 URL
|
||||
func (c *Client) GenerateAuthURL(state string, verifier string) string {
|
||||
return c.oauthConfig.AuthCodeURL(state, oauth2.S256ChallengeOption(verifier))
|
||||
}
|
||||
|
||||
// GetAccessToken 通过授权码获取访问令牌
|
||||
func (c *Client) GetAccessToken(ctx context.Context, code string, codeVerifier string) (*oauth2.Token, error) {
|
||||
token, err := c.oauthConfig.Exchange(ctx, code, oauth2.VerifierOption(codeVerifier))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("oauthConfig.Exchange() failed: %w", err)
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// GetUserInfoByCode 获取用户信息
|
||||
func (c *Client) GetUserInfoByCode(ctx context.Context, code string, codeVerifier string) (*UserInfo, error) {
|
||||
token, err := c.oauthConfig.Exchange(ctx, code, oauth2.VerifierOption(codeVerifier))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("oauthConfig.Exchange() failed: %w", err)
|
||||
}
|
||||
|
||||
client := c.oauthConfig.Client(ctx, token)
|
||||
req, err := http.NewRequest("GET", UserInfoURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user info: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var r Response
|
||||
if err := json.NewDecoder(resp.Body).Decode(&r); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode user info: %w", err)
|
||||
}
|
||||
|
||||
c.logger.Info("GetUserInfoByCode", log.Any("resp", r))
|
||||
|
||||
if r.Code != 0 {
|
||||
return nil, fmt.Errorf("failed to get user info: %s", r.Msg)
|
||||
}
|
||||
|
||||
return &r.Data, nil
|
||||
}
|
||||
207
backend/pkg/ldap/ldap.go
Normal file
207
backend/pkg/ldap/ldap.go
Normal file
@@ -0,0 +1,207 @@
|
||||
package ldap
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/go-ldap/ldap/v3"
|
||||
|
||||
"github.com/chaitin/panda-wiki/log"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
logger *log.Logger
|
||||
ctx context.Context
|
||||
config *Config
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
ServerURL string `json:"server_url"` // LDAP服务器URL,如 ldap://openldap.company.com:389
|
||||
BindDN string `json:"bind_dn"` // 绑定DN,如 cn=admin,dc=company,dc=com
|
||||
BindPassword string `json:"bind_password"` // 绑定密码
|
||||
UserBaseDN string `json:"user_base_dn"` // 用户基础DN,如 ou=People,dc=company,dc=com
|
||||
UserFilter string `json:"user_filter"` // 用户查询过滤器,如 (&(objectClass=person)(uid=%s))
|
||||
UserIDAttr string `json:"user_id_attr"` // 用户ID属性,默认 uid
|
||||
UserNameAttr string `json:"user_name_attr"` // 用户名属性,默认 cn
|
||||
UserEmailAttr string `json:"user_email_attr"` // 用户邮箱属性,默认 mail
|
||||
}
|
||||
|
||||
type UserInfo struct {
|
||||
ID string `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
DN string `json:"dn"` // Distinguished Name
|
||||
}
|
||||
|
||||
const (
|
||||
defaultUserIDAttr = "uid"
|
||||
defaultUserNameAttr = "cn"
|
||||
defaultUserEmailAttr = "mail"
|
||||
defaultUserFilter = "(&(objectClass=person)(uid=%s))"
|
||||
)
|
||||
|
||||
// NewClient 创建LDAP客户端
|
||||
func NewClient(ctx context.Context, logger *log.Logger, config Config) (*Client, error) {
|
||||
// 设置默认值
|
||||
if config.UserIDAttr == "" {
|
||||
config.UserIDAttr = defaultUserIDAttr
|
||||
}
|
||||
if config.UserNameAttr == "" {
|
||||
config.UserNameAttr = defaultUserNameAttr
|
||||
}
|
||||
if config.UserEmailAttr == "" {
|
||||
config.UserEmailAttr = defaultUserEmailAttr
|
||||
}
|
||||
if config.UserFilter == "" {
|
||||
config.UserFilter = defaultUserFilter
|
||||
}
|
||||
|
||||
// 验证必需的配置
|
||||
if config.ServerURL == "" {
|
||||
return nil, fmt.Errorf("LDAP server URL is required")
|
||||
}
|
||||
if config.BindDN == "" {
|
||||
return nil, fmt.Errorf("bind DN is required")
|
||||
}
|
||||
if config.UserBaseDN == "" {
|
||||
return nil, fmt.Errorf("user base DN is required")
|
||||
}
|
||||
|
||||
return &Client{
|
||||
ctx: ctx,
|
||||
logger: logger.WithModule("pkg.ldap"),
|
||||
config: &config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Authenticate 验证用户凭据并获取用户信息
|
||||
func (c *Client) Authenticate(username, password string) (*UserInfo, error) {
|
||||
// 连接到LDAP服务器
|
||||
conn, err := ldap.DialURL(c.config.ServerURL)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to connect to LDAP server", log.Error(err))
|
||||
return nil, fmt.Errorf("failed to connect to LDAP server: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// 使用管理员账户绑定
|
||||
err = conn.Bind(c.config.BindDN, c.config.BindPassword)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to bind with admin credentials", log.Error(err))
|
||||
return nil, fmt.Errorf("failed to bind with admin credentials: %w", err)
|
||||
}
|
||||
|
||||
// 搜索用户
|
||||
userInfo, err := c.searchUser(conn, username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 验证用户密码
|
||||
err = conn.Bind(userInfo.DN, password)
|
||||
if err != nil {
|
||||
c.logger.Error("user authentication failed",
|
||||
log.String("username", username),
|
||||
log.String("dn", userInfo.DN),
|
||||
log.Error(err))
|
||||
return nil, fmt.Errorf("authentication failed: invalid credentials")
|
||||
}
|
||||
|
||||
c.logger.Info("user authenticated successfully",
|
||||
log.String("username", username),
|
||||
log.String("dn", userInfo.DN))
|
||||
|
||||
return userInfo, nil
|
||||
}
|
||||
|
||||
// searchUser 搜索用户信息
|
||||
func (c *Client) searchUser(conn *ldap.Conn, username string) (*UserInfo, error) {
|
||||
// 构建搜索过滤器
|
||||
filter := fmt.Sprintf(c.config.UserFilter, username)
|
||||
|
||||
// 构建搜索请求
|
||||
searchRequest := ldap.NewSearchRequest(
|
||||
c.config.UserBaseDN,
|
||||
ldap.ScopeWholeSubtree,
|
||||
ldap.NeverDerefAliases,
|
||||
0, // 不限制结果数量
|
||||
0, // 不限制搜索时间
|
||||
false,
|
||||
filter,
|
||||
[]string{c.config.UserIDAttr, c.config.UserNameAttr, c.config.UserEmailAttr},
|
||||
nil,
|
||||
)
|
||||
|
||||
c.logger.Info("searching for user",
|
||||
log.String("filter", filter),
|
||||
log.String("base_dn", c.config.UserBaseDN))
|
||||
|
||||
// 执行搜索
|
||||
searchResult, err := conn.Search(searchRequest)
|
||||
if err != nil {
|
||||
c.logger.Error("user search failed", log.Error(err))
|
||||
return nil, fmt.Errorf("user search failed: %w", err)
|
||||
}
|
||||
|
||||
// 检查搜索结果
|
||||
if len(searchResult.Entries) == 0 {
|
||||
c.logger.Warn("user not found", log.String("username", username))
|
||||
return nil, fmt.Errorf("user not found: %s", username)
|
||||
}
|
||||
|
||||
if len(searchResult.Entries) > 1 {
|
||||
c.logger.Warn("multiple users found",
|
||||
log.String("username", username),
|
||||
log.Int("count", len(searchResult.Entries)))
|
||||
return nil, fmt.Errorf("multiple users found for username: %s", username)
|
||||
}
|
||||
|
||||
// 解析用户信息
|
||||
entry := searchResult.Entries[0]
|
||||
userInfo := &UserInfo{
|
||||
DN: entry.DN,
|
||||
ID: c.getAttributeValue(entry, c.config.UserIDAttr),
|
||||
Username: c.getAttributeValue(entry, c.config.UserNameAttr),
|
||||
Email: c.getAttributeValue(entry, c.config.UserEmailAttr),
|
||||
}
|
||||
|
||||
// 如果没有获取到用户名,使用ID作为用户名
|
||||
if userInfo.Username == "" {
|
||||
userInfo.Username = userInfo.ID
|
||||
}
|
||||
|
||||
c.logger.Info("user found",
|
||||
log.String("dn", userInfo.DN),
|
||||
log.String("id", userInfo.ID),
|
||||
log.String("username", userInfo.Username),
|
||||
log.String("email", userInfo.Email))
|
||||
|
||||
return userInfo, nil
|
||||
}
|
||||
|
||||
// getAttributeValue 获取LDAP属性值
|
||||
func (c *Client) getAttributeValue(entry *ldap.Entry, attrName string) string {
|
||||
values := entry.GetAttributeValues(attrName)
|
||||
if len(values) > 0 {
|
||||
return strings.TrimSpace(values[0])
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// TestConnection 测试LDAP连接
|
||||
func (c *Client) TestConnection() error {
|
||||
conn, err := ldap.DialURL(c.config.ServerURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to LDAP server: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
err = conn.Bind(c.config.BindDN, c.config.BindPassword)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to bind with admin credentials: %w", err)
|
||||
}
|
||||
|
||||
c.logger.Info("LDAP connection test successful")
|
||||
return nil
|
||||
}
|
||||
137
backend/pkg/oauth/github.go
Normal file
137
backend/pkg/oauth/github.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/chaitin/panda-wiki/consts"
|
||||
"github.com/chaitin/panda-wiki/log"
|
||||
)
|
||||
|
||||
const (
|
||||
githubAuthorizeURL = "https://github.com/login/oauth/authorize"
|
||||
githubTokenURL = "https://github.com/login/oauth/access_token"
|
||||
githubUserInfoURL = "https://api.github.com/user"
|
||||
githubUserEmailURL = "https://api.github.com/user/emails"
|
||||
githubCallbackPathPro = "/share/pro/v1/openapi/github/callback"
|
||||
githubCallbackPath = "/share/v1/openapi/github/callback"
|
||||
)
|
||||
|
||||
func NewGithubClient(ctx context.Context, logger *log.Logger, clientID, clientSecret, redirectURI, proxyURL string) (*Client, error) {
|
||||
licenseEdition, ok := ctx.Value(consts.ContextKeyEdition).(consts.LicenseEdition)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to retrieve license edition from context")
|
||||
}
|
||||
|
||||
redirectURL, _ := url.Parse(redirectURI)
|
||||
redirectURL.Path = githubCallbackPath
|
||||
|
||||
if licenseEdition > consts.LicenseEditionFree {
|
||||
redirectURL.Path = githubCallbackPathPro
|
||||
}
|
||||
|
||||
redirectURI = redirectURL.String()
|
||||
|
||||
var httpClient *http.Client
|
||||
if proxyURL != "" {
|
||||
proxyURLParsed, err := url.Parse(proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid proxy URL: %w", err)
|
||||
}
|
||||
|
||||
httpClient = &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyURL(proxyURLParsed),
|
||||
},
|
||||
}
|
||||
logger.Info("GitHub OAuth client configured with proxy", log.String("proxy", proxyURL))
|
||||
} else {
|
||||
httpClient = http.DefaultClient
|
||||
}
|
||||
|
||||
config := Config{
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
Scopes: []string{"user:email"},
|
||||
AuthorizeURL: githubAuthorizeURL,
|
||||
TokenURL: githubTokenURL,
|
||||
UserInfoURL: githubUserInfoURL,
|
||||
IDField: "id",
|
||||
NameField: "login",
|
||||
AvatarField: "avatar_url",
|
||||
EmailField: "email",
|
||||
RedirectURI: redirectURI,
|
||||
}
|
||||
|
||||
oauthConfig := &oauth2.Config{
|
||||
ClientID: config.ClientID,
|
||||
ClientSecret: config.ClientSecret,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: config.AuthorizeURL,
|
||||
TokenURL: config.TokenURL,
|
||||
},
|
||||
RedirectURL: redirectURI,
|
||||
Scopes: config.Scopes,
|
||||
}
|
||||
|
||||
if proxyURL != "" {
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
|
||||
}
|
||||
|
||||
return &Client{
|
||||
ctx: ctx,
|
||||
logger: logger.WithModule("pkg.oauth"),
|
||||
oauth: oauthConfig,
|
||||
httpClient: httpClient,
|
||||
config: &config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Client) GetGithubPrimaryEmail(token *oauth2.Token) (string, error) {
|
||||
var client *http.Client
|
||||
if c.httpClient != nil {
|
||||
ctx := context.WithValue(c.ctx, oauth2.HTTPClient, c.httpClient)
|
||||
client = c.oauth.Client(ctx, token)
|
||||
} else {
|
||||
client = c.oauth.Client(c.ctx, token)
|
||||
}
|
||||
|
||||
type Email struct {
|
||||
Email string `json:"email"`
|
||||
Primary bool `json:"primary"`
|
||||
Verified bool `json:"verified"`
|
||||
}
|
||||
|
||||
resp, err := client.Get(githubUserEmailURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
buf, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
c.logger.Info("GetGithubPrimaryEmail:", log.Any("buf", string(buf)))
|
||||
|
||||
var emails []Email
|
||||
if err := json.Unmarshal(buf, &emails); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for _, email := range emails {
|
||||
if email.Primary && email.Verified {
|
||||
return email.Email, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", errors.New("no primary verified email found")
|
||||
}
|
||||
110
backend/pkg/oauth/oauth.go
Normal file
110
backend/pkg/oauth/oauth.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/chaitin/panda-wiki/log"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
logger *log.Logger
|
||||
ctx context.Context
|
||||
config *Config
|
||||
oauth *oauth2.Config
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
const (
|
||||
callbackPath = "/share/pro/v1/openapi/oauth/callback"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
ClientID string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"`
|
||||
RedirectURI string `json:"redirect_uri,omitempty"`
|
||||
Scopes []string `json:"scopes,omitempty"`
|
||||
AuthorizeURL string `json:"authorize_url,omitempty"`
|
||||
TokenURL string `json:"token_url,omitempty"`
|
||||
UserInfoURL string `json:"user_info_url,omitempty"`
|
||||
IDField string `json:"id_field,omitempty"`
|
||||
NameField string `json:"name_field,omitempty"`
|
||||
AvatarField string `json:"avatar_field,omitempty"`
|
||||
EmailField string `json:"email_field,omitempty"`
|
||||
}
|
||||
type UserInfo struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
AvatarUrl string `json:"avatar_url"`
|
||||
}
|
||||
|
||||
// NewClient 创建OAuth客户端
|
||||
func NewClient(ctx context.Context, logger *log.Logger, baseUrl string, config Config) (*Client, error) {
|
||||
redirectURI, err := url.JoinPath(baseUrl, callbackPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Client{
|
||||
ctx: ctx,
|
||||
logger: logger.WithModule("pkg.oauth"),
|
||||
oauth: &oauth2.Config{
|
||||
ClientID: config.ClientID,
|
||||
ClientSecret: config.ClientSecret,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: config.AuthorizeURL,
|
||||
TokenURL: config.TokenURL,
|
||||
},
|
||||
RedirectURL: redirectURI,
|
||||
Scopes: config.Scopes,
|
||||
},
|
||||
config: &config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Client) GetAuthorizeURL(state string) string {
|
||||
return c.oauth.AuthCodeURL(state)
|
||||
}
|
||||
|
||||
func (c *Client) GetUserInfo(code string) (*UserInfo, error) {
|
||||
token, err := c.oauth.Exchange(c.ctx, code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client := c.oauth.Client(c.ctx, token)
|
||||
res, err := client.Get(c.config.UserInfoURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
buf, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.logger.Info("oauth GetUserInfo:", log.Any("resp", string(buf)))
|
||||
|
||||
jsonString := string(buf)
|
||||
|
||||
email := gjson.Get(jsonString, c.config.EmailField).String()
|
||||
if email == "" && c.config.UserInfoURL == githubUserInfoURL {
|
||||
email, err = c.GetGithubPrimaryEmail(token)
|
||||
if err != nil {
|
||||
c.logger.Warn("GetGithubPrimaryEmail failed", log.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
return &UserInfo{
|
||||
ID: gjson.Get(jsonString, c.config.IDField).String(),
|
||||
AvatarUrl: gjson.Get(jsonString, c.config.AvatarField).String(),
|
||||
Name: gjson.Get(jsonString, c.config.NameField).String(),
|
||||
Email: email,
|
||||
}, nil
|
||||
}
|
||||
102
backend/pkg/ratelimit/rate_limiter.go
Normal file
102
backend/pkg/ratelimit/rate_limiter.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"github.com/chaitin/panda-wiki/log"
|
||||
"github.com/chaitin/panda-wiki/store/cache"
|
||||
)
|
||||
|
||||
type RateLimiter struct {
|
||||
logger *log.Logger
|
||||
cache *cache.Cache
|
||||
}
|
||||
|
||||
func NewRateLimiter(logger *log.Logger, cache *cache.Cache) *RateLimiter {
|
||||
return &RateLimiter{
|
||||
logger: logger,
|
||||
cache: cache,
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
LockThreshold1 = 5 // 第一次锁定阈值
|
||||
LockThreshold2 = 10 // 第二次锁定阈值
|
||||
LockThreshold3 = 15 // 第三次锁定阈值
|
||||
AttemptsKeyExpiry = 24 * time.Hour
|
||||
)
|
||||
|
||||
// CheckIPLocked checks if the IP is currently locked
|
||||
// Returns:
|
||||
// - bool: whether the IP is locked
|
||||
// - time.Duration: remaining lockout duration
|
||||
func (r *RateLimiter) CheckIPLocked(ctx context.Context, ip string) (bool, time.Duration) {
|
||||
lockKey := fmt.Sprintf("login_lock:%s", ip)
|
||||
|
||||
ttl, err := r.cache.TTL(ctx, lockKey).Result()
|
||||
if err != nil {
|
||||
r.logger.Error("failed to check lock status", "error", err, "ip", ip)
|
||||
return false, 0
|
||||
}
|
||||
|
||||
if ttl > 0 {
|
||||
return true, ttl
|
||||
}
|
||||
|
||||
return false, 0
|
||||
}
|
||||
|
||||
func (r *RateLimiter) LockAttempt(ctx context.Context, ip string) {
|
||||
attemptsKey := fmt.Sprintf("login_attempts:%s", ip)
|
||||
lockKey := fmt.Sprintf("login_lock:%s", ip)
|
||||
|
||||
attempts, err := r.cache.Incr(ctx, attemptsKey).Result()
|
||||
if err != nil {
|
||||
r.logger.Error("failed to increment attempts", "error", err, "ip", ip)
|
||||
return
|
||||
}
|
||||
|
||||
if err := r.cache.Expire(ctx, attemptsKey, AttemptsKeyExpiry).Err(); err != nil {
|
||||
r.logger.Error("failed to set expiry on attempts key", "error", err, "ip", ip)
|
||||
}
|
||||
|
||||
var lockDuration time.Duration
|
||||
|
||||
if attempts%5 == 0 {
|
||||
switch {
|
||||
case attempts == LockThreshold1:
|
||||
lockDuration = time.Minute
|
||||
case attempts == LockThreshold2:
|
||||
lockDuration = 15 * time.Minute
|
||||
case attempts >= LockThreshold3:
|
||||
lockDuration = time.Hour
|
||||
}
|
||||
if lockDuration > 0 {
|
||||
if err := r.cache.Set(ctx, lockKey, 1, lockDuration).Err(); err != nil {
|
||||
r.logger.Error("failed to set lock key", "error", err, "ip", ip)
|
||||
return
|
||||
}
|
||||
r.logger.Info("IP has been locked", "ip", ip, "lockDuration", lockDuration)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ResetLoginAttempts resets the login attempt counter and lock for an IP
|
||||
func (r *RateLimiter) ResetLoginAttempts(ctx context.Context, ip string) error {
|
||||
attemptsKey := fmt.Sprintf("login_attempts:%s", ip)
|
||||
lockKey := fmt.Sprintf("login_lock:%s", ip)
|
||||
|
||||
pipe := r.cache.Pipeline()
|
||||
pipe.Del(ctx, attemptsKey)
|
||||
pipe.Del(ctx, lockKey)
|
||||
_, err := pipe.Exec(ctx)
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
return fmt.Errorf("failed to reset login attempts: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
347
backend/pkg/wecom/wecom.go
Normal file
347
backend/pkg/wecom/wecom.go
Normal file
@@ -0,0 +1,347 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/chaitin/panda-wiki/log"
|
||||
"github.com/chaitin/panda-wiki/store/cache"
|
||||
)
|
||||
|
||||
const (
|
||||
// AuthURL api doc https://developer.work.weixin.qq.com/document/path/98152
|
||||
AuthWebURL = "https://login.work.weixin.qq.com/wwlogin/sso/login"
|
||||
AuthAPPURL = "https://open.weixin.qq.com/connect/oauth2/authorize"
|
||||
TokenURL = "https://qyapi.weixin.qq.com/cgi-bin/gettoken"
|
||||
UserInfoURL = "https://qyapi.weixin.qq.com/cgi-bin/auth/getuserinfo"
|
||||
UserDetailURL = "https://qyapi.weixin.qq.com/cgi-bin/user/get"
|
||||
// DepartmentListURL https://developer.work.weixin.qq.com/document/path/90344
|
||||
DepartmentListURL = "https://qyapi.weixin.qq.com/cgi-bin/department/list"
|
||||
// UserListUrl https://developer.work.weixin.qq.com/document/path/90337
|
||||
UserListUrl = "https://qyapi.weixin.qq.com/cgi-bin/user/list"
|
||||
callbackPath = "/share/pro/v1/openapi/wecom/callback"
|
||||
)
|
||||
|
||||
// Client 企业微信客户端
|
||||
type Client struct {
|
||||
context context.Context
|
||||
cache *cache.Cache
|
||||
httpClient *http.Client
|
||||
oauthConfig *oauth2.Config
|
||||
logger *log.Logger
|
||||
corpID string
|
||||
agentID string
|
||||
}
|
||||
|
||||
type TokenResponse struct {
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
AccessToken string `json:"access_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
|
||||
type UserInfoResponse struct {
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
UserID string `json:"userid"`
|
||||
UserTicket string `json:"user_ticket"`
|
||||
OpenID string `json:"openid"`
|
||||
ExternalUserid string `json:"external_userid"`
|
||||
}
|
||||
|
||||
type UserDetailResponse struct {
|
||||
Errcode int `json:"errcode"`
|
||||
Errmsg string `json:"errmsg"`
|
||||
Userid string `json:"userid"`
|
||||
Name string `json:"name"`
|
||||
Mobile string `json:"mobile"`
|
||||
Gender string `json:"gender"`
|
||||
Email string `json:"email"`
|
||||
Avatar string `json:"avatar"`
|
||||
OpenUserid string `json:"open_userid"`
|
||||
}
|
||||
type DepartmentListResponse struct {
|
||||
Errcode int `json:"errcode"`
|
||||
Errmsg string `json:"errmsg"`
|
||||
Department []struct {
|
||||
Id int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
NameEn string `json:"name_en"`
|
||||
DepartmentLeader []string `json:"department_leader"`
|
||||
Parentid int `json:"parentid"`
|
||||
Order int `json:"order"`
|
||||
} `json:"department"`
|
||||
}
|
||||
|
||||
type UserListResponse struct {
|
||||
Errcode int `json:"errcode"`
|
||||
Errmsg string `json:"errmsg"`
|
||||
Userlist []struct {
|
||||
Name string `json:"name"`
|
||||
Department []int `json:"department"`
|
||||
Position string `json:"position"`
|
||||
Status int `json:"status"`
|
||||
Email string `json:"email"`
|
||||
Avatar string `json:"avatar"`
|
||||
Enable int `json:"enable"`
|
||||
Isleader int `json:"isleader"`
|
||||
Extattr struct {
|
||||
Attrs []interface{} `json:"attrs"`
|
||||
} `json:"extattr"`
|
||||
HideMobile int `json:"hide_mobile"`
|
||||
Telephone string `json:"telephone"`
|
||||
Order []int `json:"order"`
|
||||
ExternalProfile struct {
|
||||
ExternalAttr []interface{} `json:"external_attr"`
|
||||
ExternalCorpName string `json:"external_corp_name"`
|
||||
} `json:"external_profile"`
|
||||
MainDepartment int `json:"main_department"`
|
||||
Alias string `json:"alias"`
|
||||
IsLeaderInDept []int `json:"is_leader_in_dept"`
|
||||
Userid string `json:"userid"`
|
||||
DirectLeader []interface{} `json:"direct_leader"`
|
||||
} `json:"userlist"`
|
||||
}
|
||||
|
||||
func NewClient(ctx context.Context, logger *log.Logger, corpID, corpSecret, agentID, baseUrl string, cache *cache.Cache, isApp bool) (*Client, error) {
|
||||
redirectURI, err := url.JoinPath(baseUrl, callbackPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
authUrl := AuthWebURL
|
||||
if isApp {
|
||||
authUrl = AuthAPPURL
|
||||
}
|
||||
|
||||
oauthConfig := &oauth2.Config{
|
||||
ClientID: fmt.Sprintf("%s-%s", corpID, agentID),
|
||||
ClientSecret: corpSecret,
|
||||
RedirectURL: redirectURI,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: authUrl,
|
||||
TokenURL: TokenURL,
|
||||
},
|
||||
Scopes: []string{"snsapi_privateinfo"},
|
||||
}
|
||||
|
||||
return &Client{
|
||||
context: ctx,
|
||||
httpClient: &http.Client{},
|
||||
cache: cache,
|
||||
logger: logger.WithModule("wecom.client"),
|
||||
oauthConfig: oauthConfig,
|
||||
corpID: corpID,
|
||||
agentID: agentID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GenerateAuthURL 生成授权 URL
|
||||
func (c *Client) GenerateAuthURL(state string) string {
|
||||
params := url.Values{}
|
||||
params.Set("appid", c.corpID)
|
||||
params.Set("redirect_uri", c.oauthConfig.RedirectURL)
|
||||
params.Set("response_type", "code")
|
||||
params.Set("scope", "snsapi_privateinfo")
|
||||
params.Set("login_type", "CorpApp")
|
||||
params.Set("agentid", c.agentID)
|
||||
params.Set("state", state)
|
||||
|
||||
authUrl := fmt.Sprintf("%s?%s", c.oauthConfig.Endpoint.AuthURL, params.Encode())
|
||||
if c.oauthConfig.Endpoint.AuthURL == AuthAPPURL {
|
||||
authUrl += "#wechat_redirect"
|
||||
}
|
||||
return authUrl
|
||||
}
|
||||
|
||||
// GetAccessToken 获取企业微信访问令牌
|
||||
func (c *Client) GetAccessToken(ctx context.Context) (string, error) {
|
||||
|
||||
cacheKey := fmt.Sprintf("wecom-access-token:%s", c.oauthConfig.ClientID)
|
||||
cachedData, err := c.cache.Get(ctx, cacheKey).Result()
|
||||
if err == nil && cachedData != "" {
|
||||
return cachedData, nil
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
params.Set("corpid", c.corpID)
|
||||
params.Set("corpsecret", c.oauthConfig.ClientSecret)
|
||||
|
||||
resp, err := c.httpClient.Get(fmt.Sprintf("%s?%s", TokenURL, params.Encode()))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get access token: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var tokenResp TokenResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||
return "", fmt.Errorf("failed to decode token response: %w", err)
|
||||
}
|
||||
|
||||
if tokenResp.ErrCode != 0 {
|
||||
return "", fmt.Errorf("failed to get access token: %s", tokenResp.ErrMsg)
|
||||
}
|
||||
|
||||
if err := c.cache.Set(ctx, cacheKey, tokenResp.AccessToken, time.Duration(tokenResp.ExpiresIn-300)*time.Second).Err(); err != nil {
|
||||
c.logger.Warn("failed to set cache", log.Error(err))
|
||||
}
|
||||
|
||||
return tokenResp.AccessToken, nil
|
||||
}
|
||||
|
||||
// GetUserInfoByCode 通过授权码获取用户信息
|
||||
func (c *Client) GetUserInfoByCode(ctx context.Context, code string) (*UserDetailResponse, error) {
|
||||
|
||||
accessToken, err := c.GetAccessToken(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get access token: %w", err)
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
params.Set("access_token", accessToken)
|
||||
params.Set("code", code)
|
||||
|
||||
userInfoURL := fmt.Sprintf("%s?%s", UserInfoURL, params.Encode())
|
||||
|
||||
c.logger.Debug("GetUserInfoByCode", log.Any("userInfoURL", userInfoURL))
|
||||
|
||||
resp, err := c.httpClient.Get(userInfoURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user info: %w", err)
|
||||
}
|
||||
|
||||
rawBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read body: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
c.logger.Debug("GetUserInfoByCode raw resp:", log.Any("raw", string(rawBody)))
|
||||
|
||||
resp.Body = io.NopCloser(bytes.NewReader(rawBody))
|
||||
|
||||
var userInfoResp UserInfoResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&userInfoResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode user info response: %w", err)
|
||||
}
|
||||
|
||||
c.logger.Debug("GetUserInfoByCode resp:", log.Any("resp", userInfoResp))
|
||||
|
||||
if userInfoResp.ErrCode != 0 {
|
||||
return nil, fmt.Errorf("failed to get user info: %s", userInfoResp.ErrMsg)
|
||||
}
|
||||
|
||||
detailParams := url.Values{}
|
||||
detailParams.Set("access_token", accessToken)
|
||||
detailParams.Set("userid", userInfoResp.UserID)
|
||||
|
||||
userDetailURL := fmt.Sprintf("%s?%s", UserDetailURL, detailParams.Encode())
|
||||
|
||||
detailResp, err := c.httpClient.Get(userDetailURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user detail: %w", err)
|
||||
}
|
||||
defer detailResp.Body.Close()
|
||||
|
||||
var UserDetailResp UserDetailResponse
|
||||
if err := json.NewDecoder(detailResp.Body).Decode(&UserDetailResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode user detail response: %w", err)
|
||||
}
|
||||
|
||||
c.logger.Debug("GetUserInfoByCode detail info", log.Any("resp", UserDetailResp))
|
||||
|
||||
if UserDetailResp.Errcode != 0 {
|
||||
return nil, fmt.Errorf("failed to get user detail: %s", UserDetailResp.Errmsg)
|
||||
}
|
||||
|
||||
return &UserDetailResp, nil
|
||||
}
|
||||
|
||||
func (c *Client) GetDepartmentList(ctx context.Context) (*DepartmentListResponse, error) {
|
||||
|
||||
accessToken, err := c.GetAccessToken(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get access token: %w", err)
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
params.Set("access_token", accessToken)
|
||||
|
||||
departmentListURL := fmt.Sprintf("%s?%s", DepartmentListURL, params.Encode())
|
||||
|
||||
resp, err := c.httpClient.Get(departmentListURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get department list: %w", err)
|
||||
}
|
||||
|
||||
rawBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read body: %w", err)
|
||||
}
|
||||
c.logger.Debug("GetDepartmentList raw resp:", log.Any("raw", string(rawBody)))
|
||||
defer resp.Body.Close()
|
||||
|
||||
resp.Body = io.NopCloser(bytes.NewReader(rawBody))
|
||||
|
||||
var departmentListResponse DepartmentListResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&departmentListResponse); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode department list response: %w", err)
|
||||
}
|
||||
|
||||
c.logger.Debug("GetDepartmentList resp:", log.Any("resp", departmentListResponse))
|
||||
|
||||
if departmentListResponse.Errcode != 0 {
|
||||
return nil, fmt.Errorf("failed to get user info: %s", departmentListResponse.Errmsg)
|
||||
}
|
||||
|
||||
return &departmentListResponse, nil
|
||||
}
|
||||
|
||||
func (c *Client) GetUserList(ctx context.Context, deptID string) (*UserListResponse, error) {
|
||||
|
||||
accessToken, err := c.GetAccessToken(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get access token: %w", err)
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
params.Set("access_token", accessToken)
|
||||
params.Set("department_id", deptID)
|
||||
|
||||
userListUrl := fmt.Sprintf("%s?%s", UserListUrl, params.Encode())
|
||||
|
||||
resp, err := c.httpClient.Get(userListUrl)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user list: %w", err)
|
||||
}
|
||||
|
||||
rawBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read body: %w", err)
|
||||
}
|
||||
|
||||
c.logger.Debug("GetUserList raw resp:", log.Any("raw", string(rawBody)))
|
||||
|
||||
resp.Body = io.NopCloser(bytes.NewReader(rawBody))
|
||||
|
||||
var userListResponse UserListResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&userListResponse); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode user list response: %w", err)
|
||||
}
|
||||
|
||||
c.logger.Debug("GetUserList resp:", log.Any("resp", userListResponse))
|
||||
|
||||
if userListResponse.Errcode != 0 {
|
||||
return nil, fmt.Errorf("failed to get user info: %s", userListResponse.Errmsg)
|
||||
}
|
||||
|
||||
return &userListResponse, nil
|
||||
}
|
||||
Reference in New Issue
Block a user