290 lines
7.8 KiB
Go
290 lines
7.8 KiB
Go
package rag
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/JohannesKaufmann/html-to-markdown/v2/converter"
|
|
raglite "github.com/chaitin/raglite-go-sdk"
|
|
"github.com/cloudwego/eino/schema"
|
|
"github.com/google/uuid"
|
|
|
|
"github.com/chaitin/panda-wiki/config"
|
|
"github.com/chaitin/panda-wiki/domain"
|
|
"github.com/chaitin/panda-wiki/log"
|
|
"github.com/chaitin/panda-wiki/utils"
|
|
)
|
|
|
|
type CTRAG struct {
|
|
client *raglite.Client
|
|
logger *log.Logger
|
|
mdConv *converter.Converter
|
|
}
|
|
|
|
func NewCTRAG(config *config.Config, logger *log.Logger) (*CTRAG, error) {
|
|
client, err := raglite.NewClient(
|
|
config.RAG.CTRAG.BaseURL,
|
|
raglite.WithAPIKey(config.RAG.CTRAG.APIKey),
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create raglite client: %w", err)
|
|
}
|
|
return &CTRAG{
|
|
client: client,
|
|
logger: logger.WithModule("store.vector.ct"),
|
|
mdConv: NewHTML2MDConverter(),
|
|
}, nil
|
|
}
|
|
|
|
func (s *CTRAG) CreateKnowledgeBase(ctx context.Context) (string, error) {
|
|
dataset, err := s.client.Datasets.Create(ctx, &raglite.CreateDatasetRequest{
|
|
Name: uuid.New().String(),
|
|
})
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return dataset.ID, nil
|
|
}
|
|
|
|
func (s *CTRAG) QueryRecords(ctx context.Context, req *QueryRecordsRequest) (string, []*domain.NodeContentChunk, error) {
|
|
var chatMsgs []raglite.ChatMessage
|
|
for _, msg := range req.HistoryMsgs {
|
|
switch msg.Role {
|
|
case schema.User:
|
|
chatMsgs = append(chatMsgs, raglite.ChatMessage{
|
|
Role: string(msg.Role),
|
|
Content: msg.Content,
|
|
})
|
|
case schema.Assistant:
|
|
chatMsgs = append(chatMsgs, raglite.ChatMessage{
|
|
Role: string(msg.Role),
|
|
Content: msg.Content,
|
|
})
|
|
default:
|
|
continue
|
|
}
|
|
}
|
|
s.logger.Debug("retrieving by history msgs", log.Any("history_msgs", req.HistoryMsgs), log.Any("chat_msgs", chatMsgs))
|
|
data := &raglite.RetrieveRequest{
|
|
DatasetID: req.DatasetID,
|
|
Query: req.Query,
|
|
TopK: 10,
|
|
Metadata: map[string]interface{}{
|
|
"group_ids": req.GroupIDs,
|
|
},
|
|
Tags: req.Tags,
|
|
SimilarityThreshold: req.SimilarityThreshold,
|
|
ChatHistory: chatMsgs,
|
|
MaxChunksPerDoc: req.MaxChunksPerDoc,
|
|
}
|
|
res, err := s.client.Search.Retrieve(ctx, data)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
s.logger.Info("retrieve chunks result", log.Int("chunks count", len(res.Results)), log.String("query", res.Query))
|
|
nodeChunks := make([]*domain.NodeContentChunk, len(res.Results))
|
|
for i, chunk := range res.Results {
|
|
nodeChunks[i] = &domain.NodeContentChunk{
|
|
ID: chunk.ChunkID,
|
|
Content: chunk.Content,
|
|
DocID: chunk.DocumentID,
|
|
}
|
|
}
|
|
return res.Query, nodeChunks, nil
|
|
}
|
|
|
|
func (s *CTRAG) UpsertRecords(ctx context.Context, req *UpsertRecordsRequest) (string, error) {
|
|
markdown := req.Content
|
|
// if the content is html, convert it to markdown first
|
|
if utils.IsLikelyHTML(req.Content) {
|
|
var err error
|
|
markdown, err = s.mdConv.ConvertString(req.Content)
|
|
if err != nil {
|
|
return "", fmt.Errorf("convert html to markdown failed: %w", err)
|
|
}
|
|
}
|
|
data := &raglite.UploadDocumentRequest{
|
|
DatasetID: req.DatasetID,
|
|
DocumentID: req.DocID,
|
|
Title: req.Title,
|
|
File: strings.NewReader(markdown),
|
|
Filename: fmt.Sprintf("%s.md", req.ID),
|
|
Metadata: make(map[string]interface{}),
|
|
}
|
|
if req.GroupIDs != nil {
|
|
data.Metadata["group_ids"] = req.GroupIDs
|
|
}
|
|
if req.Tags != nil {
|
|
data.Tags = req.Tags
|
|
}
|
|
res, err := s.client.Documents.Upload(ctx, data)
|
|
if err != nil {
|
|
return "", fmt.Errorf("upload document text failed: %w", err)
|
|
}
|
|
return res.DocumentID, nil
|
|
}
|
|
|
|
func (s *CTRAG) DeleteRecords(ctx context.Context, datasetID string, docIDs []string) error {
|
|
if err := s.client.Documents.BatchDelete(ctx, &raglite.BatchDeleteDocumentsRequest{
|
|
DatasetID: datasetID,
|
|
DocumentIDs: docIDs,
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *CTRAG) DeleteKnowledgeBase(ctx context.Context, datasetID string) error {
|
|
if err := s.client.Datasets.Delete(ctx, datasetID); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *CTRAG) AddModel(ctx context.Context, model *domain.Model) (string, error) {
|
|
maxTokens := model.Parameters.MaxTokens
|
|
if maxTokens == 0 {
|
|
maxTokens = 8192
|
|
}
|
|
modelConfig, err := s.client.Models.Create(ctx, &raglite.CreateModelRequest{
|
|
Name: model.Model,
|
|
Provider: string(model.Provider),
|
|
ModelType: string(model.Type),
|
|
ModelName: model.Model,
|
|
Config: raglite.AIModelConfig{
|
|
APIBase: model.BaseURL,
|
|
APIKey: model.APIKey,
|
|
APIHeader: model.APIHeader,
|
|
APIVersion: model.APIVersion,
|
|
MaxTokens: raglite.Ptr(maxTokens),
|
|
ExtraParameters: model.Parameters.Map(),
|
|
},
|
|
IsDefault: true,
|
|
})
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return modelConfig.ID, nil
|
|
}
|
|
|
|
func (s *CTRAG) UpsertModel(ctx context.Context, model *domain.Model) error {
|
|
maxTokens := model.Parameters.MaxTokens
|
|
if maxTokens == 0 {
|
|
maxTokens = 8192
|
|
}
|
|
data := raglite.UpsertModelRequest{
|
|
Name: model.Model,
|
|
Provider: string(model.Provider),
|
|
ModelName: model.Model,
|
|
ModelType: string(model.Type),
|
|
Config: raglite.AIModelConfig{
|
|
APIBase: model.BaseURL,
|
|
APIKey: model.APIKey,
|
|
APIHeader: model.APIHeader,
|
|
APIVersion: model.APIVersion,
|
|
MaxTokens: raglite.Ptr(maxTokens),
|
|
ExtraParameters: model.Parameters.Map(),
|
|
},
|
|
IsDefault: true,
|
|
IsActive: model.IsActive,
|
|
}
|
|
_, err := s.client.Models.Upsert(ctx, &data)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *CTRAG) UpdateModel(ctx context.Context, model *domain.Model) error {
|
|
maxTokens := model.Parameters.MaxTokens
|
|
if maxTokens == 0 {
|
|
maxTokens = 8192
|
|
}
|
|
data := raglite.UpdateModelRequest{
|
|
Name: raglite.Ptr(model.Model),
|
|
Provider: raglite.Ptr(string(model.Provider)),
|
|
ModelName: raglite.Ptr(model.Model),
|
|
Config: &raglite.AIModelConfig{
|
|
APIBase: model.BaseURL,
|
|
APIKey: model.APIKey,
|
|
APIHeader: model.APIHeader,
|
|
APIVersion: model.APIVersion,
|
|
MaxTokens: raglite.Ptr(maxTokens),
|
|
ExtraParameters: model.Parameters.Map(),
|
|
},
|
|
IsDefault: raglite.Ptr(true),
|
|
IsActive: raglite.Ptr(model.IsActive),
|
|
}
|
|
_, err := s.client.Models.Update(ctx, model.ID, &data)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *CTRAG) DeleteModel(ctx context.Context, model *domain.Model) error {
|
|
err := s.client.Models.Delete(ctx, model.ID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *CTRAG) GetModelList(ctx context.Context) ([]*domain.Model, error) {
|
|
res, err := s.client.Models.List(ctx, &raglite.ListModelsRequest{})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
models := make([]*domain.Model, len(res.Models))
|
|
for i, model := range res.Models {
|
|
models[i] = &domain.Model{
|
|
ID: model.ID,
|
|
Model: model.Name,
|
|
BaseURL: model.Config.APIBase,
|
|
APIKey: model.Config.APIKey,
|
|
Type: domain.ModelType(model.ModelType),
|
|
}
|
|
}
|
|
return models, nil
|
|
}
|
|
|
|
func (s *CTRAG) UpdateDocumentGroupIDs(ctx context.Context, datasetID string, docID string, groupIds []int) error {
|
|
req := &raglite.UpdateDocumentRequest{
|
|
DatasetID: datasetID,
|
|
DocumentID: docID,
|
|
Metadata: map[string]interface{}{},
|
|
}
|
|
if groupIds != nil {
|
|
req.Metadata["group_ids"] = groupIds
|
|
}
|
|
_, err := s.client.Documents.Update(ctx, req)
|
|
if err != nil {
|
|
return fmt.Errorf("update document group IDs failed: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *CTRAG) ListDocuments(ctx context.Context, datasetID string, documentIDs []string) ([]Document, error) {
|
|
res, err := s.client.Documents.List(ctx, &raglite.ListDocumentsRequest{
|
|
DocumentIDs: documentIDs,
|
|
DatasetID: datasetID,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
documents := make([]Document, len(res.Documents))
|
|
for i, document := range res.Documents {
|
|
documents[i] = Document{
|
|
ID: document.ID,
|
|
Name: document.Filename,
|
|
DatasetID: document.DatasetID,
|
|
Status: document.Status,
|
|
ProgressMsg: document.ProgressMsg,
|
|
Tags: document.Tags,
|
|
MetaData: raglite.Decode[DocumentMetadata](document.Metadata),
|
|
}
|
|
}
|
|
return documents, nil
|
|
}
|