Files
YouduWiki/backend/usecase/llm.go
2026-05-21 19:52:45 +08:00

528 lines
15 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package usecase
import (
"context"
"errors"
"fmt"
"io"
"slices"
"strings"
"time"
modelkit "github.com/chaitin/ModelKit/v2/usecase"
"github.com/cloudwego/eino-ext/components/model/deepseek"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/components/prompt"
"github.com/cloudwego/eino/schema"
"github.com/pkoukk/tiktoken-go"
"github.com/samber/lo"
"github.com/chaitin/panda-wiki/config"
"github.com/chaitin/panda-wiki/domain"
"github.com/chaitin/panda-wiki/log"
"github.com/chaitin/panda-wiki/repo/pg"
"github.com/chaitin/panda-wiki/store/rag"
"github.com/chaitin/panda-wiki/utils"
)
type LLMUsecase struct {
rag rag.RAGService
conversationRepo *pg.ConversationRepository
kbRepo *pg.KnowledgeBaseRepository
nodeRepo *pg.NodeRepository
modelRepo *pg.ModelRepository
promptRepo *pg.PromptRepo
config *config.Config
logger *log.Logger
modelkit *modelkit.ModelKit
}
const (
summaryChunkTokenLimit = 30720 // 30KB tokens per chunk
summaryMaxChunks = 4 // max chunks to process for summary
)
func NewLLMUsecase(config *config.Config, rag rag.RAGService, conversationRepo *pg.ConversationRepository, kbRepo *pg.KnowledgeBaseRepository, nodeRepo *pg.NodeRepository, modelRepo *pg.ModelRepository, promptRepo *pg.PromptRepo, logger *log.Logger) *LLMUsecase {
tiktoken.SetBpeLoader(&utils.Localloader{})
modelkit := modelkit.NewModelKit(logger.Logger)
return &LLMUsecase{
config: config,
rag: rag,
conversationRepo: conversationRepo,
kbRepo: kbRepo,
nodeRepo: nodeRepo,
modelRepo: modelRepo,
promptRepo: promptRepo,
logger: logger.WithModule("usecase.llm"),
modelkit: modelkit,
}
}
func (u *LLMUsecase) BuildConversationMessageWithRAG(
ctx context.Context,
conversationID string,
kbID string,
groupIDs []int,
systemPrompt string,
) ([]*schema.Message, []*domain.RankedNodeChunks, error) {
messages := make([]*schema.Message, 0)
rankedNodes := make([]*domain.RankedNodeChunks, 0)
msgs, err := u.conversationRepo.GetConversationMessagesByID(ctx, conversationID)
if err != nil {
u.logger.Error("get conversation messages failed", log.Error(err))
return nil, nil, errors.New("get conversation messages failed")
}
if len(msgs) > 0 {
historyMessages := make([]*schema.Message, 0)
for _, msg := range msgs {
switch msg.Role {
case schema.Assistant:
historyMessages = append(historyMessages, schema.AssistantMessage(msg.Content, nil))
case schema.User:
content := u.formatMessageWithImages(msg.Content, msg.ImagePaths)
historyMessages = append(historyMessages, schema.UserMessage(content))
default:
continue
}
}
if len(historyMessages) > 0 {
question := historyMessages[len(historyMessages)-1].Content
var rewrittenQuery string
if systemPrompt == "" {
if settingPrompt, err := u.promptRepo.GetPromptContent(ctx, kbID); err != nil {
u.logger.Error("get prompt from settings failed", log.Error(err))
} else {
if settingPrompt != "" {
systemPrompt = settingPrompt
} else {
systemPrompt = domain.SystemDefaultPrompt
}
}
}
template := prompt.FromMessages(schema.GoTemplate,
schema.SystemMessage(systemPrompt),
schema.UserMessage(domain.UserQuestionFormatter),
)
kb, err := u.kbRepo.GetKnowledgeBaseByID(ctx, kbID)
if err != nil {
u.logger.Error("get kb failed", log.Error(err))
return nil, nil, errors.New("get kb failed")
}
rewrittenQuery, rankedNodes, err = u.GetRankNodes(ctx, GetRankNodesRequest{
DatasetID: kb.DatasetID,
Question: question,
GroupIDs: groupIDs,
SimilarityThreshold: 0.2,
HistoryMessages: historyMessages[:len(historyMessages)-1],
})
if err != nil {
u.logger.Error("get rank nodes failed", log.Error(err))
return nil, nil, errors.New("get rank nodes failed")
}
documents := domain.FormatNodeChunks(rankedNodes, kb.AccessSettings.BaseURL)
u.logger.Debug("documents", log.String("documents", documents))
formattedMessages, err := template.Format(ctx, map[string]any{
"CurrentDate": time.Now().Format("2006-01-02"),
"Question": rewrittenQuery,
"Documents": documents,
})
if err != nil {
u.logger.Error("format messages failed", log.Error(err))
return nil, nil, errors.New("format messages failed")
}
messages = slices.Insert(formattedMessages, 1, historyMessages[:len(historyMessages)-1]...)
}
}
return messages, rankedNodes, nil
}
func (u *LLMUsecase) ChatWithAgent(
ctx context.Context,
chatModel model.BaseChatModel,
messages []*schema.Message,
usage *schema.TokenUsage,
onChunk func(ctx context.Context, dataType, chunk string) error,
) error {
resp, err := chatModel.Stream(ctx, messages)
if err != nil {
return fmt.Errorf("stream failed: %w", err)
}
firstReasoning := false
firstData := false
for {
msg, err := resp.Recv()
if err == io.EOF {
break
}
if err != nil {
return fmt.Errorf("recv failed: %w", err)
}
reasoning, ok := deepseek.GetReasoningContent(msg)
if ok {
if !firstReasoning {
firstReasoning = true
reasoning = "<think>" + reasoning
}
if err := onChunk(ctx, "data", reasoning); err != nil {
return fmt.Errorf("on chunk reasoning: %w", err)
}
continue
}
if firstReasoning && !firstData {
firstData = true
msg.Content = "</think>\n" + msg.Content
if err := onChunk(ctx, "data", msg.Content); err != nil {
return fmt.Errorf("on chunk data: %w", err)
}
continue
}
if err := onChunk(ctx, "data", msg.Content); err != nil {
return fmt.Errorf("on chunk data: %w", err)
}
// set to usage
if msg.ResponseMeta.Usage != nil {
*usage = *msg.ResponseMeta.Usage
}
}
return nil
}
func (u *LLMUsecase) Generate(
ctx context.Context,
chatModel model.BaseChatModel,
messages []*schema.Message,
) (string, error) {
resp, err := chatModel.Generate(ctx, messages)
if err != nil {
return "", fmt.Errorf("generate failed: %w", err)
}
return resp.Content, nil
}
func (u *LLMUsecase) SummaryNode(ctx context.Context, kbID string, model *domain.Model, name, content string) (string, error) {
modelkitModel, err := model.ToModelkitModel()
if err != nil {
return "", err
}
chatModel, err := u.modelkit.GetChatModel(ctx, modelkitModel)
if err != nil {
return "", err
}
chunks, err := u.SplitByTokenLimit(content, summaryChunkTokenLimit)
if err != nil {
return "", err
}
if len(chunks) > summaryMaxChunks {
u.logger.Debug("trim summary chunks for large document", log.String("node", name), log.Int("original_chunks", len(chunks)), log.Int("used_chunks", summaryMaxChunks))
chunks = chunks[:summaryMaxChunks]
}
summaries := make([]string, 0, len(chunks))
for idx, chunk := range chunks {
summary, err := u.requestSummary(ctx, kbID, chatModel, name, chunk)
if err != nil {
u.logger.Error("Failed to generate summary for chunk", log.Int("chunk_index", idx), log.Error(err))
continue
}
if summary == "" {
u.logger.Warn("Empty summary returned for chunk", log.Int("chunk_index", idx))
continue
}
summaries = append(summaries, summary)
}
if len(summaries) == 0 {
return "", fmt.Errorf("failed to generate summary for document %s", name)
}
if len(summaries) == 1 {
return summaries[0], nil
}
// Join all summaries and generate final summary
joined := strings.Join(summaries, "\n\n")
finalSummary, err := u.requestSummary(ctx, kbID, chatModel, name, joined)
if err != nil {
u.logger.Error("Failed to generate final summary, using aggregated summaries", log.Error(err))
// Fallback: return the joined summaries directly
if len(joined) > 500 {
return joined[:500] + "...", nil
}
return joined, nil
}
return finalSummary, nil
}
func (u *LLMUsecase) StreamSummaryNode(
ctx context.Context,
kbID string,
model *domain.Model,
name, content string,
onChunk func(ctx context.Context, dataType, chunk string) error,
) error {
modelkitModel, err := model.ToModelkitModel()
if err != nil {
return err
}
chatModel, err := u.modelkit.GetChatModel(ctx, modelkitModel)
if err != nil {
return err
}
chunks, err := u.SplitByTokenLimit(content, summaryChunkTokenLimit)
if err != nil {
return err
}
if len(chunks) > summaryMaxChunks {
u.logger.Debug("trim summary chunks for large document", log.String("node", name), log.Int("original_chunks", len(chunks)), log.Int("used_chunks", summaryMaxChunks))
chunks = chunks[:summaryMaxChunks]
}
if len(chunks) == 1 {
return u.streamSummary(ctx, kbID, chatModel, name, chunks[0], onChunk)
}
summaries := make([]string, 0, len(chunks))
for idx, chunk := range chunks {
summary, summaryErr := u.requestSummary(ctx, kbID, chatModel, name, chunk)
if summaryErr != nil {
u.logger.Error("Failed to generate summary for chunk", log.Int("chunk_index", idx), log.Error(summaryErr))
continue
}
if summary == "" {
u.logger.Warn("Empty summary returned for chunk", log.Int("chunk_index", idx))
continue
}
summaries = append(summaries, summary)
}
if len(summaries) == 0 {
return fmt.Errorf("failed to generate summary for document %s", name)
}
if len(summaries) == 1 {
if err := onChunk(ctx, "data", summaries[0]); err != nil {
return fmt.Errorf("on chunk data: %w", err)
}
return nil
}
joined := strings.Join(summaries, "\n\n")
if err := u.streamSummary(ctx, kbID, chatModel, name, joined, onChunk); err != nil {
u.logger.Error("Failed to generate final summary, using aggregated summaries", log.Error(err))
if len(joined) > 500 {
joined = joined[:500] + "..."
}
if chunkErr := onChunk(ctx, "data", joined); chunkErr != nil {
return fmt.Errorf("on chunk data: %w", chunkErr)
}
}
return nil
}
func (u *LLMUsecase) trimThinking(summary string) string {
if !strings.HasPrefix(summary, "<think>") {
return summary
}
endIndex := strings.Index(summary, "</think>")
if endIndex == -1 {
return summary
}
return strings.TrimSpace(summary[endIndex+len("</think>"):])
}
func (u *LLMUsecase) requestSummary(ctx context.Context, kbID string, chatModel model.BaseChatModel, name, content string) (string, error) {
summaryPrompt, err := u.promptRepo.GetSummaryPrompt(ctx, kbID)
if err != nil {
return "", err
}
summary, err := u.Generate(ctx, chatModel, []*schema.Message{
{
Role: "system",
Content: summaryPrompt,
},
{
Role: "user",
Content: fmt.Sprintf("文档名称:%s\n文档内容%s", name, content),
},
})
if err != nil {
return "", err
}
return strings.TrimSpace(u.trimThinking(summary)), nil
}
func (u *LLMUsecase) streamSummary(
ctx context.Context,
kbID string,
chatModel model.BaseChatModel,
name, content string,
onChunk func(ctx context.Context, dataType, chunk string) error,
) error {
summaryPrompt, err := u.promptRepo.GetSummaryPrompt(ctx, kbID)
if err != nil {
return err
}
usage := schema.TokenUsage{}
filter := newThinkingStreamFilter()
return u.ChatWithAgent(ctx, chatModel, []*schema.Message{
{
Role: "system",
Content: summaryPrompt,
},
{
Role: "user",
Content: fmt.Sprintf("文档名称:%s\n文档内容%s", name, content),
},
}, &usage, func(ctx context.Context, dataType, chunk string) error {
if dataType != "data" {
return onChunk(ctx, dataType, chunk)
}
cleaned := filter.Append(chunk)
if cleaned == "" {
return nil
}
return onChunk(ctx, dataType, cleaned)
})
}
type thinkingStreamFilter struct {
buffer strings.Builder
done bool
}
func newThinkingStreamFilter() *thinkingStreamFilter {
return &thinkingStreamFilter{}
}
func (f *thinkingStreamFilter) Append(chunk string) string {
if f.done {
return chunk
}
f.buffer.WriteString(chunk)
content := f.buffer.String()
if !strings.HasPrefix(content, "<think>") {
f.done = true
f.buffer.Reset()
return content
}
endIndex := strings.Index(content, "</think>")
if endIndex == -1 {
return ""
}
cleaned := strings.TrimSpace(content[endIndex+len("</think>"):])
f.done = true
f.buffer.Reset()
return cleaned
}
func (u *LLMUsecase) SplitByTokenLimit(text string, maxTokens int) ([]string, error) {
if maxTokens <= 0 {
return nil, fmt.Errorf("maxTokens must be greater than 0")
}
encoding, err := tiktoken.GetEncoding("cl100k_base")
if err != nil {
return nil, fmt.Errorf("failed to get encoding: %w", err)
}
tokens := encoding.Encode(text, nil, nil)
if len(tokens) <= maxTokens {
return []string{text}, nil
}
// 预先计算需要的片段数量并分配空间
numChunks := (len(tokens) + maxTokens - 1) / maxTokens // 向上取整
result := make([]string, 0, numChunks)
for i := 0; i < len(tokens); i += maxTokens {
end := i + maxTokens
if end > len(tokens) {
end = len(tokens)
}
chunk := tokens[i:end]
decodedChunk := encoding.Decode(chunk)
result = append(result, decodedChunk)
}
return result, nil
}
type GetRankNodesRequest struct {
DatasetID string
Question string
GroupIDs []int
SimilarityThreshold float64
HistoryMessages []*schema.Message
MaxChunksPerDoc int
}
func (u *LLMUsecase) GetRankNodes(ctx context.Context, req GetRankNodesRequest) (string, []*domain.RankedNodeChunks, error) {
var rankedNodes []*domain.RankedNodeChunks
// get related documents from raglite
rewrittenQuery, records, err := u.rag.QueryRecords(ctx, &rag.QueryRecordsRequest{
DatasetID: req.DatasetID,
Query: req.Question,
GroupIDs: req.GroupIDs,
SimilarityThreshold: req.SimilarityThreshold,
HistoryMsgs: req.HistoryMessages,
MaxChunksPerDoc: req.MaxChunksPerDoc,
})
if err != nil {
return "", nil, fmt.Errorf("get records from raglite failed: %w", err)
}
u.logger.Info("get related documents from raglite", log.Any("record_count", len(records)))
rankedNodesMap := make(map[string]*domain.RankedNodeChunks)
// get raw node by doc_id
if len(records) > 0 {
docIDs := lo.Uniq(lo.Map(records, func(item *domain.NodeContentChunk, _ int) string {
return item.DocID
}))
u.logger.Info("node chunk doc ids", log.Any("docIDs", docIDs))
docIDNode, err := u.nodeRepo.GetNodeReleasesWithPathsByDocIDs(ctx, docIDs)
if err != nil {
return "", nil, fmt.Errorf("get nodes by ids failed: %w", err)
}
u.logger.Info("get node release by doc ids", log.Any("docIDNode", lo.Keys(docIDNode)))
for _, record := range records {
if nodeChunk, ok := rankedNodesMap[record.DocID]; !ok {
if docNode, ok := docIDNode[record.DocID]; ok {
rankNodeChunk := &domain.RankedNodeChunks{
NodeID: docNode.NodeID,
NodeName: docNode.Name,
NodeSummary: docNode.Meta.Summary,
NodeEmoji: docNode.Meta.Emoji,
NodePathNames: docNode.PathNames,
Chunks: []*domain.NodeContentChunk{record},
}
rankedNodes = append(rankedNodes, rankNodeChunk)
rankedNodesMap[record.DocID] = rankNodeChunk
}
} else {
nodeChunk.Chunks = append(nodeChunk.Chunks, record)
}
}
}
return rewrittenQuery, rankedNodes, nil
}
// formatMessageWithImages converts image paths to markdown format and appends to message
func (u *LLMUsecase) formatMessageWithImages(message string, imagePaths []string) string {
if len(imagePaths) == 0 {
return message
}
var builder strings.Builder
builder.WriteString(message)
for _, path := range imagePaths {
builder.WriteString("\n")
builder.WriteString(fmt.Sprintf("![](%s)", path))
}
return builder.String()
}