218 lines
		
	
	
		
			6.4 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			218 lines
		
	
	
		
			6.4 KiB
		
	
	
	
		
			Go
		
	
	
	
| /*
 | |
|  * Copyright 2025 coze-dev Authors
 | |
|  *
 | |
|  * Licensed under the Apache License, Version 2.0 (the "License");
 | |
|  * you may not use this file except in compliance with the License.
 | |
|  * You may obtain a copy of the License at
 | |
|  *
 | |
|  *     http://www.apache.org/licenses/LICENSE-2.0
 | |
|  *
 | |
|  * Unless required by applicable law or agreed to in writing, software
 | |
|  * distributed under the License is distributed on an "AS IS" BASIS,
 | |
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
|  * See the License for the specific language governing permissions and
 | |
|  * limitations under the License.
 | |
|  */
 | |
| 
 | |
| package knowledge
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"strconv"
 | |
| 
 | |
| 	"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
 | |
| 	"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
 | |
| 	"github.com/coze-dev/coze-studio/backend/domain/knowledge/entity"
 | |
| 	domainknowledge "github.com/coze-dev/coze-studio/backend/domain/knowledge/service"
 | |
| 	crossknowledge "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
 | |
| 	"github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
 | |
| 	"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
 | |
| 	"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
 | |
| )
 | |
| 
 | |
| type Knowledge struct {
 | |
| 	client domainknowledge.Knowledge
 | |
| 	idGen  idgen.IDGenerator
 | |
| }
 | |
| 
 | |
| func NewKnowledgeRepository(client domainknowledge.Knowledge, idGen idgen.IDGenerator) *Knowledge {
 | |
| 	return &Knowledge{
 | |
| 		client: client,
 | |
| 		idGen:  idGen,
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (k *Knowledge) Store(ctx context.Context, document *crossknowledge.CreateDocumentRequest) (*crossknowledge.CreateDocumentResponse, error) {
 | |
| 	var (
 | |
| 		ps *entity.ParsingStrategy
 | |
| 		cs = &entity.ChunkingStrategy{}
 | |
| 	)
 | |
| 
 | |
| 	if document.ParsingStrategy == nil {
 | |
| 		return nil, errors.New("document parsing strategy is required")
 | |
| 	}
 | |
| 
 | |
| 	if document.ChunkingStrategy == nil {
 | |
| 		return nil, errors.New("document chunking strategy is required")
 | |
| 	}
 | |
| 
 | |
| 	if document.ParsingStrategy.ParseMode == crossknowledge.AccurateParseMode {
 | |
| 		ps = &entity.ParsingStrategy{}
 | |
| 		ps.ExtractImage = document.ParsingStrategy.ExtractImage
 | |
| 		ps.ExtractTable = document.ParsingStrategy.ExtractTable
 | |
| 		ps.ImageOCR = document.ParsingStrategy.ImageOCR
 | |
| 	}
 | |
| 
 | |
| 	chunkType, err := toChunkType(document.ChunkingStrategy.ChunkType)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	cs.ChunkType = chunkType
 | |
| 	cs.Separator = document.ChunkingStrategy.Separator
 | |
| 	cs.ChunkSize = document.ChunkingStrategy.ChunkSize
 | |
| 	cs.Overlap = document.ChunkingStrategy.Overlap
 | |
| 
 | |
| 	req := &entity.Document{
 | |
| 		Info: knowledge.Info{
 | |
| 			Name: document.FileName,
 | |
| 		},
 | |
| 		KnowledgeID:      document.KnowledgeID,
 | |
| 		Type:             knowledge.DocumentTypeText,
 | |
| 		URL:              document.FileURL,
 | |
| 		Source:           entity.DocumentSourceLocal,
 | |
| 		ParsingStrategy:  ps,
 | |
| 		ChunkingStrategy: cs,
 | |
| 		FileExtension:    document.FileExtension,
 | |
| 	}
 | |
| 
 | |
| 	uid := ctxutil.GetUIDFromCtx(ctx)
 | |
| 	if uid != nil {
 | |
| 		req.Info.CreatorID = *uid
 | |
| 	}
 | |
| 
 | |
| 	response, err := k.client.CreateDocument(ctx, &domainknowledge.CreateDocumentRequest{
 | |
| 		Documents: []*entity.Document{req},
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	kCResponse := &crossknowledge.CreateDocumentResponse{
 | |
| 		FileURL:    document.FileURL,
 | |
| 		DocumentID: response.Documents[0].Info.ID,
 | |
| 		FileName:   response.Documents[0].Info.Name,
 | |
| 	}
 | |
| 
 | |
| 	return kCResponse, nil
 | |
| }
 | |
| 
 | |
| func (k *Knowledge) Retrieve(ctx context.Context, r *crossknowledge.RetrieveRequest) (*crossknowledge.RetrieveResponse, error) {
 | |
| 	rs := &entity.RetrievalStrategy{}
 | |
| 	if r.RetrievalStrategy != nil {
 | |
| 		rs.TopK = r.RetrievalStrategy.TopK
 | |
| 		rs.MinScore = r.RetrievalStrategy.MinScore
 | |
| 		searchType, err := toSearchType(r.RetrievalStrategy.SearchType)
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 		rs.SearchType = searchType
 | |
| 		rs.EnableQueryRewrite = r.RetrievalStrategy.EnableQueryRewrite
 | |
| 		rs.EnableRerank = r.RetrievalStrategy.EnableRerank
 | |
| 		rs.EnableNL2SQL = r.RetrievalStrategy.EnableNL2SQL
 | |
| 	}
 | |
| 
 | |
| 	req := &domainknowledge.RetrieveRequest{
 | |
| 		Query:        r.Query,
 | |
| 		KnowledgeIDs: r.KnowledgeIDs,
 | |
| 		Strategy:     rs,
 | |
| 	}
 | |
| 
 | |
| 	response, err := k.client.Retrieve(ctx, req)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	ss := make([]*crossknowledge.Slice, 0, len(response.RetrieveSlices))
 | |
| 	for _, s := range response.RetrieveSlices {
 | |
| 		if s.Slice == nil {
 | |
| 			continue
 | |
| 		}
 | |
| 		ss = append(ss, &crossknowledge.Slice{
 | |
| 			DocumentID: strconv.FormatInt(s.Slice.DocumentID, 10),
 | |
| 			Output:     s.Slice.GetSliceContent(),
 | |
| 		})
 | |
| 
 | |
| 	}
 | |
| 
 | |
| 	return &crossknowledge.RetrieveResponse{
 | |
| 		Slices: ss,
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| func (k *Knowledge) Delete(ctx context.Context, r *crossknowledge.DeleteDocumentRequest) (*crossknowledge.DeleteDocumentResponse, error) {
 | |
| 	docID, err := strconv.ParseInt(r.DocumentID, 10, 64)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("invalid document id: %s", r.DocumentID)
 | |
| 	}
 | |
| 
 | |
| 	err = k.client.DeleteDocument(ctx, &domainknowledge.DeleteDocumentRequest{
 | |
| 		DocumentID: docID,
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return &crossknowledge.DeleteDocumentResponse{IsSuccess: false}, err
 | |
| 	}
 | |
| 
 | |
| 	return &crossknowledge.DeleteDocumentResponse{IsSuccess: true}, nil
 | |
| }
 | |
| 
 | |
| func (k *Knowledge) ListKnowledgeDetail(ctx context.Context, req *crossknowledge.ListKnowledgeDetailRequest) (*crossknowledge.ListKnowledgeDetailResponse, error) {
 | |
| 	response, err := k.client.MGetKnowledgeByID(ctx, &domainknowledge.MGetKnowledgeByIDRequest{
 | |
| 		KnowledgeIDs: req.KnowledgeIDs,
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	resp := &crossknowledge.ListKnowledgeDetailResponse{
 | |
| 		KnowledgeDetails: slices.Transform(response.Knowledge, func(a *knowledge.Knowledge) *crossknowledge.KnowledgeDetail {
 | |
| 			return &crossknowledge.KnowledgeDetail{
 | |
| 				ID:          a.ID,
 | |
| 				Name:        a.Name,
 | |
| 				Description: a.Description,
 | |
| 				IconURL:     a.IconURL,
 | |
| 				FormatType:  int64(a.Type),
 | |
| 			}
 | |
| 		}),
 | |
| 	}
 | |
| 
 | |
| 	return resp, nil
 | |
| }
 | |
| 
 | |
| func toSearchType(typ crossknowledge.SearchType) (knowledge.SearchType, error) {
 | |
| 	switch typ {
 | |
| 	case crossknowledge.SearchTypeSemantic:
 | |
| 		return knowledge.SearchTypeSemantic, nil
 | |
| 	case crossknowledge.SearchTypeFullText:
 | |
| 		return knowledge.SearchTypeFullText, nil
 | |
| 	case crossknowledge.SearchTypeHybrid:
 | |
| 		return knowledge.SearchTypeHybrid, nil
 | |
| 	default:
 | |
| 		return 0, fmt.Errorf("unknown search type: %v", typ)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func toChunkType(typ crossknowledge.ChunkType) (parser.ChunkType, error) {
 | |
| 	switch typ {
 | |
| 	case crossknowledge.ChunkTypeDefault:
 | |
| 		return parser.ChunkTypeDefault, nil
 | |
| 	case crossknowledge.ChunkTypeCustom:
 | |
| 		return parser.ChunkTypeCustom, nil
 | |
| 	case crossknowledge.ChunkTypeLeveled:
 | |
| 		return parser.ChunkTypeLeveled, nil
 | |
| 	default:
 | |
| 		return 0, fmt.Errorf("unknown chunk type: %v", typ)
 | |
| 	}
 | |
| }
 |