init push

This commit is contained in:
2026-05-21 19:52:45 +08:00
commit e3f75311ab
1280 changed files with 179173 additions and 0 deletions

370
backend/usecase/model.go Normal file
View File

@@ -0,0 +1,370 @@
package usecase
import (
"context"
"encoding/json"
"fmt"
"github.com/cloudwego/eino/schema"
modelkitDomain "github.com/chaitin/ModelKit/v2/domain"
modelkit "github.com/chaitin/ModelKit/v2/usecase"
"github.com/chaitin/panda-wiki/config"
"github.com/chaitin/panda-wiki/consts"
"github.com/chaitin/panda-wiki/domain"
"github.com/chaitin/panda-wiki/log"
"github.com/chaitin/panda-wiki/repo/mq"
"github.com/chaitin/panda-wiki/repo/pg"
"github.com/chaitin/panda-wiki/store/rag"
)
type ModelUsecase struct {
modelRepo *pg.ModelRepository
logger *log.Logger
config *config.Config
nodeRepo *pg.NodeRepository
ragRepo *mq.RAGRepository
ragStore rag.RAGService
kbRepo *pg.KnowledgeBaseRepository
systemSettingRepo *pg.SystemSettingRepo
modelkit *modelkit.ModelKit
}
func NewModelUsecase(modelRepo *pg.ModelRepository, nodeRepo *pg.NodeRepository, ragRepo *mq.RAGRepository, ragStore rag.RAGService, logger *log.Logger, config *config.Config, kbRepo *pg.KnowledgeBaseRepository, settingRepo *pg.SystemSettingRepo) *ModelUsecase {
modelkit := modelkit.NewModelKit(logger.Logger)
u := &ModelUsecase{
modelRepo: modelRepo,
logger: logger.WithModule("usecase.model"),
config: config,
nodeRepo: nodeRepo,
ragRepo: ragRepo,
ragStore: ragStore,
kbRepo: kbRepo,
systemSettingRepo: settingRepo,
modelkit: modelkit,
}
return u
}
func (u *ModelUsecase) Create(ctx context.Context, model *domain.Model) error {
var updatedEmbeddingModel bool
if model.Type == domain.ModelTypeEmbedding {
updatedEmbeddingModel = true
}
if err := u.modelRepo.Create(ctx, model); err != nil {
return err
}
// 模型更新成功后,如果更新嵌入模型,则触发记录更新
if updatedEmbeddingModel {
if _, err := u.updateModeSettingConfig(ctx, "", "", "", true); err != nil {
return err
}
}
return nil
}
func (u *ModelUsecase) GetList(ctx context.Context) ([]*domain.ModelListItem, error) {
return u.modelRepo.GetList(ctx)
}
// trigger upsert records after embedding model is updated or created
func (u *ModelUsecase) TriggerUpsertRecords(ctx context.Context) error {
// update to new dataset
kbList, err := u.kbRepo.GetKnowledgeBaseList(ctx)
if err != nil {
return fmt.Errorf("get knowledge base list failed: %w", err)
}
for _, kb := range kbList {
newDatasetID, err := u.ragStore.CreateKnowledgeBase(ctx)
if err != nil {
return fmt.Errorf("create new dataset failed: %w", err)
}
if err := u.ragStore.DeleteKnowledgeBase(ctx, kb.DatasetID); err != nil {
return fmt.Errorf("delete old dataset failed: %w", err)
}
if err := u.kbRepo.UpdateDatasetID(ctx, kb.ID, newDatasetID); err != nil {
return fmt.Errorf("update knowledge base dataset id failed: %w", err)
}
}
// traverse all nodes
err = u.nodeRepo.TraverseNodesByCursor(ctx, func(nodeRelease *domain.NodeRelease) error {
// async upsert vector content via mq
nodeContentVectorRequests := []*domain.NodeReleaseVectorRequest{
{
KBID: nodeRelease.KBID,
NodeReleaseID: nodeRelease.ID,
Action: "upsert",
},
}
if err := u.ragRepo.AsyncUpdateNodeReleaseVector(ctx, nodeContentVectorRequests); err != nil {
return err
}
return nil
})
if err != nil {
return err
}
return nil
}
func (u *ModelUsecase) Update(ctx context.Context, req *domain.UpdateModelReq) error {
var updatedEmbeddingModel bool
if req.Type == domain.ModelTypeEmbedding {
updatedEmbeddingModel = true
}
if err := u.modelRepo.Update(ctx, req); err != nil {
return err
}
data := &domain.Model{
Provider: req.Provider,
Model: req.Model,
Type: req.Type,
APIKey: req.APIKey,
BaseURL: req.BaseURL,
APIHeader: req.APIHeader,
APIVersion: req.APIVersion,
}
if req.IsActive != nil {
data.IsActive = *req.IsActive
}
if req.Parameters != nil {
data.Parameters = *req.Parameters
}
if err := u.ragStore.UpsertModel(ctx, data); err != nil {
return err
}
// 模型更新成功后,如果更新嵌入模型,则触发记录更新
if updatedEmbeddingModel {
if _, err := u.updateModeSettingConfig(ctx, "", "", "", true); err != nil {
return err
}
}
return nil
}
func (u *ModelUsecase) GetChatModel(ctx context.Context) (*domain.Model, error) {
var model *domain.Model
modelModeSetting, err := u.GetModelModeSetting(ctx)
// 获取不到模型模式时,使用手动模式, 不返回错误
if err != nil {
u.logger.Error("get model mode setting failed, use manual mode", log.Error(err))
}
if err == nil && modelModeSetting.Mode == consts.ModelSettingModeAuto && modelModeSetting.AutoModeAPIKey != "" {
modelName := modelModeSetting.ChatModel
if modelName == "" {
modelName = string(consts.AutoModeDefaultChatModel)
}
model = &domain.Model{
Model: modelName,
Type: domain.ModelTypeChat,
IsActive: true,
BaseURL: consts.AutoModeBaseURL,
APIKey: modelModeSetting.AutoModeAPIKey,
Provider: domain.ModelProviderBrandBaiZhiCloud,
}
return model, nil
}
model, err = u.modelRepo.GetChatModel(ctx)
if err != nil {
return nil, err
}
return model, nil
}
func (u *ModelUsecase) GetModelByType(ctx context.Context, modelType domain.ModelType) (*domain.Model, error) {
return u.modelRepo.GetModelByType(ctx, modelType)
}
func (u *ModelUsecase) UpdateUsage(ctx context.Context, modelID string, usage *schema.TokenUsage) error {
return u.modelRepo.UpdateUsage(ctx, modelID, usage)
}
func (u *ModelUsecase) SwitchMode(ctx context.Context, req *domain.SwitchModeReq) error {
switch consts.ModelSettingMode(req.Mode) {
case consts.ModelSettingModeAuto:
if req.AutoModeAPIKey == "" {
return fmt.Errorf("auto mode api key is required")
}
modelName := req.ChatModel
if modelName == "" {
modelName = consts.GetAutoModeDefaultModel(string(domain.ModelTypeChat))
}
// 检查 API Key 是否有效
check, err := u.modelkit.CheckModel(ctx, &modelkitDomain.CheckModelReq{
Provider: string(domain.ModelProviderBrandBaiZhiCloud),
Model: modelName,
BaseURL: consts.AutoModeBaseURL,
APIKey: req.AutoModeAPIKey,
Type: string(domain.ModelTypeChat),
})
if err != nil {
return fmt.Errorf("百智云模型 API Key 检查失败: %w", err)
}
if check.Error != "" {
return fmt.Errorf("百智云模型 API Key 检查失败: %s", check.Error)
}
case consts.ModelSettingModeManual:
needModelTypes := []domain.ModelType{
domain.ModelTypeChat,
domain.ModelTypeEmbedding,
domain.ModelTypeRerank,
domain.ModelTypeAnalysis,
}
for _, modelType := range needModelTypes {
model, err := u.modelRepo.GetModelByType(ctx, modelType)
if err != nil {
return fmt.Errorf("需要配置 %s 模型", modelType)
}
if !model.IsActive {
if err := u.modelRepo.Updates(ctx, model.ID, map[string]any{
"is_active": true,
}); err != nil {
return err
}
}
}
default:
return fmt.Errorf("invalid req mode: %s", req.Mode)
}
oldModelModeSetting, err := u.GetModelModeSetting(ctx)
if err != nil {
return err
}
var isResetEmbeddingUpdateFlag = true
// 只有切换手动模式时重置isManualEmbeddingUpdated为false
if req.Mode == string(consts.ModelSettingModeManual) {
isResetEmbeddingUpdateFlag = false
}
modelModeSetting, err := u.updateModeSettingConfig(ctx, req.Mode, req.AutoModeAPIKey, req.ChatModel, isResetEmbeddingUpdateFlag)
if err != nil {
return err
}
if err := u.updateRAGModelsByMode(ctx, req.Mode, modelModeSetting.AutoModeAPIKey, oldModelModeSetting); err != nil {
return err
}
return nil
}
// updateModeSettingConfig 读取当前设置并更新,然后持久化
func (u *ModelUsecase) updateModeSettingConfig(ctx context.Context, mode, apiKey, chatModel string, isManualEmbeddingUpdated bool) (*domain.ModelModeSetting, error) {
// 读取当前设置
setting, err := u.systemSettingRepo.GetSystemSetting(ctx, consts.SystemSettingModelMode)
if err != nil {
return nil, fmt.Errorf("failed to get current model setting: %w", err)
}
var config domain.ModelModeSetting
if err := json.Unmarshal(setting.Value, &config); err != nil {
return nil, fmt.Errorf("failed to parse current model setting: %w", err)
}
// 更新设置
if apiKey != "" {
config.AutoModeAPIKey = apiKey
}
if chatModel != "" {
config.ChatModel = chatModel
}
if mode != "" {
config.Mode = consts.ModelSettingMode(mode)
}
config.IsManualEmbeddingUpdated = isManualEmbeddingUpdated
// 持久化设置
updatedValue, err := json.Marshal(config)
if err != nil {
return nil, fmt.Errorf("failed to marshal updated model setting: %w", err)
}
if err := u.systemSettingRepo.UpdateSystemSetting(ctx, string(consts.SystemSettingModelMode), string(updatedValue)); err != nil {
return nil, fmt.Errorf("failed to update model setting: %w", err)
}
return &config, nil
}
func (u *ModelUsecase) GetModelModeSetting(ctx context.Context) (domain.ModelModeSetting, error) {
setting, err := u.systemSettingRepo.GetSystemSetting(ctx, consts.SystemSettingModelMode)
if err != nil {
return domain.ModelModeSetting{}, fmt.Errorf("failed to get model mode setting: %w", err)
}
var config domain.ModelModeSetting
if err := json.Unmarshal(setting.Value, &config); err != nil {
return domain.ModelModeSetting{}, fmt.Errorf("failed to parse model mode setting: %w", err)
}
// 无效设置检查
if config == (domain.ModelModeSetting{}) || config.Mode == "" {
return domain.ModelModeSetting{}, fmt.Errorf("model mode setting is invalid")
}
return config, nil
}
// updateRAGModelsByMode 根据模式更新 RAG 模型
func (u *ModelUsecase) updateRAGModelsByMode(ctx context.Context, mode, autoModeAPIKey string, oldModelModeSetting domain.ModelModeSetting) error {
var isTriggerUpsertRecords = true
// 手动切换到手动模式, 根据IsManualEmbeddingUpdated字段决定
if oldModelModeSetting.Mode == consts.ModelSettingModeManual && mode == string(consts.ModelSettingModeManual) {
isTriggerUpsertRecords = oldModelModeSetting.IsManualEmbeddingUpdated
}
ragModelTypes := []domain.ModelType{
domain.ModelTypeEmbedding,
domain.ModelTypeRerank,
domain.ModelTypeAnalysis,
domain.ModelTypeAnalysisVL,
domain.ModelTypeChat,
}
for _, modelType := range ragModelTypes {
var model *domain.Model
if mode == string(consts.ModelSettingModeManual) {
// 获取该类型的活跃模型
m, err := u.modelRepo.GetModelByType(ctx, modelType)
if err != nil {
u.logger.Warn("failed to get model by type", log.String("type", string(modelType)), log.Any("error", err))
continue
}
if m == nil || !m.IsActive {
u.logger.Warn("no active model found for type", log.String("type", string(modelType)))
continue
}
model = m
} else {
modelName := consts.GetAutoModeDefaultModel(string(modelType))
model = &domain.Model{
Model: modelName,
Type: modelType,
IsActive: true,
BaseURL: consts.AutoModeBaseURL,
APIKey: autoModeAPIKey,
Provider: domain.ModelProviderBrandBaiZhiCloud,
}
}
// 更新RAG存储中的模型
if model != nil {
// rag store中更新失败不影响其他模型更新
if err := u.ragStore.UpsertModel(ctx, model); err != nil {
u.logger.Error("failed to update model in RAG store", log.String("model_id", model.ID), log.String("type", string(modelType)), log.Any("error", err))
return fmt.Errorf("failed to update model in RAG store: %s", model.Type)
}
u.logger.Info("successfully updated RAG model", log.String("model name: ", string(model.Model)))
}
}
// 触发记录更新
if isTriggerUpsertRecords {
u.logger.Info("embedding model updated, triggering upsert records")
return u.TriggerUpsertRecords(ctx)
}
return nil
}