coze-studio/backend/crossdomain/workflow/knowledge/knowledge.go

218 lines
6.4 KiB
Go

/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package knowledge
import (
"context"
"errors"
"fmt"
"strconv"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
"github.com/coze-dev/coze-studio/backend/domain/knowledge/entity"
domainknowledge "github.com/coze-dev/coze-studio/backend/domain/knowledge/service"
crossknowledge "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
)
type Knowledge struct {
client domainknowledge.Knowledge
idGen idgen.IDGenerator
}
func NewKnowledgeRepository(client domainknowledge.Knowledge, idGen idgen.IDGenerator) *Knowledge {
return &Knowledge{
client: client,
idGen: idGen,
}
}
func (k *Knowledge) Store(ctx context.Context, document *crossknowledge.CreateDocumentRequest) (*crossknowledge.CreateDocumentResponse, error) {
var (
ps *entity.ParsingStrategy
cs = &entity.ChunkingStrategy{}
)
if document.ParsingStrategy == nil {
return nil, errors.New("document parsing strategy is required")
}
if document.ChunkingStrategy == nil {
return nil, errors.New("document chunking strategy is required")
}
if document.ParsingStrategy.ParseMode == crossknowledge.AccurateParseMode {
ps = &entity.ParsingStrategy{}
ps.ExtractImage = document.ParsingStrategy.ExtractImage
ps.ExtractTable = document.ParsingStrategy.ExtractTable
ps.ImageOCR = document.ParsingStrategy.ImageOCR
}
chunkType, err := toChunkType(document.ChunkingStrategy.ChunkType)
if err != nil {
return nil, err
}
cs.ChunkType = chunkType
cs.Separator = document.ChunkingStrategy.Separator
cs.ChunkSize = document.ChunkingStrategy.ChunkSize
cs.Overlap = document.ChunkingStrategy.Overlap
req := &entity.Document{
Info: knowledge.Info{
Name: document.FileName,
},
KnowledgeID: document.KnowledgeID,
Type: knowledge.DocumentTypeText,
URL: document.FileURL,
Source: entity.DocumentSourceLocal,
ParsingStrategy: ps,
ChunkingStrategy: cs,
FileExtension: document.FileExtension,
}
uid := ctxutil.GetUIDFromCtx(ctx)
if uid != nil {
req.Info.CreatorID = *uid
}
response, err := k.client.CreateDocument(ctx, &domainknowledge.CreateDocumentRequest{
Documents: []*entity.Document{req},
})
if err != nil {
return nil, err
}
kCResponse := &crossknowledge.CreateDocumentResponse{
FileURL: document.FileURL,
DocumentID: response.Documents[0].Info.ID,
FileName: response.Documents[0].Info.Name,
}
return kCResponse, nil
}
func (k *Knowledge) Retrieve(ctx context.Context, r *crossknowledge.RetrieveRequest) (*crossknowledge.RetrieveResponse, error) {
rs := &entity.RetrievalStrategy{}
if r.RetrievalStrategy != nil {
rs.TopK = r.RetrievalStrategy.TopK
rs.MinScore = r.RetrievalStrategy.MinScore
searchType, err := toSearchType(r.RetrievalStrategy.SearchType)
if err != nil {
return nil, err
}
rs.SearchType = searchType
rs.EnableQueryRewrite = r.RetrievalStrategy.EnableQueryRewrite
rs.EnableRerank = r.RetrievalStrategy.EnableRerank
rs.EnableNL2SQL = r.RetrievalStrategy.EnableNL2SQL
}
req := &domainknowledge.RetrieveRequest{
Query: r.Query,
KnowledgeIDs: r.KnowledgeIDs,
Strategy: rs,
}
response, err := k.client.Retrieve(ctx, req)
if err != nil {
return nil, err
}
ss := make([]*crossknowledge.Slice, 0, len(response.RetrieveSlices))
for _, s := range response.RetrieveSlices {
if s.Slice == nil {
continue
}
ss = append(ss, &crossknowledge.Slice{
DocumentID: strconv.FormatInt(s.Slice.DocumentID, 10),
Output: s.Slice.GetSliceContent(),
})
}
return &crossknowledge.RetrieveResponse{
Slices: ss,
}, nil
}
func (k *Knowledge) Delete(ctx context.Context, r *crossknowledge.DeleteDocumentRequest) (*crossknowledge.DeleteDocumentResponse, error) {
docID, err := strconv.ParseInt(r.DocumentID, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid document id: %s", r.DocumentID)
}
err = k.client.DeleteDocument(ctx, &domainknowledge.DeleteDocumentRequest{
DocumentID: docID,
})
if err != nil {
return &crossknowledge.DeleteDocumentResponse{IsSuccess: false}, err
}
return &crossknowledge.DeleteDocumentResponse{IsSuccess: true}, nil
}
func (k *Knowledge) ListKnowledgeDetail(ctx context.Context, req *crossknowledge.ListKnowledgeDetailRequest) (*crossknowledge.ListKnowledgeDetailResponse, error) {
response, err := k.client.MGetKnowledgeByID(ctx, &domainknowledge.MGetKnowledgeByIDRequest{
KnowledgeIDs: req.KnowledgeIDs,
})
if err != nil {
return nil, err
}
resp := &crossknowledge.ListKnowledgeDetailResponse{
KnowledgeDetails: slices.Transform(response.Knowledge, func(a *knowledge.Knowledge) *crossknowledge.KnowledgeDetail {
return &crossknowledge.KnowledgeDetail{
ID: a.ID,
Name: a.Name,
Description: a.Description,
IconURL: a.IconURL,
FormatType: int64(a.Type),
}
}),
}
return resp, nil
}
func toSearchType(typ crossknowledge.SearchType) (knowledge.SearchType, error) {
switch typ {
case crossknowledge.SearchTypeSemantic:
return knowledge.SearchTypeSemantic, nil
case crossknowledge.SearchTypeFullText:
return knowledge.SearchTypeFullText, nil
case crossknowledge.SearchTypeHybrid:
return knowledge.SearchTypeHybrid, nil
default:
return 0, fmt.Errorf("unknown search type: %v", typ)
}
}
func toChunkType(typ crossknowledge.ChunkType) (parser.ChunkType, error) {
switch typ {
case crossknowledge.ChunkTypeDefault:
return parser.ChunkTypeDefault, nil
case crossknowledge.ChunkTypeCustom:
return parser.ChunkTypeCustom, nil
case crossknowledge.ChunkTypeLeveled:
return parser.ChunkTypeLeveled, nil
default:
return 0, fmt.Errorf("unknown chunk type: %v", typ)
}
}