package pg import ( "context" "errors" "fmt" "strings" "time" "gorm.io/gorm" "gorm.io/gorm/clause" "github.com/google/uuid" "github.com/lib/pq" "github.com/samber/lo" "github.com/samber/lo/mutable" v1 "github.com/chaitin/panda-wiki/api/node/v1" shareV1 "github.com/chaitin/panda-wiki/api/share/v1" "github.com/chaitin/panda-wiki/consts" "github.com/chaitin/panda-wiki/domain" "github.com/chaitin/panda-wiki/log" "github.com/chaitin/panda-wiki/store/pg" ) type NodeRepository struct { db *pg.DB logger *log.Logger } func NewNodeRepository(db *pg.DB, logger *log.Logger) *NodeRepository { return &NodeRepository{db: db, logger: logger.WithModule("repo.pg.node")} } func (r *NodeRepository) Create(ctx context.Context, req *domain.CreateNodeReq, userId string) (string, error) { nodeID, err := uuid.NewV7() if err != nil { return "", err } nodeIDStr := nodeID.String() err = r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { // check count var count int64 if err := tx.Model(&domain.Node{}). Where("kb_id = ?", req.KBID). Count(&count).Error; err != nil { return err } if count >= int64(req.MaxNode) { return domain.ErrMaxNodeLimitReached } var maxPos float64 query := tx.WithContext(ctx). Model(&domain.Node{}). Where("kb_id = ?", req.KBID) if req.ParentID == "" { query = query.Where("parent_id IS NULL OR parent_id = ''") } else { query = query.Where("parent_id = ?", req.ParentID) } if err := query. Select("COALESCE(MAX(position::float), 0)"). Scan(&maxPos).Error; err != nil { return err } var newPos float64 if req.Position != nil { // user specify position if *req.Position > domain.MaxPosition || *req.Position < 0 { return errors.New("specified position is out of range") } newPos = *req.Position } else { // default the last newPos = maxPos + (domain.MaxPosition-maxPos)/2.0 if newPos-maxPos < domain.MinPositionGap { if err := r.reorderPositionsByParentID(tx, req.KBID, req.ParentID); err != nil { return err } } } now := time.Now() meta := domain.NodeMeta{Emoji: req.Emoji} if req.Summary != nil { meta.Summary = *req.Summary } if req.ContentType != nil { meta.ContentType = *req.ContentType } node := &domain.Node{ ID: nodeIDStr, KBID: req.KBID, NavId: req.NavId, Name: req.Name, Content: req.Content, Meta: meta, Type: req.Type, ParentID: req.ParentID, Position: newPos, Status: domain.NodeStatusUnreleased, CreatorId: userId, EditorId: userId, CreatedAt: now, UpdatedAt: now, EditTime: now, RagInfo: domain.RagInfo{ Status: consts.NodeRagStatusPending, Message: "", }, Permissions: domain.NodePermissions{ Answerable: consts.NodeAccessPermOpen, Visitable: consts.NodeAccessPermOpen, Visible: consts.NodeAccessPermOpen, }, } return tx.Create(node).Error }) if err != nil { return "", err } return nodeIDStr, nil } func (r *NodeRepository) GetList(ctx context.Context, req *domain.GetNodeListReq) ([]*domain.NodeListItemResp, error) { var nodes []*domain.NodeListItemResp query := r.db.WithContext(ctx). Model(&domain.Node{}). Joins("LEFT JOIN users cu ON nodes.creator_id = cu.id"). Joins("LEFT JOIN users eu ON nodes.editor_id = eu.id"). Where("nodes.kb_id = ?", req.KBID). Select("cu.account AS creator, eu.account AS editor, nodes.editor_id, nodes.nav_id, nodes.rag_info, nodes.creator_id, nodes.id, nodes.permissions, nodes.type, nodes.status, nodes.name, nodes.parent_id, nodes.position, nodes.created_at, nodes.edit_time as updated_at, nodes.meta->>'summary' as summary, nodes.meta->>'emoji' as emoji, nodes.meta->>'content_type' as content_type") if req.Search != "" { searchPattern := "%" + req.Search + "%" query = query.Where("name LIKE ? OR content LIKE ?", searchPattern, searchPattern) } if req.NavId != "" { query = query.Where("nodes.nav_id = ?", req.NavId) } if err := query.Find(&nodes).Error; err != nil { return nil, err } return nodes, nil } func (r *NodeRepository) GetLatestNodeReleaseByNodeIDs(ctx context.Context, kbID string, ids []string) ([]*domain.NodeRelease, error) { var nodeReleases []*domain.NodeRelease if err := r.db.WithContext(ctx). Model(&domain.NodeRelease{}). Where("node_id IN ?", ids). Where("kb_id = ?", kbID). Select("DISTINCT ON (node_id) id, node_id, kb_id, doc_id"). Order("node_id, updated_at DESC"). Find(&nodeReleases).Error; err != nil { return nil, err } return nodeReleases, nil } func (r *NodeRepository) GetNodeReleasePublisherMap(ctx context.Context, kbID string) (map[string]string, error) { type Result struct { NodeID string `gorm:"column:node_id"` PublisherID string `gorm:"column:publisher_id"` } var results []Result if err := r.db.WithContext(ctx). Model(&domain.NodeRelease{}). Select("node_id, publisher_id"). Where("kb_id = ?", kbID). Where("node_releases.doc_id != '' "). Find(&results).Error; err != nil { return nil, err } publisherMap := make(map[string]string) for _, result := range results { if result.PublisherID != "" { publisherMap[result.NodeID] = result.PublisherID } } return publisherMap, nil } func (r *NodeRepository) UpdateNodeContent(ctx context.Context, req *domain.UpdateNodeReq, userId string) error { // Use transaction to ensure data consistency err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { // Get current node data with row-level lock var currentNode domain.Node if err := tx.Model(&domain.Node{}). Where("id = ?", req.ID). Where("kb_id = ?", req.KBID). // Use FOR UPDATE to lock the row until the transaction is complete Clauses(clause.Locking{Strength: "UPDATE"}). First(¤tNode).Error; err != nil { return err } updateMap := make(map[string]any) updateStatus := false updateMap["editor_id"] = userId // Compare and update Name if req.Name != nil && *req.Name != currentNode.Name { updateMap["name"] = *req.Name updateStatus = true } // Compare and update Content if req.Content != nil && *req.Content != currentNode.Content { updateMap["content"] = *req.Content updateStatus = true } if req.NavId != nil && *req.NavId != currentNode.NavId { updateMap["nav_id"] = *req.NavId updateStatus = true } if req.Position != nil && *req.Position != currentNode.Position { // user specify position updateMap["position"] = *req.Position if *req.Position > domain.MaxPosition || *req.Position < 0 { return errors.New("specified position is out of range") } updateStatus = true } // Handle multiple meta field updates if req.Emoji != nil || req.Summary != nil || req.ContentType != nil { metaExpr := "meta" var args []any metaUpdated := false // Compare and update Emoji if req.Emoji != nil && *req.Emoji != currentNode.Meta.Emoji { // First jsonb_set: jsonb_set(meta, '{emoji}', to_jsonb(?::text)) metaExpr = "jsonb_set(" + metaExpr + ", '{emoji}', to_jsonb(?::text))" args = append(args, *req.Emoji) // First parameter for emoji metaUpdated = true } // Compare and update Summary if req.Summary != nil && *req.Summary != currentNode.Meta.Summary { // Second jsonb_set: jsonb_set(previous_expr, '{summary}', to_jsonb(?::text)) metaExpr = "jsonb_set(" + metaExpr + ", '{summary}', to_jsonb(?::text))" args = append(args, *req.Summary) // Second parameter for summary metaUpdated = true } // Compare and update ContentType if currentNode.Meta.ContentType == "" { // can only modify content_type if it was empty before if req.ContentType != nil && *req.ContentType != currentNode.Meta.ContentType { // Second jsonb_set: jsonb_set(previous_expr, '{content_type}', to_jsonb(?::text)) metaExpr = "jsonb_set(" + metaExpr + ", '{content_type}', to_jsonb(?::text))" args = append(args, *req.ContentType) // Second parameter for content_type metaUpdated = true } } if metaUpdated { updateMap["meta"] = gorm.Expr(metaExpr, args...) updateStatus = true } } // If any field is updated and node released, set status to draft if updateStatus && currentNode.Status != domain.NodeStatusUnreleased { updateMap["status"] = domain.NodeStatusDraft updateMap["edit_time"] = time.Now() } // Perform update if there are changes if len(updateMap) > 0 { // Use the transaction's DB instance for the update return tx.Model(&domain.Node{}). Where("id = ?", req.ID). Where("kb_id = ?", req.KBID). Updates(updateMap).Error } return nil }) // Return any error from the transaction return err } func (r *NodeRepository) GetByID(ctx context.Context, id, kbId string) (*v1.NodeDetailResp, error) { var node *v1.NodeDetailResp if err := r.db.WithContext(ctx). Model(&domain.Node{}). Select("nodes.*, creator.id as creator_id, creator.account as creator_account, editor.id as editor_id, editor.account as editor_account"). Joins("left join users creator on creator.id = nodes.creator_id"). Joins("left join users editor on editor.id = nodes.editor_id"). Where("nodes.id = ?", id). Where("nodes.kb_id = ?", kbId). First(&node).Error; err != nil { return nil, err } return node, nil } func (r *NodeRepository) Delete(ctx context.Context, kbID string, ids []string) ([]string, error) { docIDs := make([]string, 0) if err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { // recursively collect all child node IDs allIDs := r.collectAllChildNodeIDs(tx, kbID, ids) var nodes []*domain.Node if err := tx.Model(&domain.Node{}). Where("id IN ?", allIDs). Where("kb_id = ?", kbID). Clauses(clause.Returning{Columns: []clause.Column{{Name: "doc_id"}}}). Delete(&nodes).Error; err != nil { return err } // backup node releases before deletion if err := r.backupNodeReleasesTx(tx, allIDs); err != nil { return err } // delete node release var nodeReleases []*domain.NodeRelease if err := tx.Model(&domain.NodeRelease{}). Where("node_id IN ?", allIDs). Clauses(clause.Returning{Columns: []clause.Column{{Name: "doc_id"}}}). Delete(&nodeReleases).Error; err != nil { return err } for _, node := range nodes { if node.DocID != "" { docIDs = append(docIDs, node.DocID) } } for _, nodeRelease := range nodeReleases { if nodeRelease.DocID != "" { docIDs = append(docIDs, nodeRelease.DocID) } } return nil }); err != nil { return nil, err } return lo.Uniq(docIDs), nil } func (r *NodeRepository) backupNodeReleasesTx(tx *gorm.DB, nodeIDs []string) error { var nodeReleases []*domain.NodeRelease if err := tx.Model(&domain.NodeRelease{}). Where("node_id IN ?", nodeIDs). Find(&nodeReleases).Error; err != nil { return err } if len(nodeReleases) == 0 { return nil } now := time.Now() backups := make([]*domain.NodeReleaseBackup, len(nodeReleases)) for i, nr := range nodeReleases { backups[i] = &domain.NodeReleaseBackup{ ID: nr.ID, KBID: nr.KBID, PublisherId: nr.PublisherId, EditorId: nr.EditorId, NodeID: nr.NodeID, DocID: nr.DocID, Type: nr.Type, Name: nr.Name, Meta: nr.Meta, Content: nr.Content, Position: nr.Position, ParentID: nr.ParentID, DeletedAt: now, CreatedAt: nr.CreatedAt, UpdatedAt: nr.UpdatedAt, } } return tx.Clauses(clause.OnConflict{DoNothing: true}).CreateInBatches(&backups, 500).Error } // collectAllChildNodeIDs recursively collects all child node IDs for the given parent IDs func (r *NodeRepository) collectAllChildNodeIDs(tx *gorm.DB, kbID string, parentIDs []string) []string { allIDs := make([]string, 0) allIDs = append(allIDs, parentIDs...) currentParentIDs := parentIDs for len(currentParentIDs) > 0 { var childIDs []string if err := tx.Model(&domain.Node{}). Where("parent_id IN ?", currentParentIDs). Where("kb_id = ?", kbID). Select("id"). Find(&childIDs).Error; err != nil { break } if len(childIDs) == 0 { break } allIDs = append(allIDs, childIDs...) currentParentIDs = childIDs } return lo.Uniq(allIDs) } func (r *NodeRepository) GetNodeByID(ctx context.Context, id string) (*domain.Node, error) { var node *domain.Node if err := r.db.WithContext(ctx). Model(&domain.Node{}). Where("id = ?", id). First(&node).Error; err != nil { return nil, err } return node, nil } // GetNodesByIDs retrieves nodes by their IDs func (r *NodeRepository) GetNodesByIDs(ctx context.Context, ids []string) (map[string]*domain.Node, error) { if len(ids) == 0 { return make(map[string]*domain.Node), nil } var nodes []*domain.Node if err := r.db.WithContext(ctx). Model(&domain.Node{}). Where("id IN ?", ids). Find(&nodes).Error; err != nil { return nil, err } nodesMap := make(map[string]*domain.Node, len(nodes)) for _, node := range nodes { nodesMap[node.ID] = node } return nodesMap, nil } // buildNodePath builds the directory path for a node release by traversing up the parent hierarchy (max 5 levels) func (r *NodeRepository) buildNodePath(ctx context.Context, kbID string, nodeRelease *domain.NodeRelease) (string, error) { // Build path by traversing up max 5 levels var pathParts []string currentParentNodeID := nodeRelease.ParentID // Traverse up the parent hierarchy, max 5 levels for i := 0; i < 5 && currentParentNodeID != ""; i++ { // Get the parent node release (ordered by created time to get the latest) var parentNodeRelease domain.NodeRelease if err := r.db.WithContext(ctx). Model(&domain.NodeRelease{}). Where("node_id = ? AND kb_id = ?", currentParentNodeID, kbID). Select("id, node_id, parent_id, name, type"). Order("created_at DESC"). First(&parentNodeRelease).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { break } return "", err } // Prepend current node name to path if it's a folder if parentNodeRelease.Type == domain.NodeTypeFolder { pathParts = append(pathParts, parentNodeRelease.Name) } // Move to parent's parent currentParentNodeID = parentNodeRelease.ParentID } // Build the final path if len(pathParts) == 0 { return "/", nil } mutable.Reverse(pathParts) path := "/" + strings.Join(pathParts, "/") + "/" return path, nil } func (r *NodeRepository) GetNodeNameByNodeIDs(ctx context.Context, ids []string) (map[string]string, error) { nodesMap := make(map[string]string) for _, chunk := range lo.Chunk(ids, 1000) { var nodes []*domain.Node if err := r.db.WithContext(ctx). Model(&domain.Node{}). Where("id IN ?", chunk). Select("id, name"). Find(&nodes).Error; err != nil { return nil, err } for _, node := range nodes { nodesMap[node.ID] = node.Name } } return nodesMap, nil } func (r *NodeRepository) GetNodeReleaseByID(ctx context.Context, id string) (*domain.NodeRelease, error) { var nodeRelease *domain.NodeRelease if err := r.db.WithContext(ctx). Model(&domain.NodeRelease{}). Where("id = ?", id). First(&nodeRelease).Error; err != nil { return nil, err } return nodeRelease, nil } func (r *NodeRepository) GetLatestNodeReleaseByNodeID(ctx context.Context, nodeID string) (*domain.NodeRelease, error) { var nodeRelease *domain.NodeRelease if err := r.db.WithContext(ctx). Model(&domain.NodeRelease{}). Where("node_id = ?", nodeID). Order("updated_at DESC"). First(&nodeRelease).Error; err != nil { return nil, err } return nodeRelease, nil } func (r *NodeRepository) GetLatestNodeReleaseWithPublishAccount(ctx context.Context, nodeID string) (*domain.NodeReleaseWithPublisher, error) { var nodeRelease *domain.NodeReleaseWithPublisher if err := r.db.WithContext(ctx). Model(&domain.NodeRelease{}). Select("node_releases.id, node_releases.publisher_id, users.account as publisher_account"). Joins("left join users on users.id = node_releases.publisher_id"). Where("node_releases.node_id = ?", nodeID). Order("node_releases.updated_at DESC"). Find(&nodeRelease).Error; err != nil { return nil, err } return nodeRelease, nil } // GetNodeReleaseWithDirPathByID gets a node release by ID and includes its directory path func (r *NodeRepository) GetNodeReleaseWithDirPathByID(ctx context.Context, id string) (*domain.NodeReleaseWithDirPath, error) { // First get the node release var nodeRelease *domain.NodeRelease if err := r.db.WithContext(ctx). Model(&domain.NodeRelease{}). Where("id = ?", id). First(&nodeRelease).Error; err != nil { return nil, err } // don't build path for folders if nodeRelease != nil && nodeRelease.Type == domain.NodeTypeFolder { return &domain.NodeReleaseWithDirPath{ NodeRelease: nodeRelease, }, nil } // Build the directory path path, err := r.buildNodePath(ctx, nodeRelease.KBID, nodeRelease) if err != nil { r.logger.Error("failed to build node path", log.String("id", id), log.Error(err)) } // Return the extended struct with path information return &domain.NodeReleaseWithDirPath{ NodeRelease: nodeRelease, Path: path, }, nil } func (r *NodeRepository) GetNodeReleasesByDocIDs(ctx context.Context, ids []string) (map[string]*domain.NodeRelease, error) { var nodeReleases []*domain.NodeRelease if err := r.db.WithContext(ctx). Model(&domain.NodeRelease{}). Where("doc_id IN ?", ids). Find(&nodeReleases).Error; err != nil { return nil, err } nodesMap := make(map[string]*domain.NodeRelease) for _, nodeRelease := range nodeReleases { nodesMap[nodeRelease.DocID] = nodeRelease } return nodesMap, nil } // NodeReleaseWithPath represents a node release with path information type NodeReleaseWithPath struct { *domain.NodeRelease PathIDs []string `json:"path_ids"` PathNames []string `json:"path_names"` Depth int `json:"depth"` } // GetNodeReleasesWithPathsByDocIDs retrieving node releases with path information func (r *NodeRepository) GetNodeReleasesWithPathsByDocIDs(ctx context.Context, ids []string) (map[string]*NodeReleaseWithPath, error) { if len(ids) == 0 { return make(map[string]*NodeReleaseWithPath), nil } // 1. 查询节点基本信息 var nodeReleases []*domain.NodeRelease if err := r.db.WithContext(ctx). Model(&domain.NodeRelease{}). Where("doc_id IN ?", ids). Find(&nodeReleases).Error; err != nil { return nil, err } if len(nodeReleases) == 0 { return make(map[string]*NodeReleaseWithPath), nil } docIDs := lo.Map(nodeReleases, func(release *domain.NodeRelease, i int) string { return release.DocID }) // 2. 批量查询路径 paths, err := r.getNodePathsBatch(ctx, docIDs) if err != nil { return nil, fmt.Errorf("failed to get paths: %w", err) } // 3. 组装结果 result := make(map[string]*NodeReleaseWithPath, len(nodeReleases)) for _, nr := range nodeReleases { nrWithPath := &NodeReleaseWithPath{ NodeRelease: nr, } if path, ok := paths[nr.DocID]; ok { nrWithPath.PathIDs = path.PathIDs nrWithPath.PathNames = path.PathNames nrWithPath.Depth = path.Depth } result[nr.DocID] = nrWithPath } return result, nil } // NodePathInfo contains path information for a node type NodePathInfo struct { DocID string PathIDs []string PathNames []string Depth int } // getNodePathsBatch batch query node paths func (r *NodeRepository) getNodePathsBatch(ctx context.Context, docIDs []string) (map[string]*NodePathInfo, error) { type pathResult struct { DocID string `gorm:"column:doc_id"` PathIDs pq.StringArray `gorm:"column:path_ids;type:text[]"` PathNames pq.StringArray `gorm:"column:path_names;type:text[]"` Depth int `gorm:"column:depth"` } var results []pathResult query := ` WITH RECURSIVE node_paths AS ( SELECT node_id, parent_id, name, doc_id as root_doc_id, ARRAY[node_id] as path_ids, ARRAY[name] as path_names, 1 as depth FROM node_releases WHERE doc_id = ANY($1) UNION ALL SELECT n.node_id, n.parent_id, n.name, np.root_doc_id, n.node_id || np.path_ids, n.name || np.path_names, np.depth + 1 FROM node_releases n INNER JOIN node_paths np ON n.node_id = np.parent_id WHERE np.depth < 20 AND n.doc_id != '' ) SELECT root_doc_id as doc_id, path_ids, path_names, depth FROM node_paths WHERE parent_id IS NULL OR parent_id = '' ` if err := r.db.WithContext(ctx). Raw(query, pq.Array(docIDs)). Scan(&results).Error; err != nil { return nil, err } // 转换为map pathMap := make(map[string]*NodePathInfo, len(results)) for _, res := range results { pathMap[res.DocID] = &NodePathInfo{ DocID: res.DocID, PathIDs: res.PathIDs, PathNames: res.PathNames, Depth: res.Depth, } } return pathMap, nil } // GetRecommendNodeListByIDs get node list by ids func (r *NodeRepository) GetRecommendNodeListByIDs(ctx context.Context, kbID string, releaseID string, ids []string) ([]*domain.RecommendNodeListResp, error) { var nodes []*domain.RecommendNodeListResp if err := r.db.WithContext(ctx). Model(&domain.KBReleaseNodeRelease{}). Joins("LEFT JOIN node_releases ON node_releases.id = kb_release_node_releases.node_release_id"). Joins("LEFT JOIN nodes ON nodes.id = node_releases.node_id"). Where("node_releases.kb_id = ?", kbID). Where("kb_release_node_releases.release_id = ?", releaseID). Where("node_releases.node_id IN ?", ids). Select("node_releases.node_id as id, node_releases.name, node_releases.type, node_releases.meta->>'summary' as summary, node_releases.meta->>'emoji' as emoji, node_releases.parent_id, node_releases.position, nodes.permissions"). Find(&nodes).Error; err != nil { return nil, err } return nodes, nil } // GetRecommendNodeListByNavIDs get node list by nav ids func (r *NodeRepository) GetRecommendNodeListByNavIDs(ctx context.Context, kbID string, releaseID string, navIds []string) ([]*domain.RecommendNodeListResp, error) { var nodes []*domain.RecommendNodeListResp if err := r.db.WithContext(ctx). Model(&domain.KBReleaseNodeRelease{}). Joins("LEFT JOIN node_releases ON node_releases.id = kb_release_node_releases.node_release_id"). Joins("LEFT JOIN nodes ON nodes.id = node_releases.node_id"). Where("node_releases.kb_id = ?", kbID). Where("kb_release_node_releases.release_id = ?", releaseID). Where("nodes.nav_id IN ?", navIds). Select("node_releases.node_id as id, node_releases.name, node_releases.type, node_releases.meta->>'summary' as summary, node_releases.meta->>'emoji' as emoji, node_releases.parent_id, node_releases.position, nodes.permissions, nodes.nav_id"). Order("node_releases.position ASC"). Find(&nodes).Error; err != nil { return nil, err } return nodes, nil } func (r *NodeRepository) GetRecommendNodeListByParentIDs(ctx context.Context, kbID string, releaseID string, parentIDs []string) (map[string][]*domain.RecommendNodeListResp, error) { var nodes []*domain.RecommendNodeListResp if err := r.db.WithContext(ctx). Model(&domain.KBReleaseNodeRelease{}). Joins("LEFT JOIN node_releases ON node_releases.id = kb_release_node_releases.node_release_id"). Joins("LEFT JOIN nodes ON nodes.id = node_releases.node_id"). Where("node_releases.kb_id = ?", kbID). Where("kb_release_node_releases.release_id = ?", releaseID). Where("node_releases.parent_id IN ?", parentIDs). Where("node_releases.type != ?", domain.NodeTypeFolder). Select("node_releases.node_id as id, node_releases.name, node_releases.type, node_releases.meta->>'summary' as summary, node_releases.meta->>'emoji' as emoji, node_releases.parent_id, node_releases.position, nodes.permissions"). Find(&nodes).Error; err != nil { return nil, err } nodesMap := make(map[string][]*domain.RecommendNodeListResp) for _, node := range nodes { if _, ok := nodesMap[node.ParentID]; !ok { nodesMap[node.ParentID] = make([]*domain.RecommendNodeListResp, 0) } nodesMap[node.ParentID] = append(nodesMap[node.ParentID], node) } return nodesMap, nil } // GetNodeReleaseListByKBID get node list by kb id func (r *NodeRepository) GetNodeReleaseListByKBID(ctx context.Context, kbID string) ([]*domain.ShareNodeListItemResp, error) { // get kb release var kbRelease *domain.KBRelease if err := r.db.WithContext(ctx). Model(&domain.KBRelease{}). Where("kb_id = ?", kbID). Order("created_at DESC"). First(&kbRelease).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil } return nil, err } var nodes []*domain.ShareNodeListItemResp qs := r.db.WithContext(ctx). Model(&domain.KBReleaseNodeRelease{}). Joins("LEFT JOIN node_releases ON node_releases.id = kb_release_node_releases.node_release_id"). Joins("LEFT JOIN nodes ON nodes.id = kb_release_node_releases.node_id"). Where("kb_release_node_releases.kb_id = ?", kbID). Where("kb_release_node_releases.release_id = ?", kbRelease.ID). Where("nodes.permissions->>'visible' != ?", consts.NodeAccessPermClosed). Select("node_releases.node_id as id, node_releases.name, node_releases.type, node_releases.parent_id, nodes.position, node_releases.meta->>'emoji' as emoji, node_releases.updated_at, nodes.permissions, nodes.meta, kb_release_node_releases.nav_id") if err := qs.Find(&nodes).Error; err != nil { return nil, err } return nodes, nil } func (r *NodeRepository) GetNodeReleaseDetailByKBIDAndID(ctx context.Context, kbID, id string) (*shareV1.ShareNodeDetailResp, error) { // get kb release var kbRelease *domain.KBRelease if err := r.db.WithContext(ctx). Model(&domain.KBRelease{}). Where("kb_id = ?", kbID). Order("created_at DESC"). First(&kbRelease).Error; err != nil { return nil, err } var node *shareV1.ShareNodeDetailResp if err := r.db.WithContext(ctx). Model(&domain.KBReleaseNodeRelease{}). Select("node_releases.*, nodes.permissions, nodes.creator_id"). Joins("LEFT JOIN node_releases ON node_releases.id = kb_release_node_releases.node_release_id"). Joins("LEFT JOIN nodes ON nodes.id = kb_release_node_releases.node_id"). Where("kb_release_node_releases.release_id = ?", kbRelease.ID). Where("node_releases.node_id = ?", id). Where("node_releases.kb_id = ?", kbID). First(&node).Error; err != nil { return nil, err } return node, nil } func (r *NodeRepository) MoveNodeBetween(ctx context.Context, id, parentID, prevID, nextID, kbId string) error { return r.db.Transaction(func(tx *gorm.DB) error { var prevPos, maxPos float64 = 0, domain.MaxPosition if prevID != "" { var prevNode *domain.Node if err := tx.Model(&domain.Node{}). Where("id = ?", prevID). Where("kb_id = ?", kbId). Where("parent_id = ?", parentID). Select("position, parent_id"). First(&prevNode).Error; err != nil { return err } prevPos = prevNode.Position } if nextID != "" { var nextNode *domain.Node if err := tx.Model(&domain.Node{}). Where("id = ?", nextID). Where("parent_id = ?", parentID). Where("kb_id = ?", kbId). Select("position, parent_id"). First(&nextNode).Error; err != nil { return err } maxPos = nextNode.Position } node, err := r.GetNodeByID(ctx, id) if err != nil { return err } newPos := prevPos + (maxPos-prevPos)/2.0 if newPos-prevPos < domain.MinPositionGap { if err := r.reorderPositionsByParentID(tx, node.KBID, parentID); err != nil { return err } } querySet := tx.Model(&domain.Node{}).Where("id = ?", id).Update("position", newPos).Update("parent_id", parentID) if node.Status == domain.NodeStatusPublished { querySet = querySet.Update("status", domain.NodeStatusDraft) } return querySet.Error }) } // UpdateNodeDocID update node doc id func (r *NodeRepository) UpdateNodeDocID(ctx context.Context, id, docID string) error { return r.db.WithContext(ctx). Model(&domain.Node{}). Omit("updated_at"). Where("id = ?", id). Updates(map[string]any{ "doc_id": docID, }).Error } // UpdateNodeReleaseDocID update node release doc id func (r *NodeRepository) UpdateNodeReleaseDocID(ctx context.Context, id, docID string) error { return r.db.WithContext(ctx). Model(&domain.NodeRelease{}). Omit("updated_at"). Where("id = ?", id). Updates(map[string]any{ "doc_id": docID, }).Error } func (r *NodeRepository) UpdateNodeSummary(ctx context.Context, kbID, nodeID, summary string) error { return r.db.WithContext(ctx). Model(&domain.Node{}). Where("kb_id = ? AND id = ?", kbID, nodeID). Updates(map[string]any{ "meta": gorm.Expr("jsonb_set(meta, '{summary}', to_jsonb(?::text))", summary), }).Error } func (r *NodeRepository) UpdateNodeStatus(ctx context.Context, kbID, nodeID string, nodeStatus domain.NodeStatus) error { return r.db.WithContext(ctx). Model(&domain.Node{}). Where("kb_id = ? AND id = ?", kbID, nodeID). Updates(map[string]any{ "status": nodeStatus, }).Error } // traverse all nodes by pg cursor func (r *NodeRepository) TraverseNodesByCursor(ctx context.Context, callback func(*domain.NodeRelease) error) error { rows, err := r.db.WithContext(ctx). Model(&domain.NodeRelease{}). Select("DISTINCT ON (node_id) id, node_id, kb_id"). Order("node_id, updated_at DESC"). Rows() if err != nil { return err } defer rows.Close() for rows.Next() { var nodeRelease domain.NodeRelease if err := r.db.ScanRows(rows, &nodeRelease); err != nil { return err } if err := callback(&nodeRelease); err != nil { return err } } if err := rows.Err(); err != nil { return err } return nil } // CreateNodeReleases create node releases func (r *NodeRepository) CreateNodeReleases(ctx context.Context, kbID, userId string, nodeIDs []string) ([]string, error) { releaseIDs := make([]string, 0) if err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { // update node status to published and return node ids var updatedNodes []*domain.Node if err := tx.Model(&domain.Node{}). Where("kb_id = ?", kbID). Where("id IN ?", nodeIDs). Update("status", domain.NodeStatusPublished). Find(&updatedNodes).Error; err != nil { return err } if len(updatedNodes) == 0 { return nil } nodeReleases := make([]*domain.NodeRelease, len(updatedNodes)) for i, updatedNode := range updatedNodes { // create node release nodeRelease := &domain.NodeRelease{ ID: uuid.New().String(), KBID: kbID, PublisherId: userId, EditorId: updatedNode.EditorId, NodeID: updatedNode.ID, Type: updatedNode.Type, Name: updatedNode.Name, Meta: updatedNode.Meta, Content: updatedNode.Content, ParentID: updatedNode.ParentID, Position: updatedNode.Position, CreatedAt: updatedNode.CreatedAt, UpdatedAt: time.Now(), } nodeReleases[i] = nodeRelease releaseIDs = append(releaseIDs, nodeRelease.ID) } if err := tx.CreateInBatches(&nodeReleases, 100).Error; err != nil { return err } return nil }); err != nil { return nil, err } return releaseIDs, nil } func (r *NodeRepository) GetOldNodeDocIDsByNodeID(ctx context.Context, nodeReleaseID, nodeID string) ([]string, error) { var docIDs []string if err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { // get old doc_ids by node_id if err := tx.Model(&domain.NodeRelease{}). Where("node_id = ?", nodeID). Where("id != ?", nodeReleaseID). Where("doc_id != ''"). Select("doc_id"). Find(&docIDs).Error; err != nil { return err } // update node_release.doc_id to "" if err := tx.Model(&domain.NodeRelease{}). Where("node_id = ?", nodeID). Where("id != ?", nodeReleaseID). Omit("updated_at"). Update("doc_id", "").Error; err != nil { return err } return nil }); err != nil { return nil, err } return docIDs, nil } func (r *NodeRepository) MoveNodeNav(ctx context.Context, kbID, navID string, nodeIDs []string) error { return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { allIDs := r.collectAllChildNodeIDs(tx, kbID, nodeIDs) if err := tx.Model(&domain.Node{}). Where("kb_id = ? AND id IN ?", kbID, allIDs). Update("nav_id", navID).Error; err != nil { return err } if err := tx.Model(&domain.Node{}). Where("kb_id = ? AND id IN ?", kbID, allIDs). Where("parent_id != ''"). Where("parent_id NOT IN ?", allIDs). Update("parent_id", "").Error; err != nil { return err } if err := tx.Model(&domain.Node{}). Where("kb_id = ? AND id IN ?", kbID, allIDs). Where("status = ?", domain.NodeStatusPublished). Update("status", domain.NodeStatusDraft).Error; err != nil { return err } return nil }) } func (r *NodeRepository) BatchMove(ctx context.Context, req *domain.BatchMoveReq) error { return r.db.Transaction(func(tx *gorm.DB) error { // update node parent_id if err := tx.WithContext(ctx).Model(&domain.Node{}). Where("kb_id = ?", req.KBID). Where("id IN ?", req.IDs). Update("parent_id", req.ParentID). Error; err != nil { return err } if err := tx.WithContext(ctx).Model(&domain.Node{}). Where("kb_id = ?", req.KBID). Where("id IN ?", req.IDs). Where("status = ?", domain.NodeStatusPublished). Update("status", domain.NodeStatusDraft). Error; err != nil { return err } return nil }) } // reorderPositionsByParentID 重排所给父节点下的所有子节点 func (r *NodeRepository) reorderPositionsByParentID(tx *gorm.DB, kbID, parentID string) error { var nodes []*domain.Node if parentID == "" { if err := tx.Model(&domain.Node{}). Where("kb_id = ?", kbID). Where("parent_id IS NULL OR parent_id = ''"). Order("position"). Find(&nodes).Error; err != nil { return err } } else { if err := tx.Model(&domain.Node{}). Where("kb_id = ?", kbID). Where("parent_id = ?", parentID). Order("position"). Find(&nodes).Error; err != nil { return err } } return r.reorderPositions(tx, nodes) } // reorderPositions 重排所给节点 func (r *NodeRepository) reorderPositions(tx *gorm.DB, nodes []*domain.Node) error { if len(nodes) == 0 { return nil } basePosition := int64(1000) // 起始位置 interval := int64(1000) // 间隔 updates := make([]map[string]interface{}, len(nodes)) for i, node := range nodes { newPosition := float64(basePosition + int64(i)*interval) updates[i] = map[string]interface{}{ "id": node.ID, "position": newPosition, } } batchSize := 300 for i := 0; i < len(updates); i += batchSize { end := i + batchSize if end > len(updates) { end = len(updates) } batch := updates[i:end] values := make([]string, 0, len(batch)) for _, update := range batch { id := update["id"] pos := update["position"] values = append(values, fmt.Sprintf("('%v', %v)", id, pos)) } sql := fmt.Sprintf("UPDATE nodes SET position = new_values.new_value FROM (VALUES %s) AS new_values(id, new_value) WHERE nodes.id = new_values.id", strings.Join(values, ", ")) if err := tx.Exec(sql).Error; err != nil { return err } } return nil } // GetNodeIDsByReleaseID get node IDs by release ID func (r *NodeRepository) GetNodeIDsByReleaseID(ctx context.Context, releaseID string) ([]string, error) { var nodeIDs []string if err := r.db.WithContext(ctx). Model(&domain.KBReleaseNodeRelease{}). Where("release_id = ?", releaseID). Select("node_id"). Find(&nodeIDs).Error; err != nil { return nil, err } return nodeIDs, nil } func (r *NodeRepository) UpdateNodeByKbID(ctx context.Context, id, kbId string, updateMap map[string]interface{}) error { return r.db.WithContext(ctx). Model(&domain.Node{}). Where("id = ?", id). Where("kb_id = ?", kbId). Updates(updateMap).Error } func (r *NodeRepository) UpdateNodesByKbID(ctx context.Context, ids []string, kbId string, updateMap map[string]interface{}) error { const batchSize = 500 // 批处理大小,避免IN子句过长 // 如果没有ID需要更新,直接返回 if len(ids) == 0 { return nil } // 分批处理 for i := 0; i < len(ids); i += batchSize { end := i + batchSize if end > len(ids) { end = len(ids) } batch := ids[i:end] if err := r.db.WithContext(ctx). Model(&domain.Node{}). Where("id in (?)", batch). Where("kb_id = ?", kbId). Updates(updateMap).Error; err != nil { return err } } return nil } func (r *NodeRepository) UpdateNodeGroupByKbIDAndNodeIds(ctx context.Context, nodeIds []string, groupIds []int, perm consts.NodePermName) error { const batchSize = 1000 // 批处理大小,避免IN子句过长 return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { // 分批删除现有的权限记录,防止nodeIds过长 for i := 0; i < len(nodeIds); i += batchSize { end := i + batchSize if end > len(nodeIds) { end = len(nodeIds) } batch := nodeIds[i:end] if err := tx.Model(&domain.NodeAuthGroup{}). Where("node_id in (?) AND perm = ?", batch, perm). Delete(&domain.NodeAuthGroup{}).Error; err != nil { return err } } // 如果 groupIds 为空,则只执行删除操作 if len(groupIds) == 0 { return nil } nodeGroups := make([]domain.NodeAuthGroup, 0) for i := range nodeIds { // 批量插入新的数据 for index := range groupIds { if groupIds[index] == 0 { continue } nodeGroups = append(nodeGroups, domain.NodeAuthGroup{ NodeID: nodeIds[i], AuthGroupID: groupIds[index], Perm: perm, }) } } if len(nodeGroups) != 0 { if err := tx.Model(&domain.NodeAuthGroup{}).CreateInBatches(&nodeGroups, 100).Error; err != nil { return err } } return nil }) } func (r *NodeRepository) GetNodeGroupByNodeId(ctx context.Context, nodeId string) ([]domain.NodeGroupDetail, error) { nodeGroup := make([]domain.NodeGroupDetail, 0) if err := r.db.WithContext(ctx). Model(&domain.NodeAuthGroup{}). Select("node_auth_groups.node_id, node_auth_groups.auth_group_id, node_auth_groups.perm, auth_groups.name, auth_groups.kb_id, auth_groups.auth_ids"). Joins("left join auth_groups on auth_groups.id = node_auth_groups.auth_group_id"). Where("node_auth_groups.node_id = ?", nodeId). Scan(&nodeGroup).Error; err != nil { return nil, err } return nodeGroup, nil } func (r *NodeRepository) Update(ctx context.Context, id string, m map[string]interface{}) error { return r.db.WithContext(ctx).Model(domain.Node{}).Where("id = ?", id).Updates(m).Error } func (r *NodeRepository) GetNodeIdByDocId(ctx context.Context, docId string) (string, error) { nodeIds := make([]string, 0) if err := r.db.WithContext(ctx).Model(domain.NodeRelease{}). Where("doc_id = ?", docId). Pluck("node_id", &nodeIds).Error; err != nil { return "", err } if len(nodeIds) < 1 { return "", fmt.Errorf("node not found for doc_id: %s", docId) } return nodeIds[0], nil } func (r *NodeRepository) GetNodeIdsWithoutStatusByKbId(ctx context.Context, kbId string) ([]string, error) { docIds := make([]string, 0) if err := r.db.WithContext(ctx). Model(&domain.Node{}). Joins("left join node_releases on node_releases.node_id = nodes.id"). Where("(nodes.rag_info ->> 'status' IS NULL OR nodes.rag_info ->> 'status' = '')"). Where("nodes.kb_id = ? ", kbId). Where("nodes.type = ? ", domain.NodeTypeDocument). Where("node_releases.doc_id != '' "). Pluck("node_releases.doc_id", &docIds).Error; err != nil { return nil, err } return docIds, nil } // GetNodeIdsByDocIds 批量获取 doc_id 到 node_id 的映射 func (r *NodeRepository) GetNodeIdsByDocIds(ctx context.Context, docIds []string) (map[string]string, error) { if len(docIds) == 0 { return make(map[string]string), nil } type Result struct { DocID string `gorm:"column:doc_id"` NodeID string `gorm:"column:node_id"` } results := make([]Result, 0) if err := r.db.WithContext(ctx). Model(&domain.NodeRelease{}). Select("doc_id, node_id"). Where("doc_id IN (?)", docIds). Find(&results).Error; err != nil { return nil, err } // 构建 doc_id -> node_id 的映射 docToNodeMap := make(map[string]string, len(results)) for _, result := range results { docToNodeMap[result.DocID] = result.NodeID } return docToNodeMap, nil } func (r *NodeRepository) DeleteOldNodeReleaseBackups(ctx context.Context, before time.Time) error { return r.db.WithContext(ctx). Where("deleted_at < ?", before). Delete(&domain.NodeReleaseBackup{}).Error } func (r *NodeRepository) GetNodeCount(ctx context.Context) (int, error) { var count int64 err := r.db.WithContext(ctx). Model(&domain.Node{}). Count(&count).Error if err != nil { return 0, err } return int(count), nil } func (r *NodeRepository) CountNodeByNavId(ctx context.Context, kbId, navId string) (int64, error) { var count int64 if err := r.db.WithContext(ctx). Model(&domain.Node{}). Where("kb_id = ?", kbId). Where("nav_id = ?", navId). Count(&count).Error; err != nil { return 0, err } return count, nil } func (r *NodeRepository) GetNodeIDsByNavId(ctx context.Context, kbId, navId string) ([]string, error) { var ids []string if err := r.db.WithContext(ctx). Model(&domain.Node{}). Where("kb_id = ? AND nav_id = ?", kbId, navId). Pluck("id", &ids).Error; err != nil { return nil, err } return ids, nil } func (r *NodeRepository) GetNodeListByStatus(ctx context.Context, kbId, status, search string) ([]*domain.NodeListItemResp, error) { var nodes []*domain.NodeListItemResp query := r.db.WithContext(ctx). Model(&domain.Node{}). Joins("LEFT JOIN users cu ON nodes.creator_id = cu.id"). Joins("LEFT JOIN users eu ON nodes.editor_id = eu.id"). Where("nodes.kb_id = ?", kbId). Select("cu.account AS creator, eu.account AS editor, nodes.editor_id, nodes.nav_id, nodes.rag_info, nodes.creator_id, nodes.id, nodes.permissions, nodes.type, nodes.status, nodes.name, nodes.parent_id, nodes.position, nodes.created_at, nodes.edit_time as updated_at, nodes.meta->>'summary' as summary, nodes.meta->>'emoji' as emoji, nodes.meta->>'content_type' as content_type") if search != "" { searchPattern := "%" + search + "%" query = query.Where("name LIKE ? OR content LIKE ?", searchPattern, searchPattern) } switch status { // 发布后允许可配置的 case "released": query = query.Where("nodes.status IN ?", []domain.NodeStatus{domain.NodeStatusDraft, domain.NodeStatusPublished}) case "unpublished": query = query.Where("nodes.status IN ?", []domain.NodeStatus{domain.NodeStatusUnreleased, domain.NodeStatusDraft}) case "unstudied": query = query.Where("nodes.type = ?", domain.NodeTypeDocument). Where("nodes.rag_info->>'status' NOT IN ? OR nodes.rag_info->>'status' IS NULL", []string{string(consts.NodeRagStatusSucceeded), string(consts.NodeRagStatusRunning), string(consts.NodeRagStatusReindexing)}) } if err := query.Find(&nodes).Error; err != nil { return nil, err } return nodes, nil } func (r *NodeRepository) GetNodeStats(ctx context.Context, kbId string) (*v1.NodeStatsResp, error) { var stats v1.NodeStatsResp // Count unpublished documents (status = 0 or 1) unpublishedQuery := r.db.WithContext(ctx). Model(&domain.Node{}). Where("kb_id = ? AND status IN ?", kbId, []domain.NodeStatus{domain.NodeStatusUnreleased, domain.NodeStatusDraft}) if err := unpublishedQuery.Count(&stats.UnpublishedCount).Error; err != nil { return nil, err } studiedStatuses := []consts.NodeRagInfoStatus{ consts.NodeRagStatusSucceeded, consts.NodeRagStatusRunning, consts.NodeRagStatusReindexing, } unstudiedQuery := r.db.WithContext(ctx). Model(&domain.Node{}). Where("kb_id = ?", kbId). Where("nodes.type = ?", domain.NodeTypeDocument). Where("rag_info->>'status' NOT IN ? OR rag_info->>'status' IS NULL", studiedStatuses) if err := unstudiedQuery.Count(&stats.UnstudiedCount).Error; err != nil { return nil, err } return &stats, nil }