849 lines
23 KiB
Go
849 lines
23 KiB
Go
package pg
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"maps"
|
|
"net"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/google/go-cmp/cmp"
|
|
"github.com/google/uuid"
|
|
"github.com/samber/lo"
|
|
"gorm.io/gorm"
|
|
|
|
v1 "github.com/chaitin/panda-wiki/api/kb/v1"
|
|
"github.com/chaitin/panda-wiki/config"
|
|
"github.com/chaitin/panda-wiki/consts"
|
|
"github.com/chaitin/panda-wiki/domain"
|
|
"github.com/chaitin/panda-wiki/log"
|
|
"github.com/chaitin/panda-wiki/store/pg"
|
|
"github.com/chaitin/panda-wiki/store/rag"
|
|
)
|
|
|
|
type KnowledgeBaseRepository struct {
|
|
db *pg.DB
|
|
config *config.Config
|
|
logger *log.Logger
|
|
rag rag.RAGService
|
|
}
|
|
|
|
func NewKnowledgeBaseRepository(db *pg.DB, config *config.Config, logger *log.Logger, rag rag.RAGService) *KnowledgeBaseRepository {
|
|
r := &KnowledgeBaseRepository{
|
|
db: db,
|
|
config: config,
|
|
logger: logger.WithModule("repo.pg.knowledge_base"),
|
|
rag: rag,
|
|
}
|
|
ctx := context.Background()
|
|
kbList, err := r.GetKnowledgeBaseList(ctx)
|
|
if err != nil {
|
|
r.logger.Error("failed to get knowledge base list", "error", err)
|
|
return r
|
|
}
|
|
if len(kbList) > 0 {
|
|
if err := r.SyncKBAccessSettingsToCaddy(ctx, kbList); err != nil {
|
|
r.logger.Error("failed to sync kb access settings to caddy", "error", err)
|
|
}
|
|
}
|
|
return r
|
|
}
|
|
|
|
func (r *KnowledgeBaseRepository) SyncKBAccessSettingsToCaddy(ctx context.Context, kbList []*domain.KnowledgeBaseListItem) error {
|
|
if len(kbList) == 0 {
|
|
return nil
|
|
}
|
|
firstKB := kbList[0]
|
|
firstHost := ""
|
|
if len(firstKB.AccessSettings.Hosts) > 0 {
|
|
firstHost = firstKB.AccessSettings.Hosts[0]
|
|
}
|
|
certs := make([]map[string]any, 0)
|
|
portHostKBMap := make(map[string]map[string]*domain.KnowledgeBaseListItem)
|
|
httpPorts := make(map[string]struct{})
|
|
for _, kb := range kbList {
|
|
for _, port := range kb.AccessSettings.Ports {
|
|
httpPorts[fmt.Sprintf(":%d", port)] = struct{}{}
|
|
if _, ok := portHostKBMap[fmt.Sprintf(":%d", port)]; !ok {
|
|
portHostKBMap[fmt.Sprintf(":%d", port)] = make(map[string]*domain.KnowledgeBaseListItem)
|
|
}
|
|
for _, host := range kb.AccessSettings.Hosts {
|
|
portHostKBMap[fmt.Sprintf(":%d", port)][host] = kb
|
|
}
|
|
}
|
|
for _, sslPort := range kb.AccessSettings.SSLPorts {
|
|
if _, ok := portHostKBMap[fmt.Sprintf(":%d", sslPort)]; !ok {
|
|
portHostKBMap[fmt.Sprintf(":%d", sslPort)] = make(map[string]*domain.KnowledgeBaseListItem)
|
|
}
|
|
for _, host := range kb.AccessSettings.Hosts {
|
|
portHostKBMap[fmt.Sprintf(":%d", sslPort)][host] = kb
|
|
}
|
|
}
|
|
if len(kb.AccessSettings.PublicKey) > 0 && len(kb.AccessSettings.PrivateKey) > 0 {
|
|
certs = append(certs, map[string]any{
|
|
"certificate": kb.AccessSettings.PublicKey,
|
|
"key": kb.AccessSettings.PrivateKey,
|
|
"tags": []string{kb.ID},
|
|
})
|
|
}
|
|
}
|
|
socketPath := r.config.CaddyAPI
|
|
// sync kb to caddy
|
|
// create server for each port
|
|
subnetPrefix := r.config.SubnetPrefix
|
|
if subnetPrefix == "" {
|
|
subnetPrefix = "169.254.15"
|
|
}
|
|
api := fmt.Sprintf("%s.2:8000", subnetPrefix)
|
|
app := fmt.Sprintf("%s.112:3010", subnetPrefix)
|
|
staticFile := fmt.Sprintf("%s.12:9000", subnetPrefix) // minio
|
|
servers := make(map[string]any, 0)
|
|
for port, hostKBMap := range portHostKBMap {
|
|
trustProxies := make([]string, 0)
|
|
for _, kb := range hostKBMap {
|
|
trustProxies = append(trustProxies, kb.AccessSettings.TrustedProxies...)
|
|
}
|
|
server := map[string]any{
|
|
"listen": []string{port},
|
|
"routes": []map[string]any{},
|
|
}
|
|
if len(trustProxies) != 0 {
|
|
trustProxies = lo.Uniq(trustProxies)
|
|
server["trusted_proxies"] = map[string]any{
|
|
"source": "static",
|
|
"ranges": trustProxies,
|
|
}
|
|
}
|
|
if _, ok := httpPorts[port]; ok {
|
|
server["automatic_https"] = map[string]any{
|
|
"disable": true,
|
|
}
|
|
} else {
|
|
server["automatic_https"] = map[string]any{
|
|
"disable_certificates": true,
|
|
"disable_redirects": true,
|
|
}
|
|
// SSL port: collect certificate tags for tls_connection_policies
|
|
certTags := make([]string, 0)
|
|
for _, kb := range hostKBMap {
|
|
if len(kb.AccessSettings.PublicKey) > 0 && len(kb.AccessSettings.PrivateKey) > 0 {
|
|
certTags = append(certTags, kb.ID)
|
|
}
|
|
}
|
|
if len(certTags) > 0 {
|
|
server["tls_connection_policies"] = []map[string]any{
|
|
{
|
|
"certificate_selection": map[string]any{
|
|
"any_tag": certTags,
|
|
},
|
|
},
|
|
}
|
|
}
|
|
}
|
|
routes := make([]map[string]any, 0)
|
|
var defaultRoute map[string]any
|
|
for host, kb := range hostKBMap {
|
|
route := map[string]any{
|
|
"handle": []map[string]any{
|
|
{
|
|
"handler": "subroute",
|
|
"routes": []map[string]any{
|
|
{
|
|
"match": []map[string]any{
|
|
{
|
|
"path": []string{"/share/v1/chat/message"},
|
|
},
|
|
},
|
|
"handle": []map[string]any{
|
|
{
|
|
"handler": "headers",
|
|
"request": map[string]any{
|
|
"set": map[string][]any{
|
|
"X-KB-ID": {kb.ID},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
"handler": "reverse_proxy",
|
|
"upstreams": []map[string]any{
|
|
{"dial": api},
|
|
},
|
|
"flush_interval": -1,
|
|
"transport": map[string]any{
|
|
"protocol": "http",
|
|
"read_timeout": "10m",
|
|
"write_timeout": "10m",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
"match": []map[string]any{
|
|
{
|
|
"path": []string{"/share/v1/chat/completions", "/share/v1/app/wechat/app", "/share/v1/app/wechat/service", "/sitemap.xml", "/share/v1/app/wechat/official_account", "/share/v1/app/wechat/service/answer", "/mcp"},
|
|
},
|
|
},
|
|
"handle": []map[string]any{
|
|
{
|
|
"handler": "headers",
|
|
"request": map[string]any{
|
|
"set": map[string][]any{
|
|
"X-KB-ID": {kb.ID},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
"handler": "reverse_proxy",
|
|
"upstreams": []map[string]any{
|
|
{"dial": api},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
"match": []map[string]any{
|
|
{
|
|
"path": []string{"/static-file/*"},
|
|
},
|
|
},
|
|
"handle": []map[string]any{
|
|
{
|
|
"handler": "subroute",
|
|
"routes": []map[string]any{
|
|
{
|
|
"match": []map[string]any{
|
|
{
|
|
"not": []map[string]any{
|
|
{"path_regexp": map[string]string{"pattern": `(?i)\.pdf($|\?)`}},
|
|
},
|
|
},
|
|
},
|
|
"handle": []map[string]any{
|
|
{
|
|
"handler": "headers",
|
|
"response": map[string]any{
|
|
"set": map[string][]string{
|
|
"Content-Disposition": {"attachment"},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
"handle": []map[string]any{
|
|
{
|
|
"handler": "reverse_proxy",
|
|
"upstreams": []map[string]any{
|
|
{"dial": staticFile},
|
|
},
|
|
"flush_interval": -1,
|
|
"transport": map[string]any{
|
|
"protocol": "http",
|
|
"read_timeout": "10m",
|
|
"write_timeout": "10m",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
"handle": []map[string]any{
|
|
{
|
|
"handler": "headers",
|
|
"request": map[string]any{
|
|
"set": map[string][]any{
|
|
"X-KB-ID": {kb.ID},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
"handler": "reverse_proxy",
|
|
"upstreams": []map[string]any{
|
|
{"dial": app},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
if host == firstHost {
|
|
// first host as default host
|
|
// copy route without the host match
|
|
defaultRoute = maps.Clone(route)
|
|
}
|
|
if host != "*" {
|
|
route["match"] = []map[string]any{
|
|
{
|
|
"host": []string{host},
|
|
},
|
|
}
|
|
}
|
|
routes = append(routes, route)
|
|
}
|
|
// add default route if exists
|
|
if defaultRoute != nil {
|
|
routes = append(routes, defaultRoute)
|
|
}
|
|
server["routes"] = routes
|
|
servers[port] = server
|
|
}
|
|
apps := map[string]any{
|
|
"http": map[string]any{
|
|
"servers": servers,
|
|
},
|
|
}
|
|
if len(certs) > 0 {
|
|
apps["tls"] = map[string]any{
|
|
"certificates": map[string]any{
|
|
"load_pem": certs,
|
|
},
|
|
}
|
|
}
|
|
config := map[string]any{
|
|
"apps": apps,
|
|
}
|
|
newBody, _ := json.Marshal(config)
|
|
tr := &http.Transport{
|
|
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
|
|
return net.Dial("unix", socketPath)
|
|
},
|
|
}
|
|
client := &http.Client{
|
|
Transport: tr,
|
|
Timeout: 5 * time.Second,
|
|
}
|
|
req, err := http.NewRequest("POST", "http://unix/load", bytes.NewBuffer(newBody))
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create request: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to send request: %w", err)
|
|
}
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
r.logger.Error("failed to update caddy config", "error", string(body))
|
|
return domain.ErrSyncCaddyConfigFailed
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *KnowledgeBaseRepository) CreateKnowledgeBase(ctx context.Context, maxKB int, kb *domain.KnowledgeBase) error {
|
|
authInfo := domain.GetAuthInfoFromCtx(ctx)
|
|
if authInfo == nil {
|
|
return fmt.Errorf("authInfo not found in context")
|
|
}
|
|
if authInfo.IsToken {
|
|
return fmt.Errorf("this api not support token call")
|
|
}
|
|
|
|
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
if err := tx.Create(kb).Error; err != nil {
|
|
return err
|
|
}
|
|
// get all kb list
|
|
var kbs []*domain.KnowledgeBaseListItem
|
|
if err := tx.Model(&domain.KnowledgeBase{}).
|
|
Order("created_at ASC").
|
|
Find(&kbs).Error; err != nil {
|
|
return err
|
|
}
|
|
if len(kbs) > maxKB {
|
|
return errors.New("kb is too many")
|
|
}
|
|
|
|
if err := r.checkUniquePortHost(kbs); err != nil {
|
|
return err
|
|
}
|
|
if err := r.SyncKBAccessSettingsToCaddy(ctx, kbs); err != nil {
|
|
r.logger.Error("failed to sync kb access settings to caddy", "error", err)
|
|
return err
|
|
}
|
|
type AppBtn struct {
|
|
ID string `json:"id"`
|
|
Icon string `json:"icon"`
|
|
ShowIcon bool `json:"showIcon"`
|
|
Target string `json:"target"`
|
|
Text string `json:"text"`
|
|
URL string `json:"url"`
|
|
Variant string `json:"variant"`
|
|
}
|
|
if err := tx.Create(&domain.App{
|
|
ID: uuid.New().String(),
|
|
KBID: kb.ID,
|
|
Name: kb.Name,
|
|
Type: domain.AppTypeWeb,
|
|
Settings: domain.AppSettings{
|
|
Title: kb.Name,
|
|
Desc: kb.Name,
|
|
Keyword: kb.Name,
|
|
Icon: domain.DefaultPandaWikiIconB64,
|
|
WelcomeStr: fmt.Sprintf("欢迎使用%s", kb.Name),
|
|
Btns: []any{
|
|
AppBtn{
|
|
ID: uuid.New().String(),
|
|
Icon: domain.DefaultGitHubIconB64,
|
|
ShowIcon: true,
|
|
Target: "_blank",
|
|
Text: "GitHub",
|
|
URL: "https://ly.safepoint.cloud/XEyeWqL",
|
|
Variant: "contained",
|
|
},
|
|
AppBtn{
|
|
ID: uuid.New().String(),
|
|
Icon: "",
|
|
ShowIcon: false,
|
|
Target: "_blank",
|
|
Text: "PandaWiki",
|
|
URL: "https://pandawiki.docs.baizhi.cloud",
|
|
Variant: "outlined",
|
|
},
|
|
},
|
|
},
|
|
}).Error; err != nil {
|
|
return err
|
|
}
|
|
var user domain.User
|
|
err := r.db.WithContext(ctx).
|
|
Where("id = ?", authInfo.UserId).
|
|
First(&user).Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// 非管理员用户需要user到kb创建映射关系
|
|
if user.Role != consts.UserRoleAdmin {
|
|
if err := r.CreateKBUser(ctx, &domain.KBUsers{
|
|
KBId: kb.ID,
|
|
UserId: authInfo.UserId,
|
|
Perm: consts.UserKBPermissionFullControl,
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func (r *KnowledgeBaseRepository) checkUniquePortHost(kbList []*domain.KnowledgeBaseListItem) error {
|
|
uniqPortHost := make(map[string]bool)
|
|
for _, kb := range kbList {
|
|
for _, port := range kb.AccessSettings.Ports {
|
|
for _, host := range kb.AccessSettings.Hosts {
|
|
portHostStr := fmt.Sprintf("%d%s", port, host)
|
|
if _, ok := uniqPortHost[portHostStr]; !ok {
|
|
uniqPortHost[portHostStr] = true
|
|
} else {
|
|
r.logger.Error("port and host already exists", "port", port, "host", host)
|
|
return domain.ErrPortHostAlreadyExists
|
|
}
|
|
}
|
|
}
|
|
for _, sslPort := range kb.AccessSettings.SSLPorts {
|
|
for _, host := range kb.AccessSettings.Hosts {
|
|
portHostStr := fmt.Sprintf("%d%s", sslPort, host)
|
|
if _, ok := uniqPortHost[portHostStr]; !ok {
|
|
uniqPortHost[portHostStr] = true
|
|
} else {
|
|
r.logger.Error("port and host already exists", "port", sslPort, "host", host)
|
|
return domain.ErrPortHostAlreadyExists
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *KnowledgeBaseRepository) GetKnowledgeBaseList(ctx context.Context) ([]*domain.KnowledgeBaseListItem, error) {
|
|
var kbs []*domain.KnowledgeBaseListItem
|
|
if err := r.db.WithContext(ctx).
|
|
Model(&domain.KnowledgeBase{}).
|
|
Order("created_at ASC").
|
|
Find(&kbs).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return kbs, nil
|
|
}
|
|
|
|
func (r *KnowledgeBaseRepository) GetKnowledgeBaseIds(ctx context.Context) ([]string, error) {
|
|
var ids []string
|
|
if err := r.db.WithContext(ctx).
|
|
Model(&domain.KnowledgeBase{}).
|
|
Pluck("id", &ids).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return ids, nil
|
|
}
|
|
|
|
func (r *KnowledgeBaseRepository) GetKnowledgeBaseListByUserId(ctx context.Context) ([]*domain.KnowledgeBaseListItem, error) {
|
|
kbs := make([]*domain.KnowledgeBaseListItem, 0)
|
|
authInfo := domain.GetAuthInfoFromCtx(ctx)
|
|
if authInfo == nil {
|
|
return nil, fmt.Errorf("authInfo not found in context")
|
|
}
|
|
|
|
if authInfo.IsToken {
|
|
if err := r.db.WithContext(ctx).
|
|
Model(&domain.KnowledgeBase{}).
|
|
Where("id = ?", authInfo.KBId).
|
|
Order("created_at ASC").
|
|
Find(&kbs).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
} else {
|
|
var user domain.User
|
|
err := r.db.WithContext(ctx).
|
|
Where("id = ?", authInfo.UserId).
|
|
First(&user).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if user.Role == consts.UserRoleAdmin {
|
|
if err := r.db.WithContext(ctx).
|
|
Model(&domain.KnowledgeBase{}).
|
|
Order("created_at ASC").
|
|
Find(&kbs).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
} else {
|
|
var kbIDs []string
|
|
if err := r.db.WithContext(ctx).
|
|
Table("kb_users").
|
|
Where("user_id = ?", authInfo.UserId).
|
|
Pluck("kb_id", &kbIDs).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
if len(kbIDs) > 0 {
|
|
if err := r.db.WithContext(ctx).
|
|
Model(&domain.KnowledgeBase{}).
|
|
Where("id IN ?", kbIDs).
|
|
Order("created_at ASC").
|
|
Find(&kbs).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return kbs, nil
|
|
}
|
|
|
|
func (r *KnowledgeBaseRepository) UpdateDatasetID(ctx context.Context, kbID, datasetID string) error {
|
|
return r.db.WithContext(ctx).
|
|
Model(&domain.KnowledgeBase{}).
|
|
Where("id = ?", kbID).
|
|
Update("dataset_id", datasetID).Error
|
|
}
|
|
|
|
func (r *KnowledgeBaseRepository) UpdateKnowledgeBase(ctx context.Context, req *domain.UpdateKnowledgeBaseReq) (bool, error) {
|
|
var isChanged bool
|
|
kb, err := r.GetKnowledgeBaseByID(ctx, req.ID)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
updateMap := map[string]any{}
|
|
if req.Name != nil {
|
|
updateMap["name"] = req.Name
|
|
}
|
|
if req.AccessSettings != nil {
|
|
updateMap["access_settings"] = req.AccessSettings
|
|
}
|
|
|
|
if err = r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
if err := tx.Model(&domain.KnowledgeBase{}).Where("id = ?", req.ID).Updates(updateMap).Error; err != nil {
|
|
return err
|
|
}
|
|
// get all kb list
|
|
var kbs []*domain.KnowledgeBaseListItem
|
|
if err := tx.Model(&domain.KnowledgeBase{}).
|
|
Order("created_at ASC").
|
|
Find(&kbs).Error; err != nil {
|
|
return err
|
|
}
|
|
if err := r.checkUniquePortHost(kbs); err != nil {
|
|
return err
|
|
}
|
|
if err := r.SyncKBAccessSettingsToCaddy(ctx, kbs); err != nil {
|
|
return fmt.Errorf("failed to sync kb access settings to caddy: %w", err)
|
|
}
|
|
return nil
|
|
}); err != nil {
|
|
return false, err
|
|
}
|
|
|
|
kbNew, err := r.GetKnowledgeBaseByID(ctx, req.ID)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
if !cmp.Equal(kbNew.AccessSettings, kb.AccessSettings) {
|
|
isChanged = true
|
|
}
|
|
|
|
return isChanged, nil
|
|
}
|
|
|
|
func (r *KnowledgeBaseRepository) GetKnowledgeBaseByID(ctx context.Context, kbID string) (*domain.KnowledgeBase, error) {
|
|
var kb domain.KnowledgeBase
|
|
if err := r.db.WithContext(ctx).Where("id = ?", kbID).First(&kb).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return &kb, nil
|
|
}
|
|
|
|
func (r *KnowledgeBaseRepository) DeleteKnowledgeBase(ctx context.Context, kbID string) error {
|
|
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
if err := tx.Where("kb_id = ?", kbID).Delete(&domain.Node{}).Error; err != nil {
|
|
return err
|
|
}
|
|
if err := tx.Where("kb_id = ?", kbID).Delete(&domain.App{}).Error; err != nil {
|
|
return err
|
|
}
|
|
if err := tx.Where("id = ?", kbID).Delete(&domain.KnowledgeBase{}).Error; err != nil {
|
|
return err
|
|
}
|
|
// get all kb list
|
|
var kbs []*domain.KnowledgeBaseListItem
|
|
if err := tx.Model(&domain.KnowledgeBase{}).
|
|
Order("created_at ASC").
|
|
Find(&kbs).Error; err != nil {
|
|
return err
|
|
}
|
|
if err := r.SyncKBAccessSettingsToCaddy(ctx, kbs); err != nil {
|
|
return fmt.Errorf("failed to sync kb access settings to caddy: %w", err)
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func (r *KnowledgeBaseRepository) CreateKBRelease(ctx context.Context, release *domain.KBRelease) error {
|
|
if err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
// create new release
|
|
if err := tx.Create(release).Error; err != nil {
|
|
return err
|
|
}
|
|
// create release node for all released nodes
|
|
var nodeReleases []*domain.NodeRelease
|
|
if err := tx.Where("kb_id = ?", release.KBID).
|
|
Select("DISTINCT ON (node_id) id, node_id").
|
|
Order("node_id, updated_at DESC").
|
|
Find(&nodeReleases).Error; err != nil {
|
|
return err
|
|
}
|
|
if len(nodeReleases) == 0 {
|
|
return nil
|
|
}
|
|
|
|
// build node_id -> nav_id map from current nodes
|
|
type nodeNavID struct {
|
|
ID string `gorm:"column:id"`
|
|
NavID string `gorm:"column:nav_id"`
|
|
}
|
|
var nodeNavIDs []nodeNavID
|
|
nodeIDs := make([]string, len(nodeReleases))
|
|
for i, nr := range nodeReleases {
|
|
nodeIDs[i] = nr.NodeID
|
|
}
|
|
if err := tx.Model(&domain.Node{}).
|
|
Where("id IN ?", nodeIDs).
|
|
Select("id, nav_id").
|
|
Find(&nodeNavIDs).Error; err != nil {
|
|
return err
|
|
}
|
|
navIDMap := make(map[string]string, len(nodeNavIDs))
|
|
for _, n := range nodeNavIDs {
|
|
navIDMap[n.ID] = n.NavID
|
|
}
|
|
|
|
kbReleaseNodeReleases := make([]*domain.KBReleaseNodeRelease, len(nodeReleases))
|
|
for i, nodeRelease := range nodeReleases {
|
|
kbReleaseNodeReleases[i] = &domain.KBReleaseNodeRelease{
|
|
ID: uuid.New().String(),
|
|
KBID: release.KBID,
|
|
ReleaseID: release.ID,
|
|
NodeID: nodeRelease.NodeID,
|
|
NodeReleaseID: nodeRelease.ID,
|
|
NavID: navIDMap[nodeRelease.NodeID],
|
|
CreatedAt: time.Now(),
|
|
}
|
|
}
|
|
if err := tx.CreateInBatches(&kbReleaseNodeReleases, 2000).Error; err != nil {
|
|
return err
|
|
}
|
|
|
|
// snapshot current navs into nav_releases
|
|
var navs []*domain.Nav
|
|
if err := tx.Where("kb_id = ?", release.KBID).
|
|
Order("position ASC").
|
|
Find(&navs).Error; err != nil {
|
|
return err
|
|
}
|
|
if len(navs) > 0 {
|
|
navReleases := make([]*domain.NavRelease, len(navs))
|
|
now := time.Now()
|
|
for i, nav := range navs {
|
|
navReleases[i] = &domain.NavRelease{
|
|
ID: uuid.New().String(),
|
|
NavID: nav.ID,
|
|
ReleaseID: release.ID,
|
|
KbID: release.KBID,
|
|
Name: nav.Name,
|
|
Position: nav.Position,
|
|
CreatedAt: now,
|
|
}
|
|
}
|
|
if err := tx.CreateInBatches(&navReleases, 2000).Error; err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *KnowledgeBaseRepository) GetKBReleaseList(ctx context.Context, kbID string, offset, limit int) (int64, []domain.KBReleaseListItemResp, error) {
|
|
var total int64
|
|
if err := r.db.Model(&domain.KBRelease{}).Where("kb_id = ?", kbID).Count(&total).Error; err != nil {
|
|
return 0, nil, err
|
|
}
|
|
|
|
var releases []domain.KBReleaseListItemResp
|
|
if err := r.db.WithContext(ctx).Model(&domain.KBRelease{}).
|
|
Select("publish.account as publisher_account, kb_releases.*").
|
|
Joins("left join users publish on kb_releases.publisher_id = publish.id").
|
|
Where("kb_id = ?", kbID).
|
|
Order("created_at DESC").
|
|
Offset(offset).
|
|
Limit(limit).
|
|
Find(&releases).Error; err != nil {
|
|
return 0, nil, err
|
|
}
|
|
|
|
return total, releases, nil
|
|
}
|
|
|
|
func (r *KnowledgeBaseRepository) GetLatestRelease(ctx context.Context, kbID string) (*domain.KBRelease, error) {
|
|
var release domain.KBRelease
|
|
if err := r.db.WithContext(ctx).
|
|
Where("kb_id = ?", kbID).
|
|
Order("created_at DESC").
|
|
First(&release).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return &release, nil
|
|
}
|
|
|
|
func (r *KnowledgeBaseRepository) GetKBUserlist(ctx context.Context, kbID string) ([]v1.KBUserListItemResp, error) {
|
|
var users []v1.KBUserListItemResp
|
|
err := r.db.WithContext(ctx).
|
|
Model(&domain.User{}).
|
|
Select("users.id, users.account, users.role, kbu.perm, kbu.created_at").
|
|
Joins("INNER JOIN kb_users kbu ON users.id = kbu.user_id").
|
|
Where("kbu.kb_id = ?", kbID).
|
|
Where("users.role = ?", consts.UserRoleUser).
|
|
Order("kbu.created_at DESC").
|
|
Scan(&users).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var adminUsers []v1.KBUserListItemResp
|
|
err = r.db.WithContext(ctx).
|
|
Model(&domain.User{}).
|
|
Select("users.id, users.account, users.role").
|
|
Where("users.role = ?", consts.UserRoleAdmin).
|
|
Order("Users.id DESC").
|
|
Scan(&adminUsers).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for index := range adminUsers {
|
|
adminUsers[index].Perm = consts.UserKBPermissionFullControl
|
|
}
|
|
|
|
users = append(users, adminUsers...)
|
|
return users, nil
|
|
}
|
|
|
|
func (r *KnowledgeBaseRepository) CreateKBUser(ctx context.Context, kbUser *domain.KBUsers) error {
|
|
|
|
return r.db.WithContext(ctx).Create(kbUser).Error
|
|
}
|
|
|
|
func (r *KnowledgeBaseRepository) UpdateKBUserPerm(ctx context.Context, kbId, userId string, perm consts.UserKBPermission) error {
|
|
return r.db.WithContext(ctx).
|
|
Model(&domain.KBUsers{}).
|
|
Where("kb_id = ? AND user_id = ?", kbId, userId).
|
|
Update("perm", perm).Error
|
|
}
|
|
|
|
func (r *KnowledgeBaseRepository) DeleteKBUser(ctx context.Context, kbId, userId string) error {
|
|
return r.db.WithContext(ctx).
|
|
Where("kb_id = ? AND user_id = ?", kbId, userId).
|
|
Delete(&domain.KBUsers{}).Error
|
|
}
|
|
|
|
func (r *KnowledgeBaseRepository) GetKBUser(ctx context.Context, kbId, userId string) (*domain.KBUsers, error) {
|
|
var users domain.KBUsers
|
|
err := r.db.WithContext(ctx).
|
|
Where("kb_id = ? AND user_id = ?", kbId, userId).
|
|
First(&users).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &users, err
|
|
}
|
|
|
|
func (r *KnowledgeBaseRepository) GetKBPermByUserId(ctx context.Context, kbId string) (consts.UserKBPermission, error) {
|
|
authInfo := domain.GetAuthInfoFromCtx(ctx)
|
|
if authInfo == nil {
|
|
return "", fmt.Errorf("authInfo not found in context")
|
|
}
|
|
|
|
var (
|
|
user domain.User
|
|
perm consts.UserKBPermission
|
|
)
|
|
|
|
if authInfo.IsToken {
|
|
if authInfo.KBId != kbId {
|
|
return "", errors.New("token kb permission denied")
|
|
}
|
|
|
|
return authInfo.Permission, nil
|
|
} else {
|
|
if err := r.db.WithContext(ctx).Model(&domain.User{}).Where("id = ?", authInfo.UserId).First(&user).Error; err != nil {
|
|
return perm, err
|
|
}
|
|
if user.Role == consts.UserRoleAdmin {
|
|
return consts.UserKBPermissionFullControl, nil
|
|
}
|
|
kbUser, err := r.GetKBUser(ctx, kbId, authInfo.UserId)
|
|
if err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return consts.UserKBPermissionNull, nil
|
|
}
|
|
return perm, err
|
|
}
|
|
|
|
return kbUser.Perm, nil
|
|
}
|
|
}
|