Files
YouduWiki/backend/usecase/node.go
2026-05-21 19:52:45 +08:00

876 lines
24 KiB
Go

package usecase
import (
"context"
"errors"
"fmt"
"slices"
"github.com/gomarkdown/markdown"
"github.com/gomarkdown/markdown/html"
"github.com/gomarkdown/markdown/parser"
"github.com/microcosm-cc/bluemonday"
"github.com/samber/lo"
"gorm.io/gorm"
navV1 "github.com/chaitin/panda-wiki/api/nav/v1"
v1 "github.com/chaitin/panda-wiki/api/node/v1"
shareV1 "github.com/chaitin/panda-wiki/api/share/v1"
"github.com/chaitin/panda-wiki/consts"
"github.com/chaitin/panda-wiki/domain"
"github.com/chaitin/panda-wiki/log"
"github.com/chaitin/panda-wiki/repo/mq"
"github.com/chaitin/panda-wiki/repo/pg"
"github.com/chaitin/panda-wiki/store/rag"
"github.com/chaitin/panda-wiki/store/s3"
"github.com/chaitin/panda-wiki/utils"
)
type NodeUsecase struct {
nodeRepo *pg.NodeRepository
navRepo *pg.NavRepository
appRepo *pg.AppRepository
ragRepo *mq.RAGRepository
kbRepo *pg.KnowledgeBaseRepository
modelRepo *pg.ModelRepository
userRepo *pg.UserRepository
authRepo *pg.AuthRepo
llmUsecase *LLMUsecase
logger *log.Logger
s3Client *s3.MinioClient
rAGService rag.RAGService
modelUsecase *ModelUsecase
}
func NewNodeUsecase(
nodeRepo *pg.NodeRepository,
navRepo *pg.NavRepository,
appRepo *pg.AppRepository,
ragRepo *mq.RAGRepository,
userRepo *pg.UserRepository,
kbRepo *pg.KnowledgeBaseRepository,
llmUsecase *LLMUsecase,
ragService rag.RAGService,
logger *log.Logger,
s3Client *s3.MinioClient,
modelRepo *pg.ModelRepository,
authRepo *pg.AuthRepo,
modelUsecase *ModelUsecase,
) *NodeUsecase {
return &NodeUsecase{
nodeRepo: nodeRepo,
navRepo: navRepo,
rAGService: ragService,
appRepo: appRepo,
ragRepo: ragRepo,
kbRepo: kbRepo,
authRepo: authRepo,
userRepo: userRepo,
llmUsecase: llmUsecase,
modelRepo: modelRepo,
logger: logger.WithModule("usecase.node"),
s3Client: s3Client,
modelUsecase: modelUsecase,
}
}
const ragSyncChunkSize = 100
func (u *NodeUsecase) Create(ctx context.Context, req *domain.CreateNodeReq, userId string) (string, error) {
nodeID, err := u.nodeRepo.Create(ctx, req, userId)
if err != nil {
return "", err
}
return nodeID, nil
}
func (u *NodeUsecase) GetList(ctx context.Context, req *domain.GetNodeListReq) ([]*domain.NodeListItemResp, error) {
nodes, err := u.nodeRepo.GetList(ctx, req)
if err != nil {
return nil, err
}
if len(nodes) == 0 {
return nodes, nil
}
publisherMap, err := u.nodeRepo.GetNodeReleasePublisherMap(ctx, req.KBID)
if err != nil {
return nil, err
}
for _, node := range nodes {
if publisherID, exists := publisherMap[node.ID]; exists {
node.PublisherId = publisherID
}
}
return nodes, nil
}
func (u *NodeUsecase) GetNodeByKBID(ctx context.Context, id, kbId, format string) (*v1.NodeDetailResp, error) {
node, err := u.nodeRepo.GetByID(ctx, id, kbId)
if err != nil {
return nil, err
}
nodeRelease, err := u.nodeRepo.GetLatestNodeReleaseWithPublishAccount(ctx, node.ID)
if err != nil {
return nil, err
}
if nodeRelease != nil {
node.PublisherId = nodeRelease.PublisherId
node.PublisherAccount = nodeRelease.PublisherAccount
}
nodeStat, err := u.nodeRepo.GetNodeStatsByNodeId(ctx, node.ID)
if err != nil {
return nil, err
}
node.PV = nodeStat.PV
if node.Meta.ContentType == domain.ContentTypeMD {
return node, nil
}
if format != "raw" {
if !utils.IsLikelyHTML(node.Content) {
node.Content = u.convertMDToHTML(node.Content)
}
}
return node, nil
}
func (u *NodeUsecase) NodeAction(ctx context.Context, req *domain.NodeActionReq) error {
switch req.Action {
case "delete":
docIDs, err := u.nodeRepo.Delete(ctx, req.KBID, req.IDs)
if err != nil {
return err
}
nodeVectorContentRequests := make([]*domain.NodeReleaseVectorRequest, 0)
for _, docID := range docIDs {
nodeVectorContentRequests = append(nodeVectorContentRequests, &domain.NodeReleaseVectorRequest{
KBID: req.KBID,
DocID: docID,
Action: "delete",
})
}
if err := u.ragRepo.AsyncUpdateNodeReleaseVector(ctx, nodeVectorContentRequests); err != nil {
return err
}
}
return nil
}
func (u *NodeUsecase) Update(ctx context.Context, req *domain.UpdateNodeReq, userId string) error {
if req.NavId != nil {
_, err := u.navRepo.GetById(ctx, *req.NavId)
if err != nil {
return errors.New("invalid nav_id")
}
}
err := u.nodeRepo.UpdateNodeContent(ctx, req, userId)
if err != nil {
return err
}
return nil
}
func (u *NodeUsecase) ValidateNodePerm(ctx context.Context, kbID, nodeId string, authId uint) *domain.PWResponseErrCode {
node, err := u.nodeRepo.GetNodeReleaseDetailByKBIDAndID(ctx, kbID, nodeId)
if err != nil {
return &domain.ErrCodeNotFound
}
switch node.Permissions.Visitable {
case consts.NodeAccessPermOpen:
return nil
case consts.NodeAccessPermClosed:
return &domain.ErrCodePermissionDenied
case consts.NodeAccessPermPartial:
authGroups, err := u.authRepo.GetAuthGroupWithParentsByAuthId(ctx, authId)
if err != nil {
return &domain.ErrCodeInternalError
}
authGroupIds := lo.Map(authGroups, func(v domain.AuthGroup, i int) uint {
return v.ID
})
nodeGroupIds := make([]string, 0)
if len(authGroupIds) != 0 {
nodeGroups, err := u.nodeRepo.GetNodeGroupsByGroupIdsPerm(ctx, authGroupIds, consts.NodePermNameVisitable)
if err != nil {
return &domain.ErrCodeInternalError
}
nodeGroupIds = lo.Map(nodeGroups, func(v domain.NodeAuthGroup, i int) string {
return v.NodeID
})
}
if !slices.Contains(nodeGroupIds, nodeId) {
u.logger.Error("ValidateNodePerm failed", log.Any("node_group_ids", nodeGroupIds), log.Any("node_id", nodeId))
return &domain.ErrCodePermissionDenied
}
default:
return &domain.ErrCodeInternalError
}
return nil
}
func (u *NodeUsecase) GetNodeReleaseDetailByKBIDAndID(ctx context.Context, kbID, nodeId, format string) (*shareV1.ShareNodeDetailResp, error) {
node, err := u.nodeRepo.GetNodeReleaseDetailByKBIDAndID(ctx, kbID, nodeId)
if err != nil {
return nil, err
}
userMap, err := u.userRepo.GetUsersAccountMap(ctx)
if err != nil {
return nil, err
}
if account, ok := userMap[node.CreatorId]; ok {
node.CreatorAccount = account
}
if account, ok := userMap[node.EditorId]; ok {
node.EditorAccount = account
}
if account, ok := userMap[node.PublisherId]; ok {
node.PublisherAccount = account
}
if domain.GetBaseEditionLimitation(ctx).AllowNodeStats {
webApp, err := u.appRepo.GetOrCreateAppByKBIDAndType(ctx, kbID, domain.AppTypeWeb)
if err != nil {
return nil, err
}
if webApp.Settings.StatsSetting.PVEnable {
nodeStat, err := u.nodeRepo.GetNodeStatsByNodeId(ctx, nodeId)
if err != nil {
return nil, err
}
node.PV = nodeStat.PV
}
}
if node.Meta.ContentType == domain.ContentTypeMD {
return node, nil
}
// just for info
if format != "raw" {
if !utils.IsLikelyHTML(node.Content) {
node.Content = u.convertMDToHTML(node.Content)
}
}
return node, nil
}
func (u *NodeUsecase) MoveNode(ctx context.Context, req *domain.MoveNodeReq) error {
return u.nodeRepo.MoveNodeBetween(ctx, req.ID, req.ParentID, req.PrevID, req.NextID, req.KbID)
}
func (u *NodeUsecase) SummaryNode(ctx context.Context, req *domain.NodeSummaryReq) error {
_, err := u.modelUsecase.GetChatModel(ctx)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return domain.ErrModelNotConfigured
}
return err
}
// async create node summary
nodeVectorContentRequests := make([]*domain.NodeReleaseVectorRequest, 0)
for _, id := range req.IDs {
nodeVectorContentRequests = append(nodeVectorContentRequests, &domain.NodeReleaseVectorRequest{
KBID: req.KBID,
NodeID: id,
Action: "summary",
})
}
if err := u.ragRepo.AsyncUpdateNodeReleaseVector(ctx, nodeVectorContentRequests); err != nil {
return err
}
return nil
}
func (u *NodeUsecase) StreamSummaryNode(ctx context.Context, req *domain.NodeSummaryReq, onChunk func(ctx context.Context, dataType, chunk string) error) error {
model, err := u.modelUsecase.GetChatModel(ctx)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return domain.ErrModelNotConfigured
}
return err
}
if len(req.IDs) != 1 {
return fmt.Errorf("stream summary only supports single node")
}
node, err := u.nodeRepo.GetNodeByID(ctx, req.IDs[0])
if err != nil {
return fmt.Errorf("get latest node release failed: %w", err)
}
if err := u.llmUsecase.StreamSummaryNode(ctx, req.KBID, model, node.Name, node.Content, onChunk); err != nil {
return fmt.Errorf("summary node failed: %w", err)
}
return nil
}
func (u *NodeUsecase) GetRecommendNodeList(ctx context.Context, req *domain.GetRecommendNodeListReq) ([]*domain.RecommendNodeListResp, error) {
// get latest kb release
kbRelease, err := u.kbRepo.GetLatestRelease(ctx, req.KBID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
var nodes []*domain.RecommendNodeListResp
// 优先通过 NavIds 搜索,如果 NavIds 为空则使用 NodeIDs
if len(req.NavIds) > 0 {
nodes, err = u.nodeRepo.GetRecommendNodeListByNavIDs(ctx, req.KBID, kbRelease.ID, req.NavIds)
if err != nil {
return nil, err
}
} else if len(req.NodeIDs) > 0 {
nodes, err = u.nodeRepo.GetRecommendNodeListByIDs(ctx, req.KBID, kbRelease.ID, req.NodeIDs)
if err != nil {
return nil, err
}
}
if len(nodes) > 0 {
// 如果是通过 NodeIDs 查询,按照 req.NodeIDs 的顺序排序
if len(req.NodeIDs) > 0 && len(req.NavIds) == 0 {
nodesMap := lo.SliceToMap(nodes, func(item *domain.RecommendNodeListResp) (string, *domain.RecommendNodeListResp) {
return item.ID, item
})
nodes = make([]*domain.RecommendNodeListResp, 0)
for _, id := range req.NodeIDs {
if node, ok := nodesMap[id]; ok {
nodes = append(nodes, node)
}
}
}
// get folder nodes
folderNodeIds := lo.Filter(nodes, func(item *domain.RecommendNodeListResp, _ int) bool {
return item.Type == domain.NodeTypeFolder
})
if len(folderNodeIds) > 0 {
parentIDNodeMap, err := u.nodeRepo.GetRecommendNodeListByParentIDs(ctx, req.KBID, kbRelease.ID, lo.Map(folderNodeIds, func(item *domain.RecommendNodeListResp, _ int) string {
return item.ID
}))
if err != nil {
return nil, err
}
for _, node := range nodes {
if parentNodes, ok := parentIDNodeMap[node.ID]; ok {
node.RecommendNodes = parentNodes
}
}
}
return nodes, nil
}
return nil, nil
}
func (u *NodeUsecase) BatchMoveNode(ctx context.Context, req *domain.BatchMoveReq) error {
return u.nodeRepo.BatchMove(ctx, req)
}
func (u *NodeUsecase) MoveNodeNav(ctx context.Context, req *v1.NodeMoveNavReq) error {
nav, err := u.navRepo.GetById(ctx, req.NavID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("nav not found: %w", err)
}
return err
}
if nav.KbID != req.KbID {
return fmt.Errorf("nav does not belong to kb %s", req.KbID)
}
return u.nodeRepo.MoveNodeNav(ctx, req.KbID, req.NavID, req.IDs)
}
func (u *NodeUsecase) convertMDToHTML(mdStr string) string {
extensions := parser.CommonExtensions & ^parser.Autolink & ^parser.MathJax
p := parser.NewWithExtensions(extensions)
doc := p.Parse([]byte(mdStr))
// create HTML renderer with extensions
htmlFlags := html.CommonFlags | html.HrefTargetBlank
opts := html.RendererOptions{Flags: htmlFlags}
renderer := html.NewRenderer(opts)
maybeUnsafeHTML := markdown.Render(doc, renderer)
html := bluemonday.UGCPolicy().SanitizeBytes(maybeUnsafeHTML)
return string(html)
}
func (u *NodeUsecase) GetShareNodeList(ctx context.Context, kbId string, authId uint) ([]*shareV1.NodeListGroupNavResp, error) {
nodes, err := u.nodeRepo.GetNodeReleaseListByKBID(ctx, kbId)
if err != nil {
return nil, err
}
nodeGroupIds, err := u.GetNodeIdsByAuthId(ctx, authId, consts.NodePermNameVisible)
if err != nil {
return nil, err
}
navs, err := u.navRepo.GetReleaseList(ctx, kbId)
if err != nil {
return nil, err
}
result := make([]*shareV1.NodeListGroupNavResp, 0, len(navs))
navIndexMap := make(map[string]int, len(navs))
for _, nav := range navs {
navIndexMap[nav.ID] = len(result)
result = append(result, &shareV1.NodeListGroupNavResp{
NavID: nav.ID,
NavName: nav.Name,
Position: nav.Position,
List: []domain.ShareNodeListItemResp{},
})
}
// O(1) auth group lookup
nodeGroupIdSet := lo.SliceToMap(nodeGroupIds, func(id string) (string, struct{}) {
return id, struct{}{}
})
for _, node := range nodes {
switch node.Permissions.Visible {
case consts.NodeAccessPermOpen:
case consts.NodeAccessPermPartial:
if _, ok := nodeGroupIdSet[node.ID]; !ok {
continue
}
default:
continue
}
if idx, ok := navIndexMap[node.NavId]; ok {
result[idx].List = append(result[idx].List, *node)
result[idx].Count++
}
}
return result, nil
}
func (u *NodeUsecase) GetNodeReleaseListByParentID(ctx context.Context, kbID, parentID string, authId uint) ([]*domain.ShareNodeDetailItem, error) {
// 一次性查询所有节点
allNodes, err := u.nodeRepo.GetNodeReleaseListByKBID(ctx, kbID)
if err != nil {
return nil, err
}
nodeGroupIds, err := u.GetNodeIdsByAuthId(ctx, authId, consts.NodePermNameVisible)
if err != nil {
return nil, err
}
// 先过滤权限
visibleNodes := make([]*domain.ShareNodeListItemResp, 0)
for i, node := range allNodes {
switch node.Permissions.Visible {
case consts.NodeAccessPermOpen:
visibleNodes = append(visibleNodes, allNodes[i])
case consts.NodeAccessPermPartial:
if slices.Contains(nodeGroupIds, node.ID) {
visibleNodes = append(visibleNodes, allNodes[i])
}
}
}
// 构建父子关系映射
childrenMap := make(map[string][]*domain.ShareNodeListItemResp)
for _, node := range visibleNodes {
childrenMap[node.ParentID] = append(childrenMap[node.ParentID], node)
}
// 构建树结构
result := u.buildNodeTree(parentID, childrenMap)
return result, nil
}
// buildNodeTree 递归构建节点树结构
func (u *NodeUsecase) buildNodeTree(parentID string, childrenMap map[string][]*domain.ShareNodeListItemResp) []*domain.ShareNodeDetailItem {
children := childrenMap[parentID]
result := make([]*domain.ShareNodeDetailItem, 0, len(children))
for _, child := range children {
node := &domain.ShareNodeDetailItem{
ID: child.ID,
Name: child.Name,
Type: child.Type,
ParentID: child.ParentID,
Position: child.Position,
Meta: child.Meta,
Emoji: child.Emoji,
UpdatedAt: child.UpdatedAt,
Children: make([]*domain.ShareNodeDetailItem, 0),
}
// 如果是文件夹,递归构建其子节点
if child.Type == domain.NodeTypeFolder {
childNodes := u.buildNodeTree(child.ID, childrenMap)
if len(childNodes) > 0 {
node.Children = append(node.Children, childNodes...)
}
}
result = append(result, node)
}
return result
}
func (u *NodeUsecase) GetNodeIdsByAuthId(ctx context.Context, authId uint, PermName consts.NodePermName) ([]string, error) {
authGroups, err := u.authRepo.GetAuthGroupWithParentsByAuthId(ctx, authId)
if err != nil {
return nil, err
}
authGroupIds := lo.Map(authGroups, func(v domain.AuthGroup, i int) uint {
return v.ID
})
nodeGroupIds := make([]string, 0)
if len(authGroupIds) != 0 {
nodeGroups, err := u.nodeRepo.GetNodeGroupsByGroupIdsPerm(ctx, authGroupIds, PermName)
if err != nil {
return nil, err
}
nodeGroupIds = lo.Map(nodeGroups, func(v domain.NodeAuthGroup, i int) string {
return v.NodeID
})
}
return nodeGroupIds, nil
}
func (u *NodeUsecase) GetNodePermissionsByID(ctx context.Context, id, kbID string) (*v1.NodePermissionResp, error) {
node, err := u.nodeRepo.GetByID(ctx, id, kbID)
if err != nil {
return nil, err
}
resp := &v1.NodePermissionResp{
ID: node.ID,
Permissions: node.Permissions,
AnswerableGroups: make([]domain.NodeGroupDetail, 0),
VisitableGroups: make([]domain.NodeGroupDetail, 0),
VisibleGroups: make([]domain.NodeGroupDetail, 0),
}
nodeGroupList, err := u.nodeRepo.GetNodeGroupByNodeId(ctx, node.ID)
if err != nil {
return nil, err
}
for i, nodeGroup := range nodeGroupList {
switch nodeGroup.Perm {
case consts.NodePermNameAnswerable:
resp.AnswerableGroups = append(resp.AnswerableGroups, nodeGroupList[i])
case consts.NodePermNameVisitable:
resp.VisitableGroups = append(resp.VisitableGroups, nodeGroupList[i])
case consts.NodePermNameVisible:
resp.VisibleGroups = append(resp.VisibleGroups, nodeGroupList[i])
}
}
return resp, err
}
func (u *NodeUsecase) ValidateNodePermissionsEdit(req v1.NodePermissionEditReq, edition consts.LicenseEdition) error {
if !slices.Contains([]consts.LicenseEdition{consts.LicenseEditionBusiness, consts.LicenseEditionEnterprise}, edition) {
if req.Permissions.Answerable == consts.NodeAccessPermPartial || req.Permissions.Visitable == consts.NodeAccessPermPartial || req.Permissions.Visible == consts.NodeAccessPermPartial {
return domain.ErrPermissionDenied
}
if req.AnswerableGroups != nil || req.VisitableGroups != nil || req.VisibleGroups != nil {
return domain.ErrPermissionDenied
}
}
return nil
}
func (u *NodeUsecase) NodePermissionsEdit(ctx context.Context, req v1.NodePermissionEditReq) error {
if req.Permissions != nil {
updateMap := map[string]interface{}{
"permissions": req.Permissions,
}
if err := u.nodeRepo.UpdateNodesByKbID(ctx, req.IDs, req.KbId, updateMap); err != nil {
return err
}
}
nodeReleases, err := u.nodeRepo.GetLatestNodeReleaseByNodeIDs(ctx, req.KbId, req.IDs)
if err != nil {
return fmt.Errorf("get latest node release failed: %w", err)
}
if len(nodeReleases) > 0 {
nodeVectorContentRequests := make([]*domain.NodeReleaseVectorRequest, 0)
var groupIds []int
switch req.Permissions.Answerable {
case consts.NodeAccessPermOpen:
groupIds = nil
case consts.NodeAccessPermPartial:
groupIds = *req.AnswerableGroups
case consts.NodeAccessPermClosed:
groupIds = make([]int, 0)
}
for _, nodeRelease := range nodeReleases {
if nodeRelease.DocID == "" {
continue
}
nodeVectorContentRequests = append(nodeVectorContentRequests, &domain.NodeReleaseVectorRequest{
KBID: req.KbId,
DocID: nodeRelease.DocID,
Action: "update_group_ids",
GroupIds: groupIds,
})
}
if len(nodeVectorContentRequests) != 0 {
if err := u.ragRepo.AsyncUpdateNodeReleaseVector(ctx, nodeVectorContentRequests); err != nil {
return err
}
}
}
if req.AnswerableGroups != nil {
if err := u.nodeRepo.UpdateNodeGroupByKbIDAndNodeIds(ctx, req.IDs, *req.AnswerableGroups, consts.NodePermNameAnswerable); err != nil {
return err
}
}
if req.VisibleGroups != nil {
if err := u.nodeRepo.UpdateNodeGroupByKbIDAndNodeIds(ctx, req.IDs, *req.VisibleGroups, consts.NodePermNameVisible); err != nil {
return err
}
}
if req.VisitableGroups != nil {
if err := u.nodeRepo.UpdateNodeGroupByKbIDAndNodeIds(ctx, req.IDs, *req.VisitableGroups, consts.NodePermNameVisitable); err != nil {
return err
}
}
return nil
}
func (u *NodeUsecase) SyncRagNodeStatus(ctx context.Context) error {
kbs, err := u.kbRepo.GetKnowledgeBaseList(ctx)
if err != nil {
return err
}
for _, kb := range kbs {
docIds, err := u.nodeRepo.GetNodeIdsWithoutStatusByKbId(ctx, kb.ID)
if err != nil {
u.logger.Error("get node ids without status failed",
log.String("kb_id", kb.ID),
log.Error(err))
continue
}
if len(docIds) == 0 {
continue
}
chunks := lo.Chunk(docIds, ragSyncChunkSize)
for _, chunk := range chunks {
docs, err := u.rAGService.ListDocuments(ctx, kb.DatasetID, chunk)
if err != nil {
u.logger.Error("list documents from RAG failed",
log.String("kb_id", kb.ID),
log.String("dataset_id", kb.DatasetID),
log.Error(err))
continue
}
if len(docs) == 0 {
continue
}
docToNodeMap, err := u.nodeRepo.GetNodeIdsByDocIds(ctx, chunk)
if err != nil {
u.logger.Error("get node ids by doc ids failed",
log.String("kb_id", kb.ID),
log.Error(err))
continue
}
type StatusInfo struct {
status string
message string
}
statusGroups := make(map[StatusInfo][]string) // status+message -> []nodeIDs
for _, doc := range docs {
nodeID, exists := docToNodeMap[doc.ID]
if !exists {
u.logger.Warn("doc_id not found in node_releases",
log.String("doc_id", doc.ID))
continue
}
statusKey := StatusInfo{
status: doc.Status,
message: doc.ProgressMsg,
}
statusGroups[statusKey] = append(statusGroups[statusKey], nodeID)
}
for statusInfo, nodeIDs := range statusGroups {
updateMap := map[string]interface{}{
"rag_info": domain.RagInfo{
Status: consts.NodeRagInfoStatus(statusInfo.status),
Message: statusInfo.message,
},
}
if err := u.nodeRepo.UpdateNodesByKbID(ctx, nodeIDs, kb.ID, updateMap); err != nil {
u.logger.Error("batch update node rag status failed",
log.String("kb_id", kb.ID),
log.Int("node_count", len(nodeIDs)),
log.String("status", statusInfo.status),
log.Error(err))
continue
}
u.logger.Debug("batch updated node rag status",
log.String("kb_id", kb.ID),
log.Int("node_count", len(nodeIDs)),
log.String("status", statusInfo.status))
}
}
}
return nil
}
func (u *NodeUsecase) NodeRestudy(ctx context.Context, req *v1.NodeRestudyReq) error {
nodeReleases, err := u.nodeRepo.GetLatestNodeReleaseByNodeIDs(ctx, req.KbId, req.NodeIds)
if err != nil {
u.logger.Error("get latest node release failed", log.Error(err))
return fmt.Errorf("get latest node release failed")
}
if len(nodeReleases) == 0 {
return fmt.Errorf("文档未首次发布,无法重新学习")
}
for _, nodeRelease := range nodeReleases {
if nodeRelease.DocID == "" {
continue
}
if err := u.ragRepo.AsyncUpdateNodeReleaseVector(ctx, []*domain.NodeReleaseVectorRequest{
{
KBID: nodeRelease.KBID,
NodeReleaseID: nodeRelease.ID,
Action: "upsert",
},
}); err != nil {
u.logger.Error("async update node release vector failed",
log.String("node_release_id", nodeRelease.ID),
log.Error(err))
continue
}
}
return nil
}
func (u *NodeUsecase) GetNodeStats(ctx context.Context, kbId string) (*v1.NodeStatsResp, error) {
resp, err := u.nodeRepo.GetNodeStats(ctx, kbId)
if err != nil {
return nil, err
}
navs, err := u.navRepo.GetList(ctx, kbId)
if err != nil {
return nil, err
}
navsReleased, err := u.navRepo.GetReleaseList(ctx, kbId)
if err != nil {
return nil, err
}
navsReleasedMap := make(map[string]*navV1.NavListResp, len(navsReleased))
for _, nr := range navsReleased {
navsReleasedMap[nr.ID] = &nr
}
for _, nav := range navs {
navsRelease, found := navsReleasedMap[nav.ID]
if !found || navsRelease.Position != nav.Position || navsRelease.Name != nav.Name {
resp.UnreleasedNavCount++
}
}
return resp, nil
}
func (u *NodeUsecase) GetNodeListGroupByNav(ctx context.Context, req v1.NodeListGroupNavReq) ([]*v1.NodeListGroupNavResp, error) {
nodes, err := u.nodeRepo.GetNodeListByStatus(ctx, req.KbId, req.Status, req.Search)
if err != nil {
return nil, err
}
navs, err := u.navRepo.GetListByIds(ctx, req.KbId, req.NavIds)
if err != nil {
return nil, err
}
navsReleased, err := u.navRepo.GetReleaseList(ctx, req.KbId)
if err != nil {
return nil, err
}
navsReleasedMap := make(map[string]*navV1.NavListResp, len(navsReleased))
for _, nr := range navsReleased {
navsReleasedMap[nr.ID] = &nr
}
// 按 position 顺序预建分组,用 map 做 O(1) 索引
result := make([]*v1.NodeListGroupNavResp, 0, len(navs))
navIndexMap := make(map[string]int, len(navs))
for _, nav := range navs {
release, found := navsReleasedMap[nav.ID]
navIndexMap[nav.ID] = len(result)
result = append(result, &v1.NodeListGroupNavResp{
NavID: nav.ID,
NavName: nav.Name,
Position: nav.Position,
IsReleased: found && release.Position == nav.Position && release.Name == nav.Name,
List: []domain.NodeListItemResp{},
})
}
for _, node := range nodes {
if idx, ok := navIndexMap[node.NavId]; ok {
result[idx].List = append(result[idx].List, *node)
result[idx].Count++
}
}
// 搜索时过滤掉空分组
if req.Search != "" {
filtered := make([]*v1.NodeListGroupNavResp, 0, len(result))
for _, group := range result {
if group.Count > 0 {
filtered = append(filtered, group)
}
}
return filtered, nil
}
return result, nil
}