refactor(workflow): Move the knowledge component in the Workflow package into the common crossdomain package (#708)
This commit is contained in:
@@ -28,7 +28,8 @@ import (
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/spf13/cast"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
|
||||
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
|
||||
crossmodelmgr "github.com/coze-dev/coze-studio/backend/crossdomain/contract/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
@@ -116,7 +117,7 @@ func (c *Config) Build(ctx context.Context, _ *schema2.NodeSchema, _ ...schema2.
|
||||
return nil, errors.New("config intents is required")
|
||||
}
|
||||
|
||||
m, _, err := model.GetManager().GetModel(ctx, c.LLMParams)
|
||||
m, _, err := crossmodelmgr.DefaultSVC().GetModel(ctx, c.LLMParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ package knowledge
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
||||
)
|
||||
|
||||
func convertParsingType(p string) (knowledge.ParseMode, error) {
|
||||
@@ -52,6 +52,6 @@ func convertRetrievalSearchType(s int64) (knowledge.SearchType, error) {
|
||||
case 20:
|
||||
return knowledge.SearchTypeFullText, nil
|
||||
default:
|
||||
return "", fmt.Errorf("invalid RetrievalSearchType %v", s)
|
||||
return 0, fmt.Errorf("invalid RetrievalSearchType %v", s)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,7 +20,8 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
||||
crossknowledge "github.com/coze-dev/coze-studio/backend/crossdomain/contract/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
@@ -50,14 +51,10 @@ func (d *DeleterConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOpt
|
||||
}
|
||||
|
||||
func (d *DeleterConfig) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
return &Deleter{
|
||||
knowledgeDeleter: knowledge.GetKnowledgeOperator(),
|
||||
}, nil
|
||||
return &Deleter{}, nil
|
||||
}
|
||||
|
||||
type Deleter struct {
|
||||
knowledgeDeleter knowledge.KnowledgeOperator
|
||||
}
|
||||
type Deleter struct{}
|
||||
|
||||
func (k *Deleter) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
documentID, ok := input["documentID"].(string)
|
||||
@@ -69,7 +66,7 @@ func (k *Deleter) Invoke(ctx context.Context, input map[string]any) (map[string]
|
||||
DocumentID: documentID,
|
||||
}
|
||||
|
||||
response, err := k.knowledgeDeleter.Delete(ctx, req)
|
||||
response, err := crossknowledge.DefaultSVC().Delete(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -26,7 +26,8 @@ import (
|
||||
|
||||
"github.com/spf13/cast"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
||||
crossknowledge "github.com/coze-dev/coze-studio/backend/crossdomain/contract/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
@@ -109,7 +110,6 @@ func (i *IndexerConfig) Build(_ context.Context, _ *schema.NodeSchema, _ ...sche
|
||||
knowledgeID: i.KnowledgeID,
|
||||
parsingStrategy: i.ParsingStrategy,
|
||||
chunkingStrategy: i.ChunkingStrategy,
|
||||
knowledgeIndexer: knowledge.GetKnowledgeOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -117,7 +117,6 @@ type Indexer struct {
|
||||
knowledgeID int64
|
||||
parsingStrategy *knowledge.ParsingStrategy
|
||||
chunkingStrategy *knowledge.ChunkingStrategy
|
||||
knowledgeIndexer knowledge.KnowledgeOperator
|
||||
}
|
||||
|
||||
func (k *Indexer) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
@@ -141,7 +140,7 @@ func (k *Indexer) Invoke(ctx context.Context, input map[string]any) (map[string]
|
||||
FileExtension: ext,
|
||||
}
|
||||
|
||||
response, err := k.knowledgeIndexer.Store(ctx, req)
|
||||
response, err := crossknowledge.DefaultSVC().Store(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -22,7 +22,8 @@ import (
|
||||
|
||||
"github.com/spf13/cast"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
||||
crossknowledge "github.com/coze-dev/coze-studio/backend/crossdomain/contract/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
@@ -155,14 +156,12 @@ func (r *RetrieveConfig) Build(_ context.Context, _ *schema.NodeSchema, _ ...sch
|
||||
return &Retrieve{
|
||||
knowledgeIDs: r.KnowledgeIDs,
|
||||
retrievalStrategy: r.RetrievalStrategy,
|
||||
retriever: knowledge.GetKnowledgeOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type Retrieve struct {
|
||||
knowledgeIDs []int64
|
||||
retrievalStrategy *knowledge.RetrievalStrategy
|
||||
retriever knowledge.KnowledgeOperator
|
||||
}
|
||||
|
||||
func (kr *Retrieve) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
@@ -172,20 +171,20 @@ func (kr *Retrieve) Invoke(ctx context.Context, input map[string]any) (map[strin
|
||||
}
|
||||
|
||||
req := &knowledge.RetrieveRequest{
|
||||
Query: query,
|
||||
KnowledgeIDs: kr.knowledgeIDs,
|
||||
RetrievalStrategy: kr.retrievalStrategy,
|
||||
Query: query,
|
||||
KnowledgeIDs: kr.knowledgeIDs,
|
||||
Strategy: kr.retrievalStrategy,
|
||||
}
|
||||
|
||||
response, err := kr.retriever.Retrieve(ctx, req)
|
||||
response, err := crossknowledge.DefaultSVC().Retrieve(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := make(map[string]any)
|
||||
result[outputList] = slices.Transform(response.Slices, func(m *knowledge.Slice) any {
|
||||
result[outputList] = slices.Transform(response.RetrieveSlices, func(m *knowledge.RetrieveSlice) any {
|
||||
return map[string]any{
|
||||
"documentId": m.DocumentID,
|
||||
"output": m.Output,
|
||||
"documentId": m.Slice.DocumentID,
|
||||
"output": m.Slice.GetSliceContent(),
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@@ -34,10 +34,12 @@ import (
|
||||
callbacks2 "github.com/cloudwego/eino/utils/callbacks"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
|
||||
crossmodel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
|
||||
workflow3 "github.com/coze-dev/coze-studio/backend/api/model/workflow"
|
||||
crossknowledge "github.com/coze-dev/coze-studio/backend/crossdomain/contract/knowledge"
|
||||
crossmodelmgr "github.com/coze-dev/coze-studio/backend/crossdomain/contract/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
|
||||
crossmodel "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
@@ -157,7 +159,6 @@ type RetrievalStrategy struct {
|
||||
|
||||
type KnowledgeRecallConfig struct {
|
||||
ChatModel model.BaseChatModel
|
||||
Retriever knowledge.KnowledgeOperator
|
||||
RetrievalStrategy *RetrievalStrategy
|
||||
SelectedKnowledgeDetails []*knowledge.KnowledgeDetail
|
||||
}
|
||||
@@ -360,7 +361,7 @@ func (c *Config) Build(ctx context.Context, ns *schema2.NodeSchema, _ ...schema2
|
||||
knowledgeRecallConfig *KnowledgeRecallConfig
|
||||
)
|
||||
|
||||
chatModel, info, err = crossmodel.GetManager().GetModel(ctx, c.LLMParams)
|
||||
chatModel, info, err = crossmodelmgr.DefaultSVC().GetModel(ctx, c.LLMParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -369,7 +370,7 @@ func (c *Config) Build(ctx context.Context, ns *schema2.NodeSchema, _ ...schema2
|
||||
if exceptionConf != nil && exceptionConf.MaxRetry > 0 {
|
||||
backupModelParams := c.BackupLLMParams
|
||||
if backupModelParams != nil {
|
||||
fallbackM, fallbackI, err = crossmodel.GetManager().GetModel(ctx, backupModelParams)
|
||||
fallbackM, fallbackI, err = crossmodelmgr.DefaultSVC().GetModel(ctx, backupModelParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -491,11 +492,9 @@ func (c *Config) Build(ctx context.Context, ns *schema2.NodeSchema, _ ...schema2
|
||||
return nil, fmt.Errorf("workflow builtin chat model for knowledge recall not configured")
|
||||
}
|
||||
|
||||
knowledgeOperator := knowledge.GetKnowledgeOperator()
|
||||
setting := fcParams.KnowledgeFCParam.GlobalSetting
|
||||
knowledgeRecallConfig = &KnowledgeRecallConfig{
|
||||
ChatModel: kwChatModel,
|
||||
Retriever: knowledgeOperator,
|
||||
}
|
||||
searchType, err := toRetrievalSearchType(setting.SearchMode)
|
||||
if err != nil {
|
||||
@@ -523,7 +522,7 @@ func (c *Config) Build(ctx context.Context, ns *schema2.NodeSchema, _ ...schema2
|
||||
knowledgeIDs = append(knowledgeIDs, kid)
|
||||
}
|
||||
|
||||
detailResp, err := knowledgeOperator.ListKnowledgeDetail(ctx,
|
||||
detailResp, err := crossknowledge.DefaultSVC().ListKnowledgeDetail(ctx,
|
||||
&knowledge.ListKnowledgeDetailRequest{
|
||||
KnowledgeIDs: knowledgeIDs,
|
||||
})
|
||||
@@ -811,7 +810,7 @@ func toRetrievalSearchType(s int64) (knowledge.SearchType, error) {
|
||||
case 20:
|
||||
return knowledge.SearchTypeFullText, nil
|
||||
default:
|
||||
return "", fmt.Errorf("invalid retrieval search type %v", s)
|
||||
return 0, fmt.Errorf("invalid retrieval search type %v", s)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1156,28 +1155,28 @@ func injectKnowledgeTool(_ context.Context, g *compose.Graph[map[string]any, map
|
||||
return make(map[string]any), nil
|
||||
}
|
||||
|
||||
docs, err := cfg.Retriever.Retrieve(ctx, &knowledge.RetrieveRequest{
|
||||
Query: userPrompt,
|
||||
KnowledgeIDs: recallKnowledgeIDs,
|
||||
RetrievalStrategy: cfg.RetrievalStrategy.RetrievalStrategy,
|
||||
docs, err := crossknowledge.DefaultSVC().Retrieve(ctx, &knowledge.RetrieveRequest{
|
||||
Query: userPrompt,
|
||||
KnowledgeIDs: recallKnowledgeIDs,
|
||||
Strategy: cfg.RetrievalStrategy.RetrievalStrategy,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(docs.Slices) == 0 && cfg.RetrievalStrategy.NoReCallReplyMode == NoReCallReplyModeOfDefault {
|
||||
if len(docs.RetrieveSlices) == 0 && cfg.RetrievalStrategy.NoReCallReplyMode == NoReCallReplyModeOfDefault {
|
||||
return make(map[string]any), nil
|
||||
}
|
||||
|
||||
sb := strings.Builder{}
|
||||
if len(docs.Slices) == 0 && cfg.RetrievalStrategy.NoReCallReplyMode == NoReCallReplyModeOfCustomize {
|
||||
if len(docs.RetrieveSlices) == 0 && cfg.RetrievalStrategy.NoReCallReplyMode == NoReCallReplyModeOfCustomize {
|
||||
sb.WriteString("recall slice 1: \n")
|
||||
sb.WriteString(cfg.RetrievalStrategy.NoReCallReplyCustomizePrompt + "\n")
|
||||
}
|
||||
|
||||
for idx, msg := range docs.Slices {
|
||||
for idx, msg := range docs.RetrieveSlices {
|
||||
sb.WriteString(fmt.Sprintf("recall slice %d:\n", idx+1))
|
||||
sb.WriteString(fmt.Sprintf("%s\n", msg.Output))
|
||||
sb.WriteString(fmt.Sprintf("%s\n", msg.Slice.GetSliceContent()))
|
||||
}
|
||||
|
||||
output = map[string]any{
|
||||
|
||||
@@ -28,8 +28,9 @@ import (
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
crossmodel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
|
||||
crossmodelmgr "github.com/coze-dev/coze-studio/backend/crossdomain/contract/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
crossmodel "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
@@ -232,7 +233,7 @@ func (c *Config) Build(ctx context.Context, ns *schema2.NodeSchema, _ ...schema2
|
||||
err error
|
||||
)
|
||||
if c.LLMParams != nil {
|
||||
m, _, err = crossmodel.GetManager().GetModel(ctx, c.LLMParams)
|
||||
m, _, err = crossmodelmgr.DefaultSVC().GetModel(ctx, c.LLMParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user