528 lines
15 KiB
Go
528 lines
15 KiB
Go
package usecase
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"slices"
|
||
"strings"
|
||
"time"
|
||
|
||
modelkit "github.com/chaitin/ModelKit/v2/usecase"
|
||
"github.com/cloudwego/eino-ext/components/model/deepseek"
|
||
"github.com/cloudwego/eino/components/model"
|
||
"github.com/cloudwego/eino/components/prompt"
|
||
"github.com/cloudwego/eino/schema"
|
||
"github.com/pkoukk/tiktoken-go"
|
||
"github.com/samber/lo"
|
||
|
||
"github.com/chaitin/panda-wiki/config"
|
||
"github.com/chaitin/panda-wiki/domain"
|
||
"github.com/chaitin/panda-wiki/log"
|
||
"github.com/chaitin/panda-wiki/repo/pg"
|
||
"github.com/chaitin/panda-wiki/store/rag"
|
||
"github.com/chaitin/panda-wiki/utils"
|
||
)
|
||
|
||
type LLMUsecase struct {
|
||
rag rag.RAGService
|
||
conversationRepo *pg.ConversationRepository
|
||
kbRepo *pg.KnowledgeBaseRepository
|
||
nodeRepo *pg.NodeRepository
|
||
modelRepo *pg.ModelRepository
|
||
promptRepo *pg.PromptRepo
|
||
config *config.Config
|
||
logger *log.Logger
|
||
modelkit *modelkit.ModelKit
|
||
}
|
||
|
||
const (
|
||
summaryChunkTokenLimit = 30720 // 30KB tokens per chunk
|
||
summaryMaxChunks = 4 // max chunks to process for summary
|
||
)
|
||
|
||
func NewLLMUsecase(config *config.Config, rag rag.RAGService, conversationRepo *pg.ConversationRepository, kbRepo *pg.KnowledgeBaseRepository, nodeRepo *pg.NodeRepository, modelRepo *pg.ModelRepository, promptRepo *pg.PromptRepo, logger *log.Logger) *LLMUsecase {
|
||
tiktoken.SetBpeLoader(&utils.Localloader{})
|
||
modelkit := modelkit.NewModelKit(logger.Logger)
|
||
return &LLMUsecase{
|
||
config: config,
|
||
rag: rag,
|
||
conversationRepo: conversationRepo,
|
||
kbRepo: kbRepo,
|
||
nodeRepo: nodeRepo,
|
||
modelRepo: modelRepo,
|
||
promptRepo: promptRepo,
|
||
logger: logger.WithModule("usecase.llm"),
|
||
modelkit: modelkit,
|
||
}
|
||
}
|
||
|
||
func (u *LLMUsecase) BuildConversationMessageWithRAG(
|
||
ctx context.Context,
|
||
conversationID string,
|
||
kbID string,
|
||
groupIDs []int,
|
||
systemPrompt string,
|
||
) ([]*schema.Message, []*domain.RankedNodeChunks, error) {
|
||
messages := make([]*schema.Message, 0)
|
||
rankedNodes := make([]*domain.RankedNodeChunks, 0)
|
||
|
||
msgs, err := u.conversationRepo.GetConversationMessagesByID(ctx, conversationID)
|
||
if err != nil {
|
||
u.logger.Error("get conversation messages failed", log.Error(err))
|
||
return nil, nil, errors.New("get conversation messages failed")
|
||
}
|
||
if len(msgs) > 0 {
|
||
historyMessages := make([]*schema.Message, 0)
|
||
for _, msg := range msgs {
|
||
switch msg.Role {
|
||
case schema.Assistant:
|
||
historyMessages = append(historyMessages, schema.AssistantMessage(msg.Content, nil))
|
||
case schema.User:
|
||
content := u.formatMessageWithImages(msg.Content, msg.ImagePaths)
|
||
historyMessages = append(historyMessages, schema.UserMessage(content))
|
||
default:
|
||
continue
|
||
}
|
||
}
|
||
if len(historyMessages) > 0 {
|
||
question := historyMessages[len(historyMessages)-1].Content
|
||
var rewrittenQuery string
|
||
if systemPrompt == "" {
|
||
if settingPrompt, err := u.promptRepo.GetPromptContent(ctx, kbID); err != nil {
|
||
u.logger.Error("get prompt from settings failed", log.Error(err))
|
||
} else {
|
||
if settingPrompt != "" {
|
||
systemPrompt = settingPrompt
|
||
} else {
|
||
systemPrompt = domain.SystemDefaultPrompt
|
||
}
|
||
}
|
||
}
|
||
|
||
template := prompt.FromMessages(schema.GoTemplate,
|
||
schema.SystemMessage(systemPrompt),
|
||
schema.UserMessage(domain.UserQuestionFormatter),
|
||
)
|
||
kb, err := u.kbRepo.GetKnowledgeBaseByID(ctx, kbID)
|
||
if err != nil {
|
||
u.logger.Error("get kb failed", log.Error(err))
|
||
return nil, nil, errors.New("get kb failed")
|
||
}
|
||
rewrittenQuery, rankedNodes, err = u.GetRankNodes(ctx, GetRankNodesRequest{
|
||
DatasetID: kb.DatasetID,
|
||
Question: question,
|
||
GroupIDs: groupIDs,
|
||
SimilarityThreshold: 0.2,
|
||
HistoryMessages: historyMessages[:len(historyMessages)-1],
|
||
})
|
||
if err != nil {
|
||
u.logger.Error("get rank nodes failed", log.Error(err))
|
||
return nil, nil, errors.New("get rank nodes failed")
|
||
}
|
||
documents := domain.FormatNodeChunks(rankedNodes, kb.AccessSettings.BaseURL)
|
||
u.logger.Debug("documents", log.String("documents", documents))
|
||
|
||
formattedMessages, err := template.Format(ctx, map[string]any{
|
||
"CurrentDate": time.Now().Format("2006-01-02"),
|
||
"Question": rewrittenQuery,
|
||
"Documents": documents,
|
||
})
|
||
if err != nil {
|
||
u.logger.Error("format messages failed", log.Error(err))
|
||
return nil, nil, errors.New("format messages failed")
|
||
}
|
||
messages = slices.Insert(formattedMessages, 1, historyMessages[:len(historyMessages)-1]...)
|
||
}
|
||
}
|
||
return messages, rankedNodes, nil
|
||
}
|
||
|
||
func (u *LLMUsecase) ChatWithAgent(
|
||
ctx context.Context,
|
||
chatModel model.BaseChatModel,
|
||
messages []*schema.Message,
|
||
usage *schema.TokenUsage,
|
||
onChunk func(ctx context.Context, dataType, chunk string) error,
|
||
) error {
|
||
resp, err := chatModel.Stream(ctx, messages)
|
||
if err != nil {
|
||
return fmt.Errorf("stream failed: %w", err)
|
||
}
|
||
firstReasoning := false
|
||
firstData := false
|
||
|
||
for {
|
||
msg, err := resp.Recv()
|
||
if err == io.EOF {
|
||
break
|
||
}
|
||
if err != nil {
|
||
return fmt.Errorf("recv failed: %w", err)
|
||
}
|
||
reasoning, ok := deepseek.GetReasoningContent(msg)
|
||
if ok {
|
||
if !firstReasoning {
|
||
firstReasoning = true
|
||
reasoning = "<think>" + reasoning
|
||
}
|
||
if err := onChunk(ctx, "data", reasoning); err != nil {
|
||
return fmt.Errorf("on chunk reasoning: %w", err)
|
||
}
|
||
continue
|
||
}
|
||
if firstReasoning && !firstData {
|
||
firstData = true
|
||
msg.Content = "</think>\n" + msg.Content
|
||
if err := onChunk(ctx, "data", msg.Content); err != nil {
|
||
return fmt.Errorf("on chunk data: %w", err)
|
||
}
|
||
continue
|
||
}
|
||
if err := onChunk(ctx, "data", msg.Content); err != nil {
|
||
return fmt.Errorf("on chunk data: %w", err)
|
||
}
|
||
|
||
// set to usage
|
||
if msg.ResponseMeta.Usage != nil {
|
||
*usage = *msg.ResponseMeta.Usage
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (u *LLMUsecase) Generate(
|
||
ctx context.Context,
|
||
chatModel model.BaseChatModel,
|
||
messages []*schema.Message,
|
||
) (string, error) {
|
||
resp, err := chatModel.Generate(ctx, messages)
|
||
if err != nil {
|
||
return "", fmt.Errorf("generate failed: %w", err)
|
||
}
|
||
return resp.Content, nil
|
||
}
|
||
|
||
func (u *LLMUsecase) SummaryNode(ctx context.Context, kbID string, model *domain.Model, name, content string) (string, error) {
|
||
modelkitModel, err := model.ToModelkitModel()
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
chatModel, err := u.modelkit.GetChatModel(ctx, modelkitModel)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
chunks, err := u.SplitByTokenLimit(content, summaryChunkTokenLimit)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
if len(chunks) > summaryMaxChunks {
|
||
u.logger.Debug("trim summary chunks for large document", log.String("node", name), log.Int("original_chunks", len(chunks)), log.Int("used_chunks", summaryMaxChunks))
|
||
chunks = chunks[:summaryMaxChunks]
|
||
}
|
||
|
||
summaries := make([]string, 0, len(chunks))
|
||
for idx, chunk := range chunks {
|
||
summary, err := u.requestSummary(ctx, kbID, chatModel, name, chunk)
|
||
if err != nil {
|
||
u.logger.Error("Failed to generate summary for chunk", log.Int("chunk_index", idx), log.Error(err))
|
||
continue
|
||
}
|
||
if summary == "" {
|
||
u.logger.Warn("Empty summary returned for chunk", log.Int("chunk_index", idx))
|
||
continue
|
||
}
|
||
summaries = append(summaries, summary)
|
||
}
|
||
|
||
if len(summaries) == 0 {
|
||
return "", fmt.Errorf("failed to generate summary for document %s", name)
|
||
}
|
||
if len(summaries) == 1 {
|
||
return summaries[0], nil
|
||
}
|
||
|
||
// Join all summaries and generate final summary
|
||
joined := strings.Join(summaries, "\n\n")
|
||
finalSummary, err := u.requestSummary(ctx, kbID, chatModel, name, joined)
|
||
if err != nil {
|
||
u.logger.Error("Failed to generate final summary, using aggregated summaries", log.Error(err))
|
||
// Fallback: return the joined summaries directly
|
||
if len(joined) > 500 {
|
||
return joined[:500] + "...", nil
|
||
}
|
||
return joined, nil
|
||
}
|
||
return finalSummary, nil
|
||
}
|
||
|
||
func (u *LLMUsecase) StreamSummaryNode(
|
||
ctx context.Context,
|
||
kbID string,
|
||
model *domain.Model,
|
||
name, content string,
|
||
onChunk func(ctx context.Context, dataType, chunk string) error,
|
||
) error {
|
||
modelkitModel, err := model.ToModelkitModel()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
chatModel, err := u.modelkit.GetChatModel(ctx, modelkitModel)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
chunks, err := u.SplitByTokenLimit(content, summaryChunkTokenLimit)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if len(chunks) > summaryMaxChunks {
|
||
u.logger.Debug("trim summary chunks for large document", log.String("node", name), log.Int("original_chunks", len(chunks)), log.Int("used_chunks", summaryMaxChunks))
|
||
chunks = chunks[:summaryMaxChunks]
|
||
}
|
||
|
||
if len(chunks) == 1 {
|
||
return u.streamSummary(ctx, kbID, chatModel, name, chunks[0], onChunk)
|
||
}
|
||
|
||
summaries := make([]string, 0, len(chunks))
|
||
for idx, chunk := range chunks {
|
||
summary, summaryErr := u.requestSummary(ctx, kbID, chatModel, name, chunk)
|
||
if summaryErr != nil {
|
||
u.logger.Error("Failed to generate summary for chunk", log.Int("chunk_index", idx), log.Error(summaryErr))
|
||
continue
|
||
}
|
||
if summary == "" {
|
||
u.logger.Warn("Empty summary returned for chunk", log.Int("chunk_index", idx))
|
||
continue
|
||
}
|
||
summaries = append(summaries, summary)
|
||
}
|
||
|
||
if len(summaries) == 0 {
|
||
return fmt.Errorf("failed to generate summary for document %s", name)
|
||
}
|
||
if len(summaries) == 1 {
|
||
if err := onChunk(ctx, "data", summaries[0]); err != nil {
|
||
return fmt.Errorf("on chunk data: %w", err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
joined := strings.Join(summaries, "\n\n")
|
||
if err := u.streamSummary(ctx, kbID, chatModel, name, joined, onChunk); err != nil {
|
||
u.logger.Error("Failed to generate final summary, using aggregated summaries", log.Error(err))
|
||
if len(joined) > 500 {
|
||
joined = joined[:500] + "..."
|
||
}
|
||
if chunkErr := onChunk(ctx, "data", joined); chunkErr != nil {
|
||
return fmt.Errorf("on chunk data: %w", chunkErr)
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (u *LLMUsecase) trimThinking(summary string) string {
|
||
if !strings.HasPrefix(summary, "<think>") {
|
||
return summary
|
||
}
|
||
endIndex := strings.Index(summary, "</think>")
|
||
if endIndex == -1 {
|
||
return summary
|
||
}
|
||
return strings.TrimSpace(summary[endIndex+len("</think>"):])
|
||
}
|
||
|
||
func (u *LLMUsecase) requestSummary(ctx context.Context, kbID string, chatModel model.BaseChatModel, name, content string) (string, error) {
|
||
summaryPrompt, err := u.promptRepo.GetSummaryPrompt(ctx, kbID)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
summary, err := u.Generate(ctx, chatModel, []*schema.Message{
|
||
{
|
||
Role: "system",
|
||
Content: summaryPrompt,
|
||
},
|
||
{
|
||
Role: "user",
|
||
Content: fmt.Sprintf("文档名称:%s\n文档内容:%s", name, content),
|
||
},
|
||
})
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
return strings.TrimSpace(u.trimThinking(summary)), nil
|
||
}
|
||
|
||
func (u *LLMUsecase) streamSummary(
|
||
ctx context.Context,
|
||
kbID string,
|
||
chatModel model.BaseChatModel,
|
||
name, content string,
|
||
onChunk func(ctx context.Context, dataType, chunk string) error,
|
||
) error {
|
||
summaryPrompt, err := u.promptRepo.GetSummaryPrompt(ctx, kbID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
usage := schema.TokenUsage{}
|
||
filter := newThinkingStreamFilter()
|
||
return u.ChatWithAgent(ctx, chatModel, []*schema.Message{
|
||
{
|
||
Role: "system",
|
||
Content: summaryPrompt,
|
||
},
|
||
{
|
||
Role: "user",
|
||
Content: fmt.Sprintf("文档名称:%s\n文档内容:%s", name, content),
|
||
},
|
||
}, &usage, func(ctx context.Context, dataType, chunk string) error {
|
||
if dataType != "data" {
|
||
return onChunk(ctx, dataType, chunk)
|
||
}
|
||
cleaned := filter.Append(chunk)
|
||
if cleaned == "" {
|
||
return nil
|
||
}
|
||
return onChunk(ctx, dataType, cleaned)
|
||
})
|
||
}
|
||
|
||
type thinkingStreamFilter struct {
|
||
buffer strings.Builder
|
||
done bool
|
||
}
|
||
|
||
func newThinkingStreamFilter() *thinkingStreamFilter {
|
||
return &thinkingStreamFilter{}
|
||
}
|
||
|
||
func (f *thinkingStreamFilter) Append(chunk string) string {
|
||
if f.done {
|
||
return chunk
|
||
}
|
||
f.buffer.WriteString(chunk)
|
||
content := f.buffer.String()
|
||
if !strings.HasPrefix(content, "<think>") {
|
||
f.done = true
|
||
f.buffer.Reset()
|
||
return content
|
||
}
|
||
endIndex := strings.Index(content, "</think>")
|
||
if endIndex == -1 {
|
||
return ""
|
||
}
|
||
cleaned := strings.TrimSpace(content[endIndex+len("</think>"):])
|
||
f.done = true
|
||
f.buffer.Reset()
|
||
return cleaned
|
||
}
|
||
|
||
func (u *LLMUsecase) SplitByTokenLimit(text string, maxTokens int) ([]string, error) {
|
||
if maxTokens <= 0 {
|
||
return nil, fmt.Errorf("maxTokens must be greater than 0")
|
||
}
|
||
encoding, err := tiktoken.GetEncoding("cl100k_base")
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to get encoding: %w", err)
|
||
}
|
||
tokens := encoding.Encode(text, nil, nil)
|
||
if len(tokens) <= maxTokens {
|
||
return []string{text}, nil
|
||
}
|
||
|
||
// 预先计算需要的片段数量并分配空间
|
||
numChunks := (len(tokens) + maxTokens - 1) / maxTokens // 向上取整
|
||
result := make([]string, 0, numChunks)
|
||
|
||
for i := 0; i < len(tokens); i += maxTokens {
|
||
end := i + maxTokens
|
||
if end > len(tokens) {
|
||
end = len(tokens)
|
||
}
|
||
|
||
chunk := tokens[i:end]
|
||
decodedChunk := encoding.Decode(chunk)
|
||
result = append(result, decodedChunk)
|
||
}
|
||
|
||
return result, nil
|
||
}
|
||
|
||
type GetRankNodesRequest struct {
|
||
DatasetID string
|
||
Question string
|
||
GroupIDs []int
|
||
SimilarityThreshold float64
|
||
HistoryMessages []*schema.Message
|
||
MaxChunksPerDoc int
|
||
}
|
||
|
||
func (u *LLMUsecase) GetRankNodes(ctx context.Context, req GetRankNodesRequest) (string, []*domain.RankedNodeChunks, error) {
|
||
var rankedNodes []*domain.RankedNodeChunks
|
||
// get related documents from raglite
|
||
rewrittenQuery, records, err := u.rag.QueryRecords(ctx, &rag.QueryRecordsRequest{
|
||
DatasetID: req.DatasetID,
|
||
Query: req.Question,
|
||
GroupIDs: req.GroupIDs,
|
||
SimilarityThreshold: req.SimilarityThreshold,
|
||
HistoryMsgs: req.HistoryMessages,
|
||
MaxChunksPerDoc: req.MaxChunksPerDoc,
|
||
})
|
||
if err != nil {
|
||
return "", nil, fmt.Errorf("get records from raglite failed: %w", err)
|
||
}
|
||
u.logger.Info("get related documents from raglite", log.Any("record_count", len(records)))
|
||
rankedNodesMap := make(map[string]*domain.RankedNodeChunks)
|
||
// get raw node by doc_id
|
||
if len(records) > 0 {
|
||
docIDs := lo.Uniq(lo.Map(records, func(item *domain.NodeContentChunk, _ int) string {
|
||
return item.DocID
|
||
}))
|
||
u.logger.Info("node chunk doc ids", log.Any("docIDs", docIDs))
|
||
docIDNode, err := u.nodeRepo.GetNodeReleasesWithPathsByDocIDs(ctx, docIDs)
|
||
if err != nil {
|
||
return "", nil, fmt.Errorf("get nodes by ids failed: %w", err)
|
||
}
|
||
u.logger.Info("get node release by doc ids", log.Any("docIDNode", lo.Keys(docIDNode)))
|
||
for _, record := range records {
|
||
if nodeChunk, ok := rankedNodesMap[record.DocID]; !ok {
|
||
if docNode, ok := docIDNode[record.DocID]; ok {
|
||
rankNodeChunk := &domain.RankedNodeChunks{
|
||
NodeID: docNode.NodeID,
|
||
NodeName: docNode.Name,
|
||
NodeSummary: docNode.Meta.Summary,
|
||
NodeEmoji: docNode.Meta.Emoji,
|
||
NodePathNames: docNode.PathNames,
|
||
Chunks: []*domain.NodeContentChunk{record},
|
||
}
|
||
rankedNodes = append(rankedNodes, rankNodeChunk)
|
||
rankedNodesMap[record.DocID] = rankNodeChunk
|
||
}
|
||
} else {
|
||
nodeChunk.Chunks = append(nodeChunk.Chunks, record)
|
||
}
|
||
}
|
||
}
|
||
return rewrittenQuery, rankedNodes, nil
|
||
}
|
||
|
||
// formatMessageWithImages converts image paths to markdown format and appends to message
|
||
func (u *LLMUsecase) formatMessageWithImages(message string, imagePaths []string) string {
|
||
if len(imagePaths) == 0 {
|
||
return message
|
||
}
|
||
var builder strings.Builder
|
||
builder.WriteString(message)
|
||
for _, path := range imagePaths {
|
||
builder.WriteString("\n")
|
||
builder.WriteString(fmt.Sprintf("", path))
|
||
}
|
||
return builder.String()
|
||
}
|