172 lines
6.1 KiB
Go
172 lines
6.1 KiB
Go
package mq
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
|
|
"github.com/chaitin/panda-wiki/consts"
|
|
"github.com/chaitin/panda-wiki/domain"
|
|
"github.com/chaitin/panda-wiki/log"
|
|
"github.com/chaitin/panda-wiki/mq"
|
|
"github.com/chaitin/panda-wiki/mq/types"
|
|
"github.com/chaitin/panda-wiki/repo/pg"
|
|
"github.com/chaitin/panda-wiki/store/rag"
|
|
"github.com/chaitin/panda-wiki/usecase"
|
|
)
|
|
|
|
type RAGMQHandler struct {
|
|
consumer mq.MQConsumer
|
|
logger *log.Logger
|
|
rag rag.RAGService
|
|
nodeRepo *pg.NodeRepository
|
|
kbRepo *pg.KnowledgeBaseRepository
|
|
llmUsecase *usecase.LLMUsecase
|
|
modelUsecase *usecase.ModelUsecase
|
|
}
|
|
|
|
func NewRAGMQHandler(consumer mq.MQConsumer, logger *log.Logger, rag rag.RAGService, nodeRepo *pg.NodeRepository, kbRepo *pg.KnowledgeBaseRepository, llmUsecase *usecase.LLMUsecase, modelUsecase *usecase.ModelUsecase) (*RAGMQHandler, error) {
|
|
h := &RAGMQHandler{
|
|
consumer: consumer,
|
|
logger: logger.WithModule("mq.rag"),
|
|
rag: rag,
|
|
nodeRepo: nodeRepo,
|
|
kbRepo: kbRepo,
|
|
llmUsecase: llmUsecase,
|
|
modelUsecase: modelUsecase,
|
|
}
|
|
if err := consumer.RegisterHandler(domain.VectorTaskTopic, h.HandleNodeContentVectorRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
return h, nil
|
|
}
|
|
|
|
func (h *RAGMQHandler) HandleNodeContentVectorRequest(ctx context.Context, msg types.Message) error {
|
|
var request domain.NodeReleaseVectorRequest
|
|
err := json.Unmarshal(msg.GetData(), &request)
|
|
if err != nil {
|
|
h.logger.Error("unmarshal node content vector request failed", log.Error(err))
|
|
return nil
|
|
}
|
|
switch request.Action {
|
|
case "update_group_ids":
|
|
h.logger.Info("update node group request", log.Any("request", request), log.Any("group_id", request.GroupIds))
|
|
kb, err := h.kbRepo.GetKnowledgeBaseByID(ctx, request.KBID)
|
|
if err != nil {
|
|
h.logger.Error("get kb failed", log.Error(err))
|
|
return nil
|
|
}
|
|
if err := h.rag.UpdateDocumentGroupIDs(ctx, kb.DatasetID, request.DocID, request.GroupIds); err != nil {
|
|
h.logger.Error("update node group failed", log.Error(err))
|
|
return nil
|
|
}
|
|
h.logger.Info("update node group success", log.Any("doc_id", request.DocID), log.Any("group_ids", request.GroupIds))
|
|
|
|
case "upsert":
|
|
h.logger.Debug("upsert node content vector request", "request", request)
|
|
nodeRelease, err := h.nodeRepo.GetNodeReleaseWithDirPathByID(ctx, request.NodeReleaseID)
|
|
if err != nil {
|
|
h.logger.Error("get node content by ids failed", log.Error(err))
|
|
return nil
|
|
}
|
|
if nodeRelease.Type == domain.NodeTypeFolder {
|
|
h.logger.Info("node is folder, skip upsert", log.Any("node_release_id", request.NodeReleaseID))
|
|
return nil
|
|
}
|
|
kb, err := h.kbRepo.GetKnowledgeBaseByID(ctx, request.KBID)
|
|
if err != nil {
|
|
h.logger.Error("get kb failed", log.Error(err), log.String("kb_id", request.KBID))
|
|
return nil
|
|
}
|
|
|
|
groupIds, err := h.nodeRepo.GetNodeAuthGroupIdsByNodeId(ctx, nodeRelease.NodeID, consts.NodePermNameAnswerable)
|
|
if err != nil {
|
|
h.logger.Error("get groupIds failed", log.Error(err), log.String("kb_id", request.KBID))
|
|
return nil
|
|
}
|
|
|
|
// upsert node content chunks
|
|
docID, err := h.rag.UpsertRecords(ctx, &rag.UpsertRecordsRequest{
|
|
ID: nodeRelease.ID,
|
|
Title: nodeRelease.Name,
|
|
DatasetID: kb.DatasetID,
|
|
DocID: nodeRelease.DocID,
|
|
Content: nodeRelease.Content,
|
|
GroupIDs: groupIds,
|
|
})
|
|
if err != nil {
|
|
h.logger.Error("upsert node content vector failed", log.Error(err))
|
|
return nil
|
|
}
|
|
// update node doc_id
|
|
if err := h.nodeRepo.UpdateNodeReleaseDocID(ctx, request.NodeReleaseID, docID); err != nil {
|
|
h.logger.Error("update node doc_id failed", log.String("node_id", request.NodeReleaseID), log.Error(err))
|
|
return nil
|
|
}
|
|
// delete old RAG records
|
|
// get old doc_ids by node_id
|
|
oldDocIDs, err := h.nodeRepo.GetOldNodeDocIDsByNodeID(ctx, nodeRelease.ID, nodeRelease.NodeID)
|
|
if err != nil {
|
|
h.logger.Error("get old doc_ids by node_id failed", log.String("node_id", nodeRelease.NodeID), log.Error(err))
|
|
return nil
|
|
}
|
|
if len(oldDocIDs) > 0 {
|
|
// delete old RAG records
|
|
if err := h.rag.DeleteRecords(ctx, kb.DatasetID, oldDocIDs); err != nil {
|
|
h.logger.Error("delete old RAG records failed", log.String("kb_id", kb.ID), log.Error(err))
|
|
return nil
|
|
}
|
|
}
|
|
|
|
h.logger.Info("upsert node content vector success", log.Any("updated_ids", request.NodeReleaseID))
|
|
case "delete":
|
|
h.logger.Info("delete node content vector request", log.Any("request", request))
|
|
kb, err := h.kbRepo.GetKnowledgeBaseByID(ctx, request.KBID)
|
|
if err != nil {
|
|
h.logger.Error("get kb failed", log.Error(err))
|
|
return nil
|
|
}
|
|
if err := h.rag.DeleteRecords(ctx, kb.DatasetID, []string{request.DocID}); err != nil {
|
|
h.logger.Error("delete node content vector failed", log.Error(err))
|
|
return nil
|
|
}
|
|
h.logger.Info("delete node content vector success", log.Any("deleted_id", request.NodeReleaseID), log.Any("deleted_doc_id", request.DocID))
|
|
case "summary":
|
|
h.logger.Info("summary node content vector request", log.Any("request", request))
|
|
node, err := h.nodeRepo.GetNodeByID(ctx, request.NodeID)
|
|
if err != nil {
|
|
h.logger.Error("get node by id failed", log.Error(err))
|
|
return nil
|
|
}
|
|
if node.Type == domain.NodeTypeFolder {
|
|
h.logger.Info("node is folder, skip summary", log.Any("node_id", request.NodeID))
|
|
return nil
|
|
}
|
|
|
|
model, err := h.modelUsecase.GetChatModel(ctx)
|
|
if err != nil {
|
|
h.logger.Error("get chat model failed", log.Error(err))
|
|
return nil
|
|
}
|
|
|
|
summary, err := h.llmUsecase.SummaryNode(ctx, request.KBID, model, node.Name, node.Content)
|
|
if err != nil {
|
|
h.logger.Error("summary node content failed", log.Error(err))
|
|
return nil
|
|
}
|
|
if err := h.nodeRepo.UpdateNodeSummary(ctx, request.KBID, request.NodeID, summary); err != nil {
|
|
h.logger.Error("update node summary failed", log.Error(err))
|
|
return nil
|
|
}
|
|
if node.Status == domain.NodeStatusPublished {
|
|
if err := h.nodeRepo.UpdateNodeStatus(ctx, request.KBID, request.NodeID, domain.NodeStatusDraft); err != nil {
|
|
h.logger.Error("update node status failed", log.Error(err))
|
|
return nil
|
|
}
|
|
}
|
|
|
|
h.logger.Info("summary node content vector success", log.Any("summary_id", request.NodeReleaseID), log.Any("summary", summary))
|
|
}
|
|
|
|
return nil
|
|
}
|