init push
This commit is contained in:
502
backend/pkg/bot/dingtalk/stream.go
Normal file
502
backend/pkg/bot/dingtalk/stream.go
Normal file
@@ -0,0 +1,502 @@
|
||||
package dingtalk
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
openapi "github.com/alibabacloud-go/darabonba-openapi/v2/client"
|
||||
dingtalkcard_1_0 "github.com/alibabacloud-go/dingtalk/card_1_0"
|
||||
dingtalkoauth2_1_0 "github.com/alibabacloud-go/dingtalk/oauth2_1_0"
|
||||
util "github.com/alibabacloud-go/tea-utils/v2/service"
|
||||
"github.com/alibabacloud-go/tea/tea"
|
||||
"github.com/google/uuid"
|
||||
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
|
||||
"github.com/open-dingtalk/dingtalk-stream-sdk-go/client"
|
||||
|
||||
"github.com/chaitin/panda-wiki/domain"
|
||||
"github.com/chaitin/panda-wiki/log"
|
||||
"github.com/chaitin/panda-wiki/pkg/bot"
|
||||
)
|
||||
|
||||
type DingTalkClient struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
clientID string
|
||||
clientSecret string
|
||||
templateID string // 4d18414c-aabc-4ec8-9e67-4ceefeada72a.schema
|
||||
oauthClient *dingtalkoauth2_1_0.Client
|
||||
cardClient *dingtalkcard_1_0.Client
|
||||
getQA bot.GetQAFun
|
||||
logger *log.Logger
|
||||
tokenCache struct {
|
||||
accessToken string
|
||||
expireAt time.Time
|
||||
}
|
||||
tokenMutex sync.RWMutex
|
||||
messageMu sync.Mutex
|
||||
messageSeenAt map[string]messageMark
|
||||
messageTTL time.Duration
|
||||
nowFunc func() time.Time
|
||||
processMessageFn func(ctx context.Context, data *chatbot.BotCallbackDataModel) error
|
||||
}
|
||||
|
||||
type messageMark struct {
|
||||
seenAt time.Time
|
||||
inFlight bool
|
||||
}
|
||||
|
||||
func NewDingTalkClient(ctx context.Context, cancel context.CancelFunc, clientId, clientSecret, templateID string, logger *log.Logger, getQA bot.GetQAFun) (*DingTalkClient, error) {
|
||||
config := &openapi.Config{}
|
||||
config.Protocol = tea.String("https")
|
||||
config.RegionId = tea.String("central")
|
||||
oauthClient, err := dingtalkoauth2_1_0.NewClient(config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create oauth client: %w", err)
|
||||
}
|
||||
cardClient, err := dingtalkcard_1_0.NewClient(config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create card client: %w", err)
|
||||
}
|
||||
client := &DingTalkClient{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
clientID: clientId,
|
||||
clientSecret: clientSecret,
|
||||
templateID: templateID,
|
||||
oauthClient: oauthClient,
|
||||
cardClient: cardClient,
|
||||
getQA: getQA,
|
||||
logger: logger,
|
||||
messageSeenAt: make(map[string]messageMark),
|
||||
messageTTL: 5 * time.Minute,
|
||||
nowFunc: time.Now,
|
||||
}
|
||||
client.startMessageCleanup()
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *DingTalkClient) GetAccessToken() (string, error) {
|
||||
c.tokenMutex.RLock()
|
||||
// TODO: use redis cache
|
||||
if c.tokenCache.accessToken != "" && time.Now().Before(c.tokenCache.expireAt) {
|
||||
token := c.tokenCache.accessToken
|
||||
c.tokenMutex.RUnlock()
|
||||
return token, nil
|
||||
}
|
||||
c.tokenMutex.RUnlock()
|
||||
|
||||
c.tokenMutex.Lock()
|
||||
defer c.tokenMutex.Unlock()
|
||||
|
||||
if c.tokenCache.accessToken != "" && time.Now().Before(c.tokenCache.expireAt) {
|
||||
return c.tokenCache.accessToken, nil
|
||||
}
|
||||
|
||||
request := &dingtalkoauth2_1_0.GetAccessTokenRequest{
|
||||
AppKey: tea.String(c.clientID),
|
||||
AppSecret: tea.String(c.clientSecret),
|
||||
}
|
||||
response, tryErr := func() (_resp *dingtalkoauth2_1_0.GetAccessTokenResponse, _e error) {
|
||||
defer func() {
|
||||
if r := tea.Recover(recover()); r != nil {
|
||||
_e = r
|
||||
}
|
||||
}()
|
||||
_resp, _err := c.oauthClient.GetAccessToken(request)
|
||||
if _err != nil {
|
||||
return nil, _err
|
||||
}
|
||||
|
||||
return _resp, nil
|
||||
}()
|
||||
if tryErr != nil {
|
||||
return "", tryErr
|
||||
}
|
||||
accessToken := *response.Body.AccessToken
|
||||
c.logger.Info("get access token", log.String("access_token", accessToken), log.Int("expire_in", int(*response.Body.ExpireIn)))
|
||||
c.tokenCache.accessToken = accessToken
|
||||
c.tokenCache.expireAt = time.Now().Add(time.Duration(*response.Body.ExpireIn-300) * time.Second)
|
||||
|
||||
return c.tokenCache.accessToken, nil
|
||||
}
|
||||
|
||||
func (c *DingTalkClient) UpdateAIStreamCard(trackID, content string, isFinalize bool) error {
|
||||
accessToken, err := c.GetAccessToken()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get access token while updating interactive card: %w", err)
|
||||
}
|
||||
|
||||
headers := &dingtalkcard_1_0.StreamingUpdateHeaders{
|
||||
XAcsDingtalkAccessToken: tea.String(accessToken),
|
||||
}
|
||||
request := &dingtalkcard_1_0.StreamingUpdateRequest{
|
||||
OutTrackId: tea.String(trackID),
|
||||
Guid: tea.String(uuid.New().String()),
|
||||
Key: tea.String("content"),
|
||||
Content: tea.String(content),
|
||||
IsFull: tea.Bool(true),
|
||||
IsFinalize: tea.Bool(isFinalize),
|
||||
IsError: tea.Bool(false),
|
||||
}
|
||||
_, err = c.cardClient.StreamingUpdateWithOptions(request, headers, &util.RuntimeOptions{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update card: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *DingTalkClient) CreateAndDeliverCard(ctx context.Context, trackID string, data *chatbot.BotCallbackDataModel) error {
|
||||
accessToken, err := c.GetAccessToken()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get access token while creating and delivering card: %w", err)
|
||||
}
|
||||
|
||||
createAndDeliverHeaders := &dingtalkcard_1_0.CreateAndDeliverHeaders{}
|
||||
createAndDeliverHeaders.XAcsDingtalkAccessToken = tea.String(accessToken)
|
||||
|
||||
cardDataCardParamMap := map[string]*string{
|
||||
"content": tea.String(""),
|
||||
}
|
||||
cardData := &dingtalkcard_1_0.CreateAndDeliverRequestCardData{
|
||||
CardParamMap: cardDataCardParamMap,
|
||||
}
|
||||
|
||||
createAndDeliverRequest := &dingtalkcard_1_0.CreateAndDeliverRequest{
|
||||
CardTemplateId: tea.String(c.templateID),
|
||||
OutTrackId: tea.String(trackID),
|
||||
CardData: cardData,
|
||||
CallbackType: tea.String("STREAM"),
|
||||
ImGroupOpenSpaceModel: &dingtalkcard_1_0.CreateAndDeliverRequestImGroupOpenSpaceModel{
|
||||
SupportForward: tea.Bool(true),
|
||||
},
|
||||
ImRobotOpenSpaceModel: &dingtalkcard_1_0.CreateAndDeliverRequestImRobotOpenSpaceModel{
|
||||
SupportForward: tea.Bool(true),
|
||||
},
|
||||
UserIdType: tea.Int32(1),
|
||||
}
|
||||
switch data.ConversationType {
|
||||
case "2": // 群聊
|
||||
openSpaceId := fmt.Sprintf("dtv1.card//%s.%s", "IM_GROUP", data.ConversationId)
|
||||
createAndDeliverRequest.SetOpenSpaceId(openSpaceId)
|
||||
createAndDeliverRequest.SetImGroupOpenDeliverModel(
|
||||
&dingtalkcard_1_0.CreateAndDeliverRequestImGroupOpenDeliverModel{
|
||||
RobotCode: tea.String(c.clientID),
|
||||
})
|
||||
case "1": // Im机器人单聊
|
||||
openSpaceId := fmt.Sprintf("dtv1.card//%s.%s", "IM_ROBOT", data.SenderStaffId)
|
||||
createAndDeliverRequest.SetOpenSpaceId(openSpaceId)
|
||||
createAndDeliverRequest.SetImRobotOpenDeliverModel(&dingtalkcard_1_0.CreateAndDeliverRequestImRobotOpenDeliverModel{
|
||||
SpaceType: tea.String("IM_GROUP"),
|
||||
})
|
||||
default:
|
||||
return fmt.Errorf("invalid conversation type: %s", data.ConversationType)
|
||||
}
|
||||
|
||||
_, err = c.cardClient.CreateAndDeliverWithOptions(createAndDeliverRequest, createAndDeliverHeaders, &util.RuntimeOptions{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create and deliver card: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *DingTalkClient) startMessageCleanup() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
c.cleanupExpiredMessages()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (c *DingTalkClient) cleanupExpiredMessages() {
|
||||
now := c.nowFunc()
|
||||
|
||||
c.messageMu.Lock()
|
||||
defer c.messageMu.Unlock()
|
||||
|
||||
for msgID, mark := range c.messageSeenAt {
|
||||
if mark.inFlight {
|
||||
continue
|
||||
}
|
||||
if now.Sub(mark.seenAt) > c.messageTTL {
|
||||
delete(c.messageSeenAt, msgID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *DingTalkClient) tryMarkMessage(msgID string) bool {
|
||||
if strings.TrimSpace(msgID) == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
now := c.nowFunc()
|
||||
|
||||
c.messageMu.Lock()
|
||||
defer c.messageMu.Unlock()
|
||||
|
||||
if mark, ok := c.messageSeenAt[msgID]; ok {
|
||||
if mark.inFlight || now.Sub(mark.seenAt) <= c.messageTTL {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
c.messageSeenAt[msgID] = messageMark{
|
||||
seenAt: now,
|
||||
inFlight: true,
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *DingTalkClient) markMessageCompleted(msgID string) {
|
||||
if strings.TrimSpace(msgID) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
c.messageMu.Lock()
|
||||
defer c.messageMu.Unlock()
|
||||
|
||||
c.messageSeenAt[msgID] = messageMark{
|
||||
seenAt: c.nowFunc(),
|
||||
inFlight: false,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *DingTalkClient) clearMessageMark(msgID string) {
|
||||
if strings.TrimSpace(msgID) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
c.messageMu.Lock()
|
||||
defer c.messageMu.Unlock()
|
||||
|
||||
delete(c.messageSeenAt, msgID)
|
||||
}
|
||||
|
||||
func (c *DingTalkClient) OnChatBotMessageReceived(ctx context.Context, data *chatbot.BotCallbackDataModel) ([]byte, error) {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
c.logger.Info("dingtalk bot is disabled, ignoring message", log.String("client_id", c.clientID))
|
||||
return nil, nil
|
||||
default:
|
||||
}
|
||||
|
||||
if !c.tryMarkMessage(data.MsgId) {
|
||||
c.logger.Info("ignore duplicate dingtalk message", log.String("msg_id", data.MsgId))
|
||||
return []byte(""), nil
|
||||
}
|
||||
|
||||
processor := c.processMessageFn
|
||||
if processor == nil {
|
||||
processor = c.processMessage
|
||||
}
|
||||
|
||||
payload := *data
|
||||
go c.processMessageAsync(c.ctx, &payload, processor)
|
||||
|
||||
return []byte(""), nil
|
||||
}
|
||||
|
||||
func (c *DingTalkClient) processMessageAsync(ctx context.Context, data *chatbot.BotCallbackDataModel, processor func(context.Context, *chatbot.BotCallbackDataModel) error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
c.clearMessageMark(data.MsgId)
|
||||
c.logger.Error("process dingtalk message panicked", log.String("msg_id", data.MsgId), log.Any("panic", r))
|
||||
}
|
||||
}()
|
||||
|
||||
if err := processor(ctx, data); err != nil {
|
||||
c.clearMessageMark(data.MsgId)
|
||||
c.logger.Error("process dingtalk message failed", log.String("msg_id", data.MsgId), log.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
c.markMessageCompleted(data.MsgId)
|
||||
}
|
||||
|
||||
func (c *DingTalkClient) processMessage(ctx context.Context, data *chatbot.BotCallbackDataModel) error {
|
||||
question := data.Text.Content
|
||||
question = strings.TrimSpace(question)
|
||||
trackID := uuid.New().String()
|
||||
// conversation_type == 1 表示机器人单聊,==2 表示群聊中@机器人
|
||||
c.logger.Info("dingtalk client received message", log.String("question", question), log.String("track_id", trackID), log.String("conversation_type", data.ConversationType))
|
||||
// create and deliver card
|
||||
if err := c.CreateAndDeliverCard(ctx, trackID, data); err != nil {
|
||||
c.logger.Error("CreateAndDeliverCard", log.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
initialContent := fmt.Sprintf("**%s**\n\n%s", question, "稍等,让我想一想……")
|
||||
|
||||
if err := c.UpdateAIStreamCard(trackID, initialContent, false); err != nil {
|
||||
c.logger.Error("UpdateInteractiveCard", log.Error(err))
|
||||
return err
|
||||
}
|
||||
// 初始化 默认为空
|
||||
convInfo := &domain.ConversationInfo{
|
||||
UserInfo: domain.UserInfo{
|
||||
From: domain.MessageFromPrivate, // 默认是私聊
|
||||
},
|
||||
}
|
||||
// 之前创建并且发送卡片消息,获取用户基本信息
|
||||
userinfo, err := c.GetUserInfo(data.SenderStaffId)
|
||||
if err != nil {
|
||||
c.logger.Error("GetUserInfo failed", log.Error(err))
|
||||
} else {
|
||||
c.logger.Info("GetUserInfo success", log.Any("userinfo", userinfo))
|
||||
convInfo.UserInfo.UserID = userinfo.Result.Userid
|
||||
convInfo.UserInfo.NickName = userinfo.Result.Name
|
||||
convInfo.UserInfo.Avatar = userinfo.Result.Avatar
|
||||
convInfo.UserInfo.Email = userinfo.Result.Email
|
||||
}
|
||||
if data.ConversationType == "2" { // 群聊
|
||||
convInfo.UserInfo.From = domain.MessageFromGroup
|
||||
} else { // 单聊
|
||||
convInfo.UserInfo.From = domain.MessageFromPrivate
|
||||
}
|
||||
|
||||
contentCh, err := c.getQA(ctx, question, *convInfo, "")
|
||||
if err != nil {
|
||||
c.logger.Error("dingtalk client failed to get answer", log.Error(err))
|
||||
if updateErr := c.UpdateAIStreamCard(trackID, "出错了,请稍后再试", true); updateErr != nil {
|
||||
c.logger.Error("UpdateInteractiveCard in contentCh failed", log.Error(updateErr))
|
||||
return fmt.Errorf("get answer failed: %w; update error card failed: %w", err, updateErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
updateTicker := time.NewTicker(1500 * time.Millisecond)
|
||||
defer updateTicker.Stop()
|
||||
|
||||
ans := fmt.Sprintf("**%s**\n\n", question)
|
||||
fullContent := fmt.Sprintf("**%s**\n\n", question)
|
||||
for {
|
||||
select {
|
||||
case content, ok := <-contentCh:
|
||||
if !ok {
|
||||
if err := c.UpdateAIStreamCard(trackID, fullContent, true); err != nil {
|
||||
c.logger.Error("UpdateInteractiveCard in contentCh", log.Error(err))
|
||||
if updateErr := c.UpdateAIStreamCard(trackID, "出错了,请稍后再试", true); updateErr != nil {
|
||||
c.logger.Error("UpdateInteractiveCard in contentCh failed", log.Error(updateErr))
|
||||
return fmt.Errorf("final update card failed: %w; fallback update failed: %w", err, updateErr)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
fullContent += content
|
||||
case <-updateTicker.C:
|
||||
if fullContent == ans {
|
||||
continue
|
||||
}
|
||||
if err := c.UpdateAIStreamCard(trackID, fullContent, false); err != nil {
|
||||
c.logger.Error("UpdateInteractiveCard in ticker", log.Error(err))
|
||||
if updateErr := c.UpdateAIStreamCard(trackID, "出错了,请稍后再试", true); updateErr != nil {
|
||||
c.logger.Error("UpdateInteractiveCard in ticker failed", log.Error(updateErr))
|
||||
return fmt.Errorf("stream update card failed: %w; fallback update failed: %w", err, updateErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *DingTalkClient) Start() error {
|
||||
cli := client.NewStreamClient(client.WithAppCredential(client.NewAppCredentialConfig(
|
||||
c.clientID,
|
||||
c.clientSecret,
|
||||
)))
|
||||
cli.RegisterChatBotCallbackRouter(c.OnChatBotMessageReceived)
|
||||
if err := cli.Start(c.ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
<-c.ctx.Done()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *DingTalkClient) Stop() {
|
||||
c.cancel()
|
||||
}
|
||||
|
||||
// 钉钉的用户信息
|
||||
type UserDetailResponse struct {
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
Result UserDetails `json:"result"`
|
||||
}
|
||||
|
||||
type UserDetails struct {
|
||||
Unionid string `json:"unionid"`
|
||||
Userid string `json:"userid"`
|
||||
Name string `json:"name"`
|
||||
Avatar string `json:"avatar"`
|
||||
Mobile string `json:"mobile"`
|
||||
Email string `json:"email"`
|
||||
Title string `json:"title"`
|
||||
Active bool `json:"active"`
|
||||
Admin bool `json:"admin"`
|
||||
Boss bool `json:"boss"`
|
||||
DeptIDList []int64 `json:"dept_id_list"`
|
||||
JobNumber string `json:"job_number"`
|
||||
HiredDate int64 `json:"hired_date"`
|
||||
ManagerUserid string `json:"manager_userid"`
|
||||
}
|
||||
|
||||
// 使用原始的http请求来获取用户的信息 - > 需要设置获取用户的权限功能:企业员工手机号信息和邮箱等个人信息、成员信息读权限
|
||||
func (c *DingTalkClient) GetUserInfo(userID string) (*UserDetailResponse, error) {
|
||||
accessToken, err := c.GetAccessToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get access token while creating and delivering card: %w", err)
|
||||
}
|
||||
// 1. 构建URL和请求体
|
||||
url := "https://oapi.dingtalk.com/topapi/v2/user/get"
|
||||
payload := map[string]string{"userid": userID, "language": "zh_CN"} // 默认是中文
|
||||
jsonPayload, _ := json.Marshal(payload)
|
||||
|
||||
req, _ := http.NewRequest("POST", url, bytes.NewBuffer(jsonPayload))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
query := req.URL.Query()
|
||||
query.Add("access_token", accessToken)
|
||||
req.URL.RawQuery = query.Encode()
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
c.logger.Error("Failed to get user info from dingtalk: %v", log.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
// 获取到用户信息
|
||||
c.logger.Info("Get user info from dingtalk success", log.Any("resp 原始的消息:", resp))
|
||||
|
||||
var result UserDetailResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
c.logger.Error("Failed to unmarshal user info response: %v", log.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result.ErrCode != 0 {
|
||||
c.logger.Error("Failed to get result info", log.Any("ErrCode", result.ErrCode), log.String("ErrMsg", result.ErrMsg))
|
||||
return nil, fmt.Errorf("result.ErrCode:%d", result.ErrCode)
|
||||
}
|
||||
// success
|
||||
c.logger.Info("Get user info from dingtalk success", log.Any("userinfo:", result))
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
263
backend/pkg/bot/dingtalk/stream_test.go
Normal file
263
backend/pkg/bot/dingtalk/stream_test.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package dingtalk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"log/slog"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
pwlog "github.com/chaitin/panda-wiki/log"
|
||||
)
|
||||
|
||||
func newTestLogger() *pwlog.Logger {
|
||||
return &pwlog.Logger{
|
||||
Logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
|
||||
}
|
||||
}
|
||||
|
||||
func newTestDingTalkClient(t *testing.T) *DingTalkClient {
|
||||
t.Helper()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
|
||||
client, err := NewDingTalkClient(
|
||||
ctx,
|
||||
cancel,
|
||||
"client-id",
|
||||
"client-secret",
|
||||
"template-id",
|
||||
newTestLogger(),
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
client.messageTTL = time.Minute
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
func TestTryMarkMessageDeduplicatesWithinTTL(t *testing.T) {
|
||||
client := newTestDingTalkClient(t)
|
||||
|
||||
now := time.Now()
|
||||
client.nowFunc = func() time.Time {
|
||||
return now
|
||||
}
|
||||
|
||||
require.True(t, client.tryMarkMessage("msg-1"))
|
||||
require.False(t, client.tryMarkMessage("msg-1"))
|
||||
|
||||
client.markMessageCompleted("msg-1")
|
||||
require.False(t, client.tryMarkMessage("msg-1"))
|
||||
|
||||
now = now.Add(client.messageTTL + time.Second)
|
||||
|
||||
require.True(t, client.tryMarkMessage("msg-1"))
|
||||
}
|
||||
|
||||
func TestOnChatBotMessageReceivedIgnoresDuplicateMsgID(t *testing.T) {
|
||||
client := newTestDingTalkClient(t)
|
||||
|
||||
processed := make(chan struct{}, 2)
|
||||
client.processMessageFn = func(context.Context, *chatbot.BotCallbackDataModel) error {
|
||||
processed <- struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
data := &chatbot.BotCallbackDataModel{
|
||||
MsgId: "msg-1",
|
||||
Text: chatbot.BotCallbackDataTextModel{
|
||||
Content: "hello",
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := client.OnChatBotMessageReceived(context.Background(), data)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte(""), resp)
|
||||
|
||||
resp, err = client.OnChatBotMessageReceived(context.Background(), data)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte(""), resp)
|
||||
|
||||
select {
|
||||
case <-processed:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected first message to be processed")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-processed:
|
||||
t.Fatal("expected duplicate message to be ignored")
|
||||
case <-time.After(300 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func TestOnChatBotMessageReceivedReturnsBeforeProcessingCompletes(t *testing.T) {
|
||||
client := newTestDingTalkClient(t)
|
||||
|
||||
started := make(chan struct{})
|
||||
unblock := make(chan struct{})
|
||||
client.processMessageFn = func(context.Context, *chatbot.BotCallbackDataModel) error {
|
||||
close(started)
|
||||
<-unblock
|
||||
return nil
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
_, _ = client.OnChatBotMessageReceived(context.Background(), &chatbot.BotCallbackDataModel{
|
||||
MsgId: "msg-2",
|
||||
Text: chatbot.BotCallbackDataTextModel{
|
||||
Content: "slow question",
|
||||
},
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected callback to return before background processing finishes")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-started:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected background processing to start")
|
||||
}
|
||||
|
||||
close(unblock)
|
||||
}
|
||||
|
||||
func TestOnChatBotMessageReceivedAllowsRetryAfterProcessingError(t *testing.T) {
|
||||
client := newTestDingTalkClient(t)
|
||||
|
||||
attempts := make(chan struct{}, 4)
|
||||
client.processMessageFn = func(context.Context, *chatbot.BotCallbackDataModel) error {
|
||||
attempts <- struct{}{}
|
||||
return assert.AnError
|
||||
}
|
||||
|
||||
data := &chatbot.BotCallbackDataModel{
|
||||
MsgId: "msg-retry",
|
||||
Text: chatbot.BotCallbackDataTextModel{
|
||||
Content: "retry please",
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := client.OnChatBotMessageReceived(context.Background(), data)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte(""), resp)
|
||||
|
||||
select {
|
||||
case <-attempts:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected first message to be processed")
|
||||
}
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
_, callErr := client.OnChatBotMessageReceived(context.Background(), data)
|
||||
require.NoError(t, callErr)
|
||||
|
||||
select {
|
||||
case <-attempts:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, time.Second, 20*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestOnChatBotMessageReceivedRecoversBackgroundPanic(t *testing.T) {
|
||||
client := newTestDingTalkClient(t)
|
||||
|
||||
attempts := make(chan struct{}, 4)
|
||||
client.processMessageFn = func(context.Context, *chatbot.BotCallbackDataModel) error {
|
||||
attempts <- struct{}{}
|
||||
panic("boom")
|
||||
}
|
||||
|
||||
data := &chatbot.BotCallbackDataModel{
|
||||
MsgId: "msg-panic",
|
||||
Text: chatbot.BotCallbackDataTextModel{
|
||||
Content: "panic please",
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := client.OnChatBotMessageReceived(context.Background(), data)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte(""), resp)
|
||||
|
||||
select {
|
||||
case <-attempts:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected background processing to start")
|
||||
}
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
_, callErr := client.OnChatBotMessageReceived(context.Background(), data)
|
||||
require.NoError(t, callErr)
|
||||
|
||||
select {
|
||||
case <-attempts:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, time.Second, 20*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestOnChatBotMessageReceivedKeepsInFlightMessageMarkedPastTTL(t *testing.T) {
|
||||
client := newTestDingTalkClient(t)
|
||||
|
||||
now := time.Now()
|
||||
client.nowFunc = func() time.Time {
|
||||
return now
|
||||
}
|
||||
|
||||
processed := make(chan struct{}, 2)
|
||||
unblock := make(chan struct{})
|
||||
client.processMessageFn = func(context.Context, *chatbot.BotCallbackDataModel) error {
|
||||
processed <- struct{}{}
|
||||
<-unblock
|
||||
return nil
|
||||
}
|
||||
|
||||
data := &chatbot.BotCallbackDataModel{
|
||||
MsgId: "msg-inflight",
|
||||
Text: chatbot.BotCallbackDataTextModel{
|
||||
Content: "long running question",
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := client.OnChatBotMessageReceived(context.Background(), data)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte(""), resp)
|
||||
|
||||
select {
|
||||
case <-processed:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected first message to be processed")
|
||||
}
|
||||
|
||||
now = now.Add(client.messageTTL + time.Second)
|
||||
client.cleanupExpiredMessages()
|
||||
|
||||
resp, err = client.OnChatBotMessageReceived(context.Background(), data)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte(""), resp)
|
||||
|
||||
select {
|
||||
case <-processed:
|
||||
t.Fatal("expected in-flight duplicate message to be ignored after ttl cleanup")
|
||||
case <-time.After(300 * time.Millisecond):
|
||||
}
|
||||
|
||||
close(unblock)
|
||||
}
|
||||
Reference in New Issue
Block a user