package usecase import ( "context" "fmt" "slices" "strings" "time" modelkit "github.com/chaitin/ModelKit/v2/usecase" "github.com/cloudwego/eino/schema" "github.com/google/uuid" "github.com/samber/lo" "gorm.io/gorm" "github.com/chaitin/panda-wiki/consts" "github.com/chaitin/panda-wiki/domain" "github.com/chaitin/panda-wiki/log" "github.com/chaitin/panda-wiki/repo/pg" "github.com/chaitin/panda-wiki/utils" ) type ChatUsecase struct { llmUsecase *LLMUsecase conversationUsecase *ConversationUsecase modelUsecase *ModelUsecase appRepo *pg.AppRepository blockWordRepo *pg.BlockWordRepo kbRepo *pg.KnowledgeBaseRepository nodeRepo *pg.NodeRepository AuthRepo *pg.AuthRepo logger *log.Logger modelkit *modelkit.ModelKit } func NewChatUsecase(llmUsecase *LLMUsecase, kbRepo *pg.KnowledgeBaseRepository, conversationUsecase *ConversationUsecase, modelUsecase *ModelUsecase, appRepo *pg.AppRepository, blockWordRepo *pg.BlockWordRepo, nodeRepo *pg.NodeRepository, authRepo *pg.AuthRepo, logger *log.Logger) (*ChatUsecase, error) { modelkit := modelkit.NewModelKit(logger.Logger) u := &ChatUsecase{ llmUsecase: llmUsecase, conversationUsecase: conversationUsecase, modelUsecase: modelUsecase, appRepo: appRepo, blockWordRepo: blockWordRepo, kbRepo: kbRepo, nodeRepo: nodeRepo, AuthRepo: authRepo, logger: logger.WithModule("usecase.chat"), modelkit: modelkit, } if err := u.initDFA(); err != nil { u.logger.Error("failed to init dfa", log.Error(err)) return nil, err } return u, nil } func (u *ChatUsecase) initDFA() error { ctx := context.Background() kbList, err := u.kbRepo.GetKnowledgeBaseList(context.Background()) if err != nil { return fmt.Errorf("failed to get kb list: %w", err) } for _, kb := range kbList { if kb != nil { words, err := u.blockWordRepo.GetBlockWords(ctx, kb.ID) if err != nil { u.logger.Error("failed to get words", log.Error(err), log.String("kb_id", kb.ID)) return fmt.Errorf("failed to get words for kb: %w", err) } if len(words) > 0 { utils.InitDFA(kb.ID, words) } } } return nil } func (u *ChatUsecase) Chat(ctx context.Context, req *domain.ChatRequest) (<-chan domain.SSEEvent, error) { eventCh := make(chan domain.SSEEvent, 100) go func() { defer close(eventCh) // 1. get app detail and validate app app, err := u.appRepo.GetOrCreateAppByKBIDAndType(ctx, req.KBID, req.AppType) if err != nil { eventCh <- domain.SSEEvent{Type: "error", Content: "app not found"} return } req.KBID = app.KBID req.AppID = app.ID req.AppType = app.Type // 2. get model and validate model model, err := u.modelUsecase.GetChatModel(ctx) if err != nil { if err == gorm.ErrRecordNotFound { eventCh <- domain.SSEEvent{Type: "error", Content: "请前往管理后台,点击右上角的“系统设置”配置推理大模型。"} } else { eventCh <- domain.SSEEvent{Type: "error", Content: "模型获取失败"} } return } req.ModelInfo = model // 3. conversation management if req.AppType == domain.AppTypeWechatServiceBot || req.AppType == domain.AppTypeWechatBot || req.AppType == domain.AppTypeWecomAIBot { // wechat service has its own id nonce := uuid.New().String() eventCh <- domain.SSEEvent{Type: "conversation_id", Content: req.ConversationID} eventCh <- domain.SSEEvent{Type: "nonce", Content: nonce} err = u.conversationUsecase.CreateConversation(ctx, &domain.Conversation{ ID: req.ConversationID, Nonce: nonce, AppID: req.AppID, KBID: req.KBID, Subject: req.Message, RemoteIP: req.RemoteIP, Info: req.Info, CreatedAt: time.Now(), }) if err != nil { u.logger.Error("failed to create chat conversation", log.Error(err)) eventCh <- domain.SSEEvent{Type: "error", Content: "failed to create chat conversation"} return } } else if req.ConversationID == "" { id, err := uuid.NewV7() if err != nil { u.logger.Error("failed to generate conversation uuid", log.Error(err)) id = uuid.New() } conversationID := id.String() req.ConversationID = conversationID nonce := uuid.New().String() eventCh <- domain.SSEEvent{Type: "conversation_id", Content: conversationID} eventCh <- domain.SSEEvent{Type: "nonce", Content: nonce} err = u.conversationUsecase.CreateConversation(ctx, &domain.Conversation{ ID: conversationID, Nonce: nonce, AppID: req.AppID, KBID: req.KBID, Subject: req.Message, RemoteIP: req.RemoteIP, Info: req.Info, CreatedAt: time.Now(), }) if err != nil { u.logger.Error("failed to create chat conversation", log.Error(err)) eventCh <- domain.SSEEvent{Type: "error", Content: "failed to create chat conversation"} return } } else { if req.Nonce == "" { eventCh <- domain.SSEEvent{Type: "error", Content: "nonce is required"} return } err := u.conversationUsecase.ValidateConversationNonce(ctx, req.ConversationID, req.Nonce) if err != nil { u.logger.Error("failed to validate chat conversation nonce", log.Error(err)) eventCh <- domain.SSEEvent{Type: "error", Content: "validate chat conversation nonce failed"} return } } messageId := uuid.New().String() eventCh <- domain.SSEEvent{Type: "message_id", Content: messageId} userMessageId := uuid.New().String() // save user question to conversation message if err := u.conversationUsecase.CreateChatConversationMessage(ctx, req.KBID, &domain.ConversationMessage{ ID: userMessageId, ConversationID: req.ConversationID, KBID: req.KBID, AppID: req.AppID, Role: schema.User, Content: req.Message, ImagePaths: req.ImagePaths, RemoteIP: req.RemoteIP, }); err != nil { u.logger.Error("failed to save user question to conversation message", log.Error(err)) eventCh <- domain.SSEEvent{Type: "error", Content: "failed to save user question to conversation message"} return } // extra1. if user set question block words then check it blockWords, err := u.blockWordRepo.GetBlockWords(ctx, req.KBID) if err != nil { u.logger.Error("failed to get question block words", log.Error(err)) eventCh <- domain.SSEEvent{Type: "error", Content: "failed to get question block words"} return } if len(blockWords) > 0 { // check --> filter questionFilter := utils.GetDFA(req.KBID) if err := questionFilter.DFA.Check(req.Message); err != nil { // exist then return err answer := "**您的问题包含敏感词, AI 无法回答您的问题。**" eventCh <- domain.SSEEvent{Type: "error", Content: answer} // save ai answer and set it err if err := u.conversationUsecase.CreateChatConversationMessage(context.Background(), req.KBID, &domain.ConversationMessage{ ID: messageId, ConversationID: req.ConversationID, KBID: req.KBID, AppID: req.AppID, Role: schema.Assistant, Content: answer, Provider: req.ModelInfo.Provider, Model: string(req.ModelInfo.Model), RemoteIP: req.RemoteIP, ParentID: userMessageId, }); err != nil { u.logger.Error("failed to save assistant answer to conversation message", log.Error(err)) eventCh <- domain.SSEEvent{Type: "error", Content: "failed to save assistant answer to conversation message"} return } return } } if req.Info.UserInfo.AuthUserID == 0 { auth, _ := u.AuthRepo.GetAuthBySourceType(ctx, req.AppType.ToSourceType()) if auth != nil { req.Info.UserInfo.AuthUserID = auth.ID } } groupIds, err := u.AuthRepo.GetAuthGroupIdsWithParentsByAuthId(ctx, req.Info.UserInfo.AuthUserID) if err != nil { u.logger.Error("failed to get auth groupIds", log.Error(err)) eventCh <- domain.SSEEvent{Type: "error", Content: "failed to get auth groupIds"} return } messages, rankedNodes, err := u.llmUsecase.BuildConversationMessageWithRAG(ctx, req.ConversationID, req.KBID, groupIds, req.Prompt) if err != nil { u.logger.Error("build messages failed", log.Error(err)) eventCh <- domain.SSEEvent{Type: "error", Content: err.Error()} return } u.logger.Debug("message:", log.Any("schema", messages)) for _, node := range rankedNodes { chunkResult := domain.NodeContentChunkSSE{ NodeID: node.NodeID, Name: node.NodeName, Summary: node.NodeSummary, NodePathNames: node.NodePathNames, } eventCh <- domain.SSEEvent{Type: "chunk_result", ChunkResult: &chunkResult} } // 5. LLM inference (streaming callback), message storage, token statistics answer := "" usage := schema.TokenUsage{} modelkitModel, err := req.ModelInfo.ToModelkitModel() if err != nil { u.logger.Error("failed to convert model to modelkit model", log.Error(err)) eventCh <- domain.SSEEvent{Type: "error", Content: "failed to convert model to modelkit model"} return } chatModel, err := u.modelkit.GetChatModel(ctx, modelkitModel) if err != nil { u.logger.Error("failed to get chat model", log.Error(err)) eventCh <- domain.SSEEvent{Type: "error", Content: "failed to get chat model"} return } // get words onChunkAC, flushBuffer := u.CreateAcOnChunk(ctx, req.KBID, &answer, eventCh, blockWords) chatErr := u.llmUsecase.ChatWithAgent(ctx, chatModel, messages, &usage, onChunkAC) // 处理缓冲区中剩余的内容 if flushBuffer != nil { flushBuffer(ctx, "data") } // save assistant answer to conversation message if err := u.conversationUsecase.CreateChatConversationMessage(ctx, req.KBID, &domain.ConversationMessage{ ID: messageId, ConversationID: req.ConversationID, KBID: req.KBID, AppID: req.AppID, Role: schema.Assistant, Content: answer, Provider: req.ModelInfo.Provider, Model: string(req.ModelInfo.Model), PromptTokens: usage.PromptTokens, CompletionTokens: usage.CompletionTokens, TotalTokens: usage.TotalTokens, RemoteIP: req.RemoteIP, ParentID: userMessageId, }); err != nil { u.logger.Error("failed to save assistant answer to conversation message", log.Error(err)) eventCh <- domain.SSEEvent{Type: "error", Content: "failed to save assistant answer to conversation message"} return } // update model usage if err := u.modelUsecase.UpdateUsage(ctx, req.ModelInfo.ID, &usage); err != nil { u.logger.Error("failed to update model usage", log.Error(err)) eventCh <- domain.SSEEvent{Type: "error", Content: "failed to update model usage"} return } if chatErr != nil { u.logger.Error("对话失败", log.Error(chatErr)) eventCh <- domain.SSEEvent{Type: "error", Content: "对话失败,请稍后再试"} return } eventCh <- domain.SSEEvent{Type: "done"} }() return eventCh, nil } func (u *ChatUsecase) ChatRagOnly(ctx context.Context, req *domain.ChatRagOnlyRequest) (<-chan domain.SSEEvent, error) { eventCh := make(chan domain.SSEEvent, 100) go func() { defer close(eventCh) // extra1. if user set question block words then check it blockWords, err := u.blockWordRepo.GetBlockWords(ctx, req.KBID) if err != nil { u.logger.Error("failed to get question block words", log.Error(err)) eventCh <- domain.SSEEvent{Type: "error", Content: "failed to get question block words"} return } if len(blockWords) > 0 { // check --> filter questionFilter := utils.GetDFA(req.KBID) if err := questionFilter.DFA.Check(req.Message); err != nil { // exist then return err answer := "**您的问题包含敏感词, AI 无法回答您的问题。**" eventCh <- domain.SSEEvent{Type: "error", Content: answer} return } } if req.UserInfo.AuthUserID == 0 { auth, _ := u.AuthRepo.GetAuthBySourceType(ctx, req.AppType.ToSourceType()) if auth != nil { req.UserInfo.AuthUserID = auth.ID } } groupIds, err := u.AuthRepo.GetAuthGroupIdsWithParentsByAuthId(ctx, req.UserInfo.AuthUserID) if err != nil { u.logger.Error("failed to get auth groupIds", log.Error(err)) eventCh <- domain.SSEEvent{Type: "error", Content: "failed to get auth groupIds"} return } // retrieve documents kb, err := u.kbRepo.GetKnowledgeBaseByID(ctx, req.KBID) if err != nil { u.logger.Error("failed to get kb", log.Error(err)) eventCh <- domain.SSEEvent{Type: "error", Content: "failed to get kb"} return } _, rankedNodes, err := u.llmUsecase.GetRankNodes(ctx, GetRankNodesRequest{ DatasetID: kb.DatasetID, Question: req.Message, GroupIDs: groupIds, HistoryMessages: nil, SimilarityThreshold: 0, MaxChunksPerDoc: 1, }) if err != nil { u.logger.Error("failed to get rank nodes", log.Error(err)) eventCh <- domain.SSEEvent{Type: "error", Content: "failed to get rank nodes"} return } documents := domain.FormatNodeChunks(rankedNodes, kb.AccessSettings.BaseURL) u.logger.Debug("documents", log.String("documents", documents)) // send only the documents part eventCh <- domain.SSEEvent{Type: "data", Content: documents} eventCh <- domain.SSEEvent{Type: "done"} }() return eventCh, nil } func (u *ChatUsecase) CreateAcOnChunk(ctx context.Context, kbID string, answer *string, eventCh chan<- domain.SSEEvent, blockWords []string) (func(ctx context.Context, dataType, chunk string) error, func(ctx context.Context, dataType string)) { var buffer strings.Builder // 如果用户没有设置敏感词,不需要处理 if len(blockWords) == 0 { onChunk := func(ctx context.Context, dataType, chunk string) error { *answer += chunk eventCh <- domain.SSEEvent{Type: dataType, Content: chunk} return nil } return onChunk, nil } // get filter --> exist filter := utils.GetDFA(kbID) onChunk := func(ctx context.Context, dataType, chunk string) error { buffer.WriteString(chunk) // 将缓冲区内容转换为 rune 切片,以便正确处理多字节字符 bufferRunes := []rune(buffer.String()) // 基于 rune 长度与 bufferSize 进行比较,确保正确处理多字节字符 if len(bufferRunes) >= filter.BuffSize { fullContent := buffer.String() // get buffer string // 直接处理完整内容 processedContent := u.replaceWithSimpleString(fullContent, filter.DFA) processedRunes := []rune(processedContent) // 输出前面的部分,保留后面bufferSize - 1个rune outputPart := string(processedRunes[:len(processedRunes)-filter.BuffSize+1]) *answer += outputPart eventCh <- domain.SSEEvent{Type: dataType, Content: outputPart} // 清空缓冲区 newBufferContent := string(processedRunes[len(processedRunes)-filter.BuffSize+1:]) buffer.Reset() buffer.WriteString(newBufferContent) } return nil } flushBuffer := func(ctx context.Context, dataType string) { //小于bufferSize的内容 bufferRunes := []rune(buffer.String()) if len(bufferRunes) > 0 { fullContent := buffer.String() processedContent := u.replaceWithSimpleString(fullContent, filter.DFA) *answer += processedContent eventCh <- domain.SSEEvent{Type: dataType, Content: processedContent} } } return onChunk, flushBuffer } // replaceWithSimpleString func (u *ChatUsecase) replaceWithSimpleString(content string, filter *utils.DFA) string { r1 := filter.Filter(content) return r1 } func (u *ChatUsecase) Search(ctx context.Context, req *domain.ChatSearchReq) (*domain.ChatSearchResp, error) { groupIds, err := u.AuthRepo.GetAuthGroupIdsWithParentsByAuthId(ctx, req.AuthUserID) if err != nil { return nil, err } kb, err := u.kbRepo.GetKnowledgeBaseByID(ctx, req.KBID) if err != nil { return nil, err } _, rankedNodes, err := u.llmUsecase.GetRankNodes(ctx, GetRankNodesRequest{ DatasetID: kb.DatasetID, Question: req.Message, GroupIDs: groupIds, SimilarityThreshold: 0.2, HistoryMessages: nil, }) if err != nil { return nil, err } // Get node IDs from ranked nodes for permission check nodeIDs := lo.Map(rankedNodes, func(node *domain.RankedNodeChunks, _ int) string { return node.NodeID }) // Get nodes with permissions nodesMap, err := u.nodeRepo.GetNodesByIDs(ctx, nodeIDs) if err != nil { return nil, err } // Get user's visitable node IDs (for partial permission check) userGroupIds := lo.Map(groupIds, func(id int, _ int) uint { return uint(id) }) visitableNodeGroups, err := u.nodeRepo.GetNodeGroupsByGroupIdsPerm(ctx, userGroupIds, consts.NodePermNameVisitable) if err != nil { return nil, err } visitableNodeIds := lo.Map(visitableNodeGroups, func(v domain.NodeAuthGroup, _ int) string { return v.NodeID }) resp := domain.ChatSearchResp{} for _, node := range rankedNodes { // Check visitable permission if nodeInfo, ok := nodesMap[node.NodeID]; ok { switch nodeInfo.Permissions.Visitable { case consts.NodeAccessPermClosed: // Skip nodes with closed visitable permission continue case consts.NodeAccessPermPartial: // Skip if user doesn't have visitable permission for this node if !slices.Contains(visitableNodeIds, node.NodeID) { continue } } } chunkResult := domain.NodeContentChunkSSE{ NodeID: node.NodeID, Name: node.NodeName, Summary: node.NodeSummary, Emoji: node.NodeEmoji, NodePathNames: node.NodePathNames, } resp.NodeResult = append(resp.NodeResult, chunkResult) } return &resp, nil }