init push
This commit is contained in:
189
backend/domain/model.go
Normal file
189
backend/domain/model.go
Normal file
@@ -0,0 +1,189 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
modelkitConsts "github.com/chaitin/ModelKit/v2/consts"
|
||||
modelkitDomain "github.com/chaitin/ModelKit/v2/domain"
|
||||
)
|
||||
|
||||
type ModelProvider string
|
||||
|
||||
const (
|
||||
ModelProviderBrandBaiZhiCloud ModelProvider = "BaiZhiCloud"
|
||||
)
|
||||
|
||||
type ModelType string
|
||||
|
||||
const (
|
||||
ModelTypeChat ModelType = "chat"
|
||||
ModelTypeEmbedding ModelType = "embedding"
|
||||
ModelTypeRerank ModelType = "rerank"
|
||||
ModelTypeAnalysis ModelType = "analysis"
|
||||
ModelTypeAnalysisVL ModelType = "analysis-vl"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
ID string `json:"id"`
|
||||
Provider ModelProvider `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
APIKey string `json:"api_key"`
|
||||
APIHeader string `json:"api_header"`
|
||||
BaseURL string `json:"base_url"`
|
||||
APIVersion string `json:"api_version"` // for azure openai
|
||||
Type ModelType `json:"type" gorm:"default:chat;uniqueIndex"`
|
||||
|
||||
IsActive bool `json:"is_active" gorm:"default:false"`
|
||||
|
||||
PromptTokens uint64 `json:"prompt_tokens" gorm:"default:0"`
|
||||
CompletionTokens uint64 `json:"completion_tokens" gorm:"default:0"`
|
||||
TotalTokens uint64 `json:"total_tokens" gorm:"default:0"`
|
||||
|
||||
Parameters ModelParam `json:"parameters" gorm:"column:parameters;type:jsonb"` // 高级参数
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// ToModelkitModel converts domain.Model to modelkitDomain.PandaModel
|
||||
func (m *Model) ToModelkitModel() (*modelkitDomain.ModelMetadata, error) {
|
||||
provider := modelkitConsts.ParseModelProvider(string(m.Provider))
|
||||
modelType := modelkitConsts.ParseModelType(string(m.Type))
|
||||
|
||||
return &modelkitDomain.ModelMetadata{
|
||||
Provider: provider,
|
||||
ModelName: m.Model,
|
||||
APIKey: m.APIKey,
|
||||
BaseURL: m.BaseURL,
|
||||
APIVersion: m.APIVersion,
|
||||
APIHeader: m.APIHeader,
|
||||
ModelType: modelType,
|
||||
Temperature: m.Parameters.Temperature,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type ModelListItem struct {
|
||||
ID string `json:"id"`
|
||||
Provider ModelProvider `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
APIKey string `json:"api_key"`
|
||||
APIHeader string `json:"api_header"`
|
||||
BaseURL string `json:"base_url"`
|
||||
APIVersion string `json:"api_version"` // for azure openai
|
||||
Type ModelType `json:"type"`
|
||||
|
||||
IsActive bool `json:"is_active" gorm:"default:false"`
|
||||
|
||||
PromptTokens uint64 `json:"prompt_tokens"`
|
||||
CompletionTokens uint64 `json:"completion_tokens"`
|
||||
TotalTokens uint64 `json:"total_tokens"`
|
||||
Parameters ModelParam `json:"parameters" gorm:"column:parameters;type:jsonb"`
|
||||
}
|
||||
|
||||
type CreateModelReq struct {
|
||||
BaseModelInfo
|
||||
Parameters *ModelParam `json:"parameters"`
|
||||
}
|
||||
|
||||
type UpdateModelReq struct {
|
||||
ID string `json:"id" validate:"required"`
|
||||
BaseModelInfo
|
||||
Parameters *ModelParam `json:"parameters"`
|
||||
IsActive *bool `json:"is_active"`
|
||||
}
|
||||
|
||||
type CheckModelReq struct {
|
||||
BaseModelInfo
|
||||
Parameters *ModelParam `json:"parameters"`
|
||||
}
|
||||
|
||||
type ModelParam struct {
|
||||
ContextWindow int `json:"context_window"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
R1Enabled bool `json:"r1_enabled"`
|
||||
SupportComputerUse bool `json:"support_computer_use"`
|
||||
SupportImages bool `json:"support_images"`
|
||||
SupportPromptCache bool `json:"support_prompt_cache"`
|
||||
Temperature *float32 `json:"temperature"`
|
||||
}
|
||||
|
||||
func (p ModelParam) Map() map[string]any {
|
||||
return map[string]any{
|
||||
"context_window": p.ContextWindow,
|
||||
"max_tokens": p.MaxTokens,
|
||||
"r1_enabled": p.R1Enabled,
|
||||
"support_computer_use": p.SupportComputerUse,
|
||||
"support_images": p.SupportImages,
|
||||
"support_prompt_cache": p.SupportPromptCache,
|
||||
"temperature": p.Temperature,
|
||||
}
|
||||
}
|
||||
|
||||
// Value implements the driver.Valuer interface for GORM
|
||||
func (p ModelParam) Value() (driver.Value, error) {
|
||||
return json.Marshal(p)
|
||||
}
|
||||
|
||||
// Scan implements the sql.Scanner interface for GORM
|
||||
func (p *ModelParam) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch v := value.(type) {
|
||||
case []byte:
|
||||
return json.Unmarshal(v, p)
|
||||
case string:
|
||||
return json.Unmarshal([]byte(v), p)
|
||||
default:
|
||||
return fmt.Errorf("cannot scan %T into ModelParam", value)
|
||||
}
|
||||
}
|
||||
|
||||
type BaseModelInfo struct {
|
||||
Provider ModelProvider `json:"provider" validate:"required"`
|
||||
Model string `json:"model" validate:"required"`
|
||||
BaseURL string `json:"base_url" validate:"required"`
|
||||
APIKey string `json:"api_key"`
|
||||
APIHeader string `json:"api_header"`
|
||||
APIVersion string `json:"api_version"` // for azure openai
|
||||
Type ModelType `json:"type" validate:"required,oneof=chat embedding rerank analysis analysis-vl"`
|
||||
}
|
||||
|
||||
type CheckModelResp struct {
|
||||
Error string `json:"error"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type GetProviderModelListReq struct {
|
||||
Provider string `json:"provider" query:"provider" validate:"required"`
|
||||
BaseURL string `json:"base_url" query:"base_url" validate:"required"`
|
||||
APIKey string `json:"api_key" query:"api_key"`
|
||||
APIHeader string `json:"api_header" query:"api_header"`
|
||||
Type ModelType `json:"type" query:"type" validate:"required,oneof=chat embedding rerank analysis analysis-vl"`
|
||||
}
|
||||
|
||||
type GetProviderModelListResp struct {
|
||||
Models []ProviderModelListItem `json:"models"`
|
||||
}
|
||||
|
||||
type ProviderModelListItem struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
type ActivateModelReq struct {
|
||||
ModelID string `json:"model_id" validate:"required"`
|
||||
}
|
||||
|
||||
type SwitchModeReq struct {
|
||||
Mode string `json:"mode" validate:"required,oneof=manual auto"`
|
||||
AutoModeAPIKey string `json:"auto_mode_api_key"` // 百智云 API Key
|
||||
ChatModel string `json:"chat_model"` // 自定义对话模型名称
|
||||
}
|
||||
|
||||
type SwitchModeResp struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
Reference in New Issue
Block a user