package usecase
import (
"context"
"errors"
"fmt"
"io"
"slices"
"strings"
"time"
modelkit "github.com/chaitin/ModelKit/v2/usecase"
"github.com/cloudwego/eino-ext/components/model/deepseek"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/components/prompt"
"github.com/cloudwego/eino/schema"
"github.com/pkoukk/tiktoken-go"
"github.com/samber/lo"
"github.com/chaitin/panda-wiki/config"
"github.com/chaitin/panda-wiki/domain"
"github.com/chaitin/panda-wiki/log"
"github.com/chaitin/panda-wiki/repo/pg"
"github.com/chaitin/panda-wiki/store/rag"
"github.com/chaitin/panda-wiki/utils"
)
type LLMUsecase struct {
rag rag.RAGService
conversationRepo *pg.ConversationRepository
kbRepo *pg.KnowledgeBaseRepository
nodeRepo *pg.NodeRepository
modelRepo *pg.ModelRepository
promptRepo *pg.PromptRepo
config *config.Config
logger *log.Logger
modelkit *modelkit.ModelKit
}
const (
summaryChunkTokenLimit = 30720 // 30KB tokens per chunk
summaryMaxChunks = 4 // max chunks to process for summary
)
func NewLLMUsecase(config *config.Config, rag rag.RAGService, conversationRepo *pg.ConversationRepository, kbRepo *pg.KnowledgeBaseRepository, nodeRepo *pg.NodeRepository, modelRepo *pg.ModelRepository, promptRepo *pg.PromptRepo, logger *log.Logger) *LLMUsecase {
tiktoken.SetBpeLoader(&utils.Localloader{})
modelkit := modelkit.NewModelKit(logger.Logger)
return &LLMUsecase{
config: config,
rag: rag,
conversationRepo: conversationRepo,
kbRepo: kbRepo,
nodeRepo: nodeRepo,
modelRepo: modelRepo,
promptRepo: promptRepo,
logger: logger.WithModule("usecase.llm"),
modelkit: modelkit,
}
}
func (u *LLMUsecase) BuildConversationMessageWithRAG(
ctx context.Context,
conversationID string,
kbID string,
groupIDs []int,
systemPrompt string,
) ([]*schema.Message, []*domain.RankedNodeChunks, error) {
messages := make([]*schema.Message, 0)
rankedNodes := make([]*domain.RankedNodeChunks, 0)
msgs, err := u.conversationRepo.GetConversationMessagesByID(ctx, conversationID)
if err != nil {
u.logger.Error("get conversation messages failed", log.Error(err))
return nil, nil, errors.New("get conversation messages failed")
}
if len(msgs) > 0 {
historyMessages := make([]*schema.Message, 0)
for _, msg := range msgs {
switch msg.Role {
case schema.Assistant:
historyMessages = append(historyMessages, schema.AssistantMessage(msg.Content, nil))
case schema.User:
content := u.formatMessageWithImages(msg.Content, msg.ImagePaths)
historyMessages = append(historyMessages, schema.UserMessage(content))
default:
continue
}
}
if len(historyMessages) > 0 {
question := historyMessages[len(historyMessages)-1].Content
var rewrittenQuery string
if systemPrompt == "" {
if settingPrompt, err := u.promptRepo.GetPromptContent(ctx, kbID); err != nil {
u.logger.Error("get prompt from settings failed", log.Error(err))
} else {
if settingPrompt != "" {
systemPrompt = settingPrompt
} else {
systemPrompt = domain.SystemDefaultPrompt
}
}
}
template := prompt.FromMessages(schema.GoTemplate,
schema.SystemMessage(systemPrompt),
schema.UserMessage(domain.UserQuestionFormatter),
)
kb, err := u.kbRepo.GetKnowledgeBaseByID(ctx, kbID)
if err != nil {
u.logger.Error("get kb failed", log.Error(err))
return nil, nil, errors.New("get kb failed")
}
rewrittenQuery, rankedNodes, err = u.GetRankNodes(ctx, GetRankNodesRequest{
DatasetID: kb.DatasetID,
Question: question,
GroupIDs: groupIDs,
SimilarityThreshold: 0.2,
HistoryMessages: historyMessages[:len(historyMessages)-1],
})
if err != nil {
u.logger.Error("get rank nodes failed", log.Error(err))
return nil, nil, errors.New("get rank nodes failed")
}
documents := domain.FormatNodeChunks(rankedNodes, kb.AccessSettings.BaseURL)
u.logger.Debug("documents", log.String("documents", documents))
formattedMessages, err := template.Format(ctx, map[string]any{
"CurrentDate": time.Now().Format("2006-01-02"),
"Question": rewrittenQuery,
"Documents": documents,
})
if err != nil {
u.logger.Error("format messages failed", log.Error(err))
return nil, nil, errors.New("format messages failed")
}
messages = slices.Insert(formattedMessages, 1, historyMessages[:len(historyMessages)-1]...)
}
}
return messages, rankedNodes, nil
}
func (u *LLMUsecase) ChatWithAgent(
ctx context.Context,
chatModel model.BaseChatModel,
messages []*schema.Message,
usage *schema.TokenUsage,
onChunk func(ctx context.Context, dataType, chunk string) error,
) error {
resp, err := chatModel.Stream(ctx, messages)
if err != nil {
return fmt.Errorf("stream failed: %w", err)
}
firstReasoning := false
firstData := false
for {
msg, err := resp.Recv()
if err == io.EOF {
break
}
if err != nil {
return fmt.Errorf("recv failed: %w", err)
}
reasoning, ok := deepseek.GetReasoningContent(msg)
if ok {
if !firstReasoning {
firstReasoning = true
reasoning = "" + reasoning
}
if err := onChunk(ctx, "data", reasoning); err != nil {
return fmt.Errorf("on chunk reasoning: %w", err)
}
continue
}
if firstReasoning && !firstData {
firstData = true
msg.Content = "\n" + msg.Content
if err := onChunk(ctx, "data", msg.Content); err != nil {
return fmt.Errorf("on chunk data: %w", err)
}
continue
}
if err := onChunk(ctx, "data", msg.Content); err != nil {
return fmt.Errorf("on chunk data: %w", err)
}
// set to usage
if msg.ResponseMeta.Usage != nil {
*usage = *msg.ResponseMeta.Usage
}
}
return nil
}
func (u *LLMUsecase) Generate(
ctx context.Context,
chatModel model.BaseChatModel,
messages []*schema.Message,
) (string, error) {
resp, err := chatModel.Generate(ctx, messages)
if err != nil {
return "", fmt.Errorf("generate failed: %w", err)
}
return resp.Content, nil
}
func (u *LLMUsecase) SummaryNode(ctx context.Context, kbID string, model *domain.Model, name, content string) (string, error) {
modelkitModel, err := model.ToModelkitModel()
if err != nil {
return "", err
}
chatModel, err := u.modelkit.GetChatModel(ctx, modelkitModel)
if err != nil {
return "", err
}
chunks, err := u.SplitByTokenLimit(content, summaryChunkTokenLimit)
if err != nil {
return "", err
}
if len(chunks) > summaryMaxChunks {
u.logger.Debug("trim summary chunks for large document", log.String("node", name), log.Int("original_chunks", len(chunks)), log.Int("used_chunks", summaryMaxChunks))
chunks = chunks[:summaryMaxChunks]
}
summaries := make([]string, 0, len(chunks))
for idx, chunk := range chunks {
summary, err := u.requestSummary(ctx, kbID, chatModel, name, chunk)
if err != nil {
u.logger.Error("Failed to generate summary for chunk", log.Int("chunk_index", idx), log.Error(err))
continue
}
if summary == "" {
u.logger.Warn("Empty summary returned for chunk", log.Int("chunk_index", idx))
continue
}
summaries = append(summaries, summary)
}
if len(summaries) == 0 {
return "", fmt.Errorf("failed to generate summary for document %s", name)
}
if len(summaries) == 1 {
return summaries[0], nil
}
// Join all summaries and generate final summary
joined := strings.Join(summaries, "\n\n")
finalSummary, err := u.requestSummary(ctx, kbID, chatModel, name, joined)
if err != nil {
u.logger.Error("Failed to generate final summary, using aggregated summaries", log.Error(err))
// Fallback: return the joined summaries directly
if len(joined) > 500 {
return joined[:500] + "...", nil
}
return joined, nil
}
return finalSummary, nil
}
func (u *LLMUsecase) StreamSummaryNode(
ctx context.Context,
kbID string,
model *domain.Model,
name, content string,
onChunk func(ctx context.Context, dataType, chunk string) error,
) error {
modelkitModel, err := model.ToModelkitModel()
if err != nil {
return err
}
chatModel, err := u.modelkit.GetChatModel(ctx, modelkitModel)
if err != nil {
return err
}
chunks, err := u.SplitByTokenLimit(content, summaryChunkTokenLimit)
if err != nil {
return err
}
if len(chunks) > summaryMaxChunks {
u.logger.Debug("trim summary chunks for large document", log.String("node", name), log.Int("original_chunks", len(chunks)), log.Int("used_chunks", summaryMaxChunks))
chunks = chunks[:summaryMaxChunks]
}
if len(chunks) == 1 {
return u.streamSummary(ctx, kbID, chatModel, name, chunks[0], onChunk)
}
summaries := make([]string, 0, len(chunks))
for idx, chunk := range chunks {
summary, summaryErr := u.requestSummary(ctx, kbID, chatModel, name, chunk)
if summaryErr != nil {
u.logger.Error("Failed to generate summary for chunk", log.Int("chunk_index", idx), log.Error(summaryErr))
continue
}
if summary == "" {
u.logger.Warn("Empty summary returned for chunk", log.Int("chunk_index", idx))
continue
}
summaries = append(summaries, summary)
}
if len(summaries) == 0 {
return fmt.Errorf("failed to generate summary for document %s", name)
}
if len(summaries) == 1 {
if err := onChunk(ctx, "data", summaries[0]); err != nil {
return fmt.Errorf("on chunk data: %w", err)
}
return nil
}
joined := strings.Join(summaries, "\n\n")
if err := u.streamSummary(ctx, kbID, chatModel, name, joined, onChunk); err != nil {
u.logger.Error("Failed to generate final summary, using aggregated summaries", log.Error(err))
if len(joined) > 500 {
joined = joined[:500] + "..."
}
if chunkErr := onChunk(ctx, "data", joined); chunkErr != nil {
return fmt.Errorf("on chunk data: %w", chunkErr)
}
}
return nil
}
func (u *LLMUsecase) trimThinking(summary string) string {
if !strings.HasPrefix(summary, "") {
return summary
}
endIndex := strings.Index(summary, "")
if endIndex == -1 {
return summary
}
return strings.TrimSpace(summary[endIndex+len(""):])
}
func (u *LLMUsecase) requestSummary(ctx context.Context, kbID string, chatModel model.BaseChatModel, name, content string) (string, error) {
summaryPrompt, err := u.promptRepo.GetSummaryPrompt(ctx, kbID)
if err != nil {
return "", err
}
summary, err := u.Generate(ctx, chatModel, []*schema.Message{
{
Role: "system",
Content: summaryPrompt,
},
{
Role: "user",
Content: fmt.Sprintf("文档名称:%s\n文档内容:%s", name, content),
},
})
if err != nil {
return "", err
}
return strings.TrimSpace(u.trimThinking(summary)), nil
}
func (u *LLMUsecase) streamSummary(
ctx context.Context,
kbID string,
chatModel model.BaseChatModel,
name, content string,
onChunk func(ctx context.Context, dataType, chunk string) error,
) error {
summaryPrompt, err := u.promptRepo.GetSummaryPrompt(ctx, kbID)
if err != nil {
return err
}
usage := schema.TokenUsage{}
filter := newThinkingStreamFilter()
return u.ChatWithAgent(ctx, chatModel, []*schema.Message{
{
Role: "system",
Content: summaryPrompt,
},
{
Role: "user",
Content: fmt.Sprintf("文档名称:%s\n文档内容:%s", name, content),
},
}, &usage, func(ctx context.Context, dataType, chunk string) error {
if dataType != "data" {
return onChunk(ctx, dataType, chunk)
}
cleaned := filter.Append(chunk)
if cleaned == "" {
return nil
}
return onChunk(ctx, dataType, cleaned)
})
}
type thinkingStreamFilter struct {
buffer strings.Builder
done bool
}
func newThinkingStreamFilter() *thinkingStreamFilter {
return &thinkingStreamFilter{}
}
func (f *thinkingStreamFilter) Append(chunk string) string {
if f.done {
return chunk
}
f.buffer.WriteString(chunk)
content := f.buffer.String()
if !strings.HasPrefix(content, "") {
f.done = true
f.buffer.Reset()
return content
}
endIndex := strings.Index(content, "")
if endIndex == -1 {
return ""
}
cleaned := strings.TrimSpace(content[endIndex+len(""):])
f.done = true
f.buffer.Reset()
return cleaned
}
func (u *LLMUsecase) SplitByTokenLimit(text string, maxTokens int) ([]string, error) {
if maxTokens <= 0 {
return nil, fmt.Errorf("maxTokens must be greater than 0")
}
encoding, err := tiktoken.GetEncoding("cl100k_base")
if err != nil {
return nil, fmt.Errorf("failed to get encoding: %w", err)
}
tokens := encoding.Encode(text, nil, nil)
if len(tokens) <= maxTokens {
return []string{text}, nil
}
// 预先计算需要的片段数量并分配空间
numChunks := (len(tokens) + maxTokens - 1) / maxTokens // 向上取整
result := make([]string, 0, numChunks)
for i := 0; i < len(tokens); i += maxTokens {
end := i + maxTokens
if end > len(tokens) {
end = len(tokens)
}
chunk := tokens[i:end]
decodedChunk := encoding.Decode(chunk)
result = append(result, decodedChunk)
}
return result, nil
}
type GetRankNodesRequest struct {
DatasetID string
Question string
GroupIDs []int
SimilarityThreshold float64
HistoryMessages []*schema.Message
MaxChunksPerDoc int
}
func (u *LLMUsecase) GetRankNodes(ctx context.Context, req GetRankNodesRequest) (string, []*domain.RankedNodeChunks, error) {
var rankedNodes []*domain.RankedNodeChunks
// get related documents from raglite
rewrittenQuery, records, err := u.rag.QueryRecords(ctx, &rag.QueryRecordsRequest{
DatasetID: req.DatasetID,
Query: req.Question,
GroupIDs: req.GroupIDs,
SimilarityThreshold: req.SimilarityThreshold,
HistoryMsgs: req.HistoryMessages,
MaxChunksPerDoc: req.MaxChunksPerDoc,
})
if err != nil {
return "", nil, fmt.Errorf("get records from raglite failed: %w", err)
}
u.logger.Info("get related documents from raglite", log.Any("record_count", len(records)))
rankedNodesMap := make(map[string]*domain.RankedNodeChunks)
// get raw node by doc_id
if len(records) > 0 {
docIDs := lo.Uniq(lo.Map(records, func(item *domain.NodeContentChunk, _ int) string {
return item.DocID
}))
u.logger.Info("node chunk doc ids", log.Any("docIDs", docIDs))
docIDNode, err := u.nodeRepo.GetNodeReleasesWithPathsByDocIDs(ctx, docIDs)
if err != nil {
return "", nil, fmt.Errorf("get nodes by ids failed: %w", err)
}
u.logger.Info("get node release by doc ids", log.Any("docIDNode", lo.Keys(docIDNode)))
for _, record := range records {
if nodeChunk, ok := rankedNodesMap[record.DocID]; !ok {
if docNode, ok := docIDNode[record.DocID]; ok {
rankNodeChunk := &domain.RankedNodeChunks{
NodeID: docNode.NodeID,
NodeName: docNode.Name,
NodeSummary: docNode.Meta.Summary,
NodeEmoji: docNode.Meta.Emoji,
NodePathNames: docNode.PathNames,
Chunks: []*domain.NodeContentChunk{record},
}
rankedNodes = append(rankedNodes, rankNodeChunk)
rankedNodesMap[record.DocID] = rankNodeChunk
}
} else {
nodeChunk.Chunks = append(nodeChunk.Chunks, record)
}
}
}
return rewrittenQuery, rankedNodes, nil
}
// formatMessageWithImages converts image paths to markdown format and appends to message
func (u *LLMUsecase) formatMessageWithImages(message string, imagePaths []string) string {
if len(imagePaths) == 0 {
return message
}
var builder strings.Builder
builder.WriteString(message)
for _, path := range imagePaths {
builder.WriteString("\n")
builder.WriteString(fmt.Sprintf("", path))
}
return builder.String()
}