1412 lines
43 KiB
Go
1412 lines
43 KiB
Go
package pg
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"fmt"
|
||
"strings"
|
||
"time"
|
||
|
||
"gorm.io/gorm"
|
||
"gorm.io/gorm/clause"
|
||
|
||
"github.com/google/uuid"
|
||
"github.com/lib/pq"
|
||
"github.com/samber/lo"
|
||
"github.com/samber/lo/mutable"
|
||
|
||
v1 "github.com/chaitin/panda-wiki/api/node/v1"
|
||
shareV1 "github.com/chaitin/panda-wiki/api/share/v1"
|
||
"github.com/chaitin/panda-wiki/consts"
|
||
"github.com/chaitin/panda-wiki/domain"
|
||
"github.com/chaitin/panda-wiki/log"
|
||
"github.com/chaitin/panda-wiki/store/pg"
|
||
)
|
||
|
||
type NodeRepository struct {
|
||
db *pg.DB
|
||
logger *log.Logger
|
||
}
|
||
|
||
func NewNodeRepository(db *pg.DB, logger *log.Logger) *NodeRepository {
|
||
return &NodeRepository{db: db, logger: logger.WithModule("repo.pg.node")}
|
||
}
|
||
|
||
func (r *NodeRepository) Create(ctx context.Context, req *domain.CreateNodeReq, userId string) (string, error) {
|
||
nodeID, err := uuid.NewV7()
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
nodeIDStr := nodeID.String()
|
||
err = r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||
// check count
|
||
var count int64
|
||
if err := tx.Model(&domain.Node{}).
|
||
Where("kb_id = ?", req.KBID).
|
||
Count(&count).Error; err != nil {
|
||
return err
|
||
}
|
||
if count >= int64(req.MaxNode) {
|
||
return domain.ErrMaxNodeLimitReached
|
||
}
|
||
var maxPos float64
|
||
query := tx.WithContext(ctx).
|
||
Model(&domain.Node{}).
|
||
Where("kb_id = ?", req.KBID)
|
||
|
||
if req.ParentID == "" {
|
||
query = query.Where("parent_id IS NULL OR parent_id = ''")
|
||
} else {
|
||
query = query.Where("parent_id = ?", req.ParentID)
|
||
}
|
||
|
||
if err := query.
|
||
Select("COALESCE(MAX(position::float), 0)").
|
||
Scan(&maxPos).Error; err != nil {
|
||
return err
|
||
}
|
||
|
||
var newPos float64
|
||
if req.Position != nil { // user specify position
|
||
if *req.Position > domain.MaxPosition || *req.Position < 0 {
|
||
return errors.New("specified position is out of range")
|
||
}
|
||
newPos = *req.Position
|
||
} else { // default the last
|
||
newPos = maxPos + (domain.MaxPosition-maxPos)/2.0
|
||
if newPos-maxPos < domain.MinPositionGap {
|
||
if err := r.reorderPositionsByParentID(tx, req.KBID, req.ParentID); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
}
|
||
|
||
now := time.Now()
|
||
meta := domain.NodeMeta{Emoji: req.Emoji}
|
||
if req.Summary != nil {
|
||
meta.Summary = *req.Summary
|
||
}
|
||
if req.ContentType != nil {
|
||
meta.ContentType = *req.ContentType
|
||
}
|
||
|
||
node := &domain.Node{
|
||
ID: nodeIDStr,
|
||
KBID: req.KBID,
|
||
NavId: req.NavId,
|
||
Name: req.Name,
|
||
Content: req.Content,
|
||
Meta: meta,
|
||
Type: req.Type,
|
||
ParentID: req.ParentID,
|
||
Position: newPos,
|
||
Status: domain.NodeStatusUnreleased,
|
||
CreatorId: userId,
|
||
EditorId: userId,
|
||
CreatedAt: now,
|
||
UpdatedAt: now,
|
||
EditTime: now,
|
||
RagInfo: domain.RagInfo{
|
||
Status: consts.NodeRagStatusPending,
|
||
Message: "",
|
||
},
|
||
Permissions: domain.NodePermissions{
|
||
Answerable: consts.NodeAccessPermOpen,
|
||
Visitable: consts.NodeAccessPermOpen,
|
||
Visible: consts.NodeAccessPermOpen,
|
||
},
|
||
}
|
||
|
||
return tx.Create(node).Error
|
||
})
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
return nodeIDStr, nil
|
||
}
|
||
|
||
func (r *NodeRepository) GetList(ctx context.Context, req *domain.GetNodeListReq) ([]*domain.NodeListItemResp, error) {
|
||
var nodes []*domain.NodeListItemResp
|
||
query := r.db.WithContext(ctx).
|
||
Model(&domain.Node{}).
|
||
Joins("LEFT JOIN users cu ON nodes.creator_id = cu.id").
|
||
Joins("LEFT JOIN users eu ON nodes.editor_id = eu.id").
|
||
Where("nodes.kb_id = ?", req.KBID).
|
||
Select("cu.account AS creator, eu.account AS editor, nodes.editor_id, nodes.nav_id, nodes.rag_info, nodes.creator_id, nodes.id, nodes.permissions, nodes.type, nodes.status, nodes.name, nodes.parent_id, nodes.position, nodes.created_at, nodes.edit_time as updated_at, nodes.meta->>'summary' as summary, nodes.meta->>'emoji' as emoji, nodes.meta->>'content_type' as content_type")
|
||
if req.Search != "" {
|
||
searchPattern := "%" + req.Search + "%"
|
||
query = query.Where("name LIKE ? OR content LIKE ?", searchPattern, searchPattern)
|
||
}
|
||
if req.NavId != "" {
|
||
query = query.Where("nodes.nav_id = ?", req.NavId)
|
||
}
|
||
if err := query.Find(&nodes).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
return nodes, nil
|
||
}
|
||
|
||
func (r *NodeRepository) GetLatestNodeReleaseByNodeIDs(ctx context.Context, kbID string, ids []string) ([]*domain.NodeRelease, error) {
|
||
var nodeReleases []*domain.NodeRelease
|
||
if err := r.db.WithContext(ctx).
|
||
Model(&domain.NodeRelease{}).
|
||
Where("node_id IN ?", ids).
|
||
Where("kb_id = ?", kbID).
|
||
Select("DISTINCT ON (node_id) id, node_id, kb_id, doc_id").
|
||
Order("node_id, updated_at DESC").
|
||
Find(&nodeReleases).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
return nodeReleases, nil
|
||
}
|
||
|
||
func (r *NodeRepository) GetNodeReleasePublisherMap(ctx context.Context, kbID string) (map[string]string, error) {
|
||
type Result struct {
|
||
NodeID string `gorm:"column:node_id"`
|
||
PublisherID string `gorm:"column:publisher_id"`
|
||
}
|
||
|
||
var results []Result
|
||
if err := r.db.WithContext(ctx).
|
||
Model(&domain.NodeRelease{}).
|
||
Select("node_id, publisher_id").
|
||
Where("kb_id = ?", kbID).
|
||
Where("node_releases.doc_id != '' ").
|
||
Find(&results).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
publisherMap := make(map[string]string)
|
||
for _, result := range results {
|
||
if result.PublisherID != "" {
|
||
publisherMap[result.NodeID] = result.PublisherID
|
||
}
|
||
}
|
||
|
||
return publisherMap, nil
|
||
}
|
||
|
||
func (r *NodeRepository) UpdateNodeContent(ctx context.Context, req *domain.UpdateNodeReq, userId string) error {
|
||
// Use transaction to ensure data consistency
|
||
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||
// Get current node data with row-level lock
|
||
var currentNode domain.Node
|
||
if err := tx.Model(&domain.Node{}).
|
||
Where("id = ?", req.ID).
|
||
Where("kb_id = ?", req.KBID).
|
||
// Use FOR UPDATE to lock the row until the transaction is complete
|
||
Clauses(clause.Locking{Strength: "UPDATE"}).
|
||
First(¤tNode).Error; err != nil {
|
||
return err
|
||
}
|
||
|
||
updateMap := make(map[string]any)
|
||
updateStatus := false
|
||
|
||
updateMap["editor_id"] = userId
|
||
|
||
// Compare and update Name
|
||
if req.Name != nil && *req.Name != currentNode.Name {
|
||
updateMap["name"] = *req.Name
|
||
updateStatus = true
|
||
}
|
||
|
||
// Compare and update Content
|
||
if req.Content != nil && *req.Content != currentNode.Content {
|
||
updateMap["content"] = *req.Content
|
||
updateStatus = true
|
||
}
|
||
|
||
if req.NavId != nil && *req.NavId != currentNode.NavId {
|
||
updateMap["nav_id"] = *req.NavId
|
||
updateStatus = true
|
||
}
|
||
|
||
if req.Position != nil && *req.Position != currentNode.Position { // user specify position
|
||
updateMap["position"] = *req.Position
|
||
if *req.Position > domain.MaxPosition || *req.Position < 0 {
|
||
return errors.New("specified position is out of range")
|
||
}
|
||
updateStatus = true
|
||
}
|
||
|
||
// Handle multiple meta field updates
|
||
if req.Emoji != nil || req.Summary != nil || req.ContentType != nil {
|
||
metaExpr := "meta"
|
||
var args []any
|
||
metaUpdated := false
|
||
|
||
// Compare and update Emoji
|
||
if req.Emoji != nil && *req.Emoji != currentNode.Meta.Emoji {
|
||
// First jsonb_set: jsonb_set(meta, '{emoji}', to_jsonb(?::text))
|
||
metaExpr = "jsonb_set(" + metaExpr + ", '{emoji}', to_jsonb(?::text))"
|
||
args = append(args, *req.Emoji) // First parameter for emoji
|
||
metaUpdated = true
|
||
}
|
||
|
||
// Compare and update Summary
|
||
if req.Summary != nil && *req.Summary != currentNode.Meta.Summary {
|
||
// Second jsonb_set: jsonb_set(previous_expr, '{summary}', to_jsonb(?::text))
|
||
metaExpr = "jsonb_set(" + metaExpr + ", '{summary}', to_jsonb(?::text))"
|
||
args = append(args, *req.Summary) // Second parameter for summary
|
||
metaUpdated = true
|
||
}
|
||
|
||
// Compare and update ContentType
|
||
if currentNode.Meta.ContentType == "" { // can only modify content_type if it was empty before
|
||
if req.ContentType != nil && *req.ContentType != currentNode.Meta.ContentType {
|
||
// Second jsonb_set: jsonb_set(previous_expr, '{content_type}', to_jsonb(?::text))
|
||
metaExpr = "jsonb_set(" + metaExpr + ", '{content_type}', to_jsonb(?::text))"
|
||
args = append(args, *req.ContentType) // Second parameter for content_type
|
||
metaUpdated = true
|
||
}
|
||
}
|
||
|
||
if metaUpdated {
|
||
updateMap["meta"] = gorm.Expr(metaExpr, args...)
|
||
updateStatus = true
|
||
}
|
||
}
|
||
|
||
// If any field is updated and node released, set status to draft
|
||
if updateStatus && currentNode.Status != domain.NodeStatusUnreleased {
|
||
updateMap["status"] = domain.NodeStatusDraft
|
||
updateMap["edit_time"] = time.Now()
|
||
}
|
||
|
||
// Perform update if there are changes
|
||
if len(updateMap) > 0 {
|
||
// Use the transaction's DB instance for the update
|
||
return tx.Model(&domain.Node{}).
|
||
Where("id = ?", req.ID).
|
||
Where("kb_id = ?", req.KBID).
|
||
Updates(updateMap).Error
|
||
}
|
||
return nil
|
||
})
|
||
|
||
// Return any error from the transaction
|
||
return err
|
||
}
|
||
|
||
func (r *NodeRepository) GetByID(ctx context.Context, id, kbId string) (*v1.NodeDetailResp, error) {
|
||
var node *v1.NodeDetailResp
|
||
if err := r.db.WithContext(ctx).
|
||
Model(&domain.Node{}).
|
||
Select("nodes.*, creator.id as creator_id, creator.account as creator_account, editor.id as editor_id, editor.account as editor_account").
|
||
Joins("left join users creator on creator.id = nodes.creator_id").
|
||
Joins("left join users editor on editor.id = nodes.editor_id").
|
||
Where("nodes.id = ?", id).
|
||
Where("nodes.kb_id = ?", kbId).
|
||
First(&node).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
return node, nil
|
||
}
|
||
|
||
func (r *NodeRepository) Delete(ctx context.Context, kbID string, ids []string) ([]string, error) {
|
||
docIDs := make([]string, 0)
|
||
if err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||
// recursively collect all child node IDs
|
||
allIDs := r.collectAllChildNodeIDs(tx, kbID, ids)
|
||
|
||
var nodes []*domain.Node
|
||
if err := tx.Model(&domain.Node{}).
|
||
Where("id IN ?", allIDs).
|
||
Where("kb_id = ?", kbID).
|
||
Clauses(clause.Returning{Columns: []clause.Column{{Name: "doc_id"}}}).
|
||
Delete(&nodes).Error; err != nil {
|
||
return err
|
||
}
|
||
// backup node releases before deletion
|
||
if err := r.backupNodeReleasesTx(tx, allIDs); err != nil {
|
||
return err
|
||
}
|
||
|
||
// delete node release
|
||
var nodeReleases []*domain.NodeRelease
|
||
if err := tx.Model(&domain.NodeRelease{}).
|
||
Where("node_id IN ?", allIDs).
|
||
Clauses(clause.Returning{Columns: []clause.Column{{Name: "doc_id"}}}).
|
||
Delete(&nodeReleases).Error; err != nil {
|
||
return err
|
||
}
|
||
for _, node := range nodes {
|
||
if node.DocID != "" {
|
||
docIDs = append(docIDs, node.DocID)
|
||
}
|
||
}
|
||
for _, nodeRelease := range nodeReleases {
|
||
if nodeRelease.DocID != "" {
|
||
docIDs = append(docIDs, nodeRelease.DocID)
|
||
}
|
||
}
|
||
return nil
|
||
}); err != nil {
|
||
return nil, err
|
||
}
|
||
return lo.Uniq(docIDs), nil
|
||
}
|
||
|
||
func (r *NodeRepository) backupNodeReleasesTx(tx *gorm.DB, nodeIDs []string) error {
|
||
var nodeReleases []*domain.NodeRelease
|
||
if err := tx.Model(&domain.NodeRelease{}).
|
||
Where("node_id IN ?", nodeIDs).
|
||
Find(&nodeReleases).Error; err != nil {
|
||
return err
|
||
}
|
||
if len(nodeReleases) == 0 {
|
||
return nil
|
||
}
|
||
now := time.Now()
|
||
backups := make([]*domain.NodeReleaseBackup, len(nodeReleases))
|
||
for i, nr := range nodeReleases {
|
||
backups[i] = &domain.NodeReleaseBackup{
|
||
ID: nr.ID,
|
||
KBID: nr.KBID,
|
||
PublisherId: nr.PublisherId,
|
||
EditorId: nr.EditorId,
|
||
NodeID: nr.NodeID,
|
||
DocID: nr.DocID,
|
||
Type: nr.Type,
|
||
Name: nr.Name,
|
||
Meta: nr.Meta,
|
||
Content: nr.Content,
|
||
Position: nr.Position,
|
||
ParentID: nr.ParentID,
|
||
DeletedAt: now,
|
||
CreatedAt: nr.CreatedAt,
|
||
UpdatedAt: nr.UpdatedAt,
|
||
}
|
||
}
|
||
return tx.Clauses(clause.OnConflict{DoNothing: true}).CreateInBatches(&backups, 500).Error
|
||
}
|
||
|
||
// collectAllChildNodeIDs recursively collects all child node IDs for the given parent IDs
|
||
func (r *NodeRepository) collectAllChildNodeIDs(tx *gorm.DB, kbID string, parentIDs []string) []string {
|
||
allIDs := make([]string, 0)
|
||
allIDs = append(allIDs, parentIDs...)
|
||
|
||
currentParentIDs := parentIDs
|
||
for len(currentParentIDs) > 0 {
|
||
var childIDs []string
|
||
if err := tx.Model(&domain.Node{}).
|
||
Where("parent_id IN ?", currentParentIDs).
|
||
Where("kb_id = ?", kbID).
|
||
Select("id").
|
||
Find(&childIDs).Error; err != nil {
|
||
break
|
||
}
|
||
|
||
if len(childIDs) == 0 {
|
||
break
|
||
}
|
||
|
||
allIDs = append(allIDs, childIDs...)
|
||
currentParentIDs = childIDs
|
||
}
|
||
|
||
return lo.Uniq(allIDs)
|
||
}
|
||
|
||
func (r *NodeRepository) GetNodeByID(ctx context.Context, id string) (*domain.Node, error) {
|
||
var node *domain.Node
|
||
if err := r.db.WithContext(ctx).
|
||
Model(&domain.Node{}).
|
||
Where("id = ?", id).
|
||
First(&node).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
return node, nil
|
||
}
|
||
|
||
// GetNodesByIDs retrieves nodes by their IDs
|
||
func (r *NodeRepository) GetNodesByIDs(ctx context.Context, ids []string) (map[string]*domain.Node, error) {
|
||
if len(ids) == 0 {
|
||
return make(map[string]*domain.Node), nil
|
||
}
|
||
var nodes []*domain.Node
|
||
if err := r.db.WithContext(ctx).
|
||
Model(&domain.Node{}).
|
||
Where("id IN ?", ids).
|
||
Find(&nodes).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
nodesMap := make(map[string]*domain.Node, len(nodes))
|
||
for _, node := range nodes {
|
||
nodesMap[node.ID] = node
|
||
}
|
||
return nodesMap, nil
|
||
}
|
||
|
||
// buildNodePath builds the directory path for a node release by traversing up the parent hierarchy (max 5 levels)
|
||
func (r *NodeRepository) buildNodePath(ctx context.Context, kbID string, nodeRelease *domain.NodeRelease) (string, error) {
|
||
// Build path by traversing up max 5 levels
|
||
var pathParts []string
|
||
currentParentNodeID := nodeRelease.ParentID
|
||
|
||
// Traverse up the parent hierarchy, max 5 levels
|
||
for i := 0; i < 5 && currentParentNodeID != ""; i++ {
|
||
// Get the parent node release (ordered by created time to get the latest)
|
||
var parentNodeRelease domain.NodeRelease
|
||
if err := r.db.WithContext(ctx).
|
||
Model(&domain.NodeRelease{}).
|
||
Where("node_id = ? AND kb_id = ?", currentParentNodeID, kbID).
|
||
Select("id, node_id, parent_id, name, type").
|
||
Order("created_at DESC").
|
||
First(&parentNodeRelease).Error; err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
break
|
||
}
|
||
return "", err
|
||
}
|
||
|
||
// Prepend current node name to path if it's a folder
|
||
if parentNodeRelease.Type == domain.NodeTypeFolder {
|
||
pathParts = append(pathParts, parentNodeRelease.Name)
|
||
}
|
||
|
||
// Move to parent's parent
|
||
currentParentNodeID = parentNodeRelease.ParentID
|
||
}
|
||
|
||
// Build the final path
|
||
if len(pathParts) == 0 {
|
||
return "/", nil
|
||
}
|
||
|
||
mutable.Reverse(pathParts)
|
||
path := "/" + strings.Join(pathParts, "/") + "/"
|
||
return path, nil
|
||
}
|
||
|
||
func (r *NodeRepository) GetNodeNameByNodeIDs(ctx context.Context, ids []string) (map[string]string, error) {
|
||
nodesMap := make(map[string]string)
|
||
for _, chunk := range lo.Chunk(ids, 1000) {
|
||
var nodes []*domain.Node
|
||
if err := r.db.WithContext(ctx).
|
||
Model(&domain.Node{}).
|
||
Where("id IN ?", chunk).
|
||
Select("id, name").
|
||
Find(&nodes).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
for _, node := range nodes {
|
||
nodesMap[node.ID] = node.Name
|
||
}
|
||
}
|
||
return nodesMap, nil
|
||
}
|
||
|
||
func (r *NodeRepository) GetNodeReleaseByID(ctx context.Context, id string) (*domain.NodeRelease, error) {
|
||
var nodeRelease *domain.NodeRelease
|
||
if err := r.db.WithContext(ctx).
|
||
Model(&domain.NodeRelease{}).
|
||
Where("id = ?", id).
|
||
First(&nodeRelease).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
return nodeRelease, nil
|
||
}
|
||
|
||
func (r *NodeRepository) GetLatestNodeReleaseByNodeID(ctx context.Context, nodeID string) (*domain.NodeRelease, error) {
|
||
var nodeRelease *domain.NodeRelease
|
||
if err := r.db.WithContext(ctx).
|
||
Model(&domain.NodeRelease{}).
|
||
Where("node_id = ?", nodeID).
|
||
Order("updated_at DESC").
|
||
First(&nodeRelease).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
return nodeRelease, nil
|
||
}
|
||
|
||
func (r *NodeRepository) GetLatestNodeReleaseWithPublishAccount(ctx context.Context, nodeID string) (*domain.NodeReleaseWithPublisher, error) {
|
||
var nodeRelease *domain.NodeReleaseWithPublisher
|
||
if err := r.db.WithContext(ctx).
|
||
Model(&domain.NodeRelease{}).
|
||
Select("node_releases.id, node_releases.publisher_id, users.account as publisher_account").
|
||
Joins("left join users on users.id = node_releases.publisher_id").
|
||
Where("node_releases.node_id = ?", nodeID).
|
||
Order("node_releases.updated_at DESC").
|
||
Find(&nodeRelease).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
return nodeRelease, nil
|
||
}
|
||
|
||
// GetNodeReleaseWithDirPathByID gets a node release by ID and includes its directory path
|
||
func (r *NodeRepository) GetNodeReleaseWithDirPathByID(ctx context.Context, id string) (*domain.NodeReleaseWithDirPath, error) {
|
||
// First get the node release
|
||
var nodeRelease *domain.NodeRelease
|
||
if err := r.db.WithContext(ctx).
|
||
Model(&domain.NodeRelease{}).
|
||
Where("id = ?", id).
|
||
First(&nodeRelease).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
// don't build path for folders
|
||
if nodeRelease != nil && nodeRelease.Type == domain.NodeTypeFolder {
|
||
return &domain.NodeReleaseWithDirPath{
|
||
NodeRelease: nodeRelease,
|
||
}, nil
|
||
}
|
||
|
||
// Build the directory path
|
||
path, err := r.buildNodePath(ctx, nodeRelease.KBID, nodeRelease)
|
||
if err != nil {
|
||
r.logger.Error("failed to build node path", log.String("id", id), log.Error(err))
|
||
}
|
||
|
||
// Return the extended struct with path information
|
||
return &domain.NodeReleaseWithDirPath{
|
||
NodeRelease: nodeRelease,
|
||
Path: path,
|
||
}, nil
|
||
}
|
||
|
||
func (r *NodeRepository) GetNodeReleasesByDocIDs(ctx context.Context, ids []string) (map[string]*domain.NodeRelease, error) {
|
||
var nodeReleases []*domain.NodeRelease
|
||
if err := r.db.WithContext(ctx).
|
||
Model(&domain.NodeRelease{}).
|
||
Where("doc_id IN ?", ids).
|
||
Find(&nodeReleases).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
nodesMap := make(map[string]*domain.NodeRelease)
|
||
for _, nodeRelease := range nodeReleases {
|
||
nodesMap[nodeRelease.DocID] = nodeRelease
|
||
}
|
||
return nodesMap, nil
|
||
}
|
||
|
||
// NodeReleaseWithPath represents a node release with path information
|
||
type NodeReleaseWithPath struct {
|
||
*domain.NodeRelease
|
||
PathIDs []string `json:"path_ids"`
|
||
PathNames []string `json:"path_names"`
|
||
Depth int `json:"depth"`
|
||
}
|
||
|
||
// GetNodeReleasesWithPathsByDocIDs retrieving node releases with path information
|
||
func (r *NodeRepository) GetNodeReleasesWithPathsByDocIDs(ctx context.Context, ids []string) (map[string]*NodeReleaseWithPath, error) {
|
||
if len(ids) == 0 {
|
||
return make(map[string]*NodeReleaseWithPath), nil
|
||
}
|
||
|
||
// 1. 查询节点基本信息
|
||
var nodeReleases []*domain.NodeRelease
|
||
if err := r.db.WithContext(ctx).
|
||
Model(&domain.NodeRelease{}).
|
||
Where("doc_id IN ?", ids).
|
||
Find(&nodeReleases).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
if len(nodeReleases) == 0 {
|
||
return make(map[string]*NodeReleaseWithPath), nil
|
||
}
|
||
|
||
docIDs := lo.Map(nodeReleases, func(release *domain.NodeRelease, i int) string {
|
||
return release.DocID
|
||
})
|
||
|
||
// 2. 批量查询路径
|
||
paths, err := r.getNodePathsBatch(ctx, docIDs)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to get paths: %w", err)
|
||
}
|
||
|
||
// 3. 组装结果
|
||
result := make(map[string]*NodeReleaseWithPath, len(nodeReleases))
|
||
for _, nr := range nodeReleases {
|
||
nrWithPath := &NodeReleaseWithPath{
|
||
NodeRelease: nr,
|
||
}
|
||
|
||
if path, ok := paths[nr.DocID]; ok {
|
||
nrWithPath.PathIDs = path.PathIDs
|
||
nrWithPath.PathNames = path.PathNames
|
||
nrWithPath.Depth = path.Depth
|
||
}
|
||
|
||
result[nr.DocID] = nrWithPath
|
||
}
|
||
|
||
return result, nil
|
||
}
|
||
|
||
// NodePathInfo contains path information for a node
|
||
type NodePathInfo struct {
|
||
DocID string
|
||
PathIDs []string
|
||
PathNames []string
|
||
Depth int
|
||
}
|
||
|
||
// getNodePathsBatch batch query node paths
|
||
func (r *NodeRepository) getNodePathsBatch(ctx context.Context, docIDs []string) (map[string]*NodePathInfo, error) {
|
||
type pathResult struct {
|
||
DocID string `gorm:"column:doc_id"`
|
||
PathIDs pq.StringArray `gorm:"column:path_ids;type:text[]"`
|
||
PathNames pq.StringArray `gorm:"column:path_names;type:text[]"`
|
||
Depth int `gorm:"column:depth"`
|
||
}
|
||
|
||
var results []pathResult
|
||
|
||
query := `
|
||
WITH RECURSIVE node_paths AS (
|
||
SELECT
|
||
node_id,
|
||
parent_id,
|
||
name,
|
||
doc_id as root_doc_id,
|
||
ARRAY[node_id] as path_ids,
|
||
ARRAY[name] as path_names,
|
||
1 as depth
|
||
FROM node_releases
|
||
WHERE doc_id = ANY($1)
|
||
|
||
UNION ALL
|
||
|
||
SELECT
|
||
n.node_id,
|
||
n.parent_id,
|
||
n.name,
|
||
np.root_doc_id,
|
||
n.node_id || np.path_ids,
|
||
n.name || np.path_names,
|
||
np.depth + 1
|
||
FROM node_releases n
|
||
INNER JOIN node_paths np ON n.node_id = np.parent_id
|
||
WHERE np.depth < 20 AND n.doc_id != ''
|
||
)
|
||
SELECT
|
||
root_doc_id as doc_id,
|
||
path_ids,
|
||
path_names,
|
||
depth
|
||
FROM node_paths
|
||
WHERE parent_id IS NULL OR parent_id = ''
|
||
`
|
||
|
||
if err := r.db.WithContext(ctx).
|
||
Raw(query, pq.Array(docIDs)).
|
||
Scan(&results).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 转换为map
|
||
pathMap := make(map[string]*NodePathInfo, len(results))
|
||
for _, res := range results {
|
||
pathMap[res.DocID] = &NodePathInfo{
|
||
DocID: res.DocID,
|
||
PathIDs: res.PathIDs,
|
||
PathNames: res.PathNames,
|
||
Depth: res.Depth,
|
||
}
|
||
}
|
||
|
||
return pathMap, nil
|
||
}
|
||
|
||
// GetRecommendNodeListByIDs get node list by ids
|
||
func (r *NodeRepository) GetRecommendNodeListByIDs(ctx context.Context, kbID string, releaseID string, ids []string) ([]*domain.RecommendNodeListResp, error) {
|
||
var nodes []*domain.RecommendNodeListResp
|
||
if err := r.db.WithContext(ctx).
|
||
Model(&domain.KBReleaseNodeRelease{}).
|
||
Joins("LEFT JOIN node_releases ON node_releases.id = kb_release_node_releases.node_release_id").
|
||
Joins("LEFT JOIN nodes ON nodes.id = node_releases.node_id").
|
||
Where("node_releases.kb_id = ?", kbID).
|
||
Where("kb_release_node_releases.release_id = ?", releaseID).
|
||
Where("node_releases.node_id IN ?", ids).
|
||
Select("node_releases.node_id as id, node_releases.name, node_releases.type, node_releases.meta->>'summary' as summary, node_releases.meta->>'emoji' as emoji, node_releases.parent_id, node_releases.position, nodes.permissions").
|
||
Find(&nodes).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
return nodes, nil
|
||
}
|
||
|
||
// GetRecommendNodeListByNavIDs get node list by nav ids
|
||
func (r *NodeRepository) GetRecommendNodeListByNavIDs(ctx context.Context, kbID string, releaseID string, navIds []string) ([]*domain.RecommendNodeListResp, error) {
|
||
var nodes []*domain.RecommendNodeListResp
|
||
if err := r.db.WithContext(ctx).
|
||
Model(&domain.KBReleaseNodeRelease{}).
|
||
Joins("LEFT JOIN node_releases ON node_releases.id = kb_release_node_releases.node_release_id").
|
||
Joins("LEFT JOIN nodes ON nodes.id = node_releases.node_id").
|
||
Where("node_releases.kb_id = ?", kbID).
|
||
Where("kb_release_node_releases.release_id = ?", releaseID).
|
||
Where("nodes.nav_id IN ?", navIds).
|
||
Select("node_releases.node_id as id, node_releases.name, node_releases.type, node_releases.meta->>'summary' as summary, node_releases.meta->>'emoji' as emoji, node_releases.parent_id, node_releases.position, nodes.permissions, nodes.nav_id").
|
||
Order("node_releases.position ASC").
|
||
Find(&nodes).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
return nodes, nil
|
||
}
|
||
|
||
func (r *NodeRepository) GetRecommendNodeListByParentIDs(ctx context.Context, kbID string, releaseID string, parentIDs []string) (map[string][]*domain.RecommendNodeListResp, error) {
|
||
var nodes []*domain.RecommendNodeListResp
|
||
if err := r.db.WithContext(ctx).
|
||
Model(&domain.KBReleaseNodeRelease{}).
|
||
Joins("LEFT JOIN node_releases ON node_releases.id = kb_release_node_releases.node_release_id").
|
||
Joins("LEFT JOIN nodes ON nodes.id = node_releases.node_id").
|
||
Where("node_releases.kb_id = ?", kbID).
|
||
Where("kb_release_node_releases.release_id = ?", releaseID).
|
||
Where("node_releases.parent_id IN ?", parentIDs).
|
||
Where("node_releases.type != ?", domain.NodeTypeFolder).
|
||
Select("node_releases.node_id as id, node_releases.name, node_releases.type, node_releases.meta->>'summary' as summary, node_releases.meta->>'emoji' as emoji, node_releases.parent_id, node_releases.position, nodes.permissions").
|
||
Find(&nodes).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
nodesMap := make(map[string][]*domain.RecommendNodeListResp)
|
||
for _, node := range nodes {
|
||
if _, ok := nodesMap[node.ParentID]; !ok {
|
||
nodesMap[node.ParentID] = make([]*domain.RecommendNodeListResp, 0)
|
||
}
|
||
nodesMap[node.ParentID] = append(nodesMap[node.ParentID], node)
|
||
}
|
||
return nodesMap, nil
|
||
}
|
||
|
||
// GetNodeReleaseListByKBID get node list by kb id
|
||
func (r *NodeRepository) GetNodeReleaseListByKBID(ctx context.Context, kbID string) ([]*domain.ShareNodeListItemResp, error) {
|
||
// get kb release
|
||
var kbRelease *domain.KBRelease
|
||
if err := r.db.WithContext(ctx).
|
||
Model(&domain.KBRelease{}).
|
||
Where("kb_id = ?", kbID).
|
||
Order("created_at DESC").
|
||
First(&kbRelease).Error; err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, nil
|
||
}
|
||
return nil, err
|
||
}
|
||
|
||
var nodes []*domain.ShareNodeListItemResp
|
||
qs := r.db.WithContext(ctx).
|
||
Model(&domain.KBReleaseNodeRelease{}).
|
||
Joins("LEFT JOIN node_releases ON node_releases.id = kb_release_node_releases.node_release_id").
|
||
Joins("LEFT JOIN nodes ON nodes.id = kb_release_node_releases.node_id").
|
||
Where("kb_release_node_releases.kb_id = ?", kbID).
|
||
Where("kb_release_node_releases.release_id = ?", kbRelease.ID).
|
||
Where("nodes.permissions->>'visible' != ?", consts.NodeAccessPermClosed).
|
||
Select("node_releases.node_id as id, node_releases.name, node_releases.type, node_releases.parent_id, nodes.position, node_releases.meta->>'emoji' as emoji, node_releases.updated_at, nodes.permissions, nodes.meta, kb_release_node_releases.nav_id")
|
||
|
||
if err := qs.Find(&nodes).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
return nodes, nil
|
||
}
|
||
|
||
func (r *NodeRepository) GetNodeReleaseDetailByKBIDAndID(ctx context.Context, kbID, id string) (*shareV1.ShareNodeDetailResp, error) {
|
||
// get kb release
|
||
var kbRelease *domain.KBRelease
|
||
if err := r.db.WithContext(ctx).
|
||
Model(&domain.KBRelease{}).
|
||
Where("kb_id = ?", kbID).
|
||
Order("created_at DESC").
|
||
First(&kbRelease).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
var node *shareV1.ShareNodeDetailResp
|
||
if err := r.db.WithContext(ctx).
|
||
Model(&domain.KBReleaseNodeRelease{}).
|
||
Select("node_releases.*, nodes.permissions, nodes.creator_id").
|
||
Joins("LEFT JOIN node_releases ON node_releases.id = kb_release_node_releases.node_release_id").
|
||
Joins("LEFT JOIN nodes ON nodes.id = kb_release_node_releases.node_id").
|
||
Where("kb_release_node_releases.release_id = ?", kbRelease.ID).
|
||
Where("node_releases.node_id = ?", id).
|
||
Where("node_releases.kb_id = ?", kbID).
|
||
First(&node).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
return node, nil
|
||
}
|
||
|
||
func (r *NodeRepository) MoveNodeBetween(ctx context.Context, id, parentID, prevID, nextID, kbId string) error {
|
||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||
var prevPos, maxPos float64 = 0, domain.MaxPosition
|
||
if prevID != "" {
|
||
var prevNode *domain.Node
|
||
if err := tx.Model(&domain.Node{}).
|
||
Where("id = ?", prevID).
|
||
Where("kb_id = ?", kbId).
|
||
Where("parent_id = ?", parentID).
|
||
Select("position, parent_id").
|
||
First(&prevNode).Error; err != nil {
|
||
return err
|
||
}
|
||
prevPos = prevNode.Position
|
||
}
|
||
if nextID != "" {
|
||
var nextNode *domain.Node
|
||
if err := tx.Model(&domain.Node{}).
|
||
Where("id = ?", nextID).
|
||
Where("parent_id = ?", parentID).
|
||
Where("kb_id = ?", kbId).
|
||
Select("position, parent_id").
|
||
First(&nextNode).Error; err != nil {
|
||
return err
|
||
}
|
||
maxPos = nextNode.Position
|
||
}
|
||
|
||
node, err := r.GetNodeByID(ctx, id)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
newPos := prevPos + (maxPos-prevPos)/2.0
|
||
if newPos-prevPos < domain.MinPositionGap {
|
||
if err := r.reorderPositionsByParentID(tx, node.KBID, parentID); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
querySet := tx.Model(&domain.Node{}).Where("id = ?", id).Update("position", newPos).Update("parent_id", parentID)
|
||
|
||
if node.Status == domain.NodeStatusPublished {
|
||
querySet = querySet.Update("status", domain.NodeStatusDraft)
|
||
}
|
||
|
||
return querySet.Error
|
||
})
|
||
}
|
||
|
||
// UpdateNodeDocID update node doc id
|
||
func (r *NodeRepository) UpdateNodeDocID(ctx context.Context, id, docID string) error {
|
||
return r.db.WithContext(ctx).
|
||
Model(&domain.Node{}).
|
||
Omit("updated_at").
|
||
Where("id = ?", id).
|
||
Updates(map[string]any{
|
||
"doc_id": docID,
|
||
}).Error
|
||
}
|
||
|
||
// UpdateNodeReleaseDocID update node release doc id
|
||
func (r *NodeRepository) UpdateNodeReleaseDocID(ctx context.Context, id, docID string) error {
|
||
return r.db.WithContext(ctx).
|
||
Model(&domain.NodeRelease{}).
|
||
Omit("updated_at").
|
||
Where("id = ?", id).
|
||
Updates(map[string]any{
|
||
"doc_id": docID,
|
||
}).Error
|
||
}
|
||
|
||
func (r *NodeRepository) UpdateNodeSummary(ctx context.Context, kbID, nodeID, summary string) error {
|
||
return r.db.WithContext(ctx).
|
||
Model(&domain.Node{}).
|
||
Where("kb_id = ? AND id = ?", kbID, nodeID).
|
||
Updates(map[string]any{
|
||
"meta": gorm.Expr("jsonb_set(meta, '{summary}', to_jsonb(?::text))", summary),
|
||
}).Error
|
||
}
|
||
|
||
func (r *NodeRepository) UpdateNodeStatus(ctx context.Context, kbID, nodeID string, nodeStatus domain.NodeStatus) error {
|
||
return r.db.WithContext(ctx).
|
||
Model(&domain.Node{}).
|
||
Where("kb_id = ? AND id = ?", kbID, nodeID).
|
||
Updates(map[string]any{
|
||
"status": nodeStatus,
|
||
}).Error
|
||
}
|
||
|
||
// traverse all nodes by pg cursor
|
||
func (r *NodeRepository) TraverseNodesByCursor(ctx context.Context, callback func(*domain.NodeRelease) error) error {
|
||
rows, err := r.db.WithContext(ctx).
|
||
Model(&domain.NodeRelease{}).
|
||
Select("DISTINCT ON (node_id) id, node_id, kb_id").
|
||
Order("node_id, updated_at DESC").
|
||
Rows()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer rows.Close()
|
||
|
||
for rows.Next() {
|
||
var nodeRelease domain.NodeRelease
|
||
if err := r.db.ScanRows(rows, &nodeRelease); err != nil {
|
||
return err
|
||
}
|
||
if err := callback(&nodeRelease); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
if err := rows.Err(); err != nil {
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// CreateNodeReleases create node releases
|
||
func (r *NodeRepository) CreateNodeReleases(ctx context.Context, kbID, userId string, nodeIDs []string) ([]string, error) {
|
||
releaseIDs := make([]string, 0)
|
||
if err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||
// update node status to published and return node ids
|
||
var updatedNodes []*domain.Node
|
||
if err := tx.Model(&domain.Node{}).
|
||
Where("kb_id = ?", kbID).
|
||
Where("id IN ?", nodeIDs).
|
||
Update("status", domain.NodeStatusPublished).
|
||
Find(&updatedNodes).Error; err != nil {
|
||
return err
|
||
}
|
||
if len(updatedNodes) == 0 {
|
||
return nil
|
||
}
|
||
nodeReleases := make([]*domain.NodeRelease, len(updatedNodes))
|
||
for i, updatedNode := range updatedNodes {
|
||
// create node release
|
||
nodeRelease := &domain.NodeRelease{
|
||
ID: uuid.New().String(),
|
||
KBID: kbID,
|
||
PublisherId: userId,
|
||
EditorId: updatedNode.EditorId,
|
||
NodeID: updatedNode.ID,
|
||
Type: updatedNode.Type,
|
||
Name: updatedNode.Name,
|
||
Meta: updatedNode.Meta,
|
||
Content: updatedNode.Content,
|
||
ParentID: updatedNode.ParentID,
|
||
Position: updatedNode.Position,
|
||
CreatedAt: updatedNode.CreatedAt,
|
||
UpdatedAt: time.Now(),
|
||
}
|
||
nodeReleases[i] = nodeRelease
|
||
releaseIDs = append(releaseIDs, nodeRelease.ID)
|
||
}
|
||
|
||
if err := tx.CreateInBatches(&nodeReleases, 100).Error; err != nil {
|
||
return err
|
||
}
|
||
return nil
|
||
}); err != nil {
|
||
return nil, err
|
||
}
|
||
return releaseIDs, nil
|
||
}
|
||
|
||
func (r *NodeRepository) GetOldNodeDocIDsByNodeID(ctx context.Context, nodeReleaseID, nodeID string) ([]string, error) {
|
||
var docIDs []string
|
||
if err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||
// get old doc_ids by node_id
|
||
if err := tx.Model(&domain.NodeRelease{}).
|
||
Where("node_id = ?", nodeID).
|
||
Where("id != ?", nodeReleaseID).
|
||
Where("doc_id != ''").
|
||
Select("doc_id").
|
||
Find(&docIDs).Error; err != nil {
|
||
return err
|
||
}
|
||
// update node_release.doc_id to ""
|
||
if err := tx.Model(&domain.NodeRelease{}).
|
||
Where("node_id = ?", nodeID).
|
||
Where("id != ?", nodeReleaseID).
|
||
Omit("updated_at").
|
||
Update("doc_id", "").Error; err != nil {
|
||
return err
|
||
}
|
||
return nil
|
||
}); err != nil {
|
||
return nil, err
|
||
}
|
||
return docIDs, nil
|
||
}
|
||
|
||
func (r *NodeRepository) MoveNodeNav(ctx context.Context, kbID, navID string, nodeIDs []string) error {
|
||
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||
allIDs := r.collectAllChildNodeIDs(tx, kbID, nodeIDs)
|
||
if err := tx.Model(&domain.Node{}).
|
||
Where("kb_id = ? AND id IN ?", kbID, allIDs).
|
||
Update("nav_id", navID).Error; err != nil {
|
||
return err
|
||
}
|
||
|
||
if err := tx.Model(&domain.Node{}).
|
||
Where("kb_id = ? AND id IN ?", kbID, allIDs).
|
||
Where("parent_id != ''").
|
||
Where("parent_id NOT IN ?", allIDs).
|
||
Update("parent_id", "").Error; err != nil {
|
||
return err
|
||
}
|
||
|
||
if err := tx.Model(&domain.Node{}).
|
||
Where("kb_id = ? AND id IN ?", kbID, allIDs).
|
||
Where("status = ?", domain.NodeStatusPublished).
|
||
Update("status", domain.NodeStatusDraft).Error; err != nil {
|
||
return err
|
||
}
|
||
return nil
|
||
})
|
||
}
|
||
|
||
func (r *NodeRepository) BatchMove(ctx context.Context, req *domain.BatchMoveReq) error {
|
||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||
// update node parent_id
|
||
if err := tx.WithContext(ctx).Model(&domain.Node{}).
|
||
Where("kb_id = ?", req.KBID).
|
||
Where("id IN ?", req.IDs).
|
||
Update("parent_id", req.ParentID).
|
||
Error; err != nil {
|
||
return err
|
||
}
|
||
if err := tx.WithContext(ctx).Model(&domain.Node{}).
|
||
Where("kb_id = ?", req.KBID).
|
||
Where("id IN ?", req.IDs).
|
||
Where("status = ?", domain.NodeStatusPublished).
|
||
Update("status", domain.NodeStatusDraft).
|
||
Error; err != nil {
|
||
return err
|
||
}
|
||
return nil
|
||
})
|
||
}
|
||
|
||
// reorderPositionsByParentID 重排所给父节点下的所有子节点
|
||
func (r *NodeRepository) reorderPositionsByParentID(tx *gorm.DB, kbID, parentID string) error {
|
||
var nodes []*domain.Node
|
||
if parentID == "" {
|
||
if err := tx.Model(&domain.Node{}).
|
||
Where("kb_id = ?", kbID).
|
||
Where("parent_id IS NULL OR parent_id = ''").
|
||
Order("position").
|
||
Find(&nodes).Error; err != nil {
|
||
return err
|
||
}
|
||
} else {
|
||
if err := tx.Model(&domain.Node{}).
|
||
Where("kb_id = ?", kbID).
|
||
Where("parent_id = ?", parentID).
|
||
Order("position").
|
||
Find(&nodes).Error; err != nil {
|
||
return err
|
||
}
|
||
}
|
||
return r.reorderPositions(tx, nodes)
|
||
}
|
||
|
||
// reorderPositions 重排所给节点
|
||
func (r *NodeRepository) reorderPositions(tx *gorm.DB, nodes []*domain.Node) error {
|
||
if len(nodes) == 0 {
|
||
return nil
|
||
}
|
||
|
||
basePosition := int64(1000) // 起始位置
|
||
interval := int64(1000) // 间隔
|
||
|
||
updates := make([]map[string]interface{}, len(nodes))
|
||
for i, node := range nodes {
|
||
newPosition := float64(basePosition + int64(i)*interval)
|
||
updates[i] = map[string]interface{}{
|
||
"id": node.ID,
|
||
"position": newPosition,
|
||
}
|
||
}
|
||
|
||
batchSize := 300
|
||
for i := 0; i < len(updates); i += batchSize {
|
||
end := i + batchSize
|
||
if end > len(updates) {
|
||
end = len(updates)
|
||
}
|
||
batch := updates[i:end]
|
||
|
||
values := make([]string, 0, len(batch))
|
||
for _, update := range batch {
|
||
id := update["id"]
|
||
pos := update["position"]
|
||
values = append(values, fmt.Sprintf("('%v', %v)", id, pos))
|
||
}
|
||
|
||
sql := fmt.Sprintf("UPDATE nodes SET position = new_values.new_value FROM (VALUES %s) AS new_values(id, new_value) WHERE nodes.id = new_values.id", strings.Join(values, ", "))
|
||
|
||
if err := tx.Exec(sql).Error; err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// GetNodeIDsByReleaseID get node IDs by release ID
|
||
func (r *NodeRepository) GetNodeIDsByReleaseID(ctx context.Context, releaseID string) ([]string, error) {
|
||
var nodeIDs []string
|
||
if err := r.db.WithContext(ctx).
|
||
Model(&domain.KBReleaseNodeRelease{}).
|
||
Where("release_id = ?", releaseID).
|
||
Select("node_id").
|
||
Find(&nodeIDs).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
return nodeIDs, nil
|
||
}
|
||
func (r *NodeRepository) UpdateNodeByKbID(ctx context.Context, id, kbId string, updateMap map[string]interface{}) error {
|
||
return r.db.WithContext(ctx).
|
||
Model(&domain.Node{}).
|
||
Where("id = ?", id).
|
||
Where("kb_id = ?", kbId).
|
||
Updates(updateMap).Error
|
||
}
|
||
|
||
func (r *NodeRepository) UpdateNodesByKbID(ctx context.Context, ids []string, kbId string, updateMap map[string]interface{}) error {
|
||
const batchSize = 500 // 批处理大小,避免IN子句过长
|
||
|
||
// 如果没有ID需要更新,直接返回
|
||
if len(ids) == 0 {
|
||
return nil
|
||
}
|
||
|
||
// 分批处理
|
||
for i := 0; i < len(ids); i += batchSize {
|
||
end := i + batchSize
|
||
if end > len(ids) {
|
||
end = len(ids)
|
||
}
|
||
|
||
batch := ids[i:end]
|
||
if err := r.db.WithContext(ctx).
|
||
Model(&domain.Node{}).
|
||
Where("id in (?)", batch).
|
||
Where("kb_id = ?", kbId).
|
||
Updates(updateMap).Error; err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (r *NodeRepository) UpdateNodeGroupByKbIDAndNodeIds(ctx context.Context, nodeIds []string, groupIds []int, perm consts.NodePermName) error {
|
||
const batchSize = 1000 // 批处理大小,避免IN子句过长
|
||
|
||
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||
// 分批删除现有的权限记录,防止nodeIds过长
|
||
for i := 0; i < len(nodeIds); i += batchSize {
|
||
end := i + batchSize
|
||
if end > len(nodeIds) {
|
||
end = len(nodeIds)
|
||
}
|
||
|
||
batch := nodeIds[i:end]
|
||
if err := tx.Model(&domain.NodeAuthGroup{}).
|
||
Where("node_id in (?) AND perm = ?", batch, perm).
|
||
Delete(&domain.NodeAuthGroup{}).Error; err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
// 如果 groupIds 为空,则只执行删除操作
|
||
if len(groupIds) == 0 {
|
||
return nil
|
||
}
|
||
|
||
nodeGroups := make([]domain.NodeAuthGroup, 0)
|
||
for i := range nodeIds {
|
||
// 批量插入新的数据
|
||
for index := range groupIds {
|
||
if groupIds[index] == 0 {
|
||
continue
|
||
}
|
||
nodeGroups = append(nodeGroups, domain.NodeAuthGroup{
|
||
NodeID: nodeIds[i],
|
||
AuthGroupID: groupIds[index],
|
||
Perm: perm,
|
||
})
|
||
}
|
||
}
|
||
|
||
if len(nodeGroups) != 0 {
|
||
if err := tx.Model(&domain.NodeAuthGroup{}).CreateInBatches(&nodeGroups, 100).Error; err != nil {
|
||
return err
|
||
}
|
||
}
|
||
return nil
|
||
})
|
||
}
|
||
|
||
func (r *NodeRepository) GetNodeGroupByNodeId(ctx context.Context, nodeId string) ([]domain.NodeGroupDetail, error) {
|
||
nodeGroup := make([]domain.NodeGroupDetail, 0)
|
||
if err := r.db.WithContext(ctx).
|
||
Model(&domain.NodeAuthGroup{}).
|
||
Select("node_auth_groups.node_id, node_auth_groups.auth_group_id, node_auth_groups.perm, auth_groups.name, auth_groups.kb_id, auth_groups.auth_ids").
|
||
Joins("left join auth_groups on auth_groups.id = node_auth_groups.auth_group_id").
|
||
Where("node_auth_groups.node_id = ?", nodeId).
|
||
Scan(&nodeGroup).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
return nodeGroup, nil
|
||
}
|
||
|
||
func (r *NodeRepository) Update(ctx context.Context, id string, m map[string]interface{}) error {
|
||
return r.db.WithContext(ctx).Model(domain.Node{}).Where("id = ?", id).Updates(m).Error
|
||
}
|
||
|
||
func (r *NodeRepository) GetNodeIdByDocId(ctx context.Context, docId string) (string, error) {
|
||
nodeIds := make([]string, 0)
|
||
if err := r.db.WithContext(ctx).Model(domain.NodeRelease{}).
|
||
Where("doc_id = ?", docId).
|
||
Pluck("node_id", &nodeIds).Error; err != nil {
|
||
return "", err
|
||
}
|
||
if len(nodeIds) < 1 {
|
||
return "", fmt.Errorf("node not found for doc_id: %s", docId)
|
||
}
|
||
return nodeIds[0], nil
|
||
}
|
||
|
||
func (r *NodeRepository) GetNodeIdsWithoutStatusByKbId(ctx context.Context, kbId string) ([]string, error) {
|
||
docIds := make([]string, 0)
|
||
if err := r.db.WithContext(ctx).
|
||
Model(&domain.Node{}).
|
||
Joins("left join node_releases on node_releases.node_id = nodes.id").
|
||
Where("(nodes.rag_info ->> 'status' IS NULL OR nodes.rag_info ->> 'status' = '')").
|
||
Where("nodes.kb_id = ? ", kbId).
|
||
Where("nodes.type = ? ", domain.NodeTypeDocument).
|
||
Where("node_releases.doc_id != '' ").
|
||
Pluck("node_releases.doc_id", &docIds).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
return docIds, nil
|
||
}
|
||
|
||
// GetNodeIdsByDocIds 批量获取 doc_id 到 node_id 的映射
|
||
func (r *NodeRepository) GetNodeIdsByDocIds(ctx context.Context, docIds []string) (map[string]string, error) {
|
||
if len(docIds) == 0 {
|
||
return make(map[string]string), nil
|
||
}
|
||
|
||
type Result struct {
|
||
DocID string `gorm:"column:doc_id"`
|
||
NodeID string `gorm:"column:node_id"`
|
||
}
|
||
|
||
results := make([]Result, 0)
|
||
if err := r.db.WithContext(ctx).
|
||
Model(&domain.NodeRelease{}).
|
||
Select("doc_id, node_id").
|
||
Where("doc_id IN (?)", docIds).
|
||
Find(&results).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 构建 doc_id -> node_id 的映射
|
||
docToNodeMap := make(map[string]string, len(results))
|
||
for _, result := range results {
|
||
docToNodeMap[result.DocID] = result.NodeID
|
||
}
|
||
|
||
return docToNodeMap, nil
|
||
}
|
||
|
||
func (r *NodeRepository) DeleteOldNodeReleaseBackups(ctx context.Context, before time.Time) error {
|
||
return r.db.WithContext(ctx).
|
||
Where("deleted_at < ?", before).
|
||
Delete(&domain.NodeReleaseBackup{}).Error
|
||
}
|
||
|
||
func (r *NodeRepository) GetNodeCount(ctx context.Context) (int, error) {
|
||
var count int64
|
||
err := r.db.WithContext(ctx).
|
||
Model(&domain.Node{}).
|
||
Count(&count).Error
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return int(count), nil
|
||
}
|
||
|
||
func (r *NodeRepository) CountNodeByNavId(ctx context.Context, kbId, navId string) (int64, error) {
|
||
var count int64
|
||
if err := r.db.WithContext(ctx).
|
||
Model(&domain.Node{}).
|
||
Where("kb_id = ?", kbId).
|
||
Where("nav_id = ?", navId).
|
||
Count(&count).Error; err != nil {
|
||
return 0, err
|
||
}
|
||
return count, nil
|
||
}
|
||
|
||
func (r *NodeRepository) GetNodeIDsByNavId(ctx context.Context, kbId, navId string) ([]string, error) {
|
||
var ids []string
|
||
if err := r.db.WithContext(ctx).
|
||
Model(&domain.Node{}).
|
||
Where("kb_id = ? AND nav_id = ?", kbId, navId).
|
||
Pluck("id", &ids).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
return ids, nil
|
||
}
|
||
|
||
func (r *NodeRepository) GetNodeListByStatus(ctx context.Context, kbId, status, search string) ([]*domain.NodeListItemResp, error) {
|
||
var nodes []*domain.NodeListItemResp
|
||
query := r.db.WithContext(ctx).
|
||
Model(&domain.Node{}).
|
||
Joins("LEFT JOIN users cu ON nodes.creator_id = cu.id").
|
||
Joins("LEFT JOIN users eu ON nodes.editor_id = eu.id").
|
||
Where("nodes.kb_id = ?", kbId).
|
||
Select("cu.account AS creator, eu.account AS editor, nodes.editor_id, nodes.nav_id, nodes.rag_info, nodes.creator_id, nodes.id, nodes.permissions, nodes.type, nodes.status, nodes.name, nodes.parent_id, nodes.position, nodes.created_at, nodes.edit_time as updated_at, nodes.meta->>'summary' as summary, nodes.meta->>'emoji' as emoji, nodes.meta->>'content_type' as content_type")
|
||
|
||
if search != "" {
|
||
searchPattern := "%" + search + "%"
|
||
query = query.Where("name LIKE ? OR content LIKE ?", searchPattern, searchPattern)
|
||
}
|
||
|
||
switch status {
|
||
// 发布后允许可配置的
|
||
case "released":
|
||
query = query.Where("nodes.status IN ?", []domain.NodeStatus{domain.NodeStatusDraft, domain.NodeStatusPublished})
|
||
case "unpublished":
|
||
query = query.Where("nodes.status IN ?", []domain.NodeStatus{domain.NodeStatusUnreleased, domain.NodeStatusDraft})
|
||
case "unstudied":
|
||
query = query.Where("nodes.type = ?", domain.NodeTypeDocument).
|
||
Where("nodes.rag_info->>'status' NOT IN ? OR nodes.rag_info->>'status' IS NULL",
|
||
[]string{string(consts.NodeRagStatusSucceeded), string(consts.NodeRagStatusRunning), string(consts.NodeRagStatusReindexing)})
|
||
}
|
||
|
||
if err := query.Find(&nodes).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
return nodes, nil
|
||
}
|
||
|
||
func (r *NodeRepository) GetNodeStats(ctx context.Context, kbId string) (*v1.NodeStatsResp, error) {
|
||
var stats v1.NodeStatsResp
|
||
|
||
// Count unpublished documents (status = 0 or 1)
|
||
unpublishedQuery := r.db.WithContext(ctx).
|
||
Model(&domain.Node{}).
|
||
Where("kb_id = ? AND status IN ?", kbId, []domain.NodeStatus{domain.NodeStatusUnreleased, domain.NodeStatusDraft})
|
||
|
||
if err := unpublishedQuery.Count(&stats.UnpublishedCount).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
studiedStatuses := []consts.NodeRagInfoStatus{
|
||
consts.NodeRagStatusSucceeded,
|
||
consts.NodeRagStatusRunning,
|
||
consts.NodeRagStatusReindexing,
|
||
}
|
||
|
||
unstudiedQuery := r.db.WithContext(ctx).
|
||
Model(&domain.Node{}).
|
||
Where("kb_id = ?", kbId).
|
||
Where("nodes.type = ?", domain.NodeTypeDocument).
|
||
Where("rag_info->>'status' NOT IN ? OR rag_info->>'status' IS NULL", studiedStatuses)
|
||
|
||
if err := unstudiedQuery.Count(&stats.UnstudiedCount).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return &stats, nil
|
||
}
|