package rag import ( "context" "fmt" "strings" "github.com/JohannesKaufmann/html-to-markdown/v2/converter" raglite "github.com/chaitin/raglite-go-sdk" "github.com/cloudwego/eino/schema" "github.com/google/uuid" "github.com/chaitin/panda-wiki/config" "github.com/chaitin/panda-wiki/domain" "github.com/chaitin/panda-wiki/log" "github.com/chaitin/panda-wiki/utils" ) type CTRAG struct { client *raglite.Client logger *log.Logger mdConv *converter.Converter } func NewCTRAG(config *config.Config, logger *log.Logger) (*CTRAG, error) { client, err := raglite.NewClient( config.RAG.CTRAG.BaseURL, raglite.WithAPIKey(config.RAG.CTRAG.APIKey), ) if err != nil { return nil, fmt.Errorf("failed to create raglite client: %w", err) } return &CTRAG{ client: client, logger: logger.WithModule("store.vector.ct"), mdConv: NewHTML2MDConverter(), }, nil } func (s *CTRAG) CreateKnowledgeBase(ctx context.Context) (string, error) { dataset, err := s.client.Datasets.Create(ctx, &raglite.CreateDatasetRequest{ Name: uuid.New().String(), }) if err != nil { return "", err } return dataset.ID, nil } func (s *CTRAG) QueryRecords(ctx context.Context, req *QueryRecordsRequest) (string, []*domain.NodeContentChunk, error) { var chatMsgs []raglite.ChatMessage for _, msg := range req.HistoryMsgs { switch msg.Role { case schema.User: chatMsgs = append(chatMsgs, raglite.ChatMessage{ Role: string(msg.Role), Content: msg.Content, }) case schema.Assistant: chatMsgs = append(chatMsgs, raglite.ChatMessage{ Role: string(msg.Role), Content: msg.Content, }) default: continue } } s.logger.Debug("retrieving by history msgs", log.Any("history_msgs", req.HistoryMsgs), log.Any("chat_msgs", chatMsgs)) data := &raglite.RetrieveRequest{ DatasetID: req.DatasetID, Query: req.Query, TopK: 10, Metadata: map[string]interface{}{ "group_ids": req.GroupIDs, }, Tags: req.Tags, SimilarityThreshold: req.SimilarityThreshold, ChatHistory: chatMsgs, MaxChunksPerDoc: req.MaxChunksPerDoc, } res, err := s.client.Search.Retrieve(ctx, data) if err != nil { return "", nil, err } s.logger.Info("retrieve chunks result", log.Int("chunks count", len(res.Results)), log.String("query", res.Query)) nodeChunks := make([]*domain.NodeContentChunk, len(res.Results)) for i, chunk := range res.Results { nodeChunks[i] = &domain.NodeContentChunk{ ID: chunk.ChunkID, Content: chunk.Content, DocID: chunk.DocumentID, } } return res.Query, nodeChunks, nil } func (s *CTRAG) UpsertRecords(ctx context.Context, req *UpsertRecordsRequest) (string, error) { markdown := req.Content // if the content is html, convert it to markdown first if utils.IsLikelyHTML(req.Content) { var err error markdown, err = s.mdConv.ConvertString(req.Content) if err != nil { return "", fmt.Errorf("convert html to markdown failed: %w", err) } } data := &raglite.UploadDocumentRequest{ DatasetID: req.DatasetID, DocumentID: req.DocID, Title: req.Title, File: strings.NewReader(markdown), Filename: fmt.Sprintf("%s.md", req.ID), Metadata: make(map[string]interface{}), } if req.GroupIDs != nil { data.Metadata["group_ids"] = req.GroupIDs } if req.Tags != nil { data.Tags = req.Tags } res, err := s.client.Documents.Upload(ctx, data) if err != nil { return "", fmt.Errorf("upload document text failed: %w", err) } return res.DocumentID, nil } func (s *CTRAG) DeleteRecords(ctx context.Context, datasetID string, docIDs []string) error { if err := s.client.Documents.BatchDelete(ctx, &raglite.BatchDeleteDocumentsRequest{ DatasetID: datasetID, DocumentIDs: docIDs, }); err != nil { return err } return nil } func (s *CTRAG) DeleteKnowledgeBase(ctx context.Context, datasetID string) error { if err := s.client.Datasets.Delete(ctx, datasetID); err != nil { return err } return nil } func (s *CTRAG) AddModel(ctx context.Context, model *domain.Model) (string, error) { maxTokens := model.Parameters.MaxTokens if maxTokens == 0 { maxTokens = 8192 } modelConfig, err := s.client.Models.Create(ctx, &raglite.CreateModelRequest{ Name: model.Model, Provider: string(model.Provider), ModelType: string(model.Type), ModelName: model.Model, Config: raglite.AIModelConfig{ APIBase: model.BaseURL, APIKey: model.APIKey, APIHeader: model.APIHeader, APIVersion: model.APIVersion, MaxTokens: raglite.Ptr(maxTokens), ExtraParameters: model.Parameters.Map(), }, IsDefault: true, }) if err != nil { return "", err } return modelConfig.ID, nil } func (s *CTRAG) UpsertModel(ctx context.Context, model *domain.Model) error { maxTokens := model.Parameters.MaxTokens if maxTokens == 0 { maxTokens = 8192 } data := raglite.UpsertModelRequest{ Name: model.Model, Provider: string(model.Provider), ModelName: model.Model, ModelType: string(model.Type), Config: raglite.AIModelConfig{ APIBase: model.BaseURL, APIKey: model.APIKey, APIHeader: model.APIHeader, APIVersion: model.APIVersion, MaxTokens: raglite.Ptr(maxTokens), ExtraParameters: model.Parameters.Map(), }, IsDefault: true, IsActive: model.IsActive, } _, err := s.client.Models.Upsert(ctx, &data) if err != nil { return err } return nil } func (s *CTRAG) UpdateModel(ctx context.Context, model *domain.Model) error { maxTokens := model.Parameters.MaxTokens if maxTokens == 0 { maxTokens = 8192 } data := raglite.UpdateModelRequest{ Name: raglite.Ptr(model.Model), Provider: raglite.Ptr(string(model.Provider)), ModelName: raglite.Ptr(model.Model), Config: &raglite.AIModelConfig{ APIBase: model.BaseURL, APIKey: model.APIKey, APIHeader: model.APIHeader, APIVersion: model.APIVersion, MaxTokens: raglite.Ptr(maxTokens), ExtraParameters: model.Parameters.Map(), }, IsDefault: raglite.Ptr(true), IsActive: raglite.Ptr(model.IsActive), } _, err := s.client.Models.Update(ctx, model.ID, &data) if err != nil { return err } return nil } func (s *CTRAG) DeleteModel(ctx context.Context, model *domain.Model) error { err := s.client.Models.Delete(ctx, model.ID) if err != nil { return err } return nil } func (s *CTRAG) GetModelList(ctx context.Context) ([]*domain.Model, error) { res, err := s.client.Models.List(ctx, &raglite.ListModelsRequest{}) if err != nil { return nil, err } models := make([]*domain.Model, len(res.Models)) for i, model := range res.Models { models[i] = &domain.Model{ ID: model.ID, Model: model.Name, BaseURL: model.Config.APIBase, APIKey: model.Config.APIKey, Type: domain.ModelType(model.ModelType), } } return models, nil } func (s *CTRAG) UpdateDocumentGroupIDs(ctx context.Context, datasetID string, docID string, groupIds []int) error { req := &raglite.UpdateDocumentRequest{ DatasetID: datasetID, DocumentID: docID, Metadata: map[string]interface{}{}, } if groupIds != nil { req.Metadata["group_ids"] = groupIds } _, err := s.client.Documents.Update(ctx, req) if err != nil { return fmt.Errorf("update document group IDs failed: %w", err) } return nil } func (s *CTRAG) ListDocuments(ctx context.Context, datasetID string, documentIDs []string) ([]Document, error) { res, err := s.client.Documents.List(ctx, &raglite.ListDocumentsRequest{ DocumentIDs: documentIDs, DatasetID: datasetID, }) if err != nil { return nil, err } documents := make([]Document, len(res.Documents)) for i, document := range res.Documents { documents[i] = Document{ ID: document.ID, Name: document.Filename, DatasetID: document.DatasetID, Status: document.Status, ProgressMsg: document.ProgressMsg, Tags: document.Tags, MetaData: raglite.Decode[DocumentMetadata](document.Metadata), } } return documents, nil }