802 lines
26 KiB
Go
802 lines
26 KiB
Go
/*
|
|
* Copyright 2025 coze-dev Authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package service
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"regexp"
|
|
"strconv"
|
|
"sync"
|
|
"unicode/utf8"
|
|
|
|
"github.com/cloudwego/eino/components/retriever"
|
|
"github.com/cloudwego/eino/compose"
|
|
"github.com/cloudwego/eino/schema"
|
|
"golang.org/x/sync/errgroup"
|
|
|
|
knowledgeModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
|
"github.com/coze-dev/coze-studio/backend/domain/knowledge/entity"
|
|
"github.com/coze-dev/coze-studio/backend/domain/knowledge/internal/consts"
|
|
"github.com/coze-dev/coze-studio/backend/domain/knowledge/internal/convert"
|
|
"github.com/coze-dev/coze-studio/backend/domain/knowledge/internal/dal/model"
|
|
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
|
|
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
|
|
"github.com/coze-dev/coze-studio/backend/infra/contract/document/nl2sql"
|
|
"github.com/coze-dev/coze-studio/backend/infra/contract/document/rerank"
|
|
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
|
|
"github.com/coze-dev/coze-studio/backend/infra/contract/messages2query"
|
|
"github.com/coze-dev/coze-studio/backend/infra/contract/rdb"
|
|
sqlparsercontract "github.com/coze-dev/coze-studio/backend/infra/contract/sqlparser"
|
|
"github.com/coze-dev/coze-studio/backend/infra/impl/sqlparser"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/lang/sets"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
|
"github.com/coze-dev/coze-studio/backend/types/errno"
|
|
)
|
|
|
|
func (k *knowledgeSVC) Retrieve(ctx context.Context, request *RetrieveRequest) (response *RetrieveResponse, err error) {
|
|
if request == nil {
|
|
return nil, errorx.New(errno.ErrKnowledgeInvalidParamCode, errorx.KV("msg", "request is nil"))
|
|
}
|
|
if len(request.Query) == 0 {
|
|
return &knowledgeModel.RetrieveResponse{}, nil
|
|
}
|
|
retrieveContext, err := k.newRetrieveContext(ctx, request)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(retrieveContext.Documents) == 0 {
|
|
return &knowledgeModel.RetrieveResponse{}, nil
|
|
}
|
|
chain := compose.NewChain[*RetrieveContext, []*knowledgeModel.RetrieveSlice]()
|
|
rewriteNode := compose.InvokableLambda(k.queryRewriteNode)
|
|
// vectorized recall
|
|
vectorRetrieveNode := compose.InvokableLambda(k.vectorRetrieveNode)
|
|
// ES recall
|
|
EsRetrieveNode := compose.InvokableLambda(k.esRetrieveNode)
|
|
// Nl2Sql recall
|
|
Nl2SqlRetrieveNode := compose.InvokableLambda(k.nl2SqlRetrieveNode)
|
|
// pass user query Node
|
|
passRequestContextNode := compose.InvokableLambda(k.passRequestContext)
|
|
// reRank Node
|
|
reRankNode := compose.InvokableLambda(k.reRankNode)
|
|
// Pack Result Interface
|
|
packResult := compose.InvokableLambda(k.packResults)
|
|
parallelNode := compose.NewParallel().
|
|
AddLambda("vectorRetrieveNode", vectorRetrieveNode).
|
|
AddLambda("esRetrieveNode", EsRetrieveNode).
|
|
AddLambda("nl2SqlRetrieveNode", Nl2SqlRetrieveNode).
|
|
AddLambda("passRequestContext", passRequestContextNode)
|
|
|
|
r, err := chain.
|
|
AppendLambda(rewriteNode).
|
|
AppendParallel(parallelNode).
|
|
AppendLambda(reRankNode).
|
|
AppendLambda(packResult).
|
|
Compile(ctx)
|
|
if err != nil {
|
|
logs.CtxErrorf(ctx, "compile chain failed: %v", err)
|
|
return nil, errorx.New(errno.ErrKnowledgeBuildRetrieveChainFailCode, errorx.KV("msg", err.Error()))
|
|
}
|
|
output, err := r.Invoke(ctx, retrieveContext)
|
|
if err != nil {
|
|
logs.CtxErrorf(ctx, "invoke chain failed: %v", err)
|
|
return nil, errorx.New(errno.ErrKnowledgeRetrieveExecFailCode, errorx.KV("msg", err.Error()))
|
|
}
|
|
return &RetrieveResponse{
|
|
RetrieveSlices: output,
|
|
}, nil
|
|
}
|
|
|
|
func (k *knowledgeSVC) newRetrieveContext(ctx context.Context, req *RetrieveRequest) (*RetrieveContext, error) {
|
|
if req.Strategy == nil {
|
|
return nil, errorx.New(errno.ErrKnowledgeInvalidParamCode, errorx.KV("msg", "strategy is required"))
|
|
}
|
|
knowledgeIDSets := sets.FromSlice(req.KnowledgeIDs)
|
|
docIDSets := sets.FromSlice(req.DocumentIDs)
|
|
enableDocs, enableKnowledge, err := k.prepareRAGDocuments(ctx, docIDSets.ToSlice(), knowledgeIDSets.ToSlice())
|
|
if err != nil {
|
|
logs.CtxErrorf(ctx, "prepare rag documents failed: %v", err)
|
|
return nil, err
|
|
}
|
|
if len(enableDocs) == 0 {
|
|
return &RetrieveContext{}, nil
|
|
}
|
|
knowledgeInfoMap := make(map[int64]*KnowledgeInfo)
|
|
for _, kn := range enableKnowledge {
|
|
if knowledgeInfoMap[kn.ID] == nil {
|
|
knowledgeInfoMap[kn.ID] = &KnowledgeInfo{}
|
|
knowledgeInfoMap[kn.ID].DocumentType = knowledgeModel.DocumentType(kn.FormatType)
|
|
knowledgeInfoMap[kn.ID].DocumentIDs = []int64{}
|
|
knowledgeInfoMap[kn.ID].KnowledgeName = kn.Name
|
|
}
|
|
}
|
|
for _, doc := range enableDocs {
|
|
info, found := knowledgeInfoMap[doc.KnowledgeID]
|
|
if !found {
|
|
continue
|
|
}
|
|
info.DocumentIDs = append(info.DocumentIDs, doc.ID)
|
|
if info.DocumentType == knowledgeModel.DocumentTypeTable && info.TableColumns == nil && doc.TableInfo != nil {
|
|
info.TableColumns = doc.TableInfo.Columns
|
|
}
|
|
}
|
|
|
|
var cm chatmodel.BaseChatModel
|
|
if req.ChatModelProtocol != nil && req.ChatModelConfig != nil {
|
|
cm, err = k.modelFactory.CreateChatModel(ctx, ptr.From(req.ChatModelProtocol), req.ChatModelConfig)
|
|
if err != nil {
|
|
return nil, errorx.New(errno.ErrKnowledgeInvalidParamCode,
|
|
errorx.KV("msg", "invalid retriever chat model protocol or config"))
|
|
}
|
|
}
|
|
|
|
resp := RetrieveContext{
|
|
Ctx: ctx,
|
|
OriginQuery: req.Query,
|
|
ChatHistory: append(req.ChatHistory, schema.UserMessage(req.Query)),
|
|
KnowledgeIDs: knowledgeIDSets,
|
|
KnowledgeInfoMap: knowledgeInfoMap,
|
|
Strategy: req.Strategy,
|
|
Documents: enableDocs,
|
|
ChatModel: cm,
|
|
}
|
|
return &resp, nil
|
|
}
|
|
|
|
func (k *knowledgeSVC) prepareRAGDocuments(ctx context.Context, documentIDs []int64, knowledgeIDs []int64) ([]*model.KnowledgeDocument, []*model.Knowledge, error) {
|
|
enableKnowledge, err := k.knowledgeRepo.FilterEnableKnowledge(ctx, knowledgeIDs)
|
|
if err != nil {
|
|
logs.CtxErrorf(ctx, "filter enable knowledge failed: %v", err)
|
|
return nil, nil, errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", err.Error()))
|
|
}
|
|
if len(enableKnowledge) == 0 {
|
|
return nil, nil, nil
|
|
}
|
|
var enableKnowledgeIDs []int64
|
|
for _, kn := range enableKnowledge {
|
|
enableKnowledgeIDs = append(enableKnowledgeIDs, kn.ID)
|
|
}
|
|
enableDocs, _, err := k.documentRepo.FindDocumentByCondition(ctx, &entity.WhereDocumentOpt{
|
|
IDs: documentIDs,
|
|
KnowledgeIDs: enableKnowledgeIDs,
|
|
StatusIn: []int32{int32(entity.DocumentStatusEnable)},
|
|
SelectAll: true,
|
|
})
|
|
if err != nil {
|
|
logs.CtxErrorf(ctx, "find document by condition failed: %v", err)
|
|
return nil, nil, errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", err.Error()))
|
|
}
|
|
return enableDocs, enableKnowledge, nil
|
|
}
|
|
|
|
func (k *knowledgeSVC) queryRewriteNode(ctx context.Context, req *RetrieveContext) (newRetrieveContext *RetrieveContext, err error) {
|
|
if len(req.ChatHistory) == 1 {
|
|
// No context, no rewriting.
|
|
return req, nil
|
|
}
|
|
if !req.Strategy.EnableQueryRewrite || k.rewriter == nil {
|
|
// Rewrite function is not enabled, no context rewrite is required
|
|
return req, nil
|
|
}
|
|
var opts []messages2query.Option
|
|
if req.ChatModel != nil {
|
|
opts = append(opts, messages2query.WithChatModel(req.ChatModel))
|
|
}
|
|
rewrittenQuery, err := k.rewriter.MessagesToQuery(ctx, req.ChatHistory, opts...)
|
|
if err != nil {
|
|
logs.CtxErrorf(ctx, "rewrite query failed: %v", err)
|
|
return req, nil
|
|
}
|
|
// Rewrite completed
|
|
req.RewrittenQuery = &rewrittenQuery
|
|
return req, nil
|
|
}
|
|
|
|
func (k *knowledgeSVC) vectorRetrieveNode(ctx context.Context, req *RetrieveContext) (retrieveResult []*schema.Document, err error) {
|
|
if req.Strategy.SearchType == knowledgeModel.SearchTypeFullText {
|
|
return nil, nil
|
|
}
|
|
var manager searchstore.Manager
|
|
for i := range k.searchStoreManagers {
|
|
m := k.searchStoreManagers[i]
|
|
if m != nil && m.GetType() == searchstore.TypeVectorStore {
|
|
manager = m
|
|
break
|
|
}
|
|
}
|
|
if manager == nil {
|
|
logs.CtxErrorf(ctx, "err:%s", errorx.New(errno.ErrKnowledgeSearchStoreCode, errorx.KV("msg", "未实现vectorStore")).Error())
|
|
return nil, nil
|
|
}
|
|
|
|
retrieveResult, err = k.retrieveChannels(ctx, req, manager)
|
|
if err != nil {
|
|
logs.CtxErrorf(ctx, "retrieveChannels err:%s", err.Error())
|
|
}
|
|
return retrieveResult, nil
|
|
}
|
|
|
|
func (k *knowledgeSVC) esRetrieveNode(ctx context.Context, req *RetrieveContext) (retrieveResult []*schema.Document, err error) {
|
|
if req.Strategy.SearchType == knowledgeModel.SearchTypeSemantic {
|
|
return nil, nil
|
|
}
|
|
var manager searchstore.Manager
|
|
for i := range k.searchStoreManagers {
|
|
m := k.searchStoreManagers[i]
|
|
if m != nil && m.GetType() == searchstore.TypeTextStore {
|
|
manager = m
|
|
break
|
|
}
|
|
}
|
|
if manager == nil {
|
|
logs.CtxErrorf(ctx, "err:%s", errorx.New(errno.ErrKnowledgeSearchStoreCode, errorx.KV("msg", "未实现esStore")).Error())
|
|
return nil, nil
|
|
}
|
|
|
|
retrieveResult, err = k.retrieveChannels(ctx, req, manager)
|
|
if err != nil {
|
|
logs.CtxErrorf(ctx, "retrieveChannels err:%s", err.Error())
|
|
}
|
|
return retrieveResult, nil
|
|
}
|
|
|
|
func (k *knowledgeSVC) retrieveChannels(ctx context.Context, req *RetrieveContext, manager searchstore.Manager) (result []*schema.Document, err error) {
|
|
query := req.OriginQuery
|
|
if req.Strategy.EnableQueryRewrite && req.RewrittenQuery != nil {
|
|
query = *req.RewrittenQuery
|
|
}
|
|
mu := sync.Mutex{}
|
|
eg, ctx := errgroup.WithContext(ctx)
|
|
eg.SetLimit(2)
|
|
for knowledgeID, knowledgeInfo := range req.KnowledgeInfoMap {
|
|
kid := knowledgeID
|
|
info := knowledgeInfo
|
|
collectionName := getCollectionName(kid)
|
|
|
|
dsl := &searchstore.DSL{
|
|
Op: searchstore.OpIn,
|
|
Field: "document_id",
|
|
Value: knowledgeInfo.DocumentIDs,
|
|
}
|
|
partitions := make([]string, 0, len(req.Documents))
|
|
for _, doc := range req.Documents {
|
|
if doc.KnowledgeID == kid {
|
|
partitions = append(partitions, strconv.FormatInt(doc.ID, 10))
|
|
}
|
|
}
|
|
if len(partitions) == 0 {
|
|
continue
|
|
}
|
|
opts := []retriever.Option{
|
|
searchstore.WithRetrieverPartitionKey(fieldNameDocumentID),
|
|
searchstore.WithPartitions(partitions),
|
|
retriever.WithDSLInfo(dsl.DSL()),
|
|
}
|
|
if info.DocumentType == knowledgeModel.DocumentTypeTable && !k.enableCompactTable {
|
|
var matchCols []string
|
|
for _, col := range info.TableColumns {
|
|
if col.Indexing {
|
|
matchCols = append(matchCols, getColName(col.ID))
|
|
}
|
|
}
|
|
opts = append(opts, searchstore.WithMultiMatch(matchCols, query))
|
|
}
|
|
eg.Go(func() error {
|
|
ss, err := manager.GetSearchStore(ctx, collectionName)
|
|
if err != nil {
|
|
return errorx.New(errno.ErrKnowledgeSearchStoreCode, errorx.KV("msg", err.Error()))
|
|
}
|
|
retrievedDocs, err := ss.Retrieve(ctx, query, opts...)
|
|
if err != nil {
|
|
return errorx.New(errno.ErrKnowledgeRetrieveExecFailCode, errorx.KV("msg", err.Error()))
|
|
}
|
|
mu.Lock()
|
|
result = append(result, retrievedDocs...)
|
|
mu.Unlock()
|
|
return nil
|
|
})
|
|
}
|
|
if err = eg.Wait(); err != nil {
|
|
return nil, err
|
|
}
|
|
return
|
|
}
|
|
|
|
func (k *knowledgeSVC) nl2SqlRetrieveNode(ctx context.Context, req *RetrieveContext) (retrieveResult []*schema.Document, err error) {
|
|
hasTable := false
|
|
var tableDocs []*model.KnowledgeDocument
|
|
for _, doc := range req.Documents {
|
|
if doc.DocumentType == int32(knowledgeModel.DocumentTypeTable) {
|
|
hasTable = true
|
|
tableDocs = append(tableDocs, doc)
|
|
}
|
|
}
|
|
var opts []nl2sql.Option
|
|
if req.ChatModel != nil {
|
|
opts = append(opts, nl2sql.WithChatModel(req.ChatModel))
|
|
}
|
|
if hasTable && req.Strategy.EnableNL2SQL {
|
|
mu := sync.Mutex{}
|
|
eg, ctx := errgroup.WithContext(ctx)
|
|
eg.SetLimit(len(tableDocs))
|
|
res := make([]*schema.Document, 0)
|
|
for i := range tableDocs {
|
|
t := i
|
|
eg.Go(func() error {
|
|
doc := tableDocs[t]
|
|
docs, execErr := k.nl2SqlExec(ctx, doc, req, opts)
|
|
if execErr != nil {
|
|
logs.CtxErrorf(ctx, "nl2sql exec failed: %v", execErr)
|
|
return errorx.New(errno.ErrKnowledgeNL2SqlExecFailCode, errorx.KV("msg", execErr.Error()))
|
|
}
|
|
mu.Lock()
|
|
res = append(res, docs...)
|
|
mu.Unlock()
|
|
return nil
|
|
})
|
|
}
|
|
err = eg.Wait()
|
|
if err != nil {
|
|
logs.CtxErrorf(ctx, "nl2sql exec failed: %v", err)
|
|
return nil, nil
|
|
}
|
|
return res, nil
|
|
} else {
|
|
return nil, nil
|
|
}
|
|
}
|
|
|
|
func (k *knowledgeSVC) nl2SqlExec(ctx context.Context, doc *model.KnowledgeDocument, retrieveCtx *RetrieveContext, opts []nl2sql.Option) (
|
|
retrieveResult []*schema.Document, err error) {
|
|
sql, err := k.nl2Sql.NL2SQL(ctx, retrieveCtx.ChatHistory, []*document.TableSchema{packNL2SqlRequest(doc)}, opts...)
|
|
if err != nil {
|
|
logs.CtxErrorf(ctx, "nl2sql failed: %v", err)
|
|
return nil, err
|
|
}
|
|
sql = addSliceIdColumn(sql)
|
|
// Execute sql
|
|
replaceMap := map[string]sqlparsercontract.TableColumn{}
|
|
replaceMap[doc.Name] = sqlparsercontract.TableColumn{
|
|
NewTableName: ptr.Of(doc.TableInfo.PhysicalTableName),
|
|
ColumnMap: map[string]string{
|
|
pkID: consts.RDBFieldID,
|
|
},
|
|
}
|
|
for i := range doc.TableInfo.Columns {
|
|
if doc.TableInfo.Columns[i] == nil {
|
|
continue
|
|
}
|
|
if doc.TableInfo.Columns[i].Name == consts.RDBFieldID {
|
|
continue
|
|
}
|
|
replaceMap[doc.Name].ColumnMap[doc.TableInfo.Columns[i].Name] = convert.ColumnIDToRDBField(doc.TableInfo.Columns[i].ID)
|
|
}
|
|
parsedSQL, err := sqlparser.NewSQLParser().ParseAndModifySQL(sql, replaceMap)
|
|
if err != nil {
|
|
logs.CtxErrorf(ctx, "parse sql failed: %v", err)
|
|
return nil, err
|
|
}
|
|
// Execute sql
|
|
resp, err := k.rdb.ExecuteSQL(ctx, &rdb.ExecuteSQLRequest{
|
|
SQL: parsedSQL,
|
|
})
|
|
if err != nil {
|
|
logs.CtxErrorf(ctx, "execute sql failed: %v", err)
|
|
return nil, err
|
|
}
|
|
for i := range resp.ResultSet.Rows {
|
|
d := &schema.Document{
|
|
Content: "",
|
|
MetaData: map[string]any{
|
|
"document_id": doc.ID,
|
|
"document_name": doc.Name,
|
|
"knowledge_id": doc.KnowledgeID,
|
|
"knowledge_name": retrieveCtx.KnowledgeInfoMap[doc.KnowledgeID].KnowledgeName,
|
|
},
|
|
}
|
|
id, ok := resp.ResultSet.Rows[i][consts.RDBFieldID].(int64)
|
|
if !ok {
|
|
byteData, err := sonic.Marshal(resp.ResultSet.Rows)
|
|
if err != nil {
|
|
logs.CtxErrorf(ctx, "marshal sql resp failed: %v", err)
|
|
return nil, err
|
|
}
|
|
prefix := "sql:" + sql + ";result:"
|
|
d.Content = prefix + string(byteData)
|
|
} else {
|
|
d.ID = strconv.FormatInt(id, 10)
|
|
}
|
|
d.WithScore(1)
|
|
retrieveResult = append(retrieveResult, d)
|
|
}
|
|
return retrieveResult, nil
|
|
}
|
|
|
|
const pkID = "_knowledge_slice_id"
|
|
|
|
func addSliceIdColumn(originalSql string) string {
|
|
sql, err := sqlparser.NewSQLParser().AddSelectFieldsToSelectSQL(originalSql, []string{pkID})
|
|
if err != nil {
|
|
logs.Errorf("add slice id column failed: %v", err)
|
|
return originalSql
|
|
}
|
|
return sql
|
|
}
|
|
func packNL2SqlRequest(doc *model.KnowledgeDocument) *document.TableSchema {
|
|
res := &document.TableSchema{}
|
|
if doc.TableInfo == nil {
|
|
return res
|
|
}
|
|
res.Name = doc.TableInfo.VirtualTableName
|
|
res.Comment = doc.TableInfo.TableDesc
|
|
res.Columns = []*document.Column{}
|
|
for _, column := range doc.TableInfo.Columns {
|
|
if column.Name == consts.RDBFieldID {
|
|
continue
|
|
}
|
|
res.Columns = append(res.Columns, &document.Column{
|
|
Name: column.Name,
|
|
Type: column.Type,
|
|
Description: column.Description,
|
|
Nullable: !column.Indexing,
|
|
IsPrimary: false,
|
|
})
|
|
}
|
|
return res
|
|
}
|
|
|
|
func (k *knowledgeSVC) passRequestContext(ctx context.Context, req *RetrieveContext) (context *RetrieveContext, err error) {
|
|
return req, nil
|
|
}
|
|
|
|
func (k *knowledgeSVC) reRankNode(ctx context.Context, resultMap map[string]any) (retrieveResult []*schema.Document, err error) {
|
|
// First retrieve the context
|
|
retrieveCtx, ok := resultMap["passRequestContext"].(*RetrieveContext)
|
|
if !ok {
|
|
logs.CtxErrorf(ctx, "retrieve context is not found")
|
|
return nil, errorx.New(errno.ErrKnowledgeSystemCode, errorx.KV("msg", "retrieve context is not found"))
|
|
}
|
|
// Get the interface for the downvectorized recall
|
|
vectorRetrieveResult, ok := resultMap["vectorRetrieveNode"].([]*schema.Document)
|
|
if !ok {
|
|
logs.CtxErrorf(ctx, "vector retrieve result is not found")
|
|
vectorRetrieveResult = []*schema.Document{}
|
|
}
|
|
// Get the interface of the es recall.
|
|
esRetrieveResult, ok := resultMap["esRetrieveNode"].([]*schema.Document)
|
|
if !ok {
|
|
logs.CtxErrorf(ctx, "es retrieve result is not found")
|
|
esRetrieveResult = []*schema.Document{}
|
|
}
|
|
// Get the interface recalled under nl2sql
|
|
nl2SqlRetrieveResult, ok := resultMap["nl2SqlRetrieveNode"].([]*schema.Document)
|
|
if !ok {
|
|
logs.CtxErrorf(ctx, "nl2sql retrieve result is not found")
|
|
nl2SqlRetrieveResult = []*schema.Document{}
|
|
}
|
|
|
|
docs2RerankData := func(docs []*schema.Document) []*rerank.Data {
|
|
data := make([]*rerank.Data, 0, len(docs))
|
|
for i := range docs {
|
|
doc := docs[i]
|
|
data = append(data, &rerank.Data{Document: doc, Score: doc.Score()})
|
|
}
|
|
return data
|
|
}
|
|
|
|
// Obtain recall results from different channels according to the recall strategy
|
|
var retrieveResultArr [][]*rerank.Data
|
|
if retrieveCtx.Strategy.EnableNL2SQL {
|
|
// Nl2sql results
|
|
retrieveResultArr = append(retrieveResultArr, docs2RerankData(nl2SqlRetrieveResult))
|
|
}
|
|
switch retrieveCtx.Strategy.SearchType {
|
|
case knowledgeModel.SearchTypeSemantic:
|
|
retrieveResultArr = append(retrieveResultArr, docs2RerankData(vectorRetrieveResult))
|
|
case knowledgeModel.SearchTypeFullText:
|
|
retrieveResultArr = append(retrieveResultArr, docs2RerankData(esRetrieveResult))
|
|
case knowledgeModel.SearchTypeHybrid:
|
|
retrieveResultArr = append(retrieveResultArr, docs2RerankData(vectorRetrieveResult))
|
|
retrieveResultArr = append(retrieveResultArr, docs2RerankData(esRetrieveResult))
|
|
default:
|
|
retrieveResultArr = append(retrieveResultArr, docs2RerankData(vectorRetrieveResult))
|
|
}
|
|
|
|
query := retrieveCtx.OriginQuery
|
|
if retrieveCtx.Strategy.EnableQueryRewrite && retrieveCtx.RewrittenQuery != nil {
|
|
query = ptr.From(retrieveCtx.RewrittenQuery)
|
|
}
|
|
|
|
resp, err := k.reranker.Rerank(ctx, &rerank.Request{
|
|
Query: query,
|
|
Data: retrieveResultArr,
|
|
TopN: retrieveCtx.Strategy.TopK,
|
|
})
|
|
if err != nil {
|
|
logs.CtxErrorf(ctx, "rerank failed: %v", err)
|
|
return nil, err
|
|
}
|
|
|
|
retrieveResult = make([]*schema.Document, 0, len(resp.SortedData))
|
|
for _, item := range resp.SortedData {
|
|
if item.Score < ptr.From(retrieveCtx.Strategy.MinScore) {
|
|
continue
|
|
}
|
|
doc := item.Document
|
|
doc.WithScore(item.Score)
|
|
retrieveResult = append(retrieveResult, doc)
|
|
}
|
|
|
|
return retrieveResult, nil
|
|
}
|
|
|
|
func (k *knowledgeSVC) packResults(ctx context.Context, retrieveResult []*schema.Document) (results []*knowledgeModel.RetrieveSlice, err error) {
|
|
if len(retrieveResult) == 0 {
|
|
return nil, nil
|
|
}
|
|
sliceIDs := make(sets.Set[int64])
|
|
docIDs := make(sets.Set[int64])
|
|
knowledgeIDs := make(sets.Set[int64])
|
|
results = []*knowledgeModel.RetrieveSlice{}
|
|
documentMap := map[int64]*model.KnowledgeDocument{}
|
|
knowledgeMap := map[int64]*model.Knowledge{}
|
|
sliceScoreMap := map[int64]float64{}
|
|
for _, doc := range retrieveResult {
|
|
if len(doc.ID) == 0 {
|
|
results = append(results, &knowledgeModel.RetrieveSlice{
|
|
Slice: &knowledgeModel.Slice{
|
|
KnowledgeID: doc.MetaData["knowledge_id"].(int64),
|
|
DocumentID: doc.MetaData["document_id"].(int64),
|
|
DocumentName: doc.MetaData["document_name"].(string),
|
|
RawContent: []*knowledgeModel.SliceContent{
|
|
{
|
|
Type: knowledgeModel.SliceContentTypeText,
|
|
Text: ptr.Of(doc.Content),
|
|
},
|
|
},
|
|
Extra: map[string]string{
|
|
consts.KnowledgeName: doc.MetaData["knowledge_name"].(string),
|
|
consts.DocumentURL: "",
|
|
},
|
|
},
|
|
Score: 1,
|
|
})
|
|
} else {
|
|
id, err := strconv.ParseInt(doc.ID, 10, 64)
|
|
if err != nil {
|
|
logs.CtxErrorf(ctx, "convert id failed: %v", err)
|
|
return nil, errorx.New(errno.ErrKnowledgeSystemCode, errorx.KV("msg", "convert id failed"))
|
|
}
|
|
sliceIDs[id] = struct{}{}
|
|
sliceScoreMap[id] = doc.Score()
|
|
}
|
|
}
|
|
slices, err := k.sliceRepo.MGetSlices(ctx, sliceIDs.ToSlice())
|
|
if err != nil {
|
|
logs.CtxErrorf(ctx, "mget slices failed: %v", err)
|
|
return nil, errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", err.Error()))
|
|
}
|
|
for _, slice := range slices {
|
|
docIDs[slice.DocumentID] = struct{}{}
|
|
knowledgeIDs[slice.KnowledgeID] = struct{}{}
|
|
}
|
|
knowledgeModels, err := k.knowledgeRepo.FilterEnableKnowledge(ctx, knowledgeIDs.ToSlice())
|
|
if err != nil {
|
|
logs.CtxErrorf(ctx, "filter enable knowledge failed: %v", err)
|
|
return nil, errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", err.Error()))
|
|
}
|
|
for _, kn := range knowledgeModels {
|
|
knowledgeMap[kn.ID] = kn
|
|
}
|
|
documents, err := k.documentRepo.MGetByID(ctx, docIDs.ToSlice())
|
|
if err != nil {
|
|
logs.CtxErrorf(ctx, "mget documents failed: %v", err)
|
|
return nil, errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", err.Error()))
|
|
}
|
|
for _, doc := range documents {
|
|
documentMap[doc.ID] = doc
|
|
}
|
|
slicesInTable := map[int64][]*model.KnowledgeDocumentSlice{}
|
|
for _, slice := range slices {
|
|
if slice == nil {
|
|
continue
|
|
}
|
|
if knowledgeMap[slice.KnowledgeID] == nil {
|
|
continue
|
|
}
|
|
if knowledgeMap[slice.KnowledgeID].FormatType == int32(knowledgeModel.DocumentTypeTable) {
|
|
if slicesInTable[slice.DocumentID] == nil {
|
|
slicesInTable[slice.DocumentID] = []*model.KnowledgeDocumentSlice{}
|
|
}
|
|
slicesInTable[slice.DocumentID] = append(slicesInTable[slice.DocumentID], slice)
|
|
}
|
|
}
|
|
var sliceMap map[int64]*entity.Slice
|
|
for docID, slices := range slicesInTable {
|
|
if documentMap[docID] == nil {
|
|
continue
|
|
}
|
|
sliceMap, err = k.selectTableData(ctx, documentMap[docID].TableInfo, slices)
|
|
if err != nil {
|
|
logs.CtxErrorf(ctx, "select table data failed: %v", err)
|
|
return nil, err
|
|
}
|
|
}
|
|
for i := range slices {
|
|
doc := documentMap[slices[i].DocumentID]
|
|
kn := knowledgeMap[slices[i].KnowledgeID]
|
|
sliceEntity := entity.Slice{
|
|
Info: knowledgeModel.Info{
|
|
ID: slices[i].ID,
|
|
CreatorID: slices[i].CreatorID,
|
|
SpaceID: doc.SpaceID,
|
|
AppID: kn.AppID,
|
|
CreatedAtMs: slices[i].CreatedAt,
|
|
UpdatedAtMs: slices[i].UpdatedAt,
|
|
},
|
|
KnowledgeID: slices[i].KnowledgeID,
|
|
DocumentID: slices[i].DocumentID,
|
|
DocumentName: doc.Name,
|
|
Sequence: int64(slices[i].Sequence),
|
|
ByteCount: int64(len(slices[i].Content)),
|
|
SliceStatus: knowledgeModel.SliceStatus(slices[i].Status),
|
|
CharCount: int64(utf8.RuneCountInString(slices[i].Content)),
|
|
}
|
|
docUri := documentMap[slices[i].DocumentID].URI
|
|
var docURL string
|
|
if len(docUri) != 0 {
|
|
docURL, err = k.storage.GetObjectUrl(ctx, docUri)
|
|
if err != nil {
|
|
logs.CtxErrorf(ctx, "get object url failed: %v", err)
|
|
return nil, errorx.New(errno.ErrKnowledgeGetObjectURLFailCode, errorx.KV("msg", err.Error()))
|
|
}
|
|
}
|
|
sliceEntity.Extra = map[string]string{
|
|
consts.KnowledgeName: kn.Name,
|
|
consts.DocumentURL: docURL,
|
|
}
|
|
switch knowledgeModel.DocumentType(doc.DocumentType) {
|
|
case knowledgeModel.DocumentTypeText:
|
|
sliceEntity.RawContent = []*knowledgeModel.SliceContent{
|
|
{Type: knowledgeModel.SliceContentTypeText, Text: ptr.Of(k.formatSliceContent(ctx, slices[i].Content))},
|
|
}
|
|
case knowledgeModel.DocumentTypeTable:
|
|
if v, ok := sliceMap[slices[i].ID]; ok {
|
|
sliceEntity.RawContent = v.RawContent
|
|
}
|
|
case knowledgeModel.DocumentTypeImage:
|
|
img := fmt.Sprintf(`<img src="" data-tos-key="%s">`, documentMap[slices[i].DocumentID].URI)
|
|
sliceEntity.RawContent = []*knowledgeModel.SliceContent{
|
|
{Type: knowledgeModel.SliceContentTypeText, Text: ptr.Of(k.formatSliceContent(ctx, img+slices[i].Content))},
|
|
}
|
|
default:
|
|
}
|
|
|
|
results = append(results, &knowledgeModel.RetrieveSlice{
|
|
Slice: &sliceEntity,
|
|
Score: sliceScoreMap[slices[i].ID],
|
|
})
|
|
}
|
|
err = k.sliceRepo.IncrementHitCount(ctx, sliceIDs.ToSlice())
|
|
if err != nil {
|
|
logs.CtxWarnf(ctx, "increment hit count failed: %v", err)
|
|
}
|
|
return results, nil
|
|
}
|
|
|
|
func (k *knowledgeSVC) formatSliceContent(ctx context.Context, sliceContent string) string {
|
|
res := sliceContent
|
|
imageData := k.ParseFrontEndImageContent(ctx, sliceContent)
|
|
for _, v := range imageData {
|
|
if v.TagsKV[DATATOSKEY] != "" {
|
|
tosURL, err := k.storage.GetObjectUrl(ctx, v.TagsKV[DATATOSKEY])
|
|
if err != nil {
|
|
logs.CtxErrorf(ctx, "get object url failed: %v", err)
|
|
} else {
|
|
v.SetKV(SRC, tosURL)
|
|
}
|
|
}
|
|
sliceContent = sliceContent[0:v.StartOffset] + v.Format() + sliceContent[v.EndOffset:]
|
|
res = sliceContent
|
|
}
|
|
return res
|
|
}
|
|
|
|
type ImageContent struct {
|
|
TagsKV map[string]string
|
|
TagsKList []string
|
|
StartOffset int64
|
|
EndOffset int64
|
|
}
|
|
|
|
const (
|
|
SRC = "src"
|
|
DATATOSKEY = "data-tos-key"
|
|
)
|
|
|
|
func (i *ImageContent) Format() string {
|
|
res := "<img "
|
|
for _, v := range i.TagsKList {
|
|
res = res + v + "=\"" + i.TagsKV[v] + "\" "
|
|
}
|
|
return res + ">"
|
|
}
|
|
|
|
func (i *ImageContent) SetKV(k string, v string) {
|
|
if _, ok := i.TagsKV[k]; !ok {
|
|
i.TagsKList = append(i.TagsKList, k)
|
|
}
|
|
if i.TagsKV == nil {
|
|
i.TagsKV = make(map[string]string)
|
|
}
|
|
i.TagsKV[k] = v
|
|
}
|
|
|
|
func (k *knowledgeSVC) ParseFrontEndImageContent(ctx context.Context, s string) []*ImageContent {
|
|
res := make([]*ImageContent, 0)
|
|
imgRe := regexp.MustCompile(`<img\s+[^>]*>`)
|
|
// Find all matches
|
|
matches := imgRe.FindAllSubmatchIndex([]byte(s), -1)
|
|
// Traverse matches and output the src and data-tos-key fields
|
|
// Iterate the index of each match
|
|
for _, match := range matches {
|
|
// Outputs the beginning and end positions of the entire regular for each match in the text
|
|
matchStart := match[0]
|
|
matchEnd := match[1]
|
|
all := s[match[0]:match[1]]
|
|
|
|
re := regexp.MustCompile(`<img\s+([^>]+)>`)
|
|
// Initialize map to store kv information and remove redundant information
|
|
m := make(map[string]string)
|
|
l := make([]string, 0)
|
|
match := re.FindStringSubmatch(all)
|
|
if len(match) < 2 {
|
|
continue
|
|
}
|
|
attributes := match[1]
|
|
// Defines a regular expression pattern for extracting attribute key-value pairs
|
|
attrRe := regexp.MustCompile(`(\S+)=(?:"([^"]*)"|'([^']*)')`)
|
|
|
|
// Find all attribute key-value pairs
|
|
attrMatches := attrRe.FindAllStringSubmatch(attributes, -1)
|
|
|
|
// Extract and store kv information
|
|
for _, attrMatch := range attrMatches {
|
|
key := attrMatch[1]
|
|
value := attrMatch[2]
|
|
if value == "" {
|
|
value = attrMatch[3]
|
|
}
|
|
m[key] = value
|
|
l = append(l, key)
|
|
}
|
|
res = append(res, &ImageContent{
|
|
TagsKV: m,
|
|
TagsKList: l,
|
|
StartOffset: int64(matchStart),
|
|
EndOffset: int64(matchEnd),
|
|
})
|
|
}
|
|
slices.Reverse(res)
|
|
return res
|
|
}
|