refactor(workflow): Move the knowledge component in the Workflow package into the common crossdomain package (#708)
This commit is contained in:
@@ -22,11 +22,16 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
||||
)
|
||||
|
||||
//go:generate mockgen -destination knowledgemock/knowledge_mock.go --package knowledgemock -source knowledge.go
|
||||
type Knowledge interface {
|
||||
ListKnowledge(ctx context.Context, request *knowledge.ListKnowledgeRequest) (response *knowledge.ListKnowledgeResponse, err error)
|
||||
GetKnowledgeByID(ctx context.Context, request *knowledge.GetKnowledgeByIDRequest) (response *knowledge.GetKnowledgeByIDResponse, err error)
|
||||
Retrieve(ctx context.Context, req *knowledge.RetrieveRequest) (*knowledge.RetrieveResponse, error)
|
||||
DeleteKnowledge(ctx context.Context, request *knowledge.DeleteKnowledgeRequest) error
|
||||
MGetKnowledgeByID(ctx context.Context, request *knowledge.MGetKnowledgeByIDRequest) (response *knowledge.MGetKnowledgeByIDResponse, err error)
|
||||
Store(ctx context.Context, document *knowledge.CreateDocumentRequest) (*knowledge.CreateDocumentResponse, error)
|
||||
Delete(ctx context.Context, r *knowledge.DeleteDocumentRequest) (*knowledge.DeleteDocumentResponse, error)
|
||||
ListKnowledgeDetail(ctx context.Context, req *knowledge.ListKnowledgeDetailRequest) (*knowledge.ListKnowledgeDetailResponse, error)
|
||||
}
|
||||
|
||||
var defaultSVC Knowledge
|
||||
|
||||
@@ -0,0 +1,161 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: knowledge.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -destination knowledgemock/knowledge_mock.go --package knowledgemock -source knowledge.go
|
||||
//
|
||||
|
||||
// Package knowledgemock is a generated GoMock package.
|
||||
package knowledgemock
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
knowledge "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockKnowledge is a mock of Knowledge interface.
|
||||
type MockKnowledge struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockKnowledgeMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockKnowledgeMockRecorder is the mock recorder for MockKnowledge.
|
||||
type MockKnowledgeMockRecorder struct {
|
||||
mock *MockKnowledge
|
||||
}
|
||||
|
||||
// NewMockKnowledge creates a new mock instance.
|
||||
func NewMockKnowledge(ctrl *gomock.Controller) *MockKnowledge {
|
||||
mock := &MockKnowledge{ctrl: ctrl}
|
||||
mock.recorder = &MockKnowledgeMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockKnowledge) EXPECT() *MockKnowledgeMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Delete mocks base method.
|
||||
func (m *MockKnowledge) Delete(ctx context.Context, r *knowledge.DeleteDocumentRequest) (*knowledge.DeleteDocumentResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Delete", ctx, r)
|
||||
ret0, _ := ret[0].(*knowledge.DeleteDocumentResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Delete indicates an expected call of Delete.
|
||||
func (mr *MockKnowledgeMockRecorder) Delete(ctx, r any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockKnowledge)(nil).Delete), ctx, r)
|
||||
}
|
||||
|
||||
// DeleteKnowledge mocks base method.
|
||||
func (m *MockKnowledge) DeleteKnowledge(ctx context.Context, request *knowledge.DeleteKnowledgeRequest) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteKnowledge", ctx, request)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteKnowledge indicates an expected call of DeleteKnowledge.
|
||||
func (mr *MockKnowledgeMockRecorder) DeleteKnowledge(ctx, request any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteKnowledge", reflect.TypeOf((*MockKnowledge)(nil).DeleteKnowledge), ctx, request)
|
||||
}
|
||||
|
||||
// GetKnowledgeByID mocks base method.
|
||||
func (m *MockKnowledge) GetKnowledgeByID(ctx context.Context, request *knowledge.GetKnowledgeByIDRequest) (*knowledge.GetKnowledgeByIDResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetKnowledgeByID", ctx, request)
|
||||
ret0, _ := ret[0].(*knowledge.GetKnowledgeByIDResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetKnowledgeByID indicates an expected call of GetKnowledgeByID.
|
||||
func (mr *MockKnowledgeMockRecorder) GetKnowledgeByID(ctx, request any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKnowledgeByID", reflect.TypeOf((*MockKnowledge)(nil).GetKnowledgeByID), ctx, request)
|
||||
}
|
||||
|
||||
// ListKnowledge mocks base method.
|
||||
func (m *MockKnowledge) ListKnowledge(ctx context.Context, request *knowledge.ListKnowledgeRequest) (*knowledge.ListKnowledgeResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListKnowledge", ctx, request)
|
||||
ret0, _ := ret[0].(*knowledge.ListKnowledgeResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListKnowledge indicates an expected call of ListKnowledge.
|
||||
func (mr *MockKnowledgeMockRecorder) ListKnowledge(ctx, request any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListKnowledge", reflect.TypeOf((*MockKnowledge)(nil).ListKnowledge), ctx, request)
|
||||
}
|
||||
|
||||
// ListKnowledgeDetail mocks base method.
|
||||
func (m *MockKnowledge) ListKnowledgeDetail(ctx context.Context, req *knowledge.ListKnowledgeDetailRequest) (*knowledge.ListKnowledgeDetailResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListKnowledgeDetail", ctx, req)
|
||||
ret0, _ := ret[0].(*knowledge.ListKnowledgeDetailResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListKnowledgeDetail indicates an expected call of ListKnowledgeDetail.
|
||||
func (mr *MockKnowledgeMockRecorder) ListKnowledgeDetail(ctx, req any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListKnowledgeDetail", reflect.TypeOf((*MockKnowledge)(nil).ListKnowledgeDetail), ctx, req)
|
||||
}
|
||||
|
||||
// MGetKnowledgeByID mocks base method.
|
||||
func (m *MockKnowledge) MGetKnowledgeByID(ctx context.Context, request *knowledge.MGetKnowledgeByIDRequest) (*knowledge.MGetKnowledgeByIDResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "MGetKnowledgeByID", ctx, request)
|
||||
ret0, _ := ret[0].(*knowledge.MGetKnowledgeByIDResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// MGetKnowledgeByID indicates an expected call of MGetKnowledgeByID.
|
||||
func (mr *MockKnowledgeMockRecorder) MGetKnowledgeByID(ctx, request any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGetKnowledgeByID", reflect.TypeOf((*MockKnowledge)(nil).MGetKnowledgeByID), ctx, request)
|
||||
}
|
||||
|
||||
// Retrieve mocks base method.
|
||||
func (m *MockKnowledge) Retrieve(ctx context.Context, req *knowledge.RetrieveRequest) (*knowledge.RetrieveResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Retrieve", ctx, req)
|
||||
ret0, _ := ret[0].(*knowledge.RetrieveResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Retrieve indicates an expected call of Retrieve.
|
||||
func (mr *MockKnowledgeMockRecorder) Retrieve(ctx, req any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Retrieve", reflect.TypeOf((*MockKnowledge)(nil).Retrieve), ctx, req)
|
||||
}
|
||||
|
||||
// Store mocks base method.
|
||||
func (m *MockKnowledge) Store(ctx context.Context, document *knowledge.CreateDocumentRequest) (*knowledge.CreateDocumentResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Store", ctx, document)
|
||||
ret0, _ := ret[0].(*knowledge.CreateDocumentResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Store indicates an expected call of Store.
|
||||
func (mr *MockKnowledgeMockRecorder) Store(ctx, document any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Store", reflect.TypeOf((*MockKnowledge)(nil).Store), ctx, document)
|
||||
}
|
||||
40
backend/crossdomain/contract/modelmgr/modelmgr.go
Normal file
40
backend/crossdomain/contract/modelmgr/modelmgr.go
Normal file
@@ -0,0 +1,40 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package modelmgr
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
eino "github.com/cloudwego/eino/components/model"
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
||||
)
|
||||
|
||||
//go:generate mockgen -destination modelmock/model_mock.go --package mockmodel -source modelmgr.go
|
||||
type Manager interface {
|
||||
GetModel(ctx context.Context, params *model.LLMParams) (eino.BaseChatModel, *modelmgr.Model, error)
|
||||
}
|
||||
|
||||
var defaultSVC Manager
|
||||
|
||||
func DefaultSVC() Manager {
|
||||
return defaultSVC
|
||||
}
|
||||
|
||||
func SetDefaultSVC(svc Manager) {
|
||||
defaultSVC = svc
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: modelmgr.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -destination modelmock/model_mock.go --package mockmodel -source modelmgr.go
|
||||
//
|
||||
|
||||
// Package mockmodel is a generated GoMock package.
|
||||
package mockmodel
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
model "github.com/cloudwego/eino/components/model"
|
||||
model0 "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
|
||||
modelmgr "github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockManager is a mock of Manager interface.
|
||||
type MockManager struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockManagerMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockManagerMockRecorder is the mock recorder for MockManager.
|
||||
type MockManagerMockRecorder struct {
|
||||
mock *MockManager
|
||||
}
|
||||
|
||||
// NewMockManager creates a new mock instance.
|
||||
func NewMockManager(ctrl *gomock.Controller) *MockManager {
|
||||
mock := &MockManager{ctrl: ctrl}
|
||||
mock.recorder = &MockManagerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockManager) EXPECT() *MockManagerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// GetModel mocks base method.
|
||||
func (m *MockManager) GetModel(ctx context.Context, params *model0.LLMParams) (model.BaseChatModel, *modelmgr.Model, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetModel", ctx, params)
|
||||
ret0, _ := ret[0].(model.BaseChatModel)
|
||||
ret1, _ := ret[1].(*modelmgr.Model)
|
||||
ret2, _ := ret[2].(error)
|
||||
return ret0, ret1, ret2
|
||||
}
|
||||
|
||||
// GetModel indicates an expected call of GetModel.
|
||||
func (mr *MockManagerMockRecorder) GetModel(ctx, params any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetModel", reflect.TypeOf((*MockManager)(nil).GetModel), ctx, params)
|
||||
}
|
||||
@@ -18,10 +18,18 @@ package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
|
||||
crossknowledge "github.com/coze-dev/coze-studio/backend/crossdomain/contract/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/knowledge/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/knowledge/service"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
)
|
||||
|
||||
var defaultSVC crossknowledge.Knowledge
|
||||
@@ -57,3 +65,120 @@ func (i *impl) GetKnowledgeByID(ctx context.Context, request *model.GetKnowledge
|
||||
func (i *impl) MGetKnowledgeByID(ctx context.Context, request *model.MGetKnowledgeByIDRequest) (response *model.MGetKnowledgeByIDResponse, err error) {
|
||||
return i.DomainSVC.MGetKnowledgeByID(ctx, request)
|
||||
}
|
||||
|
||||
func (i *impl) Store(ctx context.Context, document *model.CreateDocumentRequest) (*model.CreateDocumentResponse, error) {
|
||||
var (
|
||||
ps *entity.ParsingStrategy
|
||||
cs = &entity.ChunkingStrategy{}
|
||||
)
|
||||
|
||||
if document.ParsingStrategy == nil {
|
||||
return nil, errors.New("document parsing strategy is required")
|
||||
}
|
||||
|
||||
if document.ChunkingStrategy == nil {
|
||||
return nil, errors.New("document chunking strategy is required")
|
||||
}
|
||||
|
||||
if document.ParsingStrategy.ParseMode == model.AccurateParseMode {
|
||||
ps = &entity.ParsingStrategy{}
|
||||
ps.ExtractImage = document.ParsingStrategy.ExtractImage
|
||||
ps.ExtractTable = document.ParsingStrategy.ExtractTable
|
||||
ps.ImageOCR = document.ParsingStrategy.ImageOCR
|
||||
}
|
||||
|
||||
chunkType, err := toChunkType(document.ChunkingStrategy.ChunkType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cs.ChunkType = chunkType
|
||||
cs.Separator = document.ChunkingStrategy.Separator
|
||||
cs.ChunkSize = document.ChunkingStrategy.ChunkSize
|
||||
cs.Overlap = document.ChunkingStrategy.Overlap
|
||||
|
||||
req := &entity.Document{
|
||||
Info: knowledge.Info{
|
||||
Name: document.FileName,
|
||||
},
|
||||
KnowledgeID: document.KnowledgeID,
|
||||
Type: knowledge.DocumentTypeText,
|
||||
URL: document.FileURL,
|
||||
Source: entity.DocumentSourceLocal,
|
||||
ParsingStrategy: ps,
|
||||
ChunkingStrategy: cs,
|
||||
FileExtension: document.FileExtension,
|
||||
}
|
||||
|
||||
uid := ctxutil.GetUIDFromCtx(ctx)
|
||||
if uid != nil {
|
||||
req.Info.CreatorID = *uid
|
||||
}
|
||||
|
||||
response, err := i.DomainSVC.CreateDocument(ctx, &service.CreateDocumentRequest{
|
||||
Documents: []*entity.Document{req},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
kCResponse := &model.CreateDocumentResponse{
|
||||
FileURL: document.FileURL,
|
||||
DocumentID: response.Documents[0].Info.ID,
|
||||
FileName: response.Documents[0].Info.Name,
|
||||
}
|
||||
|
||||
return kCResponse, nil
|
||||
}
|
||||
|
||||
func (i *impl) Delete(ctx context.Context, r *model.DeleteDocumentRequest) (*model.DeleteDocumentResponse, error) {
|
||||
docID, err := strconv.ParseInt(r.DocumentID, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid document id: %s", r.DocumentID)
|
||||
}
|
||||
|
||||
err = i.DomainSVC.DeleteDocument(ctx, &service.DeleteDocumentRequest{
|
||||
DocumentID: docID,
|
||||
})
|
||||
if err != nil {
|
||||
return &model.DeleteDocumentResponse{IsSuccess: false}, err
|
||||
}
|
||||
|
||||
return &model.DeleteDocumentResponse{IsSuccess: true}, nil
|
||||
}
|
||||
|
||||
func (i *impl) ListKnowledgeDetail(ctx context.Context, req *model.ListKnowledgeDetailRequest) (*model.ListKnowledgeDetailResponse, error) {
|
||||
response, err := i.DomainSVC.MGetKnowledgeByID(ctx, &service.MGetKnowledgeByIDRequest{
|
||||
KnowledgeIDs: req.KnowledgeIDs,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp := &model.ListKnowledgeDetailResponse{
|
||||
KnowledgeDetails: slices.Transform(response.Knowledge, func(a *knowledge.Knowledge) *model.KnowledgeDetail {
|
||||
return &model.KnowledgeDetail{
|
||||
ID: a.ID,
|
||||
Name: a.Name,
|
||||
Description: a.Description,
|
||||
IconURL: a.IconURL,
|
||||
FormatType: int64(a.Type),
|
||||
}
|
||||
}),
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func toChunkType(typ model.ChunkType) (parser.ChunkType, error) {
|
||||
switch typ {
|
||||
case model.ChunkTypeDefault:
|
||||
return parser.ChunkTypeDefault, nil
|
||||
case model.ChunkTypeCustom:
|
||||
return parser.ChunkTypeCustom, nil
|
||||
case model.ChunkTypeLeveled:
|
||||
return parser.ChunkTypeLeveled, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown chunk type: %v", typ)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,37 +14,37 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package model
|
||||
package modelmgr
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
model2 "github.com/cloudwego/eino/components/model"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
|
||||
eino "github.com/cloudwego/eino/components/model"
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
|
||||
crossmodelmgr "github.com/coze-dev/coze-studio/backend/crossdomain/contract/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
||||
chatmodel2 "github.com/coze-dev/coze-studio/backend/infra/impl/chatmodel"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
type ModelManager struct {
|
||||
type modelManager struct {
|
||||
modelMgr modelmgr.Manager
|
||||
factory chatmodel.Factory
|
||||
}
|
||||
|
||||
func NewModelManager(m modelmgr.Manager, f chatmodel.Factory) *ModelManager {
|
||||
func InitDomainService(m modelmgr.Manager, f chatmodel.Factory) crossmodelmgr.Manager {
|
||||
if f == nil {
|
||||
f = chatmodel2.NewDefaultFactory()
|
||||
}
|
||||
return &ModelManager{
|
||||
return &modelManager{
|
||||
modelMgr: m,
|
||||
factory: f,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *ModelManager) GetModel(ctx context.Context, params *model.LLMParams) (model2.BaseChatModel, *modelmgr.Model, error) {
|
||||
func (m *modelManager) GetModel(ctx context.Context, params *model.LLMParams) (eino.BaseChatModel, *modelmgr.Model, error) {
|
||||
modelID := params.ModelType
|
||||
models, err := m.modelMgr.MGetModelByID(ctx, &modelmgr.MGetModelRequest{
|
||||
IDs: []int64{modelID},
|
||||
@@ -1,217 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/knowledge/entity"
|
||||
domainknowledge "github.com/coze-dev/coze-studio/backend/domain/knowledge/service"
|
||||
crossknowledge "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
)
|
||||
|
||||
type Knowledge struct {
|
||||
client domainknowledge.Knowledge
|
||||
idGen idgen.IDGenerator
|
||||
}
|
||||
|
||||
func NewKnowledgeRepository(client domainknowledge.Knowledge, idGen idgen.IDGenerator) *Knowledge {
|
||||
return &Knowledge{
|
||||
client: client,
|
||||
idGen: idGen,
|
||||
}
|
||||
}
|
||||
|
||||
func (k *Knowledge) Store(ctx context.Context, document *crossknowledge.CreateDocumentRequest) (*crossknowledge.CreateDocumentResponse, error) {
|
||||
var (
|
||||
ps *entity.ParsingStrategy
|
||||
cs = &entity.ChunkingStrategy{}
|
||||
)
|
||||
|
||||
if document.ParsingStrategy == nil {
|
||||
return nil, errors.New("document parsing strategy is required")
|
||||
}
|
||||
|
||||
if document.ChunkingStrategy == nil {
|
||||
return nil, errors.New("document chunking strategy is required")
|
||||
}
|
||||
|
||||
if document.ParsingStrategy.ParseMode == crossknowledge.AccurateParseMode {
|
||||
ps = &entity.ParsingStrategy{}
|
||||
ps.ExtractImage = document.ParsingStrategy.ExtractImage
|
||||
ps.ExtractTable = document.ParsingStrategy.ExtractTable
|
||||
ps.ImageOCR = document.ParsingStrategy.ImageOCR
|
||||
}
|
||||
|
||||
chunkType, err := toChunkType(document.ChunkingStrategy.ChunkType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cs.ChunkType = chunkType
|
||||
cs.Separator = document.ChunkingStrategy.Separator
|
||||
cs.ChunkSize = document.ChunkingStrategy.ChunkSize
|
||||
cs.Overlap = document.ChunkingStrategy.Overlap
|
||||
|
||||
req := &entity.Document{
|
||||
Info: knowledge.Info{
|
||||
Name: document.FileName,
|
||||
},
|
||||
KnowledgeID: document.KnowledgeID,
|
||||
Type: knowledge.DocumentTypeText,
|
||||
URL: document.FileURL,
|
||||
Source: entity.DocumentSourceLocal,
|
||||
ParsingStrategy: ps,
|
||||
ChunkingStrategy: cs,
|
||||
FileExtension: document.FileExtension,
|
||||
}
|
||||
|
||||
uid := ctxutil.GetUIDFromCtx(ctx)
|
||||
if uid != nil {
|
||||
req.Info.CreatorID = *uid
|
||||
}
|
||||
|
||||
response, err := k.client.CreateDocument(ctx, &domainknowledge.CreateDocumentRequest{
|
||||
Documents: []*entity.Document{req},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
kCResponse := &crossknowledge.CreateDocumentResponse{
|
||||
FileURL: document.FileURL,
|
||||
DocumentID: response.Documents[0].Info.ID,
|
||||
FileName: response.Documents[0].Info.Name,
|
||||
}
|
||||
|
||||
return kCResponse, nil
|
||||
}
|
||||
|
||||
func (k *Knowledge) Retrieve(ctx context.Context, r *crossknowledge.RetrieveRequest) (*crossknowledge.RetrieveResponse, error) {
|
||||
rs := &entity.RetrievalStrategy{}
|
||||
if r.RetrievalStrategy != nil {
|
||||
rs.TopK = r.RetrievalStrategy.TopK
|
||||
rs.MinScore = r.RetrievalStrategy.MinScore
|
||||
searchType, err := toSearchType(r.RetrievalStrategy.SearchType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rs.SearchType = searchType
|
||||
rs.EnableQueryRewrite = r.RetrievalStrategy.EnableQueryRewrite
|
||||
rs.EnableRerank = r.RetrievalStrategy.EnableRerank
|
||||
rs.EnableNL2SQL = r.RetrievalStrategy.EnableNL2SQL
|
||||
}
|
||||
|
||||
req := &domainknowledge.RetrieveRequest{
|
||||
Query: r.Query,
|
||||
KnowledgeIDs: r.KnowledgeIDs,
|
||||
Strategy: rs,
|
||||
}
|
||||
|
||||
response, err := k.client.Retrieve(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ss := make([]*crossknowledge.Slice, 0, len(response.RetrieveSlices))
|
||||
for _, s := range response.RetrieveSlices {
|
||||
if s.Slice == nil {
|
||||
continue
|
||||
}
|
||||
ss = append(ss, &crossknowledge.Slice{
|
||||
DocumentID: strconv.FormatInt(s.Slice.DocumentID, 10),
|
||||
Output: s.Slice.GetSliceContent(),
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
return &crossknowledge.RetrieveResponse{
|
||||
Slices: ss,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (k *Knowledge) Delete(ctx context.Context, r *crossknowledge.DeleteDocumentRequest) (*crossknowledge.DeleteDocumentResponse, error) {
|
||||
docID, err := strconv.ParseInt(r.DocumentID, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid document id: %s", r.DocumentID)
|
||||
}
|
||||
|
||||
err = k.client.DeleteDocument(ctx, &domainknowledge.DeleteDocumentRequest{
|
||||
DocumentID: docID,
|
||||
})
|
||||
if err != nil {
|
||||
return &crossknowledge.DeleteDocumentResponse{IsSuccess: false}, err
|
||||
}
|
||||
|
||||
return &crossknowledge.DeleteDocumentResponse{IsSuccess: true}, nil
|
||||
}
|
||||
|
||||
func (k *Knowledge) ListKnowledgeDetail(ctx context.Context, req *crossknowledge.ListKnowledgeDetailRequest) (*crossknowledge.ListKnowledgeDetailResponse, error) {
|
||||
response, err := k.client.MGetKnowledgeByID(ctx, &domainknowledge.MGetKnowledgeByIDRequest{
|
||||
KnowledgeIDs: req.KnowledgeIDs,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp := &crossknowledge.ListKnowledgeDetailResponse{
|
||||
KnowledgeDetails: slices.Transform(response.Knowledge, func(a *knowledge.Knowledge) *crossknowledge.KnowledgeDetail {
|
||||
return &crossknowledge.KnowledgeDetail{
|
||||
ID: a.ID,
|
||||
Name: a.Name,
|
||||
Description: a.Description,
|
||||
IconURL: a.IconURL,
|
||||
FormatType: int64(a.Type),
|
||||
}
|
||||
}),
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func toSearchType(typ crossknowledge.SearchType) (knowledge.SearchType, error) {
|
||||
switch typ {
|
||||
case crossknowledge.SearchTypeSemantic:
|
||||
return knowledge.SearchTypeSemantic, nil
|
||||
case crossknowledge.SearchTypeFullText:
|
||||
return knowledge.SearchTypeFullText, nil
|
||||
case crossknowledge.SearchTypeHybrid:
|
||||
return knowledge.SearchTypeHybrid, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown search type: %v", typ)
|
||||
}
|
||||
}
|
||||
|
||||
func toChunkType(typ crossknowledge.ChunkType) (parser.ChunkType, error) {
|
||||
switch typ {
|
||||
case crossknowledge.ChunkTypeDefault:
|
||||
return parser.ChunkTypeDefault, nil
|
||||
case crossknowledge.ChunkTypeCustom:
|
||||
return parser.ChunkTypeCustom, nil
|
||||
case crossknowledge.ChunkTypeLeveled:
|
||||
return parser.ChunkTypeLeveled, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown chunk type: %v", typ)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user