From bb6ff0026b06ffa42ac00c1a3605bd8a87e4af08 Mon Sep 17 00:00:00 2001 From: shentongmartin Date: Tue, 5 Aug 2025 14:02:33 +0800 Subject: [PATCH] refactor: how to add a node type in workflow (#558) --- .../api/handler/coze/workflow_service_test.go | 177 +- backend/application/workflow/init.go | 4 + backend/application/workflow/workflow.go | 209 +- .../domain/plugin/service/exec_tool_test.go | 49 - backend/domain/workflow/entity/node_meta.go | 793 ++++++- .../workflow/entity/node_type_literal.go | 867 ------- backend/domain/workflow/entity/vo/canvas.go | 459 ++-- backend/domain/workflow/entity/vo/node.go | 6 - backend/domain/workflow/entity/workflow.go | 7 - .../internal/canvas/adaptor/canvas_test.go | 5 + .../internal/canvas/adaptor/from_node.go | 41 +- .../internal/canvas/adaptor/to_schema.go | 2056 ++--------------- .../internal/canvas/adaptor/type_convert.go | 1230 ---------- .../internal/canvas/convert/type_convert.go | 663 ++++++ .../canvas/validate/canvas_validate.go | 117 +- .../workflow/internal/compose/branch.go | 181 -- .../workflow/internal/compose/callbacks.go | 194 -- .../internal/compose/designate_option.go | 21 +- .../workflow/internal/compose/field_fill.go | 7 +- .../internal/compose/field_fill_test.go | 5 +- .../workflow/internal/compose/node_builder.go | 118 + .../workflow/internal/compose/node_runner.go | 217 +- .../workflow/internal/compose/node_schema.go | 580 ----- .../domain/workflow/internal/compose/state.go | 80 +- .../workflow/internal/compose/stream.go | 155 +- .../internal/compose/test/batch_test.go | 48 +- .../internal/compose/test/llm_test.go | 206 +- .../internal/compose/test/loop_test.go | 160 +- .../compose/test/question_answer_test.go | 206 +- .../internal/compose/test/workflow_test.go | 192 +- .../workflow/internal/compose/to_node.go | 652 ------ .../domain/workflow/internal/compose/utils.go | 107 - .../workflow/internal/compose/workflow.go | 132 +- .../internal/compose/workflow_from_node.go | 5 +- .../workflow/internal/compose/workflow_run.go | 5 +- .../internal/compose/workflow_tool.go | 13 +- .../workflow/internal/nodes/batch/batch.go | 136 +- .../workflow/internal/nodes/code/code.go | 99 +- .../workflow/internal/nodes/code/code_test.go | 133 +- .../workflow/internal/nodes/database/adapt.go | 236 ++ .../internal/nodes/database/common.go | 6 +- .../internal/nodes/database/customsql.go | 81 +- .../internal/nodes/database/customsql_test.go | 29 +- .../internal/nodes/database/delete.go | 87 +- .../internal/nodes/database/insert.go | 70 +- .../workflow/internal/nodes/database/query.go | 137 +- .../internal/nodes/database/query_test.go | 192 +- .../internal/nodes/database/update.go | 110 +- .../internal/nodes/emitter/emitter.go | 88 +- .../workflow/internal/nodes/entry/entry.go | 59 +- .../workflow/internal/nodes/exit/exit.go | 113 + .../internal/nodes/httprequester/adapt.go | 340 +++ .../nodes/httprequester/http_requester.go | 232 +- .../httprequester/http_requester_test.go | 21 +- .../nodes/intentdetector/intent_detector.go | 219 +- .../intentdetector/intent_detector_test.go | 88 - .../nodes/json/json_deserialization.go | 47 +- .../nodes/json/json_deserialization_test.go | 133 +- .../internal/nodes/json/json_serialization.go | 63 +- .../nodes/json/json_serialization_test.go | 28 +- .../internal/nodes/knowledge/adaptor.go | 57 + .../nodes/knowledge/knowledge_deleter.go | 48 +- .../nodes/knowledge/knowledge_indexer.go | 97 +- .../nodes/knowledge/knowledge_retrieve.go | 148 +- .../domain/workflow/internal/nodes/llm/llm.go | 635 ++++- .../workflow/internal/nodes/llm/prompt.go | 5 +- .../internal/nodes/loop/{ => break}/break.go | 23 +- .../internal/nodes/loop/continue/continue.go | 47 + .../workflow/internal/nodes/loop/loop.go | 207 +- .../domain/workflow/internal/nodes/nested.go | 90 - .../domain/workflow/internal/nodes/node.go | 194 ++ .../domain/workflow/internal/nodes/option.go | 170 ++ .../workflow/internal/nodes/plugin/plugin.go | 98 +- .../internal/nodes/qa/question_answer.go | 353 ++- .../internal/nodes/receiver/input_receiver.go | 43 +- .../internal/nodes/selector/callbacks.go | 190 ++ .../internal/nodes/selector/operator.go | 45 + .../internal/nodes/selector/schema.go | 148 +- .../internal/nodes/selector/selector.go | 83 +- .../domain/workflow/internal/nodes/stream.go | 86 +- .../nodes/subworkflow/sub_workflow.go | 65 +- .../workflow/internal/nodes/template.go | 19 +- .../nodes/textprocessor/text_processor.go | 94 +- .../textprocessor/text_processor_test.go | 10 +- .../variableaggregator/variable_aggregator.go | 218 +- .../nodes/variableassigner/variable_assign.go | 76 +- .../variable_assign_in_loop.go | 82 +- .../variableassigner/variable_assign_test.go | 52 +- .../workflow/internal/schema/branch_schema.go | 196 ++ .../workflow/internal/schema/node_builder.go | 73 + .../workflow/internal/schema/node_schema.go | 131 ++ .../domain/workflow/internal/schema/stream.go | 77 + .../{compose => schema}/workflow_schema.go | 92 +- .../domain/workflow/service/service_impl.go | 197 +- backend/domain/workflow/service/utils.go | 157 -- backend/types/errno/workflow.go | 2 +- 96 files changed, 8305 insertions(+), 8717 deletions(-) delete mode 100644 backend/domain/plugin/service/exec_tool_test.go delete mode 100644 backend/domain/workflow/entity/node_type_literal.go delete mode 100644 backend/domain/workflow/internal/canvas/adaptor/type_convert.go create mode 100644 backend/domain/workflow/internal/canvas/convert/type_convert.go delete mode 100644 backend/domain/workflow/internal/compose/branch.go delete mode 100644 backend/domain/workflow/internal/compose/callbacks.go create mode 100644 backend/domain/workflow/internal/compose/node_builder.go delete mode 100644 backend/domain/workflow/internal/compose/node_schema.go delete mode 100644 backend/domain/workflow/internal/compose/to_node.go delete mode 100644 backend/domain/workflow/internal/compose/utils.go create mode 100644 backend/domain/workflow/internal/nodes/database/adapt.go create mode 100644 backend/domain/workflow/internal/nodes/exit/exit.go create mode 100644 backend/domain/workflow/internal/nodes/httprequester/adapt.go delete mode 100644 backend/domain/workflow/internal/nodes/intentdetector/intent_detector_test.go create mode 100644 backend/domain/workflow/internal/nodes/knowledge/adaptor.go rename backend/domain/workflow/internal/nodes/loop/{ => break}/break.go (54%) create mode 100644 backend/domain/workflow/internal/nodes/loop/continue/continue.go delete mode 100644 backend/domain/workflow/internal/nodes/nested.go create mode 100644 backend/domain/workflow/internal/nodes/node.go create mode 100644 backend/domain/workflow/internal/nodes/option.go create mode 100644 backend/domain/workflow/internal/nodes/selector/callbacks.go create mode 100644 backend/domain/workflow/internal/schema/branch_schema.go create mode 100644 backend/domain/workflow/internal/schema/node_builder.go create mode 100644 backend/domain/workflow/internal/schema/node_schema.go create mode 100644 backend/domain/workflow/internal/schema/stream.go rename backend/domain/workflow/internal/{compose => schema}/workflow_schema.go (81%) diff --git a/backend/api/handler/coze/workflow_service_test.go b/backend/api/handler/coze/workflow_service_test.go index 34b3301d..e71fff9f 100644 --- a/backend/api/handler/coze/workflow_service_test.go +++ b/backend/api/handler/coze/workflow_service_test.go @@ -105,26 +105,28 @@ import ( func TestMain(m *testing.M) { callbacks.AppendGlobalHandlers(service.GetTokenCallbackHandler()) + service.RegisterAllNodeAdaptors() os.Exit(m.Run()) } type wfTestRunner struct { - t *testing.T - h *server.Hertz - ctrl *gomock.Controller - idGen *mock.MockIDGenerator - search *searchmock.MockNotifier - appVarS *mockvar.MockStore - userVarS *mockvar.MockStore - varGetter *mockvar.MockVariablesMetaGetter - modelManage *mockmodel.MockManager - plugin *mockPlugin.MockPluginService - tos *storageMock.MockStorage - knowledge *knowledgemock.MockKnowledgeOperator - database *databasemock.MockDatabaseOperator - pluginSrv *pluginmock.MockService - ctx context.Context - closeFn func() + t *testing.T + h *server.Hertz + ctrl *gomock.Controller + idGen *mock.MockIDGenerator + search *searchmock.MockNotifier + appVarS *mockvar.MockStore + userVarS *mockvar.MockStore + varGetter *mockvar.MockVariablesMetaGetter + modelManage *mockmodel.MockManager + plugin *mockPlugin.MockPluginService + tos *storageMock.MockStorage + knowledge *knowledgemock.MockKnowledgeOperator + database *databasemock.MockDatabaseOperator + pluginSrv *pluginmock.MockService + internalModel *testutil.UTChatModel + ctx context.Context + closeFn func() } var req2URL = map[reflect.Type]string{ @@ -243,9 +245,11 @@ func newWfTestRunner(t *testing.T) *wfTestRunner { cpStore := checkpoint.NewRedisStore(redisClient) + utChatModel := &testutil.UTChatModel{} + mockTos := storageMock.NewMockStorage(ctrl) mockTos.EXPECT().GetObjectUrl(gomock.Any(), gomock.Any(), gomock.Any()).Return("", nil).AnyTimes() - workflowRepo := service.NewWorkflowRepository(mockIDGen, db, redisClient, mockTos, cpStore, nil) + workflowRepo := service.NewWorkflowRepository(mockIDGen, db, redisClient, mockTos, cpStore, utChatModel) mockey.Mock(appworkflow.GetWorkflowDomainSVC).Return(service.NewWorkflowService(workflowRepo)).Build() mockey.Mock(workflow2.GetRepository).Return(workflowRepo).Build() @@ -312,22 +316,23 @@ func newWfTestRunner(t *testing.T) *wfTestRunner { } return &wfTestRunner{ - t: t, - h: h, - ctrl: ctrl, - idGen: mockIDGen, - search: mockSearchNotify, - appVarS: mockGlobalAppVarStore, - userVarS: mockGlobalUserVarStore, - varGetter: mockVarGetter, - modelManage: mockModelManage, - plugin: mPlugin, - tos: mockTos, - knowledge: mockKwOperator, - database: mockDatabaseOperator, - ctx: context.Background(), - closeFn: f, - pluginSrv: mockPluginSrv, + t: t, + h: h, + ctrl: ctrl, + idGen: mockIDGen, + search: mockSearchNotify, + appVarS: mockGlobalAppVarStore, + userVarS: mockGlobalUserVarStore, + varGetter: mockVarGetter, + modelManage: mockModelManage, + plugin: mPlugin, + tos: mockTos, + knowledge: mockKwOperator, + database: mockDatabaseOperator, + internalModel: utChatModel, + ctx: context.Background(), + closeFn: f, + pluginSrv: mockPluginSrv, } } @@ -1110,7 +1115,8 @@ func TestValidateTree(t *testing.T) { assert.Equal(t, i.Message, `node "代码_1" not connected`) } if i.NodeError.NodeID == "160892" { - assert.Contains(t, i.Message, `node "意图识别"'s port "branch_1" not connected`, `node "意图识别"'s port "default" not connected;`) + assert.Contains(t, i.Message, `node "意图识别"'s port "branch_1" not connected`) + assert.Contains(t, i.Message, `node "意图识别"'s port "default" not connected`) } } @@ -1157,7 +1163,8 @@ func TestValidateTree(t *testing.T) { assert.Equal(t, i.Message, `node "代码_1" not connected`) } if i.NodeError.NodeID == "160892" { - assert.Contains(t, i.Message, `node "意图识别"'s port "branch_1" not connected`, `node "意图识别"'s port "default" not connected;`) + assert.Contains(t, i.Message, `node "意图识别"'s port "branch_1" not connected`) + assert.Contains(t, i.Message, `node "意图识别"'s port "default" not connected`) } } } @@ -2950,41 +2957,41 @@ func TestLLMWithSkills(t *testing.T) { r := newWfTestRunner(t) defer r.closeFn() - utChatModel := &testutil.UTChatModel{ - InvokeResultProvider: func(index int, in []*schema.Message) (*schema.Message, error) { - if index == 0 { - assert.Equal(t, 1, len(in)) - assert.Contains(t, in[0].Content, "7512369185624686592", "你是一个知识库意图识别AI Agent", "北京有哪些著名的景点") - return &schema.Message{ - Role: schema.Assistant, - Content: "7512369185624686592", - ResponseMeta: &schema.ResponseMeta{ - Usage: &schema.TokenUsage{ - PromptTokens: 10, - CompletionTokens: 11, - TotalTokens: 21, - }, + utChatModel := r.internalModel + utChatModel.InvokeResultProvider = func(index int, in []*schema.Message) (*schema.Message, error) { + if index == 0 { + assert.Equal(t, 1, len(in)) + assert.Contains(t, in[0].Content, "7512369185624686592", "你是一个知识库意图识别AI Agent", "北京有哪些著名的景点") + return &schema.Message{ + Role: schema.Assistant, + Content: "7512369185624686592", + ResponseMeta: &schema.ResponseMeta{ + Usage: &schema.TokenUsage{ + PromptTokens: 10, + CompletionTokens: 11, + TotalTokens: 21, }, - }, nil + }, + }, nil - } else if index == 1 { - assert.Equal(t, 2, len(in)) - for _, message := range in { - if message.Role == schema.System { - assert.Equal(t, "你是一个旅游推荐专家,通过用户提出的问题,推荐用户具体城市的旅游景点", message.Content) - } - if message.Role == schema.User { - assert.Contains(t, message.Content, "天安门广场 ‌:中国政治文化中心,见证了近现代重大历史事件‌", "八达岭长城 ‌:明代长城的精华段,被誉为“不到长城非好汉") - } + } else if index == 1 { + assert.Equal(t, 2, len(in)) + for _, message := range in { + if message.Role == schema.System { + assert.Equal(t, "你是一个旅游推荐专家,通过用户提出的问题,推荐用户具体城市的旅游景点", message.Content) + } + if message.Role == schema.User { + assert.Contains(t, message.Content, "天安门广场 ‌:中国政治文化中心,见证了近现代重大历史事件‌", "八达岭长城 ‌:明代长城的精华段,被誉为“不到长城非好汉") } - return &schema.Message{ - Role: schema.Assistant, - Content: `八达岭长城 ‌:明代长城的精华段,被誉为“不到长城非好汉‌`, - }, nil } - return nil, fmt.Errorf("unexpected index: %d", index) - }, + return &schema.Message{ + Role: schema.Assistant, + Content: `八达岭长城 ‌:明代长城的精华段,被誉为“不到长城非好汉‌`, + }, nil + } + return nil, fmt.Errorf("unexpected index: %d", index) } + r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(utChatModel, nil, nil).AnyTimes() r.knowledge.EXPECT().ListKnowledgeDetail(gomock.Any(), gomock.Any()).Return(&knowledge.ListKnowledgeDetailResponse{ @@ -3767,10 +3774,10 @@ func TestReleaseApplicationWorkflows(t *testing.T) { var validateCv func(ns []*vo.Node) validateCv = func(ns []*vo.Node) { for _, n := range ns { - if n.Type == vo.BlockTypeBotSubWorkflow { + if n.Type == entity.NodeTypeSubWorkflow.IDStr() { assert.Equal(t, n.Data.Inputs.WorkflowVersion, version) } - if n.Type == vo.BlockTypeBotAPI { + if n.Type == entity.NodeTypePlugin.IDStr() { for _, apiParam := range n.Data.Inputs.APIParams { // In the application, the workflow plugin node When the plugin version is equal to 0, the plugin is a plugin created in the application if apiParam.Name == "pluginVersion" { @@ -3779,7 +3786,7 @@ func TestReleaseApplicationWorkflows(t *testing.T) { } } - if n.Type == vo.BlockTypeBotLLM { + if n.Type == entity.NodeTypeLLM.IDStr() { if n.Data.Inputs.FCParam != nil && n.Data.Inputs.FCParam.PluginFCParam != nil { // In the application, the workflow llm node When the plugin version is equal to 0, the plugin is a plugin created in the application for _, p := range n.Data.Inputs.FCParam.PluginFCParam.PluginList { @@ -4063,8 +4070,8 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) { var validateSubWorkflowIDs func(nodes []*vo.Node) validateSubWorkflowIDs = func(nodes []*vo.Node) { for _, node := range nodes { - switch node.Type { - case vo.BlockTypeBotAPI: + switch entity.IDStrToNodeType(node.Type) { + case entity.NodeTypePlugin: apiParams := slices.ToMap(node.Data.Inputs.APIParams, func(e *vo.Param) (string, *vo.Param) { return e.Name, e }) @@ -4082,7 +4089,7 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) { assert.Equal(t, "100100", pID) } - case vo.BlockTypeBotSubWorkflow: + case entity.NodeTypeSubWorkflow: assert.True(t, copiedIDMap[node.Data.Inputs.WorkflowID]) wfId, err := strconv.ParseInt(node.Data.Inputs.WorkflowID, 10, 64) assert.NoError(t, err) @@ -4096,7 +4103,7 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) { err = sonic.UnmarshalString(subWf.Canvas, subworkflowCanvas) assert.NoError(t, err) validateSubWorkflowIDs(subworkflowCanvas.Nodes) - case vo.BlockTypeBotLLM: + case entity.NodeTypeLLM: if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil { for _, w := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList { assert.True(t, copiedIDMap[w.WorkflowID]) @@ -4116,13 +4123,13 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) { assert.Equal(t, "100100", k.ID) } } - case vo.BlockTypeBotDataset, vo.BlockTypeBotDatasetWrite: + case entity.NodeTypeKnowledgeIndexer, entity.NodeTypeKnowledgeRetriever: datasetListInfoParam := node.Data.Inputs.DatasetParam[0] knowledgeIDs := datasetListInfoParam.Input.Value.Content.([]any) for idx := range knowledgeIDs { assert.Equal(t, "100100", knowledgeIDs[idx].(string)) } - case vo.BlockTypeDatabase, vo.BlockTypeDatabaseSelect, vo.BlockTypeDatabaseInsert, vo.BlockTypeDatabaseDelete, vo.BlockTypeDatabaseUpdate: + case entity.NodeTypeDatabaseCustomSQL, entity.NodeTypeDatabaseQuery, entity.NodeTypeDatabaseInsert, entity.NodeTypeDatabaseDelete, entity.NodeTypeDatabaseUpdate: for _, d := range node.Data.Inputs.DatabaseInfoList { assert.Equal(t, "100100", d.DatabaseInfoID) } @@ -4208,10 +4215,10 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) { var validateSubWorkflowIDs func(nodes []*vo.Node) validateSubWorkflowIDs = func(nodes []*vo.Node) { for _, node := range nodes { - switch node.Type { - case vo.BlockTypeBotSubWorkflow: + switch entity.IDStrToNodeType(node.Type) { + case entity.NodeTypeSubWorkflow: assert.True(t, copiedIDMap[node.Data.Inputs.WorkflowID]) - case vo.BlockTypeBotLLM: + case entity.NodeTypeLLM: if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil { for _, w := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList { assert.True(t, copiedIDMap[w.WorkflowID]) @@ -4229,13 +4236,13 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) { assert.Equal(t, "100100", k.ID) } } - case vo.BlockTypeBotDataset, vo.BlockTypeBotDatasetWrite: + case entity.NodeTypeKnowledgeIndexer, entity.NodeTypeKnowledgeRetriever: datasetListInfoParam := node.Data.Inputs.DatasetParam[0] knowledgeIDs := datasetListInfoParam.Input.Value.Content.([]any) for idx := range knowledgeIDs { assert.Equal(t, "100100", knowledgeIDs[idx].(string)) } - case vo.BlockTypeDatabase, vo.BlockTypeDatabaseSelect, vo.BlockTypeDatabaseInsert, vo.BlockTypeDatabaseDelete, vo.BlockTypeDatabaseUpdate: + case entity.NodeTypeDatabaseCustomSQL, entity.NodeTypeDatabaseQuery, entity.NodeTypeDatabaseInsert, entity.NodeTypeDatabaseDelete, entity.NodeTypeDatabaseUpdate: for _, d := range node.Data.Inputs.DatabaseInfoList { assert.Equal(t, "100100", d.DatabaseInfoID) } @@ -4356,7 +4363,7 @@ func TestMoveWorkflowAppToLibrary(t *testing.T) { err = sonic.Unmarshal(data, mainCanvas) assert.NoError(t, err) for _, node := range mainCanvas.Nodes { - if node.Type == vo.BlockTypeBotSubWorkflow { + if node.Type == entity.NodeTypeSubWorkflow.IDStr() { if node.Data.Inputs.WorkflowID == "7516826260387921920" { node.Data.Inputs.WorkflowID = c1IdStr } @@ -4372,7 +4379,7 @@ func TestMoveWorkflowAppToLibrary(t *testing.T) { err = sonic.Unmarshal(cc1Data, cc1Canvas) assert.NoError(t, err) for _, node := range cc1Canvas.Nodes { - if node.Type == vo.BlockTypeBotSubWorkflow { + if node.Type == entity.NodeTypeSubWorkflow.IDStr() { if node.Data.Inputs.WorkflowID == "7516826283318181888" { node.Data.Inputs.WorkflowID = c2IdStr } @@ -4423,7 +4430,7 @@ func TestMoveWorkflowAppToLibrary(t *testing.T) { assert.NoError(t, err) for _, node := range newMainCanvas.Nodes { - if node.Type == vo.BlockTypeBotSubWorkflow { + if node.Type == entity.NodeTypeSubWorkflow.IDStr() { assert.True(t, newSubWorkflowID[node.Data.Inputs.WorkflowID]) assert.Equal(t, "v0.0.1", node.Data.Inputs.WorkflowVersion) } @@ -4437,7 +4444,7 @@ func TestMoveWorkflowAppToLibrary(t *testing.T) { assert.NoError(t, err) for _, node := range cc1Canvas.Nodes { - if node.Type == vo.BlockTypeBotSubWorkflow { + if node.Type == entity.NodeTypeSubWorkflow.IDStr() { assert.True(t, newSubWorkflowID[node.Data.Inputs.WorkflowID]) assert.Equal(t, "v0.0.1", node.Data.Inputs.WorkflowVersion) } @@ -4508,10 +4515,10 @@ func TestDuplicateWorkflowsByAppID(t *testing.T) { var validateSubWorkflowIDs func(nodes []*vo.Node) validateSubWorkflowIDs = func(nodes []*vo.Node) { for _, node := range nodes { - if node.Type == vo.BlockTypeBotSubWorkflow { + if node.Type == entity.NodeTypeSubWorkflow.IDStr() { assert.True(t, copiedIDMap[node.Data.Inputs.WorkflowID]) } - if node.Type == vo.BlockTypeBotLLM { + if node.Type == entity.NodeTypeLLM.IDStr() { if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil { for _, w := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList { assert.True(t, copiedIDMap[w.WorkflowID]) diff --git a/backend/application/workflow/init.go b/backend/application/workflow/init.go index 1e8e6de8..2b14b587 100644 --- a/backend/application/workflow/init.go +++ b/backend/application/workflow/init.go @@ -24,6 +24,7 @@ import ( "github.com/cloudwego/eino/callbacks" "github.com/coze-dev/coze-studio/backend/application/internal" + wfdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/database" wfknowledge "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/knowledge" wfmodel "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/model" @@ -78,6 +79,9 @@ func InitService(ctx context.Context, components *ServiceComponents) (*Applicati if !ok { logs.CtxWarnf(ctx, "workflow builtin chat model for knowledge recall not configured") } + + service.RegisterAllNodeAdaptors() + workflowRepo := service.NewWorkflowRepository(components.IDGen, components.DB, components.Cache, components.Tos, components.CPStore, bcm) workflow.SetRepository(workflowRepo) diff --git a/backend/application/workflow/workflow.go b/backend/application/workflow/workflow.go index 9b268ad0..d006bf35 100644 --- a/backend/application/workflow/workflow.go +++ b/backend/application/workflow/workflow.go @@ -93,8 +93,8 @@ func (w *ApplicationService) GetNodeTemplateList(ctx context.Context, req *workf toQueryTypes := make(map[entity.NodeType]bool) for _, t := range req.NodeTypes { - entityType, err := nodeType2EntityNodeType(t) - if err != nil { + entityType := entity.IDStrToNodeType(t) + if len(entityType) == 0 { logs.Warnf("get node type %v failed, err:=%v", t, err) continue } @@ -116,23 +116,19 @@ func (w *ApplicationService) GetNodeTemplateList(ctx context.Context, req *workf Name: category, } for _, nodeMeta := range nodeMetaList { - tplType, err := entityNodeTypeToAPINodeTemplateType(nodeMeta.Type) - if err != nil { - return nil, err - } tpl := &workflow.NodeTemplate{ ID: fmt.Sprintf("%d", nodeMeta.ID), - Type: tplType, + Type: workflow.NodeTemplateType(nodeMeta.ID), Name: ternary.IFElse(i18n.GetLocale(ctx) == i18n.LocaleEN, nodeMeta.EnUSName, nodeMeta.Name), Desc: ternary.IFElse(i18n.GetLocale(ctx) == i18n.LocaleEN, nodeMeta.EnUSDescription, nodeMeta.Desc), IconURL: nodeMeta.IconURL, SupportBatch: ternary.IFElse(nodeMeta.SupportBatch, workflow.SupportBatch_SUPPORT, workflow.SupportBatch_NOT_SUPPORT), - NodeType: fmt.Sprintf("%d", tplType), + NodeType: fmt.Sprintf("%d", nodeMeta.ID), Color: nodeMeta.Color, } resp.Data.TemplateList = append(resp.Data.TemplateList, tpl) - categoryMap[category].NodeTypeList = append(categoryMap[category].NodeTypeList, fmt.Sprintf("%d", tplType)) + categoryMap[category].NodeTypeList = append(categoryMap[category].NodeTypeList, fmt.Sprintf("%d", nodeMeta.ID)) } } @@ -178,7 +174,7 @@ func (w *ApplicationService) CreateWorkflow(ctx context.Context, req *workflow.C IconURI: req.IconURI, AppID: parseInt64(req.ProjectID), Mode: ternary.IFElse(req.IsSetFlowMode(), req.GetFlowMode(), workflow.WorkflowMode_Workflow), - InitCanvasSchema: entity.GetDefaultInitCanvasJsonSchema(i18n.GetLocale(ctx)), + InitCanvasSchema: vo.GetDefaultInitCanvasJsonSchema(i18n.GetLocale(ctx)), } id, err := GetWorkflowDomainSVC().Create(ctx, wf) @@ -1041,7 +1037,8 @@ func (w *ApplicationService) CopyWorkflowFromLibraryToApp(ctx context.Context, w return wf.ID, nil } -func (w *ApplicationService) MoveWorkflowFromAppToLibrary(ctx context.Context, workflowID int64, spaceID, appID int64) (_ int64, _ []*vo.ValidateIssue, err error) { +func (w *ApplicationService) MoveWorkflowFromAppToLibrary(ctx context.Context, workflowID int64, spaceID, /*not used for now*/ + appID int64) (_ int64, _ []*vo.ValidateIssue, err error) { defer func() { if panicErr := recover(); panicErr != nil { err = safego.NewPanicErr(panicErr, debug.Stack()) @@ -1127,15 +1124,10 @@ func (w *ApplicationService) MoveWorkflowFromAppToLibrary(ctx context.Context, w } func convertNodeExecution(nodeExe *entity.NodeExecution) (*workflow.NodeResult, error) { - nType, err := entityNodeTypeToAPINodeTemplateType(nodeExe.NodeType) - if err != nil { - return nil, err - } - nr := &workflow.NodeResult{ NodeId: nodeExe.NodeID, NodeName: nodeExe.NodeName, - NodeType: nType.String(), + NodeType: entity.NodeMetaByNodeType(nodeExe.NodeType).GetDisplayKey(), NodeStatus: workflow.NodeExeStatus(nodeExe.Status), ErrorInfo: ptr.FromOrDefault(nodeExe.ErrorInfo, ""), Input: ptr.FromOrDefault(nodeExe.Input, ""), @@ -1316,13 +1308,6 @@ func convertStreamRunEvent(workflowID int64) func(msg *entity.Message) (res *wor return nil, schema.ErrNoValue } - var nodeType workflow.NodeTemplateType - nodeType, err = entityNodeTypeToAPINodeTemplateType(msg.NodeType) - if err != nil { - logs.Errorf("convert node type %v failed, err:=%v", msg.NodeType, err) - nodeType = workflow.NodeTemplateType(0) - } - res = &workflow.OpenAPIStreamRunFlowResponse{ ID: strconv.Itoa(messageID), Event: string(MessageEvent), @@ -1330,7 +1315,7 @@ func convertStreamRunEvent(workflowID int64) func(msg *entity.Message) (res *wor Content: ptr.Of(msg.Content), ContentType: ptr.Of("text"), NodeIsFinish: ptr.Of(msg.Last), - NodeType: ptr.Of(nodeType.String()), + NodeType: ptr.Of(entity.NodeMetaByNodeType(msg.NodeType).GetDisplayKey()), NodeID: ptr.Of(msg.NodeID), } @@ -3344,178 +3329,6 @@ func toWorkflowParameter(nType *vo.NamedTypeInfo) (*workflow.Parameter, error) { return wp, nil } -func nodeType2EntityNodeType(t string) (entity.NodeType, error) { - i, err := strconv.Atoi(t) - if err != nil { - return "", fmt.Errorf("invalid node type string '%s': %w", t, err) - } - - switch i { - case 1: - return entity.NodeTypeEntry, nil - case 2: - return entity.NodeTypeExit, nil - case 3: - return entity.NodeTypeLLM, nil - case 4: - return entity.NodeTypePlugin, nil - case 5: - return entity.NodeTypeCodeRunner, nil - case 6: - return entity.NodeTypeKnowledgeRetriever, nil - case 8: - return entity.NodeTypeSelector, nil - case 9: - return entity.NodeTypeSubWorkflow, nil - case 12: - return entity.NodeTypeDatabaseCustomSQL, nil - case 13: - return entity.NodeTypeOutputEmitter, nil - case 15: - return entity.NodeTypeTextProcessor, nil - case 18: - return entity.NodeTypeQuestionAnswer, nil - case 19: - return entity.NodeTypeBreak, nil - case 20: - return entity.NodeTypeVariableAssignerWithinLoop, nil - case 21: - return entity.NodeTypeLoop, nil - case 22: - return entity.NodeTypeIntentDetector, nil - case 27: - return entity.NodeTypeKnowledgeIndexer, nil - case 28: - return entity.NodeTypeBatch, nil - case 29: - return entity.NodeTypeContinue, nil - case 30: - return entity.NodeTypeInputReceiver, nil - case 32: - return entity.NodeTypeVariableAggregator, nil - case 37: - return entity.NodeTypeMessageList, nil - case 38: - return entity.NodeTypeClearMessage, nil - case 39: - return entity.NodeTypeCreateConversation, nil - case 40: - return entity.NodeTypeVariableAssigner, nil - case 42: - return entity.NodeTypeDatabaseUpdate, nil - case 43: - return entity.NodeTypeDatabaseQuery, nil - case 44: - return entity.NodeTypeDatabaseDelete, nil - case 45: - return entity.NodeTypeHTTPRequester, nil - case 46: - return entity.NodeTypeDatabaseInsert, nil - case 58: - return entity.NodeTypeJsonSerialization, nil - case 59: - return entity.NodeTypeJsonDeserialization, nil - case 60: - return entity.NodeTypeKnowledgeDeleter, nil - default: - // Handle all unknown or unsupported types here - return "", fmt.Errorf("unsupported or unknown node type ID: %d", i) - } -} - -// entityNodeTypeToAPINodeTemplateType converts an entity.NodeType to the corresponding workflow.NodeTemplateType. -func entityNodeTypeToAPINodeTemplateType(nodeType entity.NodeType) (workflow.NodeTemplateType, error) { - switch nodeType { - case entity.NodeTypeEntry: - return workflow.NodeTemplateType_Start, nil - case entity.NodeTypeExit: - return workflow.NodeTemplateType_End, nil - case entity.NodeTypeLLM: - return workflow.NodeTemplateType_LLM, nil - case entity.NodeTypePlugin: - // Maps to Api type in the API model - return workflow.NodeTemplateType_Api, nil - case entity.NodeTypeCodeRunner: - return workflow.NodeTemplateType_Code, nil - case entity.NodeTypeKnowledgeRetriever: - // Maps to Dataset type in the API model - return workflow.NodeTemplateType_Dataset, nil - case entity.NodeTypeSelector: - // Maps to If type in the API model - return workflow.NodeTemplateType_If, nil - case entity.NodeTypeSubWorkflow: - return workflow.NodeTemplateType_SubWorkflow, nil - case entity.NodeTypeDatabaseCustomSQL: - // Maps to the generic Database type in the API model - return workflow.NodeTemplateType_Database, nil - case entity.NodeTypeOutputEmitter: - // Maps to Message type in the API model - return workflow.NodeTemplateType_Message, nil - case entity.NodeTypeTextProcessor: - return workflow.NodeTemplateType_Text, nil - case entity.NodeTypeQuestionAnswer: - return workflow.NodeTemplateType_Question, nil - case entity.NodeTypeBreak: - return workflow.NodeTemplateType_Break, nil - case entity.NodeTypeVariableAssigner: - return workflow.NodeTemplateType_AssignVariable, nil - case entity.NodeTypeVariableAssignerWithinLoop: - return workflow.NodeTemplateType_LoopSetVariable, nil - case entity.NodeTypeLoop: - return workflow.NodeTemplateType_Loop, nil - case entity.NodeTypeIntentDetector: - return workflow.NodeTemplateType_Intent, nil - case entity.NodeTypeKnowledgeIndexer: - // Maps to DatasetWrite type in the API model - return workflow.NodeTemplateType_DatasetWrite, nil - case entity.NodeTypeBatch: - return workflow.NodeTemplateType_Batch, nil - case entity.NodeTypeContinue: - return workflow.NodeTemplateType_Continue, nil - case entity.NodeTypeInputReceiver: - return workflow.NodeTemplateType_Input, nil - case entity.NodeTypeMessageList: - return workflow.NodeTemplateType(37), nil - case entity.NodeTypeVariableAggregator: - return workflow.NodeTemplateType(32), nil - case entity.NodeTypeClearMessage: - return workflow.NodeTemplateType(38), nil - case entity.NodeTypeCreateConversation: - return workflow.NodeTemplateType(39), nil - // Note: entity.NodeTypeVariableAggregator (ID 32) has no direct mapping in NodeTemplateType - // Note: entity.NodeTypeMessageList (ID 37) has no direct mapping in NodeTemplateType - // Note: entity.NodeTypeClearMessage (ID 38) has no direct mapping in NodeTemplateType - // Note: entity.NodeTypeCreateConversation (ID 39) has no direct mapping in NodeTemplateType - case entity.NodeTypeDatabaseUpdate: - return workflow.NodeTemplateType_DatabaseUpdate, nil - case entity.NodeTypeDatabaseQuery: - // Maps to DatabasesELECT (ID 43) in the API model (note potential typo) - return workflow.NodeTemplateType_DatabasesELECT, nil - case entity.NodeTypeDatabaseDelete: - return workflow.NodeTemplateType_DatabaseDelete, nil - - // Note: entity.NodeTypeHTTPRequester (ID 45) has no direct mapping in NodeTemplateType - case entity.NodeTypeHTTPRequester: - return workflow.NodeTemplateType(45), nil - - case entity.NodeTypeDatabaseInsert: - // Maps to DatabaseInsert (ID 41) in the API model, despite entity ID being 46. - // return workflow.NodeTemplateType_DatabaseInsert, nil - return workflow.NodeTemplateType(46), nil - case entity.NodeTypeJsonSerialization: - return workflow.NodeTemplateType(58), nil - case entity.NodeTypeJsonDeserialization: - return workflow.NodeTemplateType_JsonDeserialization, nil - case entity.NodeTypeKnowledgeDeleter: - return workflow.NodeTemplateType_DatasetDelete, nil - case entity.NodeTypeLambda: - return 0, nil - default: - // Handle entity types that don't have a corresponding NodeTemplateType - return workflow.NodeTemplateType(0), fmt.Errorf("cannot map entity node type '%s' to a workflow.NodeTemplateType", nodeType) - } -} - func i64PtrToStringPtr(i *int64) *string { if i == nil { return nil @@ -3761,7 +3574,7 @@ func mergeWorkflowAPIParameters(latestAPIParameters []*workflow.APIParameter, ex func parseWorkflowTerminatePlanType(c *vo.Canvas) (int32, error) { var endNode *vo.Node for _, n := range c.Nodes { - if n.Type == vo.BlockTypeBotEnd { + if n.Type == entity.NodeTypeExit.IDStr() { endNode = n break } diff --git a/backend/domain/plugin/service/exec_tool_test.go b/backend/domain/plugin/service/exec_tool_test.go deleted file mode 100644 index 1b1d28b7..00000000 --- a/backend/domain/plugin/service/exec_tool_test.go +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright 2025 coze-dev Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package service - -import ( - "net/http" - "net/url" - "testing" - - . "github.com/bytedance/mockey" - "github.com/stretchr/testify/assert" -) - -func TestGenRequestString(t *testing.T) { - PatchConvey("", t, func() { - requestStr, err := genRequestString(&http.Request{ - Header: http.Header{ - "Content-Type": []string{"application/json"}, - }, - Method: http.MethodPost, - URL: &url.URL{Path: "/test"}, - }, []byte(`{"a": 1}`)) - assert.NoError(t, err) - assert.Equal(t, `{"header":{"Content-Type":["application/json"]},"query":null,"path":"/test","body":{"a": 1}}`, requestStr) - }) - - PatchConvey("", t, func() { - var body []byte - requestStr, err := genRequestString(&http.Request{ - URL: &url.URL{Path: "/test"}, - }, body) - assert.NoError(t, err) - assert.Equal(t, `{"header":null,"query":null,"path":"/test","body":null}`, requestStr) - }) -} diff --git a/backend/domain/workflow/entity/node_meta.go b/backend/domain/workflow/entity/node_meta.go index 6b11c941..184d9734 100644 --- a/backend/domain/workflow/entity/node_meta.go +++ b/backend/domain/workflow/entity/node_meta.go @@ -1,4 +1,5 @@ /* + * Copyright 2025 coze-dev Authors * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,58 +17,85 @@ package entity +import ( + "fmt" + "strconv" +) + type NodeType string +func (nt NodeType) IDStr() string { + m := NodeMetaByNodeType(nt) + if m == nil { + return "" + } + return fmt.Sprintf("%d", m.ID) +} + +func IDStrToNodeType(s string) NodeType { + id, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return "" + } + for _, m := range NodeTypeMetas { + if m.ID == id { + return m.Key + } + } + return "" +} + type NodeTypeMeta struct { - ID int64 `json:"id"` - Name string `json:"name"` - Type NodeType `json:"type"` - Category string `json:"category"` - Color string `json:"color"` - Desc string `json:"desc"` - IconURL string `json:"icon_url"` - SupportBatch bool `json:"support_batch"` - Disabled bool `json:"disabled,omitempty"` - EnUSName string `json:"en_us_name,omitempty"` - EnUSDescription string `json:"en_us_description,omitempty"` + ID int64 + Key NodeType + DisplayKey string + Name string `json:"name"` + Category string `json:"category"` + Color string `json:"color"` + Desc string `json:"desc"` + IconURL string `json:"icon_url"` + SupportBatch bool `json:"support_batch"` + Disabled bool `json:"disabled,omitempty"` + EnUSName string `json:"en_us_name,omitempty"` + EnUSDescription string `json:"en_us_description,omitempty"` ExecutableMeta } +func (ntm *NodeTypeMeta) GetDisplayKey() string { + if len(ntm.DisplayKey) > 0 { + return ntm.DisplayKey + } + + return string(ntm.Key) +} + type Category struct { Key string `json:"key"` Name string `json:"name"` EnUSName string `json:"en_us_name"` } -type StreamingParadigm string - -const ( - Invoke StreamingParadigm = "invoke" - Stream StreamingParadigm = "stream" - Collect StreamingParadigm = "collect" - Transform StreamingParadigm = "transform" -) - type ExecutableMeta struct { - IsComposite bool `json:"is_composite,omitempty"` - DefaultTimeoutMS int64 `json:"default_timeout_ms,omitempty"` // default timeout in milliseconds, 0 means no timeout - PreFillZero bool `json:"pre_fill_zero,omitempty"` - PostFillNil bool `json:"post_fill_nil,omitempty"` - CallbackEnabled bool `json:"callback_enabled,omitempty"` // is false, Eino framework will inject callbacks for this node - MayUseChatModel bool `json:"may_use_chat_model,omitempty"` - InputSourceAware bool `json:"input_source_aware,omitempty"` // whether this node needs to know the runtime status of its input sources - StreamingParadigms map[StreamingParadigm]bool `json:"streaming_paradigms,omitempty"` - StreamSourceEOFAware bool `json:"needs_stream_source_eof,omitempty"` // whether this node needs to be aware stream sources' SourceEOF error - /* - IncrementalOutput indicates that the node's output is intended for progressive, user-facing streaming. - This distinguishes nodes that actually stream text to the user (e.g., 'Exit', 'Output') - from those that are merely capable of streaming internally (defined by StreamingParadigms), - whose output is consumed by other nodes. - In essence, nodes with IncrementalOutput are a subset of those defined in StreamingParadigms. - When set to true, stream chunks from the node are persisted in real-time and can be fetched by get_process. - */ + IsComposite bool `json:"is_composite,omitempty"` + DefaultTimeoutMS int64 `json:"default_timeout_ms,omitempty"` // default timeout in milliseconds, 0 means no timeout + PreFillZero bool `json:"pre_fill_zero,omitempty"` + PostFillNil bool `json:"post_fill_nil,omitempty"` + MayUseChatModel bool `json:"may_use_chat_model,omitempty"` + InputSourceAware bool `json:"input_source_aware,omitempty"` // whether this node needs to know the runtime status of its input sources + StreamSourceEOFAware bool `json:"needs_stream_source_eof,omitempty"` // whether this node needs to be aware stream sources' SourceEOF error + + // IncrementalOutput indicates that the node's output is intended for progressive, user-facing streaming. + // This distinguishes nodes that actually stream text to the user (e.g., 'Exit', 'Output') + //from those that are merely capable of streaming internally (defined by StreamingParadigms), + // In essence, nodes with IncrementalOutput are a subset of those defined in StreamingParadigms. + // When set to true, stream chunks from the node are persisted in real-time and can be fetched by get_process. IncrementalOutput bool `json:"incremental_output,omitempty"` + + // UseCtxCache indicates that the node would require a newly initialized ctx cache for each invocation. + // example use cases: + // - write warnings to the ctx cache during Invoke, and read from the ctx within Callback output converter + UseCtxCache bool `json:"use_ctx_cache"` } type PluginNodeMeta struct { @@ -125,9 +153,700 @@ const ( NodeTypeSubWorkflow NodeType = "SubWorkflow" NodeTypeJsonSerialization NodeType = "JsonSerialization" NodeTypeJsonDeserialization NodeType = "JsonDeserialization" + NodeTypeComment NodeType = "Comment" ) const ( EntryNodeKey = "100001" ExitNodeKey = "900001" ) + +var Categories = []Category{ + { + Key: "", // this is the default category. some of the most important nodes belong here, such as LLM, plugin, sub-workflow + Name: "", + EnUSName: "", + }, + { + Key: "logic", + Name: "业务逻辑", + EnUSName: "Logic", + }, + { + Key: "input&output", + Name: "输入&输出", + EnUSName: "Input&Output", + }, + { + Key: "database", + Name: "数据库", + EnUSName: "Database", + }, + { + Key: "data", + Name: "知识库&数据", + EnUSName: "Data", + }, + { + Key: "image", + Name: "图像处理", + EnUSName: "Image", + }, + { + Key: "audio&video", + Name: "音视频处理", + EnUSName: "Audio&Video", + }, + { + Key: "utilities", + Name: "组件", + EnUSName: "Utilities", + }, + { + Key: "conversation_management", + Name: "会话管理", + EnUSName: "Conversation management", + }, + { + Key: "conversation_history", + Name: "会话历史", + EnUSName: "Conversation history", + }, + { + Key: "message", + Name: "消息", + EnUSName: "Message", + }, +} + +// NodeTypeMetas holds the metadata for all available node types. +// It is initialized with built-in node types and potentially extended by loading from external sources. +var NodeTypeMetas = map[NodeType]*NodeTypeMeta{ + NodeTypeEntry: { + ID: 1, + Key: NodeTypeEntry, + DisplayKey: "Start", + Name: "开始", + Category: "input&output", + Desc: "工作流的起始节点,用于设定启动工作流需要的信息", + Color: "#5C62FF", + IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start-v2.jpg", + SupportBatch: false, + ExecutableMeta: ExecutableMeta{ + PostFillNil: true, + }, + EnUSName: "Start", + EnUSDescription: "The starting node of the workflow, used to set the information needed to initiate the workflow.", + }, + NodeTypeExit: { + ID: 2, + Key: NodeTypeExit, + DisplayKey: "End", + Name: "结束", + Category: "input&output", + Desc: "工作流的最终节点,用于返回工作流运行后的结果信息", + Color: "#5C62FF", + IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End-v2.jpg", + SupportBatch: false, + ExecutableMeta: ExecutableMeta{ + PreFillZero: true, + InputSourceAware: true, + StreamSourceEOFAware: true, + IncrementalOutput: true, + }, + EnUSName: "End", + EnUSDescription: "The final node of the workflow, used to return the result information after the workflow runs.", + }, + NodeTypeLLM: { + ID: 3, + Key: NodeTypeLLM, + DisplayKey: "LLM", + Name: "大模型", + Category: "", + Desc: "调用大语言模型,使用变量和提示词生成回复", + Color: "#5C62FF", + IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-LLM-v2.jpg", + SupportBatch: true, + ExecutableMeta: ExecutableMeta{ + DefaultTimeoutMS: 3 * 60 * 1000, // 3 minutes + PreFillZero: true, + PostFillNil: true, + InputSourceAware: true, + MayUseChatModel: true, + }, + EnUSName: "LLM", + EnUSDescription: "Invoke the large language model, generate responses using variables and prompt words.", + }, + NodeTypePlugin: { + ID: 4, + Key: NodeTypePlugin, + DisplayKey: "Api", + Name: "插件", + Category: "", + Desc: "通过添加工具访问实时数据和执行外部操作", + Color: "#CA61FF", + IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Plugin-v2.jpg", + SupportBatch: true, + ExecutableMeta: ExecutableMeta{ + DefaultTimeoutMS: 3 * 60 * 1000, // 3 minutes + PreFillZero: true, + PostFillNil: true, + }, + EnUSName: "Plugin", + EnUSDescription: "Used to access external real-time data and perform operations", + }, + NodeTypeCodeRunner: { + ID: 5, + Key: NodeTypeCodeRunner, + DisplayKey: "Code", + Name: "代码", + Category: "logic", + Desc: "编写代码,处理输入变量来生成返回值", + Color: "#00B2B2", + IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Code-v2.jpg", + SupportBatch: false, + ExecutableMeta: ExecutableMeta{ + DefaultTimeoutMS: 60 * 1000, // 1 minute + PreFillZero: true, + PostFillNil: true, + UseCtxCache: true, + }, + EnUSName: "Code", + EnUSDescription: "Write code to process input variables to generate return values.", + }, + NodeTypeKnowledgeRetriever: { + ID: 6, + Key: NodeTypeKnowledgeRetriever, + DisplayKey: "Dataset", + Name: "知识库检索", + Category: "data", + Desc: "在选定的知识中,根据输入变量召回最匹配的信息,并以列表形式返回", + Color: "#FF811A", + IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-KnowledgeQuery-v2.jpg", + SupportBatch: false, + ExecutableMeta: ExecutableMeta{ + DefaultTimeoutMS: 60 * 1000, // 1 minute + PreFillZero: true, + PostFillNil: true, + }, + EnUSName: "Knowledge retrieval", + EnUSDescription: "In the selected knowledge, the best matching information is recalled based on the input variable and returned as an Array.", + }, + NodeTypeSelector: { + ID: 8, + Key: NodeTypeSelector, + DisplayKey: "If", + Name: "选择器", + Category: "logic", + Desc: "连接多个下游分支,若设定的条件成立则仅运行对应的分支,若均不成立则只运行“否则”分支", + Color: "#00B2B2", + IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Condition-v2.jpg", + SupportBatch: false, + ExecutableMeta: ExecutableMeta{}, + EnUSName: "Condition", + EnUSDescription: "Connect multiple downstream branches. Only the corresponding branch will be executed if the set conditions are met. If none are met, only the 'else' branch will be executed.", + }, + NodeTypeSubWorkflow: { + ID: 9, + Key: NodeTypeSubWorkflow, + DisplayKey: "SubWorkflow", + Name: "工作流", + Category: "", + Desc: "集成已发布工作流,可以执行嵌套子任务", + Color: "#00B83E", + IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Workflow-v2.jpg", + SupportBatch: true, + ExecutableMeta: ExecutableMeta{}, + EnUSName: "Workflow", + EnUSDescription: "Add published workflows to execute subtasks", + }, + NodeTypeDatabaseCustomSQL: { + ID: 12, + Key: NodeTypeDatabaseCustomSQL, + DisplayKey: "End", + Name: "SQL自定义", + Category: "database", + Desc: "基于用户自定义的 SQL 完成对数据库的增删改查操作", + Color: "#FF811A", + IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Database-v2.jpg", + SupportBatch: false, + ExecutableMeta: ExecutableMeta{ + DefaultTimeoutMS: 60 * 1000, // 1 minute + PreFillZero: true, + PostFillNil: true, + }, + EnUSName: "SQL Customization", + EnUSDescription: "Complete the operations of adding, deleting, modifying and querying the database based on user-defined SQL", + }, + NodeTypeOutputEmitter: { + ID: 13, + Key: NodeTypeOutputEmitter, + DisplayKey: "Message", + Name: "输出", + Category: "input&output", + Desc: "节点从“消息”更名为“输出”,支持中间过程的消息输出,支持流式和非流式两种方式", + Color: "#5C62FF", + IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Output-v2.jpg", + SupportBatch: false, + ExecutableMeta: ExecutableMeta{ + PreFillZero: true, + InputSourceAware: true, + StreamSourceEOFAware: true, + IncrementalOutput: true, + }, + EnUSName: "Output", + EnUSDescription: "The node is renamed from \"message\" to \"output\", Supports message output in the intermediate process and streaming and non-streaming methods", + }, + NodeTypeTextProcessor: { + ID: 15, + Key: NodeTypeTextProcessor, + DisplayKey: "Text", + Name: "文本处理", + Category: "utilities", + Desc: "用于处理多个字符串类型变量的格式", + Color: "#3071F2", + IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-StrConcat-v2.jpg", + SupportBatch: false, + ExecutableMeta: ExecutableMeta{ + PreFillZero: true, + InputSourceAware: true, + }, + EnUSName: "Text Processing", + EnUSDescription: "The format used for handling multiple string-type variables.", + }, + NodeTypeQuestionAnswer: { + ID: 18, + Key: NodeTypeQuestionAnswer, + DisplayKey: "Question", + Name: "问答", + Category: "utilities", + Desc: "支持中间向用户提问问题,支持预置选项提问和开放式问题提问两种方式", + Color: "#3071F2", + IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Direct-Question-v2.jpg", + SupportBatch: false, + ExecutableMeta: ExecutableMeta{ + DefaultTimeoutMS: 60 * 1000, // 1 minute + PreFillZero: true, + PostFillNil: true, + MayUseChatModel: true, + }, + EnUSName: "Question", + EnUSDescription: "Support asking questions to the user in the middle of the conversation, with both preset options and open-ended questions", + }, + NodeTypeBreak: { + ID: 19, + Key: NodeTypeBreak, + DisplayKey: "Break", + Name: "终止循环", + Category: "logic", + Desc: "用于立即终止当前所在的循环,跳出循环体", + Color: "#00B2B2", + IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Break-v2.jpg", + SupportBatch: false, + ExecutableMeta: ExecutableMeta{}, + EnUSName: "Break", + EnUSDescription: "Used to immediately terminate the current loop and jump out of the loop", + }, + NodeTypeVariableAssignerWithinLoop: { + ID: 20, + Key: NodeTypeVariableAssignerWithinLoop, + DisplayKey: "LoopSetVariable", + Name: "设置变量", + Category: "logic", + Desc: "用于重置循环变量的值,使其下次循环使用重置后的值", + Color: "#00B2B2", + IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-LoopSetVariable-v2.jpg", + SupportBatch: false, + ExecutableMeta: ExecutableMeta{}, + EnUSName: "Set Variable", + EnUSDescription: "Used to reset the value of the loop variable so that it uses the reset value in the next iteration", + }, + NodeTypeLoop: { + ID: 21, + Key: NodeTypeLoop, + DisplayKey: "Loop", + Name: "循环", + Category: "logic", + Desc: "用于通过设定循环次数和逻辑,重复执行一系列任务", + Color: "#00B2B2", + IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Loop-v2.jpg", + SupportBatch: false, + ExecutableMeta: ExecutableMeta{ + IsComposite: true, + DefaultTimeoutMS: 60 * 1000, // 1 minute + PreFillZero: true, + PostFillNil: true, + }, + EnUSName: "Loop", + EnUSDescription: "Used to repeatedly execute a series of tasks by setting the number of iterations and logic", + }, + NodeTypeIntentDetector: { + ID: 22, + Key: NodeTypeIntentDetector, + DisplayKey: "Intent", + Name: "意图识别", + Category: "logic", + Desc: "用于用户输入的意图识别,并将其与预设意图选项进行匹配。", + Color: "#00B2B2", + IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Intent-v2.jpg", + SupportBatch: false, + ExecutableMeta: ExecutableMeta{ + DefaultTimeoutMS: 60 * 1000, // 1 minute + PreFillZero: true, + PostFillNil: true, + MayUseChatModel: true, + }, + EnUSName: "Intent recognition", + EnUSDescription: "Used for recognizing the intent in user input and matching it with preset intent options.", + }, + NodeTypeKnowledgeIndexer: { + ID: 27, + Key: NodeTypeKnowledgeIndexer, + DisplayKey: "DatasetWrite", + Name: "知识库写入", + Category: "data", + Desc: "写入节点可以添加 文本类型 的知识库,仅可以添加一个知识库", + Color: "#FF811A", + IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-KnowledgeWriting-v2.jpg", + SupportBatch: false, + ExecutableMeta: ExecutableMeta{ + DefaultTimeoutMS: 60 * 1000, // 1 minute + PreFillZero: true, + PostFillNil: true, + }, + EnUSName: "Knowledge writing", + EnUSDescription: "The write node can add a knowledge base of type text. Only one knowledge base can be added.", + }, + NodeTypeBatch: { + ID: 28, + Key: NodeTypeBatch, + DisplayKey: "Batch", + Name: "批处理", + Category: "logic", + Desc: "通过设定批量运行次数和逻辑,运行批处理体内的任务", + Color: "#00B2B2", + IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Batch-v2.jpg", + SupportBatch: false, + ExecutableMeta: ExecutableMeta{ + IsComposite: true, + DefaultTimeoutMS: 60 * 1000, // 1 minute + PreFillZero: true, + PostFillNil: true, + }, + EnUSName: "Batch", + EnUSDescription: "By setting the number of batch runs and logic, run the tasks in the batch body.", + }, + NodeTypeContinue: { + ID: 29, + Key: NodeTypeContinue, + DisplayKey: "Continue", + Name: "继续循环", + Category: "logic", + Desc: "用于终止当前循环,执行下次循环", + Color: "#00B2B2", + IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Continue-v2.jpg", + SupportBatch: false, + ExecutableMeta: ExecutableMeta{}, + EnUSName: "Continue", + EnUSDescription: "Used to immediately terminate the current loop and execute next loop", + }, + NodeTypeInputReceiver: { + ID: 30, + Key: NodeTypeInputReceiver, + DisplayKey: "Input", + Name: "输入", + Category: "input&output", + Desc: "支持中间过程的信息输入", + Color: "#5C62FF", + IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Input-v2.jpg", + SupportBatch: false, + ExecutableMeta: ExecutableMeta{ + PostFillNil: true, + }, + EnUSName: "Input", + EnUSDescription: "Support intermediate information input", + }, + NodeTypeComment: { + ID: 31, + Key: "", + Name: "注释", + Category: "", // Not found in cate_list + Desc: "comment_desc", // Placeholder from JSON + Color: "", + IconURL: "comment_icon", // Placeholder from JSON + SupportBatch: false, // supportBatch: 1 + EnUSName: "Comment", + }, + NodeTypeVariableAggregator: { + ID: 32, + Key: NodeTypeVariableAggregator, + Name: "变量聚合", + Category: "logic", + Desc: "对多个分支的输出进行聚合处理", + Color: "#00B2B2", + IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/VariableMerge-icon.jpg", + SupportBatch: false, + ExecutableMeta: ExecutableMeta{ + PostFillNil: true, + InputSourceAware: true, + UseCtxCache: true, + }, + EnUSName: "Variable Merge", + EnUSDescription: "Aggregate the outputs of multiple branches.", + }, + NodeTypeMessageList: { + ID: 37, + Key: NodeTypeMessageList, + Name: "查询消息列表", + Category: "message", + Desc: "用于查询消息列表", + Color: "#F2B600", + IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Conversation-List.jpeg", + SupportBatch: false, + Disabled: true, + ExecutableMeta: ExecutableMeta{ + PreFillZero: true, + PostFillNil: true, + }, + EnUSName: "Query message list", + EnUSDescription: "Used to query the message list", + }, + NodeTypeClearMessage: { + ID: 38, + Key: NodeTypeClearMessage, + Name: "清除上下文", + Category: "conversation_history", + Desc: "用于清空会话历史,清空后LLM看到的会话历史为空", + Color: "#F2B600", + IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Conversation-Delete.jpeg", + SupportBatch: false, + Disabled: true, + ExecutableMeta: ExecutableMeta{ + PreFillZero: true, + PostFillNil: true, + }, + EnUSName: "Clear conversation history", + EnUSDescription: "Used to clear conversation history. After clearing, the conversation history visible to the LLM node will be empty.", + }, + NodeTypeCreateConversation: { + ID: 39, + Key: NodeTypeCreateConversation, + Name: "创建会话", + Category: "conversation_management", + Desc: "用于创建会话", + Color: "#F2B600", + IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Conversation-Create.jpeg", + SupportBatch: false, + Disabled: true, + ExecutableMeta: ExecutableMeta{ + PreFillZero: true, + PostFillNil: true, + }, + EnUSName: "Create conversation", + EnUSDescription: "This node is used to create a conversation.", + }, + NodeTypeVariableAssigner: { + ID: 40, + Key: NodeTypeVariableAssigner, + DisplayKey: "AssignVariable", + Name: "变量赋值", + Category: "data", + Desc: "用于给支持写入的变量赋值,包括应用变量、用户变量", + Color: "#FF811A", + IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/Variable.jpg", + SupportBatch: false, + ExecutableMeta: ExecutableMeta{}, + EnUSName: "Variable assign", + EnUSDescription: "Assigns values to variables that support the write operation, including app and user variables.", + }, + NodeTypeDatabaseUpdate: { + ID: 42, + Key: NodeTypeDatabaseUpdate, + DisplayKey: "DatabaseUpdate", + Name: "更新数据", + Category: "database", + Desc: "修改表中已存在的数据记录,用户指定更新条件和内容来更新数据", + Color: "#F2B600", + IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-database-update.jpg", + SupportBatch: false, + ExecutableMeta: ExecutableMeta{ + DefaultTimeoutMS: 60 * 1000, // 1 minute + PreFillZero: true, + }, + EnUSName: "Update Data", + EnUSDescription: "Modify the existing data records in the table, and the user specifies the update conditions and contents to update the data", + }, + NodeTypeDatabaseQuery: { + ID: 43, + Key: NodeTypeDatabaseQuery, + DisplayKey: "DatabaseSelect", + Name: "查询数据", + Category: "database", + Desc: "从表获取数据,用户可定义查询条件、选择列等,输出符合条件的数据", + Color: "#F2B600", + IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icaon-database-select.jpg", + SupportBatch: false, + ExecutableMeta: ExecutableMeta{ + DefaultTimeoutMS: 60 * 1000, // 1 minute + PreFillZero: true, + }, + EnUSName: "Query Data", + EnUSDescription: "Query data from the table, and the user can define query conditions, select columns, etc., and output the data that meets the conditions", + }, + NodeTypeDatabaseDelete: { + ID: 44, + Key: NodeTypeDatabaseDelete, + DisplayKey: "DatabaseDelete", + Name: "删除数据", + Category: "database", + Desc: "从表中删除数据记录,用户指定删除条件来删除符合条件的记录", + Color: "#F2B600", + IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-database-delete.jpg", + SupportBatch: false, + ExecutableMeta: ExecutableMeta{ + DefaultTimeoutMS: 60 * 1000, // 1 minute + PreFillZero: true, + }, + EnUSName: "Delete Data", + EnUSDescription: "Delete data records from the table, and the user specifies the deletion conditions to delete the records that meet the conditions", + }, + NodeTypeHTTPRequester: { + ID: 45, + Key: NodeTypeHTTPRequester, + DisplayKey: "Http", + Name: "HTTP 请求", + Category: "utilities", + Desc: "用于发送API请求,从接口返回数据", + Color: "#3071F2", + IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-HTTP.png", + SupportBatch: false, + ExecutableMeta: ExecutableMeta{ + DefaultTimeoutMS: 60 * 1000, // 1 minute + PreFillZero: true, + PostFillNil: true, + }, + EnUSName: "HTTP request", + EnUSDescription: "It is used to send API requests and return data from the interface.", + }, + NodeTypeDatabaseInsert: { + ID: 46, + Key: NodeTypeDatabaseInsert, + DisplayKey: "DatabaseInsert", + Name: "新增数据", + Category: "database", + Desc: "向表添加新数据记录,用户输入数据内容后插入数据库", + Color: "#F2B600", + IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-database-insert.jpg", + SupportBatch: false, + ExecutableMeta: ExecutableMeta{ + DefaultTimeoutMS: 60 * 1000, // 1 minute + PreFillZero: true, + }, + EnUSName: "Add Data", + EnUSDescription: "Add new data records to the table, and insert them into the database after the user enters the data content", + }, + NodeTypeJsonSerialization: { + // ID is the unique identifier of this node type. Used in various front-end APIs. + ID: 58, + + // Key is the unique NodeType of this node. Used in backend code as well as saved in DB. + Key: NodeTypeJsonSerialization, + + // DisplayKey is the string used in frontend to identify this node. + // Example use cases: + // - used during querying test-run results for nodes + // - used in returned messages from streaming openAPI Runs. + // If empty, will use Key as DisplayKey. + DisplayKey: "ToJSON", + + // Name is the node in ZH_CN, will be displayed on Canvas. + Name: "JSON 序列化", + + // Category is the category of this node, determines which category this node will be displayed in. + Category: "utilities", + + // Desc is the desc in ZH_CN, will be displayed as tooltip on Canvas. + Desc: "用于把变量转化为JSON字符串", + + // Color is the color of the upper edge of the node displayed on Canvas. + Color: "F2B600", + + // IconURL is the URL of the icon displayed on Canvas. + IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-to_json.png", + + // SupportBatch indicates whether this node can set batch mode. + // NOTE: ultimately it's frontend that decides which node can enable batch mode. + SupportBatch: false, + + // ExecutableMeta configures certain common aspects of request-time behaviors for this node. + ExecutableMeta: ExecutableMeta{ + // DefaultTimeoutMS configures the default timeout for this node, in milliseconds. 0 means no timeout. + DefaultTimeoutMS: 60 * 1000, // 1 minute + // PreFillZero decides whether to pre-fill zero value for any missing fields in input. + PreFillZero: true, + // PostFillNil decides whether to post-fill nil value for any missing fields in output. + PostFillNil: true, + }, + // EnUSName is the name in EN_US, will be displayed on Canvas if language of Coze-Studio is set to EnUS. + EnUSName: "JSON serialization", + // EnUSDescription is the description in EN_US, will be displayed on Canvas if language of Coze-Studio is set to EnUS. + EnUSDescription: "Convert variable to JSON string", + }, + NodeTypeJsonDeserialization: { + ID: 59, + Key: NodeTypeJsonDeserialization, + DisplayKey: "FromJSON", + Name: "JSON 反序列化", + Category: "utilities", + Desc: "用于将JSON字符串解析为变量", + Color: "F2B600", + IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-from_json.png", + SupportBatch: false, + ExecutableMeta: ExecutableMeta{ + DefaultTimeoutMS: 60 * 1000, // 1 minute + PreFillZero: true, + PostFillNil: true, + UseCtxCache: true, + }, + EnUSName: "JSON deserialization", + EnUSDescription: "Parse JSON string to variable", + }, + NodeTypeKnowledgeDeleter: { + ID: 60, + Key: NodeTypeKnowledgeDeleter, + DisplayKey: "KnowledgeDelete", + Name: "知识库删除", + Category: "data", + Desc: "用于删除知识库中的文档", + Color: "#FF811A", + IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icons-dataset-delete.png", + SupportBatch: false, + ExecutableMeta: ExecutableMeta{ + DefaultTimeoutMS: 60 * 1000, // 1 minute + PreFillZero: true, + PostFillNil: true, + }, + EnUSName: "Knowledge delete", + EnUSDescription: "The delete node can delete a document in knowledge base.", + }, + NodeTypeLambda: { + ID: 1000, + Key: NodeTypeLambda, + Name: "Lambda", + EnUSName: "Comment", + }, +} + +// PluginNodeMetas holds metadata for specific plugin API entity. +var PluginNodeMetas []*PluginNodeMeta + +// PluginCategoryMetas holds metadata for plugin category entity. +var PluginCategoryMetas []*PluginCategoryMeta + +func NodeMetaByNodeType(t NodeType) *NodeTypeMeta { + if m, ok := NodeTypeMetas[t]; ok { + return m + } + + return nil +} diff --git a/backend/domain/workflow/entity/node_type_literal.go b/backend/domain/workflow/entity/node_type_literal.go deleted file mode 100644 index cc2d3b98..00000000 --- a/backend/domain/workflow/entity/node_type_literal.go +++ /dev/null @@ -1,867 +0,0 @@ -/* - * Copyright 2025 coze-dev Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package entity - -import ( - "github.com/coze-dev/coze-studio/backend/pkg/i18n" - "github.com/coze-dev/coze-studio/backend/pkg/lang/ternary" -) - -var Categories = []Category{ - { - Key: "", // this is the default category. some of the most important nodes belong here, such as LLM, plugin, sub-workflow - Name: "", - EnUSName: "", - }, - { - Key: "logic", - Name: "业务逻辑", - EnUSName: "Logic", - }, - { - Key: "input&output", - Name: "输入&输出", - EnUSName: "Input&Output", - }, - { - Key: "database", - Name: "数据库", - EnUSName: "Database", - }, - { - Key: "data", - Name: "知识库&数据", - EnUSName: "Data", - }, - { - Key: "image", - Name: "图像处理", - EnUSName: "Image", - }, - { - Key: "audio&video", - Name: "音视频处理", - EnUSName: "Audio&Video", - }, - { - Key: "utilities", - Name: "组件", - EnUSName: "Utilities", - }, - { - Key: "conversation_management", - Name: "会话管理", - EnUSName: "Conversation management", - }, - { - Key: "conversation_history", - Name: "会话历史", - EnUSName: "Conversation history", - }, - { - Key: "message", - Name: "消息", - EnUSName: "Message", - }, -} - -// NodeTypeMetas holds the metadata for all available node types. -// It is initialized with built-in types and potentially extended by loading from external sources. -var NodeTypeMetas = []*NodeTypeMeta{ - { - ID: 1, - Name: "开始", - Type: NodeTypeEntry, - Category: "input&output", // Mapped from cate_list - Desc: "工作流的起始节点,用于设定启动工作流需要的信息", - Color: "#5C62FF", - IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start-v2.jpg", - SupportBatch: false, // supportBatch: 1 - ExecutableMeta: ExecutableMeta{ - PostFillNil: true, - StreamingParadigms: map[StreamingParadigm]bool{Invoke: true}, - }, - EnUSName: "Start", - EnUSDescription: "The starting node of the workflow, used to set the information needed to initiate the workflow.", - }, - { - ID: 2, - Name: "结束", - Type: NodeTypeExit, - Category: "input&output", // Mapped from cate_list - Desc: "工作流的最终节点,用于返回工作流运行后的结果信息", - Color: "#5C62FF", - IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End-v2.jpg", - SupportBatch: false, // supportBatch: 1 - ExecutableMeta: ExecutableMeta{ - PreFillZero: true, - CallbackEnabled: true, - InputSourceAware: true, - StreamingParadigms: map[StreamingParadigm]bool{Invoke: true, Transform: true}, - StreamSourceEOFAware: true, - IncrementalOutput: true, - }, - EnUSName: "End", - EnUSDescription: "The final node of the workflow, used to return the result information after the workflow runs.", - }, - { - ID: 3, - Name: "大模型", - Type: NodeTypeLLM, - Category: "", // Mapped from cate_list - Desc: "调用大语言模型,使用变量和提示词生成回复", - Color: "#5C62FF", - IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-LLM-v2.jpg", - SupportBatch: true, // supportBatch: 2 - ExecutableMeta: ExecutableMeta{ - DefaultTimeoutMS: 3 * 60 * 1000, // 3 minutes - PreFillZero: true, - PostFillNil: true, - CallbackEnabled: true, - InputSourceAware: true, - MayUseChatModel: true, - StreamingParadigms: map[StreamingParadigm]bool{Invoke: true, Stream: true}, - }, - EnUSName: "LLM", - EnUSDescription: "Invoke the large language model, generate responses using variables and prompt words.", - }, - - { - ID: 4, - Name: "插件", - Type: NodeTypePlugin, - Category: "", // Mapped from cate_list - Desc: "通过添加工具访问实时数据和执行外部操作", - Color: "#CA61FF", - IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Plugin-v2.jpg", - SupportBatch: true, // supportBatch: 2 - ExecutableMeta: ExecutableMeta{ - DefaultTimeoutMS: 3 * 60 * 1000, // 3 minutes - PreFillZero: true, - PostFillNil: true, - StreamingParadigms: map[StreamingParadigm]bool{Invoke: true}, - }, - EnUSName: "Plugin", - EnUSDescription: "Used to access external real-time data and perform operations", - }, - { - ID: 5, - Name: "代码", - Type: NodeTypeCodeRunner, - Category: "logic", // Mapped from cate_list - Desc: "编写代码,处理输入变量来生成返回值", - Color: "#00B2B2", - IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Code-v2.jpg", - SupportBatch: false, // supportBatch: 1 - ExecutableMeta: ExecutableMeta{ - DefaultTimeoutMS: 60 * 1000, // 1 minute - PreFillZero: true, - PostFillNil: true, - CallbackEnabled: true, - StreamingParadigms: map[StreamingParadigm]bool{Invoke: true}, - }, - EnUSName: "Code", - EnUSDescription: "Write code to process input variables to generate return values.", - }, - { - ID: 6, - Name: "知识库检索", - Type: NodeTypeKnowledgeRetriever, - Category: "data", // Mapped from cate_list - Desc: "在选定的知识中,根据输入变量召回最匹配的信息,并以列表形式返回", - Color: "#FF811A", - IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-KnowledgeQuery-v2.jpg", - SupportBatch: false, // supportBatch: 1 - ExecutableMeta: ExecutableMeta{ - DefaultTimeoutMS: 60 * 1000, // 1 minute - PreFillZero: true, - PostFillNil: true, - StreamingParadigms: map[StreamingParadigm]bool{Invoke: true}, - }, - EnUSName: "Knowledge retrieval", - EnUSDescription: "In the selected knowledge, the best matching information is recalled based on the input variable and returned as an Array.", - }, - { - ID: 8, - Name: "选择器", - Type: NodeTypeSelector, - Category: "logic", // Mapped from cate_list - Desc: "连接多个下游分支,若设定的条件成立则仅运行对应的分支,若均不成立则只运行“否则”分支", - Color: "#00B2B2", - IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Condition-v2.jpg", - SupportBatch: false, // supportBatch: 1 - ExecutableMeta: ExecutableMeta{ - CallbackEnabled: true, - StreamingParadigms: map[StreamingParadigm]bool{Invoke: true}, - }, - EnUSName: "Condition", - EnUSDescription: "Connect multiple downstream branches. Only the corresponding branch will be executed if the set conditions are met. If none are met, only the 'else' branch will be executed.", - }, - { - ID: 9, - Name: "工作流", - Type: NodeTypeSubWorkflow, - Category: "", // Mapped from cate_list - Desc: "集成已发布工作流,可以执行嵌套子任务", - Color: "#00B83E", - IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Workflow-v2.jpg", - SupportBatch: true, // supportBatch: 2 - ExecutableMeta: ExecutableMeta{ - CallbackEnabled: true, - StreamingParadigms: map[StreamingParadigm]bool{Invoke: true}, - }, - EnUSName: "Workflow", - EnUSDescription: "Add published workflows to execute subtasks", - }, - { - ID: 12, - Name: "SQL自定义", - Type: NodeTypeDatabaseCustomSQL, - Category: "database", // Mapped from cate_list - Desc: "基于用户自定义的 SQL 完成对数据库的增删改查操作", - Color: "#FF811A", - IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Database-v2.jpg", - SupportBatch: false, // supportBatch: 2 - ExecutableMeta: ExecutableMeta{ - DefaultTimeoutMS: 60 * 1000, // 1 minute - PreFillZero: true, - PostFillNil: true, - StreamingParadigms: map[StreamingParadigm]bool{Invoke: true}, - }, - EnUSName: "SQL Customization", - EnUSDescription: "Complete the operations of adding, deleting, modifying and querying the database based on user-defined SQL", - }, - { - ID: 13, - Name: "输出", - Type: NodeTypeOutputEmitter, - Category: "input&output", // Mapped from cate_list - Desc: "节点从“消息”更名为“输出”,支持中间过程的消息输出,支持流式和非流式两种方式", - Color: "#5C62FF", - IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Output-v2.jpg", - SupportBatch: false, - ExecutableMeta: ExecutableMeta{ - PreFillZero: true, - CallbackEnabled: true, - InputSourceAware: true, - StreamingParadigms: map[StreamingParadigm]bool{Invoke: true, Stream: true}, - StreamSourceEOFAware: true, - IncrementalOutput: true, - }, - EnUSName: "Output", - EnUSDescription: "The node is renamed from \"message\" to \"output\", Supports message output in the intermediate process and streaming and non-streaming methods", - }, - { - ID: 15, - Name: "文本处理", - Type: NodeTypeTextProcessor, - Category: "utilities", // Mapped from cate_list - Desc: "用于处理多个字符串类型变量的格式", - Color: "#3071F2", - IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-StrConcat-v2.jpg", - SupportBatch: false, // supportBatch: 2 - ExecutableMeta: ExecutableMeta{ - PreFillZero: true, - CallbackEnabled: true, - InputSourceAware: true, - StreamingParadigms: map[StreamingParadigm]bool{Invoke: true}, - }, - EnUSName: "Text Processing", - EnUSDescription: "The format used for handling multiple string-type variables.", - }, - { - ID: 18, - Name: "问答", - Type: NodeTypeQuestionAnswer, - Category: "utilities", // Mapped from cate_list - Desc: "支持中间向用户提问问题,支持预置选项提问和开放式问题提问两种方式", - Color: "#3071F2", - IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Direct-Question-v2.jpg", - SupportBatch: false, // supportBatch: 1 - ExecutableMeta: ExecutableMeta{ - DefaultTimeoutMS: 60 * 1000, // 1 minute - PreFillZero: true, - PostFillNil: true, - CallbackEnabled: true, - MayUseChatModel: true, - StreamingParadigms: map[StreamingParadigm]bool{Invoke: true}, - }, - EnUSName: "Question", - EnUSDescription: "Support asking questions to the user in the middle of the conversation, with both preset options and open-ended questions", - }, - { - ID: 19, - Name: "终止循环", - Type: NodeTypeBreak, - Category: "logic", // Mapped from cate_list - Desc: "用于立即终止当前所在的循环,跳出循环体", - Color: "#00B2B2", - IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Break-v2.jpg", - SupportBatch: false, // supportBatch: 1 - ExecutableMeta: ExecutableMeta{ - StreamingParadigms: map[StreamingParadigm]bool{Invoke: true}, - }, - EnUSName: "Break", - EnUSDescription: "Used to immediately terminate the current loop and jump out of the loop", - }, - { - ID: 20, - Name: "设置变量", - Type: NodeTypeVariableAssignerWithinLoop, - Category: "logic", // Mapped from cate_list - Desc: "用于重置循环变量的值,使其下次循环使用重置后的值", - Color: "#00B2B2", - IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-LoopSetVariable-v2.jpg", - SupportBatch: false, // supportBatch: 1 - ExecutableMeta: ExecutableMeta{ - StreamingParadigms: map[StreamingParadigm]bool{Invoke: true}, - }, - EnUSName: "Set Variable", - EnUSDescription: "Used to reset the value of the loop variable so that it uses the reset value in the next iteration", - }, - { - ID: 21, - Name: "循环", - Type: NodeTypeLoop, - Category: "logic", // Mapped from cate_list - Desc: "用于通过设定循环次数和逻辑,重复执行一系列任务", - Color: "#00B2B2", - IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Loop-v2.jpg", - SupportBatch: false, - ExecutableMeta: ExecutableMeta{ - IsComposite: true, - DefaultTimeoutMS: 60 * 1000, // 1 minute - PreFillZero: true, - PostFillNil: true, - CallbackEnabled: true, - StreamingParadigms: map[StreamingParadigm]bool{Invoke: true}, - }, - EnUSName: "Loop", - EnUSDescription: "Used to repeatedly execute a series of tasks by setting the number of iterations and logic", - }, - { - ID: 22, - Name: "意图识别", - Type: NodeTypeIntentDetector, - Category: "logic", // Mapped from cate_list - Desc: "用于用户输入的意图识别,并将其与预设意图选项进行匹配。", - Color: "#00B2B2", - IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Intent-v2.jpg", - SupportBatch: false, // supportBatch: 1 - ExecutableMeta: ExecutableMeta{ - DefaultTimeoutMS: 60 * 1000, // 1 minute - PreFillZero: true, - PostFillNil: true, - CallbackEnabled: true, - MayUseChatModel: true, - StreamingParadigms: map[StreamingParadigm]bool{Invoke: true}, - }, - EnUSName: "Intent recognition", - EnUSDescription: "Used for recognizing the intent in user input and matching it with preset intent options.", - }, - { - ID: 27, - Name: "知识库写入", - Type: NodeTypeKnowledgeIndexer, - Category: "data", // Mapped from cate_list - Desc: "写入节点可以添加 文本类型 的知识库,仅可以添加一个知识库", - Color: "#FF811A", - IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-KnowledgeWriting-v2.jpg", - SupportBatch: false, // supportBatch: 1 - ExecutableMeta: ExecutableMeta{ - DefaultTimeoutMS: 60 * 1000, // 1 minute - PreFillZero: true, - PostFillNil: true, - StreamingParadigms: map[StreamingParadigm]bool{Invoke: true}, - }, - EnUSName: "Knowledge writing", - EnUSDescription: "The write node can add a knowledge base of type text. Only one knowledge base can be added.", - }, - { - ID: 28, - Name: "批处理", - Type: NodeTypeBatch, - Category: "logic", // Mapped from cate_list - Desc: "通过设定批量运行次数和逻辑,运行批处理体内的任务", - Color: "#00B2B2", - IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Batch-v2.jpg", - SupportBatch: false, // supportBatch: 1 (Corrected from previous assumption) - ExecutableMeta: ExecutableMeta{ - IsComposite: true, - DefaultTimeoutMS: 60 * 1000, // 1 minute - PreFillZero: true, - PostFillNil: true, - CallbackEnabled: true, - StreamingParadigms: map[StreamingParadigm]bool{Invoke: true}, - }, - EnUSName: "Batch", - EnUSDescription: "By setting the number of batch runs and logic, run the tasks in the batch body.", - }, - { - ID: 29, - Name: "继续循环", - Type: NodeTypeContinue, - Category: "logic", // Mapped from cate_list - Desc: "用于终止当前循环,执行下次循环", - Color: "#00B2B2", - IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Continue-v2.jpg", - SupportBatch: false, // supportBatch: 1 - ExecutableMeta: ExecutableMeta{ - StreamingParadigms: map[StreamingParadigm]bool{Invoke: true}, - }, - EnUSName: "Continue", - EnUSDescription: "Used to immediately terminate the current loop and execute next loop", - }, - { - ID: 30, - Name: "输入", - Type: NodeTypeInputReceiver, - Category: "input&output", // Mapped from cate_list - Desc: "支持中间过程的信息输入", - Color: "#5C62FF", - IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Input-v2.jpg", - SupportBatch: false, // supportBatch: 1 - ExecutableMeta: ExecutableMeta{ - PostFillNil: true, - CallbackEnabled: true, - StreamingParadigms: map[StreamingParadigm]bool{Invoke: true}, - }, - EnUSName: "Input", - EnUSDescription: "Support intermediate information input", - }, - { - ID: 31, - Name: "注释", - Type: "", - Category: "", // Not found in cate_list - Desc: "comment_desc", // Placeholder from JSON - Color: "", - IconURL: "comment_icon", // Placeholder from JSON - SupportBatch: false, // supportBatch: 1 - EnUSName: "Comment", - }, - { - ID: 32, - Name: "变量聚合", - Type: NodeTypeVariableAggregator, - Category: "logic", // Mapped from cate_list - Desc: "对多个分支的输出进行聚合处理", - Color: "#00B2B2", - IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/VariableMerge-icon.jpg", - SupportBatch: false, // supportBatch: 1 - ExecutableMeta: ExecutableMeta{ - PostFillNil: true, - CallbackEnabled: true, - InputSourceAware: true, - StreamingParadigms: map[StreamingParadigm]bool{Invoke: true, Transform: true}, - }, - EnUSName: "Variable Merge", - EnUSDescription: "Aggregate the outputs of multiple branches.", - }, - { - ID: 37, - Name: "查询消息列表", - Type: NodeTypeMessageList, - Category: "message", // Mapped from cate_list - Desc: "用于查询消息列表", - Color: "#F2B600", - IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Conversation-List.jpeg", - SupportBatch: false, // supportBatch: 1 - Disabled: true, - ExecutableMeta: ExecutableMeta{ - PreFillZero: true, - PostFillNil: true, - StreamingParadigms: map[StreamingParadigm]bool{Invoke: true}, - }, - EnUSName: "Query message list", - EnUSDescription: "Used to query the message list", - }, - { - ID: 38, - Name: "清除上下文", - Type: NodeTypeClearMessage, - Category: "conversation_history", // Mapped from cate_list - Desc: "用于清空会话历史,清空后LLM看到的会话历史为空", - Color: "#F2B600", - IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Conversation-Delete.jpeg", - SupportBatch: false, // supportBatch: 1 - Disabled: true, - ExecutableMeta: ExecutableMeta{ - PreFillZero: true, - PostFillNil: true, - StreamingParadigms: map[StreamingParadigm]bool{Invoke: true}, - }, - EnUSName: "Clear conversation history", - EnUSDescription: "Used to clear conversation history. After clearing, the conversation history visible to the LLM node will be empty.", - }, - { - ID: 39, - Name: "创建会话", - Type: NodeTypeCreateConversation, - Category: "conversation_management", // Mapped from cate_list - Desc: "用于创建会话", - Color: "#F2B600", - IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Conversation-Create.jpeg", - SupportBatch: false, // supportBatch: 1 - Disabled: true, - ExecutableMeta: ExecutableMeta{ - PreFillZero: true, - PostFillNil: true, - StreamingParadigms: map[StreamingParadigm]bool{Invoke: true}, - }, - EnUSName: "Create conversation", - EnUSDescription: "This node is used to create a conversation.", - }, - { - ID: 40, - Name: "变量赋值", - Type: NodeTypeVariableAssigner, - Category: "data", // Mapped from cate_list - Desc: "用于给支持写入的变量赋值,包括应用变量、用户变量", - Color: "#FF811A", - IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/Variable.jpg", - SupportBatch: false, // supportBatch: 1 - ExecutableMeta: ExecutableMeta{ - StreamingParadigms: map[StreamingParadigm]bool{Invoke: true}, - }, - EnUSName: "Variable assign", - EnUSDescription: "Assigns values to variables that support the write operation, including app and user variables.", - }, - { - ID: 42, - Name: "更新数据", - Type: NodeTypeDatabaseUpdate, - Category: "database", // Mapped from cate_list - Desc: "修改表中已存在的数据记录,用户指定更新条件和内容来更新数据", - Color: "#F2B600", - IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-database-update.jpg", // Corrected Icon URL from JSON - SupportBatch: false, // supportBatch: 1 - ExecutableMeta: ExecutableMeta{ - DefaultTimeoutMS: 60 * 1000, // 1 minute - PreFillZero: true, - CallbackEnabled: true, - StreamingParadigms: map[StreamingParadigm]bool{Invoke: true}, - }, - EnUSName: "Update Data", - EnUSDescription: "Modify the existing data records in the table, and the user specifies the update conditions and contents to update the data", - }, - { - ID: 43, - Name: "查询数据", // Corrected Name from JSON (was "insert data") - Type: NodeTypeDatabaseQuery, - Category: "database", // Mapped from cate_list - Desc: "从表获取数据,用户可定义查询条件、选择列等,输出符合条件的数据", // Corrected Desc from JSON - Color: "#F2B600", - IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icaon-database-select.jpg", // Corrected Icon URL from JSON - SupportBatch: false, // supportBatch: 1 - ExecutableMeta: ExecutableMeta{ - DefaultTimeoutMS: 60 * 1000, // 1 minute - PreFillZero: true, - CallbackEnabled: true, - StreamingParadigms: map[StreamingParadigm]bool{Invoke: true}, - }, - EnUSName: "Query Data", - EnUSDescription: "Query data from the table, and the user can define query conditions, select columns, etc., and output the data that meets the conditions", - }, - { - ID: 44, - Name: "删除数据", - Type: NodeTypeDatabaseDelete, - Category: "database", // Mapped from cate_list - Desc: "从表中删除数据记录,用户指定删除条件来删除符合条件的记录", // Corrected Desc from JSON - Color: "#F2B600", - IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-database-delete.jpg", // Corrected Icon URL from JSON - SupportBatch: false, // supportBatch: 1 - ExecutableMeta: ExecutableMeta{ - DefaultTimeoutMS: 60 * 1000, // 1 minute - PreFillZero: true, - CallbackEnabled: true, - StreamingParadigms: map[StreamingParadigm]bool{Invoke: true}, - }, - EnUSName: "Delete Data", - EnUSDescription: "Delete data records from the table, and the user specifies the deletion conditions to delete the records that meet the conditions", - }, - { - ID: 45, - Name: "HTTP 请求", - Type: NodeTypeHTTPRequester, - Category: "utilities", // Mapped from cate_list - Desc: "用于发送API请求,从接口返回数据", // Corrected Desc from JSON - Color: "#3071F2", - IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-HTTP.png", // Corrected Icon URL from JSON - SupportBatch: false, // supportBatch: 1 - ExecutableMeta: ExecutableMeta{ - DefaultTimeoutMS: 60 * 1000, // 1 minute - PreFillZero: true, - PostFillNil: true, - CallbackEnabled: true, - StreamingParadigms: map[StreamingParadigm]bool{Invoke: true}, - }, - EnUSName: "HTTP request", - EnUSDescription: "It is used to send API requests and return data from the interface.", - }, - { - ID: 46, - Name: "新增数据", // Corrected Name from JSON (was "Query Data") - Type: NodeTypeDatabaseInsert, - Category: "database", // Mapped from cate_list - Desc: "向表添加新数据记录,用户输入数据内容后插入数据库", // Corrected Desc from JSON - Color: "#F2B600", - IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-database-insert.jpg", // Corrected Icon URL from JSON - SupportBatch: false, // supportBatch: 1 - ExecutableMeta: ExecutableMeta{ - DefaultTimeoutMS: 60 * 1000, // 1 minute - PreFillZero: true, - CallbackEnabled: true, - StreamingParadigms: map[StreamingParadigm]bool{Invoke: true}, - }, - EnUSName: "Add Data", - EnUSDescription: "Add new data records to the table, and insert them into the database after the user enters the data content", - }, - { - ID: 58, - Name: "JSON 序列化", - Type: NodeTypeJsonSerialization, - Category: "utilities", - Desc: "用于把变量转化为JSON字符串", - Color: "F2B600", - IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-to_json.png", - SupportBatch: false, - ExecutableMeta: ExecutableMeta{ - DefaultTimeoutMS: 60 * 1000, // 1 minute - PreFillZero: true, - CallbackEnabled: true, - StreamingParadigms: map[StreamingParadigm]bool{Invoke: true}, - }, - EnUSName: "JSON serialization", - EnUSDescription: "Convert variable to JSON string", - }, - { - ID: 59, - Name: "JSON 反序列化", - Type: NodeTypeJsonDeserialization, - Category: "utilities", - Desc: "用于将JSON字符串解析为变量", - Color: "F2B600", - IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-from_json.png", - SupportBatch: false, - ExecutableMeta: ExecutableMeta{ - DefaultTimeoutMS: 60 * 1000, // 1 minute - PreFillZero: true, - PostFillNil: true, - CallbackEnabled: true, - StreamingParadigms: map[StreamingParadigm]bool{Invoke: true}, - }, - EnUSName: "JSON deserialization", - EnUSDescription: "Parse JSON string to variable", - }, - { - ID: 60, - Name: "知识库删除", - Type: NodeTypeKnowledgeDeleter, - Category: "data", // Mapped from cate_list - Desc: "用于删除知识库中的文档", - Color: "#FF811A", - IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icons-dataset-delete.png", - SupportBatch: false, // supportBatch: 1 - ExecutableMeta: ExecutableMeta{ - DefaultTimeoutMS: 60 * 1000, // 1 minute - PreFillZero: true, - PostFillNil: true, - StreamingParadigms: map[StreamingParadigm]bool{Invoke: true}, - }, - EnUSName: "Knowledge delete", - EnUSDescription: "The delete node can delete a document in knowledge base.", - }, - // --- End of nodes parsed from template_list --- -} - -// PluginNodeMetas holds metadata for specific plugin API entity. -var PluginNodeMetas []*PluginNodeMeta - -// PluginCategoryMetas holds metadata for plugin category entity. -var PluginCategoryMetas []*PluginCategoryMeta - -func NodeMetaByNodeType(t NodeType) *NodeTypeMeta { - for _, meta := range NodeTypeMetas { - if meta.Type == t { - return meta - } - } - - return nil -} - -const defaultZhCNInitCanvasJsonSchema = `{ - "nodes": [ - { - "id": "100001", - "type": "1", - "meta": { - "position": { - "x": 0, - "y": 0 - } - }, - "data": { - "nodeMeta": { - "description": "工作流的起始节点,用于设定启动工作流需要的信息", - "icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start.png", - "subTitle": "", - "title": "开始" - }, - "outputs": [ - { - "type": "string", - "name": "input", - "required": false - } - ], - "trigger_parameters": [ - { - "type": "string", - "name": "input", - "required": false - } - ] - } - }, - { - "id": "900001", - "type": "2", - "meta": { - "position": { - "x": 1000, - "y": 0 - } - }, - "data": { - "nodeMeta": { - "description": "工作流的最终节点,用于返回工作流运行后的结果信息", - "icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End.png", - "subTitle": "", - "title": "结束" - }, - "inputs": { - "terminatePlan": "returnVariables", - "inputParameters": [ - { - "name": "output", - "input": { - "type": "string", - "value": { - "type": "ref", - "content": { - "source": "block-output", - "blockID": "", - "name": "" - } - } - } - } - ] - } - } - } - ], - "edges": [], - "versions": { - "loop": "v2" - } -}` - -const defaultEnUSInitCanvasJsonSchema = `{ - "nodes": [ - { - "id": "100001", - "type": "1", - "meta": { - "position": { - "x": 0, - "y": 0 - } - }, - "data": { - "nodeMeta": { - "description": "The starting node of the workflow, used to set the information needed to initiate the workflow.", - "icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start.png", - "subTitle": "", - "title": "Start" - }, - "outputs": [ - { - "type": "string", - "name": "input", - "required": false - } - ], - "trigger_parameters": [ - { - "type": "string", - "name": "input", - "required": false - } - ] - } - }, - { - "id": "900001", - "type": "2", - "meta": { - "position": { - "x": 1000, - "y": 0 - } - }, - "data": { - "nodeMeta": { - "description": "The final node of the workflow, used to return the result information after the workflow runs.", - "icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End.png", - "subTitle": "", - "title": "End" - }, - "inputs": { - "terminatePlan": "returnVariables", - "inputParameters": [ - { - "name": "output", - "input": { - "type": "string", - "value": { - "type": "ref", - "content": { - "source": "block-output", - "blockID": "", - "name": "" - } - } - } - } - ] - } - } - } - ], - "edges": [], - "versions": { - "loop": "v2" - } -}` - -func GetDefaultInitCanvasJsonSchema(locale i18n.Locale) string { - return ternary.IFElse(locale == i18n.LocaleEN, defaultEnUSInitCanvasJsonSchema, defaultZhCNInitCanvasJsonSchema) -} diff --git a/backend/domain/workflow/entity/vo/canvas.go b/backend/domain/workflow/entity/vo/canvas.go index 936009e8..5f410332 100644 --- a/backend/domain/workflow/entity/vo/canvas.go +++ b/backend/domain/workflow/entity/vo/canvas.go @@ -19,24 +19,48 @@ package vo import ( "github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/workflow" "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model" + "github.com/coze-dev/coze-studio/backend/pkg/i18n" + "github.com/coze-dev/coze-studio/backend/pkg/lang/ternary" ) +// Canvas is the definition of FRONTEND schema for a workflow. type Canvas struct { Nodes []*Node `json:"nodes"` Edges []*Edge `json:"edges"` Versions any `json:"versions"` } +// Node represents a node within a workflow canvas. type Node struct { - ID string `json:"id"` - Type BlockType `json:"type"` - Meta any `json:"meta"` - Data *Data `json:"data"` - Blocks []*Node `json:"blocks,omitempty"` - Edges []*Edge `json:"edges,omitempty"` - Version string `json:"version,omitempty"` + // ID is the unique node ID within the workflow. + // In normal use cases, this ID is generated by frontend. + // It does NOT need to be unique between parent workflow and sub workflows. + // The Entry node and Exit node of a workflow always have fixed node IDs: 100001 and 900001. + ID string `json:"id"` - parent *Node + // Type is the Node Type of this node instance. + // It corresponds to the string value of 'ID' field of NodeMeta. + Type string `json:"type"` + + // Meta holds meta data used by frontend, such as the node's position within canvas. + Meta any `json:"meta"` + + // Data holds the actual configurations of a node, such as inputs, outputs and exception handling. + // It also holds exclusive configurations for different node types, such as LLM configurations. + Data *Data `json:"data"` + + // Blocks holds the sub nodes of this node. + // It is only used when the node type is Composite, such as NodeTypeBatch and NodeTypeLoop. + Blocks []*Node `json:"blocks,omitempty"` + + // Edges are the connections between nodes. + // Strictly corresponds to connections drawn on canvas. + Edges []*Edge `json:"edges,omitempty"` + + // Version is the version of this node type's schema. + Version string `json:"version,omitempty"` + + parent *Node // if this node is within a composite node, coze will set this. No need to set manually } func (n *Node) SetParent(parent *Node) { @@ -47,7 +71,7 @@ func (n *Node) Parent() *Node { return n.parent } -type NodeMeta struct { +type NodeMetaFE struct { Title string `json:"title,omitempty"` Description string `json:"description,omitempty"` Icon string `json:"icon,omitempty"` @@ -62,52 +86,88 @@ type Edge struct { TargetPortID string `json:"targetPortID,omitempty"` } +// Data holds the actual configuration of a Node. type Data struct { - Meta *NodeMeta `json:"nodeMeta,omitempty"` - Outputs []any `json:"outputs,omitempty"` // either []*Variable or []*Param - Inputs *Inputs `json:"inputs,omitempty"` - Size any `json:"size,omitempty"` + // Meta is the meta data of this node. Only used by frontend. + Meta *NodeMetaFE `json:"nodeMeta,omitempty"` + + // Outputs configures the output fields and their types. + // Outputs can be either []*Variable (most of the cases, just fields and types), + // or []*Param (used by composite nodes as they need to refer to outputs of sub nodes) + Outputs []any `json:"outputs,omitempty"` + + // Inputs configures ALL input information of a node, + // including fixed input fields and dynamic input fields defined by user. + Inputs *Inputs `json:"inputs,omitempty"` + + // Size configures the size of this node in frontend. + // Only used by NodeTypeComment. + Size any `json:"size,omitempty"` } type Inputs struct { - InputParameters []*Param `json:"inputParameters"` - Content *BlockInput `json:"content"` - TerminatePlan *TerminatePlan `json:"terminatePlan,omitempty"` - StreamingOutput bool `json:"streamingOutput,omitempty"` - CallTransferVoice bool `json:"callTransferVoice,omitempty"` - ChatHistoryWriting string `json:"chatHistoryWriting,omitempty"` - LLMParam any `json:"llmParam,omitempty"` // The LLMParam type may be one of the LLMParam or IntentDetectorLLMParam type or QALLMParam type - FCParam *FCParam `json:"fcParam,omitempty"` - SettingOnError *SettingOnError `json:"settingOnError,omitempty"` + // InputParameters are the fields defined by user for this particular node. + InputParameters []*Param `json:"inputParameters"` + // SettingOnError configures common error handling strategy for nodes. + // NOTE: enable in frontend node's form first. + SettingOnError *SettingOnError `json:"settingOnError,omitempty"` + + // NodeBatchInfo configures batch mode for nodes. + // NOTE: not to be confused with NodeTypeBatch. + NodeBatchInfo *NodeBatch `json:"batch,omitempty"` + + // LLMParam may be one of the LLMParam or IntentDetectorLLMParam or SimpleLLMParam. + // Shared between most nodes requiring an ChatModel to function. + LLMParam any `json:"llmParam,omitempty"` + + *OutputEmitter // exclusive configurations for NodeTypeEmitter and NodeTypeExit in Answer mode + *Exit // exclusive configurations for NodeTypeExit + *LLM // exclusive configurations for NodeTypeLLM + *Loop // exclusive configurations for NodeTypeLoop + *Selector // exclusive configurations for NodeTypeSelector + *TextProcessor // exclusive configurations for NodeTypeTextProcessor + *SubWorkflow // exclusive configurations for NodeTypeSubWorkflow + *IntentDetector // exclusive configurations for NodeTypeIntentDetector + *DatabaseNode // exclusive configurations for various Database nodes + *HttpRequestNode // exclusive configurations for NodeTypeHTTPRequester + *Knowledge // exclusive configurations for various Knowledge nodes + *CodeRunner // exclusive configurations for NodeTypeCodeRunner + *PluginAPIParam // exclusive configurations for NodeTypePlugin + *VariableAggregator // exclusive configurations for NodeTypeVariableAggregator + *VariableAssigner // exclusive configurations for NodeTypeVariableAssigner + *QA // exclusive configurations for NodeTypeQuestionAnswer + *Batch // exclusive configurations for NodeTypeBatch + *Comment // exclusive configurations for NodeTypeComment + *InputReceiver // exclusive configurations for NodeTypeInputReceiver +} + +type OutputEmitter struct { + Content *BlockInput `json:"content"` + StreamingOutput bool `json:"streamingOutput,omitempty"` +} + +type Exit struct { + TerminatePlan *TerminatePlan `json:"terminatePlan,omitempty"` +} + +type LLM struct { + FCParam *FCParam `json:"fcParam,omitempty"` +} + +type Loop struct { LoopType LoopType `json:"loopType,omitempty"` LoopCount *BlockInput `json:"loopCount,omitempty"` VariableParameters []*Param `json:"variableParameters,omitempty"` +} +type Selector struct { Branches []*struct { Condition struct { Logic LogicType `json:"logic"` Conditions []*Condition `json:"conditions"` } `json:"condition"` } `json:"branches,omitempty"` - - NodeBatchInfo *NodeBatch `json:"batch,omitempty"` // node in batch mode - - *TextProcessor - *SubWorkflow - *IntentDetector - *DatabaseNode - *HttpRequestNode - *KnowledgeIndexer - *CodeRunner - *PluginAPIParam - *VariableAggregator - *VariableAssigner - *QA - *Batch - *Comment - - OutputSchema string `json:"outputSchema,omitempty"` } type Comment struct { @@ -127,7 +187,7 @@ type VariableAssigner struct { type LLMParam = []*Param type IntentDetectorLLMParam = map[string]any -type QALLMParam struct { +type SimpleLLMParam struct { GenerationDiversity string `json:"generationDiversity"` MaxTokens int `json:"maxTokens"` ModelName string `json:"modelName"` @@ -248,7 +308,7 @@ type CodeRunner struct { Language int64 `json:"language"` } -type KnowledgeIndexer struct { +type Knowledge struct { DatasetParam []*Param `json:"datasetParam,omitempty"` StrategyParam StrategyParam `json:"strategyParam,omitempty"` } @@ -384,23 +444,54 @@ type ChatHistorySetting struct { type Intent struct { Name string `json:"name"` } + +// Param is a node's field with type and source info. type Param struct { - Name string `json:"name,omitempty"` - Input *BlockInput `json:"input,omitempty"` - Left *BlockInput `json:"left,omitempty"` - Right *BlockInput `json:"right,omitempty"` + // Name is the field's name. + Name string `json:"name,omitempty"` + + // Input is the configurations for normal, singular field. + Input *BlockInput `json:"input,omitempty"` + + // Left is the configurations for the left half of an expression, + // such as an assignment in NodeTypeVariableAssigner. + Left *BlockInput `json:"left,omitempty"` + + // Right is the configuration for the right half of an expression. + Right *BlockInput `json:"right,omitempty"` + + // Variables are configurations for a group of fields. + // Only used in NodeTypeVariableAggregator. Variables []*BlockInput `json:"variables,omitempty"` } +// Variable is the configuration of a node's field, either input or output. type Variable struct { - Name string `json:"name"` - Type VariableType `json:"type"` - Required bool `json:"required,omitempty"` - AssistType AssistType `json:"assistType,omitempty"` - Schema any `json:"schema,omitempty"` // either []*Variable (for object) or *Variable (for list) - Description string `json:"description,omitempty"` - ReadOnly bool `json:"readOnly,omitempty"` - DefaultValue any `json:"defaultValue,omitempty"` + // Name is the field's name as defined on canvas. + Name string `json:"name"` + + // Type is the field's data type, such as string, integer, number, object, array, etc. + Type VariableType `json:"type"` + + // Required is set to true if you checked the 'required box' on this field + Required bool `json:"required,omitempty"` + + // AssistType is the 'secondary' type of string fields, such as different types of file and image, or time. + AssistType AssistType `json:"assistType,omitempty"` + + // Schema contains detailed info for sub-fields of an object field, or element type of an array. + Schema any `json:"schema,omitempty"` // either []*Variable (for object) or *Variable (for list) + + // Description describes the field's intended use. Used on Entry node. Useful for workflow tools. + Description string `json:"description,omitempty"` + + // ReadOnly indicates a field is not to be set by Node's business logic. + // e.g. the ErrorBody field when exception strategy is configured. + ReadOnly bool `json:"readOnly,omitempty"` + + // DefaultValue configures the 'default value' if this field is missing in input. + // Effective only in Entry node. + DefaultValue any `json:"defaultValue,omitempty"` } type BlockInput struct { @@ -436,48 +527,6 @@ type SubWorkflow struct { SpaceID string `json:"spaceId,omitempty"` } -// BlockType is the enumeration of node types for front-end canvas schema. -// To add a new BlockType, start from a really big number such as 1000, to avoid conflict with future extensions. -type BlockType string - -func (b BlockType) String() string { - return string(b) -} - -const ( - BlockTypeBotStart BlockType = "1" - BlockTypeBotEnd BlockType = "2" - BlockTypeBotLLM BlockType = "3" - BlockTypeBotAPI BlockType = "4" - BlockTypeBotCode BlockType = "5" - BlockTypeBotDataset BlockType = "6" - BlockTypeCondition BlockType = "8" - BlockTypeBotSubWorkflow BlockType = "9" - BlockTypeDatabase BlockType = "12" - BlockTypeBotMessage BlockType = "13" - BlockTypeBotText BlockType = "15" - BlockTypeQuestion BlockType = "18" - BlockTypeBotBreak BlockType = "19" - BlockTypeBotLoopSetVariable BlockType = "20" - BlockTypeBotLoop BlockType = "21" - BlockTypeBotIntent BlockType = "22" - BlockTypeBotDatasetWrite BlockType = "27" - BlockTypeBotInput BlockType = "30" - BlockTypeBotBatch BlockType = "28" - BlockTypeBotContinue BlockType = "29" - BlockTypeBotComment BlockType = "31" - BlockTypeBotVariableMerge BlockType = "32" - BlockTypeBotAssignVariable BlockType = "40" - BlockTypeDatabaseUpdate BlockType = "42" - BlockTypeDatabaseSelect BlockType = "43" - BlockTypeDatabaseDelete BlockType = "44" - BlockTypeBotHttp BlockType = "45" - BlockTypeDatabaseInsert BlockType = "46" - BlockTypeJsonSerialization BlockType = "58" - BlockTypeJsonDeserialization BlockType = "59" - BlockTypeBotDatasetDelete BlockType = "60" -) - type VariableType string const ( @@ -536,19 +585,31 @@ const ( type ErrorProcessType int const ( - ErrorProcessTypeThrow ErrorProcessType = 1 - ErrorProcessTypeDefault ErrorProcessType = 2 - ErrorProcessTypeExceptionBranch ErrorProcessType = 3 + ErrorProcessTypeThrow ErrorProcessType = 1 // throws the error as usual + ErrorProcessTypeReturnDefaultData ErrorProcessType = 2 // return DataOnErr configured in SettingOnError + ErrorProcessTypeExceptionBranch ErrorProcessType = 3 // executes the exception branch on error ) +// SettingOnError contains common error handling strategy. type SettingOnError struct { - DataOnErr string `json:"dataOnErr,omitempty"` - Switch bool `json:"switch,omitempty"` + // DataOnErr defines the JSON result to be returned on error. + DataOnErr string `json:"dataOnErr,omitempty"` + // Switch defines whether ANY error handling strategy is active. + // If set to false, it's equivalent to set ProcessType = ErrorProcessTypeThrow + Switch bool `json:"switch,omitempty"` + // ProcessType determines the error handling strategy for this node. ProcessType *ErrorProcessType `json:"processType,omitempty"` - RetryTimes int64 `json:"retryTimes,omitempty"` - TimeoutMs int64 `json:"timeoutMs,omitempty"` - Ext *struct { - BackupLLMParam string `json:"backupLLMParam,omitempty"` // only for LLM Node, marshaled from QALLMParam + // RetryTimes determines how many times to retry. 0 means no retry. + // If positive, any retries will be executed immediately after error. + RetryTimes int64 `json:"retryTimes,omitempty"` + // TimeoutMs sets the timeout duration in millisecond. + // If any retry happens, ALL retry attempts accumulates to the same timeout threshold. + TimeoutMs int64 `json:"timeoutMs,omitempty"` + // Ext sets any extra settings specific to NodeType + Ext *struct { + // BackupLLMParam is only for LLM Node, marshaled from SimpleLLMParam. + // If retry happens, the backup LLM will be used instead of the main LLM. + BackupLLMParam string `json:"backupLLMParam,omitempty"` } `json:"ext,omitempty"` } @@ -597,32 +658,8 @@ const ( LoopTypeInfinite LoopType = "infinite" ) -type WorkflowIdentity struct { - ID string `json:"id"` - Version string `json:"version"` -} - -func (c *Canvas) GetAllSubWorkflowIdentities() []*WorkflowIdentity { - workflowEntities := make([]*WorkflowIdentity, 0) - - var collectSubWorkFlowEntities func(nodes []*Node) - collectSubWorkFlowEntities = func(nodes []*Node) { - for _, n := range nodes { - if n.Type == BlockTypeBotSubWorkflow { - workflowEntities = append(workflowEntities, &WorkflowIdentity{ - ID: n.Data.Inputs.WorkflowID, - Version: n.Data.Inputs.WorkflowVersion, - }) - } - if len(n.Blocks) > 0 { - collectSubWorkFlowEntities(n.Blocks) - } - } - } - - collectSubWorkFlowEntities(c.Nodes) - - return workflowEntities +type InputReceiver struct { + OutputSchema string `json:"outputSchema,omitempty"` } func GenerateNodeIDForBatchMode(key string) string { @@ -632,3 +669,163 @@ func GenerateNodeIDForBatchMode(key string) string { func IsGeneratedNodeForBatchMode(key string, parentKey string) bool { return key == GenerateNodeIDForBatchMode(parentKey) } + +const defaultZhCNInitCanvasJsonSchema = `{ + "nodes": [ + { + "id": "100001", + "type": "1", + "meta": { + "position": { + "x": 0, + "y": 0 + } + }, + "data": { + "nodeMeta": { + "description": "工作流的起始节点,用于设定启动工作流需要的信息", + "icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start.png", + "subTitle": "", + "title": "开始" + }, + "outputs": [ + { + "type": "string", + "name": "input", + "required": false + } + ], + "trigger_parameters": [ + { + "type": "string", + "name": "input", + "required": false + } + ] + } + }, + { + "id": "900001", + "type": "2", + "meta": { + "position": { + "x": 1000, + "y": 0 + } + }, + "data": { + "nodeMeta": { + "description": "工作流的最终节点,用于返回工作流运行后的结果信息", + "icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End.png", + "subTitle": "", + "title": "结束" + }, + "inputs": { + "terminatePlan": "returnVariables", + "inputParameters": [ + { + "name": "output", + "input": { + "type": "string", + "value": { + "type": "ref", + "content": { + "source": "block-output", + "blockID": "", + "name": "" + } + } + } + } + ] + } + } + } + ], + "edges": [], + "versions": { + "loop": "v2" + } +}` + +const defaultEnUSInitCanvasJsonSchema = `{ + "nodes": [ + { + "id": "100001", + "type": "1", + "meta": { + "position": { + "x": 0, + "y": 0 + } + }, + "data": { + "nodeMeta": { + "description": "The starting node of the workflow, used to set the information needed to initiate the workflow.", + "icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start.png", + "subTitle": "", + "title": "Start" + }, + "outputs": [ + { + "type": "string", + "name": "input", + "required": false + } + ], + "trigger_parameters": [ + { + "type": "string", + "name": "input", + "required": false + } + ] + } + }, + { + "id": "900001", + "type": "2", + "meta": { + "position": { + "x": 1000, + "y": 0 + } + }, + "data": { + "nodeMeta": { + "description": "The final node of the workflow, used to return the result information after the workflow runs.", + "icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End.png", + "subTitle": "", + "title": "End" + }, + "inputs": { + "terminatePlan": "returnVariables", + "inputParameters": [ + { + "name": "output", + "input": { + "type": "string", + "value": { + "type": "ref", + "content": { + "source": "block-output", + "blockID": "", + "name": "" + } + } + } + } + ] + } + } + } + ], + "edges": [], + "versions": { + "loop": "v2" + } +}` + +func GetDefaultInitCanvasJsonSchema(locale i18n.Locale) string { + return ternary.IFElse(locale == i18n.LocaleEN, defaultEnUSInitCanvasJsonSchema, defaultZhCNInitCanvasJsonSchema) +} diff --git a/backend/domain/workflow/entity/vo/node.go b/backend/domain/workflow/entity/vo/node.go index b85a3181..45eaa914 100644 --- a/backend/domain/workflow/entity/vo/node.go +++ b/backend/domain/workflow/entity/vo/node.go @@ -47,12 +47,6 @@ type FieldSource struct { Val any `json:"val,omitempty"` } -type ImplicitNodeDependency struct { - NodeID string - FieldPath compose.FieldPath - TypeInfo *TypeInfo -} - type TypeInfo struct { Type DataType `json:"type"` ElemTypeInfo *TypeInfo `json:"elem_type_info,omitempty"` diff --git a/backend/domain/workflow/entity/workflow.go b/backend/domain/workflow/entity/workflow.go index 669de893..384f7041 100644 --- a/backend/domain/workflow/entity/workflow.go +++ b/backend/domain/workflow/entity/workflow.go @@ -69,13 +69,6 @@ type IDVersionPair struct { Version string } -type Stage uint8 - -const ( - StageDraft Stage = 1 - StagePublished Stage = 2 -) - type WorkflowBasic struct { ID int64 Version string diff --git a/backend/domain/workflow/internal/canvas/adaptor/canvas_test.go b/backend/domain/workflow/internal/canvas/adaptor/canvas_test.go index bf8311f8..c4fdc283 100644 --- a/backend/domain/workflow/internal/canvas/adaptor/canvas_test.go +++ b/backend/domain/workflow/internal/canvas/adaptor/canvas_test.go @@ -58,6 +58,11 @@ import ( "github.com/coze-dev/coze-studio/backend/types/consts" ) +func TestMain(m *testing.M) { + RegisterAllNodeAdaptors() + m.Run() +} + func TestIntentDetectorAndDatabase(t *testing.T) { mockey.PatchConvey("intent detector & database custom sql", t, func() { data, err := os.ReadFile("../examples/intent_detector_database_custom_sql.json") diff --git a/backend/domain/workflow/internal/canvas/adaptor/from_node.go b/backend/domain/workflow/internal/canvas/adaptor/from_node.go index f11ac764..f93b45ed 100644 --- a/backend/domain/workflow/internal/canvas/adaptor/from_node.go +++ b/backend/domain/workflow/internal/canvas/adaptor/from_node.go @@ -26,10 +26,11 @@ import ( "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" ) func WorkflowSchemaFromNode(ctx context.Context, c *vo.Canvas, nodeID string) ( - *compose.WorkflowSchema, error) { + *schema.WorkflowSchema, error) { var ( n *vo.Node nodeFinder func(nodes []*vo.Node) *vo.Node @@ -62,35 +63,27 @@ func WorkflowSchemaFromNode(ctx context.Context, c *vo.Canvas, nodeID string) ( n = batchN } - implicitDependencies, err := extractImplicitDependency(n, c.Nodes) - if err != nil { - return nil, err - } - opts := make([]OptionFn, 0, 1) - if len(implicitDependencies) > 0 { - opts = append(opts, WithImplicitNodeDependencies(implicitDependencies)) - } - nsList, hierarchy, err := NodeToNodeSchema(ctx, n, opts...) + nsList, hierarchy, err := NodeToNodeSchema(ctx, n, c) if err != nil { return nil, err } var ( - ns *compose.NodeSchema - innerNodes map[vo.NodeKey]*compose.NodeSchema // inner nodes of the composite node if nodeKey is composite - connections []*compose.Connection + ns *schema.NodeSchema + innerNodes map[vo.NodeKey]*schema.NodeSchema // inner nodes of the composite node if nodeKey is composite + connections []*schema.Connection ) if len(nsList) == 1 { ns = nsList[0] } else { - innerNodes = make(map[vo.NodeKey]*compose.NodeSchema) + innerNodes = make(map[vo.NodeKey]*schema.NodeSchema) for i := range nsList { one := nsList[i] if _, ok := hierarchy[one.Key]; ok { innerNodes[one.Key] = one if one.Type == entity.NodeTypeContinue || one.Type == entity.NodeTypeBreak { - connections = append(connections, &compose.Connection{ + connections = append(connections, &schema.Connection{ FromNode: one.Key, ToNode: vo.NodeKey(nodeID), }) @@ -106,13 +99,13 @@ func WorkflowSchemaFromNode(ctx context.Context, c *vo.Canvas, nodeID string) ( } const inputFillerKey = "input_filler" - connections = append(connections, &compose.Connection{ + connections = append(connections, &schema.Connection{ FromNode: einoCompose.START, ToNode: inputFillerKey, - }, &compose.Connection{ + }, &schema.Connection{ FromNode: inputFillerKey, ToNode: ns.Key, - }, &compose.Connection{ + }, &schema.Connection{ FromNode: ns.Key, ToNode: einoCompose.END, }) @@ -209,7 +202,7 @@ func WorkflowSchemaFromNode(ctx context.Context, c *vo.Canvas, nodeID string) ( return newOutput, nil } - inputFiller := &compose.NodeSchema{ + inputFiller := &schema.NodeSchema{ Key: inputFillerKey, Type: entity.NodeTypeLambda, Lambda: einoCompose.InvokableLambda(i), @@ -227,10 +220,16 @@ func WorkflowSchemaFromNode(ctx context.Context, c *vo.Canvas, nodeID string) ( OutputTypes: startOutputTypes, } - trimmedSC := &compose.WorkflowSchema{ - Nodes: append([]*compose.NodeSchema{ns, inputFiller}, maps.Values(innerNodes)...), + branches, err := schema.BuildBranches(connections) + if err != nil { + return nil, err + } + + trimmedSC := &schema.WorkflowSchema{ + Nodes: append([]*schema.NodeSchema{ns, inputFiller}, maps.Values(innerNodes)...), Connections: connections, Hierarchy: hierarchy, + Branches: branches, } if enabled { diff --git a/backend/domain/workflow/internal/canvas/adaptor/to_schema.go b/backend/domain/workflow/internal/canvas/adaptor/to_schema.go index fed6cb6c..93c05b36 100644 --- a/backend/domain/workflow/internal/canvas/adaptor/to_schema.go +++ b/backend/domain/workflow/internal/canvas/adaptor/to_schema.go @@ -20,32 +20,40 @@ import ( "context" "errors" "fmt" - "regexp" "runtime/debug" "strconv" "strings" - "time" einoCompose "github.com/cloudwego/eino/compose" - "github.com/spf13/cast" "github.com/coze-dev/coze-studio/backend/domain/workflow" - "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database" - "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge" - "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/batch" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/code" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/database" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/emitter" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/httprequester" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/intentdetector" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/json" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/knowledge" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop" + _break "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop/break" + _continue "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop/continue" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/plugin" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/receiver" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/subworkflow" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/textprocessor" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableaggregator" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableassigner" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" "github.com/coze-dev/coze-studio/backend/pkg/lang/slices" "github.com/coze-dev/coze-studio/backend/pkg/lang/ternary" @@ -53,7 +61,7 @@ import ( "github.com/coze-dev/coze-studio/backend/pkg/sonic" ) -func CanvasToWorkflowSchema(ctx context.Context, s *vo.Canvas) (sc *compose.WorkflowSchema, err error) { +func CanvasToWorkflowSchema(ctx context.Context, s *vo.Canvas) (sc *schema.WorkflowSchema, err error) { defer func() { if panicErr := recover(); panicErr != nil { err = safego.NewPanicErr(panicErr, debug.Stack()) @@ -66,7 +74,7 @@ func CanvasToWorkflowSchema(ctx context.Context, s *vo.Canvas) (sc *compose.Work Edges: connectedEdges, } - sc = &compose.WorkflowSchema{} + sc = &schema.WorkflowSchema{} nodeMap := make(map[string]*vo.Node) @@ -83,8 +91,8 @@ func CanvasToWorkflowSchema(ctx context.Context, s *vo.Canvas) (sc *compose.Work return nil, fmt.Errorf("nodes in inner-workflow should not have edges info") } - if subNode.Type == vo.BlockTypeBotBreak || subNode.Type == vo.BlockTypeBotContinue { - sc.Connections = append(sc.Connections, &compose.Connection{ + if subNode.Type == entity.NodeTypeBreak.IDStr() || subNode.Type == entity.NodeTypeContinue.IDStr() { + sc.Connections = append(sc.Connections, &schema.Connection{ FromNode: vo.NodeKey(subNode.ID), ToNode: vo.NodeKey(subNode.Parent().ID), }) @@ -101,17 +109,7 @@ func CanvasToWorkflowSchema(ctx context.Context, s *vo.Canvas) (sc *compose.Work sc.GeneratedNodes = append(sc.GeneratedNodes, vo.NodeKey(node.Blocks[0].ID)) } - implicitDependencies, err := extractImplicitDependency(node, s.Nodes) - if err != nil { - return nil, err - } - - opts := make([]OptionFn, 0, 1) - if len(implicitDependencies) > 0 { - opts = append(opts, WithImplicitNodeDependencies(implicitDependencies)) - } - - nsList, hierarchy, err := NodeToNodeSchema(ctx, node, opts...) + nsList, hierarchy, err := NodeToNodeSchema(ctx, node, s) if err != nil { return nil, err } @@ -142,12 +140,19 @@ func CanvasToWorkflowSchema(ctx context.Context, s *vo.Canvas) (sc *compose.Work } sc.Connections = newConnections + branches, err := schema.BuildBranches(newConnections) + if err != nil { + return nil, err + } + + sc.Branches = branches + sc.Init() return sc, nil } -func normalizePorts(connections []*compose.Connection, nodeMap map[string]*vo.Node) (normalized []*compose.Connection, err error) { +func normalizePorts(connections []*schema.Connection, nodeMap map[string]*vo.Node) (normalized []*schema.Connection, err error) { for i := range connections { conn := connections[i] if conn.FromPort == nil { @@ -175,31 +180,24 @@ func normalizePorts(connections []*compose.Connection, nodeMap map[string]*vo.No var newPort string switch node.Type { - case vo.BlockTypeCondition: + case entity.NodeTypeSelector.IDStr(): if *conn.FromPort == "true" { - newPort = fmt.Sprintf(compose.BranchFmt, 0) + newPort = fmt.Sprintf(schema.PortBranchFormat, 0) } else if *conn.FromPort == "false" { - newPort = compose.DefaultBranch + newPort = schema.PortDefault } else if strings.HasPrefix(*conn.FromPort, "true_") { portN := strings.TrimPrefix(*conn.FromPort, "true_") n, err := strconv.Atoi(portN) if err != nil { return nil, fmt.Errorf("invalid port name: %s", *conn.FromPort) } - newPort = fmt.Sprintf(compose.BranchFmt, n) + newPort = fmt.Sprintf(schema.PortBranchFormat, n) } - case vo.BlockTypeBotIntent: - newPort = *conn.FromPort - case vo.BlockTypeQuestion: - newPort = *conn.FromPort default: - if *conn.FromPort != "default" && *conn.FromPort != "branch_error" { - return nil, fmt.Errorf("invalid port name: %s", *conn.FromPort) - } newPort = *conn.FromPort } - normalized = append(normalized, &compose.Connection{ + normalized = append(normalized, &schema.Connection{ FromNode: conn.FromNode, ToNode: conn.ToNode, FromPort: &newPort, @@ -209,96 +207,74 @@ func normalizePorts(connections []*compose.Connection, nodeMap map[string]*vo.No return normalized, nil } -var blockTypeToNodeSchema = map[vo.BlockType]func(*vo.Node, ...OptionFn) (*compose.NodeSchema, error){ - vo.BlockTypeBotStart: toEntryNodeSchema, - vo.BlockTypeBotEnd: toExitNodeSchema, - vo.BlockTypeBotLLM: toLLMNodeSchema, - vo.BlockTypeBotLoopSetVariable: toLoopSetVariableNodeSchema, - vo.BlockTypeBotBreak: toBreakNodeSchema, - vo.BlockTypeBotContinue: toContinueNodeSchema, - vo.BlockTypeCondition: toSelectorNodeSchema, - vo.BlockTypeBotText: toTextProcessorNodeSchema, - vo.BlockTypeBotIntent: toIntentDetectorSchema, - vo.BlockTypeDatabase: toDatabaseCustomSQLSchema, - vo.BlockTypeDatabaseSelect: toDatabaseQuerySchema, - vo.BlockTypeDatabaseInsert: toDatabaseInsertSchema, - vo.BlockTypeDatabaseDelete: toDatabaseDeleteSchema, - vo.BlockTypeDatabaseUpdate: toDatabaseUpdateSchema, - vo.BlockTypeBotHttp: toHttpRequesterSchema, - vo.BlockTypeBotDatasetWrite: toKnowledgeIndexerSchema, - vo.BlockTypeBotDatasetDelete: toKnowledgeDeleterSchema, - vo.BlockTypeBotDataset: toKnowledgeRetrieverSchema, - vo.BlockTypeBotAssignVariable: toVariableAssignerSchema, - vo.BlockTypeBotCode: toCodeRunnerSchema, - vo.BlockTypeBotAPI: toPluginSchema, - vo.BlockTypeBotVariableMerge: toVariableAggregatorSchema, - vo.BlockTypeBotInput: toInputReceiverSchema, - vo.BlockTypeBotMessage: toOutputEmitterNodeSchema, - vo.BlockTypeQuestion: toQASchema, - vo.BlockTypeJsonSerialization: toJSONSerializeSchema, - vo.BlockTypeJsonDeserialization: toJSONDeserializeSchema, +var blockTypeToSkip = map[entity.NodeType]bool{ + entity.NodeTypeComment: true, } -var blockTypeToSkip = map[vo.BlockType]bool{ - vo.BlockTypeBotComment: true, -} +func NodeToNodeSchema(ctx context.Context, n *vo.Node, c *vo.Canvas) ([]*schema.NodeSchema, map[vo.NodeKey]vo.NodeKey, error) { + et := entity.IDStrToNodeType(n.Type) -type option struct { - implicitNodeDependencies []*vo.ImplicitNodeDependency -} -type OptionFn func(*option) - -func WithImplicitNodeDependencies(implicitNodeDependencies []*vo.ImplicitNodeDependency) OptionFn { - return func(o *option) { - o.implicitNodeDependencies = implicitNodeDependencies - } -} - -func NodeToNodeSchema(ctx context.Context, n *vo.Node, opts ...OptionFn) ([]*compose.NodeSchema, map[vo.NodeKey]vo.NodeKey, error) { - cfg, ok := blockTypeToNodeSchema[n.Type] - if ok { - ns, err := cfg(n, opts...) - if err != nil { - return nil, nil, err - } - - if ns.ExceptionConfigs, err = toMetaConfig(n, ns.Type); err != nil { - return nil, nil, err - } - - return []*compose.NodeSchema{ns}, nil, nil - } - - _, ok = blockTypeToSkip[n.Type] - if ok { - return nil, nil, nil - } - - if n.Type == vo.BlockTypeBotSubWorkflow { + if et == entity.NodeTypeSubWorkflow { ns, err := toSubWorkflowNodeSchema(ctx, n) if err != nil { return nil, nil, err } - if ns.ExceptionConfigs, err = toMetaConfig(n, ns.Type); err != nil { + if ns.ExceptionConfigs, err = toExceptionConfig(n, ns.Type); err != nil { return nil, nil, err } - return []*compose.NodeSchema{ns}, nil, nil - } else if n.Type == vo.BlockTypeBotBatch { - return toBatchNodeSchema(ctx, n, opts...) - } else if n.Type == vo.BlockTypeBotLoop { - return toLoopNodeSchema(ctx, n, opts...) + return []*schema.NodeSchema{ns}, nil, nil + } + + na, ok := nodes.GetNodeAdaptor(et) + if ok { + ns, err := na.Adapt(ctx, n, nodes.WithCanvas(c)) + if err != nil { + return nil, nil, err + } + + if ns.ExceptionConfigs, err = toExceptionConfig(n, ns.Type); err != nil { + return nil, nil, err + } + + if len(n.Blocks) > 0 { + var ( + allNS []*schema.NodeSchema + hierarchy = make(map[vo.NodeKey]vo.NodeKey) + ) + + for _, childN := range n.Blocks { + childN.SetParent(n) + childNS, _, err := NodeToNodeSchema(ctx, childN, c) + if err != nil { + return nil, nil, err + } + + allNS = append(allNS, childNS...) + hierarchy[vo.NodeKey(childN.ID)] = vo.NodeKey(n.ID) + } + + allNS = append(allNS, ns) + return allNS, hierarchy, nil + } + + return []*schema.NodeSchema{ns}, nil, nil + } + + _, ok = blockTypeToSkip[et] + if ok { + return nil, nil, nil } return nil, nil, fmt.Errorf("unsupported block type: %v", n.Type) } -func EdgeToConnection(e *vo.Edge) *compose.Connection { +func EdgeToConnection(e *vo.Edge) *schema.Connection { toNode := vo.NodeKey(e.TargetNodeID) if len(e.SourcePortID) > 0 && (e.TargetPortID == "loop-function-inline-input" || e.TargetPortID == "batch-function-inline-input") { toNode = einoCompose.END } - conn := &compose.Connection{ + conn := &schema.Connection{ FromNode: vo.NodeKey(e.SourceNodeID), ToNode: toNode, } @@ -310,7 +286,7 @@ func EdgeToConnection(e *vo.Edge) *compose.Connection { return conn } -func toMetaConfig(n *vo.Node, nType entity.NodeType) (*compose.ExceptionConfig, error) { +func toExceptionConfig(n *vo.Node, nType entity.NodeType) (*schema.ExceptionConfig, error) { nodeMeta := entity.NodeMetaByNodeType(nType) var settingOnErr *vo.SettingOnError @@ -324,687 +300,33 @@ func toMetaConfig(n *vo.Node, nType entity.NodeType) (*compose.ExceptionConfig, return nil, nil } - metaConf := &compose.ExceptionConfig{ + metaConf := &schema.ExceptionConfig{ TimeoutMS: nodeMeta.DefaultTimeoutMS, } if settingOnErr != nil { - metaConf = &compose.ExceptionConfig{ + metaConf = &schema.ExceptionConfig{ TimeoutMS: settingOnErr.TimeoutMs, MaxRetry: settingOnErr.RetryTimes, DataOnErr: settingOnErr.DataOnErr, ProcessType: settingOnErr.ProcessType, } - if metaConf.ProcessType != nil && *metaConf.ProcessType == vo.ErrorProcessTypeDefault { + if metaConf.ProcessType != nil && *metaConf.ProcessType == vo.ErrorProcessTypeReturnDefaultData { if len(metaConf.DataOnErr) == 0 { return nil, errors.New("error process type is returning default value, but dataOnError is not specified") } } if metaConf.ProcessType == nil && len(metaConf.DataOnErr) > 0 && settingOnErr.Switch { - metaConf.ProcessType = ptr.Of(vo.ErrorProcessTypeDefault) + metaConf.ProcessType = ptr.Of(vo.ErrorProcessTypeReturnDefaultData) } } return metaConf, nil } -func toEntryNodeSchema(n *vo.Node, _ ...OptionFn) (*compose.NodeSchema, error) { - if n.Parent() != nil { - return nil, fmt.Errorf("entry node cannot have parent: %s", n.Parent().ID) - } - - if n.ID != entity.EntryNodeKey { - return nil, fmt.Errorf("entry node id must be %s, got %s", entity.EntryNodeKey, n.ID) - } - - ns := &compose.NodeSchema{ - Key: entity.EntryNodeKey, - Type: entity.NodeTypeEntry, - Name: n.Data.Meta.Title, - } - - defaultValues := make(map[string]any, len(n.Data.Outputs)) - for _, v := range n.Data.Outputs { - variable, err := vo.ParseVariable(v) - if err != nil { - return nil, err - } - if variable.DefaultValue != nil { - defaultValues[variable.Name] = variable.DefaultValue - } - - } - - ns.SetConfigKV("DefaultValues", defaultValues) - - if err := SetOutputTypesForNodeSchema(n, ns); err != nil { - return nil, err - } - - return ns, nil -} - -func toExitNodeSchema(n *vo.Node, _ ...OptionFn) (*compose.NodeSchema, error) { - if n.Parent() != nil { - return nil, fmt.Errorf("exit node cannot have parent: %s", n.Parent().ID) - } - - if n.ID != entity.ExitNodeKey { - return nil, fmt.Errorf("exit node id must be %s, got %s", entity.ExitNodeKey, n.ID) - } - - ns := &compose.NodeSchema{ - Key: entity.ExitNodeKey, - Type: entity.NodeTypeExit, - Name: n.Data.Meta.Title, - } - - content := n.Data.Inputs.Content - streamingOutput := n.Data.Inputs.StreamingOutput - - if streamingOutput { - ns.SetConfigKV("Mode", nodes.Streaming) - ns.StreamConfigs = &compose.StreamConfig{ - RequireStreamingInput: true, - } - } else { - ns.SetConfigKV("Mode", nodes.NonStreaming) - ns.StreamConfigs = &compose.StreamConfig{ - RequireStreamingInput: false, - } - } - - if content != nil { - if content.Type != vo.VariableTypeString { - return nil, fmt.Errorf("exit node's content type must be %s, got %s", vo.VariableTypeString, content.Type) - } - - if content.Value.Type != vo.BlockInputValueTypeLiteral { - return nil, fmt.Errorf("exit node's content value type must be %s, got %s", vo.BlockInputValueTypeLiteral, content.Value.Type) - } - - ns.SetConfigKV("Template", content.Value.Content.(string)) - } - - ns.SetConfigKV("TerminalPlan", *n.Data.Inputs.TerminatePlan) - - if err := SetInputsForNodeSchema(n, ns); err != nil { - return nil, err - } - - return ns, nil -} - -func toOutputEmitterNodeSchema(n *vo.Node, _ ...OptionFn) (*compose.NodeSchema, error) { - ns := &compose.NodeSchema{ - Key: vo.NodeKey(n.ID), - Type: entity.NodeTypeOutputEmitter, - Name: n.Data.Meta.Title, - } - - content := n.Data.Inputs.Content - streamingOutput := n.Data.Inputs.StreamingOutput - - if streamingOutput { - ns.SetConfigKV("Mode", nodes.Streaming) - ns.StreamConfigs = &compose.StreamConfig{ - RequireStreamingInput: true, - } - } else { - ns.SetConfigKV("Mode", nodes.NonStreaming) - ns.StreamConfigs = &compose.StreamConfig{ - RequireStreamingInput: false, - } - } - - if content != nil { - if content.Type != vo.VariableTypeString { - return nil, fmt.Errorf("output emitter node's content type must be %s, got %s", vo.VariableTypeString, content.Type) - } - - if content.Value.Type != vo.BlockInputValueTypeLiteral { - return nil, fmt.Errorf("output emitter node's content value type must be %s, got %s", vo.BlockInputValueTypeLiteral, content.Value.Type) - } - - if content.Value.Content == nil { - ns.SetConfigKV("Template", "") - } else { - template, ok := content.Value.Content.(string) - if !ok { - return nil, fmt.Errorf("output emitter node's content value must be string, got %v", content.Value.Content) - } - - ns.SetConfigKV("Template", template) - } - } - - if err := SetInputsForNodeSchema(n, ns); err != nil { - return nil, err - } - - return ns, nil -} - -func toLLMNodeSchema(n *vo.Node, _ ...OptionFn) (*compose.NodeSchema, error) { - ns := &compose.NodeSchema{ - Key: vo.NodeKey(n.ID), - Type: entity.NodeTypeLLM, - Name: n.Data.Meta.Title, - } - - param := n.Data.Inputs.LLMParam - if param == nil { - return nil, fmt.Errorf("llm node's llmParam is nil") - } - - bs, _ := sonic.Marshal(param) - llmParam := make(vo.LLMParam, 0) - if err := sonic.Unmarshal(bs, &llmParam); err != nil { - return nil, err - } - convertedLLMParam, err := LLMParamsToLLMParam(llmParam) - if err != nil { - return nil, err - } - - ns.SetConfigKV("LLMParams", convertedLLMParam) - ns.SetConfigKV("SystemPrompt", convertedLLMParam.SystemPrompt) - ns.SetConfigKV("UserPrompt", convertedLLMParam.Prompt) - - var resFormat llm.Format - switch convertedLLMParam.ResponseFormat { - case model.ResponseFormatText: - resFormat = llm.FormatText - case model.ResponseFormatMarkdown: - resFormat = llm.FormatMarkdown - case model.ResponseFormatJSON: - resFormat = llm.FormatJSON - default: - return nil, fmt.Errorf("unsupported response format: %d", convertedLLMParam.ResponseFormat) - } - - ns.SetConfigKV("OutputFormat", resFormat) - - if err = SetInputsForNodeSchema(n, ns); err != nil { - return nil, err - } - - if err = SetOutputTypesForNodeSchema(n, ns); err != nil { - return nil, err - } - - if resFormat == llm.FormatJSON { - if len(ns.OutputTypes) == 1 { - for _, v := range ns.OutputTypes { - if v.Type == vo.DataTypeString { - resFormat = llm.FormatText - break - } - } - } else if len(ns.OutputTypes) == 2 { - if _, ok := ns.OutputTypes[llm.ReasoningOutputKey]; ok { - for k, v := range ns.OutputTypes { - if k != llm.ReasoningOutputKey && v.Type == vo.DataTypeString { - resFormat = llm.FormatText - break - } - } - } - } - } - - if resFormat == llm.FormatJSON { - ns.StreamConfigs = &compose.StreamConfig{ - CanGeneratesStream: false, - } - } else { - ns.StreamConfigs = &compose.StreamConfig{ - CanGeneratesStream: true, - } - } - - if n.Data.Inputs.FCParam != nil { - ns.SetConfigKV("FCParam", n.Data.Inputs.FCParam) - } - - if se := n.Data.Inputs.SettingOnError; se != nil { - if se.Ext != nil && len(se.Ext.BackupLLMParam) > 0 { - var backupLLMParam vo.QALLMParam - if err = sonic.UnmarshalString(se.Ext.BackupLLMParam, &backupLLMParam); err != nil { - return nil, err - } - - backupModel, err := qaLLMParamsToLLMParams(backupLLMParam) - if err != nil { - return nil, err - } - ns.SetConfigKV("BackupLLMParams", backupModel) - } - } - - return ns, nil -} - -func toLoopSetVariableNodeSchema(n *vo.Node, _ ...OptionFn) (*compose.NodeSchema, error) { - if n.Parent() == nil { - return nil, fmt.Errorf("loop set variable node must have parent: %s", n.ID) - } - - ns := &compose.NodeSchema{ - Key: vo.NodeKey(n.ID), - Type: entity.NodeTypeVariableAssignerWithinLoop, - Name: n.Data.Meta.Title, - } - - var pairs []*variableassigner.Pair - for i, param := range n.Data.Inputs.InputParameters { - if param.Left == nil || param.Right == nil { - return nil, fmt.Errorf("loop set variable node's param left or right is nil") - } - - leftSources, err := CanvasBlockInputToFieldInfo(param.Left, einoCompose.FieldPath{fmt.Sprintf("left_%d", i)}, n.Parent()) - if err != nil { - return nil, err - } - - if len(leftSources) != 1 { - return nil, fmt.Errorf("loop set variable node's param left is not a single source") - } - - if leftSources[0].Source.Ref == nil { - return nil, fmt.Errorf("loop set variable node's param left's ref is nil") - } - - if leftSources[0].Source.Ref.VariableType == nil || *leftSources[0].Source.Ref.VariableType != vo.ParentIntermediate { - return nil, fmt.Errorf("loop set variable node's param left's ref's variable type is not variable.ParentIntermediate") - } - - rightSources, err := CanvasBlockInputToFieldInfo(param.Right, leftSources[0].Source.Ref.FromPath, n.Parent()) - if err != nil { - return nil, err - } - - ns.AddInputSource(rightSources...) - - if len(rightSources) != 1 { - return nil, fmt.Errorf("loop set variable node's param right is not a single source") - } - - pair := &variableassigner.Pair{ - Left: *leftSources[0].Source.Ref, - Right: rightSources[0].Path, - } - - pairs = append(pairs, pair) - } - - ns.Configs = pairs - - return ns, nil -} - -func toBreakNodeSchema(n *vo.Node, _ ...OptionFn) (*compose.NodeSchema, error) { - return &compose.NodeSchema{ - Key: vo.NodeKey(n.ID), - Type: entity.NodeTypeBreak, - Name: n.Data.Meta.Title, - }, nil -} - -func toContinueNodeSchema(n *vo.Node, _ ...OptionFn) (*compose.NodeSchema, error) { - return &compose.NodeSchema{ - Key: vo.NodeKey(n.ID), - Type: entity.NodeTypeContinue, - Name: n.Data.Meta.Title, - }, nil -} - -func toSelectorNodeSchema(n *vo.Node, _ ...OptionFn) (*compose.NodeSchema, error) { - ns := &compose.NodeSchema{ - Key: vo.NodeKey(n.ID), - Type: entity.NodeTypeSelector, - Name: n.Data.Meta.Title, - } - - clauses := make([]*selector.OneClauseSchema, 0) - for i, branchCond := range n.Data.Inputs.Branches { - inputType := &vo.TypeInfo{ - Type: vo.DataTypeObject, - Properties: map[string]*vo.TypeInfo{}, - } - - if len(branchCond.Condition.Conditions) == 1 { // single condition - cond := branchCond.Condition.Conditions[0] - - left := cond.Left - if left == nil { - return nil, fmt.Errorf("operator left is nil") - } - - leftType, err := CanvasBlockInputToTypeInfo(left.Input) - if err != nil { - return nil, err - } - - leftSources, err := CanvasBlockInputToFieldInfo(left.Input, einoCompose.FieldPath{fmt.Sprintf("%d", i), selector.LeftKey}, n.Parent()) - if err != nil { - return nil, err - } - - inputType.Properties[selector.LeftKey] = leftType - - ns.AddInputSource(leftSources...) - - op, err := ToSelectorOperator(cond.Operator, leftType) - if err != nil { - return nil, err - } - - if cond.Right != nil { - rightType, err := CanvasBlockInputToTypeInfo(cond.Right.Input) - if err != nil { - return nil, err - } - - rightSources, err := CanvasBlockInputToFieldInfo(cond.Right.Input, einoCompose.FieldPath{fmt.Sprintf("%d", i), selector.RightKey}, n.Parent()) - if err != nil { - return nil, err - } - - inputType.Properties[selector.RightKey] = rightType - ns.AddInputSource(rightSources...) - } - - ns.SetInputType(fmt.Sprintf("%d", i), inputType) - - clauses = append(clauses, &selector.OneClauseSchema{ - Single: &op, - }) - - continue - } - - var relation selector.ClauseRelation - logic := branchCond.Condition.Logic - if logic == vo.OR { - relation = selector.ClauseRelationOR - } else if logic == vo.AND { - relation = selector.ClauseRelationAND - } - - var ops []*selector.Operator - for j, cond := range branchCond.Condition.Conditions { - left := cond.Left - if left == nil { - return nil, fmt.Errorf("operator left is nil") - } - - leftType, err := CanvasBlockInputToTypeInfo(left.Input) - if err != nil { - return nil, err - } - - leftSources, err := CanvasBlockInputToFieldInfo(left.Input, einoCompose.FieldPath{fmt.Sprintf("%d", i), fmt.Sprintf("%d", j), selector.LeftKey}, n.Parent()) - if err != nil { - return nil, err - } - - inputType.Properties[fmt.Sprintf("%d", j)] = &vo.TypeInfo{ - Type: vo.DataTypeObject, - Properties: map[string]*vo.TypeInfo{ - selector.LeftKey: leftType, - }, - } - - ns.AddInputSource(leftSources...) - - op, err := ToSelectorOperator(cond.Operator, leftType) - if err != nil { - return nil, err - } - ops = append(ops, &op) - - if cond.Right != nil { - rightType, err := CanvasBlockInputToTypeInfo(cond.Right.Input) - if err != nil { - return nil, err - } - - rightSources, err := CanvasBlockInputToFieldInfo(cond.Right.Input, einoCompose.FieldPath{fmt.Sprintf("%d", i), fmt.Sprintf("%d", j), selector.RightKey}, n.Parent()) - if err != nil { - return nil, err - } - - inputType.Properties[fmt.Sprintf("%d", j)].Properties[selector.RightKey] = rightType - ns.AddInputSource(rightSources...) - } - } - - ns.SetInputType(fmt.Sprintf("%d", i), inputType) - - clauses = append(clauses, &selector.OneClauseSchema{ - Multi: &selector.MultiClauseSchema{ - Clauses: ops, - Relation: relation, - }, - }) - } - - ns.Configs = map[string]any{"Clauses": clauses} - return ns, nil -} - -func toTextProcessorNodeSchema(n *vo.Node, _ ...OptionFn) (*compose.NodeSchema, error) { - ns := &compose.NodeSchema{ - Key: vo.NodeKey(n.ID), - Type: entity.NodeTypeTextProcessor, - Name: n.Data.Meta.Title, - } - - configs := make(map[string]any) - - if n.Data.Inputs.Method == vo.Concat { - configs["Type"] = textprocessor.ConcatText - params := n.Data.Inputs.ConcatParams - for _, param := range params { - if param.Name == "concatResult" { - configs["Tpl"] = param.Input.Value.Content.(string) - } else if param.Name == "arrayItemConcatChar" { - configs["ConcatChar"] = param.Input.Value.Content.(string) - } - } - } else if n.Data.Inputs.Method == vo.Split { - configs["Type"] = textprocessor.SplitText - params := n.Data.Inputs.SplitParams - separators := make([]string, 0, len(params)) - for _, param := range params { - if param.Name == "delimiters" { - delimiters := param.Input.Value.Content.([]any) - for _, d := range delimiters { - separators = append(separators, d.(string)) - } - } - } - configs["Separators"] = separators - - } else { - return nil, fmt.Errorf("not supported method: %s", n.Data.Inputs.Method) - } - - ns.Configs = configs - - if err := SetInputsForNodeSchema(n, ns); err != nil { - return nil, err - } - - if err := SetOutputTypesForNodeSchema(n, ns); err != nil { - return nil, err - } - - return ns, nil -} - -func toLoopNodeSchema(ctx context.Context, n *vo.Node, opts ...OptionFn) ([]*compose.NodeSchema, map[vo.NodeKey]vo.NodeKey, error) { - if n.Parent() != nil { - return nil, nil, fmt.Errorf("loop node cannot have parent: %s", n.Parent().ID) - } - - ns := &compose.NodeSchema{ - Key: vo.NodeKey(n.ID), - Type: entity.NodeTypeLoop, - Name: n.Data.Meta.Title, - } - - var ( - allNS []*compose.NodeSchema - hierarchy = make(map[vo.NodeKey]vo.NodeKey) - ) - - for _, childN := range n.Blocks { - childN.SetParent(n) - childNS, _, err := NodeToNodeSchema(ctx, childN, opts...) - if err != nil { - return nil, nil, err - } - - allNS = append(allNS, childNS...) - hierarchy[vo.NodeKey(childN.ID)] = vo.NodeKey(n.ID) - } - - loopType, err := ToLoopType(n.Data.Inputs.LoopType) - if err != nil { - return nil, nil, err - } - ns.SetConfigKV("LoopType", loopType) - - intermediateVars := make(map[string]*vo.TypeInfo) - for _, param := range n.Data.Inputs.VariableParameters { - tInfo, err := CanvasBlockInputToTypeInfo(param.Input) - if err != nil { - return nil, nil, err - } - intermediateVars[param.Name] = tInfo - - ns.SetInputType(param.Name, tInfo) - sources, err := CanvasBlockInputToFieldInfo(param.Input, einoCompose.FieldPath{param.Name}, nil) - if err != nil { - return nil, nil, err - } - ns.AddInputSource(sources...) - } - ns.SetConfigKV("IntermediateVars", intermediateVars) - - if err := SetInputsForNodeSchema(n, ns); err != nil { - return nil, nil, err - } - - if err := SetOutputsForNodeSchema(n, ns); err != nil { - return nil, nil, err - } - - for _, fieldInfo := range ns.OutputSources { - if fieldInfo.Source.Ref != nil { - if len(fieldInfo.Source.Ref.FromPath) == 1 { - if _, ok := intermediateVars[fieldInfo.Source.Ref.FromPath[0]]; ok { - fieldInfo.Source.Ref.VariableType = ptr.Of(vo.ParentIntermediate) - } - } - } - } - - loopCount := n.Data.Inputs.LoopCount - if loopCount != nil { - typeInfo, err := CanvasBlockInputToTypeInfo(loopCount) - if err != nil { - return nil, nil, err - } - ns.SetInputType(loop.Count, typeInfo) - - sources, err := CanvasBlockInputToFieldInfo(loopCount, einoCompose.FieldPath{loop.Count}, nil) - if err != nil { - return nil, nil, err - } - ns.AddInputSource(sources...) - } - - if ns.ExceptionConfigs, err = toMetaConfig(n, entity.NodeTypeLoop); err != nil { - return nil, nil, err - } - - allNS = append(allNS, ns) - - return allNS, hierarchy, nil -} - -func toBatchNodeSchema(ctx context.Context, n *vo.Node, opts ...OptionFn) ([]*compose.NodeSchema, map[vo.NodeKey]vo.NodeKey, error) { - if n.Parent() != nil { - return nil, nil, fmt.Errorf("batch node cannot have parent: %s", n.Parent().ID) - } - - ns := &compose.NodeSchema{ - Key: vo.NodeKey(n.ID), - Type: entity.NodeTypeBatch, - Name: n.Data.Meta.Title, - } - - var ( - allNS []*compose.NodeSchema - hierarchy = make(map[vo.NodeKey]vo.NodeKey) - ) - - for _, childN := range n.Blocks { - childN.SetParent(n) - childNS, _, err := NodeToNodeSchema(ctx, childN, opts...) - if err != nil { - return nil, nil, err - } - - allNS = append(allNS, childNS...) - hierarchy[vo.NodeKey(childN.ID)] = vo.NodeKey(n.ID) - } - - batchSizeField, err := CanvasBlockInputToFieldInfo(n.Data.Inputs.BatchSize, einoCompose.FieldPath{batch.MaxBatchSizeKey}, nil) - if err != nil { - return nil, nil, err - } - ns.AddInputSource(batchSizeField...) - concurrentSizeField, err := CanvasBlockInputToFieldInfo(n.Data.Inputs.ConcurrentSize, einoCompose.FieldPath{batch.ConcurrentSizeKey}, nil) - if err != nil { - return nil, nil, err - } - ns.AddInputSource(concurrentSizeField...) - - batchSizeType, err := CanvasBlockInputToTypeInfo(n.Data.Inputs.BatchSize) - if err != nil { - return nil, nil, err - } - ns.SetInputType(batch.MaxBatchSizeKey, batchSizeType) - concurrentSizeType, err := CanvasBlockInputToTypeInfo(n.Data.Inputs.ConcurrentSize) - if err != nil { - return nil, nil, err - } - ns.SetInputType(batch.ConcurrentSizeKey, concurrentSizeType) - - if err := SetInputsForNodeSchema(n, ns); err != nil { - return nil, nil, err - } - - if err := SetOutputsForNodeSchema(n, ns); err != nil { - return nil, nil, err - } - - if ns.ExceptionConfigs, err = toMetaConfig(n, entity.NodeTypeBatch); err != nil { - return nil, nil, err - } - - allNS = append(allNS, ns) - - return allNS, hierarchy, nil -} - -func toSubWorkflowNodeSchema(ctx context.Context, n *vo.Node) (*compose.NodeSchema, error) { +func toSubWorkflowNodeSchema(ctx context.Context, n *vo.Node) (*schema.NodeSchema, error) { idStr := n.Data.Inputs.WorkflowID id, err := strconv.ParseInt(idStr, 10, 64) if err != nil { @@ -1032,12 +354,15 @@ func toSubWorkflowNodeSchema(ctx context.Context, n *vo.Node) (*compose.NodeSche return nil, err } - ns := &compose.NodeSchema{ + cfg := &subworkflow.Config{} + + ns := &schema.NodeSchema{ Key: vo.NodeKey(n.ID), Type: entity.NodeTypeSubWorkflow, Name: n.Data.Meta.Title, SubWorkflowBasic: subWF.GetBasic(), SubWorkflowSchema: subWorkflowSC, + Configs: cfg, } workflowIDStr := n.Data.Inputs.WorkflowID @@ -1048,933 +373,18 @@ func toSubWorkflowNodeSchema(ctx context.Context, n *vo.Node) (*compose.NodeSche if err != nil { return nil, fmt.Errorf("sub workflow node's workflowID is not a number: %s", workflowIDStr) } - ns.SetConfigKV("WorkflowID", workflowID) - ns.SetConfigKV("WorkflowVersion", n.Data.Inputs.WorkflowVersion) + cfg.WorkflowID = workflowID + cfg.WorkflowVersion = n.Data.Inputs.WorkflowVersion - if err := SetInputsForNodeSchema(n, ns); err != nil { + if err := convert.SetInputsForNodeSchema(n, ns); err != nil { return nil, err } - if err := SetOutputTypesForNodeSchema(n, ns); err != nil { + if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil { return nil, err } return ns, nil } -func toIntentDetectorSchema(n *vo.Node, _ ...OptionFn) (*compose.NodeSchema, error) { - ns := &compose.NodeSchema{ - Key: vo.NodeKey(n.ID), - Type: entity.NodeTypeIntentDetector, - Name: n.Data.Meta.Title, - } - - param := n.Data.Inputs.LLMParam - if param == nil { - return nil, fmt.Errorf("intent detector node's llmParam is nil") - } - - llmParam, ok := param.(vo.IntentDetectorLLMParam) - if !ok { - return nil, fmt.Errorf("llm node's llmParam must be LLMParam, got %v", llmParam) - } - - paramBytes, err := sonic.Marshal(param) - if err != nil { - return nil, err - } - var intentDetectorConfig = &vo.IntentDetectorLLMConfig{} - - err = sonic.Unmarshal(paramBytes, &intentDetectorConfig) - if err != nil { - return nil, err - } - - modelLLMParams := &model.LLMParams{} - modelLLMParams.ModelType = int64(intentDetectorConfig.ModelType) - modelLLMParams.ModelName = intentDetectorConfig.ModelName - modelLLMParams.TopP = intentDetectorConfig.TopP - modelLLMParams.Temperature = intentDetectorConfig.Temperature - modelLLMParams.MaxTokens = intentDetectorConfig.MaxTokens - modelLLMParams.ResponseFormat = model.ResponseFormat(intentDetectorConfig.ResponseFormat) - modelLLMParams.SystemPrompt = intentDetectorConfig.SystemPrompt.Value.Content.(string) - - ns.SetConfigKV("LLMParams", modelLLMParams) - ns.SetConfigKV("SystemPrompt", modelLLMParams.SystemPrompt) - - var intents = make([]string, 0, len(n.Data.Inputs.Intents)) - for _, it := range n.Data.Inputs.Intents { - intents = append(intents, it.Name) - } - ns.SetConfigKV("Intents", intents) - - if n.Data.Inputs.Mode == "top_speed" { - ns.SetConfigKV("IsFastMode", true) - } - - if err = SetInputsForNodeSchema(n, ns); err != nil { - return nil, err - } - - if err = SetOutputTypesForNodeSchema(n, ns); err != nil { - return nil, err - } - - return ns, nil -} - -func toDatabaseCustomSQLSchema(n *vo.Node, _ ...OptionFn) (*compose.NodeSchema, error) { - ns := &compose.NodeSchema{ - Key: vo.NodeKey(n.ID), - Type: entity.NodeTypeDatabaseCustomSQL, - Name: n.Data.Meta.Title, - } - - dsList := n.Data.Inputs.DatabaseInfoList - if len(dsList) == 0 { - return nil, fmt.Errorf("database info is requird") - } - databaseInfo := dsList[0] - - dsID, err := strconv.ParseInt(databaseInfo.DatabaseInfoID, 10, 64) - if err != nil { - return nil, err - } - ns.SetConfigKV("DatabaseInfoID", dsID) - - sql := n.Data.Inputs.SQL - if len(sql) == 0 { - return nil, fmt.Errorf("sql is requird") - } - - ns.SetConfigKV("SQLTemplate", sql) - - if err = SetInputsForNodeSchema(n, ns); err != nil { - return nil, err - } - - if err = SetOutputTypesForNodeSchema(n, ns); err != nil { - return nil, err - } - - return ns, nil -} - -func toDatabaseQuerySchema(n *vo.Node, _ ...OptionFn) (*compose.NodeSchema, error) { - ns := &compose.NodeSchema{ - Key: vo.NodeKey(n.ID), - Type: entity.NodeTypeDatabaseQuery, - Name: n.Data.Meta.Title, - } - - dsList := n.Data.Inputs.DatabaseInfoList - if len(dsList) == 0 { - return nil, fmt.Errorf("database info is requird") - } - databaseInfo := dsList[0] - - dsID, err := strconv.ParseInt(databaseInfo.DatabaseInfoID, 10, 64) - if err != nil { - return nil, err - } - ns.SetConfigKV("DatabaseInfoID", dsID) - - selectParam := n.Data.Inputs.SelectParam - ns.SetConfigKV("Limit", selectParam.Limit) - - queryFields := make([]string, 0) - for _, v := range selectParam.FieldList { - queryFields = append(queryFields, strconv.FormatInt(v.FieldID, 10)) - } - ns.SetConfigKV("QueryFields", queryFields) - - orderClauses := make([]*database.OrderClause, 0, len(selectParam.OrderByList)) - for _, o := range selectParam.OrderByList { - orderClauses = append(orderClauses, &database.OrderClause{ - FieldID: strconv.FormatInt(o.FieldID, 10), - IsAsc: o.IsAsc, - }) - } - ns.SetConfigKV("OrderClauses", orderClauses) - - clauseGroup := &database.ClauseGroup{} - - if selectParam.Condition != nil { - clauseGroup, err = buildClauseGroupFromCondition(selectParam.Condition) - if err != nil { - return nil, err - } - } - - ns.SetConfigKV("ClauseGroup", clauseGroup) - - if err = SetDatabaseInputsForNodeSchema(n, ns); err != nil { - return nil, err - } - - if err = SetOutputTypesForNodeSchema(n, ns); err != nil { - return nil, err - } - - return ns, nil -} - -func toDatabaseInsertSchema(n *vo.Node, _ ...OptionFn) (*compose.NodeSchema, error) { - ns := &compose.NodeSchema{ - Key: vo.NodeKey(n.ID), - Type: entity.NodeTypeDatabaseInsert, - Name: n.Data.Meta.Title, - } - - dsList := n.Data.Inputs.DatabaseInfoList - if len(dsList) == 0 { - return nil, fmt.Errorf("database info is requird") - } - databaseInfo := dsList[0] - - dsID, err := strconv.ParseInt(databaseInfo.DatabaseInfoID, 10, 64) - if err != nil { - return nil, err - } - ns.SetConfigKV("DatabaseInfoID", dsID) - - if err = SetDatabaseInputsForNodeSchema(n, ns); err != nil { - return nil, err - } - - if err = SetOutputTypesForNodeSchema(n, ns); err != nil { - return nil, err - } - - return ns, nil -} - -func toDatabaseDeleteSchema(n *vo.Node, _ ...OptionFn) (*compose.NodeSchema, error) { - ns := &compose.NodeSchema{ - Key: vo.NodeKey(n.ID), - Type: entity.NodeTypeDatabaseDelete, - Name: n.Data.Meta.Title, - } - - dsList := n.Data.Inputs.DatabaseInfoList - if len(dsList) == 0 { - return nil, fmt.Errorf("database info is requird") - } - databaseInfo := dsList[0] - - dsID, err := strconv.ParseInt(databaseInfo.DatabaseInfoID, 10, 64) - if err != nil { - return nil, err - } - ns.SetConfigKV("DatabaseInfoID", dsID) - - deleteParam := n.Data.Inputs.DeleteParam - - clauseGroup, err := buildClauseGroupFromCondition(&deleteParam.Condition) - if err != nil { - return nil, err - } - ns.SetConfigKV("ClauseGroup", clauseGroup) - - if err = SetDatabaseInputsForNodeSchema(n, ns); err != nil { - return nil, err - } - - if err = SetOutputTypesForNodeSchema(n, ns); err != nil { - return nil, err - } - - return ns, nil -} - -func toDatabaseUpdateSchema(n *vo.Node, _ ...OptionFn) (*compose.NodeSchema, error) { - ns := &compose.NodeSchema{ - Key: vo.NodeKey(n.ID), - Type: entity.NodeTypeDatabaseUpdate, - Name: n.Data.Meta.Title, - } - - dsList := n.Data.Inputs.DatabaseInfoList - if len(dsList) == 0 { - return nil, fmt.Errorf("database info is requird") - } - databaseInfo := dsList[0] - - dsID, err := strconv.ParseInt(databaseInfo.DatabaseInfoID, 10, 64) - if err != nil { - return nil, err - } - ns.SetConfigKV("DatabaseInfoID", dsID) - - updateParam := n.Data.Inputs.UpdateParam - if updateParam == nil { - return nil, fmt.Errorf("update param is requird") - } - clauseGroup, err := buildClauseGroupFromCondition(&updateParam.Condition) - if err != nil { - return nil, err - } - ns.SetConfigKV("ClauseGroup", clauseGroup) - if err = SetDatabaseInputsForNodeSchema(n, ns); err != nil { - return nil, err - } - - if err = SetOutputTypesForNodeSchema(n, ns); err != nil { - return nil, err - } - - return ns, nil -} - -func toHttpRequesterSchema(n *vo.Node, opts ...OptionFn) (*compose.NodeSchema, error) { - ns := &compose.NodeSchema{ - Key: vo.NodeKey(n.ID), - Type: entity.NodeTypeHTTPRequester, - Name: n.Data.Meta.Title, - } - option := &option{} - for _, opt := range opts { - opt(option) - } - - implicitNodeDependencies := option.implicitNodeDependencies - - inputs := n.Data.Inputs - - md5FieldMapping := &httprequester.MD5FieldMapping{} - - method := inputs.APIInfo.Method - ns.SetConfigKV("Method", method) - url := inputs.APIInfo.URL - ns.SetConfigKV("URLConfig", httprequester.URLConfig{ - Tpl: strings.TrimSpace(url), - }) - - urlVars := extractBracesContent(url) - md5FieldMapping.SetURLFields(urlVars...) - - md5FieldMapping.SetHeaderFields(slices.Transform(inputs.Headers, func(a *vo.Param) string { - return a.Name - })...) - - md5FieldMapping.SetParamFields(slices.Transform(inputs.Params, func(a *vo.Param) string { - return a.Name - })...) - - if inputs.Auth != nil && inputs.Auth.AuthOpen { - auth := &httprequester.AuthenticationConfig{} - ty, err := ConvertAuthType(inputs.Auth.AuthType) - if err != nil { - return nil, err - } - auth.Type = ty - location, err := ConvertLocation(inputs.Auth.AuthData.CustomData.AddTo) - if err != nil { - return nil, err - } - auth.Location = location - - ns.SetConfigKV("AuthConfig", auth) - - } - - bodyConfig := httprequester.BodyConfig{} - - bodyConfig.BodyType = httprequester.BodyType(inputs.Body.BodyType) - switch httprequester.BodyType(inputs.Body.BodyType) { - case httprequester.BodyTypeJSON: - jsonTpl := inputs.Body.BodyData.Json - bodyConfig.TextJsonConfig = &httprequester.TextJsonConfig{ - Tpl: jsonTpl, - } - jsonVars := extractBracesContent(jsonTpl) - md5FieldMapping.SetBodyFields(jsonVars...) - case httprequester.BodyTypeFormData: - bodyConfig.FormDataConfig = &httprequester.FormDataConfig{ - FileTypeMapping: map[string]bool{}, - } - formDataVars := make([]string, 0) - for i := range inputs.Body.BodyData.FormData.Data { - p := inputs.Body.BodyData.FormData.Data[i] - formDataVars = append(formDataVars, p.Name) - if p.Input.Type == vo.VariableTypeString && p.Input.AssistType > vo.AssistTypeNotSet && p.Input.AssistType < vo.AssistTypeTime { - bodyConfig.FormDataConfig.FileTypeMapping[p.Name] = true - } - - } - md5FieldMapping.SetBodyFields(formDataVars...) - case httprequester.BodyTypeRawText: - TextTpl := inputs.Body.BodyData.RawText - bodyConfig.TextPlainConfig = &httprequester.TextPlainConfig{ - Tpl: TextTpl, - } - textPlainVars := extractBracesContent(TextTpl) - md5FieldMapping.SetBodyFields(textPlainVars...) - case httprequester.BodyTypeFormURLEncoded: - formURLEncodedVars := make([]string, 0) - for _, p := range inputs.Body.BodyData.FormURLEncoded { - formURLEncodedVars = append(formURLEncodedVars, p.Name) - } - md5FieldMapping.SetBodyFields(formURLEncodedVars...) - } - ns.SetConfigKV("BodyConfig", bodyConfig) - ns.SetConfigKV("MD5FieldMapping", *md5FieldMapping) - - if inputs.Setting != nil { - ns.SetConfigKV("Timeout", time.Duration(inputs.Setting.Timeout)*time.Second) - ns.SetConfigKV("RetryTimes", uint64(inputs.Setting.RetryTimes)) - } - - if err := SetHttpRequesterInputsForNodeSchema(n, ns, implicitNodeDependencies); err != nil { - return nil, err - } - if err := SetOutputTypesForNodeSchema(n, ns); err != nil { - return nil, err - } - return ns, nil -} - -func toKnowledgeIndexerSchema(n *vo.Node, _ ...OptionFn) (*compose.NodeSchema, error) { - ns := &compose.NodeSchema{ - Key: vo.NodeKey(n.ID), - Type: entity.NodeTypeKnowledgeIndexer, - Name: n.Data.Meta.Title, - } - - inputs := n.Data.Inputs - datasetListInfoParam := inputs.DatasetParam[0] - datasetIDs := datasetListInfoParam.Input.Value.Content.([]any) - if len(datasetIDs) == 0 { - return nil, fmt.Errorf("dataset ids is required") - } - knowledgeID, err := cast.ToInt64E(datasetIDs[0]) - if err != nil { - return nil, err - } - - ns.SetConfigKV("KnowledgeID", knowledgeID) - ps := inputs.StrategyParam.ParsingStrategy - parseMode, err := ConvertParsingType(ps.ParsingType) - if err != nil { - return nil, err - } - parsingStrategy := &knowledge.ParsingStrategy{ - ParseMode: parseMode, - ImageOCR: ps.ImageOcr, - ExtractImage: ps.ImageExtraction, - ExtractTable: ps.TableExtraction, - } - - ns.SetConfigKV("ParsingStrategy", parsingStrategy) - cs := inputs.StrategyParam.ChunkStrategy - chunkType, err := ConvertChunkType(cs.ChunkType) - if err != nil { - return nil, err - } - chunkingStrategy := &knowledge.ChunkingStrategy{ - ChunkType: chunkType, - Separator: cs.Separator, - ChunkSize: cs.MaxToken, - Overlap: int64(cs.Overlap * float64(cs.MaxToken)), - } - ns.SetConfigKV("ChunkingStrategy", chunkingStrategy) - - if err = SetInputsForNodeSchema(n, ns); err != nil { - return nil, err - } - - if err = SetOutputTypesForNodeSchema(n, ns); err != nil { - return nil, err - } - - return ns, nil -} - -func toKnowledgeRetrieverSchema(n *vo.Node, _ ...OptionFn) (*compose.NodeSchema, error) { - ns := &compose.NodeSchema{ - Key: vo.NodeKey(n.ID), - Type: entity.NodeTypeKnowledgeRetriever, - Name: n.Data.Meta.Title, - } - - inputs := n.Data.Inputs - datasetListInfoParam := inputs.DatasetParam[0] - datasetIDs := datasetListInfoParam.Input.Value.Content.([]any) - knowledgeIDs := make([]int64, 0, len(datasetIDs)) - for _, id := range datasetIDs { - k, err := cast.ToInt64E(id) - if err != nil { - return nil, err - } - knowledgeIDs = append(knowledgeIDs, k) - } - ns.SetConfigKV("KnowledgeIDs", knowledgeIDs) - - retrievalStrategy := &knowledge.RetrievalStrategy{} - - var getDesignatedParamContent = func(name string) (any, bool) { - for _, param := range inputs.DatasetParam { - if param.Name == name { - return param.Input.Value.Content, true - } - } - return nil, false - - } - - if content, ok := getDesignatedParamContent("topK"); ok { - topK, err := cast.ToInt64E(content) - if err != nil { - return nil, err - } - retrievalStrategy.TopK = &topK - } - - if content, ok := getDesignatedParamContent("useRerank"); ok { - useRerank, err := cast.ToBoolE(content) - if err != nil { - return nil, err - } - retrievalStrategy.EnableRerank = useRerank - } - - if content, ok := getDesignatedParamContent("useRewrite"); ok { - useRewrite, err := cast.ToBoolE(content) - if err != nil { - return nil, err - } - retrievalStrategy.EnableQueryRewrite = useRewrite - } - - if content, ok := getDesignatedParamContent("isPersonalOnly"); ok { - isPersonalOnly, err := cast.ToBoolE(content) - if err != nil { - return nil, err - } - retrievalStrategy.IsPersonalOnly = isPersonalOnly - } - - if content, ok := getDesignatedParamContent("useNl2sql"); ok { - useNl2sql, err := cast.ToBoolE(content) - if err != nil { - return nil, err - } - retrievalStrategy.EnableNL2SQL = useNl2sql - } - - if content, ok := getDesignatedParamContent("minScore"); ok { - minScore, err := cast.ToFloat64E(content) - if err != nil { - return nil, err - } - retrievalStrategy.MinScore = &minScore - } - - if content, ok := getDesignatedParamContent("strategy"); ok { - strategy, err := cast.ToInt64E(content) - if err != nil { - return nil, err - } - searchType, err := ConvertRetrievalSearchType(strategy) - if err != nil { - return nil, err - } - retrievalStrategy.SearchType = searchType - } - - ns.SetConfigKV("RetrievalStrategy", retrievalStrategy) - - if err := SetInputsForNodeSchema(n, ns); err != nil { - return nil, err - } - - if err := SetOutputTypesForNodeSchema(n, ns); err != nil { - return nil, err - } - - return ns, nil -} - -func toKnowledgeDeleterSchema(n *vo.Node, _ ...OptionFn) (*compose.NodeSchema, error) { - ns := &compose.NodeSchema{ - Key: vo.NodeKey(n.ID), - Type: entity.NodeTypeKnowledgeDeleter, - Name: n.Data.Meta.Title, - } - - inputs := n.Data.Inputs - datasetListInfoParam := inputs.DatasetParam[0] - datasetIDs := datasetListInfoParam.Input.Value.Content.([]any) - if len(datasetIDs) == 0 { - return nil, fmt.Errorf("dataset ids is required") - } - knowledgeID, err := cast.ToInt64E(datasetIDs[0]) - if err != nil { - return nil, err - } - - ns.SetConfigKV("KnowledgeID", knowledgeID) - - if err := SetInputsForNodeSchema(n, ns); err != nil { - return nil, err - } - - if err := SetOutputTypesForNodeSchema(n, ns); err != nil { - return nil, err - } - - return ns, nil -} - -func toVariableAssignerSchema(n *vo.Node, _ ...OptionFn) (*compose.NodeSchema, error) { - ns := &compose.NodeSchema{ - Key: vo.NodeKey(n.ID), - Type: entity.NodeTypeVariableAssigner, - Name: n.Data.Meta.Title, - } - - var pairs = make([]*variableassigner.Pair, 0, len(n.Data.Inputs.InputParameters)) - for i, param := range n.Data.Inputs.InputParameters { - if param.Left == nil || param.Input == nil { - return nil, fmt.Errorf("variable assigner node's param left or input is nil") - } - - leftSources, err := CanvasBlockInputToFieldInfo(param.Left, einoCompose.FieldPath{fmt.Sprintf("left_%d", i)}, n.Parent()) - if err != nil { - return nil, err - } - - if leftSources[0].Source.Ref == nil { - return nil, fmt.Errorf("variable assigner node's param left source ref is nil") - } - - if leftSources[0].Source.Ref.VariableType == nil { - return nil, fmt.Errorf("variable assigner node's param left source ref's variable type is nil") - } - - if *leftSources[0].Source.Ref.VariableType == vo.GlobalSystem { - return nil, fmt.Errorf("variable assigner node's param left's ref's variable type cannot be variable.GlobalSystem") - } - - inputSource, err := CanvasBlockInputToFieldInfo(param.Input, leftSources[0].Source.Ref.FromPath, n.Parent()) - if err != nil { - return nil, err - } - ns.AddInputSource(inputSource...) - pair := &variableassigner.Pair{ - Left: *leftSources[0].Source.Ref, - Right: inputSource[0].Path, - } - pairs = append(pairs, pair) - } - ns.Configs = pairs - - return ns, nil -} - -func toCodeRunnerSchema(n *vo.Node, _ ...OptionFn) (*compose.NodeSchema, error) { - ns := &compose.NodeSchema{ - Key: vo.NodeKey(n.ID), - Type: entity.NodeTypeCodeRunner, - Name: n.Data.Meta.Title, - } - inputs := n.Data.Inputs - - code := inputs.Code - ns.SetConfigKV("Code", code) - - language, err := ConvertCodeLanguage(inputs.Language) - if err != nil { - return nil, err - } - ns.SetConfigKV("Language", language) - - if err := SetInputsForNodeSchema(n, ns); err != nil { - return nil, err - } - - if err := SetOutputTypesForNodeSchema(n, ns); err != nil { - return nil, err - } - - return ns, nil -} - -func toPluginSchema(n *vo.Node, _ ...OptionFn) (*compose.NodeSchema, error) { - ns := &compose.NodeSchema{ - Key: vo.NodeKey(n.ID), - Type: entity.NodeTypePlugin, - Name: n.Data.Meta.Title, - } - inputs := n.Data.Inputs - - apiParams := slices.ToMap(inputs.APIParams, func(e *vo.Param) (string, *vo.Param) { - return e.Name, e - }) - - ps, ok := apiParams["pluginID"] - if !ok { - return nil, fmt.Errorf("plugin id param is not found") - } - - pID, err := strconv.ParseInt(ps.Input.Value.Content.(string), 10, 64) - - ns.SetConfigKV("PluginID", pID) - - ps, ok = apiParams["apiID"] - if !ok { - return nil, fmt.Errorf("plugin id param is not found") - } - - tID, err := strconv.ParseInt(ps.Input.Value.Content.(string), 10, 64) - if err != nil { - return nil, err - } - - ns.SetConfigKV("ToolID", tID) - - ps, ok = apiParams["pluginVersion"] - if !ok { - return nil, fmt.Errorf("plugin version param is not found") - } - version := ps.Input.Value.Content.(string) - ns.SetConfigKV("PluginVersion", version) - - if err := SetInputsForNodeSchema(n, ns); err != nil { - return nil, err - } - - if err := SetOutputTypesForNodeSchema(n, ns); err != nil { - return nil, err - } - - return ns, nil -} - -func toVariableAggregatorSchema(n *vo.Node, _ ...OptionFn) (*compose.NodeSchema, error) { - ns := &compose.NodeSchema{ - Key: vo.NodeKey(n.ID), - Type: entity.NodeTypeVariableAggregator, - Name: n.Data.Meta.Title, - } - - ns.SetConfigKV("MergeStrategy", variableaggregator.FirstNotNullValue) - inputs := n.Data.Inputs - - groupToLen := make(map[string]int, len(inputs.VariableAggregator.MergeGroups)) - for i := range inputs.VariableAggregator.MergeGroups { - group := inputs.VariableAggregator.MergeGroups[i] - tInfo := &vo.TypeInfo{ - Type: vo.DataTypeObject, - Properties: make(map[string]*vo.TypeInfo), - } - ns.SetInputType(group.Name, tInfo) - for ii, v := range group.Variables { - name := strconv.Itoa(ii) - valueTypeInfo, err := CanvasBlockInputToTypeInfo(v) - if err != nil { - return nil, err - } - tInfo.Properties[name] = valueTypeInfo - sources, err := CanvasBlockInputToFieldInfo(v, einoCompose.FieldPath{group.Name, name}, n.Parent()) - if err != nil { - return nil, err - } - ns.AddInputSource(sources...) - } - - length := len(group.Variables) - groupToLen[group.Name] = length - } - - groupOrder := make([]string, 0, len(groupToLen)) - for i := range inputs.VariableAggregator.MergeGroups { - group := inputs.VariableAggregator.MergeGroups[i] - groupOrder = append(groupOrder, group.Name) - } - - ns.SetConfigKV("GroupToLen", groupToLen) - ns.SetConfigKV("GroupOrder", groupOrder) - - if err := SetOutputTypesForNodeSchema(n, ns); err != nil { - return nil, err - } - return ns, nil -} - -func toInputReceiverSchema(n *vo.Node, _ ...OptionFn) (*compose.NodeSchema, error) { - ns := &compose.NodeSchema{ - Key: vo.NodeKey(n.ID), - Type: entity.NodeTypeInputReceiver, - Name: n.Data.Meta.Title, - } - - ns.SetConfigKV("OutputSchema", n.Data.Inputs.OutputSchema) - - if err := SetOutputTypesForNodeSchema(n, ns); err != nil { - return nil, err - } - - return ns, nil -} - -func toQASchema(n *vo.Node, _ ...OptionFn) (*compose.NodeSchema, error) { - ns := &compose.NodeSchema{ - Key: vo.NodeKey(n.ID), - Type: entity.NodeTypeQuestionAnswer, - Name: n.Data.Meta.Title, - } - - qaConf := n.Data.Inputs.QA - if qaConf == nil { - return nil, fmt.Errorf("qa config is nil") - } - ns.SetConfigKV("QuestionTpl", qaConf.Question) - - var llmParams *model.LLMParams - if n.Data.Inputs.LLMParam != nil { - llmParamBytes, err := sonic.Marshal(n.Data.Inputs.LLMParam) - if err != nil { - return nil, err - } - var qaLLMParams vo.QALLMParam - err = sonic.Unmarshal(llmParamBytes, &qaLLMParams) - if err != nil { - return nil, err - } - - llmParams, err = qaLLMParamsToLLMParams(qaLLMParams) - if err != nil { - return nil, err - } - - ns.SetConfigKV("LLMParams", llmParams) - } - - answerType, err := qaAnswerTypeToAnswerType(qaConf.AnswerType) - if err != nil { - return nil, err - } - ns.SetConfigKV("AnswerType", answerType) - - var choiceType qa.ChoiceType - if len(qaConf.OptionType) > 0 { - choiceType, err = qaOptionTypeToChoiceType(qaConf.OptionType) - if err != nil { - return nil, err - } - ns.SetConfigKV("ChoiceType", choiceType) - } - - if answerType == qa.AnswerByChoices { - switch choiceType { - case qa.FixedChoices: - var options []string - for _, option := range qaConf.Options { - options = append(options, option.Name) - } - ns.SetConfigKV("FixedChoices", options) - case qa.DynamicChoices: - inputSources, err := CanvasBlockInputToFieldInfo(qaConf.DynamicOption, einoCompose.FieldPath{qa.DynamicChoicesKey}, n.Parent()) - if err != nil { - return nil, err - } - ns.AddInputSource(inputSources...) - - inputTypes, err := CanvasBlockInputToTypeInfo(qaConf.DynamicOption) - if err != nil { - return nil, err - } - ns.SetInputType(qa.DynamicChoicesKey, inputTypes) - default: - return nil, fmt.Errorf("qa node is answer by options, but option type not provided") - } - } else if answerType == qa.AnswerDirectly { - ns.SetConfigKV("ExtractFromAnswer", qaConf.ExtractOutput) - if qaConf.ExtractOutput { - if llmParams == nil { - return nil, fmt.Errorf("qa node needs to extract from answer, but LLMParams not provided") - } - ns.SetConfigKV("AdditionalSystemPromptTpl", llmParams.SystemPrompt) - ns.SetConfigKV("MaxAnswerCount", qaConf.Limit) - if err = SetOutputTypesForNodeSchema(n, ns); err != nil { - return nil, err - } - } - } - - if err = SetInputsForNodeSchema(n, ns); err != nil { - return nil, err - } - - return ns, nil -} - -func toJSONSerializeSchema(n *vo.Node, _ ...OptionFn) (*compose.NodeSchema, error) { - ns := &compose.NodeSchema{ - Key: vo.NodeKey(n.ID), - Type: entity.NodeTypeJsonSerialization, - Name: n.Data.Meta.Title, - } - - if err := SetInputsForNodeSchema(n, ns); err != nil { - return nil, err - } - - if err := SetOutputTypesForNodeSchema(n, ns); err != nil { - return nil, err - } - - return ns, nil -} - -func toJSONDeserializeSchema(n *vo.Node, _ ...OptionFn) (*compose.NodeSchema, error) { - ns := &compose.NodeSchema{ - Key: vo.NodeKey(n.ID), - Type: entity.NodeTypeJsonDeserialization, - Name: n.Data.Meta.Title, - } - - if err := SetInputsForNodeSchema(n, ns); err != nil { - return nil, err - } - - if err := SetOutputTypesForNodeSchema(n, ns); err != nil { - return nil, err - } - - return ns, nil -} - -func buildClauseGroupFromCondition(condition *vo.DBCondition) (*database.ClauseGroup, error) { - clauseGroup := &database.ClauseGroup{} - if len(condition.ConditionList) == 1 { - params := condition.ConditionList[0] - clause, err := buildClauseFromParams(params) - if err != nil { - return nil, err - } - clauseGroup.Single = clause - } else { - relation, err := ConvertLogicTypeToRelation(condition.Logic) - if err != nil { - return nil, err - } - clauseGroup.Multi = &database.MultiClause{ - Clauses: make([]*database.Clause, 0, len(condition.ConditionList)), - Relation: relation, - } - for i := range condition.ConditionList { - params := condition.ConditionList[i] - clause, err := buildClauseFromParams(params) - if err != nil { - return nil, err - } - clauseGroup.Multi.Clauses = append(clauseGroup.Multi.Clauses, clause) - } - } - - return clauseGroup, nil -} - func PruneIsolatedNodes(nodes []*vo.Node, edges []*vo.Edge, parentNode *vo.Node) ([]*vo.Node, []*vo.Edge) { nodeDependencyCount := map[string]int{} if parentNode != nil { @@ -1985,7 +395,7 @@ func PruneIsolatedNodes(nodes []*vo.Node, edges []*vo.Edge, parentNode *vo.Node) node.Blocks, node.Edges = PruneIsolatedNodes(node.Blocks, node.Edges, node) } nodeDependencyCount[node.ID] = 0 - if node.Type == vo.BlockTypeBotContinue || node.Type == vo.BlockTypeBotBreak { + if node.Type == entity.NodeTypeContinue.IDStr() || node.Type == entity.NodeTypeBreak.IDStr() { if parentNode != nil { nodeDependencyCount[parentNode.ID]++ } @@ -2026,39 +436,6 @@ func PruneIsolatedNodes(nodes []*vo.Node, edges []*vo.Edge, parentNode *vo.Node) return connectedNodes, connectedEdges } -func buildClauseFromParams(params []*vo.Param) (*database.Clause, error) { - var left, operation *vo.Param - for _, p := range params { - if p == nil { - continue - } - if p.Name == "left" { - left = p - continue - } - if p.Name == "operation" { - operation = p - continue - } - } - if left == nil { - return nil, fmt.Errorf("left clause is required") - } - if operation == nil { - return nil, fmt.Errorf("operation clause is required") - } - operator, err := OperationToOperator(operation.Input.Value.Content.(string)) - if err != nil { - return nil, err - } - clause := &database.Clause{ - Left: left.Input.Value.Content.(string), - Operator: operator, - } - - return clause, nil -} - func parseBatchMode(n *vo.Node) ( batchN *vo.Node, // the new batch node enabled bool, // whether the node has enabled batch mode @@ -2135,9 +512,9 @@ func parseBatchMode(n *vo.Node) ( parentN := &vo.Node{ ID: n.ID, - Type: vo.BlockTypeBotBatch, + Type: entity.NodeTypeBatch.IDStr(), Data: &vo.Data{ - Meta: &vo.NodeMeta{ + Meta: &vo.NodeMetaFE{ Title: n.Data.Meta.Title, }, Inputs: &vo.Inputs{ @@ -2169,13 +546,13 @@ func parseBatchMode(n *vo.Node) ( ID: n.ID + "_inner", Type: n.Type, Data: &vo.Data{ - Meta: &vo.NodeMeta{ + Meta: &vo.NodeMetaFE{ Title: n.Data.Meta.Title + "_inner", }, Inputs: &vo.Inputs{ InputParameters: innerInput, LLMParam: n.Data.Inputs.LLMParam, // for llm node - FCParam: n.Data.Inputs.FCParam, // for llm node + LLM: n.Data.Inputs.LLM, // for llm node SettingOnError: n.Data.Inputs.SettingOnError, // for llm, sub-workflow and plugin nodes SubWorkflow: n.Data.Inputs.SubWorkflow, // for sub-workflow node PluginAPIParam: n.Data.Inputs.PluginAPIParam, // for plugin node @@ -2205,135 +582,106 @@ func parseBatchMode(n *vo.Node) ( return parentN, true, nil } -var extractBracesRegexp = regexp.MustCompile(`\{\{(.*?)\}\}`) - -func extractBracesContent(s string) []string { - matches := extractBracesRegexp.FindAllStringSubmatch(s, -1) - var result []string - for _, match := range matches { - if len(match) >= 2 { - result = append(result, match[1]) - } - } - return result -} - -func getTypeInfoByPath(root string, properties []string, tInfoMap map[string]*vo.TypeInfo) (*vo.TypeInfo, bool) { - if len(properties) == 0 { - if tInfo, ok := tInfoMap[root]; ok { - return tInfo, true - } - return nil, false - } - tInfo, ok := tInfoMap[root] - if !ok { - return nil, false - } - return getTypeInfoByPath(properties[0], properties[1:], tInfo.Properties) - -} - -func extractImplicitDependency(node *vo.Node, nodes []*vo.Node) ([]*vo.ImplicitNodeDependency, error) { - - if len(node.Blocks) > 0 { - nodes = append(nodes, node.Blocks...) - dependencies := make([]*vo.ImplicitNodeDependency, 0, len(nodes)) - for _, subNode := range node.Blocks { - ds, err := extractImplicitDependency(subNode, nodes) - if err != nil { - return nil, err - } - dependencies = append(dependencies, ds...) - } - return dependencies, nil - - } - - if node.Type != vo.BlockTypeBotHttp { - return nil, nil - } - - dependencies := make([]*vo.ImplicitNodeDependency, 0, len(nodes)) - url := node.Data.Inputs.APIInfo.URL - urlVars := extractBracesContent(url) - hasReferred := make(map[string]bool) - extractDependenciesFromVars := func(vars []string) error { - for _, v := range vars { - if strings.HasPrefix(v, "block_output_") { - paths := strings.Split(strings.TrimPrefix(v, "block_output_"), ".") - if len(paths) < 2 { - return fmt.Errorf("invalid block_output_ variable: %s", v) - } - if hasReferred[v] { - continue - } - hasReferred[v] = true - dependencies = append(dependencies, &vo.ImplicitNodeDependency{ - NodeID: paths[0], - FieldPath: paths[1:], - }) - } - } - return nil - } - - err := extractDependenciesFromVars(urlVars) - if err != nil { - return nil, err - } - if node.Data.Inputs.Body.BodyType == string(httprequester.BodyTypeJSON) { - jsonVars := extractBracesContent(node.Data.Inputs.Body.BodyData.Json) - err = extractDependenciesFromVars(jsonVars) - if err != nil { - return nil, err - } - } - if node.Data.Inputs.Body.BodyType == string(httprequester.BodyTypeRawText) { - rawTextVars := extractBracesContent(node.Data.Inputs.Body.BodyData.Json) - err = extractDependenciesFromVars(rawTextVars) - if err != nil { - return nil, err - } - } - - var nodeFinder func(nodes []*vo.Node, nodeID string) *vo.Node - nodeFinder = func(nodes []*vo.Node, nodeID string) *vo.Node { - for i := range nodes { - if nodes[i].ID == nodeID { - return nodes[i] - } - if len(nodes[i].Blocks) > 0 { - if n := nodeFinder(nodes[i].Blocks, nodeID); n != nil { - return n - } - } - } - return nil - } - for _, ds := range dependencies { - fNode := nodeFinder(nodes, ds.NodeID) - if fNode == nil { - continue - } - tInfoMap := make(map[string]*vo.TypeInfo, len(node.Data.Outputs)) - for _, vAny := range fNode.Data.Outputs { - v, err := vo.ParseVariable(vAny) - if err != nil { - return nil, err - } - tInfo, err := CanvasVariableToTypeInfo(v) - if err != nil { - return nil, err - } - tInfoMap[v.Name] = tInfo - } - tInfo, ok := getTypeInfoByPath(ds.FieldPath[0], ds.FieldPath[1:], tInfoMap) - if !ok { - return nil, fmt.Errorf("cannot find type info for dependency: %s", ds.FieldPath) - } - ds.TypeInfo = tInfo - } - - return dependencies, nil +// RegisterAllNodeAdaptors register all NodeType's NodeAdaptor. +func RegisterAllNodeAdaptors() { + // register a generator function so that each time a NodeAdaptor is needed, + // we can provide a brand new Config instance. + nodes.RegisterNodeAdaptor(entity.NodeTypeEntry, func() nodes.NodeAdaptor { + return &entry.Config{} + }) + nodes.RegisterNodeAdaptor(entity.NodeTypeSelector, func() nodes.NodeAdaptor { + return &selector.Config{} + }) + nodes.RegisterNodeAdaptor(entity.NodeTypeBatch, func() nodes.NodeAdaptor { + return &batch.Config{} + }) + nodes.RegisterNodeAdaptor(entity.NodeTypeBreak, func() nodes.NodeAdaptor { + return &_break.Config{} + }) + nodes.RegisterNodeAdaptor(entity.NodeTypeContinue, func() nodes.NodeAdaptor { + return &_continue.Config{} + }) + nodes.RegisterNodeAdaptor(entity.NodeTypeInputReceiver, func() nodes.NodeAdaptor { + return &receiver.Config{} + }) + nodes.RegisterNodeAdaptor(entity.NodeTypeJsonSerialization, func() nodes.NodeAdaptor { + return &json.SerializationConfig{} + }) + nodes.RegisterNodeAdaptor(entity.NodeTypeJsonDeserialization, func() nodes.NodeAdaptor { + return &json.DeserializationConfig{} + }) + nodes.RegisterNodeAdaptor(entity.NodeTypeVariableAssigner, func() nodes.NodeAdaptor { + return &variableassigner.Config{} + }) + nodes.RegisterNodeAdaptor(entity.NodeTypeVariableAssignerWithinLoop, func() nodes.NodeAdaptor { + return &variableassigner.InLoopConfig{} + }) + nodes.RegisterNodeAdaptor(entity.NodeTypePlugin, func() nodes.NodeAdaptor { + return &plugin.Config{} + }) + nodes.RegisterNodeAdaptor(entity.NodeTypeCodeRunner, func() nodes.NodeAdaptor { + return &code.Config{} + }) + nodes.RegisterNodeAdaptor(entity.NodeTypeOutputEmitter, func() nodes.NodeAdaptor { + return &emitter.Config{} + }) + nodes.RegisterNodeAdaptor(entity.NodeTypeExit, func() nodes.NodeAdaptor { + return &exit.Config{} + }) + nodes.RegisterNodeAdaptor(entity.NodeTypeVariableAggregator, func() nodes.NodeAdaptor { + return &variableaggregator.Config{} + }) + nodes.RegisterNodeAdaptor(entity.NodeTypeTextProcessor, func() nodes.NodeAdaptor { + return &textprocessor.Config{} + }) + nodes.RegisterNodeAdaptor(entity.NodeTypeIntentDetector, func() nodes.NodeAdaptor { + return &intentdetector.Config{} + }) + nodes.RegisterNodeAdaptor(entity.NodeTypeQuestionAnswer, func() nodes.NodeAdaptor { + return &qa.Config{} + }) + nodes.RegisterNodeAdaptor(entity.NodeTypeHTTPRequester, func() nodes.NodeAdaptor { + return &httprequester.Config{} + }) + nodes.RegisterNodeAdaptor(entity.NodeTypeLoop, func() nodes.NodeAdaptor { + return &loop.Config{} + }) + nodes.RegisterNodeAdaptor(entity.NodeTypeKnowledgeIndexer, func() nodes.NodeAdaptor { + return &knowledge.IndexerConfig{} + }) + nodes.RegisterNodeAdaptor(entity.NodeTypeKnowledgeRetriever, func() nodes.NodeAdaptor { + return &knowledge.RetrieveConfig{} + }) + nodes.RegisterNodeAdaptor(entity.NodeTypeKnowledgeDeleter, func() nodes.NodeAdaptor { + return &knowledge.DeleterConfig{} + }) + nodes.RegisterNodeAdaptor(entity.NodeTypeDatabaseInsert, func() nodes.NodeAdaptor { + return &database.InsertConfig{} + }) + nodes.RegisterNodeAdaptor(entity.NodeTypeDatabaseUpdate, func() nodes.NodeAdaptor { + return &database.UpdateConfig{} + }) + nodes.RegisterNodeAdaptor(entity.NodeTypeDatabaseQuery, func() nodes.NodeAdaptor { + return &database.QueryConfig{} + }) + nodes.RegisterNodeAdaptor(entity.NodeTypeDatabaseDelete, func() nodes.NodeAdaptor { + return &database.DeleteConfig{} + }) + nodes.RegisterNodeAdaptor(entity.NodeTypeDatabaseCustomSQL, func() nodes.NodeAdaptor { + return &database.CustomSQLConfig{} + }) + nodes.RegisterNodeAdaptor(entity.NodeTypeLLM, func() nodes.NodeAdaptor { + return &llm.Config{} + }) + // register branch adaptors + nodes.RegisterBranchAdaptor(entity.NodeTypeSelector, func() nodes.BranchAdaptor { + return &selector.Config{} + }) + nodes.RegisterBranchAdaptor(entity.NodeTypeIntentDetector, func() nodes.BranchAdaptor { + return &intentdetector.Config{} + }) + nodes.RegisterBranchAdaptor(entity.NodeTypeQuestionAnswer, func() nodes.BranchAdaptor { + return &qa.Config{} + }) } diff --git a/backend/domain/workflow/internal/canvas/adaptor/type_convert.go b/backend/domain/workflow/internal/canvas/adaptor/type_convert.go deleted file mode 100644 index c2f539fd..00000000 --- a/backend/domain/workflow/internal/canvas/adaptor/type_convert.go +++ /dev/null @@ -1,1230 +0,0 @@ -/* - * Copyright 2025 coze-dev Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package adaptor - -import ( - "fmt" - "regexp" - "strconv" - "strings" - - einoCompose "github.com/cloudwego/eino/compose" - - "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database" - "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge" - "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model" - "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/httprequester" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector" - "github.com/coze-dev/coze-studio/backend/infra/contract/coderunner" - "github.com/coze-dev/coze-studio/backend/pkg/lang/crypto" - "github.com/coze-dev/coze-studio/backend/pkg/sonic" - "github.com/coze-dev/coze-studio/backend/types/errno" -) - -func CanvasVariableToTypeInfo(v *vo.Variable) (*vo.TypeInfo, error) { - tInfo := &vo.TypeInfo{ - Required: v.Required, - Desc: v.Description, - } - - switch v.Type { - case vo.VariableTypeString: - switch v.AssistType { - case vo.AssistTypeTime: - tInfo.Type = vo.DataTypeTime - case vo.AssistTypeNotSet: - tInfo.Type = vo.DataTypeString - default: - fileType, ok := assistTypeToFileType(v.AssistType) - if ok { - tInfo.Type = vo.DataTypeFile - tInfo.FileType = &fileType - } else { - return nil, fmt.Errorf("unsupported assist type: %v", v.AssistType) - } - } - case vo.VariableTypeInteger: - tInfo.Type = vo.DataTypeInteger - case vo.VariableTypeFloat: - tInfo.Type = vo.DataTypeNumber - case vo.VariableTypeBoolean: - tInfo.Type = vo.DataTypeBoolean - case vo.VariableTypeObject: - tInfo.Type = vo.DataTypeObject - tInfo.Properties = make(map[string]*vo.TypeInfo) - if v.Schema != nil { - for _, subVAny := range v.Schema.([]any) { - subV, err := vo.ParseVariable(subVAny) - if err != nil { - return nil, err - } - subTInfo, err := CanvasVariableToTypeInfo(subV) - if err != nil { - return nil, err - } - tInfo.Properties[subV.Name] = subTInfo - } - } - case vo.VariableTypeList: - tInfo.Type = vo.DataTypeArray - subVAny := v.Schema - subV, err := vo.ParseVariable(subVAny) - if err != nil { - return nil, err - } - subTInfo, err := CanvasVariableToTypeInfo(subV) - if err != nil { - return nil, err - } - tInfo.ElemTypeInfo = subTInfo - - default: - return nil, fmt.Errorf("unsupported variable type: %s", v.Type) - } - - return tInfo, nil -} - -func CanvasBlockInputToTypeInfo(b *vo.BlockInput) (tInfo *vo.TypeInfo, err error) { - defer func() { - if err != nil { - err = vo.WrapIfNeeded(errno.ErrSchemaConversionFail, err) - } - }() - - tInfo = &vo.TypeInfo{} - - if b == nil { - return tInfo, nil - } - - switch b.Type { - case vo.VariableTypeString: - switch b.AssistType { - case vo.AssistTypeTime: - tInfo.Type = vo.DataTypeTime - case vo.AssistTypeNotSet: - tInfo.Type = vo.DataTypeString - default: - fileType, ok := assistTypeToFileType(b.AssistType) - if ok { - tInfo.Type = vo.DataTypeFile - tInfo.FileType = &fileType - } else { - return nil, fmt.Errorf("unsupported assist type: %v", b.AssistType) - } - } - case vo.VariableTypeInteger: - tInfo.Type = vo.DataTypeInteger - case vo.VariableTypeFloat: - tInfo.Type = vo.DataTypeNumber - case vo.VariableTypeBoolean: - tInfo.Type = vo.DataTypeBoolean - case vo.VariableTypeObject: - tInfo.Type = vo.DataTypeObject - tInfo.Properties = make(map[string]*vo.TypeInfo) - if b.Schema != nil { - for _, subVAny := range b.Schema.([]any) { - if b.Value.Type == vo.BlockInputValueTypeRef { - subV, err := vo.ParseVariable(subVAny) - if err != nil { - return nil, err - } - subTInfo, err := CanvasVariableToTypeInfo(subV) - if err != nil { - return nil, err - } - tInfo.Properties[subV.Name] = subTInfo - } else if b.Value.Type == vo.BlockInputValueTypeObjectRef { - subV, err := parseParam(subVAny) - if err != nil { - return nil, err - } - subTInfo, err := CanvasBlockInputToTypeInfo(subV.Input) - if err != nil { - return nil, err - } - tInfo.Properties[subV.Name] = subTInfo - } - } - } - case vo.VariableTypeList: - tInfo.Type = vo.DataTypeArray - subVAny := b.Schema - subV, err := vo.ParseVariable(subVAny) - if err != nil { - return nil, err - } - subTInfo, err := CanvasVariableToTypeInfo(subV) - if err != nil { - return nil, err - } - tInfo.ElemTypeInfo = subTInfo - default: - return nil, fmt.Errorf("unsupported variable type: %s", b.Type) - } - - return tInfo, nil -} - -func CanvasBlockInputToFieldInfo(b *vo.BlockInput, path einoCompose.FieldPath, parentNode *vo.Node) (sources []*vo.FieldInfo, err error) { - value := b.Value - if value == nil { - return nil, fmt.Errorf("input %v has no value, type= %s", path, b.Type) - } - - switch value.Type { - case vo.BlockInputValueTypeObjectRef: - sc := b.Schema - if sc == nil { - return nil, fmt.Errorf("input %v has no schema, type= %s", path, b.Type) - } - - paramList, ok := sc.([]any) - if !ok { - return nil, fmt.Errorf("input %v schema not []any, type= %T", path, sc) - } - - for i := range paramList { - paramAny := paramList[i] - param, err := parseParam(paramAny) - if err != nil { - return nil, err - } - - copied := make([]string, len(path)) - copy(copied, path) - subFieldInfo, err := CanvasBlockInputToFieldInfo(param.Input, append(copied, param.Name), parentNode) - if err != nil { - return nil, err - } - sources = append(sources, subFieldInfo...) - } - return sources, nil - case vo.BlockInputValueTypeLiteral: - content := value.Content - if content == nil { - return nil, fmt.Errorf("input %v is literal but has no value, type= %s", path, b.Type) - } - - switch b.Type { - case vo.VariableTypeObject: - m := make(map[string]any) - if err = sonic.UnmarshalString(content.(string), &m); err != nil { - return nil, err - } - content = m - case vo.VariableTypeList: - l := make([]any, 0) - if err = sonic.UnmarshalString(content.(string), &l); err != nil { - return nil, err - } - content = l - case vo.VariableTypeInteger: - switch content.(type) { - case string: - content, err = strconv.ParseInt(content.(string), 10, 64) - if err != nil { - return nil, err - } - case int64: - content = content.(int64) - case float64: - content = int64(content.(float64)) - default: - return nil, fmt.Errorf("unsupported variable type fot integer: %s", b.Type) - } - case vo.VariableTypeFloat: - switch content.(type) { - case string: - content, err = strconv.ParseFloat(content.(string), 64) - if err != nil { - return nil, err - } - case int64: - content = float64(content.(int64)) - case float64: - content = content.(float64) - default: - return nil, fmt.Errorf("unsupported variable type for float: %s", b.Type) - } - case vo.VariableTypeBoolean: - switch content.(type) { - case string: - content, err = strconv.ParseBool(content.(string)) - if err != nil { - return nil, err - } - case bool: - content = content.(bool) - default: - return nil, fmt.Errorf("unsupported variable type for boolean: %s", b.Type) - } - default: - } - return []*vo.FieldInfo{ - { - Path: path, - Source: vo.FieldSource{ - Val: content, - }, - }, - }, nil - case vo.BlockInputValueTypeRef: - content := value.Content - if content == nil { - return nil, fmt.Errorf("input %v is literal but has no value, type= %s", path, b.Type) - } - - ref, err := parseBlockInputRef(content) - if err != nil { - return nil, err - } - - fieldSource, err := CanvasBlockInputRefToFieldSource(ref) - if err != nil { - return nil, err - } - - if parentNode != nil { - if fieldSource.Ref != nil && len(fieldSource.Ref.FromNodeKey) > 0 && fieldSource.Ref.FromNodeKey == vo.NodeKey(parentNode.ID) { - varRoot := fieldSource.Ref.FromPath[0] - for _, p := range parentNode.Data.Inputs.VariableParameters { - if p.Name == varRoot { - fieldSource.Ref.FromNodeKey = "" - pi := vo.ParentIntermediate - fieldSource.Ref.VariableType = &pi - } - } - } - } - - return []*vo.FieldInfo{ - { - Path: path, - Source: *fieldSource, - }, - }, nil - default: - return nil, fmt.Errorf("unsupported value type: %s for blockInput type= %s", value.Type, b.Type) - } -} - -func parseBlockInputRef(content any) (*vo.BlockInputReference, error) { - if bi, ok := content.(*vo.BlockInputReference); ok { - return bi, nil - } - - m, ok := content.(map[string]any) - if !ok { - return nil, fmt.Errorf("invalid content type: %T when parse BlockInputRef", content) - } - - marshaled, err := sonic.Marshal(m) - if err != nil { - return nil, err - } - - p := &vo.BlockInputReference{} - if err := sonic.Unmarshal(marshaled, p); err != nil { - return nil, err - } - - return p, nil -} - -func parseParam(v any) (*vo.Param, error) { - if pa, ok := v.(*vo.Param); ok { - return pa, nil - } - - m, ok := v.(map[string]any) - if !ok { - return nil, fmt.Errorf("invalid content type: %T when parse Param", v) - } - - marshaled, err := sonic.Marshal(m) - if err != nil { - return nil, err - } - - p := &vo.Param{} - if err := sonic.Unmarshal(marshaled, p); err != nil { - return nil, err - } - - return p, nil -} - -func CanvasBlockInputRefToFieldSource(r *vo.BlockInputReference) (*vo.FieldSource, error) { - switch r.Source { - case vo.RefSourceTypeBlockOutput: - if len(r.BlockID) == 0 { - return nil, fmt.Errorf("invalid BlockInputReference = %+v, BlockID is empty when source is block output", r) - } - - parts := strings.Split(r.Name, ".") // an empty r.Name signals an all-to-all mapping - return &vo.FieldSource{ - Ref: &vo.Reference{ - FromNodeKey: vo.NodeKey(r.BlockID), - FromPath: parts, - }, - }, nil - case vo.RefSourceTypeGlobalApp, vo.RefSourceTypeGlobalSystem, vo.RefSourceTypeGlobalUser: - if len(r.Path) == 0 { - return nil, fmt.Errorf("invalid BlockInputReference = %+v, Path is empty when source is variables", r) - } - - var varType vo.GlobalVarType - switch r.Source { - case vo.RefSourceTypeGlobalApp: - varType = vo.GlobalAPP - case vo.RefSourceTypeGlobalSystem: - varType = vo.GlobalSystem - case vo.RefSourceTypeGlobalUser: - varType = vo.GlobalUser - default: - return nil, fmt.Errorf("invalid BlockInputReference = %+v, Source is invalid", r) - } - - return &vo.FieldSource{ - Ref: &vo.Reference{ - VariableType: &varType, - FromPath: r.Path, - }, - }, nil - default: - return nil, fmt.Errorf("unsupported ref source type: %s", r.Source) - } -} - -func assistTypeToFileType(a vo.AssistType) (vo.FileSubType, bool) { - switch a { - case vo.AssistTypeNotSet: - return "", false - case vo.AssistTypeTime: - return "", false - case vo.AssistTypeImage: - return vo.FileTypeImage, true - case vo.AssistTypeAudio: - return vo.FileTypeAudio, true - case vo.AssistTypeVideo: - return vo.FileTypeVideo, true - case vo.AssistTypeDefault: - return vo.FileTypeDefault, true - case vo.AssistTypeDoc: - return vo.FileTypeDocument, true - case vo.AssistTypeExcel: - return vo.FileTypeExcel, true - case vo.AssistTypeCode: - return vo.FileTypeCode, true - case vo.AssistTypePPT: - return vo.FileTypePPT, true - case vo.AssistTypeTXT: - return vo.FileTypeTxt, true - case vo.AssistTypeSvg: - return vo.FileTypeSVG, true - case vo.AssistTypeVoice: - return vo.FileTypeVoice, true - case vo.AssistTypeZip: - return vo.FileTypeZip, true - default: - panic("impossible") - } -} - -func LLMParamsToLLMParam(params vo.LLMParam) (*model.LLMParams, error) { - p := &model.LLMParams{} - for _, param := range params { - switch param.Name { - case "temperature": - strVal := param.Input.Value.Content.(string) - floatVal, err := strconv.ParseFloat(strVal, 64) - if err != nil { - return nil, err - } - p.Temperature = &floatVal - case "maxTokens": - strVal := param.Input.Value.Content.(string) - intVal, err := strconv.Atoi(strVal) - if err != nil { - return nil, err - } - p.MaxTokens = intVal - case "responseFormat": - strVal := param.Input.Value.Content.(string) - int64Val, err := strconv.ParseInt(strVal, 10, 64) - if err != nil { - return nil, err - } - p.ResponseFormat = model.ResponseFormat(int64Val) - case "modleName": - strVal := param.Input.Value.Content.(string) - p.ModelName = strVal - case "modelType": - strVal := param.Input.Value.Content.(string) - int64Val, err := strconv.ParseInt(strVal, 10, 64) - if err != nil { - return nil, err - } - p.ModelType = int64Val - case "prompt": - strVal := param.Input.Value.Content.(string) - p.Prompt = strVal - case "enableChatHistory": - boolVar := param.Input.Value.Content.(bool) - p.EnableChatHistory = boolVar - case "systemPrompt": - strVal := param.Input.Value.Content.(string) - p.SystemPrompt = strVal - case "chatHistoryRound", "generationDiversity", "frequencyPenalty", "presencePenalty": - // do nothing - case "topP": - strVal := param.Input.Value.Content.(string) - floatVar, err := strconv.ParseFloat(strVal, 64) - if err != nil { - return nil, err - } - p.TopP = &floatVar - default: - return nil, fmt.Errorf("invalid LLMParam name: %s", param.Name) - } - } - - return p, nil -} - -func qaLLMParamsToLLMParams(params vo.QALLMParam) (*model.LLMParams, error) { - p := &model.LLMParams{} - p.ModelName = params.ModelName - p.ModelType = params.ModelType - p.Temperature = ¶ms.Temperature - p.MaxTokens = params.MaxTokens - p.TopP = ¶ms.TopP - p.ResponseFormat = params.ResponseFormat - p.SystemPrompt = params.SystemPrompt - return p, nil -} - -func qaAnswerTypeToAnswerType(t vo.QAAnswerType) (qa.AnswerType, error) { - switch t { - case vo.QAAnswerTypeOption: - return qa.AnswerByChoices, nil - case vo.QAAnswerTypeText: - return qa.AnswerDirectly, nil - default: - return "", fmt.Errorf("invalid QAAnswerType: %s", t) - } -} - -func qaOptionTypeToChoiceType(t vo.QAOptionType) (qa.ChoiceType, error) { - switch t { - case vo.QAOptionTypeStatic: - return qa.FixedChoices, nil - case vo.QAOptionTypeDynamic: - return qa.DynamicChoices, nil - default: - return "", fmt.Errorf("invalid QAOptionType: %s", t) - } -} - -func SetInputsForNodeSchema(n *vo.Node, ns *compose.NodeSchema) error { - inputParams := n.Data.Inputs.InputParameters - if len(inputParams) == 0 { - return nil - } - - for _, param := range inputParams { - name := param.Name - tInfo, err := CanvasBlockInputToTypeInfo(param.Input) - if err != nil { - return err - } - - ns.SetInputType(name, tInfo) - - sources, err := CanvasBlockInputToFieldInfo(param.Input, einoCompose.FieldPath{name}, n.Parent()) - if err != nil { - return err - } - - ns.AddInputSource(sources...) - } - - return nil -} - -func SetDatabaseInputsForNodeSchema(n *vo.Node, ns *compose.NodeSchema) (err error) { - selectParam := n.Data.Inputs.SelectParam - if selectParam != nil { - err = applyDBConditionToSchema(ns, selectParam.Condition, n.Parent()) - if err != nil { - return err - } - } - - insertParam := n.Data.Inputs.InsertParam - if insertParam != nil { - err = applyInsetFieldInfoToSchema(ns, insertParam.FieldInfo, n.Parent()) - if err != nil { - return err - } - } - - deleteParam := n.Data.Inputs.DeleteParam - if deleteParam != nil { - err = applyDBConditionToSchema(ns, &deleteParam.Condition, n.Parent()) - if err != nil { - return err - } - } - - updateParam := n.Data.Inputs.UpdateParam - if updateParam != nil { - err = applyDBConditionToSchema(ns, &updateParam.Condition, n.Parent()) - if err != nil { - return err - } - err = applyInsetFieldInfoToSchema(ns, updateParam.FieldInfo, n.Parent()) - if err != nil { - return err - } - } - return nil -} - -var globalVariableRegex = regexp.MustCompile(`global_variable_\w+\s*\["(.*?)"\]`) - -func SetHttpRequesterInputsForNodeSchema(n *vo.Node, ns *compose.NodeSchema, implicitNodeDependencies []*vo.ImplicitNodeDependency) (err error) { - inputs := n.Data.Inputs - implicitPathVars := make(map[string]bool) - addImplicitVarsSources := func(prefix string, vars []string) error { - for _, v := range vars { - if strings.HasPrefix(v, "block_output_") { - paths := strings.Split(strings.TrimPrefix(v, "block_output_"), ".") - if len(paths) < 2 { - return fmt.Errorf("invalid implicit var : %s", v) - } - for _, dep := range implicitNodeDependencies { - if dep.NodeID == paths[0] && strings.Join(dep.FieldPath, ".") == strings.Join(paths[1:], ".") { - pathValue := prefix + crypto.MD5HexValue(v) - if _, visited := implicitPathVars[pathValue]; visited { - continue - } - implicitPathVars[pathValue] = true - ns.SetInputType(pathValue, dep.TypeInfo) - ns.AddInputSource(&vo.FieldInfo{ - Path: []string{pathValue}, - Source: vo.FieldSource{ - Ref: &vo.Reference{ - FromNodeKey: vo.NodeKey(dep.NodeID), - FromPath: dep.FieldPath, - }, - }, - }) - } - } - } - if strings.HasPrefix(v, "global_variable_") { - matches := globalVariableRegex.FindStringSubmatch(v) - if len(matches) < 2 { - continue - } - - var varType vo.GlobalVarType - if strings.HasPrefix(v, string(vo.RefSourceTypeGlobalApp)) { - varType = vo.GlobalAPP - } else if strings.HasPrefix(v, string(vo.RefSourceTypeGlobalUser)) { - varType = vo.GlobalUser - } else if strings.HasPrefix(v, string(vo.RefSourceTypeGlobalSystem)) { - varType = vo.GlobalSystem - } else { - return fmt.Errorf("invalid global variable type: %s", v) - } - - source := vo.FieldSource{ - Ref: &vo.Reference{ - VariableType: &varType, - FromPath: []string{matches[1]}, - }, - } - - ns.AddInputSource(&vo.FieldInfo{ - Path: []string{prefix + crypto.MD5HexValue(v)}, - Source: source, - }) - - } - } - return nil - - } - - urlVars := extractBracesContent(inputs.APIInfo.URL) - err = addImplicitVarsSources("__apiInfo_url_", urlVars) - if err != nil { - return err - } - - err = applyParamsToSchema(ns, "__headers_", inputs.Headers, n.Parent()) - if err != nil { - return err - } - - err = applyParamsToSchema(ns, "__params_", inputs.Params, n.Parent()) - if err != nil { - return err - } - - if inputs.Auth != nil && inputs.Auth.AuthOpen { - authData := inputs.Auth.AuthData - const bearerTokenKey = "__auth_authData_bearerTokenData_token" - if inputs.Auth.AuthType == "BEARER_AUTH" { - bearTokenParam := authData.BearerTokenData[0] - tInfo, err := CanvasBlockInputToTypeInfo(bearTokenParam.Input) - if err != nil { - return err - } - ns.SetInputType(bearerTokenKey, tInfo) - sources, err := CanvasBlockInputToFieldInfo(bearTokenParam.Input, einoCompose.FieldPath{bearerTokenKey}, n.Parent()) - if err != nil { - return err - } - ns.AddInputSource(sources...) - - } - - if inputs.Auth.AuthType == "CUSTOM_AUTH" { - const ( - customDataDataKey = "__auth_authData_customData_data_Key" - customDataDataValue = "__auth_authData_customData_data_Value" - ) - dataParams := authData.CustomData.Data - keyParam := dataParams[0] - keyTypeInfo, err := CanvasBlockInputToTypeInfo(keyParam.Input) - if err != nil { - return err - } - ns.SetInputType(customDataDataKey, keyTypeInfo) - sources, err := CanvasBlockInputToFieldInfo(keyParam.Input, einoCompose.FieldPath{customDataDataKey}, n.Parent()) - if err != nil { - return err - } - ns.AddInputSource(sources...) - - valueParam := dataParams[1] - valueTypeInfo, err := CanvasBlockInputToTypeInfo(valueParam.Input) - if err != nil { - return err - } - ns.SetInputType(customDataDataValue, valueTypeInfo) - sources, err = CanvasBlockInputToFieldInfo(valueParam.Input, einoCompose.FieldPath{customDataDataValue}, n.Parent()) - if err != nil { - return err - } - ns.AddInputSource(sources...) - - } - - } - - switch httprequester.BodyType(inputs.Body.BodyType) { - case httprequester.BodyTypeFormData: - err = applyParamsToSchema(ns, "__body_bodyData_formData_", inputs.Body.BodyData.FormData.Data, n.Parent()) - if err != nil { - return err - } - case httprequester.BodyTypeFormURLEncoded: - err = applyParamsToSchema(ns, "__body_bodyData_formURLEncoded_", inputs.Body.BodyData.FormURLEncoded, n.Parent()) - if err != nil { - return err - } - case httprequester.BodyTypeBinary: - const fileURLName = "__body_bodyData_binary_fileURL" - fileURLInput := inputs.Body.BodyData.Binary.FileURL - ns.SetInputType(fileURLName, &vo.TypeInfo{ - Type: vo.DataTypeString, - }) - sources, err := CanvasBlockInputToFieldInfo(fileURLInput, einoCompose.FieldPath{fileURLName}, n.Parent()) - if err != nil { - return err - } - ns.AddInputSource(sources...) - case httprequester.BodyTypeJSON: - jsonVars := extractBracesContent(inputs.Body.BodyData.Json) - err = addImplicitVarsSources("__body_bodyData_json_", jsonVars) - if err != nil { - return err - } - case httprequester.BodyTypeRawText: - rawTextVars := extractBracesContent(inputs.Body.BodyData.RawText) - err = addImplicitVarsSources("__body_bodyData_rawText_", rawTextVars) - if err != nil { - return err - } - - } - - return nil -} - -func applyDBConditionToSchema(ns *compose.NodeSchema, condition *vo.DBCondition, parentNode *vo.Node) error { - if condition.ConditionList == nil { - return nil - } - - for idx, params := range condition.ConditionList { - var right *vo.Param - for _, param := range params { - if param == nil { - continue - } - if param.Name == "right" { - right = param - break - } - } - - if right == nil { - continue - } - name := fmt.Sprintf("__condition_right_%d", idx) - tInfo, err := CanvasBlockInputToTypeInfo(right.Input) - if err != nil { - return err - } - ns.SetInputType(name, tInfo) - sources, err := CanvasBlockInputToFieldInfo(right.Input, einoCompose.FieldPath{name}, parentNode) - if err != nil { - return err - } - ns.AddInputSource(sources...) - - } - - return nil - -} - -func applyInsetFieldInfoToSchema(ns *compose.NodeSchema, fieldInfo [][]*vo.Param, parentNode *vo.Node) error { - if len(fieldInfo) == 0 { - return nil - } - for _, params := range fieldInfo { - // Each FieldInfo is list params, containing two elements. - // The first is to set the name of the field and the second is the corresponding value. - p0 := params[0] - p1 := params[1] - name := p0.Input.Value.Content.(string) // must string type - tInfo, err := CanvasBlockInputToTypeInfo(p1.Input) - if err != nil { - return err - } - name = "__setting_field_" + name - ns.SetInputType(name, tInfo) - sources, err := CanvasBlockInputToFieldInfo(p1.Input, einoCompose.FieldPath{name}, parentNode) - if err != nil { - return err - } - ns.AddInputSource(sources...) - } - return nil - -} - -func applyParamsToSchema(ns *compose.NodeSchema, prefix string, params []*vo.Param, parentNode *vo.Node) error { - for i := range params { - param := params[i] - name := param.Name - tInfo, err := CanvasBlockInputToTypeInfo(param.Input) - if err != nil { - return err - } - - fieldName := prefix + crypto.MD5HexValue(name) - ns.SetInputType(fieldName, tInfo) - sources, err := CanvasBlockInputToFieldInfo(param.Input, einoCompose.FieldPath{fieldName}, parentNode) - if err != nil { - return err - } - ns.AddInputSource(sources...) - } - return nil -} - -func SetOutputTypesForNodeSchema(n *vo.Node, ns *compose.NodeSchema) error { - for _, vAny := range n.Data.Outputs { - v, err := vo.ParseVariable(vAny) - if err != nil { - return err - } - - tInfo, err := CanvasVariableToTypeInfo(v) - if err != nil { - return err - } - if v.ReadOnly { - if v.Name == "errorBody" { // reserved output fields when exception happens - continue - } - } - ns.SetOutputType(v.Name, tInfo) - } - - return nil -} - -func SetOutputsForNodeSchema(n *vo.Node, ns *compose.NodeSchema) error { - for _, vAny := range n.Data.Outputs { - param, err := parseParam(vAny) - if err != nil { - return err - } - name := param.Name - tInfo, err := CanvasBlockInputToTypeInfo(param.Input) - if err != nil { - return err - } - - ns.SetOutputType(name, tInfo) - - sources, err := CanvasBlockInputToFieldInfo(param.Input, einoCompose.FieldPath{name}, n.Parent()) - if err != nil { - return err - } - - ns.AddOutputSource(sources...) - } - - return nil -} - -func ToSelectorOperator(o vo.OperatorType, leftType *vo.TypeInfo) (selector.Operator, error) { - switch o { - case vo.Equal: - return selector.OperatorEqual, nil - case vo.NotEqual: - return selector.OperatorNotEqual, nil - case vo.LengthGreaterThan: - return selector.OperatorLengthGreater, nil - case vo.LengthGreaterThanEqual: - return selector.OperatorLengthGreaterOrEqual, nil - case vo.LengthLessThan: - return selector.OperatorLengthLesser, nil - case vo.LengthLessThanEqual: - return selector.OperatorLengthLesserOrEqual, nil - case vo.Contain: - if leftType.Type == vo.DataTypeObject { - return selector.OperatorContainKey, nil - } - return selector.OperatorContain, nil - case vo.NotContain: - if leftType.Type == vo.DataTypeObject { - return selector.OperatorNotContainKey, nil - } - return selector.OperatorNotContain, nil - case vo.Empty: - return selector.OperatorEmpty, nil - case vo.NotEmpty: - return selector.OperatorNotEmpty, nil - case vo.True: - return selector.OperatorIsTrue, nil - case vo.False: - return selector.OperatorIsFalse, nil - case vo.GreaterThan: - return selector.OperatorGreater, nil - case vo.GreaterThanEqual: - return selector.OperatorGreaterOrEqual, nil - case vo.LessThan: - return selector.OperatorLesser, nil - case vo.LessThanEqual: - return selector.OperatorLesserOrEqual, nil - default: - return "", fmt.Errorf("unsupported operator type: %d", o) - } -} - -func ToLoopType(l vo.LoopType) (loop.Type, error) { - switch l { - case vo.LoopTypeArray: - return loop.ByArray, nil - case vo.LoopTypeCount: - return loop.ByIteration, nil - case vo.LoopTypeInfinite: - return loop.Infinite, nil - default: - return "", fmt.Errorf("unsupported loop type: %s", l) - } -} - -func ConvertLogicTypeToRelation(logicType vo.DatabaseLogicType) (database.ClauseRelation, error) { - switch logicType { - case vo.DatabaseLogicAnd: - return database.ClauseRelationAND, nil - case vo.DatabaseLogicOr: - return database.ClauseRelationOR, nil - default: - return "", fmt.Errorf("logic type %v is invalid", logicType) - - } -} - -func OperationToOperator(s string) (database.Operator, error) { - switch s { - case "EQUAL": - return database.OperatorEqual, nil - case "NOT_EQUAL": - return database.OperatorNotEqual, nil - case "GREATER_THAN": - return database.OperatorGreater, nil - case "LESS_THAN": - return database.OperatorLesser, nil - case "GREATER_EQUAL": - return database.OperatorGreaterOrEqual, nil - case "LESS_EQUAL": - return database.OperatorLesserOrEqual, nil - case "IN": - return database.OperatorIn, nil - case "NOT_IN": - return database.OperatorNotIn, nil - case "IS_NULL": - return database.OperatorIsNull, nil - case "IS_NOT_NULL": - return database.OperatorIsNotNull, nil - case "LIKE": - return database.OperatorLike, nil - case "NOT_LIKE": - return database.OperatorNotLike, nil - } - return "", fmt.Errorf("not a valid Operation string") -} - -func ConvertAuthType(auth string) (httprequester.AuthType, error) { - switch auth { - case "CUSTOM_AUTH": - return httprequester.Custom, nil - case "BEARER_AUTH": - return httprequester.BearToken, nil - default: - return httprequester.AuthType(0), fmt.Errorf("invalid auth type") - } -} - -func ConvertLocation(l string) (httprequester.Location, error) { - switch l { - case "header": - return httprequester.Header, nil - case "query": - return httprequester.QueryParam, nil - default: - return 0, fmt.Errorf("invalid location") - - } - -} - -func ConvertParsingType(p string) (knowledge.ParseMode, error) { - switch p { - case "fast": - return knowledge.FastParseMode, nil - case "accurate": - return knowledge.AccurateParseMode, nil - default: - return "", fmt.Errorf("invalid parsingType: %s", p) - } -} - -func ConvertChunkType(p string) (knowledge.ChunkType, error) { - switch p { - case "custom": - return knowledge.ChunkTypeCustom, nil - case "default": - return knowledge.ChunkTypeDefault, nil - default: - return "", fmt.Errorf("invalid ChunkType: %s", p) - } -} -func ConvertRetrievalSearchType(s int64) (knowledge.SearchType, error) { - switch s { - case 0: - return knowledge.SearchTypeSemantic, nil - case 1: - return knowledge.SearchTypeHybrid, nil - case 20: - return knowledge.SearchTypeFullText, nil - default: - return "", fmt.Errorf("invalid RetrievalSearchType %v", s) - } -} - -func ConvertCodeLanguage(l int64) (coderunner.Language, error) { - switch l { - case 5: - return coderunner.JavaScript, nil - case 3: - return coderunner.Python, nil - default: - return "", fmt.Errorf("invalid language: %d", l) - - } -} - -func BlockInputToNamedTypeInfo(name string, b *vo.BlockInput) (*vo.NamedTypeInfo, error) { - tInfo := &vo.NamedTypeInfo{ - Name: name, - } - if b == nil { - return tInfo, nil - } - switch b.Type { - case vo.VariableTypeString: - switch b.AssistType { - case vo.AssistTypeTime: - tInfo.Type = vo.DataTypeTime - case vo.AssistTypeNotSet: - tInfo.Type = vo.DataTypeString - default: - fileType, ok := assistTypeToFileType(b.AssistType) - if ok { - tInfo.Type = vo.DataTypeFile - tInfo.FileType = &fileType - } else { - return nil, fmt.Errorf("unsupported assist type: %v", b.AssistType) - } - } - case vo.VariableTypeInteger: - tInfo.Type = vo.DataTypeInteger - case vo.VariableTypeFloat: - tInfo.Type = vo.DataTypeNumber - case vo.VariableTypeBoolean: - tInfo.Type = vo.DataTypeBoolean - case vo.VariableTypeObject: - tInfo.Type = vo.DataTypeObject - if b.Schema != nil { - tInfo.Properties = make([]*vo.NamedTypeInfo, 0, len(b.Schema.([]any))) - for _, subVAny := range b.Schema.([]any) { - if b.Value.Type == vo.BlockInputValueTypeRef { - subV, err := vo.ParseVariable(subVAny) - if err != nil { - return nil, err - } - subNInfo, err := VariableToNamedTypeInfo(subV) - if err != nil { - return nil, err - } - tInfo.Properties = append(tInfo.Properties, subNInfo) - } else if b.Value.Type == vo.BlockInputValueTypeObjectRef { - subV, err := parseParam(subVAny) - if err != nil { - return nil, err - } - subNInfo, err := BlockInputToNamedTypeInfo(subV.Name, subV.Input) - if err != nil { - return nil, err - } - tInfo.Properties = append(tInfo.Properties, subNInfo) - } - } - } - case vo.VariableTypeList: - tInfo.Type = vo.DataTypeArray - subVAny := b.Schema - subV, err := vo.ParseVariable(subVAny) - if err != nil { - return nil, err - } - subNInfo, err := VariableToNamedTypeInfo(subV) - if err != nil { - return nil, err - } - tInfo.ElemTypeInfo = subNInfo - default: - return nil, fmt.Errorf("unsupported variable type: %s", b.Type) - } - - return tInfo, nil -} - -func VariableToNamedTypeInfo(v *vo.Variable) (*vo.NamedTypeInfo, error) { - nInfo := &vo.NamedTypeInfo{ - Required: v.Required, - Name: v.Name, - Desc: v.Description, - } - - switch v.Type { - case vo.VariableTypeString: - switch v.AssistType { - case vo.AssistTypeTime: - nInfo.Type = vo.DataTypeTime - case vo.AssistTypeNotSet: - nInfo.Type = vo.DataTypeString - default: - fileType, ok := assistTypeToFileType(v.AssistType) - if ok { - nInfo.Type = vo.DataTypeFile - nInfo.FileType = &fileType - } else { - return nil, fmt.Errorf("unsupported assist type: %v", v.AssistType) - } - } - case vo.VariableTypeInteger: - nInfo.Type = vo.DataTypeInteger - case vo.VariableTypeFloat: - nInfo.Type = vo.DataTypeNumber - case vo.VariableTypeBoolean: - nInfo.Type = vo.DataTypeBoolean - case vo.VariableTypeObject: - nInfo.Type = vo.DataTypeObject - if v.Schema != nil { - nInfo.Properties = make([]*vo.NamedTypeInfo, 0) - for _, subVAny := range v.Schema.([]any) { - subV, err := vo.ParseVariable(subVAny) - if err != nil { - return nil, err - } - subTInfo, err := VariableToNamedTypeInfo(subV) - if err != nil { - return nil, err - } - nInfo.Properties = append(nInfo.Properties, subTInfo) - - } - } - case vo.VariableTypeList: - nInfo.Type = vo.DataTypeArray - subVAny := v.Schema - subV, err := vo.ParseVariable(subVAny) - if err != nil { - return nil, err - } - subTInfo, err := VariableToNamedTypeInfo(subV) - if err != nil { - return nil, err - } - nInfo.ElemTypeInfo = subTInfo - - default: - return nil, fmt.Errorf("unsupported variable type: %s", v.Type) - } - - return nInfo, nil -} diff --git a/backend/domain/workflow/internal/canvas/convert/type_convert.go b/backend/domain/workflow/internal/canvas/convert/type_convert.go new file mode 100644 index 00000000..cda4f596 --- /dev/null +++ b/backend/domain/workflow/internal/canvas/convert/type_convert.go @@ -0,0 +1,663 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package convert + +import ( + "fmt" + "strconv" + "strings" + + einoCompose "github.com/cloudwego/eino/compose" + + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" + "github.com/coze-dev/coze-studio/backend/pkg/sonic" + "github.com/coze-dev/coze-studio/backend/types/errno" +) + +func CanvasVariableToTypeInfo(v *vo.Variable) (*vo.TypeInfo, error) { + tInfo := &vo.TypeInfo{ + Required: v.Required, + Desc: v.Description, + } + + switch v.Type { + case vo.VariableTypeString: + switch v.AssistType { + case vo.AssistTypeTime: + tInfo.Type = vo.DataTypeTime + case vo.AssistTypeNotSet: + tInfo.Type = vo.DataTypeString + default: + fileType, ok := assistTypeToFileType(v.AssistType) + if ok { + tInfo.Type = vo.DataTypeFile + tInfo.FileType = &fileType + } else { + return nil, fmt.Errorf("unsupported assist type: %v", v.AssistType) + } + } + case vo.VariableTypeInteger: + tInfo.Type = vo.DataTypeInteger + case vo.VariableTypeFloat: + tInfo.Type = vo.DataTypeNumber + case vo.VariableTypeBoolean: + tInfo.Type = vo.DataTypeBoolean + case vo.VariableTypeObject: + tInfo.Type = vo.DataTypeObject + tInfo.Properties = make(map[string]*vo.TypeInfo) + if v.Schema != nil { + for _, subVAny := range v.Schema.([]any) { + subV, err := vo.ParseVariable(subVAny) + if err != nil { + return nil, err + } + subTInfo, err := CanvasVariableToTypeInfo(subV) + if err != nil { + return nil, err + } + tInfo.Properties[subV.Name] = subTInfo + } + } + case vo.VariableTypeList: + tInfo.Type = vo.DataTypeArray + subVAny := v.Schema + subV, err := vo.ParseVariable(subVAny) + if err != nil { + return nil, err + } + subTInfo, err := CanvasVariableToTypeInfo(subV) + if err != nil { + return nil, err + } + tInfo.ElemTypeInfo = subTInfo + + default: + return nil, fmt.Errorf("unsupported variable type: %s", v.Type) + } + + return tInfo, nil +} + +func CanvasBlockInputToTypeInfo(b *vo.BlockInput) (tInfo *vo.TypeInfo, err error) { + defer func() { + if err != nil { + err = vo.WrapIfNeeded(errno.ErrSchemaConversionFail, err) + } + }() + + tInfo = &vo.TypeInfo{} + + if b == nil { + return tInfo, nil + } + + switch b.Type { + case vo.VariableTypeString: + switch b.AssistType { + case vo.AssistTypeTime: + tInfo.Type = vo.DataTypeTime + case vo.AssistTypeNotSet: + tInfo.Type = vo.DataTypeString + default: + fileType, ok := assistTypeToFileType(b.AssistType) + if ok { + tInfo.Type = vo.DataTypeFile + tInfo.FileType = &fileType + } else { + return nil, fmt.Errorf("unsupported assist type: %v", b.AssistType) + } + } + case vo.VariableTypeInteger: + tInfo.Type = vo.DataTypeInteger + case vo.VariableTypeFloat: + tInfo.Type = vo.DataTypeNumber + case vo.VariableTypeBoolean: + tInfo.Type = vo.DataTypeBoolean + case vo.VariableTypeObject: + tInfo.Type = vo.DataTypeObject + tInfo.Properties = make(map[string]*vo.TypeInfo) + if b.Schema != nil { + for _, subVAny := range b.Schema.([]any) { + if b.Value.Type == vo.BlockInputValueTypeRef { + subV, err := vo.ParseVariable(subVAny) + if err != nil { + return nil, err + } + subTInfo, err := CanvasVariableToTypeInfo(subV) + if err != nil { + return nil, err + } + tInfo.Properties[subV.Name] = subTInfo + } else if b.Value.Type == vo.BlockInputValueTypeObjectRef { + subV, err := parseParam(subVAny) + if err != nil { + return nil, err + } + subTInfo, err := CanvasBlockInputToTypeInfo(subV.Input) + if err != nil { + return nil, err + } + tInfo.Properties[subV.Name] = subTInfo + } + } + } + case vo.VariableTypeList: + tInfo.Type = vo.DataTypeArray + subVAny := b.Schema + subV, err := vo.ParseVariable(subVAny) + if err != nil { + return nil, err + } + subTInfo, err := CanvasVariableToTypeInfo(subV) + if err != nil { + return nil, err + } + tInfo.ElemTypeInfo = subTInfo + default: + return nil, fmt.Errorf("unsupported variable type: %s", b.Type) + } + + return tInfo, nil +} + +func CanvasBlockInputToFieldInfo(b *vo.BlockInput, path einoCompose.FieldPath, parentNode *vo.Node) (sources []*vo.FieldInfo, err error) { + value := b.Value + if value == nil { + return nil, fmt.Errorf("input %v has no value, type= %s", path, b.Type) + } + + switch value.Type { + case vo.BlockInputValueTypeObjectRef: + sc := b.Schema + if sc == nil { + return nil, fmt.Errorf("input %v has no schema, type= %s", path, b.Type) + } + + paramList, ok := sc.([]any) + if !ok { + return nil, fmt.Errorf("input %v schema not []any, type= %T", path, sc) + } + + for i := range paramList { + paramAny := paramList[i] + param, err := parseParam(paramAny) + if err != nil { + return nil, err + } + + copied := make([]string, len(path)) + copy(copied, path) + subFieldInfo, err := CanvasBlockInputToFieldInfo(param.Input, append(copied, param.Name), parentNode) + if err != nil { + return nil, err + } + sources = append(sources, subFieldInfo...) + } + return sources, nil + case vo.BlockInputValueTypeLiteral: + content := value.Content + if content == nil { + return nil, fmt.Errorf("input %v is literal but has no value, type= %s", path, b.Type) + } + + switch b.Type { + case vo.VariableTypeObject: + m := make(map[string]any) + if err = sonic.UnmarshalString(content.(string), &m); err != nil { + return nil, err + } + content = m + case vo.VariableTypeList: + l := make([]any, 0) + if err = sonic.UnmarshalString(content.(string), &l); err != nil { + return nil, err + } + content = l + case vo.VariableTypeInteger: + switch content.(type) { + case string: + content, err = strconv.ParseInt(content.(string), 10, 64) + if err != nil { + return nil, err + } + case int64: + content = content.(int64) + case float64: + content = int64(content.(float64)) + default: + return nil, fmt.Errorf("unsupported variable type fot integer: %s", b.Type) + } + case vo.VariableTypeFloat: + switch content.(type) { + case string: + content, err = strconv.ParseFloat(content.(string), 64) + if err != nil { + return nil, err + } + case int64: + content = float64(content.(int64)) + case float64: + content = content.(float64) + default: + return nil, fmt.Errorf("unsupported variable type for float: %s", b.Type) + } + case vo.VariableTypeBoolean: + switch content.(type) { + case string: + content, err = strconv.ParseBool(content.(string)) + if err != nil { + return nil, err + } + case bool: + content = content.(bool) + default: + return nil, fmt.Errorf("unsupported variable type for boolean: %s", b.Type) + } + default: + } + return []*vo.FieldInfo{ + { + Path: path, + Source: vo.FieldSource{ + Val: content, + }, + }, + }, nil + case vo.BlockInputValueTypeRef: + content := value.Content + if content == nil { + return nil, fmt.Errorf("input %v is literal but has no value, type= %s", path, b.Type) + } + + ref, err := parseBlockInputRef(content) + if err != nil { + return nil, err + } + + fieldSource, err := CanvasBlockInputRefToFieldSource(ref) + if err != nil { + return nil, err + } + + if parentNode != nil { + if fieldSource.Ref != nil && len(fieldSource.Ref.FromNodeKey) > 0 && fieldSource.Ref.FromNodeKey == vo.NodeKey(parentNode.ID) { + varRoot := fieldSource.Ref.FromPath[0] + if parentNode.Data.Inputs.Loop != nil { + for _, p := range parentNode.Data.Inputs.VariableParameters { + if p.Name == varRoot { + fieldSource.Ref.FromNodeKey = "" + pi := vo.ParentIntermediate + fieldSource.Ref.VariableType = &pi + } + } + } + } + } + + return []*vo.FieldInfo{ + { + Path: path, + Source: *fieldSource, + }, + }, nil + default: + return nil, fmt.Errorf("unsupported value type: %s for blockInput type= %s", value.Type, b.Type) + } +} + +func parseBlockInputRef(content any) (*vo.BlockInputReference, error) { + if bi, ok := content.(*vo.BlockInputReference); ok { + return bi, nil + } + + m, ok := content.(map[string]any) + if !ok { + return nil, fmt.Errorf("invalid content type: %T when parse BlockInputRef", content) + } + + marshaled, err := sonic.Marshal(m) + if err != nil { + return nil, err + } + + p := &vo.BlockInputReference{} + if err := sonic.Unmarshal(marshaled, p); err != nil { + return nil, err + } + + return p, nil +} + +func parseParam(v any) (*vo.Param, error) { + if pa, ok := v.(*vo.Param); ok { + return pa, nil + } + + m, ok := v.(map[string]any) + if !ok { + return nil, fmt.Errorf("invalid content type: %T when parse Param", v) + } + + marshaled, err := sonic.Marshal(m) + if err != nil { + return nil, err + } + + p := &vo.Param{} + if err := sonic.Unmarshal(marshaled, p); err != nil { + return nil, err + } + + return p, nil +} + +func CanvasBlockInputRefToFieldSource(r *vo.BlockInputReference) (*vo.FieldSource, error) { + switch r.Source { + case vo.RefSourceTypeBlockOutput: + if len(r.BlockID) == 0 { + return nil, fmt.Errorf("invalid BlockInputReference = %+v, BlockID is empty when source is block output", r) + } + + parts := strings.Split(r.Name, ".") // an empty r.Name signals an all-to-all mapping + return &vo.FieldSource{ + Ref: &vo.Reference{ + FromNodeKey: vo.NodeKey(r.BlockID), + FromPath: parts, + }, + }, nil + case vo.RefSourceTypeGlobalApp, vo.RefSourceTypeGlobalSystem, vo.RefSourceTypeGlobalUser: + if len(r.Path) == 0 { + return nil, fmt.Errorf("invalid BlockInputReference = %+v, Path is empty when source is variables", r) + } + + var varType vo.GlobalVarType + switch r.Source { + case vo.RefSourceTypeGlobalApp: + varType = vo.GlobalAPP + case vo.RefSourceTypeGlobalSystem: + varType = vo.GlobalSystem + case vo.RefSourceTypeGlobalUser: + varType = vo.GlobalUser + default: + return nil, fmt.Errorf("invalid BlockInputReference = %+v, Source is invalid", r) + } + + return &vo.FieldSource{ + Ref: &vo.Reference{ + VariableType: &varType, + FromPath: r.Path, + }, + }, nil + default: + return nil, fmt.Errorf("unsupported ref source type: %s", r.Source) + } +} + +func assistTypeToFileType(a vo.AssistType) (vo.FileSubType, bool) { + switch a { + case vo.AssistTypeNotSet: + return "", false + case vo.AssistTypeTime: + return "", false + case vo.AssistTypeImage: + return vo.FileTypeImage, true + case vo.AssistTypeAudio: + return vo.FileTypeAudio, true + case vo.AssistTypeVideo: + return vo.FileTypeVideo, true + case vo.AssistTypeDefault: + return vo.FileTypeDefault, true + case vo.AssistTypeDoc: + return vo.FileTypeDocument, true + case vo.AssistTypeExcel: + return vo.FileTypeExcel, true + case vo.AssistTypeCode: + return vo.FileTypeCode, true + case vo.AssistTypePPT: + return vo.FileTypePPT, true + case vo.AssistTypeTXT: + return vo.FileTypeTxt, true + case vo.AssistTypeSvg: + return vo.FileTypeSVG, true + case vo.AssistTypeVoice: + return vo.FileTypeVoice, true + case vo.AssistTypeZip: + return vo.FileTypeZip, true + default: + panic("impossible") + } +} + +func SetInputsForNodeSchema(n *vo.Node, ns *schema.NodeSchema) error { + if n.Data.Inputs == nil { + return nil + } + + inputParams := n.Data.Inputs.InputParameters + if len(inputParams) == 0 { + return nil + } + + for _, param := range inputParams { + name := param.Name + tInfo, err := CanvasBlockInputToTypeInfo(param.Input) + if err != nil { + return err + } + + ns.SetInputType(name, tInfo) + + sources, err := CanvasBlockInputToFieldInfo(param.Input, einoCompose.FieldPath{name}, n.Parent()) + if err != nil { + return err + } + + ns.AddInputSource(sources...) + } + + return nil +} + +func SetOutputTypesForNodeSchema(n *vo.Node, ns *schema.NodeSchema) error { + for _, vAny := range n.Data.Outputs { + v, err := vo.ParseVariable(vAny) + if err != nil { + return err + } + + tInfo, err := CanvasVariableToTypeInfo(v) + if err != nil { + return err + } + if v.ReadOnly { + if v.Name == "errorBody" { // reserved output fields when exception happens + continue + } + } + ns.SetOutputType(v.Name, tInfo) + } + + return nil +} + +func SetOutputsForNodeSchema(n *vo.Node, ns *schema.NodeSchema) error { + for _, vAny := range n.Data.Outputs { + param, err := parseParam(vAny) + if err != nil { + return err + } + name := param.Name + tInfo, err := CanvasBlockInputToTypeInfo(param.Input) + if err != nil { + return err + } + + ns.SetOutputType(name, tInfo) + + sources, err := CanvasBlockInputToFieldInfo(param.Input, einoCompose.FieldPath{name}, n.Parent()) + if err != nil { + return err + } + + ns.AddOutputSource(sources...) + } + + return nil +} + +func BlockInputToNamedTypeInfo(name string, b *vo.BlockInput) (*vo.NamedTypeInfo, error) { + tInfo := &vo.NamedTypeInfo{ + Name: name, + } + if b == nil { + return tInfo, nil + } + switch b.Type { + case vo.VariableTypeString: + switch b.AssistType { + case vo.AssistTypeTime: + tInfo.Type = vo.DataTypeTime + case vo.AssistTypeNotSet: + tInfo.Type = vo.DataTypeString + default: + fileType, ok := assistTypeToFileType(b.AssistType) + if ok { + tInfo.Type = vo.DataTypeFile + tInfo.FileType = &fileType + } else { + return nil, fmt.Errorf("unsupported assist type: %v", b.AssistType) + } + } + case vo.VariableTypeInteger: + tInfo.Type = vo.DataTypeInteger + case vo.VariableTypeFloat: + tInfo.Type = vo.DataTypeNumber + case vo.VariableTypeBoolean: + tInfo.Type = vo.DataTypeBoolean + case vo.VariableTypeObject: + tInfo.Type = vo.DataTypeObject + if b.Schema != nil { + tInfo.Properties = make([]*vo.NamedTypeInfo, 0, len(b.Schema.([]any))) + for _, subVAny := range b.Schema.([]any) { + if b.Value.Type == vo.BlockInputValueTypeRef { + subV, err := vo.ParseVariable(subVAny) + if err != nil { + return nil, err + } + subNInfo, err := VariableToNamedTypeInfo(subV) + if err != nil { + return nil, err + } + tInfo.Properties = append(tInfo.Properties, subNInfo) + } else if b.Value.Type == vo.BlockInputValueTypeObjectRef { + subV, err := parseParam(subVAny) + if err != nil { + return nil, err + } + subNInfo, err := BlockInputToNamedTypeInfo(subV.Name, subV.Input) + if err != nil { + return nil, err + } + tInfo.Properties = append(tInfo.Properties, subNInfo) + } + } + } + case vo.VariableTypeList: + tInfo.Type = vo.DataTypeArray + subVAny := b.Schema + subV, err := vo.ParseVariable(subVAny) + if err != nil { + return nil, err + } + subNInfo, err := VariableToNamedTypeInfo(subV) + if err != nil { + return nil, err + } + tInfo.ElemTypeInfo = subNInfo + default: + return nil, fmt.Errorf("unsupported variable type: %s", b.Type) + } + + return tInfo, nil +} + +func VariableToNamedTypeInfo(v *vo.Variable) (*vo.NamedTypeInfo, error) { + nInfo := &vo.NamedTypeInfo{ + Required: v.Required, + Name: v.Name, + Desc: v.Description, + } + + switch v.Type { + case vo.VariableTypeString: + switch v.AssistType { + case vo.AssistTypeTime: + nInfo.Type = vo.DataTypeTime + case vo.AssistTypeNotSet: + nInfo.Type = vo.DataTypeString + default: + fileType, ok := assistTypeToFileType(v.AssistType) + if ok { + nInfo.Type = vo.DataTypeFile + nInfo.FileType = &fileType + } else { + return nil, fmt.Errorf("unsupported assist type: %v", v.AssistType) + } + } + case vo.VariableTypeInteger: + nInfo.Type = vo.DataTypeInteger + case vo.VariableTypeFloat: + nInfo.Type = vo.DataTypeNumber + case vo.VariableTypeBoolean: + nInfo.Type = vo.DataTypeBoolean + case vo.VariableTypeObject: + nInfo.Type = vo.DataTypeObject + if v.Schema != nil { + nInfo.Properties = make([]*vo.NamedTypeInfo, 0) + for _, subVAny := range v.Schema.([]any) { + subV, err := vo.ParseVariable(subVAny) + if err != nil { + return nil, err + } + subTInfo, err := VariableToNamedTypeInfo(subV) + if err != nil { + return nil, err + } + nInfo.Properties = append(nInfo.Properties, subTInfo) + + } + } + case vo.VariableTypeList: + nInfo.Type = vo.DataTypeArray + subVAny := v.Schema + subV, err := vo.ParseVariable(subVAny) + if err != nil { + return nil, err + } + subTInfo, err := VariableToNamedTypeInfo(subV) + if err != nil { + return nil, err + } + nInfo.ElemTypeInfo = subTInfo + + default: + return nil, fmt.Errorf("unsupported variable type: %s", v.Type) + } + + return nInfo, nil +} diff --git a/backend/domain/workflow/internal/canvas/validate/canvas_validate.go b/backend/domain/workflow/internal/canvas/validate/canvas_validate.go index 1c31b6cc..5b028b6a 100644 --- a/backend/domain/workflow/internal/canvas/validate/canvas_validate.go +++ b/backend/domain/workflow/internal/canvas/validate/canvas_validate.go @@ -24,8 +24,11 @@ import ( "github.com/coze-dev/coze-studio/backend/domain/workflow" "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable" + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/adaptor" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/pkg/sonic" "github.com/coze-dev/coze-studio/backend/types/errno" ) @@ -123,7 +126,7 @@ func (cv *CanvasValidator) ValidateConnections(ctx context.Context) (issues []*I return issues, nil } -func (cv *CanvasValidator) CheckRefVariable(ctx context.Context) (issues []*Issue, err error) { +func (cv *CanvasValidator) CheckRefVariable(_ context.Context) (issues []*Issue, err error) { issues = make([]*Issue, 0) var checkRefVariable func(reachability *reachability, reachableNodes map[string]bool) error checkRefVariable = func(reachability *reachability, parentReachableNodes map[string]bool) error { @@ -237,7 +240,7 @@ func (cv *CanvasValidator) CheckRefVariable(ctx context.Context) (issues []*Issu return issues, nil } -func (cv *CanvasValidator) ValidateNestedFlows(ctx context.Context) (issues []*Issue, err error) { +func (cv *CanvasValidator) ValidateNestedFlows(_ context.Context) (issues []*Issue, err error) { issues = make([]*Issue, 0) for nodeID, node := range cv.reachability.reachableNodes { if nestedReachableNodes, ok := cv.reachability.nestedReachability[nodeID]; ok && len(nestedReachableNodes.nestedReachability) > 0 { @@ -265,13 +268,13 @@ func (cv *CanvasValidator) CheckGlobalVariables(ctx context.Context) (issues []* nVars := make([]*nodeVars, 0) for _, node := range cv.cfg.Canvas.Nodes { - if node.Type == vo.BlockTypeBotComment { + if node.Type == entity.NodeTypeComment.IDStr() { continue } - if node.Type == vo.BlockTypeBotAssignVariable { + if node.Type == entity.NodeTypeVariableAssigner.IDStr() { v := &nodeVars{node: node, vars: make(map[string]*vo.TypeInfo)} for _, p := range node.Data.Inputs.InputParameters { - v.vars[p.Name], err = adaptor.CanvasBlockInputToTypeInfo(p.Left) + v.vars[p.Name], err = convert.CanvasBlockInputToTypeInfo(p.Left) if err != nil { return nil, err } @@ -338,7 +341,7 @@ func (cv *CanvasValidator) CheckSubWorkFlowTerminatePlanType(ctx context.Context var collectSubWorkFlowNodes func(nodes []*vo.Node) collectSubWorkFlowNodes = func(nodes []*vo.Node) { for _, n := range nodes { - if n.Type == vo.BlockTypeBotSubWorkflow { + if n.Type == entity.NodeTypeSubWorkflow.IDStr() { subWfMap = append(subWfMap, n) wID, err := strconv.ParseInt(n.Data.Inputs.WorkflowID, 10, 64) if err != nil { @@ -465,62 +468,28 @@ func validateConnections(ctx context.Context, c *vo.Canvas) (issues []*Issue, er selectorPorts := make(map[string]map[string]bool) for nodeID, node := range nodeMap { - switch node.Type { - case vo.BlockTypeCondition: - branches := node.Data.Inputs.Branches + if node.Data.Inputs != nil && node.Data.Inputs.SettingOnError != nil && + node.Data.Inputs.SettingOnError.ProcessType != nil && + *node.Data.Inputs.SettingOnError.ProcessType == vo.ErrorProcessTypeExceptionBranch { if _, exists := selectorPorts[nodeID]; !exists { selectorPorts[nodeID] = make(map[string]bool) } - selectorPorts[nodeID]["false"] = true - for index := range branches { - if index == 0 { - selectorPorts[nodeID]["true"] = true - } else { - selectorPorts[nodeID][fmt.Sprintf("true_%v", index)] = true - } - } - case vo.BlockTypeBotIntent: - intents := node.Data.Inputs.Intents - if _, exists := selectorPorts[nodeID]; !exists { - selectorPorts[nodeID] = make(map[string]bool) - } - for index := range intents { - selectorPorts[nodeID][fmt.Sprintf("branch_%v", index)] = true - } - selectorPorts[nodeID]["default"] = true - if node.Data.Inputs.SettingOnError != nil && node.Data.Inputs.SettingOnError.ProcessType != nil && - *node.Data.Inputs.SettingOnError.ProcessType == vo.ErrorProcessTypeExceptionBranch { - selectorPorts[nodeID]["branch_error"] = true - } - case vo.BlockTypeQuestion: - if node.Data.Inputs.QA.AnswerType == vo.QAAnswerTypeOption { - if _, exists := selectorPorts[nodeID]; !exists { - selectorPorts[nodeID] = make(map[string]bool) - } - if node.Data.Inputs.QA.OptionType == vo.QAOptionTypeStatic { - for index := range node.Data.Inputs.QA.Options { - selectorPorts[nodeID][fmt.Sprintf("branch_%v", index)] = true - } - } - - if node.Data.Inputs.QA.OptionType == vo.QAOptionTypeDynamic { - selectorPorts[nodeID][fmt.Sprintf("branch_%v", 0)] = true - } - } - default: - if node.Data.Inputs != nil && node.Data.Inputs.SettingOnError != nil && - node.Data.Inputs.SettingOnError.ProcessType != nil && - *node.Data.Inputs.SettingOnError.ProcessType == vo.ErrorProcessTypeExceptionBranch { - if _, exists := selectorPorts[nodeID]; !exists { - selectorPorts[nodeID] = make(map[string]bool) - } - selectorPorts[nodeID]["branch_error"] = true - selectorPorts[nodeID]["default"] = true - } else { - outDegree[node.ID] = 0 - } + selectorPorts[nodeID][schema.PortBranchError] = true + selectorPorts[nodeID][schema.PortDefault] = true } + ba, ok := nodes.GetBranchAdaptor(entity.IDStrToNodeType(node.Type)) + if ok { + expects := ba.ExpectPorts(ctx, node) + if len(expects) > 0 { + if _, exists := selectorPorts[nodeID]; !exists { + selectorPorts[nodeID] = make(map[string]bool) + } + for _, e := range expects { + selectorPorts[nodeID][e] = true + } + } + } } for _, edge := range c.Edges { @@ -544,8 +513,8 @@ func validateConnections(ctx context.Context, c *vo.Canvas) (issues []*Issue, er for nodeID, node := range nodeMap { nodeName := node.Data.Meta.Title - switch node.Type { - case vo.BlockTypeBotStart: + switch et := entity.IDStrToNodeType(node.Type); et { + case entity.NodeTypeEntry: if outDegree[nodeID] == 0 { issues = append(issues, &Issue{ NodeErr: &NodeErr{ @@ -555,13 +524,9 @@ func validateConnections(ctx context.Context, c *vo.Canvas) (issues []*Issue, er Message: `node "start" not connected`, }) } - case vo.BlockTypeBotEnd: + case entity.NodeTypeExit: default: if ports, isSelector := selectorPorts[nodeID]; isSelector { - selectorIssues := &Issue{NodeErr: &NodeErr{ - NodeID: node.ID, - NodeName: nodeName, - }} message := "" for port := range ports { if portOutDegree[nodeID][port] == 0 { @@ -569,12 +534,15 @@ func validateConnections(ctx context.Context, c *vo.Canvas) (issues []*Issue, er } } if len(message) > 0 { - selectorIssues.Message = message + selectorIssues := &Issue{NodeErr: &NodeErr{ + NodeID: node.ID, + NodeName: nodeName, + }, Message: message} issues = append(issues, selectorIssues) } } else { // Break, continue without checking out degrees - if node.Type == vo.BlockTypeBotBreak || node.Type == vo.BlockTypeBotContinue { + if et == entity.NodeTypeBreak || et == entity.NodeTypeContinue { continue } if outDegree[nodeID] == 0 { @@ -585,7 +553,6 @@ func validateConnections(ctx context.Context, c *vo.Canvas) (issues []*Issue, er }, Message: fmt.Sprintf(`node "%v" not connected`, nodeName), }) - } } } @@ -602,7 +569,7 @@ func analyzeCanvasReachability(c *vo.Canvas) (*reachability, error) { return nil, err } - startNode, endNode, err := findStartAndEndNodes(c.Nodes) + startNode, _, err := findStartAndEndNodes(c.Nodes) if err != nil { return nil, err } @@ -612,7 +579,7 @@ func analyzeCanvasReachability(c *vo.Canvas) (*reachability, error) { edgeMap[edge.SourceNodeID] = append(edgeMap[edge.SourceNodeID], edge.TargetNodeID) } - reachable.reachableNodes, err = performReachabilityAnalysis(nodeMap, edgeMap, startNode, endNode) + reachable.reachableNodes, err = performReachabilityAnalysis(nodeMap, edgeMap, startNode) if err != nil { return nil, err } @@ -635,12 +602,12 @@ func processNestedReachability(c *vo.Canvas, r *reachability) error { Nodes: append([]*vo.Node{ { ID: node.ID, - Type: vo.BlockTypeBotStart, + Type: entity.NodeTypeEntry.IDStr(), Data: node.Data, }, { ID: node.ID, - Type: vo.BlockTypeBotEnd, + Type: entity.NodeTypeExit.IDStr(), }, }, node.Blocks...), Edges: node.Edges, @@ -663,9 +630,9 @@ func findStartAndEndNodes(nodes []*vo.Node) (*vo.Node, *vo.Node, error) { for _, node := range nodes { switch node.Type { - case vo.BlockTypeBotStart: + case entity.NodeTypeEntry.IDStr(): startNode = node - case vo.BlockTypeBotEnd: + case entity.NodeTypeExit.IDStr(): endNode = node } } @@ -680,7 +647,7 @@ func findStartAndEndNodes(nodes []*vo.Node) (*vo.Node, *vo.Node, error) { return startNode, endNode, nil } -func performReachabilityAnalysis(nodeMap map[string]*vo.Node, edgeMap map[string][]string, startNode *vo.Node, endNode *vo.Node) (map[string]*vo.Node, error) { +func performReachabilityAnalysis(nodeMap map[string]*vo.Node, edgeMap map[string][]string, startNode *vo.Node) (map[string]*vo.Node, error) { result := make(map[string]*vo.Node) result[startNode.ID] = startNode diff --git a/backend/domain/workflow/internal/compose/branch.go b/backend/domain/workflow/internal/compose/branch.go deleted file mode 100644 index 90e8f32b..00000000 --- a/backend/domain/workflow/internal/compose/branch.go +++ /dev/null @@ -1,181 +0,0 @@ -/* - * Copyright 2025 coze-dev Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package compose - -import ( - "context" - "errors" - "fmt" - - "github.com/cloudwego/eino/compose" - "github.com/spf13/cast" - - "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" - "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector" -) - -func (s *NodeSchema) OutputPortCount() (int, bool) { - var hasExceptionPort bool - if s.ExceptionConfigs != nil && s.ExceptionConfigs.ProcessType != nil && - *s.ExceptionConfigs.ProcessType == vo.ErrorProcessTypeExceptionBranch { - hasExceptionPort = true - } - - switch s.Type { - case entity.NodeTypeSelector: - return len(mustGetKey[[]*selector.OneClauseSchema]("Clauses", s.Configs)) + 1, hasExceptionPort - case entity.NodeTypeQuestionAnswer: - if mustGetKey[qa.AnswerType]("AnswerType", s.Configs.(map[string]any)) == qa.AnswerByChoices { - if mustGetKey[qa.ChoiceType]("ChoiceType", s.Configs.(map[string]any)) == qa.FixedChoices { - return len(mustGetKey[[]string]("FixedChoices", s.Configs.(map[string]any))) + 1, hasExceptionPort - } else { - return 2, hasExceptionPort - } - } - return 1, hasExceptionPort - case entity.NodeTypeIntentDetector: - intents := mustGetKey[[]string]("Intents", s.Configs.(map[string]any)) - return len(intents) + 1, hasExceptionPort - default: - return 1, hasExceptionPort - } -} - -type BranchMapping struct { - Normal []map[string]bool - Exception map[string]bool -} - -const ( - DefaultBranch = "default" - BranchFmt = "branch_%d" -) - -func (s *NodeSchema) GetBranch(bMapping *BranchMapping) (*compose.GraphBranch, error) { - if bMapping == nil { - return nil, errors.New("no branch mapping") - } - - endNodes := make(map[string]bool) - for i := range bMapping.Normal { - for k := range bMapping.Normal[i] { - endNodes[k] = true - } - } - - if bMapping.Exception != nil { - for k := range bMapping.Exception { - endNodes[k] = true - } - } - - switch s.Type { - case entity.NodeTypeSelector: - condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) { - choice := in[selector.SelectKey].(int) - if choice < 0 || choice > len(bMapping.Normal) { - return nil, fmt.Errorf("node %s choice out of range: %d", s.Key, choice) - } - - choices := make(map[string]bool, len((bMapping.Normal)[choice])) - for k := range (bMapping.Normal)[choice] { - choices[k] = true - } - - return choices, nil - } - return compose.NewGraphMultiBranch(condition, endNodes), nil - case entity.NodeTypeQuestionAnswer: - conf := s.Configs.(map[string]any) - if mustGetKey[qa.AnswerType]("AnswerType", conf) == qa.AnswerByChoices { - choiceType := mustGetKey[qa.ChoiceType]("ChoiceType", conf) - condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) { - optionID, ok := nodes.TakeMapValue(in, compose.FieldPath{qa.OptionIDKey}) - if !ok { - return nil, fmt.Errorf("failed to take option id from input map: %v", in) - } - - if optionID.(string) == "other" { - return (bMapping.Normal)[len(bMapping.Normal)-1], nil - } - - if choiceType == qa.DynamicChoices { // all dynamic choices maps to branch 0 - return (bMapping.Normal)[0], nil - } - - optionIDInt, ok := qa.AlphabetToInt(optionID.(string)) - if !ok { - return nil, fmt.Errorf("failed to convert option id from input map: %v", optionID) - } - - if optionIDInt < 0 || optionIDInt >= len(bMapping.Normal) { - return nil, fmt.Errorf("failed to take option id from input map: %v", in) - } - - return (bMapping.Normal)[optionIDInt], nil - } - return compose.NewGraphMultiBranch(condition, endNodes), nil - } - return nil, fmt.Errorf("this qa node should not have branches: %s", s.Key) - - case entity.NodeTypeIntentDetector: - condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) { - isSuccess, ok := in["isSuccess"] - if ok && isSuccess != nil && !isSuccess.(bool) { - return bMapping.Exception, nil - } - - classificationId, ok := nodes.TakeMapValue(in, compose.FieldPath{"classificationId"}) - if !ok { - return nil, fmt.Errorf("failed to take classification id from input map: %v", in) - } - - // Intent detector the node default branch uses classificationId=0. But currently scene, the implementation uses default as the last element of the array. - // Therefore, when classificationId=0, it needs to be converted into the node corresponding to the last index of the array. - // Other options also need to reduce the index by 1. - id, err := cast.ToInt64E(classificationId) - if err != nil { - return nil, err - } - realID := id - 1 - - if realID >= int64(len(bMapping.Normal)) { - return nil, fmt.Errorf("invalid classification id from input, classification id: %v", classificationId) - } - - if realID < 0 { - realID = int64(len(bMapping.Normal)) - 1 - } - - return (bMapping.Normal)[realID], nil - } - return compose.NewGraphMultiBranch(condition, endNodes), nil - default: - condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) { - isSuccess, ok := in["isSuccess"] - if ok && isSuccess != nil && !isSuccess.(bool) { - return bMapping.Exception, nil - } - - return (bMapping.Normal)[0], nil - } - return compose.NewGraphMultiBranch(condition, endNodes), nil - } -} diff --git a/backend/domain/workflow/internal/compose/callbacks.go b/backend/domain/workflow/internal/compose/callbacks.go deleted file mode 100644 index 0ad8d74f..00000000 --- a/backend/domain/workflow/internal/compose/callbacks.go +++ /dev/null @@ -1,194 +0,0 @@ -/* - * Copyright 2025 coze-dev Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package compose - -import ( - "context" - "fmt" - "strconv" - "strings" - - "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector" -) - -type selectorCallbackField struct { - Key string `json:"key"` - Type vo.DataType `json:"type"` - Value any `json:"value"` -} - -type selectorCondition struct { - Left selectorCallbackField `json:"left"` - Operator vo.OperatorType `json:"operator"` - Right *selectorCallbackField `json:"right,omitempty"` -} - -type selectorBranch struct { - Conditions []*selectorCondition `json:"conditions"` - Logic vo.LogicType `json:"logic"` - Name string `json:"name"` -} - -func (s *NodeSchema) toSelectorCallbackInput(sc *WorkflowSchema) func(_ context.Context, in map[string]any) (map[string]any, error) { - return func(_ context.Context, in map[string]any) (map[string]any, error) { - config := mustGetKey[[]*selector.OneClauseSchema]("Clauses", s.Configs) - count := len(config) - - output := make([]*selectorBranch, count) - - for _, source := range s.InputSources { - targetPath := source.Path - if len(targetPath) == 2 { - indexStr := targetPath[0] - index, err := strconv.Atoi(indexStr) - if err != nil { - return nil, err - } - - branch := output[index] - if branch == nil { - output[index] = &selectorBranch{ - Conditions: []*selectorCondition{ - { - Operator: config[index].Single.ToCanvasOperatorType(), - }, - }, - Logic: selector.ClauseRelationAND.ToVOLogicType(), - } - } - - if targetPath[1] == selector.LeftKey { - leftV, ok := nodes.TakeMapValue(in, targetPath) - if !ok { - return nil, fmt.Errorf("failed to take left value of %s", targetPath) - } - if source.Source.Ref.VariableType != nil { - if *source.Source.Ref.VariableType == vo.ParentIntermediate { - parentNodeKey, ok := sc.Hierarchy[s.Key] - if !ok { - return nil, fmt.Errorf("failed to find parent node key of %s", s.Key) - } - parentNode := sc.GetNode(parentNodeKey) - output[index].Conditions[0].Left = selectorCallbackField{ - Key: parentNode.Name + "." + strings.Join(source.Source.Ref.FromPath, "."), - Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Type, - Value: leftV, - } - } else { - output[index].Conditions[0].Left = selectorCallbackField{ - Key: "", - Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Type, - Value: leftV, - } - } - } else { - output[index].Conditions[0].Left = selectorCallbackField{ - Key: sc.GetNode(source.Source.Ref.FromNodeKey).Name + "." + strings.Join(source.Source.Ref.FromPath, "."), - Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Type, - Value: leftV, - } - } - } else if targetPath[1] == selector.RightKey { - rightV, ok := nodes.TakeMapValue(in, targetPath) - if !ok { - return nil, fmt.Errorf("failed to take right value of %s", targetPath) - } - output[index].Conditions[0].Right = &selectorCallbackField{ - Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Type, - Value: rightV, - } - } - } else if len(targetPath) == 3 { - indexStr := targetPath[0] - index, err := strconv.Atoi(indexStr) - if err != nil { - return nil, err - } - - multi := config[index].Multi - - branch := output[index] - if branch == nil { - output[index] = &selectorBranch{ - Conditions: make([]*selectorCondition, len(multi.Clauses)), - Logic: multi.Relation.ToVOLogicType(), - } - } - - clauseIndexStr := targetPath[1] - clauseIndex, err := strconv.Atoi(clauseIndexStr) - if err != nil { - return nil, err - } - - clause := multi.Clauses[clauseIndex] - - if output[index].Conditions[clauseIndex] == nil { - output[index].Conditions[clauseIndex] = &selectorCondition{ - Operator: clause.ToCanvasOperatorType(), - } - } - - if targetPath[2] == selector.LeftKey { - leftV, ok := nodes.TakeMapValue(in, targetPath) - if !ok { - return nil, fmt.Errorf("failed to take left value of %s", targetPath) - } - if source.Source.Ref.VariableType != nil { - if *source.Source.Ref.VariableType == vo.ParentIntermediate { - parentNodeKey, ok := sc.Hierarchy[s.Key] - if !ok { - return nil, fmt.Errorf("failed to find parent node key of %s", s.Key) - } - parentNode := sc.GetNode(parentNodeKey) - output[index].Conditions[clauseIndex].Left = selectorCallbackField{ - Key: parentNode.Name + "." + strings.Join(source.Source.Ref.FromPath, "."), - Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type, - Value: leftV, - } - } else { - output[index].Conditions[clauseIndex].Left = selectorCallbackField{ - Key: "", - Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type, - Value: leftV, - } - } - } else { - output[index].Conditions[clauseIndex].Left = selectorCallbackField{ - Key: sc.GetNode(source.Source.Ref.FromNodeKey).Name + "." + strings.Join(source.Source.Ref.FromPath, "."), - Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type, - Value: leftV, - } - } - } else if targetPath[2] == selector.RightKey { - rightV, ok := nodes.TakeMapValue(in, targetPath) - if !ok { - return nil, fmt.Errorf("failed to take right value of %s", targetPath) - } - output[index].Conditions[clauseIndex].Right = &selectorCallbackField{ - Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type, - Value: rightV, - } - } - } - } - - return map[string]any{"branches": output}, nil - } -} diff --git a/backend/domain/workflow/internal/compose/designate_option.go b/backend/domain/workflow/internal/compose/designate_option.go index fb039068..b53a011f 100644 --- a/backend/domain/workflow/internal/compose/designate_option.go +++ b/backend/domain/workflow/internal/compose/designate_option.go @@ -31,7 +31,9 @@ import ( "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm" + schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" ) @@ -53,7 +55,7 @@ func (r *WorkflowRunner) designateOptions(ctx context.Context) (context.Context, rootHandler := execute.NewRootWorkflowHandler( wb, executeID, - workflowSC.requireCheckPoint, + workflowSC.RequireCheckpoint(), eventChan, resumedEvent, exeCfg, @@ -67,7 +69,7 @@ func (r *WorkflowRunner) designateOptions(ctx context.Context) (context.Context, var nodeOpt einoCompose.Option if ns.Type == entity.NodeTypeExit { nodeOpt = nodeCallbackOption(key, ns.Name, eventChan, resumedEvent, - ptr.Of(mustGetKey[vo.TerminatePlan]("TerminalPlan", ns.Configs))) + ptr.Of(ns.Configs.(*exit.Config).TerminatePlan)) } else if ns.Type != entity.NodeTypeLambda { nodeOpt = nodeCallbackOption(key, ns.Name, eventChan, resumedEvent, nil) } @@ -117,7 +119,7 @@ func (r *WorkflowRunner) designateOptions(ctx context.Context) (context.Context, } } - if workflowSC.requireCheckPoint { + if workflowSC.RequireCheckpoint() { opts = append(opts, einoCompose.WithCheckPointID(strconv.FormatInt(executeID, 10))) } @@ -139,7 +141,7 @@ func WrapOptWithIndex(opt einoCompose.Option, parentNodeKey vo.NodeKey, index in func (r *WorkflowRunner) designateOptionsForSubWorkflow(ctx context.Context, parentHandler *execute.WorkflowHandler, - ns *NodeSchema, + ns *schema2.NodeSchema, pathPrefix ...string) (opts []einoCompose.Option, err error) { var ( resumeEvent = r.interruptEvent @@ -163,7 +165,7 @@ func (r *WorkflowRunner) designateOptionsForSubWorkflow(ctx context.Context, var nodeOpt einoCompose.Option if subNS.Type == entity.NodeTypeExit { nodeOpt = nodeCallbackOption(key, subNS.Name, eventChan, resumeEvent, - ptr.Of(mustGetKey[vo.TerminatePlan]("TerminalPlan", subNS.Configs))) + ptr.Of(subNS.Configs.(*exit.Config).TerminatePlan)) } else { nodeOpt = nodeCallbackOption(key, subNS.Name, eventChan, resumeEvent, nil) } @@ -219,7 +221,7 @@ func (r *WorkflowRunner) designateOptionsForSubWorkflow(ctx context.Context, return opts, nil } -func llmToolCallbackOptions(ctx context.Context, ns *NodeSchema, eventChan chan *execute.Event, +func llmToolCallbackOptions(ctx context.Context, ns *schema2.NodeSchema, eventChan chan *execute.Event, sw *schema.StreamWriter[*entity.Message]) ( opts []einoCompose.Option, err error) { // this is a LLM node. @@ -229,7 +231,8 @@ func llmToolCallbackOptions(ctx context.Context, ns *NodeSchema, eventChan chan panic("impossible: llmToolCallbackOptions is called on a non-LLM node") } - fcParams := getKeyOrZero[*vo.FCParam]("FCParam", ns.Configs) + cfg := ns.Configs.(*llm.Config) + fcParams := cfg.FCParam if fcParams != nil { if fcParams.WorkflowFCParam != nil { // TODO: try to avoid getting the workflow tool all over again @@ -272,7 +275,7 @@ func llmToolCallbackOptions(ctx context.Context, ns *NodeSchema, eventChan chan toolHandler := execute.NewToolHandler(eventChan, funcInfo) opt := einoCompose.WithCallbacks(toolHandler) - opt = einoCompose.WithLambdaOption(llm.WithNestedWorkflowOptions(nodes.WithOptsForNested(opt))).DesignateNode(string(ns.Key)) + opt = einoCompose.WithLambdaOption(nodes.WithOptsForNested(opt)).DesignateNode(string(ns.Key)) opts = append(opts, opt) } } @@ -310,7 +313,7 @@ func llmToolCallbackOptions(ctx context.Context, ns *NodeSchema, eventChan chan toolHandler := execute.NewToolHandler(eventChan, funcInfo) opt := einoCompose.WithCallbacks(toolHandler) - opt = einoCompose.WithLambdaOption(llm.WithNestedWorkflowOptions(nodes.WithOptsForNested(opt))).DesignateNode(string(ns.Key)) + opt = einoCompose.WithLambdaOption(nodes.WithOptsForNested(opt)).DesignateNode(string(ns.Key)) opts = append(opts, opt) } } diff --git a/backend/domain/workflow/internal/compose/field_fill.go b/backend/domain/workflow/internal/compose/field_fill.go index 90371ac8..b25b1b08 100644 --- a/backend/domain/workflow/internal/compose/field_fill.go +++ b/backend/domain/workflow/internal/compose/field_fill.go @@ -25,12 +25,13 @@ import ( "github.com/cloudwego/eino/schema" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/pkg/sonic" ) // outputValueFiller will fill the output value with nil if the key is not present in the output map. // if a node emits stream as output, the node needs to handle these absent keys in stream themselves. -func (s *NodeSchema) outputValueFiller() func(ctx context.Context, output map[string]any) (map[string]any, error) { +func outputValueFiller(s *schema2.NodeSchema) func(ctx context.Context, output map[string]any) (map[string]any, error) { if len(s.OutputTypes) == 0 { return func(ctx context.Context, output map[string]any) (map[string]any, error) { return output, nil @@ -55,7 +56,7 @@ func (s *NodeSchema) outputValueFiller() func(ctx context.Context, output map[st // inputValueFiller will fill the input value with default value(zero or nil) if the input value is not present in map. // if a node accepts stream as input, the node needs to handle these absent keys in stream themselves. -func (s *NodeSchema) inputValueFiller() func(ctx context.Context, input map[string]any) (map[string]any, error) { +func inputValueFiller(s *schema2.NodeSchema) func(ctx context.Context, input map[string]any) (map[string]any, error) { if len(s.InputTypes) == 0 { return func(ctx context.Context, input map[string]any) (map[string]any, error) { return input, nil @@ -78,7 +79,7 @@ func (s *NodeSchema) inputValueFiller() func(ctx context.Context, input map[stri } } -func (s *NodeSchema) streamInputValueFiller() func(ctx context.Context, +func streamInputValueFiller(s *schema2.NodeSchema) func(ctx context.Context, input *schema.StreamReader[map[string]any]) *schema.StreamReader[map[string]any] { fn := func(ctx context.Context, i map[string]any) (map[string]any, error) { newI := make(map[string]any) diff --git a/backend/domain/workflow/internal/compose/field_fill_test.go b/backend/domain/workflow/internal/compose/field_fill_test.go index 642afb05..f6cc7eef 100644 --- a/backend/domain/workflow/internal/compose/field_fill_test.go +++ b/backend/domain/workflow/internal/compose/field_fill_test.go @@ -23,6 +23,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" ) func TestNodeSchema_OutputValueFiller(t *testing.T) { @@ -282,11 +283,11 @@ func TestNodeSchema_OutputValueFiller(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := &NodeSchema{ + s := &schema.NodeSchema{ OutputTypes: tt.fields.Outputs, } - got, err := s.outputValueFiller()(context.Background(), tt.fields.In) + got, err := outputValueFiller(s)(context.Background(), tt.fields.In) if len(tt.wantErr) > 0 { assert.Error(t, err) diff --git a/backend/domain/workflow/internal/compose/node_builder.go b/backend/domain/workflow/internal/compose/node_builder.go new file mode 100644 index 00000000..9f36d5f4 --- /dev/null +++ b/backend/domain/workflow/internal/compose/node_builder.go @@ -0,0 +1,118 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "context" + "fmt" + "runtime/debug" + + "github.com/cloudwego/eino/compose" + + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/subworkflow" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" + "github.com/coze-dev/coze-studio/backend/pkg/errorx" + "github.com/coze-dev/coze-studio/backend/pkg/safego" + "github.com/coze-dev/coze-studio/backend/types/errno" +) + +type Node struct { + Lambda *compose.Lambda +} + +// New instantiates the actual node type from NodeSchema. +func New(ctx context.Context, s *schema.NodeSchema, + inner compose.Runnable[map[string]any, map[string]any], // inner workflow for composite node + sc *schema.WorkflowSchema, // the workflow this NodeSchema is in + deps *dependencyInfo, // the dependency for this node pre-calculated by workflow engine +) (_ *Node, err error) { + defer func() { + if panicErr := recover(); panicErr != nil { + err = safego.NewPanicErr(panicErr, debug.Stack()) + } + + if err != nil { + err = vo.WrapIfNeeded(errno.ErrCreateNodeFail, err, errorx.KV("node_name", s.Name), errorx.KV("cause", err.Error())) + } + }() + + var fullSources map[string]*schema.SourceInfo + if m := entity.NodeMetaByNodeType(s.Type); m != nil && m.InputSourceAware { + if fullSources, err = GetFullSources(s, sc, deps); err != nil { + return nil, err + } + s.FullSources = fullSources + } + + // if NodeSchema's Configs implements NodeBuilder, will use it to build the node + nb, ok := s.Configs.(schema.NodeBuilder) + if ok { + opts := []schema.BuildOption{ + schema.WithWorkflowSchema(sc), + schema.WithInnerWorkflow(inner), + } + + // build the actual InvokableNode, etc. + n, err := nb.Build(ctx, s, opts...) + if err != nil { + return nil, err + } + + // wrap InvokableNode, etc. within NodeRunner, converting to eino's Lambda + return toNode(s, n), nil + } + + switch s.Type { + case entity.NodeTypeLambda: + if s.Lambda == nil { + return nil, fmt.Errorf("lambda is not defined for NodeTypeLambda") + } + + return &Node{Lambda: s.Lambda}, nil + case entity.NodeTypeSubWorkflow: + subWorkflow, err := buildSubWorkflow(ctx, s, sc.RequireCheckpoint()) + if err != nil { + return nil, err + } + + return toNode(s, subWorkflow), nil + default: + panic(fmt.Sprintf("node schema's Configs does not implement NodeBuilder. type: %v", s.Type)) + } +} + +func buildSubWorkflow(ctx context.Context, s *schema.NodeSchema, requireCheckpoint bool) (*subworkflow.SubWorkflow, error) { + var opts []WorkflowOption + opts = append(opts, WithIDAsName(s.Configs.(*subworkflow.Config).WorkflowID)) + if requireCheckpoint { + opts = append(opts, WithParentRequireCheckpoint()) + } + if s := execute.GetStaticConfig(); s != nil && s.MaxNodeCountPerWorkflow > 0 { + opts = append(opts, WithMaxNodeCount(s.MaxNodeCountPerWorkflow)) + } + wf, err := NewWorkflow(ctx, s.SubWorkflowSchema, opts...) + if err != nil { + return nil, err + } + + return &subworkflow.SubWorkflow{ + Runner: wf.Runner, + }, nil +} diff --git a/backend/domain/workflow/internal/compose/node_runner.go b/backend/domain/workflow/internal/compose/node_runner.go index 57b606ed..ae47a15d 100644 --- a/backend/domain/workflow/internal/compose/node_runner.go +++ b/backend/domain/workflow/internal/compose/node_runner.go @@ -33,6 +33,8 @@ import ( "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" + schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" + "github.com/coze-dev/coze-studio/backend/pkg/ctxcache" "github.com/coze-dev/coze-studio/backend/pkg/errorx" "github.com/coze-dev/coze-studio/backend/pkg/logs" "github.com/coze-dev/coze-studio/backend/pkg/safego" @@ -48,7 +50,6 @@ type nodeRunConfig[O any] struct { maxRetry int64 errProcessType vo.ErrorProcessType dataOnErr func(ctx context.Context) map[string]any - callbackEnabled bool preProcessors []func(ctx context.Context, input map[string]any) (map[string]any, error) postProcessors []func(ctx context.Context, input map[string]any) (map[string]any, error) streamPreProcessors []func(ctx context.Context, @@ -58,12 +59,14 @@ type nodeRunConfig[O any] struct { init []func(context.Context) (context.Context, error) i compose.Invoke[map[string]any, map[string]any, O] s compose.Stream[map[string]any, map[string]any, O] + c compose.Collect[map[string]any, map[string]any, O] t compose.Transform[map[string]any, map[string]any, O] } -func newNodeRunConfig[O any](ns *NodeSchema, +func newNodeRunConfig[O any](ns *schema2.NodeSchema, i compose.Invoke[map[string]any, map[string]any, O], s compose.Stream[map[string]any, map[string]any, O], + c compose.Collect[map[string]any, map[string]any, O], t compose.Transform[map[string]any, map[string]any, O], opts *newNodeOptions) *nodeRunConfig[O] { meta := entity.NodeMetaByNodeType(ns.Type) @@ -92,12 +95,12 @@ func newNodeRunConfig[O any](ns *NodeSchema, keyFinishedMarkerTrimmer(), } if meta.PreFillZero { - preProcessors = append(preProcessors, ns.inputValueFiller()) + preProcessors = append(preProcessors, inputValueFiller(ns)) } var postProcessors []func(ctx context.Context, input map[string]any) (map[string]any, error) if meta.PostFillNil { - postProcessors = append(postProcessors, ns.outputValueFiller()) + postProcessors = append(postProcessors, outputValueFiller(ns)) } streamPreProcessors := []func(ctx context.Context, @@ -110,7 +113,15 @@ func newNodeRunConfig[O any](ns *NodeSchema, }, } if meta.PreFillZero { - streamPreProcessors = append(streamPreProcessors, ns.streamInputValueFiller()) + streamPreProcessors = append(streamPreProcessors, streamInputValueFiller(ns)) + } + + if meta.UseCtxCache { + opts.init = append([]func(ctx context.Context) (context.Context, error){ + func(ctx context.Context) (context.Context, error) { + return ctxcache.Init(ctx), nil + }, + }, opts.init...) } opts.init = append(opts.init, func(ctx context.Context) (context.Context, error) { @@ -129,7 +140,6 @@ func newNodeRunConfig[O any](ns *NodeSchema, maxRetry: maxRetry, errProcessType: errProcessType, dataOnErr: dataOnErr, - callbackEnabled: meta.CallbackEnabled, preProcessors: preProcessors, postProcessors: postProcessors, streamPreProcessors: streamPreProcessors, @@ -138,18 +148,21 @@ func newNodeRunConfig[O any](ns *NodeSchema, init: opts.init, i: i, s: s, + c: c, t: t, } } -func newNodeRunConfigWOOpt(ns *NodeSchema, +func newNodeRunConfigWOOpt(ns *schema2.NodeSchema, i compose.InvokeWOOpt[map[string]any, map[string]any], s compose.StreamWOOpt[map[string]any, map[string]any], + c compose.CollectWOOpt[map[string]any, map[string]any], t compose.TransformWOOpts[map[string]any, map[string]any], opts *newNodeOptions) *nodeRunConfig[any] { var ( iWO compose.Invoke[map[string]any, map[string]any, any] sWO compose.Stream[map[string]any, map[string]any, any] + cWO compose.Collect[map[string]any, map[string]any, any] tWO compose.Transform[map[string]any, map[string]any, any] ) @@ -165,13 +178,19 @@ func newNodeRunConfigWOOpt(ns *NodeSchema, } } + if c != nil { + cWO = func(ctx context.Context, in *schema.StreamReader[map[string]any], _ ...any) (out map[string]any, err error) { + return c(ctx, in) + } + } + if t != nil { tWO = func(ctx context.Context, input *schema.StreamReader[map[string]any], opts ...any) (output *schema.StreamReader[map[string]any], err error) { return t(ctx, input) } } - return newNodeRunConfig[any](ns, iWO, sWO, tWO, opts) + return newNodeRunConfig[any](ns, iWO, sWO, cWO, tWO, opts) } type newNodeOptions struct { @@ -180,57 +199,100 @@ type newNodeOptions struct { init []func(context.Context) (context.Context, error) } -type newNodeOption func(*newNodeOptions) +func toNode(ns *schema2.NodeSchema, r any) *Node { + iWOpt, _ := r.(nodes.InvokableNodeWOpt) + sWOpt, _ := r.(nodes.StreamableNodeWOpt) + cWOpt, _ := r.(nodes.CollectableNodeWOpt) + tWOpt, _ := r.(nodes.TransformableNodeWOpt) + iWOOpt, _ := r.(nodes.InvokableNode) + sWOOpt, _ := r.(nodes.StreamableNode) + cWOOpt, _ := r.(nodes.CollectableNode) + tWOOpt, _ := r.(nodes.TransformableNode) -func withCallbackInputConverter(f func(context.Context, map[string]any) (map[string]any, error)) newNodeOption { - return func(opts *newNodeOptions) { - opts.callbackInputConverter = f + var wOpt, wOOpt bool + if iWOpt != nil || sWOpt != nil || cWOpt != nil || tWOpt != nil { + wOpt = true } -} -func withCallbackOutputConverter(f func(context.Context, map[string]any) (*nodes.StructuredCallbackOutput, error)) newNodeOption { - return func(opts *newNodeOptions) { - opts.callbackOutputConverter = f + if iWOOpt != nil || sWOOpt != nil || cWOOpt != nil || tWOOpt != nil { + wOOpt = true + } + + if wOpt && wOOpt { + panic("a node's different streaming methods needs to be consistent: " + + "they should ALL have NodeOption or None should have them") + } + + if !wOpt && !wOOpt { + panic("a node should implement at least one interface among: InvokableNodeWOpt, StreamableNodeWOpt, CollectableNodeWOpt, TransformableNodeWOpt, InvokableNode, StreamableNode, CollectableNode, TransformableNode") } -} -func withInit(f func(context.Context) (context.Context, error)) newNodeOption { - return func(opts *newNodeOptions) { - opts.init = append(opts.init, f) - } -} -func invokableNode(ns *NodeSchema, i compose.InvokeWOOpt[map[string]any, map[string]any], opts ...newNodeOption) *Node { options := &newNodeOptions{} - for _, opt := range opts { - opt(options) + ci, ok := r.(nodes.CallbackInputConverted) + if ok { + options.callbackInputConverter = ci.ToCallbackInput } - return newNodeRunConfigWOOpt(ns, i, nil, nil, options).toNode() -} - -func invokableNodeWO[O any](ns *NodeSchema, i compose.Invoke[map[string]any, map[string]any, O], opts ...newNodeOption) *Node { - options := &newNodeOptions{} - for _, opt := range opts { - opt(options) + co, ok := r.(nodes.CallbackOutputConverted) + if ok { + options.callbackOutputConverter = co.ToCallbackOutput } - return newNodeRunConfig(ns, i, nil, nil, options).toNode() -} - -func invokableTransformableNode(ns *NodeSchema, i compose.InvokeWOOpt[map[string]any, map[string]any], - t compose.TransformWOOpts[map[string]any, map[string]any], opts ...newNodeOption) *Node { - options := &newNodeOptions{} - for _, opt := range opts { - opt(options) + init, ok := r.(nodes.Initializer) + if ok { + options.init = append(options.init, init.Init) } - return newNodeRunConfigWOOpt(ns, i, nil, t, options).toNode() -} -func invokableStreamableNodeWO[O any](ns *NodeSchema, i compose.Invoke[map[string]any, map[string]any, O], s compose.Stream[map[string]any, map[string]any, O], opts ...newNodeOption) *Node { - options := &newNodeOptions{} - for _, opt := range opts { - opt(options) + if wOpt { + var ( + i compose.Invoke[map[string]any, map[string]any, nodes.NodeOption] + s compose.Stream[map[string]any, map[string]any, nodes.NodeOption] + c compose.Collect[map[string]any, map[string]any, nodes.NodeOption] + t compose.Transform[map[string]any, map[string]any, nodes.NodeOption] + ) + + if iWOpt != nil { + i = iWOpt.Invoke + } + + if sWOpt != nil { + s = sWOpt.Stream + } + + if cWOpt != nil { + c = cWOpt.Collect + } + + if tWOpt != nil { + t = tWOpt.Transform + } + + return newNodeRunConfig(ns, i, s, c, t, options).toNode() } - return newNodeRunConfig(ns, i, s, nil, options).toNode() + + var ( + i compose.InvokeWOOpt[map[string]any, map[string]any] + s compose.StreamWOOpt[map[string]any, map[string]any] + c compose.CollectWOOpt[map[string]any, map[string]any] + t compose.TransformWOOpts[map[string]any, map[string]any] + ) + + if iWOOpt != nil { + i = iWOOpt.Invoke + } + + if sWOOpt != nil { + s = sWOOpt.Stream + } + + if cWOOpt != nil { + c = cWOOpt.Collect + } + + if tWOOpt != nil { + t = tWOOpt.Transform + } + + return newNodeRunConfigWOOpt(ns, i, s, c, t, options).toNode() } func (nc *nodeRunConfig[O]) invoke() func(ctx context.Context, input map[string]any, opts ...O) (output map[string]any, err error) { @@ -375,10 +437,8 @@ func (nc *nodeRunConfig[O]) transform() func(ctx context.Context, input *schema. func (nc *nodeRunConfig[O]) toNode() *Node { var opts []compose.LambdaOpt opts = append(opts, compose.WithLambdaType(string(nc.nodeType))) + opts = append(opts, compose.WithLambdaCallbackEnable(true)) - if nc.callbackEnabled { - opts = append(opts, compose.WithLambdaCallbackEnable(true)) - } l, err := compose.AnyLambda(nc.invoke(), nc.stream(), nil, nc.transform(), opts...) if err != nil { panic(fmt.Sprintf("failed to create lambda for node %s, err: %v", nc.nodeName, err)) @@ -406,9 +466,6 @@ func newNodeRunner[O any](ctx context.Context, cfg *nodeRunConfig[O]) (context.C } func (r *nodeRunner[O]) onStart(ctx context.Context, input map[string]any) (context.Context, error) { - if !r.callbackEnabled { - return ctx, nil - } if r.callbackInputConverter != nil { convertedInput, err := r.callbackInputConverter(ctx, input) if err != nil { @@ -425,10 +482,6 @@ func (r *nodeRunner[O]) onStart(ctx context.Context, input map[string]any) (cont func (r *nodeRunner[O]) onStartStream(ctx context.Context, input *schema.StreamReader[map[string]any]) ( context.Context, *schema.StreamReader[map[string]any], error) { - if !r.callbackEnabled { - return ctx, input, nil - } - if r.callbackInputConverter != nil { copied := input.Copy(2) realConverter := func(ctx context.Context) func(map[string]any) (map[string]any, error) { @@ -580,14 +633,10 @@ func (r *nodeRunner[O]) transform(ctx context.Context, input *schema.StreamReade } func (r *nodeRunner[O]) onEnd(ctx context.Context, output map[string]any) error { - if r.errProcessType == vo.ErrorProcessTypeExceptionBranch || r.errProcessType == vo.ErrorProcessTypeDefault { + if r.errProcessType == vo.ErrorProcessTypeExceptionBranch || r.errProcessType == vo.ErrorProcessTypeReturnDefaultData { output["isSuccess"] = true } - if !r.callbackEnabled { - return nil - } - if r.callbackOutputConverter != nil { convertedOutput, err := r.callbackOutputConverter(ctx, output) if err != nil { @@ -603,15 +652,11 @@ func (r *nodeRunner[O]) onEnd(ctx context.Context, output map[string]any) error func (r *nodeRunner[O]) onEndStream(ctx context.Context, output *schema.StreamReader[map[string]any]) ( *schema.StreamReader[map[string]any], error) { - if r.errProcessType == vo.ErrorProcessTypeExceptionBranch || r.errProcessType == vo.ErrorProcessTypeDefault { + if r.errProcessType == vo.ErrorProcessTypeExceptionBranch || r.errProcessType == vo.ErrorProcessTypeReturnDefaultData { flag := schema.StreamReaderFromArray([]map[string]any{{"isSuccess": true}}) output = schema.MergeStreamReaders([]*schema.StreamReader[map[string]any]{flag, output}) } - if !r.callbackEnabled { - return output, nil - } - if r.callbackOutputConverter != nil { copied := output.Copy(2) realConverter := func(ctx context.Context) func(map[string]any) (*nodes.StructuredCallbackOutput, error) { @@ -632,9 +677,7 @@ func (r *nodeRunner[O]) onEndStream(ctx context.Context, output *schema.StreamRe func (r *nodeRunner[O]) onError(ctx context.Context, err error) (map[string]any, bool) { if r.interrupted { - if r.callbackEnabled { - _ = callbacks.OnError(ctx, err) - } + _ = callbacks.OnError(ctx, err) return nil, false } @@ -653,22 +696,20 @@ func (r *nodeRunner[O]) onError(ctx context.Context, err error) (map[string]any, msg := sErr.Msg() switch r.errProcessType { - case vo.ErrorProcessTypeDefault: + case vo.ErrorProcessTypeReturnDefaultData: d := r.dataOnErr(ctx) d["errorBody"] = map[string]any{ "errorMessage": msg, "errorCode": code, } d["isSuccess"] = false - if r.callbackEnabled { - sErr = sErr.ChangeErrLevel(vo.LevelWarn) - sOutput := &nodes.StructuredCallbackOutput{ - Output: d, - RawOutput: d, - Error: sErr, - } - _ = callbacks.OnEnd(ctx, sOutput) + sErr = sErr.ChangeErrLevel(vo.LevelWarn) + sOutput := &nodes.StructuredCallbackOutput{ + Output: d, + RawOutput: d, + Error: sErr, } + _ = callbacks.OnEnd(ctx, sOutput) return d, true case vo.ErrorProcessTypeExceptionBranch: s := make(map[string]any) @@ -677,20 +718,16 @@ func (r *nodeRunner[O]) onError(ctx context.Context, err error) (map[string]any, "errorCode": code, } s["isSuccess"] = false - if r.callbackEnabled { - sErr = sErr.ChangeErrLevel(vo.LevelWarn) - sOutput := &nodes.StructuredCallbackOutput{ - Output: s, - RawOutput: s, - Error: sErr, - } - _ = callbacks.OnEnd(ctx, sOutput) + sErr = sErr.ChangeErrLevel(vo.LevelWarn) + sOutput := &nodes.StructuredCallbackOutput{ + Output: s, + RawOutput: s, + Error: sErr, } + _ = callbacks.OnEnd(ctx, sOutput) return s, true default: - if r.callbackEnabled { - _ = callbacks.OnError(ctx, sErr) - } + _ = callbacks.OnError(ctx, sErr) return nil, false } } diff --git a/backend/domain/workflow/internal/compose/node_schema.go b/backend/domain/workflow/internal/compose/node_schema.go deleted file mode 100644 index 1d5a301b..00000000 --- a/backend/domain/workflow/internal/compose/node_schema.go +++ /dev/null @@ -1,580 +0,0 @@ -/* - * Copyright 2025 coze-dev Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package compose - -import ( - "context" - "fmt" - "runtime/debug" - - "github.com/cloudwego/eino/compose" - - "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model" - "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" - "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/batch" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/code" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/conversation" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/database" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/emitter" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/httprequester" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/intentdetector" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/json" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/knowledge" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/plugin" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/receiver" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/subworkflow" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/textprocessor" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableaggregator" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableassigner" - "github.com/coze-dev/coze-studio/backend/pkg/ctxcache" - "github.com/coze-dev/coze-studio/backend/pkg/errorx" - "github.com/coze-dev/coze-studio/backend/pkg/safego" - "github.com/coze-dev/coze-studio/backend/types/errno" - - "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable" -) - -type NodeSchema struct { - Key vo.NodeKey `json:"key"` - Name string `json:"name"` - - Type entity.NodeType `json:"type"` - - // Configs are node specific configurations with pre-defined config key and config value. - // Will not participate in request-time field mapping, nor as node's static values. - // In a word, these Configs are INTERNAL to node's implementation, the workflow layer is not aware of them. - Configs any `json:"configs,omitempty"` - - InputTypes map[string]*vo.TypeInfo `json:"input_types,omitempty"` - InputSources []*vo.FieldInfo `json:"input_sources,omitempty"` - - OutputTypes map[string]*vo.TypeInfo `json:"output_types,omitempty"` - OutputSources []*vo.FieldInfo `json:"output_sources,omitempty"` // only applicable to composite nodes such as Batch or Loop - - ExceptionConfigs *ExceptionConfig `json:"exception_configs,omitempty"` // generic configurations applicable to most nodes - StreamConfigs *StreamConfig `json:"stream_configs,omitempty"` - - SubWorkflowBasic *entity.WorkflowBasic `json:"sub_workflow_basic,omitempty"` - SubWorkflowSchema *WorkflowSchema `json:"sub_workflow_schema,omitempty"` - - Lambda *compose.Lambda // not serializable, used for internal test. -} - -type ExceptionConfig struct { - TimeoutMS int64 `json:"timeout_ms,omitempty"` // timeout in milliseconds, 0 means no timeout - MaxRetry int64 `json:"max_retry,omitempty"` // max retry times, 0 means no retry - ProcessType *vo.ErrorProcessType `json:"process_type,omitempty"` // error process type, 0 means throw error - DataOnErr string `json:"data_on_err,omitempty"` // data to return when error, effective when ProcessType==Default occurs -} - -type StreamConfig struct { - // whether this node has the ability to produce genuine streaming output. - // not include nodes that only passes stream down as they receives them - CanGeneratesStream bool `json:"can_generates_stream,omitempty"` - // whether this node prioritize streaming input over none-streaming input. - // not include nodes that can accept both and does not have preference. - RequireStreamingInput bool `json:"can_process_stream,omitempty"` -} - -type Node struct { - Lambda *compose.Lambda -} - -func (s *NodeSchema) New(ctx context.Context, inner compose.Runnable[map[string]any, map[string]any], - sc *WorkflowSchema, deps *dependencyInfo) (_ *Node, err error) { - defer func() { - if panicErr := recover(); panicErr != nil { - err = safego.NewPanicErr(panicErr, debug.Stack()) - } - - if err != nil { - err = vo.WrapIfNeeded(errno.ErrCreateNodeFail, err, errorx.KV("node_name", s.Name), errorx.KV("cause", err.Error())) - } - }() - - if m := entity.NodeMetaByNodeType(s.Type); m != nil && m.InputSourceAware { - if err = s.SetFullSources(sc.GetAllNodes(), deps); err != nil { - return nil, err - } - } - - switch s.Type { - case entity.NodeTypeLambda: - if s.Lambda == nil { - return nil, fmt.Errorf("lambda is not defined for NodeTypeLambda") - } - - return &Node{Lambda: s.Lambda}, nil - case entity.NodeTypeLLM: - conf, err := s.ToLLMConfig(ctx) - if err != nil { - return nil, err - } - - l, err := llm.New(ctx, conf) - if err != nil { - return nil, err - } - - return invokableStreamableNodeWO(s, l.Chat, l.ChatStream, withCallbackOutputConverter(l.ToCallbackOutput)), nil - case entity.NodeTypeSelector: - conf := s.ToSelectorConfig() - - sl, err := selector.NewSelector(ctx, conf) - if err != nil { - return nil, err - } - - return invokableNode(s, sl.Select, withCallbackInputConverter(s.toSelectorCallbackInput(sc)), withCallbackOutputConverter(sl.ToCallbackOutput)), nil - case entity.NodeTypeBatch: - if inner == nil { - return nil, fmt.Errorf("inner workflow must not be nil when creating batch node") - } - - conf, err := s.ToBatchConfig(inner) - if err != nil { - return nil, err - } - - b, err := batch.NewBatch(ctx, conf) - if err != nil { - return nil, err - } - - return invokableNodeWO(s, b.Execute, withCallbackInputConverter(b.ToCallbackInput)), nil - case entity.NodeTypeVariableAggregator: - conf, err := s.ToVariableAggregatorConfig() - if err != nil { - return nil, err - } - - va, err := variableaggregator.NewVariableAggregator(ctx, conf) - if err != nil { - return nil, err - } - - return invokableTransformableNode(s, va.Invoke, va.Transform, - withCallbackInputConverter(va.ToCallbackInput), - withCallbackOutputConverter(va.ToCallbackOutput), - withInit(va.Init)), nil - case entity.NodeTypeTextProcessor: - conf, err := s.ToTextProcessorConfig() - if err != nil { - return nil, err - } - - tp, err := textprocessor.NewTextProcessor(ctx, conf) - if err != nil { - return nil, err - } - - return invokableNode(s, tp.Invoke), nil - case entity.NodeTypeHTTPRequester: - conf, err := s.ToHTTPRequesterConfig() - if err != nil { - return nil, err - } - - hr, err := httprequester.NewHTTPRequester(ctx, conf) - if err != nil { - return nil, err - } - - return invokableNode(s, hr.Invoke, withCallbackInputConverter(hr.ToCallbackInput), withCallbackOutputConverter(hr.ToCallbackOutput)), nil - case entity.NodeTypeContinue: - i := func(ctx context.Context, in map[string]any) (map[string]any, error) { - return map[string]any{}, nil - } - return invokableNode(s, i), nil - case entity.NodeTypeBreak: - b, err := loop.NewBreak(ctx, &nodes.ParentIntermediateStore{}) - if err != nil { - return nil, err - } - return invokableNode(s, b.DoBreak), nil - case entity.NodeTypeVariableAssigner: - handler := variable.GetVariableHandler() - - conf, err := s.ToVariableAssignerConfig(handler) - if err != nil { - return nil, err - } - va, err := variableassigner.NewVariableAssigner(ctx, conf) - if err != nil { - return nil, err - } - - return invokableNode(s, va.Assign), nil - case entity.NodeTypeVariableAssignerWithinLoop: - conf, err := s.ToVariableAssignerInLoopConfig() - if err != nil { - return nil, err - } - - va, err := variableassigner.NewVariableAssignerInLoop(ctx, conf) - if err != nil { - return nil, err - } - - return invokableNode(s, va.Assign), nil - case entity.NodeTypeLoop: - conf, err := s.ToLoopConfig(inner) - if err != nil { - return nil, err - } - l, err := loop.NewLoop(ctx, conf) - if err != nil { - return nil, err - } - return invokableNodeWO(s, l.Execute, withCallbackInputConverter(l.ToCallbackInput)), nil - case entity.NodeTypeQuestionAnswer: - conf, err := s.ToQAConfig(ctx) - if err != nil { - return nil, err - } - qA, err := qa.NewQuestionAnswer(ctx, conf) - if err != nil { - return nil, err - } - return invokableNode(s, qA.Execute, withCallbackOutputConverter(qA.ToCallbackOutput)), nil - case entity.NodeTypeInputReceiver: - conf, err := s.ToInputReceiverConfig() - if err != nil { - return nil, err - } - inputR, err := receiver.New(ctx, conf) - if err != nil { - return nil, err - } - return invokableNode(s, inputR.Invoke, withCallbackOutputConverter(inputR.ToCallbackOutput)), nil - case entity.NodeTypeOutputEmitter: - conf, err := s.ToOutputEmitterConfig(sc) - if err != nil { - return nil, err - } - - e, err := emitter.New(ctx, conf) - if err != nil { - return nil, err - } - - return invokableTransformableNode(s, e.Emit, e.EmitStream), nil - case entity.NodeTypeEntry: - conf, err := s.ToEntryConfig(ctx) - if err != nil { - return nil, err - } - e, err := entry.NewEntry(ctx, conf) - if err != nil { - return nil, err - } - return invokableNode(s, e.Invoke), nil - case entity.NodeTypeExit: - terminalPlan := mustGetKey[vo.TerminatePlan]("TerminalPlan", s.Configs) - if terminalPlan == vo.ReturnVariables { - i := func(ctx context.Context, in map[string]any) (map[string]any, error) { - if in == nil { - return map[string]any{}, nil - } - return in, nil - } - return invokableNode(s, i), nil - } - - conf, err := s.ToOutputEmitterConfig(sc) - if err != nil { - return nil, err - } - - e, err := emitter.New(ctx, conf) - if err != nil { - return nil, err - } - - return invokableTransformableNode(s, e.Emit, e.EmitStream), nil - case entity.NodeTypeDatabaseCustomSQL: - conf, err := s.ToDatabaseCustomSQLConfig() - if err != nil { - return nil, err - } - - sqlER, err := database.NewCustomSQL(ctx, conf) - if err != nil { - return nil, err - } - return invokableNode(s, sqlER.Execute), nil - case entity.NodeTypeDatabaseQuery: - conf, err := s.ToDatabaseQueryConfig() - if err != nil { - return nil, err - } - - query, err := database.NewQuery(ctx, conf) - if err != nil { - return nil, err - } - - return invokableNode(s, query.Query, withCallbackInputConverter(query.ToCallbackInput)), nil - case entity.NodeTypeDatabaseInsert: - conf, err := s.ToDatabaseInsertConfig() - if err != nil { - return nil, err - } - - insert, err := database.NewInsert(ctx, conf) - if err != nil { - return nil, err - } - - return invokableNode(s, insert.Insert, withCallbackInputConverter(insert.ToCallbackInput)), nil - case entity.NodeTypeDatabaseUpdate: - conf, err := s.ToDatabaseUpdateConfig() - if err != nil { - return nil, err - } - update, err := database.NewUpdate(ctx, conf) - if err != nil { - return nil, err - } - return invokableNode(s, update.Update, withCallbackInputConverter(update.ToCallbackInput)), nil - case entity.NodeTypeDatabaseDelete: - conf, err := s.ToDatabaseDeleteConfig() - if err != nil { - return nil, err - } - del, err := database.NewDelete(ctx, conf) - if err != nil { - return nil, err - } - - return invokableNode(s, del.Delete, withCallbackInputConverter(del.ToCallbackInput)), nil - case entity.NodeTypeKnowledgeIndexer: - conf, err := s.ToKnowledgeIndexerConfig() - if err != nil { - return nil, err - } - w, err := knowledge.NewKnowledgeIndexer(ctx, conf) - if err != nil { - return nil, err - } - return invokableNode(s, w.Store), nil - case entity.NodeTypeKnowledgeRetriever: - conf, err := s.ToKnowledgeRetrieveConfig() - if err != nil { - return nil, err - } - r, err := knowledge.NewKnowledgeRetrieve(ctx, conf) - if err != nil { - return nil, err - } - return invokableNode(s, r.Retrieve), nil - case entity.NodeTypeKnowledgeDeleter: - conf, err := s.ToKnowledgeDeleterConfig() - if err != nil { - return nil, err - } - r, err := knowledge.NewKnowledgeDeleter(ctx, conf) - if err != nil { - return nil, err - } - return invokableNode(s, r.Delete), nil - case entity.NodeTypeCodeRunner: - conf, err := s.ToCodeRunnerConfig() - if err != nil { - return nil, err - } - r, err := code.NewCodeRunner(ctx, conf) - if err != nil { - return nil, err - } - initFn := func(ctx context.Context) (context.Context, error) { - return ctxcache.Init(ctx), nil - } - return invokableNode(s, r.RunCode, withCallbackOutputConverter(r.ToCallbackOutput), withInit(initFn)), nil - case entity.NodeTypePlugin: - conf, err := s.ToPluginConfig() - if err != nil { - return nil, err - } - r, err := plugin.NewPlugin(ctx, conf) - if err != nil { - return nil, err - } - return invokableNode(s, r.Invoke), nil - case entity.NodeTypeCreateConversation: - conf, err := s.ToCreateConversationConfig() - if err != nil { - return nil, err - } - r, err := conversation.NewCreateConversation(ctx, conf) - if err != nil { - return nil, err - } - return invokableNode(s, r.Create), nil - case entity.NodeTypeMessageList: - conf, err := s.ToMessageListConfig() - if err != nil { - return nil, err - } - r, err := conversation.NewMessageList(ctx, conf) - if err != nil { - return nil, err - } - return invokableNode(s, r.List), nil - case entity.NodeTypeClearMessage: - conf, err := s.ToClearMessageConfig() - if err != nil { - return nil, err - } - r, err := conversation.NewClearMessage(ctx, conf) - if err != nil { - return nil, err - } - return invokableNode(s, r.Clear), nil - case entity.NodeTypeIntentDetector: - conf, err := s.ToIntentDetectorConfig(ctx) - if err != nil { - return nil, err - } - r, err := intentdetector.NewIntentDetector(ctx, conf) - if err != nil { - return nil, err - } - - return invokableNode(s, r.Invoke), nil - case entity.NodeTypeSubWorkflow: - conf, err := s.ToSubWorkflowConfig(ctx, sc.requireCheckPoint) - if err != nil { - return nil, err - } - r, err := subworkflow.NewSubWorkflow(ctx, conf) - if err != nil { - return nil, err - } - return invokableStreamableNodeWO(s, r.Invoke, r.Stream), nil - case entity.NodeTypeJsonSerialization: - conf, err := s.ToJsonSerializationConfig() - if err != nil { - return nil, err - } - js, err := json.NewJsonSerializer(ctx, conf) - if err != nil { - return nil, err - } - return invokableNode(s, js.Invoke), nil - case entity.NodeTypeJsonDeserialization: - conf, err := s.ToJsonDeserializationConfig() - if err != nil { - return nil, err - } - jd, err := json.NewJsonDeserializer(ctx, conf) - if err != nil { - return nil, err - } - initFn := func(ctx context.Context) (context.Context, error) { - return ctxcache.Init(ctx), nil - } - return invokableNode(s, jd.Invoke, withCallbackOutputConverter(jd.ToCallbackOutput), withInit(initFn)), nil - default: - panic("not implemented") - } -} - -func (s *NodeSchema) IsEnableUserQuery() bool { - if s == nil { - return false - } - if s.Type != entity.NodeTypeEntry { - return false - } - - if len(s.OutputSources) == 0 { - return false - } - - for _, source := range s.OutputSources { - fieldPath := source.Path - if len(fieldPath) == 1 && (fieldPath[0] == "BOT_USER_INPUT" || fieldPath[0] == "USER_INPUT") { - return true - } - } - - return false - -} - -func (s *NodeSchema) IsEnableChatHistory() bool { - if s == nil { - return false - } - - switch s.Type { - - case entity.NodeTypeLLM: - llmParam := mustGetKey[*model.LLMParams]("LLMParams", s.Configs) - return llmParam.EnableChatHistory - case entity.NodeTypeIntentDetector: - llmParam := mustGetKey[*model.LLMParams]("LLMParams", s.Configs) - return llmParam.EnableChatHistory - default: - return false - } - -} - -func (s *NodeSchema) IsRefGlobalVariable() bool { - for _, source := range s.InputSources { - if source.IsRefGlobalVariable() { - return true - } - } - for _, source := range s.OutputSources { - if source.IsRefGlobalVariable() { - return true - } - } - return false -} - -func (s *NodeSchema) requireCheckpoint() bool { - if s.Type == entity.NodeTypeQuestionAnswer || s.Type == entity.NodeTypeInputReceiver { - return true - } - - if s.Type == entity.NodeTypeLLM { - fcParams := getKeyOrZero[*vo.FCParam]("FCParam", s.Configs) - if fcParams != nil && fcParams.WorkflowFCParam != nil { - return true - } - } - - if s.Type == entity.NodeTypeSubWorkflow { - s.SubWorkflowSchema.Init() - if s.SubWorkflowSchema.requireCheckPoint { - return true - } - } - - return false -} diff --git a/backend/domain/workflow/internal/compose/state.go b/backend/domain/workflow/internal/compose/state.go index bd4b691a..2b85bcf0 100644 --- a/backend/domain/workflow/internal/compose/state.go +++ b/backend/domain/workflow/internal/compose/state.go @@ -32,8 +32,10 @@ import ( "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/receiver" + schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/pkg/sonic" ) @@ -46,9 +48,9 @@ type State struct { InterruptEvents map[vo.NodeKey]*entity.InterruptEvent `json:"interrupt_events,omitempty"` NestedWorkflowStates map[vo.NodeKey]*nodes.NestedWorkflowState `json:"nested_workflow_states,omitempty"` - ExecutedNodes map[vo.NodeKey]bool `json:"executed_nodes,omitempty"` - SourceInfos map[vo.NodeKey]map[string]*nodes.SourceInfo `json:"source_infos,omitempty"` - GroupChoices map[vo.NodeKey]map[string]int `json:"group_choices,omitempty"` + ExecutedNodes map[vo.NodeKey]bool `json:"executed_nodes,omitempty"` + SourceInfos map[vo.NodeKey]map[string]*schema2.SourceInfo `json:"source_infos,omitempty"` + GroupChoices map[vo.NodeKey]map[string]int `json:"group_choices,omitempty"` ToolInterruptEvents map[vo.NodeKey]map[string] /*ToolCallID*/ *entity.ToolInterruptEvent `json:"tool_interrupt_events,omitempty"` LLMToResumeData map[vo.NodeKey]string `json:"llm_to_resume_data,omitempty"` @@ -71,8 +73,8 @@ func init() { _ = compose.RegisterSerializableType[*model.TokenUsage]("model_token_usage") _ = compose.RegisterSerializableType[*nodes.NestedWorkflowState]("composite_state") _ = compose.RegisterSerializableType[*compose.InterruptInfo]("interrupt_info") - _ = compose.RegisterSerializableType[*nodes.SourceInfo]("source_info") - _ = compose.RegisterSerializableType[nodes.FieldStreamType]("field_stream_type") + _ = compose.RegisterSerializableType[*schema2.SourceInfo]("source_info") + _ = compose.RegisterSerializableType[schema2.FieldStreamType]("field_stream_type") _ = compose.RegisterSerializableType[compose.FieldPath]("field_path") _ = compose.RegisterSerializableType[*entity.WorkflowBasic]("workflow_basic") _ = compose.RegisterSerializableType[vo.TerminatePlan]("terminate_plan") @@ -162,41 +164,41 @@ func (s *State) GetDynamicChoice(nodeKey vo.NodeKey) map[string]int { return s.GroupChoices[nodeKey] } -func (s *State) GetDynamicStreamType(nodeKey vo.NodeKey, group string) (nodes.FieldStreamType, error) { +func (s *State) GetDynamicStreamType(nodeKey vo.NodeKey, group string) (schema2.FieldStreamType, error) { choices, ok := s.GroupChoices[nodeKey] if !ok { - return nodes.FieldMaybeStream, fmt.Errorf("choice not found for node %s", nodeKey) + return schema2.FieldMaybeStream, fmt.Errorf("choice not found for node %s", nodeKey) } choice, ok := choices[group] if !ok { - return nodes.FieldMaybeStream, fmt.Errorf("choice not found for node %s and group %s", nodeKey, group) + return schema2.FieldMaybeStream, fmt.Errorf("choice not found for node %s and group %s", nodeKey, group) } if choice == -1 { // this group picks none of the elements - return nodes.FieldNotStream, nil + return schema2.FieldNotStream, nil } sInfos, ok := s.SourceInfos[nodeKey] if !ok { - return nodes.FieldMaybeStream, fmt.Errorf("source infos not found for node %s", nodeKey) + return schema2.FieldMaybeStream, fmt.Errorf("source infos not found for node %s", nodeKey) } groupInfo, ok := sInfos[group] if !ok { - return nodes.FieldMaybeStream, fmt.Errorf("source infos not found for node %s and group %s", nodeKey, group) + return schema2.FieldMaybeStream, fmt.Errorf("source infos not found for node %s and group %s", nodeKey, group) } if groupInfo.SubSources == nil { - return nodes.FieldNotStream, fmt.Errorf("dynamic group %s of node %s does not contain any sub sources", group, nodeKey) + return schema2.FieldNotStream, fmt.Errorf("dynamic group %s of node %s does not contain any sub sources", group, nodeKey) } subInfo, ok := groupInfo.SubSources[strconv.Itoa(choice)] if !ok { - return nodes.FieldNotStream, fmt.Errorf("dynamic group %s of node %s does not contain sub source for choice %d", group, nodeKey, choice) + return schema2.FieldNotStream, fmt.Errorf("dynamic group %s of node %s does not contain sub source for choice %d", group, nodeKey, choice) } - if subInfo.FieldType != nodes.FieldMaybeStream { + if subInfo.FieldType != schema2.FieldMaybeStream { return subInfo.FieldType, nil } @@ -211,8 +213,8 @@ func (s *State) GetDynamicStreamType(nodeKey vo.NodeKey, group string) (nodes.Fi return s.GetDynamicStreamType(subInfo.FromNodeKey, subInfo.FromPath[0]) } -func (s *State) GetAllDynamicStreamTypes(nodeKey vo.NodeKey) (map[string]nodes.FieldStreamType, error) { - result := make(map[string]nodes.FieldStreamType) +func (s *State) GetAllDynamicStreamTypes(nodeKey vo.NodeKey) (map[string]schema2.FieldStreamType, error) { + result := make(map[string]schema2.FieldStreamType) choices, ok := s.GroupChoices[nodeKey] if !ok { return result, nil @@ -269,7 +271,7 @@ func GenState() compose.GenLocalState[*State] { InterruptEvents: make(map[vo.NodeKey]*entity.InterruptEvent), NestedWorkflowStates: make(map[vo.NodeKey]*nodes.NestedWorkflowState), ExecutedNodes: make(map[vo.NodeKey]bool), - SourceInfos: make(map[vo.NodeKey]map[string]*nodes.SourceInfo), + SourceInfos: make(map[vo.NodeKey]map[string]*schema2.SourceInfo), GroupChoices: make(map[vo.NodeKey]map[string]int), ToolInterruptEvents: make(map[vo.NodeKey]map[string]*entity.ToolInterruptEvent), LLMToResumeData: make(map[vo.NodeKey]string), @@ -277,7 +279,7 @@ func GenState() compose.GenLocalState[*State] { } } -func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt { +func statePreHandler(s *schema2.NodeSchema, stream bool) compose.GraphAddNodeOpt { var ( handlers []compose.StatePreHandler[map[string]any, *State] streamHandlers []compose.StreamStatePreHandler[map[string]any, *State] @@ -314,7 +316,7 @@ func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt { } return in, nil }) - } else if s.Type == entity.NodeTypeBatch || s.Type == entity.NodeTypeLoop { + } else if entity.NodeMetaByNodeType(s.Type).IsComposite { handlers = append(handlers, func(ctx context.Context, in map[string]any, state *State) (map[string]any, error) { if _, ok := state.Inputs[s.Key]; !ok { // first execution, store input for potential resume later state.Inputs[s.Key] = in @@ -329,7 +331,7 @@ func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt { } if len(handlers) > 0 || !stream { - handlerForVars := s.statePreHandlerForVars() + handlerForVars := statePreHandlerForVars(s) if handlerForVars != nil { handlers = append(handlers, handlerForVars) } @@ -349,12 +351,12 @@ func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt { if s.Type == entity.NodeTypeVariableAggregator { streamHandlers = append(streamHandlers, func(ctx context.Context, in *schema.StreamReader[map[string]any], state *State) (*schema.StreamReader[map[string]any], error) { - state.SourceInfos[s.Key] = mustGetKey[map[string]*nodes.SourceInfo]("FullSources", s.Configs) + state.SourceInfos[s.Key] = s.FullSources return in, nil }) } - handlerForVars := s.streamStatePreHandlerForVars() + handlerForVars := streamStatePreHandlerForVars(s) if handlerForVars != nil { streamHandlers = append(streamHandlers, handlerForVars) } @@ -381,7 +383,7 @@ func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt { return nil } -func (s *NodeSchema) statePreHandlerForVars() compose.StatePreHandler[map[string]any, *State] { +func statePreHandlerForVars(s *schema2.NodeSchema) compose.StatePreHandler[map[string]any, *State] { // checkout the node's inputs, if it has any variable, use the state's variableHandler to get the variables and set them to the input var vars []*vo.FieldInfo for _, input := range s.InputSources { @@ -456,7 +458,7 @@ func (s *NodeSchema) statePreHandlerForVars() compose.StatePreHandler[map[string } } -func (s *NodeSchema) streamStatePreHandlerForVars() compose.StreamStatePreHandler[map[string]any, *State] { +func streamStatePreHandlerForVars(s *schema2.NodeSchema) compose.StreamStatePreHandler[map[string]any, *State] { // checkout the node's inputs, if it has any variables, get the variables and merge them with the input var vars []*vo.FieldInfo for _, input := range s.InputSources { @@ -533,7 +535,7 @@ func (s *NodeSchema) streamStatePreHandlerForVars() compose.StreamStatePreHandle } } -func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamStatePreHandler[map[string]any, *State] { +func streamStatePreHandlerForStreamSources(s *schema2.NodeSchema) compose.StreamStatePreHandler[map[string]any, *State] { // if it does not have source info, do not add this pre handler if s.Configs == nil { return nil @@ -543,7 +545,7 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState case entity.NodeTypeVariableAggregator, entity.NodeTypeOutputEmitter: return nil case entity.NodeTypeExit: - terminatePlan := mustGetKey[vo.TerminatePlan]("TerminalPlan", s.Configs) + terminatePlan := s.Configs.(*exit.Config).TerminatePlan if terminatePlan != vo.ReturnVariables { return nil } @@ -551,7 +553,7 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState // all other node can only accept non-stream inputs, relying on Eino's automatically stream concatenation. } - sourceInfo := getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs) + sourceInfo := s.FullSources if len(sourceInfo) == 0 { return nil } @@ -566,10 +568,10 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState var ( anyStream bool - checker func(source *nodes.SourceInfo) bool + checker func(source *schema2.SourceInfo) bool ) - checker = func(source *nodes.SourceInfo) bool { - if source.FieldType != nodes.FieldNotStream { + checker = func(source *schema2.SourceInfo) bool { + if source.FieldType != schema2.FieldNotStream { return true } for _, subSource := range source.SubSources { @@ -594,8 +596,8 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState return func(ctx context.Context, in *schema.StreamReader[map[string]any], state *State) (*schema.StreamReader[map[string]any], error) { resolved := map[string]resolvedStreamSource{} - var resolver func(source nodes.SourceInfo) (result *resolvedStreamSource, err error) - resolver = func(source nodes.SourceInfo) (result *resolvedStreamSource, err error) { + var resolver func(source schema2.SourceInfo) (result *resolvedStreamSource, err error) + resolver = func(source schema2.SourceInfo) (result *resolvedStreamSource, err error) { if source.IsIntermediate { result = &resolvedStreamSource{ intermediate: true, @@ -615,14 +617,14 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState } streamType := source.FieldType - if streamType == nodes.FieldMaybeStream { + if streamType == schema2.FieldMaybeStream { streamType, err = state.GetDynamicStreamType(source.FromNodeKey, source.FromPath[0]) if err != nil { return nil, err } } - if streamType == nodes.FieldNotStream { + if streamType == schema2.FieldNotStream { return nil, nil } @@ -690,7 +692,7 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState } } -func (s *NodeSchema) StatePostHandler(stream bool) compose.GraphAddNodeOpt { +func statePostHandler(s *schema2.NodeSchema, stream bool) compose.GraphAddNodeOpt { var ( handlers []compose.StatePostHandler[map[string]any, *State] streamHandlers []compose.StreamStatePostHandler[map[string]any, *State] @@ -702,7 +704,7 @@ func (s *NodeSchema) StatePostHandler(stream bool) compose.GraphAddNodeOpt { return out, nil }) - forVars := s.streamStatePostHandlerForVars() + forVars := streamStatePostHandlerForVars(s) if forVars != nil { streamHandlers = append(streamHandlers, forVars) } @@ -725,7 +727,7 @@ func (s *NodeSchema) StatePostHandler(stream bool) compose.GraphAddNodeOpt { return out, nil }) - forVars := s.statePostHandlerForVars() + forVars := statePostHandlerForVars(s) if forVars != nil { handlers = append(handlers, forVars) } @@ -745,7 +747,7 @@ func (s *NodeSchema) StatePostHandler(stream bool) compose.GraphAddNodeOpt { return compose.WithStatePostHandler(handler) } -func (s *NodeSchema) statePostHandlerForVars() compose.StatePostHandler[map[string]any, *State] { +func statePostHandlerForVars(s *schema2.NodeSchema) compose.StatePostHandler[map[string]any, *State] { // checkout the node's output sources, if it has any variable, // use the state's variableHandler to get the variables and set them to the output var vars []*vo.FieldInfo @@ -823,7 +825,7 @@ func (s *NodeSchema) statePostHandlerForVars() compose.StatePostHandler[map[stri } } -func (s *NodeSchema) streamStatePostHandlerForVars() compose.StreamStatePostHandler[map[string]any, *State] { +func streamStatePostHandlerForVars(s *schema2.NodeSchema) compose.StreamStatePostHandler[map[string]any, *State] { // checkout the node's output sources, if it has any variables, get the variables and merge them with the output var vars []*vo.FieldInfo for _, output := range s.OutputSources { diff --git a/backend/domain/workflow/internal/compose/stream.go b/backend/domain/workflow/internal/compose/stream.go index c2dea6ad..fb5376d4 100644 --- a/backend/domain/workflow/internal/compose/stream.go +++ b/backend/domain/workflow/internal/compose/stream.go @@ -21,19 +21,20 @@ import ( "github.com/cloudwego/eino/compose" - "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" ) -// SetFullSources calculates REAL input sources for a node. +// GetFullSources calculates REAL input sources for a node. // It may be different from a NodeSchema's InputSources because of the following reasons: // 1. a inner node under composite node may refer to a field from a node in its parent workflow, // this is instead routed to and sourced from the inner workflow's start node. // 2. at the same time, the composite node needs to delegate the input source to the inner workflow. // 3. also, some node may have implicit input sources not defined in its NodeSchema's InputSources. -func (s *NodeSchema) SetFullSources(allNS map[vo.NodeKey]*NodeSchema, dep *dependencyInfo) error { - fullSource := make(map[string]*nodes.SourceInfo) +func GetFullSources(s *schema.NodeSchema, sc *schema.WorkflowSchema, dep *dependencyInfo) ( + map[string]*schema.SourceInfo, error) { + fullSource := make(map[string]*schema.SourceInfo) var fieldInfos []vo.FieldInfo for _, s := range dep.staticValues { fieldInfos = append(fieldInfos, vo.FieldInfo{ @@ -113,14 +114,14 @@ func (s *NodeSchema) SetFullSources(allNS map[vo.NodeKey]*NodeSchema, dep *depen tInfo = tInfo.Properties[path[j]] } if current, ok := currentSource[path[j]]; !ok { - currentSource[path[j]] = &nodes.SourceInfo{ + currentSource[path[j]] = &schema.SourceInfo{ IsIntermediate: true, - FieldType: nodes.FieldNotStream, + FieldType: schema.FieldNotStream, TypeInfo: tInfo, - SubSources: make(map[string]*nodes.SourceInfo), + SubSources: make(map[string]*schema.SourceInfo), } } else if !current.IsIntermediate { - return fmt.Errorf("existing sourceInfo for path %s is not intermediate, conflict", path[:j+1]) + return nil, fmt.Errorf("existing sourceInfo for path %s is not intermediate, conflict", path[:j+1]) } currentSource = currentSource[path[j]].SubSources @@ -135,9 +136,9 @@ func (s *NodeSchema) SetFullSources(allNS map[vo.NodeKey]*NodeSchema, dep *depen // static values or variables if fInfo.Source.Ref == nil || fInfo.Source.Ref.FromNodeKey == "" { - currentSource[lastPath] = &nodes.SourceInfo{ + currentSource[lastPath] = &schema.SourceInfo{ IsIntermediate: false, - FieldType: nodes.FieldNotStream, + FieldType: schema.FieldNotStream, TypeInfo: tInfo, } continue @@ -145,25 +146,25 @@ func (s *NodeSchema) SetFullSources(allNS map[vo.NodeKey]*NodeSchema, dep *depen fromNodeKey := fInfo.Source.Ref.FromNodeKey var ( - streamType nodes.FieldStreamType + streamType schema.FieldStreamType err error ) if len(fromNodeKey) > 0 { if fromNodeKey == compose.START { - streamType = nodes.FieldNotStream // TODO: set start node to not stream for now until composite node supports transform + streamType = schema.FieldNotStream // TODO: set start node to not stream for now until composite node supports transform } else { - fromNode, ok := allNS[fromNodeKey] - if !ok { - return fmt.Errorf("node %s not found", fromNodeKey) + fromNode := sc.GetNode(fromNodeKey) + if fromNode == nil { + return nil, fmt.Errorf("node %s not found", fromNodeKey) } - streamType, err = fromNode.IsStreamingField(fInfo.Source.Ref.FromPath, allNS) + streamType, err = nodes.IsStreamingField(fromNode, fInfo.Source.Ref.FromPath, sc) if err != nil { - return err + return nil, err } } } - currentSource[lastPath] = &nodes.SourceInfo{ + currentSource[lastPath] = &schema.SourceInfo{ IsIntermediate: false, FieldType: streamType, FromNodeKey: fromNodeKey, @@ -172,121 +173,5 @@ func (s *NodeSchema) SetFullSources(allNS map[vo.NodeKey]*NodeSchema, dep *depen } } - s.Configs.(map[string]any)["FullSources"] = fullSource - return nil -} - -func (s *NodeSchema) IsStreamingField(path compose.FieldPath, allNS map[vo.NodeKey]*NodeSchema) (nodes.FieldStreamType, error) { - if s.Type == entity.NodeTypeExit { - if mustGetKey[nodes.Mode]("Mode", s.Configs) == nodes.Streaming { - if len(path) == 1 && path[0] == "output" { - return nodes.FieldIsStream, nil - } - } - - return nodes.FieldNotStream, nil - } else if s.Type == entity.NodeTypeSubWorkflow { // TODO: why not use sub workflow's Mode configuration directly? - subSC := s.SubWorkflowSchema - subExit := subSC.GetNode(entity.ExitNodeKey) - subStreamType, err := subExit.IsStreamingField(path, nil) - if err != nil { - return nodes.FieldNotStream, err - } - - return subStreamType, nil - } else if s.Type == entity.NodeTypeVariableAggregator { - if len(path) == 2 { // asking about a specific index within a group - for _, fInfo := range s.InputSources { - if len(fInfo.Path) == len(path) { - equal := true - for i := range fInfo.Path { - if fInfo.Path[i] != path[i] { - equal = false - break - } - } - - if equal { - if fInfo.Source.Ref == nil || fInfo.Source.Ref.FromNodeKey == "" { - return nodes.FieldNotStream, nil - } - fromNodeKey := fInfo.Source.Ref.FromNodeKey - fromNode, ok := allNS[fromNodeKey] - if !ok { - return nodes.FieldNotStream, fmt.Errorf("node %s not found", fromNodeKey) - } - return fromNode.IsStreamingField(fInfo.Source.Ref.FromPath, allNS) - } - } - } - } else if len(path) == 1 { // asking about the entire group - var streamCount, notStreamCount int - for _, fInfo := range s.InputSources { - if fInfo.Path[0] == path[0] { // belong to the group - if fInfo.Source.Ref != nil && len(fInfo.Source.Ref.FromNodeKey) > 0 { - fromNode, ok := allNS[fInfo.Source.Ref.FromNodeKey] - if !ok { - return nodes.FieldNotStream, fmt.Errorf("node %s not found", fInfo.Source.Ref.FromNodeKey) - } - subStreamType, err := fromNode.IsStreamingField(fInfo.Source.Ref.FromPath, allNS) - if err != nil { - return nodes.FieldNotStream, err - } - - if subStreamType == nodes.FieldMaybeStream { - return nodes.FieldMaybeStream, nil - } else if subStreamType == nodes.FieldIsStream { - streamCount++ - } else { - notStreamCount++ - } - } - } - } - - if streamCount > 0 && notStreamCount == 0 { - return nodes.FieldIsStream, nil - } - - if streamCount == 0 && notStreamCount > 0 { - return nodes.FieldNotStream, nil - } - - return nodes.FieldMaybeStream, nil - } - } - - if s.Type != entity.NodeTypeLLM { - return nodes.FieldNotStream, nil - } - - if len(path) != 1 { - return nodes.FieldNotStream, nil - } - - outputs := s.OutputTypes - if len(outputs) != 1 && len(outputs) != 2 { - return nodes.FieldNotStream, nil - } - - var outputKey string - for key, output := range outputs { - if output.Type != vo.DataTypeString { - return nodes.FieldNotStream, nil - } - - if key != "reasoning_content" { - if len(outputKey) > 0 { - return nodes.FieldNotStream, nil - } - outputKey = key - } - } - - field := path[0] - if field == "reasoning_content" || field == outputKey { - return nodes.FieldIsStream, nil - } - - return nodes.FieldNotStream, nil + return fullSource, nil } diff --git a/backend/domain/workflow/internal/compose/test/batch_test.go b/backend/domain/workflow/internal/compose/test/batch_test.go index ee8bbba5..e4e63a79 100644 --- a/backend/domain/workflow/internal/compose/test/batch_test.go +++ b/backend/domain/workflow/internal/compose/test/batch_test.go @@ -28,6 +28,9 @@ import ( "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/batch" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" ) func TestBatch(t *testing.T) { @@ -52,7 +55,7 @@ func TestBatch(t *testing.T) { return in, nil } - lambdaNode1 := &compose2.NodeSchema{ + lambdaNode1 := &schema.NodeSchema{ Key: "lambda", Type: entity.NodeTypeLambda, Lambda: compose.InvokableLambda(lambda1), @@ -86,7 +89,7 @@ func TestBatch(t *testing.T) { }, }, } - lambdaNode2 := &compose2.NodeSchema{ + lambdaNode2 := &schema.NodeSchema{ Key: "index", Type: entity.NodeTypeLambda, Lambda: compose.InvokableLambda(lambda2), @@ -103,7 +106,7 @@ func TestBatch(t *testing.T) { }, } - lambdaNode3 := &compose2.NodeSchema{ + lambdaNode3 := &schema.NodeSchema{ Key: "consumer", Type: entity.NodeTypeLambda, Lambda: compose.InvokableLambda(lambda3), @@ -135,23 +138,22 @@ func TestBatch(t *testing.T) { }, } - entry := &compose2.NodeSchema{ - Key: entity.EntryNodeKey, - Type: entity.NodeTypeEntry, - Configs: map[string]any{ - "DefaultValues": map[string]any{}, - }, + entryN := &schema.NodeSchema{ + Key: entity.EntryNodeKey, + Type: entity.NodeTypeEntry, + Configs: &entry.Config{}, } - ns := &compose2.NodeSchema{ - Key: "batch_node_key", - Type: entity.NodeTypeBatch, + ns := &schema.NodeSchema{ + Key: "batch_node_key", + Type: entity.NodeTypeBatch, + Configs: &batch.Config{}, InputSources: []*vo.FieldInfo{ { Path: compose.FieldPath{"array_1"}, Source: vo.FieldSource{ Ref: &vo.Reference{ - FromNodeKey: entry.Key, + FromNodeKey: entryN.Key, FromPath: compose.FieldPath{"array_1"}, }, }, @@ -160,7 +162,7 @@ func TestBatch(t *testing.T) { Path: compose.FieldPath{"array_2"}, Source: vo.FieldSource{ Ref: &vo.Reference{ - FromNodeKey: entry.Key, + FromNodeKey: entryN.Key, FromPath: compose.FieldPath{"array_2"}, }, }, @@ -214,11 +216,11 @@ func TestBatch(t *testing.T) { }, } - exit := &compose2.NodeSchema{ + exitN := &schema.NodeSchema{ Key: entity.ExitNodeKey, Type: entity.NodeTypeExit, - Configs: map[string]any{ - "TerminalPlan": vo.ReturnVariables, + Configs: &exit.Config{ + TerminatePlan: vo.ReturnVariables, }, InputSources: []*vo.FieldInfo{ { @@ -246,18 +248,18 @@ func TestBatch(t *testing.T) { return map[string]any{"success": true}, nil } - parentLambdaNode := &compose2.NodeSchema{ + parentLambdaNode := &schema.NodeSchema{ Key: "parent_predecessor_1", Type: entity.NodeTypeLambda, Lambda: compose.InvokableLambda(parentLambda), } - ws := &compose2.WorkflowSchema{ - Nodes: []*compose2.NodeSchema{ - entry, + ws := &schema.WorkflowSchema{ + Nodes: []*schema.NodeSchema{ + entryN, parentLambdaNode, ns, - exit, + exitN, lambdaNode1, lambdaNode2, lambdaNode3, @@ -267,7 +269,7 @@ func TestBatch(t *testing.T) { "index": "batch_node_key", "consumer": "batch_node_key", }, - Connections: []*compose2.Connection{ + Connections: []*schema.Connection{ { FromNode: entity.EntryNodeKey, ToNode: "parent_predecessor_1", diff --git a/backend/domain/workflow/internal/compose/test/llm_test.go b/backend/domain/workflow/internal/compose/test/llm_test.go index 76b1713d..4969ca8b 100644 --- a/backend/domain/workflow/internal/compose/test/llm_test.go +++ b/backend/domain/workflow/internal/compose/test/llm_test.go @@ -40,7 +40,11 @@ import ( "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/emitter" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm" + schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr" "github.com/coze-dev/coze-studio/backend/internal/testutil" "github.com/coze-dev/coze-studio/backend/pkg/ctxcache" @@ -108,22 +112,20 @@ func TestLLM(t *testing.T) { } } - entry := &compose2.NodeSchema{ - Key: entity.EntryNodeKey, - Type: entity.NodeTypeEntry, - Configs: map[string]any{ - "DefaultValues": map[string]any{}, - }, + entryN := &schema2.NodeSchema{ + Key: entity.EntryNodeKey, + Type: entity.NodeTypeEntry, + Configs: &entry.Config{}, } - llmNode := &compose2.NodeSchema{ + llmNode := &schema2.NodeSchema{ Key: "llm_node_key", Type: entity.NodeTypeLLM, - Configs: map[string]any{ - "SystemPrompt": "{{sys_prompt}}", - "UserPrompt": "{{query}}", - "OutputFormat": llm.FormatText, - "LLMParams": &model.LLMParams{ + Configs: &llm.Config{ + SystemPrompt: "{{sys_prompt}}", + UserPrompt: "{{query}}", + OutputFormat: llm.FormatText, + LLMParams: &model.LLMParams{ ModelName: modelName, }, }, @@ -132,7 +134,7 @@ func TestLLM(t *testing.T) { Path: compose.FieldPath{"sys_prompt"}, Source: vo.FieldSource{ Ref: &vo.Reference{ - FromNodeKey: entry.Key, + FromNodeKey: entryN.Key, FromPath: compose.FieldPath{"sys_prompt"}, }, }, @@ -141,7 +143,7 @@ func TestLLM(t *testing.T) { Path: compose.FieldPath{"query"}, Source: vo.FieldSource{ Ref: &vo.Reference{ - FromNodeKey: entry.Key, + FromNodeKey: entryN.Key, FromPath: compose.FieldPath{"query"}, }, }, @@ -162,11 +164,11 @@ func TestLLM(t *testing.T) { }, } - exit := &compose2.NodeSchema{ + exitN := &schema2.NodeSchema{ Key: entity.ExitNodeKey, Type: entity.NodeTypeExit, - Configs: map[string]any{ - "TerminalPlan": vo.ReturnVariables, + Configs: &exit.Config{ + TerminatePlan: vo.ReturnVariables, }, InputSources: []*vo.FieldInfo{ { @@ -181,20 +183,20 @@ func TestLLM(t *testing.T) { }, } - ws := &compose2.WorkflowSchema{ - Nodes: []*compose2.NodeSchema{ - entry, + ws := &schema2.WorkflowSchema{ + Nodes: []*schema2.NodeSchema{ + entryN, llmNode, - exit, + exitN, }, - Connections: []*compose2.Connection{ + Connections: []*schema2.Connection{ { - FromNode: entry.Key, + FromNode: entryN.Key, ToNode: llmNode.Key, }, { FromNode: llmNode.Key, - ToNode: exit.Key, + ToNode: exitN.Key, }, }, } @@ -228,27 +230,20 @@ func TestLLM(t *testing.T) { } } - entry := &compose2.NodeSchema{ - Key: entity.EntryNodeKey, - Type: entity.NodeTypeEntry, - Configs: map[string]any{ - "DefaultValues": map[string]any{}, - }, + entryN := &schema2.NodeSchema{ + Key: entity.EntryNodeKey, + Type: entity.NodeTypeEntry, + Configs: &entry.Config{}, } - llmNode := &compose2.NodeSchema{ + llmNode := &schema2.NodeSchema{ Key: "llm_node_key", Type: entity.NodeTypeLLM, - Configs: map[string]any{ - "SystemPrompt": "you are a helpful assistant", - "UserPrompt": "what's the largest country in the world and it's area size in square kilometers?", - "OutputFormat": llm.FormatJSON, - "IgnoreException": true, - "DefaultOutput": map[string]any{ - "country_name": "unknown", - "area_size": int64(0), - }, - "LLMParams": &model.LLMParams{ + Configs: &llm.Config{ + SystemPrompt: "you are a helpful assistant", + UserPrompt: "what's the largest country in the world and it's area size in square kilometers?", + OutputFormat: llm.FormatJSON, + LLMParams: &model.LLMParams{ ModelName: modelName, }, }, @@ -264,11 +259,11 @@ func TestLLM(t *testing.T) { }, } - exit := &compose2.NodeSchema{ + exitN := &schema2.NodeSchema{ Key: entity.ExitNodeKey, Type: entity.NodeTypeExit, - Configs: map[string]any{ - "TerminalPlan": vo.ReturnVariables, + Configs: &exit.Config{ + TerminatePlan: vo.ReturnVariables, }, InputSources: []*vo.FieldInfo{ { @@ -292,20 +287,20 @@ func TestLLM(t *testing.T) { }, } - ws := &compose2.WorkflowSchema{ - Nodes: []*compose2.NodeSchema{ - entry, + ws := &schema2.WorkflowSchema{ + Nodes: []*schema2.NodeSchema{ + entryN, llmNode, - exit, + exitN, }, - Connections: []*compose2.Connection{ + Connections: []*schema2.Connection{ { - FromNode: entry.Key, + FromNode: entryN.Key, ToNode: llmNode.Key, }, { FromNode: llmNode.Key, - ToNode: exit.Key, + ToNode: exitN.Key, }, }, } @@ -337,22 +332,20 @@ func TestLLM(t *testing.T) { } } - entry := &compose2.NodeSchema{ - Key: entity.EntryNodeKey, - Type: entity.NodeTypeEntry, - Configs: map[string]any{ - "DefaultValues": map[string]any{}, - }, + entryN := &schema2.NodeSchema{ + Key: entity.EntryNodeKey, + Type: entity.NodeTypeEntry, + Configs: &entry.Config{}, } - llmNode := &compose2.NodeSchema{ + llmNode := &schema2.NodeSchema{ Key: "llm_node_key", Type: entity.NodeTypeLLM, - Configs: map[string]any{ - "SystemPrompt": "you are a helpful assistant", - "UserPrompt": "list the top 5 largest countries in the world", - "OutputFormat": llm.FormatMarkdown, - "LLMParams": &model.LLMParams{ + Configs: &llm.Config{ + SystemPrompt: "you are a helpful assistant", + UserPrompt: "list the top 5 largest countries in the world", + OutputFormat: llm.FormatMarkdown, + LLMParams: &model.LLMParams{ ModelName: modelName, }, }, @@ -363,11 +356,11 @@ func TestLLM(t *testing.T) { }, } - exit := &compose2.NodeSchema{ + exitN := &schema2.NodeSchema{ Key: entity.ExitNodeKey, Type: entity.NodeTypeExit, - Configs: map[string]any{ - "TerminalPlan": vo.ReturnVariables, + Configs: &exit.Config{ + TerminatePlan: vo.ReturnVariables, }, InputSources: []*vo.FieldInfo{ { @@ -382,20 +375,20 @@ func TestLLM(t *testing.T) { }, } - ws := &compose2.WorkflowSchema{ - Nodes: []*compose2.NodeSchema{ - entry, + ws := &schema2.WorkflowSchema{ + Nodes: []*schema2.NodeSchema{ + entryN, llmNode, - exit, + exitN, }, - Connections: []*compose2.Connection{ + Connections: []*schema2.Connection{ { - FromNode: entry.Key, + FromNode: entryN.Key, ToNode: llmNode.Key, }, { FromNode: llmNode.Key, - ToNode: exit.Key, + ToNode: exitN.Key, }, }, } @@ -456,22 +449,20 @@ func TestLLM(t *testing.T) { } } - entry := &compose2.NodeSchema{ - Key: entity.EntryNodeKey, - Type: entity.NodeTypeEntry, - Configs: map[string]any{ - "DefaultValues": map[string]any{}, - }, + entryN := &schema2.NodeSchema{ + Key: entity.EntryNodeKey, + Type: entity.NodeTypeEntry, + Configs: &entry.Config{}, } - openaiNode := &compose2.NodeSchema{ + openaiNode := &schema2.NodeSchema{ Key: "openai_llm_node_key", Type: entity.NodeTypeLLM, - Configs: map[string]any{ - "SystemPrompt": "you are a helpful assistant", - "UserPrompt": "plan a 10 day family visit to China.", - "OutputFormat": llm.FormatText, - "LLMParams": &model.LLMParams{ + Configs: &llm.Config{ + SystemPrompt: "you are a helpful assistant", + UserPrompt: "plan a 10 day family visit to China.", + OutputFormat: llm.FormatText, + LLMParams: &model.LLMParams{ ModelName: modelName, }, }, @@ -482,14 +473,14 @@ func TestLLM(t *testing.T) { }, } - deepseekNode := &compose2.NodeSchema{ + deepseekNode := &schema2.NodeSchema{ Key: "deepseek_llm_node_key", Type: entity.NodeTypeLLM, - Configs: map[string]any{ - "SystemPrompt": "you are a helpful assistant", - "UserPrompt": "thoroughly plan a 10 day family visit to China. Use your reasoning ability.", - "OutputFormat": llm.FormatText, - "LLMParams": &model.LLMParams{ + Configs: &llm.Config{ + SystemPrompt: "you are a helpful assistant", + UserPrompt: "thoroughly plan a 10 day family visit to China. Use your reasoning ability.", + OutputFormat: llm.FormatText, + LLMParams: &model.LLMParams{ ModelName: modelName, }, }, @@ -503,12 +494,11 @@ func TestLLM(t *testing.T) { }, } - emitterNode := &compose2.NodeSchema{ + emitterNode := &schema2.NodeSchema{ Key: "emitter_node_key", Type: entity.NodeTypeOutputEmitter, - Configs: map[string]any{ - "Template": "prefix {{inputObj.field1}} {{input2}} {{deepseek_reasoning}} \n\n###\n\n {{openai_output}} \n\n###\n\n {{deepseek_output}} {{inputObj.field2}} suffix", - "Mode": nodes.Streaming, + Configs: &emitter.Config{ + Template: "prefix {{inputObj.field1}} {{input2}} {{deepseek_reasoning}} \n\n###\n\n {{openai_output}} \n\n###\n\n {{deepseek_output}} {{inputObj.field2}} suffix", }, InputSources: []*vo.FieldInfo{ { @@ -542,7 +532,7 @@ func TestLLM(t *testing.T) { Path: compose.FieldPath{"inputObj"}, Source: vo.FieldSource{ Ref: &vo.Reference{ - FromNodeKey: entry.Key, + FromNodeKey: entryN.Key, FromPath: compose.FieldPath{"inputObj"}, }, }, @@ -551,7 +541,7 @@ func TestLLM(t *testing.T) { Path: compose.FieldPath{"input2"}, Source: vo.FieldSource{ Ref: &vo.Reference{ - FromNodeKey: entry.Key, + FromNodeKey: entryN.Key, FromPath: compose.FieldPath{"input2"}, }, }, @@ -559,11 +549,11 @@ func TestLLM(t *testing.T) { }, } - exit := &compose2.NodeSchema{ + exitN := &schema2.NodeSchema{ Key: entity.ExitNodeKey, Type: entity.NodeTypeExit, - Configs: map[string]any{ - "TerminalPlan": vo.UseAnswerContent, + Configs: &exit.Config{ + TerminatePlan: vo.UseAnswerContent, }, InputSources: []*vo.FieldInfo{ { @@ -596,17 +586,17 @@ func TestLLM(t *testing.T) { }, } - ws := &compose2.WorkflowSchema{ - Nodes: []*compose2.NodeSchema{ - entry, + ws := &schema2.WorkflowSchema{ + Nodes: []*schema2.NodeSchema{ + entryN, openaiNode, deepseekNode, emitterNode, - exit, + exitN, }, - Connections: []*compose2.Connection{ + Connections: []*schema2.Connection{ { - FromNode: entry.Key, + FromNode: entryN.Key, ToNode: openaiNode.Key, }, { @@ -614,7 +604,7 @@ func TestLLM(t *testing.T) { ToNode: emitterNode.Key, }, { - FromNode: entry.Key, + FromNode: entryN.Key, ToNode: deepseekNode.Key, }, { @@ -623,7 +613,7 @@ func TestLLM(t *testing.T) { }, { FromNode: emitterNode.Key, - ToNode: exit.Key, + ToNode: exitN.Key, }, }, } diff --git a/backend/domain/workflow/internal/compose/test/loop_test.go b/backend/domain/workflow/internal/compose/test/loop_test.go index 944adc78..94023fa1 100644 --- a/backend/domain/workflow/internal/compose/test/loop_test.go +++ b/backend/domain/workflow/internal/compose/test/loop_test.go @@ -26,15 +26,20 @@ import ( "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop" + _break "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop/break" + _continue "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop/continue" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableassigner" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" ) func TestLoop(t *testing.T) { t.Run("by iteration", func(t *testing.T) { // start-> loop_node_key[innerNode->continue] -> end - innerNode := &compose2.NodeSchema{ + innerNode := &schema.NodeSchema{ Key: "innerNode", Type: entity.NodeTypeLambda, Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) { @@ -54,31 +59,30 @@ func TestLoop(t *testing.T) { }, } - continueNode := &compose2.NodeSchema{ - Key: "continueNode", - Type: entity.NodeTypeContinue, + continueNode := &schema.NodeSchema{ + Key: "continueNode", + Type: entity.NodeTypeContinue, + Configs: &_continue.Config{}, } - entry := &compose2.NodeSchema{ - Key: entity.EntryNodeKey, - Type: entity.NodeTypeEntry, - Configs: map[string]any{ - "DefaultValues": map[string]any{}, - }, + entryN := &schema.NodeSchema{ + Key: entity.EntryNodeKey, + Type: entity.NodeTypeEntry, + Configs: &entry.Config{}, } - loopNode := &compose2.NodeSchema{ + loopNode := &schema.NodeSchema{ Key: "loop_node_key", Type: entity.NodeTypeLoop, - Configs: map[string]any{ - "LoopType": loop.ByIteration, + Configs: &loop.Config{ + LoopType: loop.ByIteration, }, InputSources: []*vo.FieldInfo{ { Path: compose.FieldPath{loop.Count}, Source: vo.FieldSource{ Ref: &vo.Reference{ - FromNodeKey: entry.Key, + FromNodeKey: entryN.Key, FromPath: compose.FieldPath{"count"}, }, }, @@ -97,11 +101,11 @@ func TestLoop(t *testing.T) { }, } - exit := &compose2.NodeSchema{ + exitN := &schema.NodeSchema{ Key: entity.ExitNodeKey, Type: entity.NodeTypeExit, - Configs: map[string]any{ - "TerminalPlan": vo.ReturnVariables, + Configs: &exit.Config{ + TerminatePlan: vo.ReturnVariables, }, InputSources: []*vo.FieldInfo{ { @@ -116,11 +120,11 @@ func TestLoop(t *testing.T) { }, } - ws := &compose2.WorkflowSchema{ - Nodes: []*compose2.NodeSchema{ - entry, + ws := &schema.WorkflowSchema{ + Nodes: []*schema.NodeSchema{ + entryN, loopNode, - exit, + exitN, innerNode, continueNode, }, @@ -128,7 +132,7 @@ func TestLoop(t *testing.T) { "innerNode": "loop_node_key", "continueNode": "loop_node_key", }, - Connections: []*compose2.Connection{ + Connections: []*schema.Connection{ { FromNode: "loop_node_key", ToNode: "innerNode", @@ -142,12 +146,12 @@ func TestLoop(t *testing.T) { ToNode: "loop_node_key", }, { - FromNode: entry.Key, + FromNode: entryN.Key, ToNode: "loop_node_key", }, { FromNode: "loop_node_key", - ToNode: exit.Key, + ToNode: exitN.Key, }, }, } @@ -168,7 +172,7 @@ func TestLoop(t *testing.T) { t.Run("infinite", func(t *testing.T) { // start-> loop_node_key[innerNode->break] -> end - innerNode := &compose2.NodeSchema{ + innerNode := &schema.NodeSchema{ Key: "innerNode", Type: entity.NodeTypeLambda, Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) { @@ -188,24 +192,23 @@ func TestLoop(t *testing.T) { }, } - breakNode := &compose2.NodeSchema{ - Key: "breakNode", - Type: entity.NodeTypeBreak, + breakNode := &schema.NodeSchema{ + Key: "breakNode", + Type: entity.NodeTypeBreak, + Configs: &_break.Config{}, } - entry := &compose2.NodeSchema{ - Key: entity.EntryNodeKey, - Type: entity.NodeTypeEntry, - Configs: map[string]any{ - "DefaultValues": map[string]any{}, - }, + entryN := &schema.NodeSchema{ + Key: entity.EntryNodeKey, + Type: entity.NodeTypeEntry, + Configs: &entry.Config{}, } - loopNode := &compose2.NodeSchema{ + loopNode := &schema.NodeSchema{ Key: "loop_node_key", Type: entity.NodeTypeLoop, - Configs: map[string]any{ - "LoopType": loop.Infinite, + Configs: &loop.Config{ + LoopType: loop.Infinite, }, OutputSources: []*vo.FieldInfo{ { @@ -220,11 +223,11 @@ func TestLoop(t *testing.T) { }, } - exit := &compose2.NodeSchema{ + exitN := &schema.NodeSchema{ Key: entity.ExitNodeKey, Type: entity.NodeTypeExit, - Configs: map[string]any{ - "TerminalPlan": vo.ReturnVariables, + Configs: &exit.Config{ + TerminatePlan: vo.ReturnVariables, }, InputSources: []*vo.FieldInfo{ { @@ -239,11 +242,11 @@ func TestLoop(t *testing.T) { }, } - ws := &compose2.WorkflowSchema{ - Nodes: []*compose2.NodeSchema{ - entry, + ws := &schema.WorkflowSchema{ + Nodes: []*schema.NodeSchema{ + entryN, loopNode, - exit, + exitN, innerNode, breakNode, }, @@ -251,7 +254,7 @@ func TestLoop(t *testing.T) { "innerNode": "loop_node_key", "breakNode": "loop_node_key", }, - Connections: []*compose2.Connection{ + Connections: []*schema.Connection{ { FromNode: "loop_node_key", ToNode: "innerNode", @@ -265,12 +268,12 @@ func TestLoop(t *testing.T) { ToNode: "loop_node_key", }, { - FromNode: entry.Key, + FromNode: entryN.Key, ToNode: "loop_node_key", }, { FromNode: "loop_node_key", - ToNode: exit.Key, + ToNode: exitN.Key, }, }, } @@ -290,14 +293,14 @@ func TestLoop(t *testing.T) { t.Run("by array", func(t *testing.T) { // start-> loop_node_key[innerNode->variable_assign] -> end - innerNode := &compose2.NodeSchema{ + innerNode := &schema.NodeSchema{ Key: "innerNode", Type: entity.NodeTypeLambda, Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) { item1 := in["item1"].(string) item2 := in["item2"].(string) count := in["count"].(int) - return map[string]any{"total": int(count) + len(item1) + len(item2)}, nil + return map[string]any{"total": count + len(item1) + len(item2)}, nil }), InputSources: []*vo.FieldInfo{ { @@ -330,16 +333,18 @@ func TestLoop(t *testing.T) { }, } - assigner := &compose2.NodeSchema{ + assigner := &schema.NodeSchema{ Key: "assigner", Type: entity.NodeTypeVariableAssignerWithinLoop, - Configs: []*variableassigner.Pair{ - { - Left: vo.Reference{ - FromPath: compose.FieldPath{"count"}, - VariableType: ptr.Of(vo.ParentIntermediate), + Configs: &variableassigner.InLoopConfig{ + Pairs: []*variableassigner.Pair{ + { + Left: vo.Reference{ + FromPath: compose.FieldPath{"count"}, + VariableType: ptr.Of(vo.ParentIntermediate), + }, + Right: compose.FieldPath{"total"}, }, - Right: compose.FieldPath{"total"}, }, }, InputSources: []*vo.FieldInfo{ @@ -355,19 +360,17 @@ func TestLoop(t *testing.T) { }, } - entry := &compose2.NodeSchema{ - Key: entity.EntryNodeKey, - Type: entity.NodeTypeEntry, - Configs: map[string]any{ - "DefaultValues": map[string]any{}, - }, + entryN := &schema.NodeSchema{ + Key: entity.EntryNodeKey, + Type: entity.NodeTypeEntry, + Configs: &entry.Config{}, } - exit := &compose2.NodeSchema{ + exitN := &schema.NodeSchema{ Key: entity.ExitNodeKey, Type: entity.NodeTypeExit, - Configs: map[string]any{ - "TerminalPlan": vo.ReturnVariables, + Configs: &exit.Config{ + TerminatePlan: vo.ReturnVariables, }, InputSources: []*vo.FieldInfo{ { @@ -382,12 +385,13 @@ func TestLoop(t *testing.T) { }, } - loopNode := &compose2.NodeSchema{ + loopNode := &schema.NodeSchema{ Key: "loop_node_key", Type: entity.NodeTypeLoop, - Configs: map[string]any{ - "LoopType": loop.ByArray, - "IntermediateVars": map[string]*vo.TypeInfo{ + Configs: &loop.Config{ + LoopType: loop.ByArray, + InputArrays: []string{"items1", "items2"}, + IntermediateVars: map[string]*vo.TypeInfo{ "count": { Type: vo.DataTypeInteger, }, @@ -408,7 +412,7 @@ func TestLoop(t *testing.T) { Path: compose.FieldPath{"items1"}, Source: vo.FieldSource{ Ref: &vo.Reference{ - FromNodeKey: entry.Key, + FromNodeKey: entryN.Key, FromPath: compose.FieldPath{"items1"}, }, }, @@ -417,7 +421,7 @@ func TestLoop(t *testing.T) { Path: compose.FieldPath{"items2"}, Source: vo.FieldSource{ Ref: &vo.Reference{ - FromNodeKey: entry.Key, + FromNodeKey: entryN.Key, FromPath: compose.FieldPath{"items2"}, }, }, @@ -442,11 +446,11 @@ func TestLoop(t *testing.T) { }, } - ws := &compose2.WorkflowSchema{ - Nodes: []*compose2.NodeSchema{ - entry, + ws := &schema.WorkflowSchema{ + Nodes: []*schema.NodeSchema{ + entryN, loopNode, - exit, + exitN, innerNode, assigner, }, @@ -454,7 +458,7 @@ func TestLoop(t *testing.T) { "innerNode": "loop_node_key", "assigner": "loop_node_key", }, - Connections: []*compose2.Connection{ + Connections: []*schema.Connection{ { FromNode: "loop_node_key", ToNode: "innerNode", @@ -468,12 +472,12 @@ func TestLoop(t *testing.T) { ToNode: "loop_node_key", }, { - FromNode: entry.Key, + FromNode: entryN.Key, ToNode: "loop_node_key", }, { FromNode: "loop_node_key", - ToNode: exit.Key, + ToNode: exitN.Key, }, }, } diff --git a/backend/domain/workflow/internal/compose/test/question_answer_test.go b/backend/domain/workflow/internal/compose/test/question_answer_test.go index 291f76a0..4ad1b5e6 100644 --- a/backend/domain/workflow/internal/compose/test/question_answer_test.go +++ b/backend/domain/workflow/internal/compose/test/question_answer_test.go @@ -43,8 +43,11 @@ import ( "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa" repo2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/repo" + schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/infra/impl/checkpoint" mock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/idgen" storageMock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/storage" @@ -106,26 +109,25 @@ func TestQuestionAnswer(t *testing.T) { mockey.Mock(workflow.GetRepository).Return(repo).Build() t.Run("answer directly, no structured output", func(t *testing.T) { - entry := &compose2.NodeSchema{ - Key: entity.EntryNodeKey, - Type: entity.NodeTypeEntry, - Configs: map[string]any{ - "DefaultValues": map[string]any{}, - }} + entryN := &schema2.NodeSchema{ + Key: entity.EntryNodeKey, + Type: entity.NodeTypeEntry, + Configs: &entry.Config{}, + } - ns := &compose2.NodeSchema{ + ns := &schema2.NodeSchema{ Key: "qa_node_key", Type: entity.NodeTypeQuestionAnswer, - Configs: map[string]any{ - "QuestionTpl": "{{input}}", - "AnswerType": qa.AnswerDirectly, + Configs: &qa.Config{ + QuestionTpl: "{{input}}", + AnswerType: qa.AnswerDirectly, }, InputSources: []*vo.FieldInfo{ { Path: compose.FieldPath{"input"}, Source: vo.FieldSource{ Ref: &vo.Reference{ - FromNodeKey: entry.Key, + FromNodeKey: entryN.Key, FromPath: compose.FieldPath{"query"}, }, }, @@ -133,11 +135,11 @@ func TestQuestionAnswer(t *testing.T) { }, } - exit := &compose2.NodeSchema{ + exitN := &schema2.NodeSchema{ Key: entity.ExitNodeKey, Type: entity.NodeTypeExit, - Configs: map[string]any{ - "TerminalPlan": vo.ReturnVariables, + Configs: &exit.Config{ + TerminatePlan: vo.ReturnVariables, }, InputSources: []*vo.FieldInfo{ { @@ -152,20 +154,20 @@ func TestQuestionAnswer(t *testing.T) { }, } - ws := &compose2.WorkflowSchema{ - Nodes: []*compose2.NodeSchema{ - entry, + ws := &schema2.WorkflowSchema{ + Nodes: []*schema2.NodeSchema{ + entryN, ns, - exit, + exitN, }, - Connections: []*compose2.Connection{ + Connections: []*schema2.Connection{ { - FromNode: entry.Key, + FromNode: entryN.Key, ToNode: "qa_node_key", }, { FromNode: "qa_node_key", - ToNode: exit.Key, + ToNode: exitN.Key, }, }, } @@ -210,30 +212,28 @@ func TestQuestionAnswer(t *testing.T) { mockModelManager.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(oneChatModel, nil, nil).Times(1) } - entry := &compose2.NodeSchema{ - Key: entity.EntryNodeKey, - Type: entity.NodeTypeEntry, - Configs: map[string]any{ - "DefaultValues": map[string]any{}, - }, + entryN := &schema2.NodeSchema{ + Key: entity.EntryNodeKey, + Type: entity.NodeTypeEntry, + Configs: &entry.Config{}, } - ns := &compose2.NodeSchema{ + ns := &schema2.NodeSchema{ Key: "qa_node_key", Type: entity.NodeTypeQuestionAnswer, - Configs: map[string]any{ - "QuestionTpl": "{{input}}", - "AnswerType": qa.AnswerByChoices, - "ChoiceType": qa.FixedChoices, - "FixedChoices": []string{"{{choice1}}", "{{choice2}}"}, - "LLMParams": &model.LLMParams{}, + Configs: &qa.Config{ + QuestionTpl: "{{input}}", + AnswerType: qa.AnswerByChoices, + ChoiceType: qa.FixedChoices, + FixedChoices: []string{"{{choice1}}", "{{choice2}}"}, + LLMParams: &model.LLMParams{}, }, InputSources: []*vo.FieldInfo{ { Path: compose.FieldPath{"input"}, Source: vo.FieldSource{ Ref: &vo.Reference{ - FromNodeKey: entry.Key, + FromNodeKey: entryN.Key, FromPath: compose.FieldPath{"query"}, }, }, @@ -242,7 +242,7 @@ func TestQuestionAnswer(t *testing.T) { Path: compose.FieldPath{"choice1"}, Source: vo.FieldSource{ Ref: &vo.Reference{ - FromNodeKey: entry.Key, + FromNodeKey: entryN.Key, FromPath: compose.FieldPath{"choice1"}, }, }, @@ -251,7 +251,7 @@ func TestQuestionAnswer(t *testing.T) { Path: compose.FieldPath{"choice2"}, Source: vo.FieldSource{ Ref: &vo.Reference{ - FromNodeKey: entry.Key, + FromNodeKey: entryN.Key, FromPath: compose.FieldPath{"choice2"}, }, }, @@ -259,11 +259,11 @@ func TestQuestionAnswer(t *testing.T) { }, } - exit := &compose2.NodeSchema{ + exitN := &schema2.NodeSchema{ Key: entity.ExitNodeKey, Type: entity.NodeTypeExit, - Configs: map[string]any{ - "TerminalPlan": vo.ReturnVariables, + Configs: &exit.Config{ + TerminatePlan: vo.ReturnVariables, }, InputSources: []*vo.FieldInfo{ { @@ -287,7 +287,7 @@ func TestQuestionAnswer(t *testing.T) { }, } - lambda := &compose2.NodeSchema{ + lambda := &schema2.NodeSchema{ Key: "lambda", Type: entity.NodeTypeLambda, Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) { @@ -295,26 +295,26 @@ func TestQuestionAnswer(t *testing.T) { }), } - ws := &compose2.WorkflowSchema{ - Nodes: []*compose2.NodeSchema{ - entry, + ws := &schema2.WorkflowSchema{ + Nodes: []*schema2.NodeSchema{ + entryN, ns, - exit, + exitN, lambda, }, - Connections: []*compose2.Connection{ + Connections: []*schema2.Connection{ { - FromNode: entry.Key, + FromNode: entryN.Key, ToNode: "qa_node_key", }, { FromNode: "qa_node_key", - ToNode: exit.Key, + ToNode: exitN.Key, FromPort: ptr.Of("branch_0"), }, { FromNode: "qa_node_key", - ToNode: exit.Key, + ToNode: exitN.Key, FromPort: ptr.Of("branch_1"), }, { @@ -324,11 +324,15 @@ func TestQuestionAnswer(t *testing.T) { }, { FromNode: "lambda", - ToNode: exit.Key, + ToNode: exitN.Key, }, }, } + branches, err := schema2.BuildBranches(ws.Connections) + assert.NoError(t, err) + ws.Branches = branches + ws.Init() wf, err := compose2.NewWorkflow(context.Background(), ws) @@ -362,28 +366,26 @@ func TestQuestionAnswer(t *testing.T) { }) t.Run("answer with dynamic choices", func(t *testing.T) { - entry := &compose2.NodeSchema{ - Key: entity.EntryNodeKey, - Type: entity.NodeTypeEntry, - Configs: map[string]any{ - "DefaultValues": map[string]any{}, - }, + entryN := &schema2.NodeSchema{ + Key: entity.EntryNodeKey, + Type: entity.NodeTypeEntry, + Configs: &entry.Config{}, } - ns := &compose2.NodeSchema{ + ns := &schema2.NodeSchema{ Key: "qa_node_key", Type: entity.NodeTypeQuestionAnswer, - Configs: map[string]any{ - "QuestionTpl": "{{input}}", - "AnswerType": qa.AnswerByChoices, - "ChoiceType": qa.DynamicChoices, + Configs: &qa.Config{ + QuestionTpl: "{{input}}", + AnswerType: qa.AnswerByChoices, + ChoiceType: qa.DynamicChoices, }, InputSources: []*vo.FieldInfo{ { Path: compose.FieldPath{"input"}, Source: vo.FieldSource{ Ref: &vo.Reference{ - FromNodeKey: entry.Key, + FromNodeKey: entryN.Key, FromPath: compose.FieldPath{"query"}, }, }, @@ -392,7 +394,7 @@ func TestQuestionAnswer(t *testing.T) { Path: compose.FieldPath{qa.DynamicChoicesKey}, Source: vo.FieldSource{ Ref: &vo.Reference{ - FromNodeKey: entry.Key, + FromNodeKey: entryN.Key, FromPath: compose.FieldPath{"choices"}, }, }, @@ -400,11 +402,11 @@ func TestQuestionAnswer(t *testing.T) { }, } - exit := &compose2.NodeSchema{ + exitN := &schema2.NodeSchema{ Key: entity.ExitNodeKey, Type: entity.NodeTypeExit, - Configs: map[string]any{ - "TerminalPlan": vo.ReturnVariables, + Configs: &exit.Config{ + TerminatePlan: vo.ReturnVariables, }, InputSources: []*vo.FieldInfo{ { @@ -428,7 +430,7 @@ func TestQuestionAnswer(t *testing.T) { }, } - lambda := &compose2.NodeSchema{ + lambda := &schema2.NodeSchema{ Key: "lambda", Type: entity.NodeTypeLambda, Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) { @@ -436,26 +438,26 @@ func TestQuestionAnswer(t *testing.T) { }), } - ws := &compose2.WorkflowSchema{ - Nodes: []*compose2.NodeSchema{ - entry, + ws := &schema2.WorkflowSchema{ + Nodes: []*schema2.NodeSchema{ + entryN, ns, - exit, + exitN, lambda, }, - Connections: []*compose2.Connection{ + Connections: []*schema2.Connection{ { - FromNode: entry.Key, + FromNode: entryN.Key, ToNode: "qa_node_key", }, { FromNode: "qa_node_key", - ToNode: exit.Key, + ToNode: exitN.Key, FromPort: ptr.Of("branch_0"), }, { FromNode: "lambda", - ToNode: exit.Key, + ToNode: exitN.Key, }, { FromNode: "qa_node_key", @@ -465,6 +467,10 @@ func TestQuestionAnswer(t *testing.T) { }, } + branches, err := schema2.BuildBranches(ws.Connections) + assert.NoError(t, err) + ws.Branches = branches + ws.Init() wf, err := compose2.NewWorkflow(context.Background(), ws) @@ -522,31 +528,29 @@ func TestQuestionAnswer(t *testing.T) { mockModelManager.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(chatModel, nil, nil).Times(1) } - entry := &compose2.NodeSchema{ - Key: entity.EntryNodeKey, - Type: entity.NodeTypeEntry, - Configs: map[string]any{ - "DefaultValues": map[string]any{}, - }, + entryN := &schema2.NodeSchema{ + Key: entity.EntryNodeKey, + Type: entity.NodeTypeEntry, + Configs: &entry.Config{}, } - ns := &compose2.NodeSchema{ + ns := &schema2.NodeSchema{ Key: "qa_node_key", Type: entity.NodeTypeQuestionAnswer, - Configs: map[string]any{ - "QuestionTpl": "{{input}}", - "AnswerType": qa.AnswerDirectly, - "ExtractFromAnswer": true, - "AdditionalSystemPromptTpl": "{{prompt}}", - "MaxAnswerCount": 2, - "LLMParams": &model.LLMParams{}, + Configs: &qa.Config{ + QuestionTpl: "{{input}}", + AnswerType: qa.AnswerDirectly, + ExtractFromAnswer: true, + AdditionalSystemPromptTpl: "{{prompt}}", + MaxAnswerCount: 2, + LLMParams: &model.LLMParams{}, }, InputSources: []*vo.FieldInfo{ { Path: compose.FieldPath{"input"}, Source: vo.FieldSource{ Ref: &vo.Reference{ - FromNodeKey: entry.Key, + FromNodeKey: entryN.Key, FromPath: compose.FieldPath{"query"}, }, }, @@ -555,7 +559,7 @@ func TestQuestionAnswer(t *testing.T) { Path: compose.FieldPath{"prompt"}, Source: vo.FieldSource{ Ref: &vo.Reference{ - FromNodeKey: entry.Key, + FromNodeKey: entryN.Key, FromPath: compose.FieldPath{"prompt"}, }, }, @@ -573,11 +577,11 @@ func TestQuestionAnswer(t *testing.T) { }, } - exit := &compose2.NodeSchema{ + exitN := &schema2.NodeSchema{ Key: entity.ExitNodeKey, Type: entity.NodeTypeExit, - Configs: map[string]any{ - "TerminalPlan": vo.ReturnVariables, + Configs: &exit.Config{ + TerminatePlan: vo.ReturnVariables, }, InputSources: []*vo.FieldInfo{ { @@ -610,20 +614,20 @@ func TestQuestionAnswer(t *testing.T) { }, } - ws := &compose2.WorkflowSchema{ - Nodes: []*compose2.NodeSchema{ - entry, + ws := &schema2.WorkflowSchema{ + Nodes: []*schema2.NodeSchema{ + entryN, ns, - exit, + exitN, }, - Connections: []*compose2.Connection{ + Connections: []*schema2.Connection{ { - FromNode: entry.Key, + FromNode: entryN.Key, ToNode: "qa_node_key", }, { FromNode: "qa_node_key", - ToNode: exit.Key, + ToNode: exitN.Key, }, }, } diff --git a/backend/domain/workflow/internal/compose/test/workflow_test.go b/backend/domain/workflow/internal/compose/test/workflow_test.go index 8ec39f16..59f25189 100644 --- a/backend/domain/workflow/internal/compose/test/workflow_test.go +++ b/backend/domain/workflow/internal/compose/test/workflow_test.go @@ -26,26 +26,28 @@ import ( "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/textprocessor" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableaggregator" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" ) func TestAddSelector(t *testing.T) { // start -> selector, selector.condition1 -> lambda1 -> end, selector.condition2 -> [lambda2, lambda3] -> end, selector default -> end - entry := &compose2.NodeSchema{ - Key: entity.EntryNodeKey, - Type: entity.NodeTypeEntry, - Configs: map[string]any{ - "DefaultValues": map[string]any{}, - }} + entryN := &schema.NodeSchema{ + Key: entity.EntryNodeKey, + Type: entity.NodeTypeEntry, + Configs: &entry.Config{}, + } - exit := &compose2.NodeSchema{ + exitN := &schema.NodeSchema{ Key: entity.ExitNodeKey, Type: entity.NodeTypeExit, - Configs: map[string]any{ - "TerminalPlan": vo.ReturnVariables, + Configs: &exit.Config{ + TerminatePlan: vo.ReturnVariables, }, InputSources: []*vo.FieldInfo{ { @@ -84,7 +86,7 @@ func TestAddSelector(t *testing.T) { }, nil } - lambdaNode1 := &compose2.NodeSchema{ + lambdaNode1 := &schema.NodeSchema{ Key: "lambda1", Type: entity.NodeTypeLambda, Lambda: compose.InvokableLambda(lambda1), @@ -96,7 +98,7 @@ func TestAddSelector(t *testing.T) { }, nil } - LambdaNode2 := &compose2.NodeSchema{ + LambdaNode2 := &schema.NodeSchema{ Key: "lambda2", Type: entity.NodeTypeLambda, Lambda: compose.InvokableLambda(lambda2), @@ -108,16 +110,16 @@ func TestAddSelector(t *testing.T) { }, nil } - lambdaNode3 := &compose2.NodeSchema{ + lambdaNode3 := &schema.NodeSchema{ Key: "lambda3", Type: entity.NodeTypeLambda, Lambda: compose.InvokableLambda(lambda3), } - ns := &compose2.NodeSchema{ + ns := &schema.NodeSchema{ Key: "selector", Type: entity.NodeTypeSelector, - Configs: map[string]any{"Clauses": []*selector.OneClauseSchema{ + Configs: &selector.Config{Clauses: []*selector.OneClauseSchema{ { Single: ptr.Of(selector.OperatorEqual), }, @@ -136,7 +138,7 @@ func TestAddSelector(t *testing.T) { Path: compose.FieldPath{"0", selector.LeftKey}, Source: vo.FieldSource{ Ref: &vo.Reference{ - FromNodeKey: entry.Key, + FromNodeKey: entryN.Key, FromPath: compose.FieldPath{"key1"}, }, }, @@ -151,7 +153,7 @@ func TestAddSelector(t *testing.T) { Path: compose.FieldPath{"1", "0", selector.LeftKey}, Source: vo.FieldSource{ Ref: &vo.Reference{ - FromNodeKey: entry.Key, + FromNodeKey: entryN.Key, FromPath: compose.FieldPath{"key2"}, }, }, @@ -160,7 +162,7 @@ func TestAddSelector(t *testing.T) { Path: compose.FieldPath{"1", "0", selector.RightKey}, Source: vo.FieldSource{ Ref: &vo.Reference{ - FromNodeKey: entry.Key, + FromNodeKey: entryN.Key, FromPath: compose.FieldPath{"key3"}, }, }, @@ -169,7 +171,7 @@ func TestAddSelector(t *testing.T) { Path: compose.FieldPath{"1", "1", selector.LeftKey}, Source: vo.FieldSource{ Ref: &vo.Reference{ - FromNodeKey: entry.Key, + FromNodeKey: entryN.Key, FromPath: compose.FieldPath{"key4"}, }, }, @@ -214,18 +216,18 @@ func TestAddSelector(t *testing.T) { }, } - ws := &compose2.WorkflowSchema{ - Nodes: []*compose2.NodeSchema{ - entry, + ws := &schema.WorkflowSchema{ + Nodes: []*schema.NodeSchema{ + entryN, ns, lambdaNode1, LambdaNode2, lambdaNode3, - exit, + exitN, }, - Connections: []*compose2.Connection{ + Connections: []*schema.Connection{ { - FromNode: entry.Key, + FromNode: entryN.Key, ToNode: "selector", }, { @@ -245,24 +247,28 @@ func TestAddSelector(t *testing.T) { }, { FromNode: "selector", - ToNode: exit.Key, + ToNode: exitN.Key, FromPort: ptr.Of("default"), }, { FromNode: "lambda1", - ToNode: exit.Key, + ToNode: exitN.Key, }, { FromNode: "lambda2", - ToNode: exit.Key, + ToNode: exitN.Key, }, { FromNode: "lambda3", - ToNode: exit.Key, + ToNode: exitN.Key, }, }, } + branches, err := schema.BuildBranches(ws.Connections) + assert.NoError(t, err) + ws.Branches = branches + ws.Init() ctx := context.Background() @@ -303,19 +309,17 @@ func TestAddSelector(t *testing.T) { } func TestVariableAggregator(t *testing.T) { - entry := &compose2.NodeSchema{ - Key: entity.EntryNodeKey, - Type: entity.NodeTypeEntry, - Configs: map[string]any{ - "DefaultValues": map[string]any{}, - }, + entryN := &schema.NodeSchema{ + Key: entity.EntryNodeKey, + Type: entity.NodeTypeEntry, + Configs: &entry.Config{}, } - exit := &compose2.NodeSchema{ + exitN := &schema.NodeSchema{ Key: entity.ExitNodeKey, Type: entity.NodeTypeExit, - Configs: map[string]any{ - "TerminalPlan": vo.ReturnVariables, + Configs: &exit.Config{ + TerminatePlan: vo.ReturnVariables, }, InputSources: []*vo.FieldInfo{ { @@ -339,16 +343,16 @@ func TestVariableAggregator(t *testing.T) { }, } - ns := &compose2.NodeSchema{ + ns := &schema.NodeSchema{ Key: "va", Type: entity.NodeTypeVariableAggregator, - Configs: map[string]any{ - "MergeStrategy": variableaggregator.FirstNotNullValue, - "GroupToLen": map[string]int{ + Configs: &variableaggregator.Config{ + MergeStrategy: variableaggregator.FirstNotNullValue, + GroupLen: map[string]int{ "Group1": 1, "Group2": 1, }, - "GroupOrder": []string{ + GroupOrder: []string{ "Group1", "Group2", }, @@ -358,7 +362,7 @@ func TestVariableAggregator(t *testing.T) { Path: compose.FieldPath{"Group1", "0"}, Source: vo.FieldSource{ Ref: &vo.Reference{ - FromNodeKey: entry.Key, + FromNodeKey: entryN.Key, FromPath: compose.FieldPath{"Str1"}, }, }, @@ -367,7 +371,7 @@ func TestVariableAggregator(t *testing.T) { Path: compose.FieldPath{"Group2", "0"}, Source: vo.FieldSource{ Ref: &vo.Reference{ - FromNodeKey: entry.Key, + FromNodeKey: entryN.Key, FromPath: compose.FieldPath{"Int1"}, }, }, @@ -401,20 +405,20 @@ func TestVariableAggregator(t *testing.T) { }, } - ws := &compose2.WorkflowSchema{ - Nodes: []*compose2.NodeSchema{ - entry, + ws := &schema.WorkflowSchema{ + Nodes: []*schema.NodeSchema{ + entryN, ns, - exit, + exitN, }, - Connections: []*compose2.Connection{ + Connections: []*schema.Connection{ { - FromNode: entry.Key, + FromNode: entryN.Key, ToNode: "va", }, { FromNode: "va", - ToNode: exit.Key, + ToNode: exitN.Key, }, }, } @@ -448,19 +452,17 @@ func TestVariableAggregator(t *testing.T) { func TestTextProcessor(t *testing.T) { t.Run("split", func(t *testing.T) { - entry := &compose2.NodeSchema{ - Key: entity.EntryNodeKey, - Type: entity.NodeTypeEntry, - Configs: map[string]any{ - "DefaultValues": map[string]any{}, - }, + entryN := &schema.NodeSchema{ + Key: entity.EntryNodeKey, + Type: entity.NodeTypeEntry, + Configs: &entry.Config{}, } - exit := &compose2.NodeSchema{ + exitN := &schema.NodeSchema{ Key: entity.ExitNodeKey, Type: entity.NodeTypeExit, - Configs: map[string]any{ - "TerminalPlan": vo.ReturnVariables, + Configs: &exit.Config{ + TerminatePlan: vo.ReturnVariables, }, InputSources: []*vo.FieldInfo{ { @@ -475,19 +477,19 @@ func TestTextProcessor(t *testing.T) { }, } - ns := &compose2.NodeSchema{ + ns := &schema.NodeSchema{ Key: "tp", Type: entity.NodeTypeTextProcessor, - Configs: map[string]any{ - "Type": textprocessor.SplitText, - "Separators": []string{"|"}, + Configs: &textprocessor.Config{ + Type: textprocessor.SplitText, + Separators: []string{"|"}, }, InputSources: []*vo.FieldInfo{ { Path: compose.FieldPath{"String"}, Source: vo.FieldSource{ Ref: &vo.Reference{ - FromNodeKey: entry.Key, + FromNodeKey: entryN.Key, FromPath: compose.FieldPath{"Str"}, }, }, @@ -495,20 +497,20 @@ func TestTextProcessor(t *testing.T) { }, } - ws := &compose2.WorkflowSchema{ - Nodes: []*compose2.NodeSchema{ + ws := &schema.WorkflowSchema{ + Nodes: []*schema.NodeSchema{ ns, - entry, - exit, + entryN, + exitN, }, - Connections: []*compose2.Connection{ + Connections: []*schema.Connection{ { - FromNode: entry.Key, + FromNode: entryN.Key, ToNode: "tp", }, { FromNode: "tp", - ToNode: exit.Key, + ToNode: exitN.Key, }, }, } @@ -527,19 +529,17 @@ func TestTextProcessor(t *testing.T) { }) t.Run("concat", func(t *testing.T) { - entry := &compose2.NodeSchema{ - Key: entity.EntryNodeKey, - Type: entity.NodeTypeEntry, - Configs: map[string]any{ - "DefaultValues": map[string]any{}, - }, + entryN := &schema.NodeSchema{ + Key: entity.EntryNodeKey, + Type: entity.NodeTypeEntry, + Configs: &entry.Config{}, } - exit := &compose2.NodeSchema{ + exitN := &schema.NodeSchema{ Key: entity.ExitNodeKey, Type: entity.NodeTypeExit, - Configs: map[string]any{ - "TerminalPlan": vo.ReturnVariables, + Configs: &exit.Config{ + TerminatePlan: vo.ReturnVariables, }, InputSources: []*vo.FieldInfo{ { @@ -554,20 +554,20 @@ func TestTextProcessor(t *testing.T) { }, } - ns := &compose2.NodeSchema{ + ns := &schema.NodeSchema{ Key: "tp", Type: entity.NodeTypeTextProcessor, - Configs: map[string]any{ - "Type": textprocessor.ConcatText, - "Tpl": "{{String1}}_{{String2.f1}}_{{String3.f2[1]}}", - "ConcatChar": "\t", + Configs: &textprocessor.Config{ + Type: textprocessor.ConcatText, + Tpl: "{{String1}}_{{String2.f1}}_{{String3.f2[1]}}", + ConcatChar: "\t", }, InputSources: []*vo.FieldInfo{ { Path: compose.FieldPath{"String1"}, Source: vo.FieldSource{ Ref: &vo.Reference{ - FromNodeKey: entry.Key, + FromNodeKey: entryN.Key, FromPath: compose.FieldPath{"Str1"}, }, }, @@ -576,7 +576,7 @@ func TestTextProcessor(t *testing.T) { Path: compose.FieldPath{"String2"}, Source: vo.FieldSource{ Ref: &vo.Reference{ - FromNodeKey: entry.Key, + FromNodeKey: entryN.Key, FromPath: compose.FieldPath{"Str2"}, }, }, @@ -585,7 +585,7 @@ func TestTextProcessor(t *testing.T) { Path: compose.FieldPath{"String3"}, Source: vo.FieldSource{ Ref: &vo.Reference{ - FromNodeKey: entry.Key, + FromNodeKey: entryN.Key, FromPath: compose.FieldPath{"Str3"}, }, }, @@ -593,20 +593,20 @@ func TestTextProcessor(t *testing.T) { }, } - ws := &compose2.WorkflowSchema{ - Nodes: []*compose2.NodeSchema{ + ws := &schema.WorkflowSchema{ + Nodes: []*schema.NodeSchema{ ns, - entry, - exit, + entryN, + exitN, }, - Connections: []*compose2.Connection{ + Connections: []*schema.Connection{ { - FromNode: entry.Key, + FromNode: entryN.Key, ToNode: "tp", }, { FromNode: "tp", - ToNode: exit.Key, + ToNode: exitN.Key, }, }, } diff --git a/backend/domain/workflow/internal/compose/to_node.go b/backend/domain/workflow/internal/compose/to_node.go deleted file mode 100644 index fa1cb161..00000000 --- a/backend/domain/workflow/internal/compose/to_node.go +++ /dev/null @@ -1,652 +0,0 @@ -/* - * Copyright 2025 coze-dev Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package compose - -import ( - "context" - "errors" - "fmt" - "runtime/debug" - "strconv" - "time" - - einomodel "github.com/cloudwego/eino/components/model" - "github.com/cloudwego/eino/components/tool" - "github.com/cloudwego/eino/compose" - "github.com/cloudwego/eino/schema" - - workflow3 "github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/workflow" - workflow2 "github.com/coze-dev/coze-studio/backend/domain/workflow" - crosscode "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/code" - crossconversation "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/conversation" - crossdatabase "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database" - crossknowledge "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge" - "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model" - crossplugin "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin" - "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable" - "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/batch" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/code" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/conversation" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/database" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/emitter" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/httprequester" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/intentdetector" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/json" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/knowledge" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/plugin" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/receiver" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/subworkflow" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/textprocessor" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableaggregator" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableassigner" - "github.com/coze-dev/coze-studio/backend/infra/contract/coderunner" - "github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr" - "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" - "github.com/coze-dev/coze-studio/backend/pkg/safego" -) - -func (s *NodeSchema) ToEntryConfig(_ context.Context) (*entry.Config, error) { - return &entry.Config{ - DefaultValues: getKeyOrZero[map[string]any]("DefaultValues", s.Configs), - OutputTypes: s.OutputTypes, - }, nil -} - -func (s *NodeSchema) ToLLMConfig(ctx context.Context) (*llm.Config, error) { - llmConf := &llm.Config{ - SystemPrompt: getKeyOrZero[string]("SystemPrompt", s.Configs), - UserPrompt: getKeyOrZero[string]("UserPrompt", s.Configs), - OutputFormat: mustGetKey[llm.Format]("OutputFormat", s.Configs), - InputFields: s.InputTypes, - OutputFields: s.OutputTypes, - FullSources: getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs), - } - - llmParams := getKeyOrZero[*model.LLMParams]("LLMParams", s.Configs) - - if llmParams == nil { - return nil, fmt.Errorf("llm node llmParams is required") - } - var ( - err error - chatModel, fallbackM einomodel.BaseChatModel - info, fallbackI *modelmgr.Model - modelWithInfo llm.ModelWithInfo - ) - - chatModel, info, err = model.GetManager().GetModel(ctx, llmParams) - if err != nil { - return nil, err - } - - metaConfigs := s.ExceptionConfigs - if metaConfigs != nil && metaConfigs.MaxRetry > 0 { - backupModelParams := getKeyOrZero[*model.LLMParams]("BackupLLMParams", s.Configs) - if backupModelParams != nil { - fallbackM, fallbackI, err = model.GetManager().GetModel(ctx, backupModelParams) - if err != nil { - return nil, err - } - } - } - - if fallbackM == nil { - modelWithInfo = llm.NewModel(chatModel, info) - } else { - modelWithInfo = llm.NewModelWithFallback(chatModel, fallbackM, info, fallbackI) - } - llmConf.ChatModel = modelWithInfo - - fcParams := getKeyOrZero[*vo.FCParam]("FCParam", s.Configs) - if fcParams != nil { - if fcParams.WorkflowFCParam != nil { - for _, wf := range fcParams.WorkflowFCParam.WorkflowList { - wfIDStr := wf.WorkflowID - wfID, err := strconv.ParseInt(wfIDStr, 10, 64) - if err != nil { - return nil, fmt.Errorf("invalid workflow id: %s", wfIDStr) - } - - workflowToolConfig := vo.WorkflowToolConfig{} - if wf.FCSetting != nil { - workflowToolConfig.InputParametersConfig = wf.FCSetting.RequestParameters - workflowToolConfig.OutputParametersConfig = wf.FCSetting.ResponseParameters - } - - locator := vo.FromDraft - if wf.WorkflowVersion != "" { - locator = vo.FromSpecificVersion - } - - wfTool, err := workflow2.GetRepository().WorkflowAsTool(ctx, vo.GetPolicy{ - ID: wfID, - QType: locator, - Version: wf.WorkflowVersion, - }, workflowToolConfig) - if err != nil { - return nil, err - } - llmConf.Tools = append(llmConf.Tools, wfTool) - if wfTool.TerminatePlan() == vo.UseAnswerContent { - if llmConf.ToolsReturnDirectly == nil { - llmConf.ToolsReturnDirectly = make(map[string]bool) - } - toolInfo, err := wfTool.Info(ctx) - if err != nil { - return nil, err - } - llmConf.ToolsReturnDirectly[toolInfo.Name] = true - } - } - } - - if fcParams.PluginFCParam != nil { - pluginToolsInvokableReq := make(map[int64]*crossplugin.ToolsInvokableRequest) - for _, p := range fcParams.PluginFCParam.PluginList { - pid, err := strconv.ParseInt(p.PluginID, 10, 64) - if err != nil { - return nil, fmt.Errorf("invalid plugin id: %s", p.PluginID) - } - toolID, err := strconv.ParseInt(p.ApiId, 10, 64) - if err != nil { - return nil, fmt.Errorf("invalid plugin id: %s", p.PluginID) - } - - var ( - requestParameters []*workflow3.APIParameter - responseParameters []*workflow3.APIParameter - ) - if p.FCSetting != nil { - requestParameters = p.FCSetting.RequestParameters - responseParameters = p.FCSetting.ResponseParameters - } - - if req, ok := pluginToolsInvokableReq[pid]; ok { - req.ToolsInvokableInfo[toolID] = &crossplugin.ToolsInvokableInfo{ - ToolID: toolID, - RequestAPIParametersConfig: requestParameters, - ResponseAPIParametersConfig: responseParameters, - } - } else { - pluginToolsInfoRequest := &crossplugin.ToolsInvokableRequest{ - PluginEntity: crossplugin.Entity{ - PluginID: pid, - PluginVersion: ptr.Of(p.PluginVersion), - }, - ToolsInvokableInfo: map[int64]*crossplugin.ToolsInvokableInfo{ - toolID: { - ToolID: toolID, - RequestAPIParametersConfig: requestParameters, - ResponseAPIParametersConfig: responseParameters, - }, - }, - IsDraft: p.IsDraft, - } - pluginToolsInvokableReq[pid] = pluginToolsInfoRequest - } - } - inInvokableTools := make([]tool.BaseTool, 0, len(fcParams.PluginFCParam.PluginList)) - for _, req := range pluginToolsInvokableReq { - toolMap, err := crossplugin.GetPluginService().GetPluginInvokableTools(ctx, req) - if err != nil { - return nil, err - } - for _, t := range toolMap { - inInvokableTools = append(inInvokableTools, crossplugin.NewInvokableTool(t)) - } - } - if len(inInvokableTools) > 0 { - llmConf.Tools = inInvokableTools - } - } - - if fcParams.KnowledgeFCParam != nil && len(fcParams.KnowledgeFCParam.KnowledgeList) > 0 { - kwChatModel := workflow2.GetRepository().GetKnowledgeRecallChatModel() - if kwChatModel == nil { - return nil, fmt.Errorf("workflow builtin chat model for knowledge recall not configured") - } - knowledgeOperator := crossknowledge.GetKnowledgeOperator() - setting := fcParams.KnowledgeFCParam.GlobalSetting - cfg := &llm.KnowledgeRecallConfig{ - ChatModel: kwChatModel, - Retriever: knowledgeOperator, - } - searchType, err := totRetrievalSearchType(setting.SearchMode) - if err != nil { - return nil, err - } - cfg.RetrievalStrategy = &llm.RetrievalStrategy{ - RetrievalStrategy: &crossknowledge.RetrievalStrategy{ - TopK: ptr.Of(setting.TopK), - MinScore: ptr.Of(setting.MinScore), - SearchType: searchType, - EnableNL2SQL: setting.UseNL2SQL, - EnableQueryRewrite: setting.UseRewrite, - EnableRerank: setting.UseRerank, - }, - NoReCallReplyMode: llm.NoReCallReplyMode(setting.NoRecallReplyMode), - NoReCallReplyCustomizePrompt: setting.NoRecallReplyCustomizePrompt, - } - - knowledgeIDs := make([]int64, 0, len(fcParams.KnowledgeFCParam.KnowledgeList)) - for _, kw := range fcParams.KnowledgeFCParam.KnowledgeList { - kid, err := strconv.ParseInt(kw.ID, 10, 64) - if err != nil { - return nil, err - } - knowledgeIDs = append(knowledgeIDs, kid) - } - - detailResp, err := knowledgeOperator.ListKnowledgeDetail(ctx, &crossknowledge.ListKnowledgeDetailRequest{ - KnowledgeIDs: knowledgeIDs, - }) - if err != nil { - return nil, err - } - cfg.SelectedKnowledgeDetails = detailResp.KnowledgeDetails - llmConf.KnowledgeRecallConfig = cfg - } - - } - - return llmConf, nil -} - -func (s *NodeSchema) ToSelectorConfig() *selector.Config { - return &selector.Config{ - Clauses: mustGetKey[[]*selector.OneClauseSchema]("Clauses", s.Configs), - } -} - -func (s *NodeSchema) SelectorInputConverter(in map[string]any) (out []selector.Operants, err error) { - conf := mustGetKey[[]*selector.OneClauseSchema]("Clauses", s.Configs) - - for i, oneConf := range conf { - if oneConf.Single != nil { - left, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), selector.LeftKey}) - if !ok { - return nil, fmt.Errorf("failed to take left operant from input map: %v, clause index= %d", in, i) - } - - right, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), selector.RightKey}) - if ok { - out = append(out, selector.Operants{Left: left, Right: right}) - } else { - out = append(out, selector.Operants{Left: left}) - } - } else if oneConf.Multi != nil { - multiClause := make([]*selector.Operants, 0) - for j := range oneConf.Multi.Clauses { - left, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), strconv.Itoa(j), selector.LeftKey}) - if !ok { - return nil, fmt.Errorf("failed to take left operant from input map: %v, clause index= %d, single clause index= %d", in, i, j) - } - right, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), strconv.Itoa(j), selector.RightKey}) - if ok { - multiClause = append(multiClause, &selector.Operants{Left: left, Right: right}) - } else { - multiClause = append(multiClause, &selector.Operants{Left: left}) - } - } - out = append(out, selector.Operants{Multi: multiClause}) - } else { - return nil, fmt.Errorf("invalid clause config, both single and multi are nil: %v", oneConf) - } - } - - return out, nil -} - -func (s *NodeSchema) ToBatchConfig(inner compose.Runnable[map[string]any, map[string]any]) (*batch.Config, error) { - conf := &batch.Config{ - BatchNodeKey: s.Key, - InnerWorkflow: inner, - Outputs: s.OutputSources, - } - - for key, tInfo := range s.InputTypes { - if tInfo.Type != vo.DataTypeArray { - continue - } - - conf.InputArrays = append(conf.InputArrays, key) - } - - return conf, nil -} - -func (s *NodeSchema) ToVariableAggregatorConfig() (*variableaggregator.Config, error) { - return &variableaggregator.Config{ - MergeStrategy: s.Configs.(map[string]any)["MergeStrategy"].(variableaggregator.MergeStrategy), - GroupLen: s.Configs.(map[string]any)["GroupToLen"].(map[string]int), - FullSources: getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs), - NodeKey: s.Key, - InputSources: s.InputSources, - GroupOrder: mustGetKey[[]string]("GroupOrder", s.Configs), - }, nil -} - -func (s *NodeSchema) variableAggregatorInputConverter(in map[string]any) (converted map[string]map[int]any) { - converted = make(map[string]map[int]any) - - for k, value := range in { - m, ok := value.(map[string]any) - if !ok { - panic(errors.New("value is not a map[string]any")) - } - converted[k] = make(map[int]any, len(m)) - for i, sv := range m { - index, err := strconv.Atoi(i) - if err != nil { - panic(fmt.Errorf(" converting %s to int failed, err=%v", i, err)) - } - converted[k][index] = sv - } - } - - return converted -} - -func (s *NodeSchema) variableAggregatorStreamInputConverter(in *schema.StreamReader[map[string]any]) *schema.StreamReader[map[string]map[int]any] { - converter := func(input map[string]any) (output map[string]map[int]any, err error) { - defer func() { - if r := recover(); r != nil { - err = safego.NewPanicErr(r, debug.Stack()) - } - }() - return s.variableAggregatorInputConverter(input), nil - } - return schema.StreamReaderWithConvert(in, converter) -} - -func (s *NodeSchema) ToTextProcessorConfig() (*textprocessor.Config, error) { - return &textprocessor.Config{ - Type: s.Configs.(map[string]any)["Type"].(textprocessor.Type), - Tpl: getKeyOrZero[string]("Tpl", s.Configs.(map[string]any)), - ConcatChar: getKeyOrZero[string]("ConcatChar", s.Configs.(map[string]any)), - Separators: getKeyOrZero[[]string]("Separators", s.Configs.(map[string]any)), - FullSources: getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs), - }, nil -} - -func (s *NodeSchema) ToJsonSerializationConfig() (*json.SerializationConfig, error) { - return &json.SerializationConfig{ - InputTypes: s.InputTypes, - }, nil -} - -func (s *NodeSchema) ToJsonDeserializationConfig() (*json.DeserializationConfig, error) { - return &json.DeserializationConfig{ - OutputFields: s.OutputTypes, - }, nil -} - -func (s *NodeSchema) ToHTTPRequesterConfig() (*httprequester.Config, error) { - return &httprequester.Config{ - URLConfig: mustGetKey[httprequester.URLConfig]("URLConfig", s.Configs), - AuthConfig: getKeyOrZero[*httprequester.AuthenticationConfig]("AuthConfig", s.Configs), - BodyConfig: mustGetKey[httprequester.BodyConfig]("BodyConfig", s.Configs), - Method: mustGetKey[string]("Method", s.Configs), - Timeout: mustGetKey[time.Duration]("Timeout", s.Configs), - RetryTimes: mustGetKey[uint64]("RetryTimes", s.Configs), - MD5FieldMapping: mustGetKey[httprequester.MD5FieldMapping]("MD5FieldMapping", s.Configs), - }, nil -} - -func (s *NodeSchema) ToVariableAssignerConfig(handler *variable.Handler) (*variableassigner.Config, error) { - return &variableassigner.Config{ - Pairs: s.Configs.([]*variableassigner.Pair), - Handler: handler, - }, nil -} - -func (s *NodeSchema) ToVariableAssignerInLoopConfig() (*variableassigner.Config, error) { - return &variableassigner.Config{ - Pairs: s.Configs.([]*variableassigner.Pair), - }, nil -} - -func (s *NodeSchema) ToLoopConfig(inner compose.Runnable[map[string]any, map[string]any]) (*loop.Config, error) { - conf := &loop.Config{ - LoopNodeKey: s.Key, - LoopType: mustGetKey[loop.Type]("LoopType", s.Configs), - IntermediateVars: getKeyOrZero[map[string]*vo.TypeInfo]("IntermediateVars", s.Configs), - Outputs: s.OutputSources, - Inner: inner, - } - - for key, tInfo := range s.InputTypes { - if tInfo.Type != vo.DataTypeArray { - continue - } - - if _, ok := conf.IntermediateVars[key]; ok { // exclude arrays in intermediate vars - continue - } - - conf.InputArrays = append(conf.InputArrays, key) - } - - return conf, nil -} - -func (s *NodeSchema) ToQAConfig(ctx context.Context) (*qa.Config, error) { - conf := &qa.Config{ - QuestionTpl: mustGetKey[string]("QuestionTpl", s.Configs), - AnswerType: mustGetKey[qa.AnswerType]("AnswerType", s.Configs), - ChoiceType: getKeyOrZero[qa.ChoiceType]("ChoiceType", s.Configs), - FixedChoices: getKeyOrZero[[]string]("FixedChoices", s.Configs), - ExtractFromAnswer: getKeyOrZero[bool]("ExtractFromAnswer", s.Configs), - MaxAnswerCount: getKeyOrZero[int]("MaxAnswerCount", s.Configs), - AdditionalSystemPromptTpl: getKeyOrZero[string]("AdditionalSystemPromptTpl", s.Configs), - OutputFields: s.OutputTypes, - NodeKey: s.Key, - } - - llmParams := getKeyOrZero[*model.LLMParams]("LLMParams", s.Configs) - if llmParams != nil { - m, _, err := model.GetManager().GetModel(ctx, llmParams) - if err != nil { - return nil, err - } - - conf.Model = m - } - - return conf, nil -} - -func (s *NodeSchema) ToInputReceiverConfig() (*receiver.Config, error) { - return &receiver.Config{ - OutputTypes: s.OutputTypes, - NodeKey: s.Key, - OutputSchema: mustGetKey[string]("OutputSchema", s.Configs), - }, nil -} - -func (s *NodeSchema) ToOutputEmitterConfig(sc *WorkflowSchema) (*emitter.Config, error) { - conf := &emitter.Config{ - Template: getKeyOrZero[string]("Template", s.Configs), - FullSources: getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs), - } - - return conf, nil -} - -func (s *NodeSchema) ToDatabaseCustomSQLConfig() (*database.CustomSQLConfig, error) { - return &database.CustomSQLConfig{ - DatabaseInfoID: mustGetKey[int64]("DatabaseInfoID", s.Configs), - SQLTemplate: mustGetKey[string]("SQLTemplate", s.Configs), - OutputConfig: s.OutputTypes, - CustomSQLExecutor: crossdatabase.GetDatabaseOperator(), - }, nil -} - -func (s *NodeSchema) ToDatabaseQueryConfig() (*database.QueryConfig, error) { - return &database.QueryConfig{ - DatabaseInfoID: mustGetKey[int64]("DatabaseInfoID", s.Configs), - QueryFields: getKeyOrZero[[]string]("QueryFields", s.Configs), - OrderClauses: getKeyOrZero[[]*crossdatabase.OrderClause]("OrderClauses", s.Configs), - ClauseGroup: getKeyOrZero[*crossdatabase.ClauseGroup]("ClauseGroup", s.Configs), - OutputConfig: s.OutputTypes, - Limit: mustGetKey[int64]("Limit", s.Configs), - Op: crossdatabase.GetDatabaseOperator(), - }, nil -} - -func (s *NodeSchema) ToDatabaseInsertConfig() (*database.InsertConfig, error) { - return &database.InsertConfig{ - DatabaseInfoID: mustGetKey[int64]("DatabaseInfoID", s.Configs), - OutputConfig: s.OutputTypes, - Inserter: crossdatabase.GetDatabaseOperator(), - }, nil -} - -func (s *NodeSchema) ToDatabaseDeleteConfig() (*database.DeleteConfig, error) { - return &database.DeleteConfig{ - DatabaseInfoID: mustGetKey[int64]("DatabaseInfoID", s.Configs), - ClauseGroup: mustGetKey[*crossdatabase.ClauseGroup]("ClauseGroup", s.Configs), - OutputConfig: s.OutputTypes, - Deleter: crossdatabase.GetDatabaseOperator(), - }, nil -} - -func (s *NodeSchema) ToDatabaseUpdateConfig() (*database.UpdateConfig, error) { - return &database.UpdateConfig{ - DatabaseInfoID: mustGetKey[int64]("DatabaseInfoID", s.Configs), - ClauseGroup: mustGetKey[*crossdatabase.ClauseGroup]("ClauseGroup", s.Configs), - OutputConfig: s.OutputTypes, - Updater: crossdatabase.GetDatabaseOperator(), - }, nil -} - -func (s *NodeSchema) ToKnowledgeIndexerConfig() (*knowledge.IndexerConfig, error) { - return &knowledge.IndexerConfig{ - KnowledgeID: mustGetKey[int64]("KnowledgeID", s.Configs), - ParsingStrategy: mustGetKey[*crossknowledge.ParsingStrategy]("ParsingStrategy", s.Configs), - ChunkingStrategy: mustGetKey[*crossknowledge.ChunkingStrategy]("ChunkingStrategy", s.Configs), - KnowledgeIndexer: crossknowledge.GetKnowledgeOperator(), - }, nil -} - -func (s *NodeSchema) ToKnowledgeRetrieveConfig() (*knowledge.RetrieveConfig, error) { - return &knowledge.RetrieveConfig{ - KnowledgeIDs: mustGetKey[[]int64]("KnowledgeIDs", s.Configs), - RetrievalStrategy: mustGetKey[*crossknowledge.RetrievalStrategy]("RetrievalStrategy", s.Configs), - Retriever: crossknowledge.GetKnowledgeOperator(), - }, nil -} - -func (s *NodeSchema) ToKnowledgeDeleterConfig() (*knowledge.DeleterConfig, error) { - return &knowledge.DeleterConfig{ - KnowledgeID: mustGetKey[int64]("KnowledgeID", s.Configs), - KnowledgeDeleter: crossknowledge.GetKnowledgeOperator(), - }, nil -} - -func (s *NodeSchema) ToPluginConfig() (*plugin.Config, error) { - return &plugin.Config{ - PluginID: mustGetKey[int64]("PluginID", s.Configs), - ToolID: mustGetKey[int64]("ToolID", s.Configs), - PluginVersion: mustGetKey[string]("PluginVersion", s.Configs), - PluginService: crossplugin.GetPluginService(), - }, nil -} - -func (s *NodeSchema) ToCodeRunnerConfig() (*code.Config, error) { - return &code.Config{ - Code: mustGetKey[string]("Code", s.Configs), - Language: mustGetKey[coderunner.Language]("Language", s.Configs), - OutputConfig: s.OutputTypes, - Runner: crosscode.GetCodeRunner(), - }, nil -} - -func (s *NodeSchema) ToCreateConversationConfig() (*conversation.CreateConversationConfig, error) { - return &conversation.CreateConversationConfig{ - Creator: crossconversation.ConversationManagerImpl, - }, nil -} - -func (s *NodeSchema) ToClearMessageConfig() (*conversation.ClearMessageConfig, error) { - return &conversation.ClearMessageConfig{ - Clearer: crossconversation.ConversationManagerImpl, - }, nil -} - -func (s *NodeSchema) ToMessageListConfig() (*conversation.MessageListConfig, error) { - return &conversation.MessageListConfig{ - Lister: crossconversation.ConversationManagerImpl, - }, nil -} - -func (s *NodeSchema) ToIntentDetectorConfig(ctx context.Context) (*intentdetector.Config, error) { - cfg := &intentdetector.Config{ - Intents: mustGetKey[[]string]("Intents", s.Configs), - SystemPrompt: getKeyOrZero[string]("SystemPrompt", s.Configs), - IsFastMode: getKeyOrZero[bool]("IsFastMode", s.Configs), - } - - llmParams := mustGetKey[*model.LLMParams]("LLMParams", s.Configs) - m, _, err := model.GetManager().GetModel(ctx, llmParams) - if err != nil { - return nil, err - } - cfg.ChatModel = m - - return cfg, nil -} - -func (s *NodeSchema) ToSubWorkflowConfig(ctx context.Context, requireCheckpoint bool) (*subworkflow.Config, error) { - var opts []WorkflowOption - opts = append(opts, WithIDAsName(mustGetKey[int64]("WorkflowID", s.Configs))) - if requireCheckpoint { - opts = append(opts, WithParentRequireCheckpoint()) - } - if s := execute.GetStaticConfig(); s != nil && s.MaxNodeCountPerWorkflow > 0 { - opts = append(opts, WithMaxNodeCount(s.MaxNodeCountPerWorkflow)) - } - wf, err := NewWorkflow(ctx, s.SubWorkflowSchema, opts...) - if err != nil { - return nil, err - } - - return &subworkflow.Config{ - Runner: wf.Runner, - }, nil -} - -func totRetrievalSearchType(s int64) (crossknowledge.SearchType, error) { - switch s { - case 0: - return crossknowledge.SearchTypeSemantic, nil - case 1: - return crossknowledge.SearchTypeHybrid, nil - case 20: - return crossknowledge.SearchTypeFullText, nil - default: - return "", fmt.Errorf("invalid retrieval search type %v", s) - } -} diff --git a/backend/domain/workflow/internal/compose/utils.go b/backend/domain/workflow/internal/compose/utils.go deleted file mode 100644 index e1f677a1..00000000 --- a/backend/domain/workflow/internal/compose/utils.go +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Copyright 2025 coze-dev Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package compose - -import ( - "fmt" - "reflect" - - "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" - "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" -) - -func getKeyOrZero[T any](key string, cfg any) T { - var zero T - if cfg == nil { - return zero - } - - m, ok := cfg.(map[string]any) - if !ok { - panic(fmt.Sprintf("m is not a map[string]any, actual type: %v", reflect.TypeOf(cfg))) - } - - if len(m) == 0 { - return zero - } - - if v, ok := m[key]; ok { - return v.(T) - } - - return zero -} - -func mustGetKey[T any](key string, cfg any) T { - if cfg == nil { - panic(fmt.Sprintf("mustGetKey[*any] is nil, key=%s", key)) - } - - m, ok := cfg.(map[string]any) - if !ok { - panic(fmt.Sprintf("m is not a map[string]any, actual type: %v", reflect.TypeOf(cfg))) - } - - if _, ok := m[key]; !ok { - panic(fmt.Sprintf("key %s does not exist in map: %v", key, m)) - } - - v, ok := m[key].(T) - if !ok { - panic(fmt.Sprintf("key %s is not a %v, actual type: %v", key, reflect.TypeOf(v), reflect.TypeOf(m[key]))) - } - - return v -} - -func (s *NodeSchema) SetConfigKV(key string, value any) { - if s.Configs == nil { - s.Configs = make(map[string]any) - } - - s.Configs.(map[string]any)[key] = value -} - -func (s *NodeSchema) SetInputType(key string, t *vo.TypeInfo) { - if s.InputTypes == nil { - s.InputTypes = make(map[string]*vo.TypeInfo) - } - s.InputTypes[key] = t -} - -func (s *NodeSchema) AddInputSource(info ...*vo.FieldInfo) { - s.InputSources = append(s.InputSources, info...) -} - -func (s *NodeSchema) SetOutputType(key string, t *vo.TypeInfo) { - if s.OutputTypes == nil { - s.OutputTypes = make(map[string]*vo.TypeInfo) - } - s.OutputTypes[key] = t -} - -func (s *NodeSchema) AddOutputSource(info ...*vo.FieldInfo) { - s.OutputSources = append(s.OutputSources, info...) -} - -func (s *NodeSchema) GetSubWorkflowIdentity() (int64, string, bool) { - if s.Type != entity.NodeTypeSubWorkflow { - return 0, "", false - } - - return mustGetKey[int64]("WorkflowID", s.Configs), mustGetKey[string]("WorkflowVersion", s.Configs), true -} diff --git a/backend/domain/workflow/internal/compose/workflow.go b/backend/domain/workflow/internal/compose/workflow.go index 391b0219..872d61be 100644 --- a/backend/domain/workflow/internal/compose/workflow.go +++ b/backend/domain/workflow/internal/compose/workflow.go @@ -29,6 +29,8 @@ import ( workflow2 "github.com/coze-dev/coze-studio/backend/domain/workflow" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/pkg/safego" ) @@ -37,7 +39,7 @@ type workflow = compose.Workflow[map[string]any, map[string]any] type Workflow struct { // TODO: too many fields in this struct, cut them down to the absolutely essentials *workflow hierarchy map[vo.NodeKey]vo.NodeKey - connections []*Connection + connections []*schema.Connection requireCheckpoint bool entry *compose.WorkflowNode inner bool @@ -47,7 +49,7 @@ type Workflow struct { // TODO: too many fields in this struct, cut them down to input map[string]*vo.TypeInfo output map[string]*vo.TypeInfo terminatePlan vo.TerminatePlan - schema *WorkflowSchema + schema *schema.WorkflowSchema } type workflowOptions struct { @@ -78,7 +80,7 @@ func WithMaxNodeCount(c int) WorkflowOption { } } -func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption) (*Workflow, error) { +func NewWorkflow(ctx context.Context, sc *schema.WorkflowSchema, opts ...WorkflowOption) (*Workflow, error) { sc.Init() wf := &Workflow{ @@ -88,8 +90,8 @@ func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption schema: sc, } - wf.streamRun = sc.requireStreaming - wf.requireCheckpoint = sc.requireCheckPoint + wf.streamRun = sc.RequireStreaming() + wf.requireCheckpoint = sc.RequireCheckpoint() wfOpts := &workflowOptions{} for _, opt := range opts { @@ -125,7 +127,6 @@ func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption processedNodeKey[child.Key] = struct{}{} } } - // add all nodes other than composite nodes and their children for _, ns := range sc.Nodes { if _, ok := processedNodeKey[ns.Key]; !ok { @@ -135,7 +136,7 @@ func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption } if ns.Type == entity.NodeTypeExit { - wf.terminatePlan = mustGetKey[vo.TerminatePlan]("TerminalPlan", ns.Configs) + wf.terminatePlan = ns.Configs.(*exit.Config).TerminatePlan } } @@ -147,7 +148,7 @@ func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption compileOpts = append(compileOpts, compose.WithGraphName(strconv.FormatInt(wfOpts.wfID, 10))) } - fanInConfigs := sc.fanInMergeConfigs() + fanInConfigs := sc.FanInMergeConfigs() if len(fanInConfigs) > 0 { compileOpts = append(compileOpts, compose.WithFanInMergeConfig(fanInConfigs)) } @@ -199,12 +200,12 @@ type innerWorkflowInfo struct { carryOvers map[vo.NodeKey][]*compose.FieldMapping } -func (w *Workflow) AddNode(ctx context.Context, ns *NodeSchema) error { +func (w *Workflow) AddNode(ctx context.Context, ns *schema.NodeSchema) error { _, err := w.addNodeInternal(ctx, ns, nil) return err } -func (w *Workflow) AddCompositeNode(ctx context.Context, cNode *CompositeNode) error { +func (w *Workflow) AddCompositeNode(ctx context.Context, cNode *schema.CompositeNode) error { inner, err := w.getInnerWorkflow(ctx, cNode) if err != nil { return err @@ -213,11 +214,11 @@ func (w *Workflow) AddCompositeNode(ctx context.Context, cNode *CompositeNode) e return err } -func (w *Workflow) addInnerNode(ctx context.Context, cNode *NodeSchema) (map[vo.NodeKey][]*compose.FieldMapping, error) { +func (w *Workflow) addInnerNode(ctx context.Context, cNode *schema.NodeSchema) (map[vo.NodeKey][]*compose.FieldMapping, error) { return w.addNodeInternal(ctx, cNode, nil) } -func (w *Workflow) addNodeInternal(ctx context.Context, ns *NodeSchema, inner *innerWorkflowInfo) (map[vo.NodeKey][]*compose.FieldMapping, error) { +func (w *Workflow) addNodeInternal(ctx context.Context, ns *schema.NodeSchema, inner *innerWorkflowInfo) (map[vo.NodeKey][]*compose.FieldMapping, error) { key := ns.Key var deps *dependencyInfo @@ -237,7 +238,7 @@ func (w *Workflow) addNodeInternal(ctx context.Context, ns *NodeSchema, inner *i innerWorkflow = inner.inner } - ins, err := ns.New(ctx, innerWorkflow, w.schema, deps) + ins, err := New(ctx, ns, innerWorkflow, w.schema, deps) if err != nil { return nil, err } @@ -245,12 +246,12 @@ func (w *Workflow) addNodeInternal(ctx context.Context, ns *NodeSchema, inner *i var opts []compose.GraphAddNodeOpt opts = append(opts, compose.WithNodeName(string(ns.Key))) - preHandler := ns.StatePreHandler(w.streamRun) + preHandler := statePreHandler(ns, w.streamRun) if preHandler != nil { opts = append(opts, preHandler) } - postHandler := ns.StatePostHandler(w.streamRun) + postHandler := statePostHandler(ns, w.streamRun) if postHandler != nil { opts = append(opts, postHandler) } @@ -297,19 +298,23 @@ func (w *Workflow) addNodeInternal(ctx context.Context, ns *NodeSchema, inner *i w.entry = wNode } - outputPortCount, hasExceptionPort := ns.OutputPortCount() - if outputPortCount > 1 || hasExceptionPort { - bMapping, err := w.resolveBranch(key, outputPortCount) - if err != nil { - return nil, err - } + b := w.schema.GetBranch(ns.Key) + if b != nil { + if b.OnlyException() { + _ = w.AddBranch(string(key), b.GetExceptionBranch()) + } else { + bb, ok := ns.Configs.(schema.BranchBuilder) + if !ok { + return nil, fmt.Errorf("node schema's Configs should implement BranchBuilder, node type= %v", ns.Type) + } - branch, err := ns.GetBranch(bMapping) - if err != nil { - return nil, err - } + br, err := b.GetFullBranch(ctx, bb) + if err != nil { + return nil, err + } - _ = w.AddBranch(string(key), branch) + _ = w.AddBranch(string(key), br) + } } return deps.inputsForParent, nil @@ -328,15 +333,15 @@ func (w *Workflow) Compile(ctx context.Context, opts ...compose.GraphCompileOpti return w.workflow.Compile(ctx, opts...) } -func (w *Workflow) getInnerWorkflow(ctx context.Context, cNode *CompositeNode) (*innerWorkflowInfo, error) { - innerNodes := make(map[vo.NodeKey]*NodeSchema) +func (w *Workflow) getInnerWorkflow(ctx context.Context, cNode *schema.CompositeNode) (*innerWorkflowInfo, error) { + innerNodes := make(map[vo.NodeKey]*schema.NodeSchema) for _, n := range cNode.Children { innerNodes[n.Key] = n } // trim the connections, only keep the connections that are related to the inner workflow // ignore the cases when we have nested inner workflows, because we do not support nested composite nodes - innerConnections := make([]*Connection, 0) + innerConnections := make([]*schema.Connection, 0) for i := range w.schema.Connections { conn := w.schema.Connections[i] if _, ok := innerNodes[conn.FromNode]; ok { @@ -510,7 +515,7 @@ func (d *dependencyInfo) merge(mappings map[vo.NodeKey][]*compose.FieldMapping) // For example, if the 'from path' is ['a', 'b', 'c'], and 'b' is an array, we will take value using a.b[0].c. // As a counter example, if the 'from path' is ['a', 'b', 'c'], and 'b' is not an array, but 'c' is an array, // we will not try to drill, instead, just take value using a.b.c. -func (d *dependencyInfo) arrayDrillDown(allNS map[vo.NodeKey]*NodeSchema) error { +func (d *dependencyInfo) arrayDrillDown(allNS map[vo.NodeKey]*schema.NodeSchema) error { for nKey, fms := range d.inputs { if nKey == compose.START { // reference to START node would NEVER need to do array drill down continue @@ -638,55 +643,6 @@ type variableInfo struct { toPath compose.FieldPath } -func (w *Workflow) resolveBranch(n vo.NodeKey, portCount int) (*BranchMapping, error) { - m := make([]map[string]bool, portCount) - var exception map[string]bool - - for _, conn := range w.connections { - if conn.FromNode != n { - continue - } - - if conn.FromPort == nil { - continue - } - - if *conn.FromPort == "default" { // default condition - if m[portCount-1] == nil { - m[portCount-1] = make(map[string]bool) - } - m[portCount-1][string(conn.ToNode)] = true - } else if *conn.FromPort == "branch_error" { - if exception == nil { - exception = make(map[string]bool) - } - exception[string(conn.ToNode)] = true - } else { - if !strings.HasPrefix(*conn.FromPort, "branch_") { - return nil, fmt.Errorf("outgoing connections has invalid port= %s", *conn.FromPort) - } - - index := (*conn.FromPort)[7:] - i, err := strconv.Atoi(index) - if err != nil { - return nil, fmt.Errorf("outgoing connections has invalid port index= %s", *conn.FromPort) - } - if i < 0 || i >= portCount { - return nil, fmt.Errorf("outgoing connections has invalid port index range= %d, condition count= %d", i, portCount) - } - if m[i] == nil { - m[i] = make(map[string]bool) - } - m[i][string(conn.ToNode)] = true - } - } - - return &BranchMapping{ - Normal: m, - Exception: exception, - }, nil -} - func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.FieldInfo) (*dependencyInfo, error) { var ( inputs = make(map[vo.NodeKey][]*compose.FieldMapping) @@ -701,7 +657,7 @@ func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.Field inputsForParent = make(map[vo.NodeKey][]*compose.FieldMapping) ) - connMap := make(map[vo.NodeKey]Connection) + connMap := make(map[vo.NodeKey]schema.Connection) for _, conn := range w.connections { if conn.ToNode != n { continue @@ -734,7 +690,7 @@ func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.Field continue } - if ok := isInSameWorkflow(w.hierarchy, n, fromNode); ok { + if ok := schema.IsInSameWorkflow(w.hierarchy, n, fromNode); ok { if _, ok := connMap[fromNode]; ok { // direct dependency if len(swp.Source.Ref.FromPath) == 0 && len(swp.Path) == 0 { if inputFull == nil { @@ -755,10 +711,10 @@ func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.Field compose.MapFieldPaths(swp.Source.Ref.FromPath, swp.Path)) } } - } else if ok := isBelowOneLevel(w.hierarchy, n, fromNode); ok { + } else if ok := schema.IsBelowOneLevel(w.hierarchy, n, fromNode); ok { firstNodesInInnerWorkflow := true for _, conn := range connMap { - if isInSameWorkflow(w.hierarchy, n, conn.FromNode) { + if schema.IsInSameWorkflow(w.hierarchy, n, conn.FromNode) { // there is another node 'conn.FromNode' that connects to this node, while also at the same level firstNodesInInnerWorkflow = false break @@ -805,9 +761,9 @@ func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.Field continue } - if isBelowOneLevel(w.hierarchy, n, fromNodeKey) { + if schema.IsBelowOneLevel(w.hierarchy, n, fromNodeKey) { fromNodeKey = compose.START - } else if !isInSameWorkflow(w.hierarchy, n, fromNodeKey) { + } else if !schema.IsInSameWorkflow(w.hierarchy, n, fromNodeKey) { continue } @@ -864,13 +820,13 @@ func (w *Workflow) resolveDependenciesAsParent(n vo.NodeKey, sourceWithPaths []* variableInfos []*variableInfo ) - connMap := make(map[vo.NodeKey]Connection) + connMap := make(map[vo.NodeKey]schema.Connection) for _, conn := range w.connections { if conn.ToNode != n { continue } - if isInSameWorkflow(w.hierarchy, conn.FromNode, n) { + if schema.IsInSameWorkflow(w.hierarchy, conn.FromNode, n) { continue } @@ -899,7 +855,7 @@ func (w *Workflow) resolveDependenciesAsParent(n vo.NodeKey, sourceWithPaths []* swp.Source.Ref.FromPath, swp.Path) } - if ok := isParentOf(w.hierarchy, n, fromNode); ok { + if ok := schema.IsParentOf(w.hierarchy, n, fromNode); ok { if _, ok := connMap[fromNode]; ok { // direct dependency inputs[fromNode] = append(inputs[fromNode], compose.MapFieldPaths(swp.Source.Ref.FromPath, append(compose.FieldPath{string(fromNode)}, swp.Source.Ref.FromPath...))) } else { // indirect dependency diff --git a/backend/domain/workflow/internal/compose/workflow_from_node.go b/backend/domain/workflow/internal/compose/workflow_from_node.go index dcafe06d..9b2db09b 100644 --- a/backend/domain/workflow/internal/compose/workflow_from_node.go +++ b/backend/domain/workflow/internal/compose/workflow_from_node.go @@ -23,9 +23,10 @@ import ( workflow2 "github.com/coze-dev/coze-studio/backend/domain/workflow" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" ) -func NewWorkflowFromNode(ctx context.Context, sc *WorkflowSchema, nodeKey vo.NodeKey, opts ...compose.GraphCompileOption) ( +func NewWorkflowFromNode(ctx context.Context, sc *schema.WorkflowSchema, nodeKey vo.NodeKey, opts ...compose.GraphCompileOption) ( *Workflow, error) { sc.Init() ns := sc.GetNode(nodeKey) @@ -37,7 +38,7 @@ func NewWorkflowFromNode(ctx context.Context, sc *WorkflowSchema, nodeKey vo.Nod schema: sc, fromNode: true, streamRun: false, // single node run can only invoke - requireCheckpoint: sc.requireCheckPoint, + requireCheckpoint: sc.RequireCheckpoint(), input: ns.InputTypes, output: ns.OutputTypes, terminatePlan: vo.ReturnVariables, diff --git a/backend/domain/workflow/internal/compose/workflow_run.go b/backend/domain/workflow/internal/compose/workflow_run.go index 5ea7f22f..8089234e 100644 --- a/backend/domain/workflow/internal/compose/workflow_run.go +++ b/backend/domain/workflow/internal/compose/workflow_run.go @@ -32,6 +32,7 @@ import ( "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa" + schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" "github.com/coze-dev/coze-studio/backend/pkg/lang/ternary" "github.com/coze-dev/coze-studio/backend/pkg/logs" @@ -42,7 +43,7 @@ type WorkflowRunner struct { basic *entity.WorkflowBasic input string resumeReq *entity.ResumeRequest - schema *WorkflowSchema + schema *schema2.WorkflowSchema streamWriter *schema.StreamWriter[*entity.Message] config vo.ExecuteConfig @@ -76,7 +77,7 @@ func WithStreamWriter(sw *schema.StreamWriter[*entity.Message]) WorkflowRunnerOp } } -func NewWorkflowRunner(b *entity.WorkflowBasic, sc *WorkflowSchema, config vo.ExecuteConfig, opts ...WorkflowRunnerOption) *WorkflowRunner { +func NewWorkflowRunner(b *entity.WorkflowBasic, sc *schema2.WorkflowSchema, config vo.ExecuteConfig, opts ...WorkflowRunnerOption) *WorkflowRunner { options := &workflowRunOptions{} for _, opt := range opts { opt(options) diff --git a/backend/domain/workflow/internal/compose/workflow_tool.go b/backend/domain/workflow/internal/compose/workflow_tool.go index e0d8486e..513292da 100644 --- a/backend/domain/workflow/internal/compose/workflow_tool.go +++ b/backend/domain/workflow/internal/compose/workflow_tool.go @@ -30,6 +30,7 @@ import ( "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" + schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/pkg/logs" "github.com/coze-dev/coze-studio/backend/pkg/sonic" ) @@ -41,7 +42,7 @@ type invokableWorkflow struct { invoke func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (map[string]any, error) terminatePlan vo.TerminatePlan wfEntity *entity.Workflow - sc *WorkflowSchema + sc *schema2.WorkflowSchema repo wf.Repository } @@ -49,7 +50,7 @@ func NewInvokableWorkflow(info *schema.ToolInfo, invoke func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (map[string]any, error), terminatePlan vo.TerminatePlan, wfEntity *entity.Workflow, - sc *WorkflowSchema, + sc *schema2.WorkflowSchema, repo wf.Repository, ) wf.ToolFromWorkflow { return &invokableWorkflow{ @@ -112,7 +113,7 @@ func (i *invokableWorkflow) InvokableRun(ctx context.Context, argumentsInJSON st return "", err } - var entryNode *NodeSchema + var entryNode *schema2.NodeSchema for _, node := range i.sc.Nodes { if node.Type == entity.NodeTypeEntry { entryNode = node @@ -190,7 +191,7 @@ type streamableWorkflow struct { stream func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (*schema.StreamReader[map[string]any], error) terminatePlan vo.TerminatePlan wfEntity *entity.Workflow - sc *WorkflowSchema + sc *schema2.WorkflowSchema repo wf.Repository } @@ -198,7 +199,7 @@ func NewStreamableWorkflow(info *schema.ToolInfo, stream func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (*schema.StreamReader[map[string]any], error), terminatePlan vo.TerminatePlan, wfEntity *entity.Workflow, - sc *WorkflowSchema, + sc *schema2.WorkflowSchema, repo wf.Repository, ) wf.ToolFromWorkflow { return &streamableWorkflow{ @@ -261,7 +262,7 @@ func (s *streamableWorkflow) StreamableRun(ctx context.Context, argumentsInJSON return nil, err } - var entryNode *NodeSchema + var entryNode *schema2.NodeSchema for _, node := range s.sc.Nodes { if node.Type == entity.NodeTypeEntry { entryNode = node diff --git a/backend/domain/workflow/internal/nodes/batch/batch.go b/backend/domain/workflow/internal/nodes/batch/batch.go index 95b96a2a..61c422b4 100644 --- a/backend/domain/workflow/internal/nodes/batch/batch.go +++ b/backend/domain/workflow/internal/nodes/batch/batch.go @@ -30,50 +30,108 @@ import ( "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/pkg/logs" "github.com/coze-dev/coze-studio/backend/pkg/safego" ) type Batch struct { - config *Config - outputs map[string]*vo.FieldSource + outputs map[string]*vo.FieldSource + innerWorkflow compose.Runnable[map[string]any, map[string]any] + key vo.NodeKey + inputArrays []string } -type Config struct { - BatchNodeKey vo.NodeKey `json:"batch_node_key"` - InnerWorkflow compose.Runnable[map[string]any, map[string]any] +type Config struct{} - InputArrays []string `json:"input_arrays"` - Outputs []*vo.FieldInfo `json:"outputs"` -} - -func NewBatch(_ context.Context, config *Config) (*Batch, error) { - if config == nil { - return nil, errors.New("config is required") +func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) { + if n.Parent() != nil { + return nil, fmt.Errorf("batch node cannot have parent: %s", n.Parent().ID) } - if len(config.InputArrays) == 0 { + ns := &schema.NodeSchema{ + Key: vo.NodeKey(n.ID), + Type: entity.NodeTypeBatch, + Name: n.Data.Meta.Title, + Configs: c, + } + + batchSizeField, err := convert.CanvasBlockInputToFieldInfo(n.Data.Inputs.BatchSize, + compose.FieldPath{MaxBatchSizeKey}, nil) + if err != nil { + return nil, err + } + ns.AddInputSource(batchSizeField...) + concurrentSizeField, err := convert.CanvasBlockInputToFieldInfo(n.Data.Inputs.ConcurrentSize, + compose.FieldPath{ConcurrentSizeKey}, nil) + if err != nil { + return nil, err + } + ns.AddInputSource(concurrentSizeField...) + + batchSizeType, err := convert.CanvasBlockInputToTypeInfo(n.Data.Inputs.BatchSize) + if err != nil { + return nil, err + } + ns.SetInputType(MaxBatchSizeKey, batchSizeType) + concurrentSizeType, err := convert.CanvasBlockInputToTypeInfo(n.Data.Inputs.ConcurrentSize) + if err != nil { + return nil, err + } + ns.SetInputType(ConcurrentSizeKey, concurrentSizeType) + + if err := convert.SetInputsForNodeSchema(n, ns); err != nil { + return nil, err + } + + if err := convert.SetOutputsForNodeSchema(n, ns); err != nil { + return nil, err + } + + return ns, nil +} + +func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, opts ...schema.BuildOption) (any, error) { + var inputArrays []string + for key, tInfo := range ns.InputTypes { + if tInfo.Type != vo.DataTypeArray { + continue + } + + inputArrays = append(inputArrays, key) + } + + if len(inputArrays) == 0 { return nil, errors.New("need to have at least one incoming array for batch") } - if len(config.Outputs) == 0 { + if len(ns.OutputSources) == 0 { return nil, errors.New("need to have at least one output variable for batch") } - b := &Batch{ - config: config, - outputs: make(map[string]*vo.FieldSource), + bo := schema.GetBuildOptions(opts...) + if bo.Inner == nil { + return nil, errors.New("need to have inner workflow for batch") } - for i := range config.Outputs { - source := config.Outputs[i] + b := &Batch{ + outputs: make(map[string]*vo.FieldSource), + innerWorkflow: bo.Inner, + key: ns.Key, + inputArrays: inputArrays, + } + + for i := range ns.OutputSources { + source := ns.OutputSources[i] path := source.Path if len(path) != 1 { return nil, fmt.Errorf("invalid path %q", path) } + // from which inner node's which field does the batch's output fields come from b.outputs[path[0]] = &source.Source } @@ -97,11 +155,11 @@ func (b *Batch) initOutput(length int) map[string]any { return out } -func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.NestedWorkflowOption) ( +func (b *Batch) Invoke(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) ( out map[string]any, err error) { - arrays := make(map[string]any, len(b.config.InputArrays)) + arrays := make(map[string]any, len(b.inputArrays)) minLen := math.MaxInt64 - for _, arrayKey := range b.config.InputArrays { + for _, arrayKey := range b.inputArrays { a, ok := nodes.TakeMapValue(in, compose.FieldPath{arrayKey}) if !ok { return nil, fmt.Errorf("incoming array not present in input: %s", arrayKey) @@ -160,13 +218,13 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne } } - input[string(b.config.BatchNodeKey)+"#index"] = int64(i) + input[string(b.key)+"#index"] = int64(i) items := make(map[string]any) for arrayKey, array := range arrays { ele := reflect.ValueOf(array).Index(i).Interface() items[arrayKey] = []any{ele} - currentKey := string(b.config.BatchNodeKey) + "#" + arrayKey + currentKey := string(b.key) + "#" + arrayKey // Recursively expand map[string]any elements var expand func(prefix string, val interface{}) @@ -200,15 +258,11 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne return nil } - options := &nodes.NestedWorkflowOptions{} - for _, opt := range opts { - opt(options) - } - + options := nodes.GetCommonOptions(&nodes.NodeOptions{}, opts...) var existingCState *nodes.NestedWorkflowState err = compose.ProcessState(ctx, func(ctx context.Context, getter nodes.NestedWorkflowAware) error { var e error - existingCState, _, e = getter.GetNestedWorkflowState(b.config.BatchNodeKey) + existingCState, _, e = getter.GetNestedWorkflowState(b.key) if e != nil { return e } @@ -280,7 +334,7 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne mu.Unlock() if subCheckpointID != "" { logs.CtxInfof(ctx, "[testInterrupt] prepare %d th run for batch node %s, subCheckPointID %s", - i, b.config.BatchNodeKey, subCheckpointID) + i, b.key, subCheckpointID) ithOpts = append(ithOpts, compose.WithCheckPointID(subCheckpointID)) } @@ -298,7 +352,7 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne // if the innerWorkflow has output emitter that requires stream output, then we need to stream the inner workflow // the output then needs to be concatenated. - taskOutput, err := b.config.InnerWorkflow.Invoke(subCtx, input, ithOpts...) + taskOutput, err := b.innerWorkflow.Invoke(subCtx, input, ithOpts...) if err != nil { info, ok := compose.ExtractInterruptInfo(err) if !ok { @@ -376,17 +430,17 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne if len(index2InterruptInfo) > 0 { // this invocation of batch.Execute has new interruptions iEvent := &entity.InterruptEvent{ - NodeKey: b.config.BatchNodeKey, + NodeKey: b.key, NodeType: entity.NodeTypeBatch, NestedInterruptInfo: index2InterruptInfo, // only emit the newly generated interruptInfo } err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error { - if e := setter.SaveNestedWorkflowState(b.config.BatchNodeKey, compState); e != nil { + if e := setter.SaveNestedWorkflowState(b.key, compState); e != nil { return e } - return setter.SetInterruptEvent(b.config.BatchNodeKey, iEvent) + return setter.SetInterruptEvent(b.key, iEvent) }) if err != nil { return nil, err @@ -398,7 +452,7 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne return nil, compose.InterruptAndRerun } else { err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error { - if e := setter.SaveNestedWorkflowState(b.config.BatchNodeKey, compState); e != nil { + if e := setter.SaveNestedWorkflowState(b.key, compState); e != nil { return e } @@ -409,8 +463,8 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne // although this invocation does not have new interruptions, // this batch node previously have interrupts yet to be resumed. // we overwrite the interrupt events, keeping only the interrupts yet to be resumed. - return setter.SetInterruptEvent(b.config.BatchNodeKey, &entity.InterruptEvent{ - NodeKey: b.config.BatchNodeKey, + return setter.SetInterruptEvent(b.key, &entity.InterruptEvent{ + NodeKey: b.key, NodeType: entity.NodeTypeBatch, NestedInterruptInfo: existingCState.Index2InterruptInfo, }) @@ -424,7 +478,7 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne if existingCState != nil && len(existingCState.Index2InterruptInfo) > 0 { logs.CtxInfof(ctx, "no interrupt thrown this round, but has historical interrupt events yet to be resumed, "+ - "nodeKey: %v. indexes: %v", b.config.BatchNodeKey, maps.Keys(existingCState.Index2InterruptInfo)) + "nodeKey: %v. indexes: %v", b.key, maps.Keys(existingCState.Index2InterruptInfo)) return nil, compose.InterruptAndRerun // interrupt again to wait for resuming of previously interrupted index runs } @@ -432,8 +486,8 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne } func (b *Batch) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) { - trimmed := make(map[string]any, len(b.config.InputArrays)) - for _, arrayKey := range b.config.InputArrays { + trimmed := make(map[string]any, len(b.inputArrays)) + for _, arrayKey := range b.inputArrays { if v, ok := in[arrayKey]; ok { trimmed[arrayKey] = v } diff --git a/backend/domain/workflow/internal/nodes/code/code.go b/backend/domain/workflow/internal/nodes/code/code.go index e647c08a..23b9d9e6 100644 --- a/backend/domain/workflow/internal/nodes/code/code.go +++ b/backend/domain/workflow/internal/nodes/code/code.go @@ -25,6 +25,10 @@ import ( "golang.org/x/exp/maps" + code2 "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/code" + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/infra/contract/coderunner" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" @@ -113,50 +117,77 @@ var pythonThirdPartyWhitelist = map[string]struct{}{ } type Config struct { - Code string - Language coderunner.Language - OutputConfig map[string]*vo.TypeInfo - Runner coderunner.Runner + Code string + Language coderunner.Language + + Runner coderunner.Runner } -type CodeRunner struct { - config *Config - importError error +func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) { + ns := &schema.NodeSchema{ + Key: vo.NodeKey(n.ID), + Type: entity.NodeTypeCodeRunner, + Name: n.Data.Meta.Title, + Configs: c, + } + inputs := n.Data.Inputs + + code := inputs.Code + c.Code = code + + language, err := convertCodeLanguage(inputs.Language) + if err != nil { + return nil, err + } + c.Language = language + + if err := convert.SetInputsForNodeSchema(n, ns); err != nil { + return nil, err + } + + if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil { + return nil, err + } + + return ns, nil } -func NewCodeRunner(ctx context.Context, cfg *Config) (*CodeRunner, error) { - if cfg == nil { - return nil, errors.New("cfg is required") +func convertCodeLanguage(l int64) (coderunner.Language, error) { + switch l { + case 5: + return coderunner.JavaScript, nil + case 3: + return coderunner.Python, nil + default: + return "", fmt.Errorf("invalid language: %d", l) } +} - if cfg.Language == "" { - return nil, errors.New("language is required") - } +func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) { - if cfg.Code == "" { - return nil, errors.New("code is required") - } - - if cfg.Language != coderunner.Python { + if c.Language != coderunner.Python { return nil, errors.New("only support python language") } - if len(cfg.OutputConfig) == 0 { - return nil, errors.New("output config is required") - } + importErr := validatePythonImports(c.Code) - if cfg.Runner == nil { - return nil, errors.New("run coder is required") - } - - importErr := validatePythonImports(cfg.Code) - - return &CodeRunner{ - config: cfg, - importError: importErr, + return &Runner{ + code: c.Code, + language: c.Language, + outputConfig: ns.OutputTypes, + runner: code2.GetCodeRunner(), + importError: importErr, }, nil } +type Runner struct { + outputConfig map[string]*vo.TypeInfo + code string + language coderunner.Language + runner coderunner.Runner + importError error +} + func validatePythonImports(code string) error { imports := parsePythonImports(code) importErrors := make([]string, 0) @@ -191,11 +222,11 @@ func validatePythonImports(code string) error { return nil } -func (c *CodeRunner) RunCode(ctx context.Context, input map[string]any) (ret map[string]any, err error) { +func (c *Runner) Invoke(ctx context.Context, input map[string]any) (ret map[string]any, err error) { if c.importError != nil { return nil, vo.WrapError(errno.ErrCodeExecuteFail, c.importError, errorx.KV("detail", c.importError.Error())) } - response, err := c.config.Runner.Run(ctx, &coderunner.RunRequest{Code: c.config.Code, Language: c.config.Language, Params: input}) + response, err := c.runner.Run(ctx, &coderunner.RunRequest{Code: c.code, Language: c.language, Params: input}) if err != nil { return nil, vo.WrapError(errno.ErrCodeExecuteFail, err, errorx.KV("detail", err.Error())) } @@ -203,7 +234,7 @@ func (c *CodeRunner) RunCode(ctx context.Context, input map[string]any) (ret map result := response.Result ctxcache.Store(ctx, coderRunnerRawOutputCtxKey, result) - output, ws, err := nodes.ConvertInputs(ctx, result, c.config.OutputConfig) + output, ws, err := nodes.ConvertInputs(ctx, result, c.outputConfig) if err != nil { return nil, vo.WrapIfNeeded(errno.ErrCodeExecuteFail, err, errorx.KV("detail", err.Error())) } @@ -217,7 +248,7 @@ func (c *CodeRunner) RunCode(ctx context.Context, input map[string]any) (ret map } -func (c *CodeRunner) ToCallbackOutput(ctx context.Context, output map[string]any) (*nodes.StructuredCallbackOutput, error) { +func (c *Runner) ToCallbackOutput(ctx context.Context, output map[string]any) (*nodes.StructuredCallbackOutput, error) { rawOutput, ok := ctxcache.Get[map[string]any](ctx, coderRunnerRawOutputCtxKey) if !ok { return nil, errors.New("raw output config is required") diff --git a/backend/domain/workflow/internal/nodes/code/code_test.go b/backend/domain/workflow/internal/nodes/code/code_test.go index 69ffb762..6bc5de34 100644 --- a/backend/domain/workflow/internal/nodes/code/code_test.go +++ b/backend/domain/workflow/internal/nodes/code/code_test.go @@ -75,30 +75,29 @@ async def main(args:Args)->Output: mockRunner.EXPECT().Run(gomock.Any(), gomock.Any()).Return(response, nil) ctx := t.Context() - c := &CodeRunner{ - config: &Config{ - Language: coderunner.Python, - Code: codeTpl, - OutputConfig: map[string]*vo.TypeInfo{ - "key0": {Type: vo.DataTypeInteger}, - "key1": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeString}}, - "key2": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}}, - "key3": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{ - "key31": &vo.TypeInfo{Type: vo.DataTypeString}, - "key32": &vo.TypeInfo{Type: vo.DataTypeString}, - "key33": &vo.TypeInfo{Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}}, - "key34": &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{ - "key341": &vo.TypeInfo{Type: vo.DataTypeString}, - "key342": &vo.TypeInfo{Type: vo.DataTypeString}, - }}, - }, - }, - "key4": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject}}, + c := &Runner{ + language: coderunner.Python, + code: codeTpl, + outputConfig: map[string]*vo.TypeInfo{ + "key0": {Type: vo.DataTypeInteger}, + "key1": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeString}}, + "key2": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}}, + "key3": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{ + "key31": {Type: vo.DataTypeString}, + "key32": {Type: vo.DataTypeString}, + "key33": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}}, + "key34": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{ + "key341": {Type: vo.DataTypeString}, + "key342": {Type: vo.DataTypeString}, + }}, }, - Runner: mockRunner, + }, + "key4": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject}}, }, + runner: mockRunner, } - ret, err := c.RunCode(ctx, map[string]any{ + + ret, err := c.Invoke(ctx, map[string]any{ "input": "1123", }) @@ -145,38 +144,36 @@ async def main(args:Args)->Output: mockRunner.EXPECT().Run(gomock.Any(), gomock.Any()).Return(response, nil) ctx := t.Context() - c := &CodeRunner{ - config: &Config{ - Code: codeTpl, - Language: coderunner.Python, - OutputConfig: map[string]*vo.TypeInfo{ - "key0": {Type: vo.DataTypeInteger}, - "key1": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeString}}, - "key2": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}}, - "key3": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{ - "key31": &vo.TypeInfo{Type: vo.DataTypeString}, - "key32": &vo.TypeInfo{Type: vo.DataTypeString}, - "key33": &vo.TypeInfo{Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}}, - "key34": &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{ - "key341": &vo.TypeInfo{Type: vo.DataTypeString}, - "key342": &vo.TypeInfo{Type: vo.DataTypeString}, - }}, + c := &Runner{ + code: codeTpl, + language: coderunner.Python, + outputConfig: map[string]*vo.TypeInfo{ + "key0": {Type: vo.DataTypeInteger}, + "key1": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeString}}, + "key2": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}}, + "key3": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{ + "key31": {Type: vo.DataTypeString}, + "key32": {Type: vo.DataTypeString}, + "key33": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}}, + "key34": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{ + "key341": {Type: vo.DataTypeString}, + "key342": {Type: vo.DataTypeString}, }}, - "key4": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{ - "key31": &vo.TypeInfo{Type: vo.DataTypeString}, - "key32": &vo.TypeInfo{Type: vo.DataTypeString}, - "key33": &vo.TypeInfo{Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}}, - "key34": &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{ - "key341": &vo.TypeInfo{Type: vo.DataTypeString}, - "key342": &vo.TypeInfo{Type: vo.DataTypeString}, - }, - }}, + }}, + "key4": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{ + "key31": {Type: vo.DataTypeString}, + "key32": {Type: vo.DataTypeString}, + "key33": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}}, + "key34": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{ + "key341": {Type: vo.DataTypeString}, + "key342": {Type: vo.DataTypeString}, }, + }}, }, - Runner: mockRunner, }, + runner: mockRunner, } - ret, err := c.RunCode(ctx, map[string]any{ + ret, err := c.Invoke(ctx, map[string]any{ "input": "1123", }) @@ -219,30 +216,28 @@ async def main(args:Args)->Output: } mockRunner.EXPECT().Run(gomock.Any(), gomock.Any()).Return(response, nil) - c := &CodeRunner{ - config: &Config{ - Code: codeTpl, - Language: coderunner.Python, - OutputConfig: map[string]*vo.TypeInfo{ - "key0": {Type: vo.DataTypeInteger}, - "key1": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}}, - "key2": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}}, - "key3": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{ - "key31": &vo.TypeInfo{Type: vo.DataTypeString}, - "key32": &vo.TypeInfo{Type: vo.DataTypeString}, - "key33": &vo.TypeInfo{Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}}, - "key34": &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{ - "key341": &vo.TypeInfo{Type: vo.DataTypeString}, - "key342": &vo.TypeInfo{Type: vo.DataTypeString}, - "key343": &vo.TypeInfo{Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}}, - }}, - }, - }, + c := &Runner{ + code: codeTpl, + language: coderunner.Python, + outputConfig: map[string]*vo.TypeInfo{ + "key0": {Type: vo.DataTypeInteger}, + "key1": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}}, + "key2": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}}, + "key3": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{ + "key31": {Type: vo.DataTypeString}, + "key32": {Type: vo.DataTypeString}, + "key33": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}}, + "key34": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{ + "key341": {Type: vo.DataTypeString}, + "key342": {Type: vo.DataTypeString}, + "key343": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}}, + }}, + }, }, - Runner: mockRunner, }, + runner: mockRunner, } - ret, err := c.RunCode(ctx, map[string]any{ + ret, err := c.Invoke(ctx, map[string]any{ "input": "1123", }) diff --git a/backend/domain/workflow/internal/nodes/database/adapt.go b/backend/domain/workflow/internal/nodes/database/adapt.go new file mode 100644 index 00000000..acf601ec --- /dev/null +++ b/backend/domain/workflow/internal/nodes/database/adapt.go @@ -0,0 +1,236 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package database + +import ( + "fmt" + + einoCompose "github.com/cloudwego/eino/compose" + + "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database" + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" +) + +func setDatabaseInputsForNodeSchema(n *vo.Node, ns *schema.NodeSchema) (err error) { + selectParam := n.Data.Inputs.SelectParam + if selectParam != nil { + err = applyDBConditionToSchema(ns, selectParam.Condition, n.Parent()) + if err != nil { + return err + } + } + + insertParam := n.Data.Inputs.InsertParam + if insertParam != nil { + err = applyInsetFieldInfoToSchema(ns, insertParam.FieldInfo, n.Parent()) + if err != nil { + return err + } + } + + deleteParam := n.Data.Inputs.DeleteParam + if deleteParam != nil { + err = applyDBConditionToSchema(ns, &deleteParam.Condition, n.Parent()) + if err != nil { + return err + } + } + + updateParam := n.Data.Inputs.UpdateParam + if updateParam != nil { + err = applyDBConditionToSchema(ns, &updateParam.Condition, n.Parent()) + if err != nil { + return err + } + err = applyInsetFieldInfoToSchema(ns, updateParam.FieldInfo, n.Parent()) + if err != nil { + return err + } + } + + return nil +} + +func applyDBConditionToSchema(ns *schema.NodeSchema, condition *vo.DBCondition, parentNode *vo.Node) error { + if condition.ConditionList == nil { + return nil + } + + for idx, params := range condition.ConditionList { + var right *vo.Param + for _, param := range params { + if param == nil { + continue + } + if param.Name == "right" { + right = param + break + } + } + + if right == nil { + continue + } + name := fmt.Sprintf("__condition_right_%d", idx) + tInfo, err := convert.CanvasBlockInputToTypeInfo(right.Input) + if err != nil { + return err + } + ns.SetInputType(name, tInfo) + sources, err := convert.CanvasBlockInputToFieldInfo(right.Input, einoCompose.FieldPath{name}, parentNode) + if err != nil { + return err + } + ns.AddInputSource(sources...) + } + + return nil +} + +func applyInsetFieldInfoToSchema(ns *schema.NodeSchema, fieldInfo [][]*vo.Param, parentNode *vo.Node) error { + if len(fieldInfo) == 0 { + return nil + } + + for _, params := range fieldInfo { + // Each FieldInfo is list params, containing two elements. + // The first is to set the name of the field and the second is the corresponding value. + p0 := params[0] + p1 := params[1] + name := p0.Input.Value.Content.(string) // must string type + tInfo, err := convert.CanvasBlockInputToTypeInfo(p1.Input) + if err != nil { + return err + } + name = "__setting_field_" + name + ns.SetInputType(name, tInfo) + sources, err := convert.CanvasBlockInputToFieldInfo(p1.Input, einoCompose.FieldPath{name}, parentNode) + if err != nil { + return err + } + ns.AddInputSource(sources...) + } + + return nil +} + +func buildClauseGroupFromCondition(condition *vo.DBCondition) (*database.ClauseGroup, error) { + clauseGroup := &database.ClauseGroup{} + if len(condition.ConditionList) == 1 { + params := condition.ConditionList[0] + clause, err := buildClauseFromParams(params) + if err != nil { + return nil, err + } + clauseGroup.Single = clause + } else { + relation, err := convertLogicTypeToRelation(condition.Logic) + if err != nil { + return nil, err + } + clauseGroup.Multi = &database.MultiClause{ + Clauses: make([]*database.Clause, 0, len(condition.ConditionList)), + Relation: relation, + } + for i := range condition.ConditionList { + params := condition.ConditionList[i] + clause, err := buildClauseFromParams(params) + if err != nil { + return nil, err + } + clauseGroup.Multi.Clauses = append(clauseGroup.Multi.Clauses, clause) + } + } + + return clauseGroup, nil +} + +func buildClauseFromParams(params []*vo.Param) (*database.Clause, error) { + var left, operation *vo.Param + for _, p := range params { + if p == nil { + continue + } + if p.Name == "left" { + left = p + continue + } + if p.Name == "operation" { + operation = p + continue + } + } + if left == nil { + return nil, fmt.Errorf("left clause is required") + } + if operation == nil { + return nil, fmt.Errorf("operation clause is required") + } + operator, err := operationToOperator(operation.Input.Value.Content.(string)) + if err != nil { + return nil, err + } + clause := &database.Clause{ + Left: left.Input.Value.Content.(string), + Operator: operator, + } + + return clause, nil +} + +func convertLogicTypeToRelation(logicType vo.DatabaseLogicType) (database.ClauseRelation, error) { + switch logicType { + case vo.DatabaseLogicAnd: + return database.ClauseRelationAND, nil + case vo.DatabaseLogicOr: + return database.ClauseRelationOR, nil + default: + return "", fmt.Errorf("logic type %v is invalid", logicType) + } +} + +func operationToOperator(s string) (database.Operator, error) { + switch s { + case "EQUAL": + return database.OperatorEqual, nil + case "NOT_EQUAL": + return database.OperatorNotEqual, nil + case "GREATER_THAN": + return database.OperatorGreater, nil + case "LESS_THAN": + return database.OperatorLesser, nil + case "GREATER_EQUAL": + return database.OperatorGreaterOrEqual, nil + case "LESS_EQUAL": + return database.OperatorLesserOrEqual, nil + case "IN": + return database.OperatorIn, nil + case "NOT_IN": + return database.OperatorNotIn, nil + case "IS_NULL": + return database.OperatorIsNull, nil + case "IS_NOT_NULL": + return database.OperatorIsNotNull, nil + case "LIKE": + return database.OperatorLike, nil + case "NOT_LIKE": + return database.OperatorNotLike, nil + } + return "", fmt.Errorf("not a valid Operation string") +} diff --git a/backend/domain/workflow/internal/nodes/database/common.go b/backend/domain/workflow/internal/nodes/database/common.go index 8c586708..f0e16670 100644 --- a/backend/domain/workflow/internal/nodes/database/common.go +++ b/backend/domain/workflow/internal/nodes/database/common.go @@ -342,7 +342,7 @@ func responseFormatted(configOutput map[string]*vo.TypeInfo, response *database. return ret, nil } -func convertClauseGroupToConditionGroup(ctx context.Context, clauseGroup *database.ClauseGroup, input map[string]any) (*database.ConditionGroup, error) { +func convertClauseGroupToConditionGroup(_ context.Context, clauseGroup *database.ClauseGroup, input map[string]any) (*database.ConditionGroup, error) { var ( rightValue any ok bool @@ -394,13 +394,13 @@ func convertClauseGroupToConditionGroup(ctx context.Context, clauseGroup *databa return conditionGroup, nil } -func convertClauseGroupToUpdateInventory(ctx context.Context, clauseGroup *database.ClauseGroup, input map[string]any) (*UpdateInventory, error) { +func convertClauseGroupToUpdateInventory(ctx context.Context, clauseGroup *database.ClauseGroup, input map[string]any) (*updateInventory, error) { conditionGroup, err := convertClauseGroupToConditionGroup(ctx, clauseGroup, input) if err != nil { return nil, err } fields := parseToInput(input) - inventory := &UpdateInventory{ + inventory := &updateInventory{ ConditionGroup: conditionGroup, Fields: fields, } diff --git a/backend/domain/workflow/internal/nodes/database/customsql.go b/backend/domain/workflow/internal/nodes/database/customsql.go index 8e475a35..3e05a57c 100644 --- a/backend/domain/workflow/internal/nodes/database/customsql.go +++ b/backend/domain/workflow/internal/nodes/database/customsql.go @@ -19,48 +19,89 @@ package database import ( "context" "errors" + "fmt" "reflect" + "strconv" "strings" "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database" + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/pkg/sonic" ) type CustomSQLConfig struct { - DatabaseInfoID int64 - SQLTemplate string - OutputConfig map[string]*vo.TypeInfo - CustomSQLExecutor database.DatabaseOperator + DatabaseInfoID int64 + SQLTemplate string } -func NewCustomSQL(_ context.Context, cfg *CustomSQLConfig) (*CustomSQL, error) { - if cfg == nil { - return nil, errors.New("config is required") +func (c *CustomSQLConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) { + ns := &schema.NodeSchema{ + Key: vo.NodeKey(n.ID), + Type: entity.NodeTypeDatabaseCustomSQL, + Name: n.Data.Meta.Title, + Configs: c, } - if cfg.DatabaseInfoID == 0 { + + dsList := n.Data.Inputs.DatabaseInfoList + if len(dsList) == 0 { + return nil, fmt.Errorf("database info is requird") + } + databaseInfo := dsList[0] + + dsID, err := strconv.ParseInt(databaseInfo.DatabaseInfoID, 10, 64) + if err != nil { + return nil, err + } + c.DatabaseInfoID = dsID + + sql := n.Data.Inputs.SQL + if len(sql) == 0 { + return nil, fmt.Errorf("sql is requird") + } + + c.SQLTemplate = sql + + if err = convert.SetInputsForNodeSchema(n, ns); err != nil { + return nil, err + } + + if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil { + return nil, err + } + + return ns, nil +} + +func (c *CustomSQLConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) { + if c.DatabaseInfoID == 0 { return nil, errors.New("database info id is required and greater than 0") } - if cfg.SQLTemplate == "" { + if c.SQLTemplate == "" { return nil, errors.New("sql template is required") } - if cfg.CustomSQLExecutor == nil { - return nil, errors.New("custom sqler is required") - } + return &CustomSQL{ - config: cfg, + databaseInfoID: c.DatabaseInfoID, + sqlTemplate: c.SQLTemplate, + outputTypes: ns.OutputTypes, + customSQLExecutor: database.GetDatabaseOperator(), }, nil } type CustomSQL struct { - config *CustomSQLConfig + databaseInfoID int64 + sqlTemplate string + outputTypes map[string]*vo.TypeInfo + customSQLExecutor database.DatabaseOperator } -func (c *CustomSQL) Execute(ctx context.Context, input map[string]any) (map[string]any, error) { - +func (c *CustomSQL) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) { req := &database.CustomSQLRequest{ - DatabaseInfoID: c.config.DatabaseInfoID, + DatabaseInfoID: c.databaseInfoID, IsDebugRun: isDebugExecute(ctx), UserID: getExecUserID(ctx), } @@ -71,7 +112,7 @@ func (c *CustomSQL) Execute(ctx context.Context, input map[string]any) (map[stri } templateSQL := "" - templateParts := nodes.ParseTemplate(c.config.SQLTemplate) + templateParts := nodes.ParseTemplate(c.sqlTemplate) sqlParams := make([]database.SQLParam, 0, len(templateParts)) var nilError = errors.New("field is nil") for _, templatePart := range templateParts { @@ -113,12 +154,12 @@ func (c *CustomSQL) Execute(ctx context.Context, input map[string]any) (map[stri templateSQL = strings.Replace(templateSQL, "`?`", "?", -1) req.SQL = templateSQL req.Params = sqlParams - response, err := c.config.CustomSQLExecutor.Execute(ctx, req) + response, err := c.customSQLExecutor.Execute(ctx, req) if err != nil { return nil, err } - ret, err := responseFormatted(c.config.OutputConfig, response) + ret, err := responseFormatted(c.outputTypes, response) if err != nil { return nil, err } diff --git a/backend/domain/workflow/internal/nodes/database/customsql_test.go b/backend/domain/workflow/internal/nodes/database/customsql_test.go index 965dd89d..331b9af6 100644 --- a/backend/domain/workflow/internal/nodes/database/customsql_test.go +++ b/backend/domain/workflow/internal/nodes/database/customsql_test.go @@ -28,6 +28,7 @@ import ( "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database/databasemock" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" ) type mockCustomSQLer struct { @@ -39,7 +40,7 @@ func (m mockCustomSQLer) Execute() func(ctx context.Context, request *database.C m.validate(request) r := &database.Response{ Objects: []database.Object{ - database.Object{ + { "v1": "v1_ret", "v2": "v2_ret", }, @@ -58,9 +59,9 @@ func TestCustomSQL_Execute(t *testing.T) { validate: func(req *database.CustomSQLRequest) { assert.Equal(t, int64(111), req.DatabaseInfoID) ps := []database.SQLParam{ - database.SQLParam{Value: "v1_value"}, - database.SQLParam{Value: "v2_value"}, - database.SQLParam{Value: "v3_value"}, + {Value: "v1_value"}, + {Value: "v2_value"}, + {Value: "v3_value"}, } assert.Equal(t, ps, req.Params) assert.Equal(t, "select * from v1 where v1 = ? and v2 = ? and v3 = ?", req.SQL) @@ -80,23 +81,25 @@ func TestCustomSQL_Execute(t *testing.T) { mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl) mockDatabaseOperator.EXPECT().Execute(gomock.Any(), gomock.Any()).DoAndReturn(mockSQLer.Execute()).AnyTimes() + defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch() + cfg := &CustomSQLConfig{ - DatabaseInfoID: 111, - SQLTemplate: "select * from v1 where v1 = {{v1}} and v2 = '{{v2}}' and v3 = `{{v3}}`", - CustomSQLExecutor: mockDatabaseOperator, - OutputConfig: map[string]*vo.TypeInfo{ + DatabaseInfoID: 111, + SQLTemplate: "select * from v1 where v1 = {{v1}} and v2 = '{{v2}}' and v3 = `{{v3}}`", + } + + c1, err := cfg.Build(context.Background(), &schema.NodeSchema{ + OutputTypes: map[string]*vo.TypeInfo{ "outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{ "v1": {Type: vo.DataTypeString}, "v2": {Type: vo.DataTypeString}, }}}, "rowNum": {Type: vo.DataTypeInteger}, }, - } - cl := &CustomSQL{ - config: cfg, - } + }) + assert.NoError(t, err) - ret, err := cl.Execute(t.Context(), map[string]any{ + ret, err := c1.(*CustomSQL).Invoke(t.Context(), map[string]any{ "v1": "v1_value", "v2": "v2_value", "v3": "v3_value", diff --git a/backend/domain/workflow/internal/nodes/database/delete.go b/backend/domain/workflow/internal/nodes/database/delete.go index 92815bed..fdbb9a3e 100644 --- a/backend/domain/workflow/internal/nodes/database/delete.go +++ b/backend/domain/workflow/internal/nodes/database/delete.go @@ -20,61 +20,102 @@ import ( "context" "errors" "fmt" + "strconv" "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database" + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" ) type DeleteConfig struct { DatabaseInfoID int64 ClauseGroup *database.ClauseGroup - OutputConfig map[string]*vo.TypeInfo - - Deleter database.DatabaseOperator -} -type Delete struct { - config *DeleteConfig } -func NewDelete(_ context.Context, cfg *DeleteConfig) (*Delete, error) { - if cfg == nil { - return nil, errors.New("config is required") +func (d *DeleteConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) { + ns := &schema.NodeSchema{ + Key: vo.NodeKey(n.ID), + Type: entity.NodeTypeDatabaseDelete, + Name: n.Data.Meta.Title, + Configs: d, } - if cfg.DatabaseInfoID == 0 { + + dsList := n.Data.Inputs.DatabaseInfoList + if len(dsList) == 0 { + return nil, fmt.Errorf("database info is requird") + } + databaseInfo := dsList[0] + + dsID, err := strconv.ParseInt(databaseInfo.DatabaseInfoID, 10, 64) + if err != nil { + return nil, err + } + d.DatabaseInfoID = dsID + + deleteParam := n.Data.Inputs.DeleteParam + + clauseGroup, err := buildClauseGroupFromCondition(&deleteParam.Condition) + if err != nil { + return nil, err + } + d.ClauseGroup = clauseGroup + + if err = setDatabaseInputsForNodeSchema(n, ns); err != nil { + return nil, err + } + + if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil { + return nil, err + } + + return ns, nil +} + +func (d *DeleteConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) { + if d.DatabaseInfoID == 0 { return nil, errors.New("database info id is required and greater than 0") } - if cfg.ClauseGroup == nil { + if d.ClauseGroup == nil { return nil, errors.New("clauseGroup is required") } - if cfg.Deleter == nil { - return nil, errors.New("deleter is required") - } return &Delete{ - config: cfg, + databaseInfoID: d.DatabaseInfoID, + clauseGroup: d.ClauseGroup, + outputTypes: ns.OutputTypes, + deleter: database.GetDatabaseOperator(), }, nil - } -func (d *Delete) Delete(ctx context.Context, in map[string]any) (map[string]any, error) { - conditionGroup, err := convertClauseGroupToConditionGroup(ctx, d.config.ClauseGroup, in) +type Delete struct { + databaseInfoID int64 + clauseGroup *database.ClauseGroup + outputTypes map[string]*vo.TypeInfo + deleter database.DatabaseOperator +} + +func (d *Delete) Invoke(ctx context.Context, in map[string]any) (map[string]any, error) { + conditionGroup, err := convertClauseGroupToConditionGroup(ctx, d.clauseGroup, in) if err != nil { return nil, err } request := &database.DeleteRequest{ - DatabaseInfoID: d.config.DatabaseInfoID, + DatabaseInfoID: d.databaseInfoID, ConditionGroup: conditionGroup, IsDebugRun: isDebugExecute(ctx), UserID: getExecUserID(ctx), } - response, err := d.config.Deleter.Delete(ctx, request) + response, err := d.deleter.Delete(ctx, request) if err != nil { return nil, err } - ret, err := responseFormatted(d.config.OutputConfig, response) + ret, err := responseFormatted(d.outputTypes, response) if err != nil { return nil, err } @@ -82,7 +123,7 @@ func (d *Delete) Delete(ctx context.Context, in map[string]any) (map[string]any, } func (d *Delete) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) { - conditionGroup, err := convertClauseGroupToConditionGroup(context.Background(), d.config.ClauseGroup, in) + conditionGroup, err := convertClauseGroupToConditionGroup(context.Background(), d.clauseGroup, in) if err != nil { return nil, err } @@ -90,7 +131,7 @@ func (d *Delete) ToCallbackInput(_ context.Context, in map[string]any) (map[stri } func (d *Delete) toDatabaseDeleteCallbackInput(conditionGroup *database.ConditionGroup) (map[string]any, error) { - databaseID := d.config.DatabaseInfoID + databaseID := d.databaseInfoID result := make(map[string]any) result["databaseInfoList"] = []string{fmt.Sprintf("%d", databaseID)} diff --git a/backend/domain/workflow/internal/nodes/database/insert.go b/backend/domain/workflow/internal/nodes/database/insert.go index cf4b588f..2f25666e 100644 --- a/backend/domain/workflow/internal/nodes/database/insert.go +++ b/backend/domain/workflow/internal/nodes/database/insert.go @@ -20,54 +20,84 @@ import ( "context" "errors" "fmt" + "strconv" "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database" + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" ) type InsertConfig struct { DatabaseInfoID int64 - OutputConfig map[string]*vo.TypeInfo - Inserter database.DatabaseOperator } -type Insert struct { - config *InsertConfig -} - -func NewInsert(_ context.Context, cfg *InsertConfig) (*Insert, error) { - if cfg == nil { - return nil, errors.New("config is required") +func (i *InsertConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) { + ns := &schema.NodeSchema{ + Key: vo.NodeKey(n.ID), + Type: entity.NodeTypeDatabaseInsert, + Name: n.Data.Meta.Title, + Configs: i, } - if cfg.DatabaseInfoID == 0 { + + dsList := n.Data.Inputs.DatabaseInfoList + if len(dsList) == 0 { + return nil, fmt.Errorf("database info is requird") + } + databaseInfo := dsList[0] + + dsID, err := strconv.ParseInt(databaseInfo.DatabaseInfoID, 10, 64) + if err != nil { + return nil, err + } + i.DatabaseInfoID = dsID + + if err = setDatabaseInputsForNodeSchema(n, ns); err != nil { + return nil, err + } + + if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil { + return nil, err + } + + return ns, nil +} + +func (i *InsertConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) { + if i.DatabaseInfoID == 0 { return nil, errors.New("database info id is required and greater than 0") } - if cfg.Inserter == nil { - return nil, errors.New("inserter is required") - } return &Insert{ - config: cfg, + databaseInfoID: i.DatabaseInfoID, + outputTypes: ns.OutputTypes, + inserter: database.GetDatabaseOperator(), }, nil - } -func (is *Insert) Insert(ctx context.Context, input map[string]any) (map[string]any, error) { +type Insert struct { + databaseInfoID int64 + outputTypes map[string]*vo.TypeInfo + inserter database.DatabaseOperator +} +func (is *Insert) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) { fields := parseToInput(input) req := &database.InsertRequest{ - DatabaseInfoID: is.config.DatabaseInfoID, + DatabaseInfoID: is.databaseInfoID, Fields: fields, IsDebugRun: isDebugExecute(ctx), UserID: getExecUserID(ctx), } - response, err := is.config.Inserter.Insert(ctx, req) + response, err := is.inserter.Insert(ctx, req) if err != nil { return nil, err } - ret, err := responseFormatted(is.config.OutputConfig, response) + ret, err := responseFormatted(is.outputTypes, response) if err != nil { return nil, err } @@ -76,7 +106,7 @@ func (is *Insert) Insert(ctx context.Context, input map[string]any) (map[string] } func (is *Insert) ToCallbackInput(_ context.Context, input map[string]any) (map[string]any, error) { - databaseID := is.config.DatabaseInfoID + databaseID := is.databaseInfoID fs := parseToInput(input) result := make(map[string]any) result["databaseInfoList"] = []string{fmt.Sprintf("%d", databaseID)} diff --git a/backend/domain/workflow/internal/nodes/database/query.go b/backend/domain/workflow/internal/nodes/database/query.go index 9c20f782..c32f7795 100644 --- a/backend/domain/workflow/internal/nodes/database/query.go +++ b/backend/domain/workflow/internal/nodes/database/query.go @@ -20,68 +20,137 @@ import ( "context" "errors" "fmt" + "strconv" "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database" + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" ) type QueryConfig struct { DatabaseInfoID int64 QueryFields []string OrderClauses []*database.OrderClause - OutputConfig map[string]*vo.TypeInfo ClauseGroup *database.ClauseGroup Limit int64 - Op database.DatabaseOperator } -type Query struct { - config *QueryConfig -} - -func NewQuery(_ context.Context, cfg *QueryConfig) (*Query, error) { - if cfg == nil { - return nil, errors.New("config is required") +func (q *QueryConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) { + ns := &schema.NodeSchema{ + Key: vo.NodeKey(n.ID), + Type: entity.NodeTypeDatabaseQuery, + Name: n.Data.Meta.Title, + Configs: q, } - if cfg.DatabaseInfoID == 0 { + + dsList := n.Data.Inputs.DatabaseInfoList + if len(dsList) == 0 { + return nil, fmt.Errorf("database info is requird") + } + databaseInfo := dsList[0] + + dsID, err := strconv.ParseInt(databaseInfo.DatabaseInfoID, 10, 64) + if err != nil { + return nil, err + } + q.DatabaseInfoID = dsID + + selectParam := n.Data.Inputs.SelectParam + q.Limit = selectParam.Limit + + queryFields := make([]string, 0) + for _, v := range selectParam.FieldList { + queryFields = append(queryFields, strconv.FormatInt(v.FieldID, 10)) + } + q.QueryFields = queryFields + + orderClauses := make([]*database.OrderClause, 0, len(selectParam.OrderByList)) + for _, o := range selectParam.OrderByList { + orderClauses = append(orderClauses, &database.OrderClause{ + FieldID: strconv.FormatInt(o.FieldID, 10), + IsAsc: o.IsAsc, + }) + } + q.OrderClauses = orderClauses + + clauseGroup := &database.ClauseGroup{} + + if selectParam.Condition != nil { + clauseGroup, err = buildClauseGroupFromCondition(selectParam.Condition) + if err != nil { + return nil, err + } + } + + q.ClauseGroup = clauseGroup + + if err = setDatabaseInputsForNodeSchema(n, ns); err != nil { + return nil, err + } + + if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil { + return nil, err + } + + return ns, nil +} + +func (q *QueryConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) { + if q.DatabaseInfoID == 0 { return nil, errors.New("database info id is required and greater than 0") } - if cfg.Limit == 0 { + if q.Limit == 0 { return nil, errors.New("limit is required and greater than 0") } - if cfg.Op == nil { - return nil, errors.New("op is required") - } - - return &Query{config: cfg}, nil - + return &Query{ + databaseInfoID: q.DatabaseInfoID, + queryFields: q.QueryFields, + orderClauses: q.OrderClauses, + outputTypes: ns.OutputTypes, + clauseGroup: q.ClauseGroup, + limit: q.Limit, + op: database.GetDatabaseOperator(), + }, nil } -func (ds *Query) Query(ctx context.Context, in map[string]any) (map[string]any, error) { - conditionGroup, err := convertClauseGroupToConditionGroup(ctx, ds.config.ClauseGroup, in) +type Query struct { + databaseInfoID int64 + queryFields []string + orderClauses []*database.OrderClause + outputTypes map[string]*vo.TypeInfo + clauseGroup *database.ClauseGroup + limit int64 + op database.DatabaseOperator +} + +func (ds *Query) Invoke(ctx context.Context, in map[string]any) (map[string]any, error) { + conditionGroup, err := convertClauseGroupToConditionGroup(ctx, ds.clauseGroup, in) if err != nil { return nil, err } req := &database.QueryRequest{ - DatabaseInfoID: ds.config.DatabaseInfoID, - OrderClauses: ds.config.OrderClauses, - SelectFields: ds.config.QueryFields, - Limit: ds.config.Limit, + DatabaseInfoID: ds.databaseInfoID, + OrderClauses: ds.orderClauses, + SelectFields: ds.queryFields, + Limit: ds.limit, IsDebugRun: isDebugExecute(ctx), UserID: getExecUserID(ctx), } req.ConditionGroup = conditionGroup - response, err := ds.config.Op.Query(ctx, req) + response, err := ds.op.Query(ctx, req) if err != nil { return nil, err } - ret, err := responseFormatted(ds.config.OutputConfig, response) + ret, err := responseFormatted(ds.outputTypes, response) if err != nil { return nil, err } @@ -93,18 +162,18 @@ func notNeedTakeMapValue(op database.Operator) bool { } func (ds *Query) ToCallbackInput(ctx context.Context, in map[string]any) (map[string]any, error) { - conditionGroup, err := convertClauseGroupToConditionGroup(ctx, ds.config.ClauseGroup, in) + conditionGroup, err := convertClauseGroupToConditionGroup(ctx, ds.clauseGroup, in) if err != nil { return nil, err } - return toDatabaseQueryCallbackInput(ds.config, conditionGroup) + return ds.toDatabaseQueryCallbackInput(conditionGroup) } -func toDatabaseQueryCallbackInput(config *QueryConfig, conditionGroup *database.ConditionGroup) (map[string]any, error) { +func (ds *Query) toDatabaseQueryCallbackInput(conditionGroup *database.ConditionGroup) (map[string]any, error) { result := make(map[string]any) - databaseID := config.DatabaseInfoID + databaseID := ds.databaseInfoID result["databaseInfoList"] = []string{fmt.Sprintf("%d", databaseID)} result["selectParam"] = map[string]any{} @@ -116,8 +185,8 @@ func toDatabaseQueryCallbackInput(config *QueryConfig, conditionGroup *database. FieldID string `json:"fieldId"` IsDistinct bool `json:"isDistinct"` } - fieldList := make([]Field, 0, len(config.QueryFields)) - for _, f := range config.QueryFields { + fieldList := make([]Field, 0, len(ds.queryFields)) + for _, f := range ds.queryFields { fieldList = append(fieldList, Field{FieldID: f}) } type Order struct { @@ -126,7 +195,7 @@ func toDatabaseQueryCallbackInput(config *QueryConfig, conditionGroup *database. } OrderList := make([]Order, 0) - for _, c := range config.OrderClauses { + for _, c := range ds.orderClauses { OrderList = append(OrderList, Order{ FieldID: c.FieldID, IsAsc: c.IsAsc, @@ -135,12 +204,11 @@ func toDatabaseQueryCallbackInput(config *QueryConfig, conditionGroup *database. result["selectParam"] = map[string]any{ "condition": condition, "fieldList": fieldList, - "limit": config.Limit, + "limit": ds.limit, "orderByList": OrderList, } return result, nil - } type ConditionItem struct { @@ -216,6 +284,5 @@ func convertToLogic(rel database.ClauseRelation) (string, error) { return "AND", nil default: return "", fmt.Errorf("unknown clause relation %v", rel) - } } diff --git a/backend/domain/workflow/internal/nodes/database/query_test.go b/backend/domain/workflow/internal/nodes/database/query_test.go index 010b0051..04678d58 100644 --- a/backend/domain/workflow/internal/nodes/database/query_test.go +++ b/backend/domain/workflow/internal/nodes/database/query_test.go @@ -30,6 +30,7 @@ import ( "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database/databasemock" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" ) type mockDsSelect struct { @@ -82,16 +83,7 @@ func TestDataset_Query(t *testing.T) { }, OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}}, QueryFields: []string{"v1", "v2"}, - OutputConfig: map[string]*vo.TypeInfo{ - "outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{ - Type: vo.DataTypeObject, - Properties: map[string]*vo.TypeInfo{ - "v1": {Type: vo.DataTypeString}, - "v2": {Type: vo.DataTypeString}, - }, - }}, - "rowNum": {Type: vo.DataTypeInteger}, - }, + Limit: 10, } mockQuery := &mockDsSelect{objects: objects, t: t, validate: func(request *database.QueryRequest) { @@ -106,17 +98,27 @@ func TestDataset_Query(t *testing.T) { mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl) mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()) - cfg.Op = mockDatabaseOperator + defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch() - ds := Query{ - config: cfg, - } + ds, err := cfg.Build(context.Background(), &schema.NodeSchema{ + OutputTypes: map[string]*vo.TypeInfo{ + "outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{ + Type: vo.DataTypeObject, + Properties: map[string]*vo.TypeInfo{ + "v1": {Type: vo.DataTypeString}, + "v2": {Type: vo.DataTypeString}, + }, + }}, + "rowNum": {Type: vo.DataTypeInteger}, + }, + }) + assert.NoError(t, err) in := map[string]interface{}{ "__condition_right_0": 1, } - result, err := ds.Query(t.Context(), in) + result, err := ds.(*Query).Invoke(t.Context(), in) assert.NoError(t, err) assert.Equal(t, "1", result["outputList"].([]any)[0].(database.Object)["v1"]) assert.Equal(t, "2", result["outputList"].([]any)[0].(database.Object)["v2"]) @@ -137,17 +139,7 @@ func TestDataset_Query(t *testing.T) { OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}}, QueryFields: []string{"v1", "v2"}, - - OutputConfig: map[string]*vo.TypeInfo{ - "outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{ - Type: vo.DataTypeObject, - Properties: map[string]*vo.TypeInfo{ - "v1": {Type: vo.DataTypeString}, - "v2": {Type: vo.DataTypeString}, - }, - }}, - "rowNum": {Type: vo.DataTypeInteger}, - }, + Limit: 10, } objects := make([]database.Object, 0) @@ -170,18 +162,28 @@ func TestDataset_Query(t *testing.T) { mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl) mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes() - cfg.Op = mockDatabaseOperator + defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch() - ds := Query{ - config: cfg, - } + ds, err := cfg.Build(context.Background(), &schema.NodeSchema{ + OutputTypes: map[string]*vo.TypeInfo{ + "outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{ + Type: vo.DataTypeObject, + Properties: map[string]*vo.TypeInfo{ + "v1": {Type: vo.DataTypeString}, + "v2": {Type: vo.DataTypeString}, + }, + }}, + "rowNum": {Type: vo.DataTypeInteger}, + }, + }) + assert.NoError(t, err) in := map[string]any{ "__condition_right_0": 1, "__condition_right_1": 2, } - result, err := ds.Query(t.Context(), in) + result, err := ds.(*Query).Invoke(t.Context(), in) assert.NoError(t, err) assert.NoError(t, err) assert.Equal(t, "1", result["outputList"].([]any)[0].(database.Object)["v1"]) @@ -199,17 +201,7 @@ func TestDataset_Query(t *testing.T) { }, OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}}, QueryFields: []string{"v1", "v2"}, - - OutputConfig: map[string]*vo.TypeInfo{ - "outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{ - Type: vo.DataTypeObject, - Properties: map[string]*vo.TypeInfo{ - "v1": {Type: vo.DataTypeInteger}, - "v2": {Type: vo.DataTypeInteger}, - }, - }}, - "rowNum": {Type: vo.DataTypeInteger}, - }, + Limit: 10, } objects := make([]database.Object, 0) objects = append(objects, database.Object{ @@ -230,17 +222,27 @@ func TestDataset_Query(t *testing.T) { mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl) mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes() - cfg.Op = mockDatabaseOperator + defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch() - ds := Query{ - config: cfg, - } + ds, err := cfg.Build(context.Background(), &schema.NodeSchema{ + OutputTypes: map[string]*vo.TypeInfo{ + "outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{ + Type: vo.DataTypeObject, + Properties: map[string]*vo.TypeInfo{ + "v1": {Type: vo.DataTypeInteger}, + "v2": {Type: vo.DataTypeInteger}, + }, + }}, + "rowNum": {Type: vo.DataTypeInteger}, + }, + }) + assert.NoError(t, err) in := map[string]any{ "__condition_right_0": 1, } - result, err := ds.Query(t.Context(), in) + result, err := ds.(*Query).Invoke(t.Context(), in) assert.NoError(t, err) fmt.Println(result) assert.Equal(t, map[string]any{ @@ -261,18 +263,7 @@ func TestDataset_Query(t *testing.T) { }, OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}}, QueryFields: []string{"v1", "v2"}, - - OutputConfig: map[string]*vo.TypeInfo{ - "outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{ - Type: vo.DataTypeObject, - Properties: map[string]*vo.TypeInfo{ - "v1": {Type: vo.DataTypeInteger}, - "v2": {Type: vo.DataTypeInteger}, - "v3": {Type: vo.DataTypeInteger}, - }, - }}, - "rowNum": {Type: vo.DataTypeInteger}, - }, + Limit: 10, } objects := make([]database.Object, 0) objects = append(objects, database.Object{ @@ -290,15 +281,26 @@ func TestDataset_Query(t *testing.T) { mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl) mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes() - cfg.Op = mockDatabaseOperator + defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch() - ds := Query{ - config: cfg, - } + ds, err := cfg.Build(context.Background(), &schema.NodeSchema{ + OutputTypes: map[string]*vo.TypeInfo{ + "outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{ + Type: vo.DataTypeObject, + Properties: map[string]*vo.TypeInfo{ + "v1": {Type: vo.DataTypeInteger}, + "v2": {Type: vo.DataTypeInteger}, + "v3": {Type: vo.DataTypeInteger}, + }, + }}, + "rowNum": {Type: vo.DataTypeInteger}, + }, + }) + assert.NoError(t, err) in := map[string]any{"__condition_right_0": 1} - result, err := ds.Query(t.Context(), in) + result, err := ds.(*Query).Invoke(t.Context(), in) assert.NoError(t, err) fmt.Println(result) assert.Equal(t, int64(1), result["outputList"].([]any)[0].(database.Object)["v1"]) @@ -321,22 +323,7 @@ func TestDataset_Query(t *testing.T) { }, OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}}, QueryFields: []string{"v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8"}, - - OutputConfig: map[string]*vo.TypeInfo{ - "outputList": {Type: vo.DataTypeArray, - ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{ - "v1": {Type: vo.DataTypeInteger}, - "v2": {Type: vo.DataTypeNumber}, - "v3": {Type: vo.DataTypeBoolean}, - "v4": {Type: vo.DataTypeBoolean}, - "v5": {Type: vo.DataTypeTime}, - "v6": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeInteger}}, - "v7": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeBoolean}}, - "v8": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}}, - }, - }}, - "rowNum": {Type: vo.DataTypeInteger}, - }, + Limit: 10, } objects := make([]database.Object, 0) @@ -363,17 +350,32 @@ func TestDataset_Query(t *testing.T) { mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl) mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes() - cfg.Op = mockDatabaseOperator + defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch() - ds := Query{ - config: cfg, - } + ds, err := cfg.Build(context.Background(), &schema.NodeSchema{ + OutputTypes: map[string]*vo.TypeInfo{ + "outputList": {Type: vo.DataTypeArray, + ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{ + "v1": {Type: vo.DataTypeInteger}, + "v2": {Type: vo.DataTypeNumber}, + "v3": {Type: vo.DataTypeBoolean}, + "v4": {Type: vo.DataTypeBoolean}, + "v5": {Type: vo.DataTypeTime}, + "v6": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeInteger}}, + "v7": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeBoolean}}, + "v8": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}}, + }, + }}, + "rowNum": {Type: vo.DataTypeInteger}, + }, + }) + assert.NoError(t, err) in := map[string]any{ "__condition_right_0": 1, } - result, err := ds.Query(t.Context(), in) + result, err := ds.(*Query).Invoke(t.Context(), in) assert.NoError(t, err) object := result["outputList"].([]any)[0].(database.Object) @@ -400,10 +402,7 @@ func TestDataset_Query(t *testing.T) { }, OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}}, QueryFields: []string{"v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8"}, - OutputConfig: map[string]*vo.TypeInfo{ - "outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{}}}, - "rowNum": {Type: vo.DataTypeInteger}, - }, + Limit: 10, } objects := make([]database.Object, 0) @@ -429,16 +428,21 @@ func TestDataset_Query(t *testing.T) { mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl) mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes() - cfg.Op = mockDatabaseOperator - ds := Query{ - config: cfg, - } + defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch() + + ds, err := cfg.Build(context.Background(), &schema.NodeSchema{ + OutputTypes: map[string]*vo.TypeInfo{ + "outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{}}}, + "rowNum": {Type: vo.DataTypeInteger}, + }, + }) + assert.NoError(t, err) in := map[string]any{ "__condition_right_0": 1, } - result, err := ds.Query(t.Context(), in) + result, err := ds.(*Query).Invoke(t.Context(), in) assert.NoError(t, err) assert.Equal(t, result["outputList"].([]any)[0].(database.Object), database.Object{ "v1": "1", diff --git a/backend/domain/workflow/internal/nodes/database/update.go b/backend/domain/workflow/internal/nodes/database/update.go index 0e32a9a0..8e3e602b 100644 --- a/backend/domain/workflow/internal/nodes/database/update.go +++ b/backend/domain/workflow/internal/nodes/database/update.go @@ -20,47 +20,93 @@ import ( "context" "errors" "fmt" + "strconv" "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database" + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" ) type UpdateConfig struct { DatabaseInfoID int64 ClauseGroup *database.ClauseGroup - OutputConfig map[string]*vo.TypeInfo - Updater database.DatabaseOperator +} + +func (u *UpdateConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) { + ns := &schema.NodeSchema{ + Key: vo.NodeKey(n.ID), + Type: entity.NodeTypeDatabaseUpdate, + Name: n.Data.Meta.Title, + Configs: u, + } + + dsList := n.Data.Inputs.DatabaseInfoList + if len(dsList) == 0 { + return nil, fmt.Errorf("database info is requird") + } + databaseInfo := dsList[0] + + dsID, err := strconv.ParseInt(databaseInfo.DatabaseInfoID, 10, 64) + if err != nil { + return nil, err + } + u.DatabaseInfoID = dsID + + updateParam := n.Data.Inputs.UpdateParam + if updateParam == nil { + return nil, fmt.Errorf("update param is requird") + } + clauseGroup, err := buildClauseGroupFromCondition(&updateParam.Condition) + if err != nil { + return nil, err + } + u.ClauseGroup = clauseGroup + + if err = setDatabaseInputsForNodeSchema(n, ns); err != nil { + return nil, err + } + + if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil { + return nil, err + } + + return ns, nil +} + +func (u *UpdateConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) { + if u.DatabaseInfoID == 0 { + return nil, errors.New("database info id is required and greater than 0") + } + + if u.ClauseGroup == nil { + return nil, errors.New("clause group is required and greater than 0") + } + + return &Update{ + databaseInfoID: u.DatabaseInfoID, + clauseGroup: u.ClauseGroup, + outputTypes: ns.OutputTypes, + updater: database.GetDatabaseOperator(), + }, nil } type Update struct { - config *UpdateConfig + databaseInfoID int64 + clauseGroup *database.ClauseGroup + outputTypes map[string]*vo.TypeInfo + updater database.DatabaseOperator } -type UpdateInventory struct { + +type updateInventory struct { ConditionGroup *database.ConditionGroup Fields map[string]any } -func NewUpdate(_ context.Context, cfg *UpdateConfig) (*Update, error) { - if cfg == nil { - return nil, errors.New("config is required") - } - if cfg.DatabaseInfoID == 0 { - return nil, errors.New("database info id is required and greater than 0") - } - - if cfg.ClauseGroup == nil { - return nil, errors.New("clause group is required and greater than 0") - } - - if cfg.Updater == nil { - return nil, errors.New("updater is required") - } - - return &Update{config: cfg}, nil -} - -func (u *Update) Update(ctx context.Context, in map[string]any) (map[string]any, error) { - inventory, err := convertClauseGroupToUpdateInventory(ctx, u.config.ClauseGroup, in) +func (u *Update) Invoke(ctx context.Context, in map[string]any) (map[string]any, error) { + inventory, err := convertClauseGroupToUpdateInventory(ctx, u.clauseGroup, in) if err != nil { return nil, err } @@ -72,20 +118,20 @@ func (u *Update) Update(ctx context.Context, in map[string]any) (map[string]any, } req := &database.UpdateRequest{ - DatabaseInfoID: u.config.DatabaseInfoID, + DatabaseInfoID: u.databaseInfoID, ConditionGroup: inventory.ConditionGroup, Fields: fields, IsDebugRun: isDebugExecute(ctx), UserID: getExecUserID(ctx), } - response, err := u.config.Updater.Update(ctx, req) + response, err := u.updater.Update(ctx, req) if err != nil { return nil, err } - ret, err := responseFormatted(u.config.OutputConfig, response) + ret, err := responseFormatted(u.outputTypes, response) if err != nil { return nil, err } @@ -94,15 +140,15 @@ func (u *Update) Update(ctx context.Context, in map[string]any) (map[string]any, } func (u *Update) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) { - inventory, err := convertClauseGroupToUpdateInventory(context.Background(), u.config.ClauseGroup, in) + inventory, err := convertClauseGroupToUpdateInventory(context.Background(), u.clauseGroup, in) if err != nil { return nil, err } return u.toDatabaseUpdateCallbackInput(inventory) } -func (u *Update) toDatabaseUpdateCallbackInput(inventory *UpdateInventory) (map[string]any, error) { - databaseID := u.config.DatabaseInfoID +func (u *Update) toDatabaseUpdateCallbackInput(inventory *updateInventory) (map[string]any, error) { + databaseID := u.databaseInfoID result := make(map[string]any) result["databaseInfoList"] = []string{fmt.Sprintf("%d", databaseID)} result["updateParam"] = map[string]any{} @@ -128,6 +174,6 @@ func (u *Update) toDatabaseUpdateCallbackInput(inventory *UpdateInventory) (map[ "condition": condition, "fieldInfo": fieldInfo, } - return result, nil + return result, nil } diff --git a/backend/domain/workflow/internal/nodes/emitter/emitter.go b/backend/domain/workflow/internal/nodes/emitter/emitter.go index 6ed1743c..9453202a 100644 --- a/backend/domain/workflow/internal/nodes/emitter/emitter.go +++ b/backend/domain/workflow/internal/nodes/emitter/emitter.go @@ -18,7 +18,6 @@ package emitter import ( "context" - "errors" "fmt" "io" "strings" @@ -26,28 +25,77 @@ import ( "github.com/bytedance/sonic" "github.com/cloudwego/eino/schema" + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" + schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/pkg/logs" "github.com/coze-dev/coze-studio/backend/pkg/safego" ) type OutputEmitter struct { - cfg *Config + Template string + FullSources map[string]*schema2.SourceInfo } type Config struct { - Template string - FullSources map[string]*nodes.SourceInfo + Template string } -func New(_ context.Context, cfg *Config) (*OutputEmitter, error) { - if cfg == nil { - return nil, errors.New("config is required") +func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema2.NodeSchema, error) { + ns := &schema2.NodeSchema{ + Key: vo.NodeKey(n.ID), + Type: entity.NodeTypeOutputEmitter, + Name: n.Data.Meta.Title, + Configs: c, } + content := n.Data.Inputs.Content + streamingOutput := n.Data.Inputs.StreamingOutput + + if streamingOutput { + ns.StreamConfigs = &schema2.StreamConfig{ + RequireStreamingInput: true, + } + } else { + ns.StreamConfigs = &schema2.StreamConfig{ + RequireStreamingInput: false, + } + } + + if content != nil { + if content.Type != vo.VariableTypeString { + return nil, fmt.Errorf("output emitter node's content type must be %s, got %s", vo.VariableTypeString, content.Type) + } + + if content.Value.Type != vo.BlockInputValueTypeLiteral { + return nil, fmt.Errorf("output emitter node's content value type must be %s, got %s", vo.BlockInputValueTypeLiteral, content.Value.Type) + } + + if content.Value.Content == nil { + c.Template = "" + } else { + template, ok := content.Value.Content.(string) + if !ok { + return nil, fmt.Errorf("output emitter node's content value must be string, got %v", content.Value.Content) + } + + c.Template = template + } + } + + if err := convert.SetInputsForNodeSchema(n, ns); err != nil { + return nil, err + } + + return ns, nil +} + +func (c *Config) Build(_ context.Context, ns *schema2.NodeSchema, _ ...schema2.BuildOption) (any, error) { return &OutputEmitter{ - cfg: cfg, + Template: c.Template, + FullSources: ns.FullSources, }, nil } @@ -59,10 +107,10 @@ type cachedVal struct { type cacheStore struct { store map[string]*cachedVal - infos map[string]*nodes.SourceInfo + infos map[string]*schema2.SourceInfo } -func newCacheStore(infos map[string]*nodes.SourceInfo) *cacheStore { +func newCacheStore(infos map[string]*schema2.SourceInfo) *cacheStore { return &cacheStore{ store: make(map[string]*cachedVal), infos: infos, @@ -76,7 +124,7 @@ func (c *cacheStore) put(k string, v any) (any, error) { } if !sInfo.IsIntermediate { // this is not an intermediate object container - isStream := sInfo.FieldType == nodes.FieldIsStream + isStream := sInfo.FieldType == schema2.FieldIsStream if !isStream { _, ok := c.store[k] if !ok { @@ -159,7 +207,7 @@ func (c *cacheStore) put(k string, v any) (any, error) { func (c *cacheStore) finished(k string) bool { cached, ok := c.store[k] if !ok { - return c.infos[k].FieldType == nodes.FieldSkipped + return c.infos[k].FieldType == schema2.FieldSkipped } if cached.finished { @@ -182,7 +230,7 @@ func (c *cacheStore) finished(k string) bool { return true } -func (c *cacheStore) find(part nodes.TemplatePart) (root any, subCache *cachedVal, sourceInfo *nodes.SourceInfo, +func (c *cacheStore) find(part nodes.TemplatePart) (root any, subCache *cachedVal, sourceInfo *schema2.SourceInfo, actualPath []string, ) { rootCached, ok := c.store[part.Root] @@ -230,7 +278,7 @@ func (c *cacheStore) readyForPart(part nodes.TemplatePart, sw *schema.StreamWrit hasErr bool, partFinished bool) { cachedRoot, subCache, sourceInfo, _ := c.find(part) if cachedRoot != nil && subCache != nil { - if subCache.finished || sourceInfo.FieldType == nodes.FieldIsStream { + if subCache.finished || sourceInfo.FieldType == schema2.FieldIsStream { hasErr = renderAndSend(part, part.Root, cachedRoot, sw) if hasErr { return true, false @@ -315,14 +363,14 @@ func merge(a, b any) any { const outputKey = "output" -func (e *OutputEmitter) EmitStream(ctx context.Context, in *schema.StreamReader[map[string]any]) (out *schema.StreamReader[map[string]any], err error) { - resolvedSources, err := nodes.ResolveStreamSources(ctx, e.cfg.FullSources) +func (e *OutputEmitter) Transform(ctx context.Context, in *schema.StreamReader[map[string]any]) (out *schema.StreamReader[map[string]any], err error) { + resolvedSources, err := nodes.ResolveStreamSources(ctx, e.FullSources) if err != nil { return nil, err } sr, sw := schema.Pipe[map[string]any](0) - parts := nodes.ParseTemplate(e.cfg.Template) + parts := nodes.ParseTemplate(e.Template) safego.Go(ctx, func() { hasErr := false defer func() { @@ -454,7 +502,7 @@ func (e *OutputEmitter) EmitStream(ctx context.Context, in *schema.StreamReader[ shouldChangePart = true } } else { - if sourceInfo.FieldType == nodes.FieldIsStream { + if sourceInfo.FieldType == schema2.FieldIsStream { currentV := v for i := 0; i < len(actualPath)-1; i++ { currentM, ok := currentV.(map[string]any) @@ -518,8 +566,8 @@ func (e *OutputEmitter) EmitStream(ctx context.Context, in *schema.StreamReader[ return sr, nil } -func (e *OutputEmitter) Emit(ctx context.Context, in map[string]any) (output map[string]any, err error) { - s, err := nodes.Render(ctx, e.cfg.Template, in, e.cfg.FullSources) +func (e *OutputEmitter) Invoke(ctx context.Context, in map[string]any) (output map[string]any, err error) { + s, err := nodes.Render(ctx, e.Template, in, e.FullSources) if err != nil { return nil, err } diff --git a/backend/domain/workflow/internal/nodes/entry/entry.go b/backend/domain/workflow/internal/nodes/entry/entry.go index de78b5bb..13f4392a 100644 --- a/backend/domain/workflow/internal/nodes/entry/entry.go +++ b/backend/domain/workflow/internal/nodes/entry/entry.go @@ -20,41 +20,74 @@ import ( "context" "fmt" + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" ) type Config struct { DefaultValues map[string]any - OutputTypes map[string]*vo.TypeInfo } -type Entry struct { - cfg *Config - defaultValues map[string]any -} - -func NewEntry(ctx context.Context, cfg *Config) (*Entry, error) { - if cfg == nil { - return nil, fmt.Errorf("config is requried") +func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) { + if n.Parent() != nil { + return nil, fmt.Errorf("entry node cannot have parent: %s", n.Parent().ID) } - defaultValues, _, err := nodes.ConvertInputs(ctx, cfg.DefaultValues, cfg.OutputTypes, nodes.FailFast(), nodes.SkipRequireCheck()) + + if n.ID != entity.EntryNodeKey { + return nil, fmt.Errorf("entry node id must be %s, got %s", entity.EntryNodeKey, n.ID) + } + + ns := &schema.NodeSchema{ + Key: vo.NodeKey(n.ID), + Name: n.Data.Meta.Title, + Type: entity.NodeTypeEntry, + } + + defaultValues := make(map[string]any, len(n.Data.Outputs)) + for _, v := range n.Data.Outputs { + variable, err := vo.ParseVariable(v) + if err != nil { + return nil, err + } + if variable.DefaultValue != nil { + defaultValues[variable.Name] = variable.DefaultValue + } + } + + c.DefaultValues = defaultValues + ns.Configs = c + + if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil { + return nil, err + } + + return ns, nil +} + +func (c *Config) Build(ctx context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) { + defaultValues, _, err := nodes.ConvertInputs(ctx, c.DefaultValues, ns.OutputTypes, nodes.FailFast(), nodes.SkipRequireCheck()) if err != nil { return nil, err } return &Entry{ - cfg: cfg, defaultValues: defaultValues, + outputTypes: ns.OutputTypes, }, nil +} +type Entry struct { + defaultValues map[string]any + outputTypes map[string]*vo.TypeInfo } func (e *Entry) Invoke(_ context.Context, in map[string]any) (out map[string]any, err error) { - for k, v := range e.defaultValues { if val, ok := in[k]; ok { - tInfo := e.cfg.OutputTypes[k] + tInfo := e.outputTypes[k] switch tInfo.Type { case vo.DataTypeString: if len(val.(string)) == 0 { diff --git a/backend/domain/workflow/internal/nodes/exit/exit.go b/backend/domain/workflow/internal/nodes/exit/exit.go new file mode 100644 index 00000000..38adf2a2 --- /dev/null +++ b/backend/domain/workflow/internal/nodes/exit/exit.go @@ -0,0 +1,113 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package exit + +import ( + "context" + "fmt" + + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/emitter" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" +) + +type Config struct { + Template string + TerminatePlan vo.TerminatePlan +} + +func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) { + if n.Parent() != nil { + return nil, fmt.Errorf("exit node cannot have parent: %s", n.Parent().ID) + } + + if n.ID != entity.ExitNodeKey { + return nil, fmt.Errorf("exit node id must be %s, got %s", entity.ExitNodeKey, n.ID) + } + + ns := &schema.NodeSchema{ + Key: entity.ExitNodeKey, + Type: entity.NodeTypeExit, + Name: n.Data.Meta.Title, + Configs: c, + } + + var ( + content *vo.BlockInput + streamingOutput bool + ) + if n.Data.Inputs.OutputEmitter != nil { + content = n.Data.Inputs.Content + streamingOutput = n.Data.Inputs.StreamingOutput + } + + if streamingOutput { + ns.StreamConfigs = &schema.StreamConfig{ + RequireStreamingInput: true, + } + } else { + ns.StreamConfigs = &schema.StreamConfig{ + RequireStreamingInput: false, + } + } + + if content != nil { + if content.Type != vo.VariableTypeString { + return nil, fmt.Errorf("exit node's content type must be %s, got %s", vo.VariableTypeString, content.Type) + } + + if content.Value.Type != vo.BlockInputValueTypeLiteral { + return nil, fmt.Errorf("exit node's content value type must be %s, got %s", vo.BlockInputValueTypeLiteral, content.Value.Type) + } + + c.Template = content.Value.Content.(string) + } + + if n.Data.Inputs.TerminatePlan == nil { + return nil, fmt.Errorf("exit node requires a TerminatePlan") + } + c.TerminatePlan = *n.Data.Inputs.TerminatePlan + + if err := convert.SetInputsForNodeSchema(n, ns); err != nil { + return nil, err + } + + return ns, nil +} + +func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) { + if c.TerminatePlan == vo.ReturnVariables { + return &Exit{}, nil + } + + return &emitter.OutputEmitter{ + Template: c.Template, + FullSources: ns.FullSources, + }, nil +} + +type Exit struct{} + +func (e *Exit) Invoke(_ context.Context, in map[string]any) (map[string]any, error) { + if in == nil { + return map[string]any{}, nil + } + return in, nil +} diff --git a/backend/domain/workflow/internal/nodes/httprequester/adapt.go b/backend/domain/workflow/internal/nodes/httprequester/adapt.go new file mode 100644 index 00000000..060df502 --- /dev/null +++ b/backend/domain/workflow/internal/nodes/httprequester/adapt.go @@ -0,0 +1,340 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package httprequester + +import ( + "fmt" + "regexp" + "strings" + + "github.com/cloudwego/eino/compose" + + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" + "github.com/coze-dev/coze-studio/backend/pkg/lang/crypto" +) + +var extractBracesRegexp = regexp.MustCompile(`\{\{(.*?)}}`) + +func extractBracesContent(s string) []string { + matches := extractBracesRegexp.FindAllStringSubmatch(s, -1) + var result []string + for _, match := range matches { + if len(match) >= 2 { + result = append(result, match[1]) + } + } + return result +} + +type ImplicitNodeDependency struct { + NodeID string + FieldPath compose.FieldPath + TypeInfo *vo.TypeInfo +} + +func extractImplicitDependency(node *vo.Node, canvas *vo.Canvas) ([]*ImplicitNodeDependency, error) { + dependencies := make([]*ImplicitNodeDependency, 0, len(canvas.Nodes)) + url := node.Data.Inputs.APIInfo.URL + urlVars := extractBracesContent(url) + hasReferred := make(map[string]bool) + extractDependenciesFromVars := func(vars []string) error { + for _, v := range vars { + if strings.HasPrefix(v, "block_output_") { + paths := strings.Split(strings.TrimPrefix(v, "block_output_"), ".") + if len(paths) < 2 { + return fmt.Errorf("invalid block_output_ variable: %s", v) + } + if hasReferred[v] { + continue + } + hasReferred[v] = true + dependencies = append(dependencies, &ImplicitNodeDependency{ + NodeID: paths[0], + FieldPath: paths[1:], + }) + } + } + return nil + } + + err := extractDependenciesFromVars(urlVars) + if err != nil { + return nil, err + } + if node.Data.Inputs.Body.BodyType == string(BodyTypeJSON) { + jsonVars := extractBracesContent(node.Data.Inputs.Body.BodyData.Json) + err = extractDependenciesFromVars(jsonVars) + if err != nil { + return nil, err + } + } + if node.Data.Inputs.Body.BodyType == string(BodyTypeRawText) { + rawTextVars := extractBracesContent(node.Data.Inputs.Body.BodyData.Json) + err = extractDependenciesFromVars(rawTextVars) + if err != nil { + return nil, err + } + } + + var nodeFinder func(nodes []*vo.Node, nodeID string) *vo.Node + nodeFinder = func(nodes []*vo.Node, nodeID string) *vo.Node { + for i := range nodes { + if nodes[i].ID == nodeID { + return nodes[i] + } + if len(nodes[i].Blocks) > 0 { + if n := nodeFinder(nodes[i].Blocks, nodeID); n != nil { + return n + } + } + } + return nil + } + for _, ds := range dependencies { + fNode := nodeFinder(canvas.Nodes, ds.NodeID) + if fNode == nil { + continue + } + tInfoMap := make(map[string]*vo.TypeInfo, len(node.Data.Outputs)) + for _, vAny := range fNode.Data.Outputs { + v, err := vo.ParseVariable(vAny) + if err != nil { + return nil, err + } + tInfo, err := convert.CanvasVariableToTypeInfo(v) + if err != nil { + return nil, err + } + tInfoMap[v.Name] = tInfo + } + tInfo, ok := getTypeInfoByPath(ds.FieldPath[0], ds.FieldPath[1:], tInfoMap) + if !ok { + return nil, fmt.Errorf("cannot find type info for dependency: %s", ds.FieldPath) + } + ds.TypeInfo = tInfo + } + + return dependencies, nil +} + +func getTypeInfoByPath(root string, properties []string, tInfoMap map[string]*vo.TypeInfo) (*vo.TypeInfo, bool) { + if len(properties) == 0 { + if tInfo, ok := tInfoMap[root]; ok { + return tInfo, true + } + return nil, false + } + tInfo, ok := tInfoMap[root] + if !ok { + return nil, false + } + return getTypeInfoByPath(properties[0], properties[1:], tInfo.Properties) +} + +var globalVariableRegex = regexp.MustCompile(`global_variable_\w+\s*\["(.*?)"]`) + +func setHttpRequesterInputsForNodeSchema(n *vo.Node, ns *schema.NodeSchema, implicitNodeDependencies []*ImplicitNodeDependency) (err error) { + inputs := n.Data.Inputs + implicitPathVars := make(map[string]bool) + addImplicitVarsSources := func(prefix string, vars []string) error { + for _, v := range vars { + if strings.HasPrefix(v, "block_output_") { + paths := strings.Split(strings.TrimPrefix(v, "block_output_"), ".") + if len(paths) < 2 { + return fmt.Errorf("invalid implicit var : %s", v) + } + for _, dep := range implicitNodeDependencies { + if dep.NodeID == paths[0] && strings.Join(dep.FieldPath, ".") == strings.Join(paths[1:], ".") { + pathValue := prefix + crypto.MD5HexValue(v) + if _, visited := implicitPathVars[pathValue]; visited { + continue + } + implicitPathVars[pathValue] = true + ns.SetInputType(pathValue, dep.TypeInfo) + ns.AddInputSource(&vo.FieldInfo{ + Path: []string{pathValue}, + Source: vo.FieldSource{ + Ref: &vo.Reference{ + FromNodeKey: vo.NodeKey(dep.NodeID), + FromPath: dep.FieldPath, + }, + }, + }) + } + } + } + if strings.HasPrefix(v, "global_variable_") { + matches := globalVariableRegex.FindStringSubmatch(v) + if len(matches) < 2 { + continue + } + + var varType vo.GlobalVarType + if strings.HasPrefix(v, string(vo.RefSourceTypeGlobalApp)) { + varType = vo.GlobalAPP + } else if strings.HasPrefix(v, string(vo.RefSourceTypeGlobalUser)) { + varType = vo.GlobalUser + } else if strings.HasPrefix(v, string(vo.RefSourceTypeGlobalSystem)) { + varType = vo.GlobalSystem + } else { + return fmt.Errorf("invalid global variable type: %s", v) + } + + source := vo.FieldSource{ + Ref: &vo.Reference{ + VariableType: &varType, + FromPath: []string{matches[1]}, + }, + } + + ns.AddInputSource(&vo.FieldInfo{ + Path: []string{prefix + crypto.MD5HexValue(v)}, + Source: source, + }) + + } + } + + return nil + } + + urlVars := extractBracesContent(inputs.APIInfo.URL) + err = addImplicitVarsSources("__apiInfo_url_", urlVars) + if err != nil { + return err + } + + err = applyParamsToSchema(ns, "__headers_", inputs.Headers, n.Parent()) + if err != nil { + return err + } + + err = applyParamsToSchema(ns, "__params_", inputs.Params, n.Parent()) + if err != nil { + return err + } + + if inputs.Auth != nil && inputs.Auth.AuthOpen { + authData := inputs.Auth.AuthData + const bearerTokenKey = "__auth_authData_bearerTokenData_token" + if inputs.Auth.AuthType == "BEARER_AUTH" { + bearTokenParam := authData.BearerTokenData[0] + tInfo, err := convert.CanvasBlockInputToTypeInfo(bearTokenParam.Input) + if err != nil { + return err + } + ns.SetInputType(bearerTokenKey, tInfo) + sources, err := convert.CanvasBlockInputToFieldInfo(bearTokenParam.Input, compose.FieldPath{bearerTokenKey}, n.Parent()) + if err != nil { + return err + } + ns.AddInputSource(sources...) + + } + + if inputs.Auth.AuthType == "CUSTOM_AUTH" { + const ( + customDataDataKey = "__auth_authData_customData_data_Key" + customDataDataValue = "__auth_authData_customData_data_Value" + ) + dataParams := authData.CustomData.Data + keyParam := dataParams[0] + keyTypeInfo, err := convert.CanvasBlockInputToTypeInfo(keyParam.Input) + if err != nil { + return err + } + ns.SetInputType(customDataDataKey, keyTypeInfo) + sources, err := convert.CanvasBlockInputToFieldInfo(keyParam.Input, compose.FieldPath{customDataDataKey}, n.Parent()) + if err != nil { + return err + } + ns.AddInputSource(sources...) + + valueParam := dataParams[1] + valueTypeInfo, err := convert.CanvasBlockInputToTypeInfo(valueParam.Input) + if err != nil { + return err + } + ns.SetInputType(customDataDataValue, valueTypeInfo) + sources, err = convert.CanvasBlockInputToFieldInfo(valueParam.Input, compose.FieldPath{customDataDataValue}, n.Parent()) + if err != nil { + return err + } + ns.AddInputSource(sources...) + } + } + + switch BodyType(inputs.Body.BodyType) { + case BodyTypeFormData: + err = applyParamsToSchema(ns, "__body_bodyData_formData_", inputs.Body.BodyData.FormData.Data, n.Parent()) + if err != nil { + return err + } + case BodyTypeFormURLEncoded: + err = applyParamsToSchema(ns, "__body_bodyData_formURLEncoded_", inputs.Body.BodyData.FormURLEncoded, n.Parent()) + if err != nil { + return err + } + case BodyTypeBinary: + const fileURLName = "__body_bodyData_binary_fileURL" + fileURLInput := inputs.Body.BodyData.Binary.FileURL + ns.SetInputType(fileURLName, &vo.TypeInfo{ + Type: vo.DataTypeString, + }) + sources, err := convert.CanvasBlockInputToFieldInfo(fileURLInput, compose.FieldPath{fileURLName}, n.Parent()) + if err != nil { + return err + } + ns.AddInputSource(sources...) + case BodyTypeJSON: + jsonVars := extractBracesContent(inputs.Body.BodyData.Json) + err = addImplicitVarsSources("__body_bodyData_json_", jsonVars) + if err != nil { + return err + } + case BodyTypeRawText: + rawTextVars := extractBracesContent(inputs.Body.BodyData.RawText) + err = addImplicitVarsSources("__body_bodyData_rawText_", rawTextVars) + if err != nil { + return err + } + } + + return nil +} + +func applyParamsToSchema(ns *schema.NodeSchema, prefix string, params []*vo.Param, parentNode *vo.Node) error { + for i := range params { + param := params[i] + name := param.Name + tInfo, err := convert.CanvasBlockInputToTypeInfo(param.Input) + if err != nil { + return err + } + + fieldName := prefix + crypto.MD5HexValue(name) + ns.SetInputType(fieldName, tInfo) + sources, err := convert.CanvasBlockInputToFieldInfo(param.Input, compose.FieldPath{fieldName}, parentNode) + if err != nil { + return err + } + ns.AddInputSource(sources...) + } + return nil +} diff --git a/backend/domain/workflow/internal/nodes/httprequester/http_requester.go b/backend/domain/workflow/internal/nodes/httprequester/http_requester.go index 2d4c306f..9deac5c6 100644 --- a/backend/domain/workflow/internal/nodes/httprequester/http_requester.go +++ b/backend/domain/workflow/internal/nodes/httprequester/http_requester.go @@ -31,9 +31,14 @@ import ( "strings" "time" + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/pkg/lang/crypto" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" + "github.com/coze-dev/coze-studio/backend/pkg/lang/slices" "github.com/coze-dev/coze-studio/backend/pkg/sonic" ) @@ -129,7 +134,7 @@ type Request struct { FileURL *string } -var globalVariableReplaceRegexp = regexp.MustCompile(`global_variable_(\w+)\["(\w+)"\]`) +var globalVariableReplaceRegexp = regexp.MustCompile(`global_variable_(\w+)\["(\w+)"]`) type MD5FieldMapping struct { HeaderMD5Mapping map[string]string `json:"header_md_5_mapping,omitempty"` // md5 vs key @@ -184,49 +189,188 @@ type Config struct { Timeout time.Duration RetryTimes uint64 - IgnoreException bool - DefaultOutput map[string]any - MD5FieldMapping } -type HTTPRequester struct { - client *http.Client - config *Config -} - -func NewHTTPRequester(_ context.Context, cfg *Config) (*HTTPRequester, error) { - if cfg == nil { - return nil, fmt.Errorf("config is requried") +func (c *Config) Adapt(_ context.Context, n *vo.Node, opts ...nodes.AdaptOption) (*schema.NodeSchema, error) { + options := nodes.GetAdaptOptions(opts...) + if options.Canvas == nil { + return nil, fmt.Errorf("canvas is requried when adapting HTTPRequester node") } - if len(cfg.Method) == 0 { + implicitDeps, err := extractImplicitDependency(n, options.Canvas) + if err != nil { + return nil, err + } + + ns := &schema.NodeSchema{ + Key: vo.NodeKey(n.ID), + Type: entity.NodeTypeHTTPRequester, + Name: n.Data.Meta.Title, + Configs: c, + } + + inputs := n.Data.Inputs + + md5FieldMapping := &MD5FieldMapping{} + + method := inputs.APIInfo.Method + c.Method = method + reqURL := inputs.APIInfo.URL + c.URLConfig = URLConfig{ + Tpl: strings.TrimSpace(reqURL), + } + + urlVars := extractBracesContent(reqURL) + md5FieldMapping.SetURLFields(urlVars...) + + md5FieldMapping.SetHeaderFields(slices.Transform(inputs.Headers, func(a *vo.Param) string { + return a.Name + })...) + + md5FieldMapping.SetParamFields(slices.Transform(inputs.Params, func(a *vo.Param) string { + return a.Name + })...) + + if inputs.Auth != nil && inputs.Auth.AuthOpen { + auth := &AuthenticationConfig{} + ty, err := convertAuthType(inputs.Auth.AuthType) + if err != nil { + return nil, err + } + auth.Type = ty + location, err := convertLocation(inputs.Auth.AuthData.CustomData.AddTo) + if err != nil { + return nil, err + } + auth.Location = location + + c.AuthConfig = auth + } + + bodyConfig := BodyConfig{} + + bodyConfig.BodyType = BodyType(inputs.Body.BodyType) + switch BodyType(inputs.Body.BodyType) { + case BodyTypeJSON: + jsonTpl := inputs.Body.BodyData.Json + bodyConfig.TextJsonConfig = &TextJsonConfig{ + Tpl: jsonTpl, + } + jsonVars := extractBracesContent(jsonTpl) + md5FieldMapping.SetBodyFields(jsonVars...) + case BodyTypeFormData: + bodyConfig.FormDataConfig = &FormDataConfig{ + FileTypeMapping: map[string]bool{}, + } + formDataVars := make([]string, 0) + for i := range inputs.Body.BodyData.FormData.Data { + p := inputs.Body.BodyData.FormData.Data[i] + formDataVars = append(formDataVars, p.Name) + if p.Input.Type == vo.VariableTypeString && p.Input.AssistType > vo.AssistTypeNotSet && p.Input.AssistType < vo.AssistTypeTime { + bodyConfig.FormDataConfig.FileTypeMapping[p.Name] = true + } + } + + md5FieldMapping.SetBodyFields(formDataVars...) + case BodyTypeRawText: + TextTpl := inputs.Body.BodyData.RawText + bodyConfig.TextPlainConfig = &TextPlainConfig{ + Tpl: TextTpl, + } + textPlainVars := extractBracesContent(TextTpl) + md5FieldMapping.SetBodyFields(textPlainVars...) + case BodyTypeFormURLEncoded: + formURLEncodedVars := make([]string, 0) + for _, p := range inputs.Body.BodyData.FormURLEncoded { + formURLEncodedVars = append(formURLEncodedVars, p.Name) + } + md5FieldMapping.SetBodyFields(formURLEncodedVars...) + } + c.BodyConfig = bodyConfig + c.MD5FieldMapping = *md5FieldMapping + + if inputs.Setting != nil { + c.Timeout = time.Duration(inputs.Setting.Timeout) * time.Second + c.RetryTimes = uint64(inputs.Setting.RetryTimes) + } + + if err := setHttpRequesterInputsForNodeSchema(n, ns, implicitDeps); err != nil { + return nil, err + } + if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil { + return nil, err + } + return ns, nil +} + +func convertAuthType(auth string) (AuthType, error) { + switch auth { + case "CUSTOM_AUTH": + return Custom, nil + case "BEARER_AUTH": + return BearToken, nil + default: + return AuthType(0), fmt.Errorf("invalid auth type") + } +} + +func convertLocation(l string) (Location, error) { + switch l { + case "header": + return Header, nil + case "query": + return QueryParam, nil + default: + return 0, fmt.Errorf("invalid location") + } + +} + +func (c *Config) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) { + if len(c.Method) == 0 { return nil, fmt.Errorf("method is requried") } - hg := &HTTPRequester{} + hg := &HTTPRequester{ + urlConfig: c.URLConfig, + method: c.Method, + retryTimes: c.RetryTimes, + authConfig: c.AuthConfig, + bodyConfig: c.BodyConfig, + md5FieldMapping: c.MD5FieldMapping, + } client := http.DefaultClient - if cfg.Timeout > 0 { - client.Timeout = cfg.Timeout + if c.Timeout > 0 { + client.Timeout = c.Timeout } hg.client = client - hg.config = cfg return hg, nil } +type HTTPRequester struct { + client *http.Client + urlConfig URLConfig + authConfig *AuthenticationConfig + bodyConfig BodyConfig + method string + retryTimes uint64 + md5FieldMapping MD5FieldMapping +} + func (hg *HTTPRequester) Invoke(ctx context.Context, input map[string]any) (output map[string]any, err error) { var ( req = &Request{} - method = hg.config.Method - retryTimes = hg.config.RetryTimes + method = hg.method + retryTimes = hg.retryTimes body io.ReadCloser contentType string response *http.Response ) - req, err = hg.config.parserToRequest(input) + req, err = hg.parserToRequest(input) if err != nil { return nil, err } @@ -236,7 +380,7 @@ func (hg *HTTPRequester) Invoke(ctx context.Context, input map[string]any) (outp Header: http.Header{}, } - httpURL, err := nodes.TemplateRender(hg.config.URLConfig.Tpl, req.URLVars) + httpURL, err := nodes.TemplateRender(hg.urlConfig.Tpl, req.URLVars) if err != nil { return nil, err } @@ -255,8 +399,8 @@ func (hg *HTTPRequester) Invoke(ctx context.Context, input map[string]any) (outp params.Set(key, value) } - if hg.config.AuthConfig != nil { - httpRequest.Header, params, err = hg.config.AuthConfig.addAuthentication(ctx, req.Authentication, httpRequest.Header, params) + if hg.authConfig != nil { + httpRequest.Header, params, err = hg.authConfig.addAuthentication(ctx, req.Authentication, httpRequest.Header, params) if err != nil { return nil, err } @@ -264,7 +408,7 @@ func (hg *HTTPRequester) Invoke(ctx context.Context, input map[string]any) (outp u.RawQuery = params.Encode() httpRequest.URL = u - body, contentType, err = hg.config.BodyConfig.getBodyAndContentType(ctx, req) + body, contentType, err = hg.bodyConfig.getBodyAndContentType(ctx, req) if err != nil { return nil, err } @@ -479,18 +623,16 @@ func httpGet(ctx context.Context, url string) (*http.Response, error) { } func (hg *HTTPRequester) ToCallbackInput(_ context.Context, input map[string]any) (map[string]any, error) { - var ( - request = &Request{} - config = hg.config - ) - request, err := hg.config.parserToRequest(input) + var request = &Request{} + + request, err := hg.parserToRequest(input) if err != nil { return nil, err } result := make(map[string]any) - result["method"] = config.Method + result["method"] = hg.method - u, err := nodes.TemplateRender(config.URLConfig.Tpl, request.URLVars) + u, err := nodes.TemplateRender(hg.urlConfig.Tpl, request.URLVars) if err != nil { return nil, err } @@ -508,13 +650,13 @@ func (hg *HTTPRequester) ToCallbackInput(_ context.Context, input map[string]any } result["header"] = headers result["auth"] = nil - if config.AuthConfig != nil { - if config.AuthConfig.Type == Custom { + if hg.authConfig != nil { + if hg.authConfig.Type == Custom { result["auth"] = map[string]interface{}{ "Key": request.Authentication.Key, "Value": request.Authentication.Value, } - } else if config.AuthConfig.Type == BearToken { + } else if hg.authConfig.Type == BearToken { result["auth"] = map[string]interface{}{ "token": request.Authentication.Token, } @@ -522,9 +664,9 @@ func (hg *HTTPRequester) ToCallbackInput(_ context.Context, input map[string]any } result["body"] = nil - switch config.BodyConfig.BodyType { + switch hg.bodyConfig.BodyType { case BodyTypeJSON: - js, err := nodes.TemplateRender(config.BodyConfig.TextJsonConfig.Tpl, request.JsonVars) + js, err := nodes.TemplateRender(hg.bodyConfig.TextJsonConfig.Tpl, request.JsonVars) if err != nil { return nil, err } @@ -535,7 +677,7 @@ func (hg *HTTPRequester) ToCallbackInput(_ context.Context, input map[string]any } result["body"] = ret case BodyTypeRawText: - tx, err := nodes.TemplateRender(config.BodyConfig.TextPlainConfig.Tpl, request.TextPlainVars) + tx, err := nodes.TemplateRender(hg.bodyConfig.TextPlainConfig.Tpl, request.TextPlainVars) if err != nil { return nil, err @@ -569,7 +711,7 @@ const ( bodyBinaryFileURLPrefix = "binary_fileURL" ) -func (cfg *Config) parserToRequest(input map[string]any) (*Request, error) { +func (hg *HTTPRequester) parserToRequest(input map[string]any) (*Request, error) { request := &Request{ URLVars: make(map[string]any), Headers: make(map[string]string), @@ -583,7 +725,7 @@ func (cfg *Config) parserToRequest(input map[string]any) (*Request, error) { for key, value := range input { if strings.HasPrefix(key, apiInfoURLPrefix) { urlMD5 := strings.TrimPrefix(key, apiInfoURLPrefix) - if urlKey, ok := cfg.URLMD5Mapping[urlMD5]; ok { + if urlKey, ok := hg.md5FieldMapping.URLMD5Mapping[urlMD5]; ok { if strings.HasPrefix(urlKey, "global_variable_") { urlKey = globalVariableReplaceRegexp.ReplaceAllString(urlKey, "global_variable_$1.$2") } @@ -592,13 +734,13 @@ func (cfg *Config) parserToRequest(input map[string]any) (*Request, error) { } if strings.HasPrefix(key, headersPrefix) { headerKeyMD5 := strings.TrimPrefix(key, headersPrefix) - if headerKey, ok := cfg.HeaderMD5Mapping[headerKeyMD5]; ok { + if headerKey, ok := hg.md5FieldMapping.HeaderMD5Mapping[headerKeyMD5]; ok { request.Headers[headerKey] = value.(string) } } if strings.HasPrefix(key, paramsPrefix) { paramKeyMD5 := strings.TrimPrefix(key, paramsPrefix) - if paramKey, ok := cfg.ParamMD5Mapping[paramKeyMD5]; ok { + if paramKey, ok := hg.md5FieldMapping.ParamMD5Mapping[paramKeyMD5]; ok { request.Params[paramKey] = value.(string) } } @@ -622,7 +764,7 @@ func (cfg *Config) parserToRequest(input map[string]any) (*Request, error) { bodyKey := strings.TrimPrefix(key, bodyDataPrefix) if strings.HasPrefix(bodyKey, bodyJsonPrefix) { jsonMd5Key := strings.TrimPrefix(bodyKey, bodyJsonPrefix) - if jsonKey, ok := cfg.BodyMD5Mapping[jsonMd5Key]; ok { + if jsonKey, ok := hg.md5FieldMapping.BodyMD5Mapping[jsonMd5Key]; ok { if strings.HasPrefix(jsonKey, "global_variable_") { jsonKey = globalVariableReplaceRegexp.ReplaceAllString(jsonKey, "global_variable_$1.$2") } @@ -632,7 +774,7 @@ func (cfg *Config) parserToRequest(input map[string]any) (*Request, error) { } if strings.HasPrefix(bodyKey, bodyFormDataPrefix) { formDataMd5Key := strings.TrimPrefix(bodyKey, bodyFormDataPrefix) - if formDataKey, ok := cfg.BodyMD5Mapping[formDataMd5Key]; ok { + if formDataKey, ok := hg.md5FieldMapping.BodyMD5Mapping[formDataMd5Key]; ok { request.FormDataVars[formDataKey] = value.(string) } @@ -640,14 +782,14 @@ func (cfg *Config) parserToRequest(input map[string]any) (*Request, error) { if strings.HasPrefix(bodyKey, bodyFormURLEncodedPrefix) { formURLEncodeMd5Key := strings.TrimPrefix(bodyKey, bodyFormURLEncodedPrefix) - if formURLEncodeKey, ok := cfg.BodyMD5Mapping[formURLEncodeMd5Key]; ok { + if formURLEncodeKey, ok := hg.md5FieldMapping.BodyMD5Mapping[formURLEncodeMd5Key]; ok { request.FormURLEncodedVars[formURLEncodeKey] = value.(string) } } if strings.HasPrefix(bodyKey, bodyRawTextPrefix) { rawTextMd5Key := strings.TrimPrefix(bodyKey, bodyRawTextPrefix) - if rawTextKey, ok := cfg.BodyMD5Mapping[rawTextMd5Key]; ok { + if rawTextKey, ok := hg.md5FieldMapping.BodyMD5Mapping[rawTextMd5Key]; ok { if strings.HasPrefix(rawTextKey, "global_variable_") { rawTextKey = globalVariableReplaceRegexp.ReplaceAllString(rawTextKey, "global_variable_$1.$2") } diff --git a/backend/domain/workflow/internal/nodes/httprequester/http_requester_test.go b/backend/domain/workflow/internal/nodes/httprequester/http_requester_test.go index 37a42844..faa01825 100644 --- a/backend/domain/workflow/internal/nodes/httprequester/http_requester_test.go +++ b/backend/domain/workflow/internal/nodes/httprequester/http_requester_test.go @@ -28,6 +28,7 @@ import ( "github.com/stretchr/testify/assert" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/pkg/lang/crypto" ) @@ -68,7 +69,7 @@ func TestInvoke(t *testing.T) { }, }, } - hg, err := NewHTTPRequester(context.Background(), cfg) + hg, err := cfg.Build(context.Background(), &schema.NodeSchema{}) assert.NoError(t, err) m := map[string]any{ "__apiInfo_url_" + crypto.MD5HexValue("url_v1"): "v1", @@ -78,7 +79,7 @@ func TestInvoke(t *testing.T) { "__params_" + crypto.MD5HexValue("p2"): "v2", } - result, err := hg.Invoke(context.Background(), m) + result, err := hg.(*HTTPRequester).Invoke(context.Background(), m) assert.NoError(t, err) assert.Equal(t, `{"message":"success"}`, result["body"]) assert.Equal(t, int64(200), result["statusCode"]) @@ -157,7 +158,7 @@ func TestInvoke(t *testing.T) { } // Create an HTTPRequest instance - hg, err := NewHTTPRequester(context.Background(), cfg) + hg, err := cfg.Build(context.Background(), &schema.NodeSchema{}) assert.NoError(t, err) m := map[string]any{ @@ -171,7 +172,7 @@ func TestInvoke(t *testing.T) { "__body_bodyData_formData_" + crypto.MD5HexValue("fileURL"): fileServer.URL, } - result, err := hg.Invoke(context.Background(), m) + result, err := hg.(*HTTPRequester).Invoke(context.Background(), m) assert.NoError(t, err) assert.Equal(t, `{"message":"success"}`, result["body"]) assert.Equal(t, int64(200), result["statusCode"]) @@ -228,7 +229,7 @@ func TestInvoke(t *testing.T) { }, }, } - hg, err := NewHTTPRequester(context.Background(), cfg) + hg, err := cfg.Build(context.Background(), &schema.NodeSchema{}) assert.NoError(t, err) m := map[string]any{ @@ -241,7 +242,7 @@ func TestInvoke(t *testing.T) { "__body_bodyData_rawText_" + crypto.MD5HexValue("v2"): "v2", } - result, err := hg.Invoke(context.Background(), m) + result, err := hg.(*HTTPRequester).Invoke(context.Background(), m) assert.NoError(t, err) assert.Equal(t, `{"message":"success"}`, result["body"]) assert.Equal(t, int64(200), result["statusCode"]) @@ -303,7 +304,7 @@ func TestInvoke(t *testing.T) { } // Create an HTTPRequest instance - hg, err := NewHTTPRequester(context.Background(), cfg) + hg, err := cfg.Build(context.Background(), &schema.NodeSchema{}) assert.NoError(t, err) m := map[string]any{ @@ -316,7 +317,7 @@ func TestInvoke(t *testing.T) { "__body_bodyData_json_" + crypto.MD5HexValue("v2"): "v2", } - result, err := hg.Invoke(context.Background(), m) + result, err := hg.(*HTTPRequester).Invoke(context.Background(), m) assert.NoError(t, err) assert.Equal(t, `{"message":"success"}`, result["body"]) assert.Equal(t, int64(200), result["statusCode"]) @@ -376,7 +377,7 @@ func TestInvoke(t *testing.T) { } // Create an HTTPRequest instance - hg, err := NewHTTPRequester(context.Background(), cfg) + hg, err := cfg.Build(context.Background(), &schema.NodeSchema{}) assert.NoError(t, err) m := map[string]any{ @@ -388,7 +389,7 @@ func TestInvoke(t *testing.T) { "__body_bodyData_binary_fileURL" + crypto.MD5HexValue("v1"): fileServer.URL, } - result, err := hg.Invoke(context.Background(), m) + result, err := hg.(*HTTPRequester).Invoke(context.Background(), m) assert.NoError(t, err) assert.Equal(t, `{"message":"success"}`, result["body"]) assert.Equal(t, int64(200), result["statusCode"]) diff --git a/backend/domain/workflow/internal/nodes/intentdetector/intent_detector.go b/backend/domain/workflow/internal/nodes/intentdetector/intent_detector.go index 1eaba0c2..458cc08a 100644 --- a/backend/domain/workflow/internal/nodes/intentdetector/intent_detector.go +++ b/backend/domain/workflow/internal/nodes/intentdetector/intent_detector.go @@ -18,26 +18,167 @@ package intentdetector import ( "context" - "encoding/json" "errors" + "fmt" "strconv" "strings" - "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/prompt" "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/schema" "github.com/spf13/cast" + "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model" + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" + schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/pkg/lang/ternary" + "github.com/coze-dev/coze-studio/backend/pkg/sonic" ) type Config struct { Intents []string SystemPrompt string IsFastMode bool - ChatModel model.BaseChatModel + LLMParams *model.LLMParams +} + +func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema2.NodeSchema, error) { + ns := &schema2.NodeSchema{ + Key: vo.NodeKey(n.ID), + Type: entity.NodeTypeIntentDetector, + Name: n.Data.Meta.Title, + Configs: c, + } + + param := n.Data.Inputs.LLMParam + if param == nil { + return nil, fmt.Errorf("intent detector node's llmParam is nil") + } + + llmParam, ok := param.(vo.IntentDetectorLLMParam) + if !ok { + return nil, fmt.Errorf("llm node's llmParam must be LLMParam, got %v", llmParam) + } + + paramBytes, err := sonic.Marshal(param) + if err != nil { + return nil, err + } + var intentDetectorConfig = &vo.IntentDetectorLLMConfig{} + + err = sonic.Unmarshal(paramBytes, &intentDetectorConfig) + if err != nil { + return nil, err + } + + modelLLMParams := &model.LLMParams{} + modelLLMParams.ModelType = int64(intentDetectorConfig.ModelType) + modelLLMParams.ModelName = intentDetectorConfig.ModelName + modelLLMParams.TopP = intentDetectorConfig.TopP + modelLLMParams.Temperature = intentDetectorConfig.Temperature + modelLLMParams.MaxTokens = intentDetectorConfig.MaxTokens + modelLLMParams.ResponseFormat = model.ResponseFormat(intentDetectorConfig.ResponseFormat) + modelLLMParams.SystemPrompt = intentDetectorConfig.SystemPrompt.Value.Content.(string) + + c.LLMParams = modelLLMParams + c.SystemPrompt = modelLLMParams.SystemPrompt + + var intents = make([]string, 0, len(n.Data.Inputs.Intents)) + for _, it := range n.Data.Inputs.Intents { + intents = append(intents, it.Name) + } + c.Intents = intents + + if n.Data.Inputs.Mode == "top_speed" { + c.IsFastMode = true + } + + if err = convert.SetInputsForNodeSchema(n, ns); err != nil { + return nil, err + } + + if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil { + return nil, err + } + + return ns, nil +} + +func (c *Config) Build(ctx context.Context, _ *schema2.NodeSchema, _ ...schema2.BuildOption) (any, error) { + if !c.IsFastMode && c.LLMParams == nil { + return nil, errors.New("config chat model is required") + } + + if len(c.Intents) == 0 { + return nil, errors.New("config intents is required") + } + + m, _, err := model.GetManager().GetModel(ctx, c.LLMParams) + if err != nil { + return nil, err + } + + chain := compose.NewChain[map[string]any, *schema.Message]() + + spt := ternary.IFElse[string](c.IsFastMode, FastModeSystemIntentPrompt, SystemIntentPrompt) + + intents, err := toIntentString(c.Intents) + if err != nil { + return nil, err + } + + sptTemplate, err := nodes.TemplateRender(spt, map[string]interface{}{ + "intents": intents, + }) + if err != nil { + return nil, err + } + prompts := prompt.FromMessages(schema.Jinja2, + &schema.Message{Content: sptTemplate, Role: schema.System}, + &schema.Message{Content: "{{query}}", Role: schema.User}) + + r, err := chain.AppendChatTemplate(prompts).AppendChatModel(m).Compile(ctx) + if err != nil { + return nil, err + } + return &IntentDetector{ + isFastMode: c.IsFastMode, + systemPrompt: c.SystemPrompt, + runner: r, + }, nil +} + +func (c *Config) BuildBranch(_ context.Context) ( + func(ctx context.Context, nodeOutput map[string]any) (int64, bool, error), bool) { + return func(ctx context.Context, nodeOutput map[string]any) (int64, bool, error) { + classificationId, ok := nodeOutput[classificationID] + if !ok { + return -1, false, fmt.Errorf("failed to take classification id from input map: %v", nodeOutput) + } + + cID64, ok := classificationId.(int64) + if !ok { + return -1, false, fmt.Errorf("classificationID not of type int64, actual type: %T", classificationId) + } + + if cID64 == 0 { + return -1, true, nil + } + + return cID64 - 1, false, nil + }, true +} + +func (c *Config) ExpectPorts(ctx context.Context, n *vo.Node) []string { + expects := make([]string, len(n.Data.Inputs.Intents)+1) + expects[0] = schema2.PortDefault + for i := 0; i < len(n.Data.Inputs.Intents); i++ { + expects[i+1] = fmt.Sprintf(schema2.PortBranchFormat, i) + } + return expects } const SystemIntentPrompt = ` @@ -95,71 +236,39 @@ Note: ##Limit - Please do not reply in text.` +const classificationID = "classificationId" + type IntentDetector struct { - config *Config - runner compose.Runnable[map[string]any, *schema.Message] -} - -func NewIntentDetector(ctx context.Context, cfg *Config) (*IntentDetector, error) { - if cfg == nil { - return nil, errors.New("cfg is required") - } - if !cfg.IsFastMode && cfg.ChatModel == nil { - return nil, errors.New("config chat model is required") - } - - if len(cfg.Intents) == 0 { - return nil, errors.New("config intents is required") - } - chain := compose.NewChain[map[string]any, *schema.Message]() - - spt := ternary.IFElse[string](cfg.IsFastMode, FastModeSystemIntentPrompt, SystemIntentPrompt) - - sptTemplate, err := nodes.TemplateRender(spt, map[string]interface{}{ - "intents": toIntentString(cfg.Intents), - }) - if err != nil { - return nil, err - } - prompts := prompt.FromMessages(schema.Jinja2, - &schema.Message{Content: sptTemplate, Role: schema.System}, - &schema.Message{Content: "{{query}}", Role: schema.User}) - - r, err := chain.AppendChatTemplate(prompts).AppendChatModel(cfg.ChatModel).Compile(ctx) - if err != nil { - return nil, err - } - return &IntentDetector{ - config: cfg, - runner: r, - }, nil + isFastMode bool + systemPrompt string + runner compose.Runnable[map[string]any, *schema.Message] } func (id *IntentDetector) parseToNodeOut(content string) (map[string]any, error) { - nodeOutput := make(map[string]any) - nodeOutput["classificationId"] = 0 if content == "" { - return nodeOutput, errors.New("content is empty") + return nil, errors.New("intent detector's LLM output content is empty") } - if id.config.IsFastMode { + if id.isFastMode { cid, err := strconv.ParseInt(content, 10, 64) if err != nil { - return nodeOutput, err + return nil, err } - nodeOutput["classificationId"] = cid - return nodeOutput, nil + return map[string]any{ + classificationID: cid, + }, nil } leftIndex := strings.Index(content, "{") rightIndex := strings.Index(content, "}") if leftIndex == -1 || rightIndex == -1 { - return nodeOutput, errors.New("content is invalid") + return nil, fmt.Errorf("intent detector's LLM output content is invalid: %s", content) } - err := json.Unmarshal([]byte(content[leftIndex:rightIndex+1]), &nodeOutput) + var nodeOutput map[string]any + err := sonic.UnmarshalString(content[leftIndex:rightIndex+1], &nodeOutput) if err != nil { - return nodeOutput, err + return nil, err } return nodeOutput, nil @@ -178,8 +287,8 @@ func (id *IntentDetector) Invoke(ctx context.Context, input map[string]any) (map vars := make(map[string]any) vars["query"] = queryStr - if !id.config.IsFastMode { - ad, err := nodes.TemplateRender(id.config.SystemPrompt, map[string]any{"query": query}) + if !id.isFastMode { + ad, err := nodes.TemplateRender(id.systemPrompt, map[string]any{"query": query}) if err != nil { return nil, err } @@ -193,7 +302,7 @@ func (id *IntentDetector) Invoke(ctx context.Context, input map[string]any) (map return id.parseToNodeOut(o.Content) } -func toIntentString(its []string) string { +func toIntentString(its []string) (string, error) { type IntentVariableItem struct { ClassificationID int64 `json:"classificationId"` Content string `json:"content"` @@ -207,6 +316,6 @@ func toIntentString(its []string) string { Content: it, }) } - itsBytes, _ := json.Marshal(vs) - return string(itsBytes) + + return sonic.MarshalString(vs) } diff --git a/backend/domain/workflow/internal/nodes/intentdetector/intent_detector_test.go b/backend/domain/workflow/internal/nodes/intentdetector/intent_detector_test.go deleted file mode 100644 index 334d7fb1..00000000 --- a/backend/domain/workflow/internal/nodes/intentdetector/intent_detector_test.go +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Copyright 2025 coze-dev Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package intentdetector - -import ( - "context" - "fmt" - "testing" - - "github.com/cloudwego/eino/components/model" - "github.com/cloudwego/eino/schema" - "github.com/stretchr/testify/assert" -) - -type mockChatModel struct { - topSeed bool -} - -func (m mockChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { - if m.topSeed { - return &schema.Message{ - Content: "1", - }, nil - } - return &schema.Message{ - Content: `{"classificationId":1,"reason":"高兴"}`, - }, nil -} - -func (m mockChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { - return nil, nil -} - -func (m mockChatModel) BindTools(tools []*schema.ToolInfo) error { - return nil -} - -func TestNewIntentDetector(t *testing.T) { - ctx := context.Background() - t.Run("fast mode", func(t *testing.T) { - dt, err := NewIntentDetector(ctx, &Config{ - Intents: []string{"高兴", "悲伤"}, - IsFastMode: true, - ChatModel: &mockChatModel{topSeed: true}, - }) - assert.Nil(t, err) - - ret, err := dt.Invoke(ctx, map[string]any{ - "query": "我考了100分", - }) - assert.Nil(t, err) - assert.Equal(t, ret["classificationId"], int64(1)) - }) - - t.Run("full mode", func(t *testing.T) { - - dt, err := NewIntentDetector(ctx, &Config{ - Intents: []string{"高兴", "悲伤"}, - IsFastMode: false, - ChatModel: &mockChatModel{}, - }) - assert.Nil(t, err) - - ret, err := dt.Invoke(ctx, map[string]any{ - "query": "我考了100分", - }) - fmt.Println(err) - assert.Nil(t, err) - fmt.Println(ret) - assert.Equal(t, ret["classificationId"], float64(1)) - assert.Equal(t, ret["reason"], "高兴") - }) - -} diff --git a/backend/domain/workflow/internal/nodes/json/json_deserialization.go b/backend/domain/workflow/internal/nodes/json/json_deserialization.go index a8236aca..269cd6e6 100644 --- a/backend/domain/workflow/internal/nodes/json/json_deserialization.go +++ b/backend/domain/workflow/internal/nodes/json/json_deserialization.go @@ -20,8 +20,11 @@ import ( "context" "fmt" + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/pkg/ctxcache" "github.com/coze-dev/coze-studio/backend/pkg/errorx" "github.com/coze-dev/coze-studio/backend/pkg/sonic" @@ -34,32 +37,42 @@ const ( warningsKey = "deserialization_warnings" ) -type DeserializationConfig struct { - OutputFields map[string]*vo.TypeInfo `json:"outputFields,omitempty"` +type DeserializationConfig struct{} + +func (d *DeserializationConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) ( + *schema.NodeSchema, error) { + ns := &schema.NodeSchema{ + Key: vo.NodeKey(n.ID), + Type: entity.NodeTypeJsonDeserialization, + Name: n.Data.Meta.Title, + Configs: d, + } + + if err := convert.SetInputsForNodeSchema(n, ns); err != nil { + return nil, err + } + + if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil { + return nil, err + } + + return ns, nil } -type Deserializer struct { - config *DeserializationConfig - typeInfo *vo.TypeInfo -} - -func NewJsonDeserializer(_ context.Context, cfg *DeserializationConfig) (*Deserializer, error) { - if cfg == nil { - return nil, fmt.Errorf("config required") - } - if cfg.OutputFields == nil { - return nil, fmt.Errorf("OutputFields is required for deserialization") - } - typeInfo := cfg.OutputFields[OutputKeyDeserialization] - if typeInfo == nil { +func (d *DeserializationConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) { + typeInfo, ok := ns.OutputTypes[OutputKeyDeserialization] + if !ok { return nil, fmt.Errorf("no output field specified in deserialization config") } return &Deserializer{ - config: cfg, typeInfo: typeInfo, }, nil } +type Deserializer struct { + typeInfo *vo.TypeInfo +} + func (jd *Deserializer) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) { jsonStrValue := input[InputKeyDeserialization] diff --git a/backend/domain/workflow/internal/nodes/json/json_deserialization_test.go b/backend/domain/workflow/internal/nodes/json/json_deserialization_test.go index 4e1d94e2..7ba4c9ec 100644 --- a/backend/domain/workflow/internal/nodes/json/json_deserialization_test.go +++ b/backend/domain/workflow/internal/nodes/json/json_deserialization_test.go @@ -24,6 +24,7 @@ import ( "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/pkg/ctxcache" "github.com/coze-dev/coze-studio/backend/pkg/sonic" ) @@ -31,19 +32,9 @@ import ( func TestNewJsonDeserializer(t *testing.T) { ctx := context.Background() - // Test with nil config - _, err := NewJsonDeserializer(ctx, nil) - assert.Error(t, err) - assert.Contains(t, err.Error(), "config required") - - // Test with missing OutputFields config - _, err = NewJsonDeserializer(ctx, &DeserializationConfig{}) - assert.Error(t, err) - assert.Contains(t, err.Error(), "OutputFields is required") - // Test with missing output key in OutputFields - _, err = NewJsonDeserializer(ctx, &DeserializationConfig{ - OutputFields: map[string]*vo.TypeInfo{ + _, err := (&DeserializationConfig{}).Build(ctx, &schema.NodeSchema{ + OutputTypes: map[string]*vo.TypeInfo{ "testKey": {Type: vo.DataTypeString}, }, }) @@ -51,12 +42,12 @@ func TestNewJsonDeserializer(t *testing.T) { assert.Contains(t, err.Error(), "no output field specified in deserialization config") // Test with valid config - validConfig := &DeserializationConfig{ - OutputFields: map[string]*vo.TypeInfo{ + validConfig := &schema.NodeSchema{ + OutputTypes: map[string]*vo.TypeInfo{ OutputKeyDeserialization: {Type: vo.DataTypeString}, }, } - processor, err := NewJsonDeserializer(ctx, validConfig) + processor, err := (&DeserializationConfig{}).Build(ctx, validConfig) assert.NoError(t, err) assert.NotNil(t, processor) } @@ -65,16 +56,16 @@ func TestJsonDeserializer_Invoke(t *testing.T) { ctx := context.Background() // Base type test config - baseTypeConfig := &DeserializationConfig{ - OutputFields: map[string]*vo.TypeInfo{ - "output": {Type: vo.DataTypeString}, + baseTypeConfig := &schema.NodeSchema{ + OutputTypes: map[string]*vo.TypeInfo{ + OutputKeyDeserialization: {Type: vo.DataTypeString}, }, } // Object type test config - objectTypeConfig := &DeserializationConfig{ - OutputFields: map[string]*vo.TypeInfo{ - "output": { + objectTypeConfig := &schema.NodeSchema{ + OutputTypes: map[string]*vo.TypeInfo{ + OutputKeyDeserialization: { Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{ "name": {Type: vo.DataTypeString, Required: true}, @@ -85,9 +76,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) { } // Array type test config - arrayTypeConfig := &DeserializationConfig{ - OutputFields: map[string]*vo.TypeInfo{ - "output": { + arrayTypeConfig := &schema.NodeSchema{ + OutputTypes: map[string]*vo.TypeInfo{ + OutputKeyDeserialization: { Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeInteger}, }, @@ -95,9 +86,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) { } // Nested array object test config - nestedArrayConfig := &DeserializationConfig{ - OutputFields: map[string]*vo.TypeInfo{ - "output": { + nestedArrayConfig := &schema.NodeSchema{ + OutputTypes: map[string]*vo.TypeInfo{ + OutputKeyDeserialization: { Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{ Type: vo.DataTypeObject, @@ -113,7 +104,7 @@ func TestJsonDeserializer_Invoke(t *testing.T) { // Test cases tests := []struct { name string - config *DeserializationConfig + config *schema.NodeSchema inputJSON string expectedOutput any expectErr bool @@ -127,9 +118,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) { expectWarnings: 0, }, { name: "Test integer deserialization", - config: &DeserializationConfig{ - OutputFields: map[string]*vo.TypeInfo{ - "output": {Type: vo.DataTypeInteger}, + config: &schema.NodeSchema{ + OutputTypes: map[string]*vo.TypeInfo{ + OutputKeyDeserialization: {Type: vo.DataTypeInteger}, }, }, inputJSON: `123`, @@ -138,9 +129,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) { expectWarnings: 0, }, { name: "Test boolean deserialization", - config: &DeserializationConfig{ - OutputFields: map[string]*vo.TypeInfo{ - "output": {Type: vo.DataTypeBoolean}, + config: &schema.NodeSchema{ + OutputTypes: map[string]*vo.TypeInfo{ + OutputKeyDeserialization: {Type: vo.DataTypeBoolean}, }, }, inputJSON: `true`, @@ -180,9 +171,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) { expectWarnings: 0, }, { name: "Test type mismatch warning", - config: &DeserializationConfig{ - OutputFields: map[string]*vo.TypeInfo{ - "output": {Type: vo.DataTypeInteger}, + config: &schema.NodeSchema{ + OutputTypes: map[string]*vo.TypeInfo{ + OutputKeyDeserialization: {Type: vo.DataTypeInteger}, }, }, inputJSON: `"not a number"`, @@ -198,9 +189,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) { expectWarnings: 0, }, { name: "Test string to integer conversion", - config: &DeserializationConfig{ - OutputFields: map[string]*vo.TypeInfo{ - "output": {Type: vo.DataTypeInteger}, + config: &schema.NodeSchema{ + OutputTypes: map[string]*vo.TypeInfo{ + OutputKeyDeserialization: {Type: vo.DataTypeInteger}, }, }, inputJSON: `"123"`, @@ -209,9 +200,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) { expectWarnings: 0, }, { name: "Test float to integer conversion (integer part)", - config: &DeserializationConfig{ - OutputFields: map[string]*vo.TypeInfo{ - "output": {Type: vo.DataTypeInteger}, + config: &schema.NodeSchema{ + OutputTypes: map[string]*vo.TypeInfo{ + OutputKeyDeserialization: {Type: vo.DataTypeInteger}, }, }, inputJSON: `123.0`, @@ -220,9 +211,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) { expectWarnings: 0, }, { name: "Test float to integer conversion (non-integer part)", - config: &DeserializationConfig{ - OutputFields: map[string]*vo.TypeInfo{ - "output": {Type: vo.DataTypeInteger}, + config: &schema.NodeSchema{ + OutputTypes: map[string]*vo.TypeInfo{ + OutputKeyDeserialization: {Type: vo.DataTypeInteger}, }, }, inputJSON: `123.5`, @@ -231,9 +222,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) { expectWarnings: 0, }, { name: "Test boolean to integer conversion", - config: &DeserializationConfig{ - OutputFields: map[string]*vo.TypeInfo{ - "output": {Type: vo.DataTypeInteger}, + config: &schema.NodeSchema{ + OutputTypes: map[string]*vo.TypeInfo{ + OutputKeyDeserialization: {Type: vo.DataTypeInteger}, }, }, inputJSON: `true`, @@ -242,9 +233,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) { expectWarnings: 1, }, { name: "Test string to boolean conversion", - config: &DeserializationConfig{ - OutputFields: map[string]*vo.TypeInfo{ - "output": {Type: vo.DataTypeBoolean}, + config: &schema.NodeSchema{ + OutputTypes: map[string]*vo.TypeInfo{ + OutputKeyDeserialization: {Type: vo.DataTypeBoolean}, }, }, inputJSON: `"true"`, @@ -252,10 +243,11 @@ func TestJsonDeserializer_Invoke(t *testing.T) { expectErr: false, expectWarnings: 0, }, { - name: "Test string to integer conversion in nested object", - config: &DeserializationConfig{ - OutputFields: map[string]*vo.TypeInfo{ - "output": { + name: "Test string to integer conversion in nested object", + inputJSON: `{"age":"456"}`, + config: &schema.NodeSchema{ + OutputTypes: map[string]*vo.TypeInfo{ + OutputKeyDeserialization: { Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{ "age": {Type: vo.DataTypeInteger}, @@ -263,15 +255,14 @@ func TestJsonDeserializer_Invoke(t *testing.T) { }, }, }, - inputJSON: `{"age":"456"}`, expectedOutput: map[string]any{"age": 456}, expectErr: false, expectWarnings: 0, }, { name: "Test string to integer conversion for array elements", - config: &DeserializationConfig{ - OutputFields: map[string]*vo.TypeInfo{ - "output": { + config: &schema.NodeSchema{ + OutputTypes: map[string]*vo.TypeInfo{ + OutputKeyDeserialization: { Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeInteger}, }, @@ -283,9 +274,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) { expectWarnings: 0, }, { name: "Test string with non-numeric characters to integer conversion", - config: &DeserializationConfig{ - OutputFields: map[string]*vo.TypeInfo{ - "output": {Type: vo.DataTypeInteger}, + config: &schema.NodeSchema{ + OutputTypes: map[string]*vo.TypeInfo{ + OutputKeyDeserialization: {Type: vo.DataTypeInteger}, }, }, inputJSON: `"123abc"`, @@ -294,9 +285,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) { expectWarnings: 1, }, { name: "Test type mismatch in nested object field", - config: &DeserializationConfig{ - OutputFields: map[string]*vo.TypeInfo{ - "output": { + config: &schema.NodeSchema{ + OutputTypes: map[string]*vo.TypeInfo{ + OutputKeyDeserialization: { Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{ "score": {Type: vo.DataTypeInteger}, @@ -310,9 +301,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) { expectWarnings: 1, }, { name: "Test partial conversion failure in array elements", - config: &DeserializationConfig{ - OutputFields: map[string]*vo.TypeInfo{ - "output": { + config: &schema.NodeSchema{ + OutputTypes: map[string]*vo.TypeInfo{ + OutputKeyDeserialization: { Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeInteger}, }, @@ -326,12 +317,12 @@ func TestJsonDeserializer_Invoke(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - processor, err := NewJsonDeserializer(ctx, tt.config) + processor, err := (&DeserializationConfig{}).Build(ctx, tt.config) assert.NoError(t, err) ctxWithCache := ctxcache.Init(ctx) input := map[string]any{"input": tt.inputJSON} - result, err := processor.Invoke(ctxWithCache, input) + result, err := processor.(*Deserializer).Invoke(ctxWithCache, input) if tt.expectErr { assert.Error(t, err) diff --git a/backend/domain/workflow/internal/nodes/json/json_serialization.go b/backend/domain/workflow/internal/nodes/json/json_serialization.go index bba1a4b5..35a4a002 100644 --- a/backend/domain/workflow/internal/nodes/json/json_serialization.go +++ b/backend/domain/workflow/internal/nodes/json/json_serialization.go @@ -20,7 +20,11 @@ import ( "context" "fmt" + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/pkg/sonic" ) @@ -29,28 +33,57 @@ const ( OutputKeySerialization = "output" ) +// SerializationConfig is the Config type for NodeTypeJsonSerialization. +// Each Node Type should have its own designated Config type, +// which should implement NodeAdaptor and NodeBuilder. +// NOTE: we didn't define any fields for this type, +// because this node is simple, we doesn't need to extract any SPECIFIC piece of info +// from frontend Node. In other cases we would need to do it, such as LLM's model configs. type SerializationConfig struct { - InputTypes map[string]*vo.TypeInfo + // you can define ANY number of fields here, + // as long as these fields are SERIALIZABLE and EXPORTED. + // to store specific info extracted from frontend node. + // e.g. + // - LLM model configs + // - conditional expressions + // - fixed input fields such as MaxBatchSize } -type JsonSerializer struct { - config *SerializationConfig -} - -func NewJsonSerializer(_ context.Context, cfg *SerializationConfig) (*JsonSerializer, error) { - if cfg == nil { - return nil, fmt.Errorf("config required") - } - if cfg.InputTypes == nil { - return nil, fmt.Errorf("InputTypes is required for serialization") +// Adapt provides conversion from Node to NodeSchema. +// NOTE: in this specific case, we don't need AdaptOption. +func (s *SerializationConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) { + ns := &schema.NodeSchema{ + Key: vo.NodeKey(n.ID), + Type: entity.NodeTypeJsonSerialization, + Name: n.Data.Meta.Title, + Configs: s, // remember to set the Node's Config Type to NodeSchema as well } - return &JsonSerializer{ - config: cfg, - }, nil + // this sets input fields' type and mapping info + if err := convert.SetInputsForNodeSchema(n, ns); err != nil { + return nil, err + } + + // this set output fields' type info + if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil { + return nil, err + } + + return ns, nil } -func (js *JsonSerializer) Invoke(_ context.Context, input map[string]any) (map[string]any, error) { +func (s *SerializationConfig) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) ( + any, error) { + return &Serializer{}, nil +} + +// Serializer is the actual node implementation. +type Serializer struct { + // here can holds ANY data required for node execution +} + +// Invoke implements the InvokableNode interface. +func (js *Serializer) Invoke(_ context.Context, input map[string]any) (map[string]any, error) { // Directly use the input map for serialization if input == nil { return nil, fmt.Errorf("input data for serialization cannot be nil") diff --git a/backend/domain/workflow/internal/nodes/json/json_serialization_test.go b/backend/domain/workflow/internal/nodes/json/json_serialization_test.go index d49ca055..323627ee 100644 --- a/backend/domain/workflow/internal/nodes/json/json_serialization_test.go +++ b/backend/domain/workflow/internal/nodes/json/json_serialization_test.go @@ -23,44 +23,34 @@ import ( "github.com/stretchr/testify/assert" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" ) func TestNewJsonSerialize(t *testing.T) { ctx := context.Background() - // Test with nil config - _, err := NewJsonSerializer(ctx, nil) - assert.Error(t, err) - assert.Contains(t, err.Error(), "config required") - // Test with missing InputTypes config - _, err = NewJsonSerializer(ctx, &SerializationConfig{}) - assert.Error(t, err) - assert.Contains(t, err.Error(), "InputTypes is required") - - // Test with valid config - validConfig := &SerializationConfig{ + s, err := (&SerializationConfig{}).Build(ctx, &schema.NodeSchema{ InputTypes: map[string]*vo.TypeInfo{ "testKey": {Type: "string"}, }, - } - processor, err := NewJsonSerializer(ctx, validConfig) + }) + assert.NoError(t, err) - assert.NotNil(t, processor) + assert.NotNil(t, s) } func TestJsonSerialize_Invoke(t *testing.T) { ctx := context.Background() - config := &SerializationConfig{ + + processor, err := (&SerializationConfig{}).Build(ctx, &schema.NodeSchema{ InputTypes: map[string]*vo.TypeInfo{ "stringKey": {Type: "string"}, "intKey": {Type: "integer"}, "boolKey": {Type: "boolean"}, "objKey": {Type: "object"}, }, - } - - processor, err := NewJsonSerializer(ctx, config) + }) assert.NoError(t, err) // Test cases @@ -115,7 +105,7 @@ func TestJsonSerialize_Invoke(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result, err := processor.Invoke(ctx, tt.input) + result, err := processor.(*Serializer).Invoke(ctx, tt.input) if tt.expectErr { assert.Error(t, err) diff --git a/backend/domain/workflow/internal/nodes/knowledge/adaptor.go b/backend/domain/workflow/internal/nodes/knowledge/adaptor.go new file mode 100644 index 00000000..14e14074 --- /dev/null +++ b/backend/domain/workflow/internal/nodes/knowledge/adaptor.go @@ -0,0 +1,57 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package knowledge + +import ( + "fmt" + + "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge" +) + +func convertParsingType(p string) (knowledge.ParseMode, error) { + switch p { + case "fast": + return knowledge.FastParseMode, nil + case "accurate": + return knowledge.AccurateParseMode, nil + default: + return "", fmt.Errorf("invalid parsingType: %s", p) + } +} + +func convertChunkType(p string) (knowledge.ChunkType, error) { + switch p { + case "custom": + return knowledge.ChunkTypeCustom, nil + case "default": + return knowledge.ChunkTypeDefault, nil + default: + return "", fmt.Errorf("invalid ChunkType: %s", p) + } +} +func convertRetrievalSearchType(s int64) (knowledge.SearchType, error) { + switch s { + case 0: + return knowledge.SearchTypeSemantic, nil + case 1: + return knowledge.SearchTypeHybrid, nil + case 20: + return knowledge.SearchTypeFullText, nil + default: + return "", fmt.Errorf("invalid RetrievalSearchType %v", s) + } +} diff --git a/backend/domain/workflow/internal/nodes/knowledge/knowledge_deleter.go b/backend/domain/workflow/internal/nodes/knowledge/knowledge_deleter.go index 759397eb..3030b5e7 100644 --- a/backend/domain/workflow/internal/nodes/knowledge/knowledge_deleter.go +++ b/backend/domain/workflow/internal/nodes/knowledge/knowledge_deleter.go @@ -21,27 +21,45 @@ import ( "errors" "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge" + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" ) -type DeleterConfig struct { - KnowledgeID int64 - KnowledgeDeleter knowledge.KnowledgeOperator -} +type DeleterConfig struct{} -type KnowledgeDeleter struct { - config *DeleterConfig -} - -func NewKnowledgeDeleter(_ context.Context, cfg *DeleterConfig) (*KnowledgeDeleter, error) { - if cfg.KnowledgeDeleter == nil { - return nil, errors.New("knowledge deleter is required") +func (d *DeleterConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) { + ns := &schema.NodeSchema{ + Key: vo.NodeKey(n.ID), + Type: entity.NodeTypeKnowledgeDeleter, + Name: n.Data.Meta.Title, + Configs: d, } - return &KnowledgeDeleter{ - config: cfg, + + if err := convert.SetInputsForNodeSchema(n, ns); err != nil { + return nil, err + } + + if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil { + return nil, err + } + + return ns, nil +} + +func (d *DeleterConfig) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) { + return &Deleter{ + knowledgeDeleter: knowledge.GetKnowledgeOperator(), }, nil } -func (k *KnowledgeDeleter) Delete(ctx context.Context, input map[string]any) (map[string]any, error) { +type Deleter struct { + knowledgeDeleter knowledge.KnowledgeOperator +} + +func (k *Deleter) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) { documentID, ok := input["documentID"].(string) if !ok { return nil, errors.New("documentID is required and must be a string") @@ -51,7 +69,7 @@ func (k *KnowledgeDeleter) Delete(ctx context.Context, input map[string]any) (ma DocumentID: documentID, } - response, err := k.config.KnowledgeDeleter.Delete(ctx, req) + response, err := k.knowledgeDeleter.Delete(ctx, req) if err != nil { return nil, err } diff --git a/backend/domain/workflow/internal/nodes/knowledge/knowledge_indexer.go b/backend/domain/workflow/internal/nodes/knowledge/knowledge_indexer.go index 74de5cd7..e80a650d 100644 --- a/backend/domain/workflow/internal/nodes/knowledge/knowledge_indexer.go +++ b/backend/domain/workflow/internal/nodes/knowledge/knowledge_indexer.go @@ -24,7 +24,14 @@ import ( "path/filepath" "strings" + "github.com/spf13/cast" + "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge" + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser" ) @@ -32,30 +39,88 @@ type IndexerConfig struct { KnowledgeID int64 ParsingStrategy *knowledge.ParsingStrategy ChunkingStrategy *knowledge.ChunkingStrategy - KnowledgeIndexer knowledge.KnowledgeOperator } -type KnowledgeIndexer struct { - config *IndexerConfig +func (i *IndexerConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) { + ns := &schema.NodeSchema{ + Key: vo.NodeKey(n.ID), + Type: entity.NodeTypeKnowledgeIndexer, + Name: n.Data.Meta.Title, + Configs: i, + } + + inputs := n.Data.Inputs + datasetListInfoParam := inputs.DatasetParam[0] + datasetIDs := datasetListInfoParam.Input.Value.Content.([]any) + if len(datasetIDs) == 0 { + return nil, fmt.Errorf("dataset ids is required") + } + knowledgeID, err := cast.ToInt64E(datasetIDs[0]) + if err != nil { + return nil, err + } + + i.KnowledgeID = knowledgeID + ps := inputs.StrategyParam.ParsingStrategy + parseMode, err := convertParsingType(ps.ParsingType) + if err != nil { + return nil, err + } + parsingStrategy := &knowledge.ParsingStrategy{ + ParseMode: parseMode, + ImageOCR: ps.ImageOcr, + ExtractImage: ps.ImageExtraction, + ExtractTable: ps.TableExtraction, + } + i.ParsingStrategy = parsingStrategy + + cs := inputs.StrategyParam.ChunkStrategy + chunkType, err := convertChunkType(cs.ChunkType) + if err != nil { + return nil, err + } + chunkingStrategy := &knowledge.ChunkingStrategy{ + ChunkType: chunkType, + Separator: cs.Separator, + ChunkSize: cs.MaxToken, + Overlap: int64(cs.Overlap * float64(cs.MaxToken)), + } + i.ChunkingStrategy = chunkingStrategy + + if err = convert.SetInputsForNodeSchema(n, ns); err != nil { + return nil, err + } + + if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil { + return nil, err + } + + return ns, nil } -func NewKnowledgeIndexer(_ context.Context, cfg *IndexerConfig) (*KnowledgeIndexer, error) { - if cfg.ParsingStrategy == nil { +func (i *IndexerConfig) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) { + if i.ParsingStrategy == nil { return nil, errors.New("parsing strategy is required") } - if cfg.ChunkingStrategy == nil { + if i.ChunkingStrategy == nil { return nil, errors.New("chunking strategy is required") } - if cfg.KnowledgeIndexer == nil { - return nil, errors.New("knowledge indexer is required") - } - return &KnowledgeIndexer{ - config: cfg, + return &Indexer{ + knowledgeID: i.KnowledgeID, + parsingStrategy: i.ParsingStrategy, + chunkingStrategy: i.ChunkingStrategy, + knowledgeIndexer: knowledge.GetKnowledgeOperator(), }, nil } -func (k *KnowledgeIndexer) Store(ctx context.Context, input map[string]any) (map[string]any, error) { +type Indexer struct { + knowledgeID int64 + parsingStrategy *knowledge.ParsingStrategy + chunkingStrategy *knowledge.ChunkingStrategy + knowledgeIndexer knowledge.KnowledgeOperator +} +func (k *Indexer) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) { fileURL, ok := input["knowledge"].(string) if !ok { return nil, errors.New("knowledge is required") @@ -68,15 +133,15 @@ func (k *KnowledgeIndexer) Store(ctx context.Context, input map[string]any) (map } req := &knowledge.CreateDocumentRequest{ - KnowledgeID: k.config.KnowledgeID, - ParsingStrategy: k.config.ParsingStrategy, - ChunkingStrategy: k.config.ChunkingStrategy, + KnowledgeID: k.knowledgeID, + ParsingStrategy: k.parsingStrategy, + ChunkingStrategy: k.chunkingStrategy, FileURL: fileURL, FileName: fileName, FileExtension: ext, } - response, err := k.config.KnowledgeIndexer.Store(ctx, req) + response, err := k.knowledgeIndexer.Store(ctx, req) if err != nil { return nil, err } diff --git a/backend/domain/workflow/internal/nodes/knowledge/knowledge_retrieve.go b/backend/domain/workflow/internal/nodes/knowledge/knowledge_retrieve.go index fedf7a30..fe245bd6 100644 --- a/backend/domain/workflow/internal/nodes/knowledge/knowledge_retrieve.go +++ b/backend/domain/workflow/internal/nodes/knowledge/knowledge_retrieve.go @@ -20,7 +20,14 @@ import ( "context" "errors" + "github.com/spf13/cast" + "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge" + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/pkg/lang/slices" ) @@ -29,37 +36,136 @@ const outputList = "outputList" type RetrieveConfig struct { KnowledgeIDs []int64 RetrievalStrategy *knowledge.RetrievalStrategy - Retriever knowledge.KnowledgeOperator } -type KnowledgeRetrieve struct { - config *RetrieveConfig +func (r *RetrieveConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) { + ns := &schema.NodeSchema{ + Key: vo.NodeKey(n.ID), + Type: entity.NodeTypeKnowledgeRetriever, + Name: n.Data.Meta.Title, + Configs: r, + } + + inputs := n.Data.Inputs + datasetListInfoParam := inputs.DatasetParam[0] + datasetIDs := datasetListInfoParam.Input.Value.Content.([]any) + knowledgeIDs := make([]int64, 0, len(datasetIDs)) + for _, id := range datasetIDs { + k, err := cast.ToInt64E(id) + if err != nil { + return nil, err + } + knowledgeIDs = append(knowledgeIDs, k) + } + r.KnowledgeIDs = knowledgeIDs + + retrievalStrategy := &knowledge.RetrievalStrategy{} + + var getDesignatedParamContent = func(name string) (any, bool) { + for _, param := range inputs.DatasetParam { + if param.Name == name { + return param.Input.Value.Content, true + } + } + return nil, false + } + + if content, ok := getDesignatedParamContent("topK"); ok { + topK, err := cast.ToInt64E(content) + if err != nil { + return nil, err + } + retrievalStrategy.TopK = &topK + } + + if content, ok := getDesignatedParamContent("useRerank"); ok { + useRerank, err := cast.ToBoolE(content) + if err != nil { + return nil, err + } + retrievalStrategy.EnableRerank = useRerank + } + + if content, ok := getDesignatedParamContent("useRewrite"); ok { + useRewrite, err := cast.ToBoolE(content) + if err != nil { + return nil, err + } + retrievalStrategy.EnableQueryRewrite = useRewrite + } + + if content, ok := getDesignatedParamContent("isPersonalOnly"); ok { + isPersonalOnly, err := cast.ToBoolE(content) + if err != nil { + return nil, err + } + retrievalStrategy.IsPersonalOnly = isPersonalOnly + } + + if content, ok := getDesignatedParamContent("useNl2sql"); ok { + useNl2sql, err := cast.ToBoolE(content) + if err != nil { + return nil, err + } + retrievalStrategy.EnableNL2SQL = useNl2sql + } + + if content, ok := getDesignatedParamContent("minScore"); ok { + minScore, err := cast.ToFloat64E(content) + if err != nil { + return nil, err + } + retrievalStrategy.MinScore = &minScore + } + + if content, ok := getDesignatedParamContent("strategy"); ok { + strategy, err := cast.ToInt64E(content) + if err != nil { + return nil, err + } + searchType, err := convertRetrievalSearchType(strategy) + if err != nil { + return nil, err + } + retrievalStrategy.SearchType = searchType + } + + r.RetrievalStrategy = retrievalStrategy + + if err := convert.SetInputsForNodeSchema(n, ns); err != nil { + return nil, err + } + + if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil { + return nil, err + } + + return ns, nil } -func NewKnowledgeRetrieve(_ context.Context, cfg *RetrieveConfig) (*KnowledgeRetrieve, error) { - if cfg == nil { - return nil, errors.New("cfg is required") +func (r *RetrieveConfig) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) { + if len(r.KnowledgeIDs) == 0 { + return nil, errors.New("knowledge ids are required") } - if cfg.Retriever == nil { - return nil, errors.New("retriever is required") - } - - if len(cfg.KnowledgeIDs) == 0 { - return nil, errors.New("knowledgeI ids is required") - } - - if cfg.RetrievalStrategy == nil { + if r.RetrievalStrategy == nil { return nil, errors.New("retrieval strategy is required") } - return &KnowledgeRetrieve{ - config: cfg, + return &Retrieve{ + knowledgeIDs: r.KnowledgeIDs, + retrievalStrategy: r.RetrievalStrategy, + retriever: knowledge.GetKnowledgeOperator(), }, nil } -func (kr *KnowledgeRetrieve) Retrieve(ctx context.Context, input map[string]any) (map[string]any, error) { +type Retrieve struct { + knowledgeIDs []int64 + retrievalStrategy *knowledge.RetrievalStrategy + retriever knowledge.KnowledgeOperator +} +func (kr *Retrieve) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) { query, ok := input["Query"].(string) if !ok { return nil, errors.New("capital query key is required") @@ -67,11 +173,11 @@ func (kr *KnowledgeRetrieve) Retrieve(ctx context.Context, input map[string]any) req := &knowledge.RetrieveRequest{ Query: query, - KnowledgeIDs: kr.config.KnowledgeIDs, - RetrievalStrategy: kr.config.RetrievalStrategy, + KnowledgeIDs: kr.knowledgeIDs, + RetrievalStrategy: kr.retrievalStrategy, } - response, err := kr.config.Retriever.Retrieve(ctx, req) + response, err := kr.retriever.Retrieve(ctx, req) if err != nil { return nil, err } diff --git a/backend/domain/workflow/internal/nodes/llm/llm.go b/backend/domain/workflow/internal/nodes/llm/llm.go index a3c73cf1..7238cfd4 100644 --- a/backend/domain/workflow/internal/nodes/llm/llm.go +++ b/backend/domain/workflow/internal/nodes/llm/llm.go @@ -34,13 +34,20 @@ import ( callbacks2 "github.com/cloudwego/eino/utils/callbacks" "golang.org/x/exp/maps" + workflow3 "github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/workflow" "github.com/coze-dev/coze-studio/backend/domain/workflow" - crossknowledge "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge" + "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge" + crossmodel "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model" + "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" + schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" + "github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr" "github.com/coze-dev/coze-studio/backend/pkg/ctxcache" + "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" "github.com/coze-dev/coze-studio/backend/pkg/lang/slices" "github.com/coze-dev/coze-studio/backend/pkg/logs" "github.com/coze-dev/coze-studio/backend/pkg/safego" @@ -143,126 +150,408 @@ const ( ) type RetrievalStrategy struct { - RetrievalStrategy *crossknowledge.RetrievalStrategy + RetrievalStrategy *knowledge.RetrievalStrategy NoReCallReplyMode NoReCallReplyMode NoReCallReplyCustomizePrompt string } type KnowledgeRecallConfig struct { ChatModel model.BaseChatModel - Retriever crossknowledge.KnowledgeOperator + Retriever knowledge.KnowledgeOperator RetrievalStrategy *RetrievalStrategy - SelectedKnowledgeDetails []*crossknowledge.KnowledgeDetail + SelectedKnowledgeDetails []*knowledge.KnowledgeDetail } type Config struct { - ChatModel ModelWithInfo - Tools []tool.BaseTool - SystemPrompt string - UserPrompt string - OutputFormat Format - InputFields map[string]*vo.TypeInfo - OutputFields map[string]*vo.TypeInfo - ToolsReturnDirectly map[string]bool - KnowledgeRecallConfig *KnowledgeRecallConfig - FullSources map[string]*nodes.SourceInfo + SystemPrompt string + UserPrompt string + OutputFormat Format + LLMParams *crossmodel.LLMParams + FCParam *vo.FCParam + BackupLLMParams *crossmodel.LLMParams } -type LLM struct { - r compose.Runnable[map[string]any, map[string]any] - outputFormat Format - outputFields map[string]*vo.TypeInfo - canStream bool - requireCheckpoint bool - fullSources map[string]*nodes.SourceInfo -} +func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema2.NodeSchema, error) { + ns := &schema2.NodeSchema{ + Key: vo.NodeKey(n.ID), + Type: entity.NodeTypeLLM, + Name: n.Data.Meta.Title, + Configs: c, + } -const ( - rawOutputKey = "llm_raw_output_%s" - warningKey = "llm_warning_%s" -) + param := n.Data.Inputs.LLMParam + if param == nil { + return nil, fmt.Errorf("llm node's llmParam is nil") + } -func jsonParse(ctx context.Context, data string, schema_ map[string]*vo.TypeInfo) (map[string]any, error) { - data = nodes.ExtractJSONString(data) - - var result map[string]any - - err := sonic.UnmarshalString(data, &result) + bs, _ := sonic.Marshal(param) + llmParam := make(vo.LLMParam, 0) + if err := sonic.Unmarshal(bs, &llmParam); err != nil { + return nil, err + } + convertedLLMParam, err := llmParamsToLLMParam(llmParam) if err != nil { - c := execute.GetExeCtx(ctx) - if c != nil { - logs.CtxErrorf(ctx, "failed to parse json: %v, data: %s", err, data) - rawOutputK := fmt.Sprintf(rawOutputKey, c.NodeCtx.NodeKey) - warningK := fmt.Sprintf(warningKey, c.NodeCtx.NodeKey) - ctxcache.Store(ctx, rawOutputK, data) - ctxcache.Store(ctx, warningK, vo.WrapWarn(errno.ErrLLMStructuredOutputParseFail, err)) - return map[string]any{}, nil - } - return nil, err } - r, ws, err := nodes.ConvertInputs(ctx, result, schema_) - if err != nil { - return nil, vo.WrapError(errno.ErrLLMStructuredOutputParseFail, err) + c.LLMParams = convertedLLMParam + c.SystemPrompt = convertedLLMParam.SystemPrompt + c.UserPrompt = convertedLLMParam.Prompt + + var resFormat Format + switch convertedLLMParam.ResponseFormat { + case crossmodel.ResponseFormatText: + resFormat = FormatText + case crossmodel.ResponseFormatMarkdown: + resFormat = FormatMarkdown + case crossmodel.ResponseFormatJSON: + resFormat = FormatJSON + default: + return nil, fmt.Errorf("unsupported response format: %d", convertedLLMParam.ResponseFormat) } - if ws != nil { - logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws) + c.OutputFormat = resFormat + + if err = convert.SetInputsForNodeSchema(n, ns); err != nil { + return nil, err } - return r, nil + if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil { + return nil, err + } + + if resFormat == FormatJSON { + if len(ns.OutputTypes) == 1 { + for _, v := range ns.OutputTypes { + if v.Type == vo.DataTypeString { + resFormat = FormatText + break + } + } + } else if len(ns.OutputTypes) == 2 { + if _, ok := ns.OutputTypes[ReasoningOutputKey]; ok { + for k, v := range ns.OutputTypes { + if k != ReasoningOutputKey && v.Type == vo.DataTypeString { + resFormat = FormatText + break + } + } + } + } + } + + if resFormat == FormatJSON { + ns.StreamConfigs = &schema2.StreamConfig{ + CanGeneratesStream: false, + } + } else { + ns.StreamConfigs = &schema2.StreamConfig{ + CanGeneratesStream: true, + } + } + + if n.Data.Inputs.LLM != nil && n.Data.Inputs.FCParam != nil { + c.FCParam = n.Data.Inputs.FCParam + } + + if se := n.Data.Inputs.SettingOnError; se != nil { + if se.Ext != nil && len(se.Ext.BackupLLMParam) > 0 { + var backupLLMParam vo.SimpleLLMParam + if err = sonic.UnmarshalString(se.Ext.BackupLLMParam, &backupLLMParam); err != nil { + return nil, err + } + + backupModel, err := simpleLLMParamsToLLMParams(backupLLMParam) + if err != nil { + return nil, err + } + c.BackupLLMParams = backupModel + } + } + + return ns, nil +} + +func llmParamsToLLMParam(params vo.LLMParam) (*crossmodel.LLMParams, error) { + p := &crossmodel.LLMParams{} + for _, param := range params { + switch param.Name { + case "temperature": + strVal := param.Input.Value.Content.(string) + floatVal, err := strconv.ParseFloat(strVal, 64) + if err != nil { + return nil, err + } + p.Temperature = &floatVal + case "maxTokens": + strVal := param.Input.Value.Content.(string) + intVal, err := strconv.Atoi(strVal) + if err != nil { + return nil, err + } + p.MaxTokens = intVal + case "responseFormat": + strVal := param.Input.Value.Content.(string) + int64Val, err := strconv.ParseInt(strVal, 10, 64) + if err != nil { + return nil, err + } + p.ResponseFormat = crossmodel.ResponseFormat(int64Val) + case "modleName": + strVal := param.Input.Value.Content.(string) + p.ModelName = strVal + case "modelType": + strVal := param.Input.Value.Content.(string) + int64Val, err := strconv.ParseInt(strVal, 10, 64) + if err != nil { + return nil, err + } + p.ModelType = int64Val + case "prompt": + strVal := param.Input.Value.Content.(string) + p.Prompt = strVal + case "enableChatHistory": + boolVar := param.Input.Value.Content.(bool) + p.EnableChatHistory = boolVar + case "systemPrompt": + strVal := param.Input.Value.Content.(string) + p.SystemPrompt = strVal + case "chatHistoryRound", "generationDiversity", "frequencyPenalty", "presencePenalty": + // do nothing + case "topP": + strVal := param.Input.Value.Content.(string) + floatVar, err := strconv.ParseFloat(strVal, 64) + if err != nil { + return nil, err + } + p.TopP = &floatVar + default: + return nil, fmt.Errorf("invalid LLMParam name: %s", param.Name) + } + } + + return p, nil +} + +func simpleLLMParamsToLLMParams(params vo.SimpleLLMParam) (*crossmodel.LLMParams, error) { + p := &crossmodel.LLMParams{} + p.ModelName = params.ModelName + p.ModelType = params.ModelType + p.Temperature = ¶ms.Temperature + p.MaxTokens = params.MaxTokens + p.TopP = ¶ms.TopP + p.ResponseFormat = params.ResponseFormat + p.SystemPrompt = params.SystemPrompt + return p, nil } func getReasoningContent(message *schema.Message) string { return message.ReasoningContent } -type Options struct { - nested []nodes.NestedWorkflowOption - toolWorkflowSW *schema.StreamWriter[*entity.Message] -} +func (c *Config) Build(ctx context.Context, ns *schema2.NodeSchema, _ ...schema2.BuildOption) (any, error) { + var ( + err error + chatModel, fallbackM model.BaseChatModel + info, fallbackI *modelmgr.Model + modelWithInfo ModelWithInfo + tools []tool.BaseTool + toolsReturnDirectly map[string]bool + knowledgeRecallConfig *KnowledgeRecallConfig + ) -type Option func(o *Options) - -func WithNestedWorkflowOptions(nested ...nodes.NestedWorkflowOption) Option { - return func(o *Options) { - o.nested = append(o.nested, nested...) + chatModel, info, err = crossmodel.GetManager().GetModel(ctx, c.LLMParams) + if err != nil { + return nil, err } -} -func WithToolWorkflowMessageWriter(sw *schema.StreamWriter[*entity.Message]) Option { - return func(o *Options) { - o.toolWorkflowSW = sw + exceptionConf := ns.ExceptionConfigs + if exceptionConf != nil && exceptionConf.MaxRetry > 0 { + backupModelParams := c.BackupLLMParams + if backupModelParams != nil { + fallbackM, fallbackI, err = crossmodel.GetManager().GetModel(ctx, backupModelParams) + if err != nil { + return nil, err + } + } } -} -type llmState = map[string]any + if fallbackM == nil { + modelWithInfo = NewModel(chatModel, info) + } else { + modelWithInfo = NewModelWithFallback(chatModel, fallbackM, info, fallbackI) + } -const agentModelName = "agent_model" + fcParams := c.FCParam + if fcParams != nil { + if fcParams.WorkflowFCParam != nil { + for _, wf := range fcParams.WorkflowFCParam.WorkflowList { + wfIDStr := wf.WorkflowID + wfID, err := strconv.ParseInt(wfIDStr, 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid workflow id: %s", wfIDStr) + } + + workflowToolConfig := vo.WorkflowToolConfig{} + if wf.FCSetting != nil { + workflowToolConfig.InputParametersConfig = wf.FCSetting.RequestParameters + workflowToolConfig.OutputParametersConfig = wf.FCSetting.ResponseParameters + } + + locator := vo.FromDraft + if wf.WorkflowVersion != "" { + locator = vo.FromSpecificVersion + } + + wfTool, err := workflow.GetRepository().WorkflowAsTool(ctx, vo.GetPolicy{ + ID: wfID, + QType: locator, + Version: wf.WorkflowVersion, + }, workflowToolConfig) + if err != nil { + return nil, err + } + tools = append(tools, wfTool) + if wfTool.TerminatePlan() == vo.UseAnswerContent { + if toolsReturnDirectly == nil { + toolsReturnDirectly = make(map[string]bool) + } + toolInfo, err := wfTool.Info(ctx) + if err != nil { + return nil, err + } + toolsReturnDirectly[toolInfo.Name] = true + } + } + } + + if fcParams.PluginFCParam != nil { + pluginToolsInvokableReq := make(map[int64]*plugin.ToolsInvokableRequest) + for _, p := range fcParams.PluginFCParam.PluginList { + pid, err := strconv.ParseInt(p.PluginID, 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid plugin id: %s", p.PluginID) + } + toolID, err := strconv.ParseInt(p.ApiId, 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid plugin id: %s", p.PluginID) + } + + var ( + requestParameters []*workflow3.APIParameter + responseParameters []*workflow3.APIParameter + ) + if p.FCSetting != nil { + requestParameters = p.FCSetting.RequestParameters + responseParameters = p.FCSetting.ResponseParameters + } + + if req, ok := pluginToolsInvokableReq[pid]; ok { + req.ToolsInvokableInfo[toolID] = &plugin.ToolsInvokableInfo{ + ToolID: toolID, + RequestAPIParametersConfig: requestParameters, + ResponseAPIParametersConfig: responseParameters, + } + } else { + pluginToolsInfoRequest := &plugin.ToolsInvokableRequest{ + PluginEntity: plugin.Entity{ + PluginID: pid, + PluginVersion: ptr.Of(p.PluginVersion), + }, + ToolsInvokableInfo: map[int64]*plugin.ToolsInvokableInfo{ + toolID: { + ToolID: toolID, + RequestAPIParametersConfig: requestParameters, + ResponseAPIParametersConfig: responseParameters, + }, + }, + IsDraft: p.IsDraft, + } + pluginToolsInvokableReq[pid] = pluginToolsInfoRequest + } + } + inInvokableTools := make([]tool.BaseTool, 0, len(fcParams.PluginFCParam.PluginList)) + for _, req := range pluginToolsInvokableReq { + toolMap, err := plugin.GetPluginService().GetPluginInvokableTools(ctx, req) + if err != nil { + return nil, err + } + for _, t := range toolMap { + inInvokableTools = append(inInvokableTools, plugin.NewInvokableTool(t)) + } + } + if len(inInvokableTools) > 0 { + tools = append(tools, inInvokableTools...) + } + } + + if fcParams.KnowledgeFCParam != nil && len(fcParams.KnowledgeFCParam.KnowledgeList) > 0 { + kwChatModel := workflow.GetRepository().GetKnowledgeRecallChatModel() + if kwChatModel == nil { + return nil, fmt.Errorf("workflow builtin chat model for knowledge recall not configured") + } + + knowledgeOperator := knowledge.GetKnowledgeOperator() + setting := fcParams.KnowledgeFCParam.GlobalSetting + knowledgeRecallConfig = &KnowledgeRecallConfig{ + ChatModel: kwChatModel, + Retriever: knowledgeOperator, + } + searchType, err := toRetrievalSearchType(setting.SearchMode) + if err != nil { + return nil, err + } + knowledgeRecallConfig.RetrievalStrategy = &RetrievalStrategy{ + RetrievalStrategy: &knowledge.RetrievalStrategy{ + TopK: ptr.Of(setting.TopK), + MinScore: ptr.Of(setting.MinScore), + SearchType: searchType, + EnableNL2SQL: setting.UseNL2SQL, + EnableQueryRewrite: setting.UseRewrite, + EnableRerank: setting.UseRerank, + }, + NoReCallReplyMode: NoReCallReplyMode(setting.NoRecallReplyMode), + NoReCallReplyCustomizePrompt: setting.NoRecallReplyCustomizePrompt, + } + + knowledgeIDs := make([]int64, 0, len(fcParams.KnowledgeFCParam.KnowledgeList)) + for _, kw := range fcParams.KnowledgeFCParam.KnowledgeList { + kid, err := strconv.ParseInt(kw.ID, 10, 64) + if err != nil { + return nil, err + } + knowledgeIDs = append(knowledgeIDs, kid) + } + + detailResp, err := knowledgeOperator.ListKnowledgeDetail(ctx, + &knowledge.ListKnowledgeDetailRequest{ + KnowledgeIDs: knowledgeIDs, + }) + if err != nil { + return nil, err + } + knowledgeRecallConfig.SelectedKnowledgeDetails = detailResp.KnowledgeDetails + } + } -func New(ctx context.Context, cfg *Config) (*LLM, error) { g := compose.NewGraph[map[string]any, map[string]any](compose.WithGenLocalState(func(ctx context.Context) (state llmState) { return llmState{} })) - var ( - hasReasoning bool - canStream = true - ) + var hasReasoning bool - format := cfg.OutputFormat + format := c.OutputFormat if format == FormatJSON { - if len(cfg.OutputFields) == 1 { - for _, v := range cfg.OutputFields { + if len(ns.OutputTypes) == 1 { + for _, v := range ns.OutputTypes { if v.Type == vo.DataTypeString { format = FormatText break } } - } else if len(cfg.OutputFields) == 2 { - if _, ok := cfg.OutputFields[ReasoningOutputKey]; ok { - for k, v := range cfg.OutputFields { + } else if len(ns.OutputTypes) == 2 { + if _, ok := ns.OutputTypes[ReasoningOutputKey]; ok { + for k, v := range ns.OutputTypes { if k != ReasoningOutputKey && v.Type == vo.DataTypeString { format = FormatText break @@ -272,10 +561,10 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) { } } - userPrompt := cfg.UserPrompt + userPrompt := c.UserPrompt switch format { case FormatJSON: - jsonSchema, err := vo.TypeInfoToJSONSchema(cfg.OutputFields, nil) + jsonSchema, err := vo.TypeInfoToJSONSchema(ns.OutputTypes, nil) if err != nil { return nil, err } @@ -287,20 +576,20 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) { case FormatText: } - if cfg.KnowledgeRecallConfig != nil { - err := injectKnowledgeTool(ctx, g, cfg.UserPrompt, cfg.KnowledgeRecallConfig) + if knowledgeRecallConfig != nil { + err := injectKnowledgeTool(ctx, g, c.UserPrompt, knowledgeRecallConfig) if err != nil { return nil, err } userPrompt = fmt.Sprintf("{{%s}}%s", knowledgeUserPromptTemplateKey, userPrompt) - inputs := maps.Clone(cfg.InputFields) + inputs := maps.Clone(ns.InputTypes) inputs[knowledgeUserPromptTemplateKey] = &vo.TypeInfo{ Type: vo.DataTypeString, } - sp := newPromptTpl(schema.System, cfg.SystemPrompt, inputs, nil) + sp := newPromptTpl(schema.System, c.SystemPrompt, inputs, nil) up := newPromptTpl(schema.User, userPrompt, inputs, []string{knowledgeUserPromptTemplateKey}) - template := newPrompts(sp, up, cfg.ChatModel) + template := newPrompts(sp, up, modelWithInfo) _ = g.AddChatTemplateNode(templateNodeKey, template, compose.WithStatePreHandler(func(ctx context.Context, in map[string]any, state llmState) (map[string]any, error) { @@ -312,28 +601,28 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) { _ = g.AddEdge(knowledgeLambdaKey, templateNodeKey) } else { - sp := newPromptTpl(schema.System, cfg.SystemPrompt, cfg.InputFields, nil) - up := newPromptTpl(schema.User, userPrompt, cfg.InputFields, nil) - template := newPrompts(sp, up, cfg.ChatModel) + sp := newPromptTpl(schema.System, c.SystemPrompt, ns.InputTypes, nil) + up := newPromptTpl(schema.User, userPrompt, ns.InputTypes, nil) + template := newPrompts(sp, up, modelWithInfo) _ = g.AddChatTemplateNode(templateNodeKey, template) _ = g.AddEdge(compose.START, templateNodeKey) } - if len(cfg.Tools) > 0 { - m, ok := cfg.ChatModel.(model.ToolCallingChatModel) + if len(tools) > 0 { + m, ok := modelWithInfo.(model.ToolCallingChatModel) if !ok { return nil, errors.New("requires a ToolCallingChatModel to use with tools") } reactConfig := react.AgentConfig{ ToolCallingModel: m, - ToolsConfig: compose.ToolsNodeConfig{Tools: cfg.Tools}, + ToolsConfig: compose.ToolsNodeConfig{Tools: tools}, ModelNodeName: agentModelName, } - if len(cfg.ToolsReturnDirectly) > 0 { - reactConfig.ToolReturnDirectly = make(map[string]struct{}, len(cfg.ToolsReturnDirectly)) - for k := range cfg.ToolsReturnDirectly { + if len(toolsReturnDirectly) > 0 { + reactConfig.ToolReturnDirectly = make(map[string]struct{}, len(toolsReturnDirectly)) + for k := range toolsReturnDirectly { reactConfig.ToolReturnDirectly[k] = struct{}{} } } @@ -347,28 +636,26 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) { opts = append(opts, compose.WithNodeName("workflow_llm_react_agent")) _ = g.AddGraphNode(llmNodeKey, agentNode, opts...) } else { - _ = g.AddChatModelNode(llmNodeKey, cfg.ChatModel) + _ = g.AddChatModelNode(llmNodeKey, modelWithInfo) } _ = g.AddEdge(templateNodeKey, llmNodeKey) if format == FormatJSON { iConvert := func(ctx context.Context, msg *schema.Message) (map[string]any, error) { - return jsonParse(ctx, msg.Content, cfg.OutputFields) + return jsonParse(ctx, msg.Content, ns.OutputTypes) } convertNode := compose.InvokableLambda(iConvert) _ = g.AddLambdaNode(outputConvertNodeKey, convertNode) - - canStream = false } else { var outputKey string - if len(cfg.OutputFields) != 1 && len(cfg.OutputFields) != 2 { + if len(ns.OutputTypes) != 1 && len(ns.OutputTypes) != 2 { panic("impossible") } - for k, v := range cfg.OutputFields { + for k, v := range ns.OutputTypes { if v.Type != vo.DataTypeString { panic("impossible") } @@ -443,17 +730,17 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) { _ = g.AddEdge(outputConvertNodeKey, compose.END) requireCheckpoint := false - if len(cfg.Tools) > 0 { + if len(tools) > 0 { requireCheckpoint = true } - var opts []compose.GraphCompileOption + var compileOpts []compose.GraphCompileOption if requireCheckpoint { - opts = append(opts, compose.WithCheckPointStore(workflow.GetRepository())) + compileOpts = append(compileOpts, compose.WithCheckPointStore(workflow.GetRepository())) } - opts = append(opts, compose.WithGraphName("workflow_llm_node_graph")) + compileOpts = append(compileOpts, compose.WithGraphName("workflow_llm_node_graph")) - r, err := g.Compile(ctx, opts...) + r, err := g.Compile(ctx, compileOpts...) if err != nil { return nil, err } @@ -461,15 +748,132 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) { llm := &LLM{ r: r, outputFormat: format, - canStream: canStream, requireCheckpoint: requireCheckpoint, - fullSources: cfg.FullSources, + fullSources: ns.FullSources, } return llm, nil } -func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...Option) (composeOpts []compose.Option, resumingEvent *entity.InterruptEvent, err error) { +func (c *Config) RequireCheckpoint() bool { + if c.FCParam != nil { + if c.FCParam.WorkflowFCParam != nil || c.FCParam.PluginFCParam != nil { + return true + } + } + + return false +} + +func (c *Config) FieldStreamType(path compose.FieldPath, ns *schema2.NodeSchema, + sc *schema2.WorkflowSchema) (schema2.FieldStreamType, error) { + if !sc.RequireStreaming() { + return schema2.FieldNotStream, nil + } + + if len(path) != 1 { + return schema2.FieldNotStream, nil + } + + outputs := ns.OutputTypes + if len(outputs) != 1 && len(outputs) != 2 { + return schema2.FieldNotStream, nil + } + + var outputKey string + for key, output := range outputs { + if output.Type != vo.DataTypeString { + return schema2.FieldNotStream, nil + } + + if key != ReasoningOutputKey { + if len(outputKey) > 0 { + return schema2.FieldNotStream, nil + } + outputKey = key + } + } + + field := path[0] + if field == ReasoningOutputKey || field == outputKey { + return schema2.FieldIsStream, nil + } + + return schema2.FieldNotStream, nil +} + +func toRetrievalSearchType(s int64) (knowledge.SearchType, error) { + switch s { + case 0: + return knowledge.SearchTypeSemantic, nil + case 1: + return knowledge.SearchTypeHybrid, nil + case 20: + return knowledge.SearchTypeFullText, nil + default: + return "", fmt.Errorf("invalid retrieval search type %v", s) + } +} + +type LLM struct { + r compose.Runnable[map[string]any, map[string]any] + outputFormat Format + requireCheckpoint bool + fullSources map[string]*schema2.SourceInfo +} + +const ( + rawOutputKey = "llm_raw_output_%s" + warningKey = "llm_warning_%s" +) + +func jsonParse(ctx context.Context, data string, schema_ map[string]*vo.TypeInfo) (map[string]any, error) { + data = nodes.ExtractJSONString(data) + + var result map[string]any + + err := sonic.UnmarshalString(data, &result) + if err != nil { + c := execute.GetExeCtx(ctx) + if c != nil { + logs.CtxErrorf(ctx, "failed to parse json: %v, data: %s", err, data) + rawOutputK := fmt.Sprintf(rawOutputKey, c.NodeCtx.NodeKey) + warningK := fmt.Sprintf(warningKey, c.NodeCtx.NodeKey) + ctxcache.Store(ctx, rawOutputK, data) + ctxcache.Store(ctx, warningK, vo.WrapWarn(errno.ErrLLMStructuredOutputParseFail, err)) + return map[string]any{}, nil + } + + return nil, err + } + + r, ws, err := nodes.ConvertInputs(ctx, result, schema_) + if err != nil { + return nil, vo.WrapError(errno.ErrLLMStructuredOutputParseFail, err) + } + + if ws != nil { + logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws) + } + + return r, nil +} + +type llmOptions struct { + toolWorkflowSW *schema.StreamWriter[*entity.Message] +} + +func WithToolWorkflowMessageWriter(sw *schema.StreamWriter[*entity.Message]) nodes.NodeOption { + return nodes.WrapImplSpecificOptFn(func(o *llmOptions) { + o.toolWorkflowSW = sw + }) +} + +type llmState = map[string]any + +const agentModelName = "agent_model" + +func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...nodes.NodeOption) (composeOpts []compose.Option, resumingEvent *entity.InterruptEvent, err error) { c := execute.GetExeCtx(ctx) if c != nil { resumingEvent = c.NodeCtx.ResumingEvent @@ -502,17 +906,9 @@ func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...Option) (co composeOpts = append(composeOpts, compose.WithCheckPointID(checkpointID)) } - llmOpts := &Options{} - for _, opt := range opts { - opt(llmOpts) - } + options := nodes.GetCommonOptions(&nodes.NodeOptions{}, opts...) - nestedOpts := &nodes.NestedWorkflowOptions{} - for _, opt := range llmOpts.nested { - opt(nestedOpts) - } - - composeOpts = append(composeOpts, nestedOpts.GetOptsForNested()...) + composeOpts = append(composeOpts, options.GetOptsForNested()...) if resumingEvent != nil { var ( @@ -580,6 +976,7 @@ func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...Option) (co composeOpts = append(composeOpts, compose.WithToolsNodeOption(compose.WithToolOption(execute.WithExecuteConfig(exeCfg)))) } + llmOpts := nodes.GetImplSpecificOptions(&llmOptions{}, opts...) if llmOpts.toolWorkflowSW != nil { toolMsgOpt, toolMsgSR := execute.WithMessagePipe() composeOpts = append(composeOpts, toolMsgOpt) @@ -697,7 +1094,7 @@ func handleInterrupt(ctx context.Context, err error, resumingEvent *entity.Inter return compose.NewInterruptAndRerunErr(ie) } -func (l *LLM) Chat(ctx context.Context, in map[string]any, opts ...Option) (out map[string]any, err error) { +func (l *LLM) Invoke(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (out map[string]any, err error) { composeOpts, resumingEvent, err := l.prepare(ctx, in, opts...) if err != nil { return nil, err @@ -712,7 +1109,7 @@ func (l *LLM) Chat(ctx context.Context, in map[string]any, opts ...Option) (out return out, nil } -func (l *LLM) ChatStream(ctx context.Context, in map[string]any, opts ...Option) (out *schema.StreamReader[map[string]any], err error) { +func (l *LLM) Stream(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (out *schema.StreamReader[map[string]any], err error) { composeOpts, resumingEvent, err := l.prepare(ctx, in, opts...) if err != nil { return nil, err @@ -745,7 +1142,7 @@ func injectKnowledgeTool(_ context.Context, g *compose.Graph[map[string]any, map _ = g.AddLambdaNode(knowledgeLambdaKey, compose.InvokableLambda(func(ctx context.Context, input *schema.Message) (output map[string]any, err error) { modelPredictionIDs := strings.Split(input.Content, ",") - selectKwIDs := slices.ToMap(cfg.SelectedKnowledgeDetails, func(e *crossknowledge.KnowledgeDetail) (string, int64) { + selectKwIDs := slices.ToMap(cfg.SelectedKnowledgeDetails, func(e *knowledge.KnowledgeDetail) (string, int64) { return strconv.Itoa(int(e.ID)), e.ID }) recallKnowledgeIDs := make([]int64, 0) @@ -759,7 +1156,7 @@ func injectKnowledgeTool(_ context.Context, g *compose.Graph[map[string]any, map return make(map[string]any), nil } - docs, err := cfg.Retriever.Retrieve(ctx, &crossknowledge.RetrieveRequest{ + docs, err := cfg.Retriever.Retrieve(ctx, &knowledge.RetrieveRequest{ Query: userPrompt, KnowledgeIDs: recallKnowledgeIDs, RetrievalStrategy: cfg.RetrievalStrategy.RetrievalStrategy, diff --git a/backend/domain/workflow/internal/nodes/llm/prompt.go b/backend/domain/workflow/internal/nodes/llm/prompt.go index 7c275160..2cc5d1be 100644 --- a/backend/domain/workflow/internal/nodes/llm/prompt.go +++ b/backend/domain/workflow/internal/nodes/llm/prompt.go @@ -26,6 +26,7 @@ import ( "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" + schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr" "github.com/coze-dev/coze-studio/backend/pkg/ctxcache" "github.com/coze-dev/coze-studio/backend/pkg/sonic" @@ -107,7 +108,7 @@ func newPrompts(sp, up *promptTpl, model ModelWithInfo) *prompts { } func (pl *promptTpl) render(ctx context.Context, vs map[string]any, - sources map[string]*nodes.SourceInfo, + sources map[string]*schema2.SourceInfo, supportedModals map[modelmgr.Modal]bool, ) (*schema.Message, error) { if !pl.hasMultiModal || len(supportedModals) == 0 { @@ -247,7 +248,7 @@ func (p *prompts) Format(ctx context.Context, vs map[string]any, _ ...prompt.Opt } sk := fmt.Sprintf(sourceKey, nodeKey) - sources, ok := ctxcache.Get[map[string]*nodes.SourceInfo](ctx, sk) + sources, ok := ctxcache.Get[map[string]*schema2.SourceInfo](ctx, sk) if !ok { return nil, fmt.Errorf("resolved sources not found llm node, key: %s", sk) } diff --git a/backend/domain/workflow/internal/nodes/loop/break.go b/backend/domain/workflow/internal/nodes/loop/break/break.go similarity index 54% rename from backend/domain/workflow/internal/nodes/loop/break.go rename to backend/domain/workflow/internal/nodes/loop/break/break.go index 30fada39..9233670a 100644 --- a/backend/domain/workflow/internal/nodes/loop/break.go +++ b/backend/domain/workflow/internal/nodes/loop/break/break.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package loop +package _break import ( "context" @@ -22,21 +22,36 @@ import ( "github.com/cloudwego/eino/compose" "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable" + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" ) type Break struct { parentIntermediateStore variable.Store } -func NewBreak(_ context.Context, store variable.Store) (*Break, error) { +type Config struct{} + +func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) { + return &schema.NodeSchema{ + Key: vo.NodeKey(n.ID), + Type: entity.NodeTypeBreak, + Name: n.Data.Meta.Title, + Configs: c, + }, nil +} + +func (c *Config) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) { return &Break{ - parentIntermediateStore: store, + parentIntermediateStore: &nodes.ParentIntermediateStore{}, }, nil } const BreakKey = "$break" -func (b *Break) DoBreak(ctx context.Context, _ map[string]any) (map[string]any, error) { +func (b *Break) Invoke(ctx context.Context, _ map[string]any) (map[string]any, error) { err := b.parentIntermediateStore.Set(ctx, compose.FieldPath{BreakKey}, true) if err != nil { return nil, err diff --git a/backend/domain/workflow/internal/nodes/loop/continue/continue.go b/backend/domain/workflow/internal/nodes/loop/continue/continue.go new file mode 100644 index 00000000..cd7762a3 --- /dev/null +++ b/backend/domain/workflow/internal/nodes/loop/continue/continue.go @@ -0,0 +1,47 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package _continue + +import ( + "context" + + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" +) + +type Continue struct{} + +type Config struct{} + +func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) { + return &schema.NodeSchema{ + Key: vo.NodeKey(n.ID), + Type: entity.NodeTypeContinue, + Name: n.Data.Meta.Title, + Configs: c, + }, nil +} + +func (c *Config) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) { + return &Continue{}, nil +} + +func (co *Continue) Invoke(_ context.Context, in map[string]any) (map[string]any, error) { + return in, nil +} diff --git a/backend/domain/workflow/internal/nodes/loop/loop.go b/backend/domain/workflow/internal/nodes/loop/loop.go index ab570a46..fa85c4aa 100644 --- a/backend/domain/workflow/internal/nodes/loop/loop.go +++ b/backend/domain/workflow/internal/nodes/loop/loop.go @@ -27,53 +27,150 @@ import ( "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" + _break "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop/break" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" ) type Loop struct { - config *Config outputs map[string]*vo.FieldSource outputVars map[string]string + inner compose.Runnable[map[string]any, map[string]any] + nodeKey vo.NodeKey + + loopType Type + inputArrays []string + intermediateVars map[string]*vo.TypeInfo } type Config struct { - LoopNodeKey vo.NodeKey LoopType Type InputArrays []string IntermediateVars map[string]*vo.TypeInfo - Outputs []*vo.FieldInfo - - Inner compose.Runnable[map[string]any, map[string]any] } -type Type string - -const ( - ByArray Type = "by_array" - ByIteration Type = "by_iteration" - Infinite Type = "infinite" -) - -func NewLoop(_ context.Context, conf *Config) (*Loop, error) { - if conf == nil { - return nil, errors.New("config is nil") +func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) { + if n.Parent() != nil { + return nil, fmt.Errorf("loop node cannot have parent: %s", n.Parent().ID) } - if conf.LoopType == ByArray { - if len(conf.InputArrays) == 0 { + ns := &schema.NodeSchema{ + Key: vo.NodeKey(n.ID), + Type: entity.NodeTypeLoop, + Name: n.Data.Meta.Title, + Configs: c, + } + + loopType, err := toLoopType(n.Data.Inputs.LoopType) + if err != nil { + return nil, err + } + c.LoopType = loopType + + intermediateVars := make(map[string]*vo.TypeInfo) + for _, param := range n.Data.Inputs.VariableParameters { + tInfo, err := convert.CanvasBlockInputToTypeInfo(param.Input) + if err != nil { + return nil, err + } + intermediateVars[param.Name] = tInfo + + ns.SetInputType(param.Name, tInfo) + sources, err := convert.CanvasBlockInputToFieldInfo(param.Input, compose.FieldPath{param.Name}, nil) + if err != nil { + return nil, err + } + ns.AddInputSource(sources...) + } + c.IntermediateVars = intermediateVars + + if err := convert.SetInputsForNodeSchema(n, ns); err != nil { + return nil, err + } + + if err := convert.SetOutputsForNodeSchema(n, ns); err != nil { + return nil, err + } + + for _, fieldInfo := range ns.OutputSources { + if fieldInfo.Source.Ref != nil { + if len(fieldInfo.Source.Ref.FromPath) == 1 { + if _, ok := intermediateVars[fieldInfo.Source.Ref.FromPath[0]]; ok { + fieldInfo.Source.Ref.VariableType = ptr.Of(vo.ParentIntermediate) + } + } + } + } + + loopCount := n.Data.Inputs.LoopCount + if loopCount != nil { + typeInfo, err := convert.CanvasBlockInputToTypeInfo(loopCount) + if err != nil { + return nil, err + } + ns.SetInputType(Count, typeInfo) + + sources, err := convert.CanvasBlockInputToFieldInfo(loopCount, compose.FieldPath{Count}, nil) + if err != nil { + return nil, err + } + ns.AddInputSource(sources...) + } + + for key, tInfo := range ns.InputTypes { + if tInfo.Type != vo.DataTypeArray { + continue + } + + if _, ok := intermediateVars[key]; ok { // exclude arrays in intermediate vars + continue + } + + c.InputArrays = append(c.InputArrays, key) + } + + return ns, nil +} + +func toLoopType(l vo.LoopType) (Type, error) { + switch l { + case vo.LoopTypeArray: + return ByArray, nil + case vo.LoopTypeCount: + return ByIteration, nil + case vo.LoopTypeInfinite: + return Infinite, nil + default: + return "", fmt.Errorf("unsupported loop type: %s", l) + } +} + +func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, opts ...schema.BuildOption) (any, error) { + if c.LoopType == ByArray { + if len(c.InputArrays) == 0 { return nil, errors.New("input arrays is empty when loop type is ByArray") } } - loop := &Loop{ - config: conf, - outputs: make(map[string]*vo.FieldSource), - outputVars: make(map[string]string), + options := schema.GetBuildOptions(opts...) + if options.Inner == nil { + return nil, errors.New("inner workflow is required for Loop Node") } - for _, info := range conf.Outputs { + loop := &Loop{ + outputs: make(map[string]*vo.FieldSource), + outputVars: make(map[string]string), + inputArrays: c.InputArrays, + nodeKey: ns.Key, + intermediateVars: c.IntermediateVars, + inner: options.Inner, + loopType: c.LoopType, + } + + for _, info := range ns.OutputSources { if len(info.Path) != 1 { return nil, fmt.Errorf("invalid output path: %s", info.Path) } @@ -87,7 +184,7 @@ func NewLoop(_ context.Context, conf *Config) (*Loop, error) { return nil, fmt.Errorf("loop output refers to intermediate variable, but path length > 1: %v", fromPath) } - if _, ok := conf.IntermediateVars[fromPath[0]]; !ok { + if _, ok := c.IntermediateVars[fromPath[0]]; !ok { return nil, fmt.Errorf("loop output refers to intermediate variable, but not found in intermediate vars: %v", fromPath) } @@ -102,18 +199,27 @@ func NewLoop(_ context.Context, conf *Config) (*Loop, error) { return loop, nil } +type Type string + +const ( + ByArray Type = "by_array" + ByIteration Type = "by_iteration" + Infinite Type = "infinite" +) + const ( Count = "loopCount" ) -func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.NestedWorkflowOption) (out map[string]any, err error) { +func (l *Loop) Invoke(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) ( + out map[string]any, err error) { maxIter, err := l.getMaxIter(in) if err != nil { return nil, err } - arrays := make(map[string][]any, len(l.config.InputArrays)) - for _, arrayKey := range l.config.InputArrays { + arrays := make(map[string][]any, len(l.inputArrays)) + for _, arrayKey := range l.inputArrays { a, ok := nodes.TakeMapValue(in, compose.FieldPath{arrayKey}) if !ok { return nil, fmt.Errorf("incoming array not present in input: %s", arrayKey) @@ -121,10 +227,7 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes arrays[arrayKey] = a.([]any) } - options := &nodes.NestedWorkflowOptions{} - for _, opt := range opts { - opt(options) - } + options := nodes.GetCommonOptions(&nodes.NodeOptions{}, opts...) var ( existingCState *nodes.NestedWorkflowState @@ -134,7 +237,7 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes ) err = compose.ProcessState(ctx, func(ctx context.Context, getter nodes.NestedWorkflowAware) error { var e error - existingCState, _, e = getter.GetNestedWorkflowState(l.config.LoopNodeKey) + existingCState, _, e = getter.GetNestedWorkflowState(l.nodeKey) if e != nil { return e } @@ -150,15 +253,15 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes for k := range existingCState.IntermediateVars { intermediateVars[k] = ptr.Of(existingCState.IntermediateVars[k]) } - intermediateVars[BreakKey] = &hasBreak + intermediateVars[_break.BreakKey] = &hasBreak } else { output = make(map[string]any, len(l.outputs)) for k := range l.outputs { output[k] = make([]any, 0) } - intermediateVars = make(map[string]*any, len(l.config.IntermediateVars)) - for varKey := range l.config.IntermediateVars { + intermediateVars = make(map[string]*any, len(l.intermediateVars)) + for varKey := range l.intermediateVars { v, ok := nodes.TakeMapValue(in, compose.FieldPath{varKey}) if !ok { return nil, fmt.Errorf("incoming intermediate variable not present in input: %s", varKey) @@ -166,10 +269,10 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes intermediateVars[varKey] = &v } - intermediateVars[BreakKey] = &hasBreak + intermediateVars[_break.BreakKey] = &hasBreak } - ctx = nodes.InitIntermediateVars(ctx, intermediateVars, l.config.IntermediateVars) + ctx = nodes.InitIntermediateVars(ctx, intermediateVars, l.intermediateVars) getIthInput := func(i int) (map[string]any, map[string]any, error) { input := make(map[string]any) @@ -190,13 +293,13 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes input[k] = v } - input[string(l.config.LoopNodeKey)+"#index"] = int64(i) + input[string(l.nodeKey)+"#index"] = int64(i) items := make(map[string]any) for arrayKey := range arrays { ele := arrays[arrayKey][i] items[arrayKey] = ele - currentKey := string(l.config.LoopNodeKey) + "#" + arrayKey + currentKey := string(l.nodeKey) + "#" + arrayKey // Recursively expand map[string]any elements var expand func(prefix string, val interface{}) @@ -276,7 +379,7 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes } } - taskOutput, err := l.config.Inner.Invoke(subCtx, input, ithOpts...) + taskOutput, err := l.inner.Invoke(subCtx, input, ithOpts...) if err != nil { info, ok := compose.ExtractInterruptInfo(err) if !ok { @@ -322,29 +425,26 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes if len(index2InterruptInfo) > 0 { // this invocation of batch.Execute has new interruptions iEvent := &entity.InterruptEvent{ - NodeKey: l.config.LoopNodeKey, + NodeKey: l.nodeKey, NodeType: entity.NodeTypeLoop, NestedInterruptInfo: index2InterruptInfo, // only emit the newly generated interruptInfo } err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error { - if e := setter.SaveNestedWorkflowState(l.config.LoopNodeKey, compState); e != nil { + if e := setter.SaveNestedWorkflowState(l.nodeKey, compState); e != nil { return e } - return setter.SetInterruptEvent(l.config.LoopNodeKey, iEvent) + return setter.SetInterruptEvent(l.nodeKey, iEvent) }) if err != nil { return nil, err } - fmt.Println("save interruptEvent in state within loop: ", iEvent) - fmt.Println("save composite info in state within loop: ", compState) - return nil, compose.InterruptAndRerun } else { err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error { - return setter.SaveNestedWorkflowState(l.config.LoopNodeKey, compState) + return setter.SaveNestedWorkflowState(l.nodeKey, compState) }) if err != nil { return nil, err @@ -354,8 +454,7 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes } if existingCState != nil && len(existingCState.Index2InterruptInfo) > 0 { - fmt.Println("no interrupt thrown this round, but has historical interrupt events: ", existingCState.Index2InterruptInfo) - panic("impossible") + panic(fmt.Sprintf("no interrupt thrown this round, but has historical interrupt events: %v", existingCState.Index2InterruptInfo)) } for outputVarKey, intermediateVarKey := range l.outputVars { @@ -368,9 +467,9 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes func (l *Loop) getMaxIter(in map[string]any) (int, error) { maxIter := math.MaxInt - switch l.config.LoopType { + switch l.loopType { case ByArray: - for _, arrayKey := range l.config.InputArrays { + for _, arrayKey := range l.inputArrays { a, ok := nodes.TakeMapValue(in, compose.FieldPath{arrayKey}) if !ok { return 0, fmt.Errorf("incoming array not present in input: %s", arrayKey) @@ -394,7 +493,7 @@ func (l *Loop) getMaxIter(in map[string]any) (int, error) { maxIter = int(iter.(int64)) case Infinite: default: - return 0, fmt.Errorf("loop type not supported: %v", l.config.LoopType) + return 0, fmt.Errorf("loop type not supported: %v", l.loopType) } return maxIter, nil @@ -409,8 +508,8 @@ func convertIntermediateVars(vars map[string]*any) map[string]any { } func (l *Loop) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) { - trimmed := make(map[string]any, len(l.config.InputArrays)) - for _, arrayKey := range l.config.InputArrays { + trimmed := make(map[string]any, len(l.inputArrays)) + for _, arrayKey := range l.inputArrays { if v, ok := in[arrayKey]; ok { trimmed[arrayKey] = v } diff --git a/backend/domain/workflow/internal/nodes/nested.go b/backend/domain/workflow/internal/nodes/nested.go deleted file mode 100644 index b7af02d0..00000000 --- a/backend/domain/workflow/internal/nodes/nested.go +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Copyright 2025 coze-dev Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package nodes - -import ( - "github.com/bytedance/sonic" - "github.com/cloudwego/eino/compose" - - "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" -) - -type NestedWorkflowOptions struct { - optsForNested []compose.Option - toResumeIndexes map[int]compose.StateModifier - optsForIndexed map[int][]compose.Option -} - -type NestedWorkflowOption func(*NestedWorkflowOptions) - -func WithOptsForNested(opts ...compose.Option) NestedWorkflowOption { - return func(o *NestedWorkflowOptions) { - o.optsForNested = append(o.optsForNested, opts...) - } -} - -func (c *NestedWorkflowOptions) GetOptsForNested() []compose.Option { - return c.optsForNested -} - -func WithResumeIndex(i int, m compose.StateModifier) NestedWorkflowOption { - return func(o *NestedWorkflowOptions) { - if o.toResumeIndexes == nil { - o.toResumeIndexes = map[int]compose.StateModifier{} - } - - o.toResumeIndexes[i] = m - } -} - -func (c *NestedWorkflowOptions) GetResumeIndexes() map[int]compose.StateModifier { - return c.toResumeIndexes -} - -func WithOptsForIndexed(index int, opts ...compose.Option) NestedWorkflowOption { - return func(o *NestedWorkflowOptions) { - if o.optsForIndexed == nil { - o.optsForIndexed = map[int][]compose.Option{} - } - o.optsForIndexed[index] = opts - } -} - -func (c *NestedWorkflowOptions) GetOptsForIndexed(index int) []compose.Option { - if c.optsForIndexed == nil { - return nil - } - return c.optsForIndexed[index] -} - -type NestedWorkflowState struct { - Index2Done map[int]bool `json:"index_2_done,omitempty"` - Index2InterruptInfo map[int]*compose.InterruptInfo `json:"index_2_interrupt_info,omitempty"` - FullOutput map[string]any `json:"full_output,omitempty"` - IntermediateVars map[string]any `json:"intermediate_vars,omitempty"` -} - -func (c *NestedWorkflowState) String() string { - s, _ := sonic.MarshalIndent(c, "", " ") - return string(s) -} - -type NestedWorkflowAware interface { - SaveNestedWorkflowState(key vo.NodeKey, state *NestedWorkflowState) error - GetNestedWorkflowState(key vo.NodeKey) (*NestedWorkflowState, bool, error) - InterruptEventStore -} diff --git a/backend/domain/workflow/internal/nodes/node.go b/backend/domain/workflow/internal/nodes/node.go new file mode 100644 index 00000000..d24691d6 --- /dev/null +++ b/backend/domain/workflow/internal/nodes/node.go @@ -0,0 +1,194 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nodes + +import ( + "context" + "fmt" + + "github.com/cloudwego/eino/compose" + einoschema "github.com/cloudwego/eino/schema" + + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" +) + +// InvokableNode is a basic workflow node that can Invoke. +// Invoke accepts non-streaming input and returns non-streaming output. +// It does not accept any options. +// Most nodes implement this, such as NodeTypePlugin. +type InvokableNode interface { + Invoke(ctx context.Context, input map[string]any) ( + output map[string]any, err error) +} + +// InvokableNodeWOpt is a workflow node that can Invoke. +// Invoke accepts non-streaming input and returns non-streaming output. +// It can accept NodeOption. +// e.g. NodeTypeLLM, NodeTypeSubWorkflow implement this. +type InvokableNodeWOpt interface { + Invoke(ctx context.Context, in map[string]any, opts ...NodeOption) ( + map[string]any, error) +} + +// StreamableNode is a workflow node that can Stream. +// Stream accepts non-streaming input and returns streaming output. +// It does not accept and options +// Currently NO Node implement this. +// A potential example would be streamable plugin for NodeTypePlugin. +type StreamableNode interface { + Stream(ctx context.Context, in map[string]any) ( + *einoschema.StreamReader[map[string]any], error) +} + +// StreamableNodeWOpt is a workflow node that can Stream. +// Stream accepts non-streaming input and returns streaming output. +// It can accept NodeOption. +// e.g. NodeTypeLLM implement this. +type StreamableNodeWOpt interface { + Stream(ctx context.Context, in map[string]any, opts ...NodeOption) ( + *einoschema.StreamReader[map[string]any], error) +} + +// CollectableNode is a workflow node that can Collect. +// Collect accepts streaming input and returns non-streaming output. +// It does not accept and options +// Currently NO Node implement this. +// A potential example would be a new condition node that makes decisions +// based on streaming input. +type CollectableNode interface { + Collect(ctx context.Context, in *einoschema.StreamReader[map[string]any]) ( + map[string]any, error) +} + +// CollectableNodeWOpt is a workflow node that can Collect. +// Collect accepts streaming input and returns non-streaming output. +// It accepts NodeOption. +// Currently NO Node implement this. +// A potential example would be a new batch node that accepts streaming input, +// process them, and finally returns non-stream aggregation of results. +type CollectableNodeWOpt interface { + Collect(ctx context.Context, in *einoschema.StreamReader[map[string]any], opts ...NodeOption) ( + map[string]any, error) +} + +// TransformableNode is a workflow node that can Transform. +// Transform accepts streaming input and returns streaming output. +// It does not accept and options +// e.g. +// NodeTypeVariableAggregator implements TransformableNode. +type TransformableNode interface { + Transform(ctx context.Context, in *einoschema.StreamReader[map[string]any]) ( + *einoschema.StreamReader[map[string]any], error) +} + +// TransformableNodeWOpt is a workflow node that can Transform. +// Transform accepts streaming input and returns streaming output. +// It accepts NodeOption. +// Currently NO Node implement this. +// A potential example would be an audio processing node that +// transforms input audio clips, but within the node is a graph +// composed by Eino, and the audio processing node needs to carry +// options for this inner graph. +type TransformableNodeWOpt interface { + Transform(ctx context.Context, in *einoschema.StreamReader[map[string]any], opts ...NodeOption) ( + *einoschema.StreamReader[map[string]any], error) +} + +// CallbackInputConverted converts node input to a form better suited for UI. +// The converted input will be displayed on canvas when test run, +// and will be returned when querying the node's input through OpenAPI. +type CallbackInputConverted interface { + ToCallbackInput(ctx context.Context, in map[string]any) (map[string]any, error) +} + +// CallbackOutputConverted converts node input to a form better suited for UI. +// The converted output will be displayed on canvas when test run, +// and will be returned when querying the node's output through OpenAPI. +type CallbackOutputConverted interface { + ToCallbackOutput(ctx context.Context, out map[string]any) (*StructuredCallbackOutput, error) +} + +type Initializer interface { + Init(ctx context.Context) (context.Context, error) +} + +type AdaptOptions struct { + Canvas *vo.Canvas +} + +type AdaptOption func(*AdaptOptions) + +func WithCanvas(canvas *vo.Canvas) AdaptOption { + return func(opts *AdaptOptions) { + opts.Canvas = canvas + } +} + +func GetAdaptOptions(opts ...AdaptOption) *AdaptOptions { + options := &AdaptOptions{} + for _, opt := range opts { + opt(options) + } + return options +} + +// NodeAdaptor provides conversion from frontend Node to backend NodeSchema. +type NodeAdaptor interface { + Adapt(ctx context.Context, n *vo.Node, opts ...AdaptOption) ( + *schema.NodeSchema, error) +} + +// BranchAdaptor provides validation and conversion from frontend port to backend port. +type BranchAdaptor interface { + ExpectPorts(ctx context.Context, n *vo.Node) []string +} + +var ( + nodeAdaptors = map[entity.NodeType]func() NodeAdaptor{} + branchAdaptors = map[entity.NodeType]func() BranchAdaptor{} +) + +func RegisterNodeAdaptor(et entity.NodeType, f func() NodeAdaptor) { + nodeAdaptors[et] = f +} + +func GetNodeAdaptor(et entity.NodeType) (NodeAdaptor, bool) { + na, ok := nodeAdaptors[et] + if !ok { + panic(fmt.Sprintf("node type %s not registered", et)) + } + return na(), ok +} + +func RegisterBranchAdaptor(et entity.NodeType, f func() BranchAdaptor) { + branchAdaptors[et] = f +} + +func GetBranchAdaptor(et entity.NodeType) (BranchAdaptor, bool) { + na, ok := branchAdaptors[et] + if !ok { + return nil, false + } + return na(), ok +} + +type StreamGenerator interface { + FieldStreamType(path compose.FieldPath, ns *schema.NodeSchema, + sc *schema.WorkflowSchema) (schema.FieldStreamType, error) +} diff --git a/backend/domain/workflow/internal/nodes/option.go b/backend/domain/workflow/internal/nodes/option.go new file mode 100644 index 00000000..9298f883 --- /dev/null +++ b/backend/domain/workflow/internal/nodes/option.go @@ -0,0 +1,170 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nodes + +import ( + "github.com/bytedance/sonic" + "github.com/cloudwego/eino/compose" + + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" +) + +type NodeOptions struct { + Nested *NestedWorkflowOptions +} + +type NestedWorkflowOptions struct { + optsForNested []compose.Option + toResumeIndexes map[int]compose.StateModifier + optsForIndexed map[int][]compose.Option +} + +type NodeOption struct { + apply func(opts *NodeOptions) + + implSpecificOptFn any +} + +type NestedWorkflowOption func(*NestedWorkflowOptions) + +func WithOptsForNested(opts ...compose.Option) NodeOption { + return NodeOption{ + apply: func(options *NodeOptions) { + if options.Nested == nil { + options.Nested = &NestedWorkflowOptions{} + } + options.Nested.optsForNested = append(options.Nested.optsForNested, opts...) + }, + } +} + +func (c *NodeOptions) GetOptsForNested() []compose.Option { + if c.Nested == nil { + return nil + } + return c.Nested.optsForNested +} + +func WithResumeIndex(i int, m compose.StateModifier) NodeOption { + return NodeOption{ + apply: func(options *NodeOptions) { + if options.Nested == nil { + options.Nested = &NestedWorkflowOptions{} + } + if options.Nested.toResumeIndexes == nil { + options.Nested.toResumeIndexes = map[int]compose.StateModifier{} + } + + options.Nested.toResumeIndexes[i] = m + }, + } +} + +func (c *NodeOptions) GetResumeIndexes() map[int]compose.StateModifier { + if c.Nested == nil { + return nil + } + return c.Nested.toResumeIndexes +} + +func WithOptsForIndexed(index int, opts ...compose.Option) NodeOption { + return NodeOption{ + apply: func(options *NodeOptions) { + if options.Nested == nil { + options.Nested = &NestedWorkflowOptions{} + } + if options.Nested.optsForIndexed == nil { + options.Nested.optsForIndexed = map[int][]compose.Option{} + } + options.Nested.optsForIndexed[index] = opts + }, + } +} + +func (c *NodeOptions) GetOptsForIndexed(index int) []compose.Option { + if c.Nested == nil { + return nil + } + return c.Nested.optsForIndexed[index] +} + +// WrapImplSpecificOptFn is the option to wrap the implementation specific option function. +func WrapImplSpecificOptFn[T any](optFn func(*T)) NodeOption { + return NodeOption{ + implSpecificOptFn: optFn, + } +} + +// GetCommonOptions extract model Options from Option list, optionally providing a base Options with default values. +func GetCommonOptions(base *NodeOptions, opts ...NodeOption) *NodeOptions { + if base == nil { + base = &NodeOptions{} + } + + for i := range opts { + opt := opts[i] + if opt.apply != nil { + opt.apply(base) + } + } + + return base +} + +// GetImplSpecificOptions extract the implementation specific options from Option list, optionally providing a base options with default values. +// e.g. +// +// myOption := &MyOption{ +// Field1: "default_value", +// } +// +// myOption := model.GetImplSpecificOptions(myOption, opts...) +func GetImplSpecificOptions[T any](base *T, opts ...NodeOption) *T { + if base == nil { + base = new(T) + } + + for i := range opts { + opt := opts[i] + if opt.implSpecificOptFn != nil { + optFn, ok := opt.implSpecificOptFn.(func(*T)) + if ok { + optFn(base) + } + } + } + + return base +} + +type NestedWorkflowState struct { + Index2Done map[int]bool `json:"index_2_done,omitempty"` + Index2InterruptInfo map[int]*compose.InterruptInfo `json:"index_2_interrupt_info,omitempty"` + FullOutput map[string]any `json:"full_output,omitempty"` + IntermediateVars map[string]any `json:"intermediate_vars,omitempty"` +} + +func (c *NestedWorkflowState) String() string { + s, _ := sonic.MarshalIndent(c, "", " ") + return string(s) +} + +type NestedWorkflowAware interface { + SaveNestedWorkflowState(key vo.NodeKey, state *NestedWorkflowState) error + GetNestedWorkflowState(key vo.NodeKey) (*NestedWorkflowState, bool, error) + InterruptEventStore +} diff --git a/backend/domain/workflow/internal/nodes/plugin/plugin.go b/backend/domain/workflow/internal/nodes/plugin/plugin.go index 08220076..e59e8f6b 100644 --- a/backend/domain/workflow/internal/nodes/plugin/plugin.go +++ b/backend/domain/workflow/internal/nodes/plugin/plugin.go @@ -18,16 +18,21 @@ package plugin import ( "context" - "errors" + "fmt" + "strconv" "github.com/cloudwego/eino/compose" "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/pkg/errorx" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" + "github.com/coze-dev/coze-studio/backend/pkg/lang/slices" "github.com/coze-dev/coze-studio/backend/types/errno" ) @@ -35,29 +40,76 @@ type Config struct { PluginID int64 ToolID int64 PluginVersion string +} - PluginService plugin.Service +func (c *Config) Adapt(ctx context.Context, n *vo.Node, opts ...nodes.AdaptOption) (*schema.NodeSchema, error) { + ns := &schema.NodeSchema{ + Key: vo.NodeKey(n.ID), + Type: entity.NodeTypePlugin, + Name: n.Data.Meta.Title, + Configs: c, + } + inputs := n.Data.Inputs + + apiParams := slices.ToMap(inputs.APIParams, func(e *vo.Param) (string, *vo.Param) { + return e.Name, e + }) + + ps, ok := apiParams["pluginID"] + if !ok { + return nil, fmt.Errorf("plugin id param is not found") + } + + pID, err := strconv.ParseInt(ps.Input.Value.Content.(string), 10, 64) + + c.PluginID = pID + + ps, ok = apiParams["apiID"] + if !ok { + return nil, fmt.Errorf("plugin id param is not found") + } + + tID, err := strconv.ParseInt(ps.Input.Value.Content.(string), 10, 64) + if err != nil { + return nil, err + } + + c.ToolID = tID + + ps, ok = apiParams["pluginVersion"] + if !ok { + return nil, fmt.Errorf("plugin version param is not found") + } + version := ps.Input.Value.Content.(string) + + c.PluginVersion = version + + if err := convert.SetInputsForNodeSchema(n, ns); err != nil { + return nil, err + } + + if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil { + return nil, err + } + + return ns, nil +} + +func (c *Config) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) { + return &Plugin{ + pluginID: c.PluginID, + toolID: c.ToolID, + pluginVersion: c.PluginVersion, + pluginService: plugin.GetPluginService(), + }, nil } type Plugin struct { - config *Config -} + pluginID int64 + toolID int64 + pluginVersion string -func NewPlugin(_ context.Context, cfg *Config) (*Plugin, error) { - if cfg == nil { - return nil, errors.New("config is nil") - } - if cfg.PluginID == 0 { - return nil, errors.New("plugin id is required") - } - if cfg.ToolID == 0 { - return nil, errors.New("tool id is required") - } - if cfg.PluginService == nil { - return nil, errors.New("tool service is required") - } - - return &Plugin{config: cfg}, nil + pluginService plugin.Service } func (p *Plugin) Invoke(ctx context.Context, parameters map[string]any) (ret map[string]any, err error) { @@ -65,10 +117,10 @@ func (p *Plugin) Invoke(ctx context.Context, parameters map[string]any) (ret map if ctxExeCfg := execute.GetExeCtx(ctx); ctxExeCfg != nil { exeCfg = ctxExeCfg.ExeCfg } - result, err := p.config.PluginService.ExecutePlugin(ctx, parameters, &vo.PluginEntity{ - PluginID: p.config.PluginID, - PluginVersion: ptr.Of(p.config.PluginVersion), - }, p.config.ToolID, exeCfg) + result, err := p.pluginService.ExecutePlugin(ctx, parameters, &vo.PluginEntity{ + PluginID: p.pluginID, + PluginVersion: ptr.Of(p.pluginVersion), + }, p.toolID, exeCfg) if err != nil { if extra, ok := compose.IsInterruptRerunError(err); ok { // TODO: temporarily replace interrupt with real error, because frontend cannot handle interrupt for now diff --git a/backend/domain/workflow/internal/nodes/qa/question_answer.go b/backend/domain/workflow/internal/nodes/qa/question_answer.go index 3919bc89..a6e9952f 100644 --- a/backend/domain/workflow/internal/nodes/qa/question_answer.go +++ b/backend/domain/workflow/internal/nodes/qa/question_answer.go @@ -29,9 +29,12 @@ import ( "github.com/cloudwego/eino/schema" "github.com/coze-dev/coze-studio/backend/domain/workflow" + crossmodel "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" + schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/pkg/lang/ternary" "github.com/coze-dev/coze-studio/backend/pkg/logs" "github.com/coze-dev/coze-studio/backend/pkg/sonic" @@ -39,8 +42,21 @@ import ( ) type QuestionAnswer struct { - config *Config + model model.BaseChatModel nodeMeta entity.NodeTypeMeta + + questionTpl string + answerType AnswerType + + choiceType ChoiceType + fixedChoices []string + + needExtractFromAnswer bool + additionalSystemPromptTpl string + maxAnswerCount int + + nodeKey vo.NodeKey + outputFields map[string]*vo.TypeInfo } type Config struct { @@ -51,15 +67,249 @@ type Config struct { FixedChoices []string // used for intent recognize if answer by choices and given a custom answer, as well as for extracting structured output from user response - Model model.BaseChatModel + LLMParams *crossmodel.LLMParams // the following are required if AnswerType is AnswerDirectly and needs to extract from answer ExtractFromAnswer bool AdditionalSystemPromptTpl string MaxAnswerCount int - OutputFields map[string]*vo.TypeInfo +} - NodeKey vo.NodeKey +func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema2.NodeSchema, error) { + ns := &schema2.NodeSchema{ + Key: vo.NodeKey(n.ID), + Type: entity.NodeTypeQuestionAnswer, + Name: n.Data.Meta.Title, + Configs: c, + } + + qaConf := n.Data.Inputs.QA + if qaConf == nil { + return nil, fmt.Errorf("qa config is nil") + } + c.QuestionTpl = qaConf.Question + + var llmParams *crossmodel.LLMParams + if n.Data.Inputs.LLMParam != nil { + llmParamBytes, err := sonic.Marshal(n.Data.Inputs.LLMParam) + if err != nil { + return nil, err + } + var qaLLMParams vo.SimpleLLMParam + err = sonic.Unmarshal(llmParamBytes, &qaLLMParams) + if err != nil { + return nil, err + } + + llmParams, err = convertLLMParams(qaLLMParams) + if err != nil { + return nil, err + } + + c.LLMParams = llmParams + } + + answerType, err := convertAnswerType(qaConf.AnswerType) + if err != nil { + return nil, err + } + c.AnswerType = answerType + + var choiceType ChoiceType + if len(qaConf.OptionType) > 0 { + choiceType, err = convertChoiceType(qaConf.OptionType) + if err != nil { + return nil, err + } + c.ChoiceType = choiceType + } + + if answerType == AnswerByChoices { + switch choiceType { + case FixedChoices: + var options []string + for _, option := range qaConf.Options { + options = append(options, option.Name) + } + c.FixedChoices = options + case DynamicChoices: + inputSources, err := convert.CanvasBlockInputToFieldInfo(qaConf.DynamicOption, compose.FieldPath{DynamicChoicesKey}, n.Parent()) + if err != nil { + return nil, err + } + ns.AddInputSource(inputSources...) + + inputTypes, err := convert.CanvasBlockInputToTypeInfo(qaConf.DynamicOption) + if err != nil { + return nil, err + } + ns.SetInputType(DynamicChoicesKey, inputTypes) + default: + return nil, fmt.Errorf("qa node is answer by options, but option type not provided") + } + } else if answerType == AnswerDirectly { + c.ExtractFromAnswer = qaConf.ExtractOutput + if qaConf.ExtractOutput { + if llmParams == nil { + return nil, fmt.Errorf("qa node needs to extract from answer, but LLMParams not provided") + } + c.AdditionalSystemPromptTpl = llmParams.SystemPrompt + c.MaxAnswerCount = qaConf.Limit + if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil { + return nil, err + } + } + } + + if err = convert.SetInputsForNodeSchema(n, ns); err != nil { + return nil, err + } + + return ns, nil +} + +func convertLLMParams(params vo.SimpleLLMParam) (*crossmodel.LLMParams, error) { + p := &crossmodel.LLMParams{} + p.ModelName = params.ModelName + p.ModelType = params.ModelType + p.Temperature = ¶ms.Temperature + p.MaxTokens = params.MaxTokens + p.TopP = ¶ms.TopP + p.ResponseFormat = params.ResponseFormat + p.SystemPrompt = params.SystemPrompt + return p, nil +} + +func convertAnswerType(t vo.QAAnswerType) (AnswerType, error) { + switch t { + case vo.QAAnswerTypeOption: + return AnswerByChoices, nil + case vo.QAAnswerTypeText: + return AnswerDirectly, nil + default: + return "", fmt.Errorf("invalid QAAnswerType: %s", t) + } +} + +func convertChoiceType(t vo.QAOptionType) (ChoiceType, error) { + switch t { + case vo.QAOptionTypeStatic: + return FixedChoices, nil + case vo.QAOptionTypeDynamic: + return DynamicChoices, nil + default: + return "", fmt.Errorf("invalid QAOptionType: %s", t) + } +} + +func (c *Config) Build(ctx context.Context, ns *schema2.NodeSchema, _ ...schema2.BuildOption) (any, error) { + if c.AnswerType == AnswerDirectly { + if c.ExtractFromAnswer { + if c.LLMParams == nil { + return nil, errors.New("model is required when extract from answer") + } + if len(ns.OutputTypes) == 0 { + return nil, errors.New("output fields is required when extract from answer") + } + } + } else if c.AnswerType == AnswerByChoices { + if c.ChoiceType == FixedChoices { + if len(c.FixedChoices) == 0 { + return nil, errors.New("fixed choices is required when extract from answer") + } + } + } else { + return nil, fmt.Errorf("unknown answer type: %s", c.AnswerType) + } + + nodeMeta := entity.NodeMetaByNodeType(entity.NodeTypeQuestionAnswer) + if nodeMeta == nil { + return nil, errors.New("node meta not found for question answer") + } + + var ( + m model.BaseChatModel + err error + ) + if c.LLMParams != nil { + m, _, err = crossmodel.GetManager().GetModel(ctx, c.LLMParams) + if err != nil { + return nil, err + } + } + + return &QuestionAnswer{ + model: m, + nodeMeta: *nodeMeta, + questionTpl: c.QuestionTpl, + answerType: c.AnswerType, + choiceType: c.ChoiceType, + fixedChoices: c.FixedChoices, + needExtractFromAnswer: c.ExtractFromAnswer, + additionalSystemPromptTpl: c.AdditionalSystemPromptTpl, + maxAnswerCount: c.MaxAnswerCount, + nodeKey: ns.Key, + outputFields: ns.OutputTypes, + }, nil +} + +func (c *Config) BuildBranch(_ context.Context) ( + func(ctx context.Context, nodeOutput map[string]any) (int64, bool, error), bool) { + if c.AnswerType != AnswerByChoices { + return nil, false + } + + return func(ctx context.Context, nodeOutput map[string]any) (int64, bool, error) { + optionID, ok := nodeOutput[OptionIDKey] + if !ok { + return -1, false, fmt.Errorf("failed to take option id from input map: %v", nodeOutput) + } + + if c.ChoiceType == DynamicChoices { + if optionID.(string) == "other" { + return -1, true, nil + } else { + return 0, false, nil + } + } + + if optionID.(string) == "other" { + return -1, true, nil + } + + optionIDInt, ok := AlphabetToInt(optionID.(string)) + if !ok { + return -1, false, fmt.Errorf("failed to convert option id from input map: %v", optionID) + } + + return optionIDInt, false, nil + }, true +} + +func (c *Config) ExpectPorts(ctx context.Context, n *vo.Node) (expects []string) { + if n.Data.Inputs.QA.AnswerType != vo.QAAnswerTypeOption { + return expects + } + + if n.Data.Inputs.QA.OptionType == vo.QAOptionTypeStatic { + for index := range n.Data.Inputs.QA.Options { + expects = append(expects, fmt.Sprintf(schema2.PortBranchFormat, index)) + } + + expects = append(expects, schema2.PortDefault) + return expects + } + + if n.Data.Inputs.QA.OptionType == vo.QAOptionTypeDynamic { + expects = append(expects, fmt.Sprintf(schema2.PortBranchFormat, 0)) + expects = append(expects, schema2.PortDefault) + } + + return expects +} + +func (c *Config) RequireCheckpoint() bool { + return true } type AnswerType string @@ -126,41 +376,6 @@ Strictly identify the intention and select the most suitable option. You can onl Note: You can only output the id or -1. Your output can only be a pure number and no other content (including the reason)!` ) -func NewQuestionAnswer(_ context.Context, conf *Config) (*QuestionAnswer, error) { - if conf == nil { - return nil, errors.New("config is nil") - } - - if conf.AnswerType == AnswerDirectly { - if conf.ExtractFromAnswer { - if conf.Model == nil { - return nil, errors.New("model is required when extract from answer") - } - if len(conf.OutputFields) == 0 { - return nil, errors.New("output fields is required when extract from answer") - } - } - } else if conf.AnswerType == AnswerByChoices { - if conf.ChoiceType == FixedChoices { - if len(conf.FixedChoices) == 0 { - return nil, errors.New("fixed choices is required when extract from answer") - } - } - } else { - return nil, fmt.Errorf("unknown answer type: %s", conf.AnswerType) - } - - nodeMeta := entity.NodeMetaByNodeType(entity.NodeTypeQuestionAnswer) - if nodeMeta == nil { - return nil, errors.New("node meta not found for question answer") - } - - return &QuestionAnswer{ - config: conf, - nodeMeta: *nodeMeta, - }, nil -} - type Question struct { Question string Choices []string @@ -182,10 +397,10 @@ type message struct { ID string `json:"id,omitempty"` } -// Execute formats the question (optionally with choices), interrupts, then extracts the answer. +// Invoke formats the question (optionally with choices), interrupts, then extracts the answer. // input: the references by input fields, as well as the dynamic choices array if needed. // output: USER_RESPONSE for direct answer, structured output if needs to extract from answer, and option ID / content for answer by choices. -func (q *QuestionAnswer) Execute(ctx context.Context, in map[string]any) (out map[string]any, err error) { +func (q *QuestionAnswer) Invoke(ctx context.Context, in map[string]any) (out map[string]any, err error) { var ( questions []*Question answers []string @@ -206,11 +421,11 @@ func (q *QuestionAnswer) Execute(ctx context.Context, in map[string]any) (out ma out[QuestionsKey] = questions out[AnswersKey] = answers - switch q.config.AnswerType { + switch q.answerType { case AnswerDirectly: if isFirst { // first execution, ask the question // format the question. Which is common to all use cases - firstQuestion, err := nodes.TemplateRender(q.config.QuestionTpl, in) + firstQuestion, err := nodes.TemplateRender(q.questionTpl, in) if err != nil { return nil, err } @@ -218,7 +433,7 @@ func (q *QuestionAnswer) Execute(ctx context.Context, in map[string]any) (out ma return nil, q.interrupt(ctx, firstQuestion, nil, nil, nil) } - if q.config.ExtractFromAnswer { + if q.needExtractFromAnswer { return q.extractFromAnswer(ctx, in, questions, answers) } @@ -253,15 +468,15 @@ func (q *QuestionAnswer) Execute(ctx context.Context, in map[string]any) (out ma } // format the question. Which is common to all use cases - firstQuestion, err := nodes.TemplateRender(q.config.QuestionTpl, in) + firstQuestion, err := nodes.TemplateRender(q.questionTpl, in) if err != nil { return nil, err } var formattedChoices []string - switch q.config.ChoiceType { + switch q.choiceType { case FixedChoices: - for _, choice := range q.config.FixedChoices { + for _, choice := range q.fixedChoices { formattedChoice, err := nodes.TemplateRender(choice, in) if err != nil { return nil, err @@ -283,18 +498,18 @@ func (q *QuestionAnswer) Execute(ctx context.Context, in map[string]any) (out ma formattedChoices = append(formattedChoices, c) } default: - return nil, fmt.Errorf("unknown choice type: %s", q.config.ChoiceType) + return nil, fmt.Errorf("unknown choice type: %s", q.choiceType) } return nil, q.interrupt(ctx, firstQuestion, formattedChoices, nil, nil) default: - return nil, fmt.Errorf("unknown answer type: %s", q.config.AnswerType) + return nil, fmt.Errorf("unknown answer type: %s", q.answerType) } } func (q *QuestionAnswer) extractFromAnswer(ctx context.Context, in map[string]any, questions []*Question, answers []string) (map[string]any, error) { fieldInfo := "FieldInfo" - s, err := vo.TypeInfoToJSONSchema(q.config.OutputFields, &fieldInfo) + s, err := vo.TypeInfoToJSONSchema(q.outputFields, &fieldInfo) if err != nil { return nil, err } @@ -302,15 +517,15 @@ func (q *QuestionAnswer) extractFromAnswer(ctx context.Context, in map[string]an sysPrompt := fmt.Sprintf(extractSystemPrompt, s) var requiredFields []string - for fName, tInfo := range q.config.OutputFields { + for fName, tInfo := range q.outputFields { if tInfo.Required { requiredFields = append(requiredFields, fName) } } var formattedAdditionalPrompt string - if len(q.config.AdditionalSystemPromptTpl) > 0 { - additionalPrompt, err := nodes.TemplateRender(q.config.AdditionalSystemPromptTpl, in) + if len(q.additionalSystemPromptTpl) > 0 { + additionalPrompt, err := nodes.TemplateRender(q.additionalSystemPromptTpl, in) if err != nil { return nil, err } @@ -336,7 +551,7 @@ func (q *QuestionAnswer) extractFromAnswer(ctx context.Context, in map[string]an messages = append(messages, schema.UserMessage(answer)) } - out, err := q.config.Model.Generate(ctx, messages) + out, err := q.model.Generate(ctx, messages) if err != nil { return nil, err } @@ -353,8 +568,8 @@ func (q *QuestionAnswer) extractFromAnswer(ctx context.Context, in map[string]an if ok { nextQuestionStr, ok := nextQuestion.(string) if ok && len(nextQuestionStr) > 0 { - if len(answers) >= q.config.MaxAnswerCount { - return nil, fmt.Errorf("max answer count= %d exceeded", q.config.MaxAnswerCount) + if len(answers) >= q.maxAnswerCount { + return nil, fmt.Errorf("max answer count= %d exceeded", q.maxAnswerCount) } return nil, q.interrupt(ctx, nextQuestionStr, nil, questions, answers) @@ -366,7 +581,7 @@ func (q *QuestionAnswer) extractFromAnswer(ctx context.Context, in map[string]an return nil, fmt.Errorf("field %s not found", fieldInfo) } - realOutput, ws, err := nodes.ConvertInputs(ctx, fields.(map[string]any), q.config.OutputFields, nodes.SkipRequireCheck()) + realOutput, ws, err := nodes.ConvertInputs(ctx, fields.(map[string]any), q.outputFields, nodes.SkipRequireCheck()) if err != nil { return nil, err } @@ -431,7 +646,7 @@ func (q *QuestionAnswer) intentDetect(ctx context.Context, answer string, choice schema.UserMessage(answer), } - out, err := q.config.Model.Generate(ctx, messages) + out, err := q.model.Generate(ctx, messages) if err != nil { return -1, err } @@ -468,7 +683,7 @@ func (q *QuestionAnswer) interrupt(ctx context.Context, newQuestion string, choi event := &entity.InterruptEvent{ ID: eventID, - NodeKey: q.config.NodeKey, + NodeKey: q.nodeKey, NodeType: entity.NodeTypeQuestionAnswer, NodeTitle: q.nodeMeta.Name, NodeIcon: q.nodeMeta.IconURL, @@ -477,7 +692,7 @@ func (q *QuestionAnswer) interrupt(ctx context.Context, newQuestion string, choi } _ = compose.ProcessState(ctx, func(ctx context.Context, setter QuestionAnswerAware) error { - setter.AddQuestion(q.config.NodeKey, &Question{ + setter.AddQuestion(q.nodeKey, &Question{ Question: newQuestion, Choices: choices, }) @@ -495,14 +710,14 @@ func intToAlphabet(num int) string { return "" } -func AlphabetToInt(str string) (int, bool) { +func AlphabetToInt(str string) (int64, bool) { if len(str) != 1 { return 0, false } char := rune(str[0]) char = unicode.ToUpper(char) if char >= 'A' && char <= 'Z' { - return int(char - 'A'), true + return int64(char - 'A'), true } return 0, false } @@ -521,14 +736,14 @@ func (q *QuestionAnswer) generateHistory(oldQuestions []*Question, oldAnswers [] for i := 0; i < len(oldQuestions); i++ { oldQuestion := oldQuestions[i] oldAnswer := oldAnswers[i] - contentType := ternary.IFElse(q.config.AnswerType == AnswerByChoices, "option", "text") + contentType := ternary.IFElse(q.answerType == AnswerByChoices, "option", "text") questionMsg := &message{ Type: "question", ContentType: contentType, - ID: fmt.Sprintf("%s_%d", q.config.NodeKey, i*2), + ID: fmt.Sprintf("%s_%d", q.nodeKey, i*2), } - if q.config.AnswerType == AnswerByChoices { + if q.answerType == AnswerByChoices { questionMsg.Content = optionContent{ Options: conv(oldQuestion.Choices), Question: oldQuestion.Question, @@ -541,14 +756,14 @@ func (q *QuestionAnswer) generateHistory(oldQuestions []*Question, oldAnswers [] Type: "answer", ContentType: contentType, Content: oldAnswer, - ID: fmt.Sprintf("%s_%d", q.config.NodeKey, i+1), + ID: fmt.Sprintf("%s_%d", q.nodeKey, i+1), } history = append(history, questionMsg, answerMsg) } if newQuestion != nil { - if q.config.AnswerType == AnswerByChoices { + if q.answerType == AnswerByChoices { history = append(history, &message{ Type: "question", ContentType: "option", @@ -556,14 +771,14 @@ func (q *QuestionAnswer) generateHistory(oldQuestions []*Question, oldAnswers [] Options: conv(choices), Question: *newQuestion, }, - ID: fmt.Sprintf("%s_%d", q.config.NodeKey, len(oldQuestions)*2), + ID: fmt.Sprintf("%s_%d", q.nodeKey, len(oldQuestions)*2), }) } else { history = append(history, &message{ Type: "question", ContentType: "text", Content: *newQuestion, - ID: fmt.Sprintf("%s_%d", q.config.NodeKey, len(oldQuestions)*2), + ID: fmt.Sprintf("%s_%d", q.nodeKey, len(oldQuestions)*2), }) } } diff --git a/backend/domain/workflow/internal/nodes/receiver/input_receiver.go b/backend/domain/workflow/internal/nodes/receiver/input_receiver.go index 4c7a0299..6d1b55a2 100644 --- a/backend/domain/workflow/internal/nodes/receiver/input_receiver.go +++ b/backend/domain/workflow/internal/nodes/receiver/input_receiver.go @@ -27,8 +27,10 @@ import ( "github.com/coze-dev/coze-studio/backend/domain/workflow" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/pkg/ctxcache" "github.com/coze-dev/coze-studio/backend/pkg/errorx" "github.com/coze-dev/coze-studio/backend/pkg/logs" @@ -37,19 +39,27 @@ import ( ) type Config struct { - OutputTypes map[string]*vo.TypeInfo - NodeKey vo.NodeKey OutputSchema string } -type InputReceiver struct { - outputTypes map[string]*vo.TypeInfo - interruptData string - nodeKey vo.NodeKey - nodeMeta entity.NodeTypeMeta +func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) { + c.OutputSchema = n.Data.Inputs.OutputSchema + + ns := &schema.NodeSchema{ + Key: vo.NodeKey(n.ID), + Type: entity.NodeTypeInputReceiver, + Name: n.Data.Meta.Title, + Configs: c, + } + + if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil { + return nil, err + } + + return ns, nil } -func New(_ context.Context, cfg *Config) (*InputReceiver, error) { +func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) { nodeMeta := entity.NodeMetaByNodeType(entity.NodeTypeInputReceiver) if nodeMeta == nil { return nil, errors.New("node meta not found for input receiver") @@ -57,7 +67,7 @@ func New(_ context.Context, cfg *Config) (*InputReceiver, error) { interruptData := map[string]string{ "content_type": "form_schema", - "content": cfg.OutputSchema, + "content": c.OutputSchema, } interruptDataStr, err := sonic.ConfigStd.MarshalToString(interruptData) // keep the order of the keys @@ -66,13 +76,24 @@ func New(_ context.Context, cfg *Config) (*InputReceiver, error) { } return &InputReceiver{ - outputTypes: cfg.OutputTypes, + outputTypes: ns.OutputTypes, // so the node can refer to its output types during execution nodeMeta: *nodeMeta, - nodeKey: cfg.NodeKey, + nodeKey: ns.Key, interruptData: interruptDataStr, }, nil } +func (c *Config) RequireCheckpoint() bool { + return true +} + +type InputReceiver struct { + outputTypes map[string]*vo.TypeInfo + interruptData string + nodeKey vo.NodeKey + nodeMeta entity.NodeTypeMeta +} + const ( ReceivedDataKey = "$received_data" receiverWarningKey = "receiver_warning_%d_%s" diff --git a/backend/domain/workflow/internal/nodes/selector/callbacks.go b/backend/domain/workflow/internal/nodes/selector/callbacks.go new file mode 100644 index 00000000..3e390a5f --- /dev/null +++ b/backend/domain/workflow/internal/nodes/selector/callbacks.go @@ -0,0 +1,190 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package selector + +import ( + "context" + "fmt" + "strconv" + "strings" + + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" +) + +type selectorCallbackField struct { + Key string `json:"key"` + Type vo.DataType `json:"type"` + Value any `json:"value"` +} + +type selectorCondition struct { + Left selectorCallbackField `json:"left"` + Operator vo.OperatorType `json:"operator"` + Right *selectorCallbackField `json:"right,omitempty"` +} + +type selectorBranch struct { + Conditions []*selectorCondition `json:"conditions"` + Logic vo.LogicType `json:"logic"` + Name string `json:"name"` +} + +func (s *Selector) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) { + count := len(s.clauses) + + output := make([]*selectorBranch, count) + + for _, source := range s.ns.InputSources { + targetPath := source.Path + if len(targetPath) == 2 { + indexStr := targetPath[0] + index, err := strconv.Atoi(indexStr) + if err != nil { + return nil, err + } + + branch := output[index] + if branch == nil { + output[index] = &selectorBranch{ + Conditions: []*selectorCondition{ + { + Operator: s.clauses[index].Single.ToCanvasOperatorType(), + }, + }, + Logic: ClauseRelationAND.ToVOLogicType(), + } + } + + if targetPath[1] == LeftKey { + leftV, ok := nodes.TakeMapValue(in, targetPath) + if !ok { + return nil, fmt.Errorf("failed to take left value of %s", targetPath) + } + if source.Source.Ref.VariableType != nil { + if *source.Source.Ref.VariableType == vo.ParentIntermediate { + parentNodeKey, ok := s.ws.Hierarchy[s.ns.Key] + if !ok { + return nil, fmt.Errorf("failed to find parent node key of %s", s.ns.Key) + } + parentNode := s.ws.GetNode(parentNodeKey) + output[index].Conditions[0].Left = selectorCallbackField{ + Key: parentNode.Name + "." + strings.Join(source.Source.Ref.FromPath, "."), + Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Type, + Value: leftV, + } + } else { + output[index].Conditions[0].Left = selectorCallbackField{ + Key: "", + Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Type, + Value: leftV, + } + } + } else { + output[index].Conditions[0].Left = selectorCallbackField{ + Key: s.ws.GetNode(source.Source.Ref.FromNodeKey).Name + "." + strings.Join(source.Source.Ref.FromPath, "."), + Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Type, + Value: leftV, + } + } + } else if targetPath[1] == RightKey { + rightV, ok := nodes.TakeMapValue(in, targetPath) + if !ok { + return nil, fmt.Errorf("failed to take right value of %s", targetPath) + } + output[index].Conditions[0].Right = &selectorCallbackField{ + Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Type, + Value: rightV, + } + } + } else if len(targetPath) == 3 { + indexStr := targetPath[0] + index, err := strconv.Atoi(indexStr) + if err != nil { + return nil, err + } + + multi := s.clauses[index].Multi + + branch := output[index] + if branch == nil { + output[index] = &selectorBranch{ + Conditions: make([]*selectorCondition, len(multi.Clauses)), + Logic: multi.Relation.ToVOLogicType(), + } + } + + clauseIndexStr := targetPath[1] + clauseIndex, err := strconv.Atoi(clauseIndexStr) + if err != nil { + return nil, err + } + + clause := multi.Clauses[clauseIndex] + + if output[index].Conditions[clauseIndex] == nil { + output[index].Conditions[clauseIndex] = &selectorCondition{ + Operator: clause.ToCanvasOperatorType(), + } + } + + if targetPath[2] == LeftKey { + leftV, ok := nodes.TakeMapValue(in, targetPath) + if !ok { + return nil, fmt.Errorf("failed to take left value of %s", targetPath) + } + if source.Source.Ref.VariableType != nil { + if *source.Source.Ref.VariableType == vo.ParentIntermediate { + parentNodeKey, ok := s.ws.Hierarchy[s.ns.Key] + if !ok { + return nil, fmt.Errorf("failed to find parent node key of %s", s.ns.Key) + } + parentNode := s.ws.GetNode(parentNodeKey) + output[index].Conditions[clauseIndex].Left = selectorCallbackField{ + Key: parentNode.Name + "." + strings.Join(source.Source.Ref.FromPath, "."), + Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type, + Value: leftV, + } + } else { + output[index].Conditions[clauseIndex].Left = selectorCallbackField{ + Key: "", + Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type, + Value: leftV, + } + } + } else { + output[index].Conditions[clauseIndex].Left = selectorCallbackField{ + Key: s.ws.GetNode(source.Source.Ref.FromNodeKey).Name + "." + strings.Join(source.Source.Ref.FromPath, "."), + Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type, + Value: leftV, + } + } + } else if targetPath[2] == RightKey { + rightV, ok := nodes.TakeMapValue(in, targetPath) + if !ok { + return nil, fmt.Errorf("failed to take right value of %s", targetPath) + } + output[index].Conditions[clauseIndex].Right = &selectorCallbackField{ + Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type, + Value: rightV, + } + } + } + } + + return map[string]any{"branches": output}, nil +} diff --git a/backend/domain/workflow/internal/nodes/selector/operator.go b/backend/domain/workflow/internal/nodes/selector/operator.go index afdbdcad..ae0e1bfc 100644 --- a/backend/domain/workflow/internal/nodes/selector/operator.go +++ b/backend/domain/workflow/internal/nodes/selector/operator.go @@ -180,3 +180,48 @@ func (o *Operator) ToCanvasOperatorType() vo.OperatorType { panic(fmt.Sprintf("unknown operator: %+v", o)) } } + +func ToSelectorOperator(o vo.OperatorType, leftType *vo.TypeInfo) (Operator, error) { + switch o { + case vo.Equal: + return OperatorEqual, nil + case vo.NotEqual: + return OperatorNotEqual, nil + case vo.LengthGreaterThan: + return OperatorLengthGreater, nil + case vo.LengthGreaterThanEqual: + return OperatorLengthGreaterOrEqual, nil + case vo.LengthLessThan: + return OperatorLengthLesser, nil + case vo.LengthLessThanEqual: + return OperatorLengthLesserOrEqual, nil + case vo.Contain: + if leftType.Type == vo.DataTypeObject { + return OperatorContainKey, nil + } + return OperatorContain, nil + case vo.NotContain: + if leftType.Type == vo.DataTypeObject { + return OperatorNotContainKey, nil + } + return OperatorNotContain, nil + case vo.Empty: + return OperatorEmpty, nil + case vo.NotEmpty: + return OperatorNotEmpty, nil + case vo.True: + return OperatorIsTrue, nil + case vo.False: + return OperatorIsFalse, nil + case vo.GreaterThan: + return OperatorGreater, nil + case vo.GreaterThanEqual: + return OperatorGreaterOrEqual, nil + case vo.LessThan: + return OperatorLesser, nil + case vo.LessThanEqual: + return OperatorLesserOrEqual, nil + default: + return "", fmt.Errorf("unsupported operator type: %d", o) + } +} diff --git a/backend/domain/workflow/internal/nodes/selector/schema.go b/backend/domain/workflow/internal/nodes/selector/schema.go index e536589b..82c12142 100644 --- a/backend/domain/workflow/internal/nodes/selector/schema.go +++ b/backend/domain/workflow/internal/nodes/selector/schema.go @@ -17,9 +17,16 @@ package selector import ( + "context" "fmt" + einoCompose "github.com/cloudwego/eino/compose" + + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" ) type ClauseRelation string @@ -29,10 +36,6 @@ const ( ClauseRelationOR ClauseRelation = "or" ) -type Config struct { - Clauses []*OneClauseSchema `json:"clauses"` -} - type OneClauseSchema struct { Single *Operator `json:"single,omitempty"` Multi *MultiClauseSchema `json:"multi,omitempty"` @@ -52,3 +55,140 @@ func (c ClauseRelation) ToVOLogicType() vo.LogicType { panic(fmt.Sprintf("unknown clause relation: %s", c)) } + +func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) { + clauses := make([]*OneClauseSchema, 0) + ns := &schema.NodeSchema{ + Key: vo.NodeKey(n.ID), + Name: n.Data.Meta.Title, + Type: entity.NodeTypeSelector, + Configs: c, + } + + for i, branchCond := range n.Data.Inputs.Branches { + inputType := &vo.TypeInfo{ + Type: vo.DataTypeObject, + Properties: map[string]*vo.TypeInfo{}, + } + + if len(branchCond.Condition.Conditions) == 1 { // single condition + cond := branchCond.Condition.Conditions[0] + + left := cond.Left + if left == nil { + return nil, fmt.Errorf("operator left is nil") + } + + leftType, err := convert.CanvasBlockInputToTypeInfo(left.Input) + if err != nil { + return nil, err + } + + leftSources, err := convert.CanvasBlockInputToFieldInfo(left.Input, einoCompose.FieldPath{fmt.Sprintf("%d", i), LeftKey}, n.Parent()) + if err != nil { + return nil, err + } + + inputType.Properties[LeftKey] = leftType + + ns.AddInputSource(leftSources...) + + op, err := ToSelectorOperator(cond.Operator, leftType) + if err != nil { + return nil, err + } + + if cond.Right != nil { + rightType, err := convert.CanvasBlockInputToTypeInfo(cond.Right.Input) + if err != nil { + return nil, err + } + + rightSources, err := convert.CanvasBlockInputToFieldInfo(cond.Right.Input, einoCompose.FieldPath{fmt.Sprintf("%d", i), RightKey}, n.Parent()) + if err != nil { + return nil, err + } + + inputType.Properties[RightKey] = rightType + ns.AddInputSource(rightSources...) + } + + ns.SetInputType(fmt.Sprintf("%d", i), inputType) + + clauses = append(clauses, &OneClauseSchema{ + Single: &op, + }) + + continue + } + + var relation ClauseRelation + logic := branchCond.Condition.Logic + if logic == vo.OR { + relation = ClauseRelationOR + } else if logic == vo.AND { + relation = ClauseRelationAND + } + + var ops []*Operator + for j, cond := range branchCond.Condition.Conditions { + left := cond.Left + if left == nil { + return nil, fmt.Errorf("operator left is nil") + } + + leftType, err := convert.CanvasBlockInputToTypeInfo(left.Input) + if err != nil { + return nil, err + } + + leftSources, err := convert.CanvasBlockInputToFieldInfo(left.Input, einoCompose.FieldPath{fmt.Sprintf("%d", i), fmt.Sprintf("%d", j), LeftKey}, n.Parent()) + if err != nil { + return nil, err + } + + inputType.Properties[fmt.Sprintf("%d", j)] = &vo.TypeInfo{ + Type: vo.DataTypeObject, + Properties: map[string]*vo.TypeInfo{ + LeftKey: leftType, + }, + } + + ns.AddInputSource(leftSources...) + + op, err := ToSelectorOperator(cond.Operator, leftType) + if err != nil { + return nil, err + } + ops = append(ops, &op) + + if cond.Right != nil { + rightType, err := convert.CanvasBlockInputToTypeInfo(cond.Right.Input) + if err != nil { + return nil, err + } + + rightSources, err := convert.CanvasBlockInputToFieldInfo(cond.Right.Input, einoCompose.FieldPath{fmt.Sprintf("%d", i), fmt.Sprintf("%d", j), RightKey}, n.Parent()) + if err != nil { + return nil, err + } + + inputType.Properties[fmt.Sprintf("%d", j)].Properties[RightKey] = rightType + ns.AddInputSource(rightSources...) + } + } + + ns.SetInputType(fmt.Sprintf("%d", i), inputType) + + clauses = append(clauses, &OneClauseSchema{ + Multi: &MultiClauseSchema{ + Clauses: ops, + Relation: relation, + }, + }) + } + + c.Clauses = clauses + + return ns, nil +} diff --git a/backend/domain/workflow/internal/nodes/selector/selector.go b/backend/domain/workflow/internal/nodes/selector/selector.go index 6b0e936c..38c0a30b 100644 --- a/backend/domain/workflow/internal/nodes/selector/selector.go +++ b/backend/domain/workflow/internal/nodes/selector/selector.go @@ -23,23 +23,32 @@ import ( "github.com/cloudwego/eino/compose" + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" ) type Selector struct { - config *Config + clauses []*OneClauseSchema + ns *schema.NodeSchema + ws *schema.WorkflowSchema } -func NewSelector(_ context.Context, config *Config) (*Selector, error) { - if config == nil { - return nil, fmt.Errorf("config is nil") +type Config struct { + Clauses []*OneClauseSchema `json:"clauses"` +} + +func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, opts ...schema.BuildOption) (any, error) { + ws := schema.GetBuildOptions(opts...).WS + if ws == nil { + return nil, fmt.Errorf("workflow schema is required") } - if len(config.Clauses) == 0 { + if len(c.Clauses) == 0 { return nil, fmt.Errorf("config clauses are empty") } - for _, clause := range config.Clauses { + for _, clause := range c.Clauses { if clause.Single == nil && clause.Multi == nil { return nil, fmt.Errorf("single clause and multi clause are both nil") } @@ -60,10 +69,42 @@ func NewSelector(_ context.Context, config *Config) (*Selector, error) { } return &Selector{ - config: config, + clauses: c.Clauses, + ns: ns, + ws: ws, }, nil } +func (c *Config) BuildBranch(_ context.Context) ( + func(ctx context.Context, nodeOutput map[string]any) (int64, bool, error), bool) { + return func(ctx context.Context, nodeOutput map[string]any) (int64, bool, error) { + choice := nodeOutput[SelectKey].(int64) + if choice < 0 || choice > int64(len(c.Clauses)+1) { + return -1, false, fmt.Errorf("selector choice out of range: %d", choice) + } + + if choice == int64(len(c.Clauses)) { // default + return -1, true, nil + } + + return choice, false, nil + }, true +} + +func (c *Config) ExpectPorts(_ context.Context, n *vo.Node) []string { + expects := make([]string, len(n.Data.Inputs.Branches)+1) + expects[0] = "false" // default branch + if len(n.Data.Inputs.Branches) > 0 { + expects[1] = "true" // first condition + } + + for i := 1; i < len(n.Data.Inputs.Branches); i++ { // other conditions + expects[i+1] = "true_" + strconv.Itoa(i) + } + + return expects +} + type Operants struct { Left any Right any @@ -76,14 +117,14 @@ const ( SelectKey = "selected" ) -func (s *Selector) Select(_ context.Context, input map[string]any) (out map[string]any, err error) { - in, err := s.SelectorInputConverter(input) +func (s *Selector) Invoke(_ context.Context, input map[string]any) (out map[string]any, err error) { + in, err := s.selectorInputConverter(input) if err != nil { return nil, err } - predicates := make([]Predicate, 0, len(s.config.Clauses)) - for i, oneConf := range s.config.Clauses { + predicates := make([]Predicate, 0, len(s.clauses)) + for i, oneConf := range s.clauses { if oneConf.Single != nil { left := in[i].Left right := in[i].Right @@ -132,23 +173,15 @@ func (s *Selector) Select(_ context.Context, input map[string]any) (out map[stri } if isTrue { - return map[string]any{SelectKey: i}, nil + return map[string]any{SelectKey: int64(i)}, nil } } - return map[string]any{SelectKey: len(in)}, nil // default choice + return map[string]any{SelectKey: int64(len(in))}, nil // default choice } -func (s *Selector) GetType() string { - return "Selector" -} - -func (s *Selector) ConditionCount() int { - return len(s.config.Clauses) -} - -func (s *Selector) SelectorInputConverter(in map[string]any) (out []Operants, err error) { - conf := s.config.Clauses +func (s *Selector) selectorInputConverter(in map[string]any) (out []Operants, err error) { + conf := s.clauses for i, oneConf := range conf { if oneConf.Single != nil { @@ -187,8 +220,8 @@ func (s *Selector) SelectorInputConverter(in map[string]any) (out []Operants, er } func (s *Selector) ToCallbackOutput(_ context.Context, output map[string]any) (*nodes.StructuredCallbackOutput, error) { - count := len(s.config.Clauses) - out := output[SelectKey].(int) + count := int64(len(s.clauses)) + out := output[SelectKey].(int64) if out == count { cOutput := map[string]any{"result": "pass to else branch"} return &nodes.StructuredCallbackOutput{ diff --git a/backend/domain/workflow/internal/nodes/stream.go b/backend/domain/workflow/internal/nodes/stream.go index fa219493..bd7e9754 100644 --- a/backend/domain/workflow/internal/nodes/stream.go +++ b/backend/domain/workflow/internal/nodes/stream.go @@ -22,57 +22,27 @@ import ( "github.com/cloudwego/eino/compose" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" ) var KeyIsFinished = "\x1FKey is finished\x1F" -type Mode string - -const ( - Streaming Mode = "streaming" - NonStreaming Mode = "non-streaming" -) - -type FieldStreamType string - -const ( - FieldIsStream FieldStreamType = "yes" // absolutely a stream - FieldNotStream FieldStreamType = "no" // absolutely not a stream - FieldMaybeStream FieldStreamType = "maybe" // maybe a stream, requires request-time resolution - FieldSkipped FieldStreamType = "skipped" // the field source's node is skipped -) - -// SourceInfo contains stream type for a input field source of a node. -type SourceInfo struct { - // IsIntermediate means this field is itself not a field source, but a map containing one or more field sources. - IsIntermediate bool - // FieldType the stream type of the field. May require request-time resolution in addition to compile-time. - FieldType FieldStreamType - // FromNodeKey is the node key that produces this field source. empty if the field is a static value or variable. - FromNodeKey vo.NodeKey - // FromPath is the path of this field source within the source node. empty if the field is a static value or variable. - FromPath compose.FieldPath - TypeInfo *vo.TypeInfo - // SubSources are SourceInfo for keys within this intermediate Map(Object) field. - SubSources map[string]*SourceInfo -} - type DynamicStreamContainer interface { SaveDynamicChoice(nodeKey vo.NodeKey, groupToChoice map[string]int) GetDynamicChoice(nodeKey vo.NodeKey) map[string]int - GetDynamicStreamType(nodeKey vo.NodeKey, group string) (FieldStreamType, error) - GetAllDynamicStreamTypes(nodeKey vo.NodeKey) (map[string]FieldStreamType, error) + GetDynamicStreamType(nodeKey vo.NodeKey, group string) (schema.FieldStreamType, error) + GetAllDynamicStreamTypes(nodeKey vo.NodeKey) (map[string]schema.FieldStreamType, error) } // ResolveStreamSources resolves incoming field sources for a node, deciding their stream type. -func ResolveStreamSources(ctx context.Context, sources map[string]*SourceInfo) (map[string]*SourceInfo, error) { - resolved := make(map[string]*SourceInfo, len(sources)) +func ResolveStreamSources(ctx context.Context, sources map[string]*schema.SourceInfo) (map[string]*schema.SourceInfo, error) { + resolved := make(map[string]*schema.SourceInfo, len(sources)) nodeKey2Skipped := make(map[vo.NodeKey]bool) - var resolver func(path string, sInfo *SourceInfo) (*SourceInfo, error) - resolver = func(path string, sInfo *SourceInfo) (*SourceInfo, error) { - resolvedNode := &SourceInfo{ + var resolver func(path string, sInfo *schema.SourceInfo) (*schema.SourceInfo, error) + resolver = func(path string, sInfo *schema.SourceInfo) (*schema.SourceInfo, error) { + resolvedNode := &schema.SourceInfo{ IsIntermediate: sInfo.IsIntermediate, FieldType: sInfo.FieldType, FromNodeKey: sInfo.FromNodeKey, @@ -81,7 +51,7 @@ func ResolveStreamSources(ctx context.Context, sources map[string]*SourceInfo) ( } if len(sInfo.SubSources) > 0 { - resolvedNode.SubSources = make(map[string]*SourceInfo, len(sInfo.SubSources)) + resolvedNode.SubSources = make(map[string]*schema.SourceInfo, len(sInfo.SubSources)) for k, subInfo := range sInfo.SubSources { resolvedSub, err := resolver(k, subInfo) @@ -109,16 +79,16 @@ func ResolveStreamSources(ctx context.Context, sources map[string]*SourceInfo) ( } if skipped { - resolvedNode.FieldType = FieldSkipped + resolvedNode.FieldType = schema.FieldSkipped return resolvedNode, nil } - if sInfo.FieldType == FieldMaybeStream { + if sInfo.FieldType == schema.FieldMaybeStream { if len(sInfo.SubSources) > 0 { panic("a maybe stream field should not have sub sources") } - var streamType FieldStreamType + var streamType schema.FieldStreamType err := compose.ProcessState(ctx, func(ctx context.Context, state DynamicStreamContainer) error { var e error streamType, e = state.GetDynamicStreamType(sInfo.FromNodeKey, sInfo.FromPath[0]) @@ -128,7 +98,7 @@ func ResolveStreamSources(ctx context.Context, sources map[string]*SourceInfo) ( return nil, err } - return &SourceInfo{ + return &schema.SourceInfo{ IsIntermediate: sInfo.IsIntermediate, FieldType: streamType, FromNodeKey: sInfo.FromNodeKey, @@ -156,30 +126,12 @@ type NodeExecuteStatusAware interface { NodeExecuted(key vo.NodeKey) bool } -func (s *SourceInfo) Skipped() bool { - if !s.IsIntermediate { - return s.FieldType == FieldSkipped +func IsStreamingField(s *schema.NodeSchema, path compose.FieldPath, + sc *schema.WorkflowSchema) (schema.FieldStreamType, error) { + sg, ok := s.Configs.(StreamGenerator) + if !ok { + return schema.FieldNotStream, nil } - for _, sub := range s.SubSources { - if !sub.Skipped() { - return false - } - } - - return true -} - -func (s *SourceInfo) FromNode(nodeKey vo.NodeKey) bool { - if !s.IsIntermediate { - return s.FromNodeKey == nodeKey - } - - for _, sub := range s.SubSources { - if sub.FromNode(nodeKey) { - return true - } - } - - return false + return sg.FieldStreamType(path, s, sc) } diff --git a/backend/domain/workflow/internal/nodes/subworkflow/sub_workflow.go b/backend/domain/workflow/internal/nodes/subworkflow/sub_workflow.go index 8ca78057..f761e0d7 100644 --- a/backend/domain/workflow/internal/nodes/subworkflow/sub_workflow.go +++ b/backend/domain/workflow/internal/nodes/subworkflow/sub_workflow.go @@ -18,7 +18,6 @@ package subworkflow import ( "context" - "errors" "fmt" "strconv" @@ -29,35 +28,56 @@ import ( "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit" + schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" ) type Config struct { - Runner compose.Runnable[map[string]any, map[string]any] + WorkflowID int64 + WorkflowVersion string +} + +func (c *Config) FieldStreamType(path compose.FieldPath, ns *schema2.NodeSchema, + sc *schema2.WorkflowSchema) (schema2.FieldStreamType, error) { + if !sc.RequireStreaming() { + return schema2.FieldNotStream, nil + } + + innerWF := ns.SubWorkflowSchema + + if !innerWF.RequireStreaming() { + return schema2.FieldNotStream, nil + } + + innerExit := innerWF.GetNode(entity.ExitNodeKey) + + if innerExit.Configs.(*exit.Config).TerminatePlan == vo.ReturnVariables { + return schema2.FieldNotStream, nil + } + + if !innerExit.StreamConfigs.RequireStreamingInput { + return schema2.FieldNotStream, nil + } + + if len(path) > 1 || path[0] != "output" { + return schema2.FieldNotStream, fmt.Errorf( + "streaming answering sub-workflow node can only have out field 'output'") + } + + return schema2.FieldIsStream, nil } type SubWorkflow struct { - cfg *Config + Runner compose.Runnable[map[string]any, map[string]any] } -func NewSubWorkflow(_ context.Context, cfg *Config) (*SubWorkflow, error) { - if cfg == nil { - return nil, errors.New("config is nil") - } - - if cfg.Runner == nil { - return nil, errors.New("runnable is nil") - } - - return &SubWorkflow{cfg: cfg}, nil -} - -func (s *SubWorkflow) Invoke(ctx context.Context, in map[string]any, opts ...nodes.NestedWorkflowOption) (map[string]any, error) { +func (s *SubWorkflow) Invoke(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (map[string]any, error) { nestedOpts, nodeKey, err := prepareOptions(ctx, opts...) if err != nil { return nil, err } - out, err := s.cfg.Runner.Invoke(ctx, in, nestedOpts...) + out, err := s.Runner.Invoke(ctx, in, nestedOpts...) if err != nil { interruptInfo, ok := compose.ExtractInterruptInfo(err) if !ok { @@ -82,13 +102,13 @@ func (s *SubWorkflow) Invoke(ctx context.Context, in map[string]any, opts ...nod return out, nil } -func (s *SubWorkflow) Stream(ctx context.Context, in map[string]any, opts ...nodes.NestedWorkflowOption) (*schema.StreamReader[map[string]any], error) { +func (s *SubWorkflow) Stream(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (*schema.StreamReader[map[string]any], error) { nestedOpts, nodeKey, err := prepareOptions(ctx, opts...) if err != nil { return nil, err } - out, err := s.cfg.Runner.Stream(ctx, in, nestedOpts...) + out, err := s.Runner.Stream(ctx, in, nestedOpts...) if err != nil { interruptInfo, ok := compose.ExtractInterruptInfo(err) if !ok { @@ -114,11 +134,8 @@ func (s *SubWorkflow) Stream(ctx context.Context, in map[string]any, opts ...nod return out, nil } -func prepareOptions(ctx context.Context, opts ...nodes.NestedWorkflowOption) ([]compose.Option, vo.NodeKey, error) { - options := &nodes.NestedWorkflowOptions{} - for _, opt := range opts { - opt(options) - } +func prepareOptions(ctx context.Context, opts ...nodes.NodeOption) ([]compose.Option, vo.NodeKey, error) { + options := nodes.GetCommonOptions(&nodes.NodeOptions{}, opts...) nestedOpts := options.GetOptsForNested() diff --git a/backend/domain/workflow/internal/nodes/template.go b/backend/domain/workflow/internal/nodes/template.go index e8e6e2da..3e8027bb 100644 --- a/backend/domain/workflow/internal/nodes/template.go +++ b/backend/domain/workflow/internal/nodes/template.go @@ -30,6 +30,7 @@ import ( "github.com/bytedance/sonic/ast" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/pkg/errorx" "github.com/coze-dev/coze-studio/backend/types/errno" ) @@ -156,7 +157,7 @@ func removeSlice(s string) string { type renderOptions struct { type2CustomRenderer map[reflect.Type]func(any) (string, error) - reservedKey map[string]struct{} + reservedKey map[string]struct{} // a reservedKey will always render, won't check for node skipping nilRenderer func() (string, error) } @@ -300,7 +301,7 @@ func (tp TemplatePart) Render(m []byte, opts ...RenderOption) (string, error) { } } -func (tp TemplatePart) Skipped(resolvedSources map[string]*SourceInfo) (skipped bool, invalid bool) { +func (tp TemplatePart) Skipped(resolvedSources map[string]*schema.SourceInfo) (skipped bool, invalid bool) { if len(resolvedSources) == 0 { // no information available, maybe outside the scope of a workflow return false, false } @@ -316,7 +317,7 @@ func (tp TemplatePart) Skipped(resolvedSources map[string]*SourceInfo) (skipped } if !matchingSource.IsIntermediate { - return matchingSource.FieldType == FieldSkipped, false + return matchingSource.FieldType == schema.FieldSkipped, false } for _, subPath := range tp.SubPathsBeforeSlice { @@ -325,20 +326,20 @@ func (tp TemplatePart) Skipped(resolvedSources map[string]*SourceInfo) (skipped if matchingSource.IsIntermediate { // the user specified a non-existing source, just skip it return false, true } - return matchingSource.FieldType == FieldSkipped, false + return matchingSource.FieldType == schema.FieldSkipped, false } matchingSource = subSource } if !matchingSource.IsIntermediate { - return matchingSource.FieldType == FieldSkipped, false + return matchingSource.FieldType == schema.FieldSkipped, false } - var checkSourceSkipped func(sInfo *SourceInfo) bool - checkSourceSkipped = func(sInfo *SourceInfo) bool { + var checkSourceSkipped func(sInfo *schema.SourceInfo) bool + checkSourceSkipped = func(sInfo *schema.SourceInfo) bool { if !sInfo.IsIntermediate { - return sInfo.FieldType == FieldSkipped + return sInfo.FieldType == schema.FieldSkipped } for _, subSource := range sInfo.SubSources { if !checkSourceSkipped(subSource) { @@ -373,7 +374,7 @@ func (tp TemplatePart) TypeInfo(types map[string]*vo.TypeInfo) *vo.TypeInfo { return currentType } -func Render(ctx context.Context, tpl string, input map[string]any, sources map[string]*SourceInfo, opts ...RenderOption) (string, error) { +func Render(ctx context.Context, tpl string, input map[string]any, sources map[string]*schema.SourceInfo, opts ...RenderOption) (string, error) { mi, err := sonic.Marshal(input) if err != nil { return "", err diff --git a/backend/domain/workflow/internal/nodes/textprocessor/text_processor.go b/backend/domain/workflow/internal/nodes/textprocessor/text_processor.go index e2a7170c..caaf761e 100644 --- a/backend/domain/workflow/internal/nodes/textprocessor/text_processor.go +++ b/backend/domain/workflow/internal/nodes/textprocessor/text_processor.go @@ -22,7 +22,11 @@ import ( "reflect" "strings" + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/pkg/sonic" ) @@ -34,42 +38,92 @@ const ( ) type Config struct { - Type Type `json:"type"` - Tpl string `json:"tpl"` - ConcatChar string `json:"concatChar"` - Separators []string `json:"separator"` - FullSources map[string]*nodes.SourceInfo `json:"fullSources"` + Type Type `json:"type"` + Tpl string `json:"tpl"` + ConcatChar string `json:"concatChar"` + Separators []string `json:"separator"` } -type TextProcessor struct { - config *Config -} - -func NewTextProcessor(_ context.Context, cfg *Config) (*TextProcessor, error) { - if cfg == nil { - return nil, fmt.Errorf("config requried") +func (c *Config) Adapt(ctx context.Context, n *vo.Node, opts ...nodes.AdaptOption) (*schema.NodeSchema, error) { + ns := &schema.NodeSchema{ + Key: vo.NodeKey(n.ID), + Type: entity.NodeTypeTextProcessor, + Name: n.Data.Meta.Title, + Configs: c, } - if cfg.Type == ConcatText && len(cfg.Tpl) == 0 { + + if n.Data.Inputs.Method == vo.Concat { + c.Type = ConcatText + params := n.Data.Inputs.ConcatParams + for _, param := range params { + if param.Name == "concatResult" { + c.Tpl = param.Input.Value.Content.(string) + } else if param.Name == "arrayItemConcatChar" { + c.ConcatChar = param.Input.Value.Content.(string) + } + } + } else if n.Data.Inputs.Method == vo.Split { + c.Type = SplitText + params := n.Data.Inputs.SplitParams + separators := make([]string, 0, len(params)) + for _, param := range params { + if param.Name == "delimiters" { + delimiters := param.Input.Value.Content.([]any) + for _, d := range delimiters { + separators = append(separators, d.(string)) + } + } + } + c.Separators = separators + + } else { + return nil, fmt.Errorf("not supported method: %s", n.Data.Inputs.Method) + } + + if err := convert.SetInputsForNodeSchema(n, ns); err != nil { + return nil, err + } + + if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil { + return nil, err + } + + return ns, nil +} + +func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) { + if c.Type == ConcatText && len(c.Tpl) == 0 { return nil, fmt.Errorf("config tpl requried") } return &TextProcessor{ - config: cfg, + typ: c.Type, + tpl: c.Tpl, + concatChar: c.ConcatChar, + separators: c.Separators, + fullSources: ns.FullSources, }, nil +} +type TextProcessor struct { + typ Type + tpl string + concatChar string + separators []string + fullSources map[string]*schema.SourceInfo } const OutputKey = "output" func (t *TextProcessor) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) { - switch t.config.Type { + switch t.typ { case ConcatText: arrayRenderer := func(i any) (string, error) { vs := i.([]any) - return join(vs, t.config.ConcatChar) + return join(vs, t.concatChar) } - result, err := nodes.Render(ctx, t.config.Tpl, input, t.config.FullSources, + result, err := nodes.Render(ctx, t.tpl, input, t.fullSources, nodes.WithCustomRender(reflect.TypeOf([]any{}), arrayRenderer)) if err != nil { return nil, err @@ -86,9 +140,9 @@ func (t *TextProcessor) Invoke(ctx context.Context, input map[string]any) (map[s if !ok { return nil, fmt.Errorf("input string field must string type but got %T", valueString) } - values := strings.Split(valueString, t.config.Separators[0]) + values := strings.Split(valueString, t.separators[0]) // Iterate over each delimiter - for _, sep := range t.config.Separators[1:] { + for _, sep := range t.separators[1:] { var tempParts []string for _, part := range values { tempParts = append(tempParts, strings.Split(part, sep)...) @@ -102,7 +156,7 @@ func (t *TextProcessor) Invoke(ctx context.Context, input map[string]any) (map[s return map[string]any{OutputKey: anyValues}, nil default: - return nil, fmt.Errorf("not support type %s", t.config.Type) + return nil, fmt.Errorf("not support type %s", t.typ) } } diff --git a/backend/domain/workflow/internal/nodes/textprocessor/text_processor_test.go b/backend/domain/workflow/internal/nodes/textprocessor/text_processor_test.go index d7f94f58..ea79458d 100644 --- a/backend/domain/workflow/internal/nodes/textprocessor/text_processor_test.go +++ b/backend/domain/workflow/internal/nodes/textprocessor/text_processor_test.go @@ -21,6 +21,8 @@ import ( "testing" "github.com/stretchr/testify/assert" + + schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" ) func TestNewTextProcessorNodeGenerator(t *testing.T) { @@ -30,10 +32,10 @@ func TestNewTextProcessorNodeGenerator(t *testing.T) { Type: SplitText, Separators: []string{",", "|", "."}, } - p, err := NewTextProcessor(ctx, cfg) + p, err := cfg.Build(ctx, &schema2.NodeSchema{}) assert.NoError(t, err) - result, err := p.Invoke(ctx, map[string]any{ + result, err := p.(*TextProcessor).Invoke(ctx, map[string]any{ "String": "a,b|c.d,e|f|g", }) @@ -60,9 +62,9 @@ func TestNewTextProcessorNodeGenerator(t *testing.T) { ConcatChar: `\t`, Tpl: "fx{{a}}=={{b.b1}}=={{b.b2[1]}}=={{c}}", } - p, err := NewTextProcessor(context.Background(), cfg) + p, err := cfg.Build(context.Background(), &schema2.NodeSchema{}) - result, err := p.Invoke(ctx, in) + result, err := p.(*TextProcessor).Invoke(ctx, in) assert.NoError(t, err) assert.Equal(t, result["output"], `fx1\t{"1":1}\t3==1\t2\t3==2=={"c1":"1"}`) }) diff --git a/backend/domain/workflow/internal/nodes/variableaggregator/variable_aggregator.go b/backend/domain/workflow/internal/nodes/variableaggregator/variable_aggregator.go index 874f2638..964cd568 100644 --- a/backend/domain/workflow/internal/nodes/variableaggregator/variable_aggregator.go +++ b/backend/domain/workflow/internal/nodes/variableaggregator/variable_aggregator.go @@ -32,8 +32,11 @@ import ( "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/schema" + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" + schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/pkg/ctxcache" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" "github.com/coze-dev/coze-studio/backend/pkg/safego" @@ -48,24 +51,147 @@ const ( type Config struct { MergeStrategy MergeStrategy GroupLen map[string]int - FullSources map[string]*nodes.SourceInfo - NodeKey vo.NodeKey - InputSources []*vo.FieldInfo GroupOrder []string // the order the groups are declared in frontend canvas } -type VariableAggregator struct { - config *Config +func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema2.NodeSchema, error) { + ns := &schema2.NodeSchema{ + Key: vo.NodeKey(n.ID), + Type: entity.NodeTypeVariableAggregator, + Name: n.Data.Meta.Title, + Configs: c, + } + + c.MergeStrategy = FirstNotNullValue + inputs := n.Data.Inputs + + groupToLen := make(map[string]int, len(inputs.VariableAggregator.MergeGroups)) + for i := range inputs.VariableAggregator.MergeGroups { + group := inputs.VariableAggregator.MergeGroups[i] + tInfo := &vo.TypeInfo{ + Type: vo.DataTypeObject, + Properties: make(map[string]*vo.TypeInfo), + } + ns.SetInputType(group.Name, tInfo) + for ii, v := range group.Variables { + name := strconv.Itoa(ii) + valueTypeInfo, err := convert.CanvasBlockInputToTypeInfo(v) + if err != nil { + return nil, err + } + tInfo.Properties[name] = valueTypeInfo + sources, err := convert.CanvasBlockInputToFieldInfo(v, compose.FieldPath{group.Name, name}, n.Parent()) + if err != nil { + return nil, err + } + ns.AddInputSource(sources...) + } + + length := len(group.Variables) + groupToLen[group.Name] = length + } + + groupOrder := make([]string, 0, len(groupToLen)) + for i := range inputs.VariableAggregator.MergeGroups { + group := inputs.VariableAggregator.MergeGroups[i] + groupOrder = append(groupOrder, group.Name) + } + + c.GroupLen = groupToLen + c.GroupOrder = groupOrder + + if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil { + return nil, err + } + return ns, nil } -func NewVariableAggregator(_ context.Context, cfg *Config) (*VariableAggregator, error) { - if cfg == nil { - return nil, errors.New("config is required") +func (c *Config) Build(_ context.Context, ns *schema2.NodeSchema, _ ...schema2.BuildOption) (any, error) { + if c.MergeStrategy != FirstNotNullValue { + return nil, fmt.Errorf("merge strategy not supported: %v", c.MergeStrategy) } - if cfg.MergeStrategy != FirstNotNullValue { - return nil, fmt.Errorf("merge strategy not supported: %v", cfg.MergeStrategy) + + return &VariableAggregator{ + groupLen: c.GroupLen, + fullSources: ns.FullSources, + nodeKey: ns.Key, + }, nil +} + +func (c *Config) FieldStreamType(path compose.FieldPath, ns *schema2.NodeSchema, + sc *schema2.WorkflowSchema) (schema2.FieldStreamType, error) { + if !sc.RequireStreaming() { + return schema2.FieldNotStream, nil } - return &VariableAggregator{config: cfg}, nil + + if len(path) == 2 { // asking about a specific index within a group + for _, fInfo := range ns.InputSources { + if len(fInfo.Path) == len(path) { + equal := true + for i := range fInfo.Path { + if fInfo.Path[i] != path[i] { + equal = false + break + } + } + + if equal { + if fInfo.Source.Ref == nil || fInfo.Source.Ref.FromNodeKey == "" { + return schema2.FieldNotStream, nil // variables or static values + } + fromNodeKey := fInfo.Source.Ref.FromNodeKey + fromNode := sc.GetNode(fromNodeKey) + if fromNode == nil { + return schema2.FieldNotStream, fmt.Errorf("node %s not found", fromNodeKey) + } + return nodes.IsStreamingField(fromNode, fInfo.Source.Ref.FromPath, sc) + } + } + } + } else if len(path) == 1 { // asking about the entire group + var streamCount, notStreamCount int + for _, fInfo := range ns.InputSources { + if fInfo.Path[0] == path[0] { // belong to the group + if fInfo.Source.Ref != nil && len(fInfo.Source.Ref.FromNodeKey) > 0 { + fromNode := sc.GetNode(fInfo.Source.Ref.FromNodeKey) + if fromNode == nil { + return schema2.FieldNotStream, fmt.Errorf("node %s not found", fInfo.Source.Ref.FromNodeKey) + } + subStreamType, err := nodes.IsStreamingField(fromNode, fInfo.Source.Ref.FromPath, sc) + if err != nil { + return schema2.FieldNotStream, err + } + + if subStreamType == schema2.FieldMaybeStream { + return schema2.FieldMaybeStream, nil + } else if subStreamType == schema2.FieldIsStream { + streamCount++ + } else { + notStreamCount++ + } + } + } + } + + if streamCount > 0 && notStreamCount == 0 { + return schema2.FieldIsStream, nil + } + + if streamCount == 0 && notStreamCount > 0 { + return schema2.FieldNotStream, nil + } + + return schema2.FieldMaybeStream, nil + } + + return schema2.FieldNotStream, fmt.Errorf("variable aggregator output path max len = 2, actual: %v", path) +} + +type VariableAggregator struct { + groupLen map[string]int + fullSources map[string]*schema2.SourceInfo + nodeKey vo.NodeKey + groupOrder []string // the order the groups are declared in frontend canvas } func (v *VariableAggregator) Invoke(ctx context.Context, input map[string]any) (_ map[string]any, err error) { @@ -76,7 +202,7 @@ func (v *VariableAggregator) Invoke(ctx context.Context, input map[string]any) ( result := make(map[string]any) groupToChoice := make(map[string]int) - for group, length := range v.config.GroupLen { + for group, length := range v.groupLen { for i := 0; i < length; i++ { if value, ok := in[group][i]; ok { if value != nil { @@ -93,14 +219,14 @@ func (v *VariableAggregator) Invoke(ctx context.Context, input map[string]any) ( } _ = compose.ProcessState(ctx, func(ctx context.Context, state nodes.DynamicStreamContainer) error { - state.SaveDynamicChoice(v.config.NodeKey, groupToChoice) + state.SaveDynamicChoice(v.nodeKey, groupToChoice) return nil }) - ctxcache.Store(ctx, groupChoiceTypeCacheKey, map[string]nodes.FieldStreamType{}) // none of the choices are stream + ctxcache.Store(ctx, groupChoiceTypeCacheKey, map[string]schema2.FieldStreamType{}) // none of the choices are stream - groupChoices := make([]any, 0, len(v.config.GroupOrder)) - for _, group := range v.config.GroupOrder { + groupChoices := make([]any, 0, len(v.groupOrder)) + for _, group := range v.groupOrder { choice := groupToChoice[group] if choice == -1 { groupChoices = append(groupChoices, nil) @@ -125,7 +251,7 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream _ *schema.StreamReader[map[string]any], err error) { inStream := streamInputConverter(input) - resolvedSources, ok := ctxcache.Get[map[string]*nodes.SourceInfo](ctx, resolvedSourcesCacheKey) + resolvedSources, ok := ctxcache.Get[map[string]*schema2.SourceInfo](ctx, resolvedSourcesCacheKey) if !ok { panic("unable to get resolvesSources from ctx cache.") } @@ -138,18 +264,18 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream defer func() { if err == nil { - groupChoiceToStreamType := map[string]nodes.FieldStreamType{} + groupChoiceToStreamType := map[string]schema2.FieldStreamType{} for group, choice := range groupToChoice { if choice != -1 { item := groupToItems[group][choice] if _, ok := item.(stream); ok { - groupChoiceToStreamType[group] = nodes.FieldIsStream + groupChoiceToStreamType[group] = schema2.FieldIsStream } } } - groupChoices := make([]any, 0, len(v.config.GroupOrder)) - for _, group := range v.config.GroupOrder { + groupChoices := make([]any, 0, len(v.groupOrder)) + for _, group := range v.groupOrder { choice := groupToChoice[group] if choice == -1 { groupChoices = append(groupChoices, nil) @@ -174,16 +300,16 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream // - if an element is not stream, actually receive from the stream to check if it's non-nil groupToCurrentIndex := make(map[string]int) // the currently known smallest index that is non-nil for each group - for group, length := range v.config.GroupLen { + for group, length := range v.groupLen { groupToItems[group] = make([]any, length) groupToCurrentIndex[group] = math.MaxInt for i := 0; i < length; i++ { fType := resolvedSources[group].SubSources[strconv.Itoa(i)].FieldType - if fType == nodes.FieldSkipped { + if fType == schema2.FieldSkipped { groupToItems[group][i] = skipped{} continue } - if fType == nodes.FieldIsStream { + if fType == schema2.FieldIsStream { groupToItems[group][i] = stream{} if ci, _ := groupToCurrentIndex[group]; i < ci { groupToCurrentIndex[group] = i @@ -211,7 +337,7 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream } allDone := func() bool { - for group := range v.config.GroupLen { + for group := range v.groupLen { _, ok := groupToChoice[group] if !ok { return false @@ -223,7 +349,7 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream alreadyDone := allDone() if alreadyDone { // all groups have made their choices, no need to actually read input streams - result := make(map[string]any, len(v.config.GroupLen)) + result := make(map[string]any, len(v.groupLen)) allSkip := true for group := range groupToChoice { choice := groupToChoice[group] @@ -237,7 +363,7 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream if allSkip { // no need to convert input streams for the output, because all groups are skipped _ = compose.ProcessState(ctx, func(ctx context.Context, state nodes.DynamicStreamContainer) error { - state.SaveDynamicChoice(v.config.NodeKey, groupToChoice) + state.SaveDynamicChoice(v.nodeKey, groupToChoice) return nil }) return schema.StreamReaderFromArray([]map[string]any{result}), nil @@ -336,7 +462,7 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream } _ = compose.ProcessState(ctx, func(ctx context.Context, state nodes.DynamicStreamContainer) error { - state.SaveDynamicChoice(v.config.NodeKey, groupToChoice) + state.SaveDynamicChoice(v.nodeKey, groupToChoice) return nil }) @@ -416,26 +542,12 @@ type vaCallbackInput struct { Variables []any `json:"variables"` } -func (v *VariableAggregator) Init(ctx context.Context) (context.Context, error) { - ctx = ctxcache.Init(ctx) - - resolvedSources, err := nodes.ResolveStreamSources(ctx, v.config.FullSources) - if err != nil { - return nil, err - } - - // need this info for callbacks.OnStart, so we put it in cache within Init() - ctxcache.Store(ctx, resolvedSourcesCacheKey, resolvedSources) - - return ctx, nil -} - type streamMarkerType string const streamMarker streamMarkerType = "" func (v *VariableAggregator) ToCallbackInput(ctx context.Context, input map[string]any) (map[string]any, error) { - resolvedSources, ok := ctxcache.Get[map[string]*nodes.SourceInfo](ctx, resolvedSourcesCacheKey) + resolvedSources, ok := ctxcache.Get[map[string]*schema2.SourceInfo](ctx, resolvedSourcesCacheKey) if !ok { panic("unable to get resolved_sources from ctx cache") } @@ -447,14 +559,14 @@ func (v *VariableAggregator) ToCallbackInput(ctx context.Context, input map[stri merged := make([]vaCallbackInput, 0, len(in)) - groupLen := v.config.GroupLen + groupLen := v.groupLen for groupName, vars := range in { orderedVars := make([]any, groupLen[groupName]) for index := range vars { orderedVars[index] = vars[index] if len(resolvedSources) > 0 { - if resolvedSources[groupName].SubSources[strconv.Itoa(index)].FieldType == nodes.FieldIsStream { + if resolvedSources[groupName].SubSources[strconv.Itoa(index)].FieldType == schema2.FieldIsStream { // replace the streams with streamMarker, // because we won't read, save to execution history, or display these streams to user orderedVars[index] = streamMarker @@ -479,7 +591,7 @@ func (v *VariableAggregator) ToCallbackInput(ctx context.Context, input map[stri } func (v *VariableAggregator) ToCallbackOutput(ctx context.Context, output map[string]any) (*nodes.StructuredCallbackOutput, error) { - dynamicStreamType, ok := ctxcache.Get[map[string]nodes.FieldStreamType](ctx, groupChoiceTypeCacheKey) + dynamicStreamType, ok := ctxcache.Get[map[string]schema2.FieldStreamType](ctx, groupChoiceTypeCacheKey) if !ok { panic("unable to get dynamic stream types from ctx cache") } @@ -501,7 +613,7 @@ func (v *VariableAggregator) ToCallbackOutput(ctx context.Context, output map[st newOut := maps.Clone(output) for k := range output { - if t, ok := dynamicStreamType[k]; ok && t == nodes.FieldIsStream { + if t, ok := dynamicStreamType[k]; ok && t == schema2.FieldIsStream { newOut[k] = streamMarker } } @@ -594,3 +706,15 @@ func init() { nodes.RegisterStreamChunkConcatFunc(concatVACallbackInputs) nodes.RegisterStreamChunkConcatFunc(concatStreamMarkers) } + +func (v *VariableAggregator) Init(ctx context.Context) (context.Context, error) { + resolvedSources, err := nodes.ResolveStreamSources(ctx, v.fullSources) + if err != nil { + return nil, err + } + + // need this info for callbacks.OnStart, so we put it in cache within Init() + ctxcache.Store(ctx, resolvedSourcesCacheKey, resolvedSources) + + return ctx, nil +} diff --git a/backend/domain/workflow/internal/nodes/variableassigner/variable_assign.go b/backend/domain/workflow/internal/nodes/variableassigner/variable_assign.go index 06b00412..058ad6c9 100644 --- a/backend/domain/workflow/internal/nodes/variableassigner/variable_assign.go +++ b/backend/domain/workflow/internal/nodes/variableassigner/variable_assign.go @@ -25,29 +25,75 @@ import ( "github.com/cloudwego/eino/compose" "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable" + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/pkg/errorx" "github.com/coze-dev/coze-studio/backend/types/errno" ) type VariableAssigner struct { - config *Config + pairs []*Pair + handler *variable.Handler } type Config struct { - Pairs []*Pair - Handler *variable.Handler + Pairs []*Pair } -type Pair struct { - Left vo.Reference - Right compose.FieldPath +func (c *Config) Adapt(ctx context.Context, n *vo.Node, opts ...nodes.AdaptOption) (*schema.NodeSchema, error) { + ns := &schema.NodeSchema{ + Key: vo.NodeKey(n.ID), + Type: entity.NodeTypeVariableAssigner, + Name: n.Data.Meta.Title, + Configs: c, + } + + var pairs = make([]*Pair, 0, len(n.Data.Inputs.InputParameters)) + for i, param := range n.Data.Inputs.InputParameters { + if param.Left == nil || param.Input == nil { + return nil, fmt.Errorf("variable assigner node's param left or input is nil") + } + + leftSources, err := convert.CanvasBlockInputToFieldInfo(param.Left, compose.FieldPath{fmt.Sprintf("left_%d", i)}, n.Parent()) + if err != nil { + return nil, err + } + + if leftSources[0].Source.Ref == nil { + return nil, fmt.Errorf("variable assigner node's param left source ref is nil") + } + + if leftSources[0].Source.Ref.VariableType == nil { + return nil, fmt.Errorf("variable assigner node's param left source ref's variable type is nil") + } + + if *leftSources[0].Source.Ref.VariableType == vo.GlobalSystem { + return nil, fmt.Errorf("variable assigner node's param left's ref's variable type cannot be variable.GlobalSystem") + } + + inputSource, err := convert.CanvasBlockInputToFieldInfo(param.Input, leftSources[0].Source.Ref.FromPath, n.Parent()) + if err != nil { + return nil, err + } + ns.AddInputSource(inputSource...) + pair := &Pair{ + Left: *leftSources[0].Source.Ref, + Right: inputSource[0].Path, + } + pairs = append(pairs, pair) + } + + c.Pairs = pairs + + return ns, nil } -func NewVariableAssigner(_ context.Context, conf *Config) (*VariableAssigner, error) { - for _, pair := range conf.Pairs { +func (c *Config) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) { + for _, pair := range c.Pairs { if pair.Left.VariableType == nil { return nil, fmt.Errorf("cannot assign to output of nodes in VariableAssigner, ref: %v", pair.Left) } @@ -63,12 +109,18 @@ func NewVariableAssigner(_ context.Context, conf *Config) (*VariableAssigner, er } return &VariableAssigner{ - config: conf, + pairs: c.Pairs, + handler: variable.GetVariableHandler(), }, nil } -func (v *VariableAssigner) Assign(ctx context.Context, in map[string]any) (map[string]any, error) { - for _, pair := range v.config.Pairs { +type Pair struct { + Left vo.Reference + Right compose.FieldPath +} + +func (v *VariableAssigner) Invoke(ctx context.Context, in map[string]any) (map[string]any, error) { + for _, pair := range v.pairs { right, ok := nodes.TakeMapValue(in, pair.Right) if !ok { return nil, vo.NewError(errno.ErrInputFieldMissing, errorx.KV("name", strings.Join(pair.Right, "."))) @@ -98,7 +150,7 @@ func (v *VariableAssigner) Assign(ctx context.Context, in map[string]any) (map[s ConnectorUID: exeCfg.ConnectorUID, })) } - err := v.config.Handler.Set(ctx, *pair.Left.VariableType, pair.Left.FromPath, right, opts...) + err := v.handler.Set(ctx, *pair.Left.VariableType, pair.Left.FromPath, right, opts...) if err != nil { return nil, vo.WrapIfNeeded(errno.ErrVariablesAPIFail, err) } diff --git a/backend/domain/workflow/internal/nodes/variableassigner/variable_assign_in_loop.go b/backend/domain/workflow/internal/nodes/variableassigner/variable_assign_in_loop.go index df7a480d..3b09e031 100644 --- a/backend/domain/workflow/internal/nodes/variableassigner/variable_assign_in_loop.go +++ b/backend/domain/workflow/internal/nodes/variableassigner/variable_assign_in_loop.go @@ -20,25 +20,93 @@ import ( "context" "fmt" + einoCompose "github.com/cloudwego/eino/compose" + "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable" + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" ) -type InLoop struct { - config *Config - intermediateVarStore variable.Store +type InLoopConfig struct { + Pairs []*Pair } -func NewVariableAssignerInLoop(_ context.Context, conf *Config) (*InLoop, error) { +func (i *InLoopConfig) Adapt(ctx context.Context, n *vo.Node, opts ...nodes.AdaptOption) (*schema.NodeSchema, error) { + if n.Parent() == nil { + return nil, fmt.Errorf("loop set variable node must have parent: %s", n.ID) + } + + ns := &schema.NodeSchema{ + Key: vo.NodeKey(n.ID), + Type: entity.NodeTypeVariableAssignerWithinLoop, + Name: n.Data.Meta.Title, + Configs: i, + } + + var pairs []*Pair + for i, param := range n.Data.Inputs.InputParameters { + if param.Left == nil || param.Right == nil { + return nil, fmt.Errorf("loop set variable node's param left or right is nil") + } + + leftSources, err := convert.CanvasBlockInputToFieldInfo(param.Left, einoCompose.FieldPath{fmt.Sprintf("left_%d", i)}, n.Parent()) + if err != nil { + return nil, err + } + + if len(leftSources) != 1 { + return nil, fmt.Errorf("loop set variable node's param left is not a single source") + } + + if leftSources[0].Source.Ref == nil { + return nil, fmt.Errorf("loop set variable node's param left's ref is nil") + } + + if leftSources[0].Source.Ref.VariableType == nil || *leftSources[0].Source.Ref.VariableType != vo.ParentIntermediate { + return nil, fmt.Errorf("loop set variable node's param left's ref's variable type is not variable.ParentIntermediate") + } + + rightSources, err := convert.CanvasBlockInputToFieldInfo(param.Right, leftSources[0].Source.Ref.FromPath, n.Parent()) + if err != nil { + return nil, err + } + + ns.AddInputSource(rightSources...) + + if len(rightSources) != 1 { + return nil, fmt.Errorf("loop set variable node's param right is not a single source") + } + + pair := &Pair{ + Left: *leftSources[0].Source.Ref, + Right: rightSources[0].Path, + } + + pairs = append(pairs, pair) + } + + i.Pairs = pairs + + return ns, nil +} + +func (i *InLoopConfig) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) { return &InLoop{ - config: conf, + pairs: i.Pairs, intermediateVarStore: &nodes.ParentIntermediateStore{}, }, nil } -func (v *InLoop) Assign(ctx context.Context, in map[string]any) (out map[string]any, err error) { - for _, pair := range v.config.Pairs { +type InLoop struct { + pairs []*Pair + intermediateVarStore variable.Store +} + +func (v *InLoop) Invoke(ctx context.Context, in map[string]any) (out map[string]any, err error) { + for _, pair := range v.pairs { if pair.Left.VariableType == nil || *pair.Left.VariableType != vo.ParentIntermediate { panic(fmt.Errorf("dest is %+v in VariableAssignerInloop, invalid", pair.Left)) } diff --git a/backend/domain/workflow/internal/nodes/variableassigner/variable_assign_test.go b/backend/domain/workflow/internal/nodes/variableassigner/variable_assign_test.go index 1555ee96..2b385ee9 100644 --- a/backend/domain/workflow/internal/nodes/variableassigner/variable_assign_test.go +++ b/backend/domain/workflow/internal/nodes/variableassigner/variable_assign_test.go @@ -37,36 +37,34 @@ func TestVariableAssigner(t *testing.T) { arrVar := any([]any{1, "2"}) va := &InLoop{ - config: &Config{ - Pairs: []*Pair{ - { - Left: vo.Reference{ - FromPath: compose.FieldPath{"int_var_s"}, - VariableType: ptr.Of(vo.ParentIntermediate), - }, - Right: compose.FieldPath{"int_var_t"}, + pairs: []*Pair{ + { + Left: vo.Reference{ + FromPath: compose.FieldPath{"int_var_s"}, + VariableType: ptr.Of(vo.ParentIntermediate), }, - { - Left: vo.Reference{ - FromPath: compose.FieldPath{"str_var_s"}, - VariableType: ptr.Of(vo.ParentIntermediate), - }, - Right: compose.FieldPath{"str_var_t"}, + Right: compose.FieldPath{"int_var_t"}, + }, + { + Left: vo.Reference{ + FromPath: compose.FieldPath{"str_var_s"}, + VariableType: ptr.Of(vo.ParentIntermediate), }, - { - Left: vo.Reference{ - FromPath: compose.FieldPath{"obj_var_s"}, - VariableType: ptr.Of(vo.ParentIntermediate), - }, - Right: compose.FieldPath{"obj_var_t"}, + Right: compose.FieldPath{"str_var_t"}, + }, + { + Left: vo.Reference{ + FromPath: compose.FieldPath{"obj_var_s"}, + VariableType: ptr.Of(vo.ParentIntermediate), }, - { - Left: vo.Reference{ - FromPath: compose.FieldPath{"arr_var_s"}, - VariableType: ptr.Of(vo.ParentIntermediate), - }, - Right: compose.FieldPath{"arr_var_t"}, + Right: compose.FieldPath{"obj_var_t"}, + }, + { + Left: vo.Reference{ + FromPath: compose.FieldPath{"arr_var_s"}, + VariableType: ptr.Of(vo.ParentIntermediate), }, + Right: compose.FieldPath{"arr_var_t"}, }, }, intermediateVarStore: &nodes.ParentIntermediateStore{}, @@ -79,7 +77,7 @@ func TestVariableAssigner(t *testing.T) { "arr_var_s": &arrVar, }, nil) - _, err := va.Assign(ctx, map[string]any{ + _, err := va.Invoke(ctx, map[string]any{ "int_var_t": 2, "str_var_t": "str2", "obj_var_t": map[string]any{ diff --git a/backend/domain/workflow/internal/schema/branch_schema.go b/backend/domain/workflow/internal/schema/branch_schema.go new file mode 100644 index 00000000..f568c3be --- /dev/null +++ b/backend/domain/workflow/internal/schema/branch_schema.go @@ -0,0 +1,196 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package schema + +import ( + "context" + "fmt" + + "github.com/cloudwego/eino/compose" + + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" +) + +// Port type constants +const ( + PortDefault = "default" + PortBranchError = "branch_error" + PortBranchFormat = "branch_%d" +) + +// BranchSchema defines the schema for workflow branches. +type BranchSchema struct { + From vo.NodeKey `json:"from_node"` + DefaultMapping map[string]bool `json:"default_mapping,omitempty"` + ExceptionMapping map[string]bool `json:"exception_mapping,omitempty"` + Mappings map[int64]map[string]bool `json:"mappings,omitempty"` +} + +// BuildBranches builds branch schemas from connections. +func BuildBranches(connections []*Connection) (map[vo.NodeKey]*BranchSchema, error) { + var branchMap map[vo.NodeKey]*BranchSchema + + for _, conn := range connections { + if conn.FromPort == nil || len(*conn.FromPort) == 0 { + continue + } + + port := *conn.FromPort + sourceNodeKey := conn.FromNode + + if branchMap == nil { + branchMap = map[vo.NodeKey]*BranchSchema{} + } + + // Get or create branch schema for source node + branch, exists := branchMap[sourceNodeKey] + if !exists { + branch = &BranchSchema{ + From: sourceNodeKey, + } + branchMap[sourceNodeKey] = branch + } + + // Classify port type and add to appropriate mapping + switch { + case port == PortDefault: + if branch.DefaultMapping == nil { + branch.DefaultMapping = map[string]bool{} + } + branch.DefaultMapping[string(conn.ToNode)] = true + case port == PortBranchError: + if branch.ExceptionMapping == nil { + branch.ExceptionMapping = map[string]bool{} + } + branch.ExceptionMapping[string(conn.ToNode)] = true + default: + var branchNum int64 + _, err := fmt.Sscanf(port, PortBranchFormat, &branchNum) + if err != nil || branchNum < 0 { + return nil, fmt.Errorf("invalid port format '%s' for connection %+v", port, conn) + } + if branch.Mappings == nil { + branch.Mappings = map[int64]map[string]bool{} + } + if _, exists := branch.Mappings[branchNum]; !exists { + branch.Mappings[branchNum] = make(map[string]bool) + } + branch.Mappings[branchNum][string(conn.ToNode)] = true + } + } + + return branchMap, nil +} + +func (bs *BranchSchema) OnlyException() bool { + return len(bs.Mappings) == 0 && len(bs.ExceptionMapping) > 0 && len(bs.DefaultMapping) > 0 +} + +func (bs *BranchSchema) GetExceptionBranch() *compose.GraphBranch { + condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) { + isSuccess, ok := in["isSuccess"] + if ok && isSuccess != nil && !isSuccess.(bool) { + return bs.ExceptionMapping, nil + } + + return bs.DefaultMapping, nil + } + + // Combine ExceptionMapping and DefaultMapping into a new map + endNodes := make(map[string]bool) + for node := range bs.ExceptionMapping { + endNodes[node] = true + } + for node := range bs.DefaultMapping { + endNodes[node] = true + } + + return compose.NewGraphMultiBranch(condition, endNodes) +} + +func (bs *BranchSchema) GetFullBranch(ctx context.Context, bb BranchBuilder) (*compose.GraphBranch, error) { + extractor, hasBranch := bb.BuildBranch(ctx) + if !hasBranch { + return nil, fmt.Errorf("branch expected but BranchBuilder thinks not. BranchSchema: %v", bs) + } + + if len(bs.ExceptionMapping) == 0 { // no exception, it's a normal branch + condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) { + index, isDefault, err := extractor(ctx, in) + if err != nil { + return nil, err + } + + if isDefault { + return bs.DefaultMapping, nil + } + + if _, ok := bs.Mappings[index]; !ok { + return nil, fmt.Errorf("chosen index= %d, out of range", index) + } + + return bs.Mappings[index], nil + } + + // Combine DefaultMapping and normal mappings into a new map + endNodes := make(map[string]bool) + for node := range bs.DefaultMapping { + endNodes[node] = true + } + for _, ms := range bs.Mappings { + for node := range ms { + endNodes[node] = true + } + } + + return compose.NewGraphMultiBranch(condition, endNodes), nil + } + + condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) { + isSuccess, ok := in["isSuccess"] + if ok && isSuccess != nil && !isSuccess.(bool) { + return bs.ExceptionMapping, nil + } + + index, isDefault, err := extractor(ctx, in) + if err != nil { + return nil, err + } + + if isDefault { + return bs.DefaultMapping, nil + } + + return bs.Mappings[index], nil + } + + // Combine ALL mappings into a new map + endNodes := make(map[string]bool) + for node := range bs.ExceptionMapping { + endNodes[node] = true + } + for node := range bs.DefaultMapping { + endNodes[node] = true + } + for _, ms := range bs.Mappings { + for node := range ms { + endNodes[node] = true + } + } + + return compose.NewGraphMultiBranch(condition, endNodes), nil +} diff --git a/backend/domain/workflow/internal/schema/node_builder.go b/backend/domain/workflow/internal/schema/node_builder.go new file mode 100644 index 00000000..ea60c815 --- /dev/null +++ b/backend/domain/workflow/internal/schema/node_builder.go @@ -0,0 +1,73 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package schema + +import ( + "context" + + "github.com/cloudwego/eino/compose" +) + +type BuildOptions struct { + WS *WorkflowSchema + Inner compose.Runnable[map[string]any, map[string]any] +} + +func GetBuildOptions(opts ...BuildOption) *BuildOptions { + bo := &BuildOptions{} + for _, o := range opts { + o(bo) + } + return bo +} + +type BuildOption func(options *BuildOptions) + +func WithWorkflowSchema(ws *WorkflowSchema) BuildOption { + return func(options *BuildOptions) { + options.WS = ws + } +} + +func WithInnerWorkflow(inner compose.Runnable[map[string]any, map[string]any]) BuildOption { + return func(options *BuildOptions) { + options.Inner = inner + } +} + +// NodeBuilder takes a NodeSchema and several BuildOption to build an executable node instance. +// The result 'executable' MUST implement at least one of the execute interfaces: +// - nodes.InvokableNode +// - nodes.StreamableNode +// - nodes.CollectableNode +// - nodes.TransformableNode +// - nodes.InvokableNodeWOpt +// - nodes.StreamableNodeWOpt +// - nodes.CollectableNodeWOpt +// - nodes.TransformableNodeWOpt +// NOTE: the 'normal' version does not take NodeOption, while the 'WOpt' versions take NodeOption. +// NOTE: a node should either implement the 'normal' versions, or the 'WOpt' versions, not mix them up. +type NodeBuilder interface { + Build(ctx context.Context, ns *NodeSchema, opts ...BuildOption) ( + executable any, err error) +} + +// BranchBuilder builds the extractor function that maps node output to port index. +type BranchBuilder interface { + BuildBranch(ctx context.Context) (extractor func(ctx context.Context, + nodeOutput map[string]any) (int64, bool /*if is default branch*/, error), hasBranch bool) +} diff --git a/backend/domain/workflow/internal/schema/node_schema.go b/backend/domain/workflow/internal/schema/node_schema.go new file mode 100644 index 00000000..61b85f0f --- /dev/null +++ b/backend/domain/workflow/internal/schema/node_schema.go @@ -0,0 +1,131 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package schema + +import ( + "github.com/cloudwego/eino/compose" + + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" +) + +// NodeSchema is the universal description and configuration for a workflow Node. +// It should contain EVERYTHING a node needs to instantiate. +type NodeSchema struct { + // Key is the node key within the Eino graph. + // A node may need this information during execution, + // e.g. + // - using this Key to query workflow State for data belonging to current node. + Key vo.NodeKey `json:"key"` + + // Name is the name for this node as specified on Canvas. + // A node may show this name on Canvas as part of this node's input/output. + Name string `json:"name"` + + // Type is the NodeType for the node. + Type entity.NodeType `json:"type"` + + // Configs are node specific configurations, with actual struct type defined by each Node Type. + // Will not hold information relating to field mappings, nor as node's static values. + // In a word, these Configs are INTERNAL to node's implementation, NOT related to workflow orchestration. + // Actual type of these Configs should implement two interfaces: + // - NodeAdaptor: to provide conversion from vo.Node to NodeSchema + // - NodeBuilder: to provide instantiation from NodeSchema to actual node instance. + Configs any `json:"configs,omitempty"` + + // InputTypes are type information about the node's input fields. + InputTypes map[string]*vo.TypeInfo `json:"input_types,omitempty"` + // InputSources are field mapping information about the node's input fields. + InputSources []*vo.FieldInfo `json:"input_sources,omitempty"` + + // OutputTypes are type information about the node's output fields. + OutputTypes map[string]*vo.TypeInfo `json:"output_types,omitempty"` + // OutputSources are field mapping information about the node's output fields. + // NOTE: only applicable to composite nodes such as NodeTypeBatch or NodeTypeLoop. + OutputSources []*vo.FieldInfo `json:"output_sources,omitempty"` + + // ExceptionConfigs are about exception handling strategy of the node. + ExceptionConfigs *ExceptionConfig `json:"exception_configs,omitempty"` + // StreamConfigs are streaming characteristics of the node. + StreamConfigs *StreamConfig `json:"stream_configs,omitempty"` + + // SubWorkflowBasic is basic information of the sub workflow if this node is NodeTypeSubWorkflow. + SubWorkflowBasic *entity.WorkflowBasic `json:"sub_workflow_basic,omitempty"` + // SubWorkflowSchema is WorkflowSchema of the sub workflow if this node is NodeTypeSubWorkflow. + SubWorkflowSchema *WorkflowSchema `json:"sub_workflow_schema,omitempty"` + + // FullSources contains more complete information about a node's input fields' mapping sources, + // such as whether a field's source is a 'streaming field', + // or whether the field is an object that contains sub-fields with real mappings. + // Used for those nodes that need to process streaming input. + // Set InputSourceAware = true in NodeMeta to enable. + FullSources map[string]*SourceInfo + + // Lambda directly sets the node to be an Eino Lambda. + // NOTE: not serializable, used ONLY for internal test. + Lambda *compose.Lambda +} + +type RequireCheckpoint interface { + RequireCheckpoint() bool +} + +type ExceptionConfig struct { + TimeoutMS int64 `json:"timeout_ms,omitempty"` // timeout in milliseconds, 0 means no timeout + MaxRetry int64 `json:"max_retry,omitempty"` // max retry times, 0 means no retry + ProcessType *vo.ErrorProcessType `json:"process_type,omitempty"` // error process type, 0 means throw error + DataOnErr string `json:"data_on_err,omitempty"` // data to return when error, effective when ProcessType==Default occurs +} + +type StreamConfig struct { + // whether this node has the ability to produce genuine streaming output. + // not include nodes that only passes stream down as they receives them + CanGeneratesStream bool `json:"can_generates_stream,omitempty"` + // whether this node prioritize streaming input over none-streaming input. + // not include nodes that can accept both and does not have preference. + RequireStreamingInput bool `json:"can_process_stream,omitempty"` +} + +func (s *NodeSchema) SetConfigKV(key string, value any) { + if s.Configs == nil { + s.Configs = make(map[string]any) + } + + s.Configs.(map[string]any)[key] = value +} + +func (s *NodeSchema) SetInputType(key string, t *vo.TypeInfo) { + if s.InputTypes == nil { + s.InputTypes = make(map[string]*vo.TypeInfo) + } + s.InputTypes[key] = t +} + +func (s *NodeSchema) AddInputSource(info ...*vo.FieldInfo) { + s.InputSources = append(s.InputSources, info...) +} + +func (s *NodeSchema) SetOutputType(key string, t *vo.TypeInfo) { + if s.OutputTypes == nil { + s.OutputTypes = make(map[string]*vo.TypeInfo) + } + s.OutputTypes[key] = t +} + +func (s *NodeSchema) AddOutputSource(info ...*vo.FieldInfo) { + s.OutputSources = append(s.OutputSources, info...) +} diff --git a/backend/domain/workflow/internal/schema/stream.go b/backend/domain/workflow/internal/schema/stream.go new file mode 100644 index 00000000..c3c362a6 --- /dev/null +++ b/backend/domain/workflow/internal/schema/stream.go @@ -0,0 +1,77 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package schema + +import ( + "github.com/cloudwego/eino/compose" + + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" +) + +type FieldStreamType string + +const ( + FieldIsStream FieldStreamType = "yes" // absolutely a stream + FieldNotStream FieldStreamType = "no" // absolutely not a stream + FieldMaybeStream FieldStreamType = "maybe" // maybe a stream, requires request-time resolution + FieldSkipped FieldStreamType = "skipped" // the field source's node is skipped +) + +type FieldSkipStatus string + +// SourceInfo contains stream type for a input field source of a node. +type SourceInfo struct { + // IsIntermediate means this field is itself not a field source, but a map containing one or more field sources. + IsIntermediate bool + // FieldType the stream type of the field. May require request-time resolution in addition to compile-time. + FieldType FieldStreamType + // FromNodeKey is the node key that produces this field source. empty if the field is a static value or variable. + FromNodeKey vo.NodeKey + // FromPath is the path of this field source within the source node. empty if the field is a static value or variable. + FromPath compose.FieldPath + TypeInfo *vo.TypeInfo + // SubSources are SourceInfo for keys within this intermediate Map(Object) field. + SubSources map[string]*SourceInfo +} + +func (s *SourceInfo) Skipped() bool { + if !s.IsIntermediate { + return s.FieldType == FieldSkipped + } + + for _, sub := range s.SubSources { + if !sub.Skipped() { + return false + } + } + + return true +} + +func (s *SourceInfo) FromNode(nodeKey vo.NodeKey) bool { + if !s.IsIntermediate { + return s.FromNodeKey == nodeKey + } + + for _, sub := range s.SubSources { + if sub.FromNode(nodeKey) { + return true + } + } + + return false +} diff --git a/backend/domain/workflow/internal/compose/workflow_schema.go b/backend/domain/workflow/internal/schema/workflow_schema.go similarity index 81% rename from backend/domain/workflow/internal/compose/workflow_schema.go rename to backend/domain/workflow/internal/schema/workflow_schema.go index 04df3cad..13682239 100644 --- a/backend/domain/workflow/internal/compose/workflow_schema.go +++ b/backend/domain/workflow/internal/schema/workflow_schema.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package compose +package schema import ( "fmt" @@ -29,9 +29,10 @@ import ( ) type WorkflowSchema struct { - Nodes []*NodeSchema `json:"nodes"` - Connections []*Connection `json:"connections"` - Hierarchy map[vo.NodeKey]vo.NodeKey `json:"hierarchy,omitempty"` // child node key-> parent node key + Nodes []*NodeSchema `json:"nodes"` + Connections []*Connection `json:"connections"` + Hierarchy map[vo.NodeKey]vo.NodeKey `json:"hierarchy,omitempty"` // child node key-> parent node key + Branches map[vo.NodeKey]*BranchSchema `json:"branches,omitempty"` GeneratedNodes []vo.NodeKey `json:"generated_nodes,omitempty"` // generated nodes for the nodes in batch mode @@ -71,9 +72,19 @@ func (w *WorkflowSchema) Init() { w.doGetCompositeNodes() for _, node := range w.Nodes { - if node.requireCheckpoint() { - w.requireCheckPoint = true - break + if node.Type == entity.NodeTypeSubWorkflow { + node.SubWorkflowSchema.Init() + if node.SubWorkflowSchema.requireCheckPoint { + w.requireCheckPoint = true + break + } + } + + if rc, ok := node.Configs.(RequireCheckpoint); ok { + if rc.RequireCheckpoint() { + w.requireCheckPoint = true + break + } } } @@ -97,6 +108,22 @@ func (w *WorkflowSchema) GetCompositeNodes() []*CompositeNode { return w.compositeNodes } +func (w *WorkflowSchema) GetBranch(key vo.NodeKey) *BranchSchema { + if w.Branches == nil { + return nil + } + + return w.Branches[key] +} + +func (w *WorkflowSchema) RequireCheckpoint() bool { + return w.requireCheckPoint +} + +func (w *WorkflowSchema) RequireStreaming() bool { + return w.requireStreaming +} + func (w *WorkflowSchema) doGetCompositeNodes() (cNodes []*CompositeNode) { if w.Hierarchy == nil { return nil @@ -125,7 +152,7 @@ func (w *WorkflowSchema) doGetCompositeNodes() (cNodes []*CompositeNode) { return cNodes } -func isInSameWorkflow(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.NodeKey) bool { +func IsInSameWorkflow(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.NodeKey) bool { if n == nil { return true } @@ -144,7 +171,7 @@ func isInSameWorkflow(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.Node return myParents == theirParents } -func isBelowOneLevel(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.NodeKey) bool { +func IsBelowOneLevel(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.NodeKey) bool { if n == nil { return false } @@ -154,7 +181,7 @@ func isBelowOneLevel(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.NodeK return myParentExists && !theirParentExists } -func isParentOf(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.NodeKey) bool { +func IsParentOf(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.NodeKey) bool { if n == nil { return false } @@ -230,21 +257,14 @@ func (w *WorkflowSchema) doRequireStreaming() bool { consumers := make(map[vo.NodeKey]bool) for _, node := range w.Nodes { - meta := entity.NodeMetaByNodeType(node.Type) - if meta != nil { - sps := meta.ExecutableMeta.StreamingParadigms - if _, ok := sps[entity.Stream]; ok { - if node.StreamConfigs != nil && node.StreamConfigs.CanGeneratesStream { - producers[node.Key] = true - } - } - - if sps[entity.Transform] || sps[entity.Collect] { - if node.StreamConfigs != nil && node.StreamConfigs.RequireStreamingInput { - consumers[node.Key] = true - } - } + if node.StreamConfigs != nil && node.StreamConfigs.CanGeneratesStream { + producers[node.Key] = true } + + if node.StreamConfigs != nil && node.StreamConfigs.RequireStreamingInput { + consumers[node.Key] = true + } + } if len(producers) == 0 || len(consumers) == 0 { @@ -290,7 +310,7 @@ func (w *WorkflowSchema) doRequireStreaming() bool { return false } -func (w *WorkflowSchema) fanInMergeConfigs() map[string]compose.FanInMergeConfig { +func (w *WorkflowSchema) FanInMergeConfigs() map[string]compose.FanInMergeConfig { // what we need to do is to see if the workflow requires streaming, if not, then no fan-in merge configs needed // then we find those nodes that have 'transform' or 'collect' as streaming paradigm, // and see if each of those nodes has multiple data predecessors, if so, it's a fan-in node. @@ -301,21 +321,15 @@ func (w *WorkflowSchema) fanInMergeConfigs() map[string]compose.FanInMergeConfig fanInNodes := make(map[vo.NodeKey]bool) for _, node := range w.Nodes { - meta := entity.NodeMetaByNodeType(node.Type) - if meta != nil { - sps := meta.ExecutableMeta.StreamingParadigms - if sps[entity.Transform] || sps[entity.Collect] { - if node.StreamConfigs != nil && node.StreamConfigs.RequireStreamingInput { - var predecessor *vo.NodeKey - for _, source := range node.InputSources { - if source.Source.Ref != nil && len(source.Source.Ref.FromNodeKey) > 0 { - if predecessor != nil { - fanInNodes[node.Key] = true - break - } - predecessor = &source.Source.Ref.FromNodeKey - } + if node.StreamConfigs != nil && node.StreamConfigs.RequireStreamingInput { + var predecessor *vo.NodeKey + for _, source := range node.InputSources { + if source.Source.Ref != nil && len(source.Source.Ref.FromNodeKey) > 0 { + if predecessor != nil { + fanInNodes[node.Key] = true + break } + predecessor = &source.Source.Ref.FromNodeKey } } } diff --git a/backend/domain/workflow/service/service_impl.go b/backend/domain/workflow/service/service_impl.go index ef6a374b..11e9abcb 100644 --- a/backend/domain/workflow/service/service_impl.go +++ b/backend/domain/workflow/service/service_impl.go @@ -37,8 +37,11 @@ import ( "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/adaptor" - "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/intentdetector" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/repo" + "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/infra/contract/cache" "github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel" "github.com/coze-dev/coze-studio/backend/infra/contract/idgen" @@ -73,7 +76,7 @@ func NewWorkflowRepository(idgen idgen.IDGenerator, db *gorm.DB, redis cache.Cmd return repo.NewRepository(idgen, db, redis, tos, cpStore, chatModel) } -func (i *impl) ListNodeMeta(ctx context.Context, nodeTypes map[entity.NodeType]bool) (map[string][]*entity.NodeTypeMeta, []entity.Category, error) { +func (i *impl) ListNodeMeta(_ context.Context, nodeTypes map[entity.NodeType]bool) (map[string][]*entity.NodeTypeMeta, []entity.Category, error) { // Initialize result maps nodeMetaMap := make(map[string][]*entity.NodeTypeMeta) @@ -82,7 +85,7 @@ func (i *impl) ListNodeMeta(ctx context.Context, nodeTypes map[entity.NodeType]b if meta.Disabled { return false } - nodeType := meta.Type + nodeType := meta.Key if nodeTypes == nil || len(nodeTypes) == 0 { return true // No filter, include all } @@ -192,10 +195,10 @@ func extractInputsAndOutputsNamedInfoList(c *vo.Canvas) (inputs []*vo.NamedTypeI if startNode != nil && endNode != nil { break } - if node.Type == vo.BlockTypeBotStart { + if node.Type == entity.NodeTypeEntry.IDStr() { startNode = node } - if node.Type == vo.BlockTypeBotEnd { + if node.Type == entity.NodeTypeExit.IDStr() { endNode = node } } @@ -207,7 +210,7 @@ func extractInputsAndOutputsNamedInfoList(c *vo.Canvas) (inputs []*vo.NamedTypeI if err != nil { return nil, err } - nInfo, err := adaptor.VariableToNamedTypeInfo(v) + nInfo, err := convert.VariableToNamedTypeInfo(v) if err != nil { return nil, err } @@ -220,7 +223,7 @@ func extractInputsAndOutputsNamedInfoList(c *vo.Canvas) (inputs []*vo.NamedTypeI if endNode != nil { outputs, err = slices.TransformWithErrorCheck(endNode.Data.Inputs.InputParameters, func(a *vo.Param) (*vo.NamedTypeInfo, error) { - return adaptor.BlockInputToNamedTypeInfo(a.Name, a.Input) + return convert.BlockInputToNamedTypeInfo(a.Name, a.Input) }) if err != nil { logs.Warn(fmt.Sprintf("transform end node inputs to named info failed, err=%v", err)) @@ -316,6 +319,34 @@ func (i *impl) GetWorkflowReference(ctx context.Context, id int64) (map[int64]*v return ret, nil } +type workflowIdentity struct { + ID string `json:"id"` + Version string `json:"version"` +} + +func getAllSubWorkflowIdentities(c *vo.Canvas) []*workflowIdentity { + workflowEntities := make([]*workflowIdentity, 0) + + var collectSubWorkFlowEntities func(nodes []*vo.Node) + collectSubWorkFlowEntities = func(nodes []*vo.Node) { + for _, n := range nodes { + if n.Type == entity.NodeTypeSubWorkflow.IDStr() { + workflowEntities = append(workflowEntities, &workflowIdentity{ + ID: n.Data.Inputs.WorkflowID, + Version: n.Data.Inputs.WorkflowVersion, + }) + } + if len(n.Blocks) > 0 { + collectSubWorkFlowEntities(n.Blocks) + } + } + } + + collectSubWorkFlowEntities(c.Nodes) + + return workflowEntities +} + func (i *impl) ValidateTree(ctx context.Context, id int64, validateConfig vo.ValidateTreeConfig) ([]*cloudworkflow.ValidateTreeInfo, error) { wfValidateInfos := make([]*cloudworkflow.ValidateTreeInfo, 0) issues, err := validateWorkflowTree(ctx, validateConfig) @@ -337,7 +368,7 @@ func (i *impl) ValidateTree(ctx context.Context, id int64, validateConfig vo.Val fmt.Errorf("failed to unmarshal canvas schema: %w", err)) } - subWorkflowIdentities := c.GetAllSubWorkflowIdentities() + subWorkflowIdentities := getAllSubWorkflowIdentities(c) if len(subWorkflowIdentities) > 0 { var ids []int64 @@ -421,25 +452,21 @@ func (i *impl) collectNodePropertyMap(ctx context.Context, canvas *vo.Canvas) (m } for _, n := range canvas.Nodes { - if n.Type == vo.BlockTypeBotSubWorkflow { - nodeSchema := &compose.NodeSchema{ + if n.Type == entity.NodeTypeSubWorkflow.IDStr() { + nodeSchema := &schema.NodeSchema{ Key: vo.NodeKey(n.ID), Type: entity.NodeTypeSubWorkflow, Name: n.Data.Meta.Title, } - err := adaptor.SetInputsForNodeSchema(n, nodeSchema) - if err != nil { - return nil, err - } - blockType, err := entityNodeTypeToBlockType(nodeSchema.Type) + err := convert.SetInputsForNodeSchema(n, nodeSchema) if err != nil { return nil, err } prop := &vo.NodeProperty{ - Type: string(blockType), - IsEnableUserQuery: nodeSchema.IsEnableUserQuery(), - IsEnableChatHistory: nodeSchema.IsEnableChatHistory(), - IsRefGlobalVariable: nodeSchema.IsRefGlobalVariable(), + Type: nodeSchema.Type.IDStr(), + IsEnableUserQuery: isEnableUserQuery(nodeSchema), + IsEnableChatHistory: isEnableChatHistory(nodeSchema), + IsRefGlobalVariable: isRefGlobalVariable(nodeSchema), } nodePropertyMap[string(nodeSchema.Key)] = prop wid, err := strconv.ParseInt(n.Data.Inputs.WorkflowID, 10, 64) @@ -478,20 +505,16 @@ func (i *impl) collectNodePropertyMap(ctx context.Context, canvas *vo.Canvas) (m prop.SubWorkflow = ret } else { - nodeSchemas, _, err := adaptor.NodeToNodeSchema(ctx, n) + nodeSchemas, _, err := adaptor.NodeToNodeSchema(ctx, n, canvas) if err != nil { return nil, err } for _, nodeSchema := range nodeSchemas { - blockType, err := entityNodeTypeToBlockType(nodeSchema.Type) - if err != nil { - return nil, err - } nodePropertyMap[string(nodeSchema.Key)] = &vo.NodeProperty{ - Type: string(blockType), - IsEnableUserQuery: nodeSchema.IsEnableUserQuery(), - IsEnableChatHistory: nodeSchema.IsEnableChatHistory(), - IsRefGlobalVariable: nodeSchema.IsRefGlobalVariable(), + Type: nodeSchema.Type.IDStr(), + IsEnableUserQuery: isEnableUserQuery(nodeSchema), + IsEnableChatHistory: isEnableChatHistory(nodeSchema), + IsRefGlobalVariable: isRefGlobalVariable(nodeSchema), } } @@ -500,6 +523,60 @@ func (i *impl) collectNodePropertyMap(ctx context.Context, canvas *vo.Canvas) (m return nodePropertyMap, nil } +func isEnableUserQuery(s *schema.NodeSchema) bool { + if s == nil { + return false + } + if s.Type != entity.NodeTypeEntry { + return false + } + + if len(s.OutputSources) == 0 { + return false + } + + for _, source := range s.OutputSources { + fieldPath := source.Path + if len(fieldPath) == 1 && (fieldPath[0] == "BOT_USER_INPUT" || fieldPath[0] == "USER_INPUT") { + return true + } + } + + return false +} + +func isEnableChatHistory(s *schema.NodeSchema) bool { + if s == nil { + return false + } + + switch s.Type { + + case entity.NodeTypeLLM: + llmParam := s.Configs.(*llm.Config).LLMParams + return llmParam.EnableChatHistory + case entity.NodeTypeIntentDetector: + llmParam := s.Configs.(*intentdetector.Config).LLMParams + return llmParam.EnableChatHistory + default: + return false + } +} + +func isRefGlobalVariable(s *schema.NodeSchema) bool { + for _, source := range s.InputSources { + if source.IsRefGlobalVariable() { + return true + } + } + for _, source := range s.OutputSources { + if source.IsRefGlobalVariable() { + return true + } + } + return false +} + func canvasToRefs(referringID int64, canvasStr string) (map[entity.WorkflowReferenceKey]struct{}, error) { var canvas vo.Canvas if err := sonic.UnmarshalString(canvasStr, &canvas); err != nil { @@ -510,7 +587,7 @@ func canvasToRefs(referringID int64, canvasStr string) (map[entity.WorkflowRefer var getRefFn func([]*vo.Node) error getRefFn = func(nodes []*vo.Node) error { for _, node := range nodes { - if node.Type == vo.BlockTypeBotSubWorkflow { + if node.Type == entity.NodeTypeSubWorkflow.IDStr() { referredID, err := strconv.ParseInt(node.Data.Inputs.WorkflowID, 10, 64) if err != nil { return vo.WrapError(errno.ErrSchemaConversionFail, err) @@ -521,19 +598,21 @@ func canvasToRefs(referringID int64, canvasStr string) (map[entity.WorkflowRefer ReferType: vo.ReferTypeSubWorkflow, ReferringBizType: vo.ReferringBizTypeWorkflow, }] = struct{}{} - } else if node.Type == vo.BlockTypeBotLLM { - if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil { - for _, w := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList { - referredID, err := strconv.ParseInt(w.WorkflowID, 10, 64) - if err != nil { - return vo.WrapError(errno.ErrSchemaConversionFail, err) + } else if node.Type == entity.NodeTypeLLM.IDStr() { + if node.Data.Inputs.LLM != nil { + if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil { + for _, w := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList { + referredID, err := strconv.ParseInt(w.WorkflowID, 10, 64) + if err != nil { + return vo.WrapError(errno.ErrSchemaConversionFail, err) + } + wfRefs[entity.WorkflowReferenceKey{ + ReferredID: referredID, + ReferringID: referringID, + ReferType: vo.ReferTypeTool, + ReferringBizType: vo.ReferringBizTypeWorkflow, + }] = struct{}{} } - wfRefs[entity.WorkflowReferenceKey{ - ReferredID: referredID, - ReferringID: referringID, - ReferType: vo.ReferTypeTool, - ReferringBizType: vo.ReferringBizTypeWorkflow, - }] = struct{}{} } } } else if len(node.Blocks) > 0 { @@ -832,7 +911,7 @@ func (i *impl) CopyWorkflowFromAppToLibrary(ctx context.Context, workflowID int6 validateAndBuildWorkflowReference = func(nodes []*vo.Node, wf *copiedWorkflow) error { for _, node := range nodes { - if node.Type == vo.BlockTypeBotSubWorkflow { + if node.Type == entity.NodeTypeSubWorkflow.IDStr() { var ( v *vo.DraftInfo wfID int64 @@ -883,7 +962,7 @@ func (i *impl) CopyWorkflowFromAppToLibrary(ctx context.Context, workflowID int6 } - if node.Type == vo.BlockTypeBotLLM { + if node.Type == entity.NodeTypeLLM.IDStr() { if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil { for _, w := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList { var ( @@ -1086,7 +1165,7 @@ func (i *impl) DuplicateWorkflowsByAppID(ctx context.Context, sourceAppID, targe var buildWorkflowReference func(nodes []*vo.Node, wf *copiedWorkflow) error buildWorkflowReference = func(nodes []*vo.Node, wf *copiedWorkflow) error { for _, node := range nodes { - if node.Type == vo.BlockTypeBotSubWorkflow { + if node.Type == entity.NodeTypeSubWorkflow.IDStr() { var ( v *vo.DraftInfo wfID int64 @@ -1121,7 +1200,7 @@ func (i *impl) DuplicateWorkflowsByAppID(ctx context.Context, sourceAppID, targe } } - if node.Type == vo.BlockTypeBotLLM { + if node.Type == entity.NodeTypeLLM.IDStr() { if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil { for _, w := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList { var ( @@ -1323,8 +1402,8 @@ func (i *impl) GetWorkflowDependenceResource(ctx context.Context, workflowID int var collectDependence func(nodes []*vo.Node) error collectDependence = func(nodes []*vo.Node) error { for _, node := range nodes { - switch node.Type { - case vo.BlockTypeBotAPI: + switch entity.IDStrToNodeType(node.Type) { + case entity.NodeTypePlugin: apiParams := slices.ToMap(node.Data.Inputs.APIParams, func(e *vo.Param) (string, *vo.Param) { return e.Name, e }) @@ -1347,7 +1426,7 @@ func (i *impl) GetWorkflowDependenceResource(ctx context.Context, workflowID int ds.PluginIDs = append(ds.PluginIDs, pID) } - case vo.BlockTypeBotDatasetWrite, vo.BlockTypeBotDataset: + case entity.NodeTypeKnowledgeIndexer, entity.NodeTypeKnowledgeRetriever: datasetListInfoParam := node.Data.Inputs.DatasetParam[0] datasetIDs := datasetListInfoParam.Input.Value.Content.([]any) for _, id := range datasetIDs { @@ -1357,7 +1436,7 @@ func (i *impl) GetWorkflowDependenceResource(ctx context.Context, workflowID int } ds.KnowledgeIDs = append(ds.KnowledgeIDs, k) } - case vo.BlockTypeDatabase, vo.BlockTypeDatabaseSelect, vo.BlockTypeDatabaseInsert, vo.BlockTypeDatabaseDelete, vo.BlockTypeDatabaseUpdate: + case entity.NodeTypeDatabaseCustomSQL, entity.NodeTypeDatabaseQuery, entity.NodeTypeDatabaseInsert, entity.NodeTypeDatabaseDelete, entity.NodeTypeDatabaseUpdate: dsList := node.Data.Inputs.DatabaseInfoList if len(dsList) == 0 { return fmt.Errorf("database info is requird") @@ -1369,7 +1448,7 @@ func (i *impl) GetWorkflowDependenceResource(ctx context.Context, workflowID int } ds.DatabaseIDs = append(ds.DatabaseIDs, dsID) } - case vo.BlockTypeBotLLM: + case entity.NodeTypeLLM: if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.PluginFCParam != nil { for idx := range node.Data.Inputs.FCParam.PluginFCParam.PluginList { pl := node.Data.Inputs.FCParam.PluginFCParam.PluginList[idx] @@ -1396,7 +1475,7 @@ func (i *impl) GetWorkflowDependenceResource(ctx context.Context, workflowID int } } - case vo.BlockTypeBotSubWorkflow: + case entity.NodeTypeSubWorkflow: wfID, err := strconv.ParseInt(node.Data.Inputs.WorkflowID, 10, 64) if err != nil { return err @@ -1567,8 +1646,8 @@ func replaceRelatedWorkflowOrExternalResourceInWorkflowNodes(nodes []*vo.Node, r ) for _, node := range nodes { - switch node.Type { - case vo.BlockTypeBotSubWorkflow: + switch entity.IDStrToNodeType(node.Type) { + case entity.NodeTypeSubWorkflow: if !hasWorkflowRelated { continue } @@ -1580,7 +1659,7 @@ func replaceRelatedWorkflowOrExternalResourceInWorkflowNodes(nodes []*vo.Node, r node.Data.Inputs.WorkflowID = strconv.FormatInt(wf.ID, 10) node.Data.Inputs.WorkflowVersion = wf.Version } - case vo.BlockTypeBotAPI: + case entity.NodeTypePlugin: if !hasPluginRelated { continue } @@ -1623,7 +1702,7 @@ func replaceRelatedWorkflowOrExternalResourceInWorkflowNodes(nodes []*vo.Node, r apiIDParam.Input.Value.Content = strconv.FormatInt(refApiID, 10) } - case vo.BlockTypeBotLLM: + case entity.NodeTypeLLM: if hasWorkflowRelated && node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil { for idx := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList { wf := node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList[idx] @@ -1669,7 +1748,7 @@ func replaceRelatedWorkflowOrExternalResourceInWorkflowNodes(nodes []*vo.Node, r } } - case vo.BlockTypeBotDataset, vo.BlockTypeBotDatasetWrite: + case entity.NodeTypeKnowledgeIndexer, entity.NodeTypeKnowledgeRetriever: if !hasKnowledgeRelated { continue } @@ -1685,7 +1764,7 @@ func replaceRelatedWorkflowOrExternalResourceInWorkflowNodes(nodes []*vo.Node, r } } - case vo.BlockTypeDatabase, vo.BlockTypeDatabaseSelect, vo.BlockTypeDatabaseInsert, vo.BlockTypeDatabaseDelete, vo.BlockTypeDatabaseUpdate: + case entity.NodeTypeDatabaseCustomSQL, entity.NodeTypeDatabaseQuery, entity.NodeTypeDatabaseInsert, entity.NodeTypeDatabaseDelete, entity.NodeTypeDatabaseUpdate: if !hasDatabaseRelated { continue } @@ -1713,3 +1792,7 @@ func replaceRelatedWorkflowOrExternalResourceInWorkflowNodes(nodes []*vo.Node, r } return nil } + +func RegisterAllNodeAdaptors() { + adaptor.RegisterAllNodeAdaptors() +} diff --git a/backend/domain/workflow/service/utils.go b/backend/domain/workflow/service/utils.go index 840bfd57..9665d430 100644 --- a/backend/domain/workflow/service/utils.go +++ b/backend/domain/workflow/service/utils.go @@ -24,11 +24,9 @@ import ( cloudworkflow "github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/workflow" "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable" - "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/adaptor" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/validate" - "github.com/coze-dev/coze-studio/backend/pkg/lang/slices" "github.com/coze-dev/coze-studio/backend/pkg/sonic" "github.com/coze-dev/coze-studio/backend/types/errno" ) @@ -199,158 +197,3 @@ func isIncremental(prev version, next version) bool { return next.Patch > prev.Patch } - -func replaceRelatedWorkflowOrPluginInWorkflowNodes(nodes []*vo.Node, relatedWorkflows map[int64]entity.IDVersionPair, relatedPlugins map[int64]vo.PluginEntity) error { - for _, node := range nodes { - if node.Type == vo.BlockTypeBotSubWorkflow { - workflowID, err := strconv.ParseInt(node.Data.Inputs.WorkflowID, 10, 64) - if err != nil { - return err - } - if wf, ok := relatedWorkflows[workflowID]; ok { - node.Data.Inputs.WorkflowID = strconv.FormatInt(wf.ID, 10) - node.Data.Inputs.WorkflowVersion = wf.Version - } - } - if node.Type == vo.BlockTypeBotAPI { - apiParams := slices.ToMap(node.Data.Inputs.APIParams, func(e *vo.Param) (string, *vo.Param) { - return e.Name, e - }) - pluginIDParam, ok := apiParams["pluginID"] - if !ok { - return fmt.Errorf("plugin id param is not found") - } - - pID, err := strconv.ParseInt(pluginIDParam.Input.Value.Content.(string), 10, 64) - if err != nil { - return err - } - - pluginVersionParam, ok := apiParams["pluginVersion"] - if !ok { - return fmt.Errorf("plugin version param is not found") - } - - if refPlugin, ok := relatedPlugins[pID]; ok { - pluginIDParam.Input.Value.Content = refPlugin.PluginID - if refPlugin.PluginVersion != nil { - pluginVersionParam.Input.Value.Content = *refPlugin.PluginVersion - } - - } - } - - if node.Type == vo.BlockTypeBotLLM { - if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil { - for idx := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList { - wf := node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList[idx] - workflowID, err := strconv.ParseInt(wf.WorkflowID, 10, 64) - if err != nil { - return err - } - if refWf, ok := relatedWorkflows[workflowID]; ok { - wf.WorkflowID = strconv.FormatInt(refWf.ID, 10) - wf.WorkflowVersion = refWf.Version - } - - } - } - if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.PluginFCParam != nil { - for idx := range node.Data.Inputs.FCParam.PluginFCParam.PluginList { - pl := node.Data.Inputs.FCParam.PluginFCParam.PluginList[idx] - pluginID, err := strconv.ParseInt(pl.PluginID, 10, 64) - if err != nil { - return err - } - if refPlugin, ok := relatedPlugins[pluginID]; ok { - pl.PluginID = strconv.FormatInt(refPlugin.PluginID, 10) - if refPlugin.PluginVersion != nil { - pl.PluginVersion = *refPlugin.PluginVersion - } - - } - - } - } - - } - if len(node.Blocks) > 0 { - err := replaceRelatedWorkflowOrPluginInWorkflowNodes(node.Blocks, relatedWorkflows, relatedPlugins) - if err != nil { - return err - } - } - - } - return nil -} - -// entityNodeTypeToBlockType converts an entity.NodeType to the corresponding vo.BlockType. -func entityNodeTypeToBlockType(nodeType entity.NodeType) (vo.BlockType, error) { - switch nodeType { - case entity.NodeTypeEntry: - return vo.BlockTypeBotStart, nil - case entity.NodeTypeExit: - return vo.BlockTypeBotEnd, nil - case entity.NodeTypeLLM: - return vo.BlockTypeBotLLM, nil - case entity.NodeTypePlugin: - return vo.BlockTypeBotAPI, nil - case entity.NodeTypeCodeRunner: - return vo.BlockTypeBotCode, nil - case entity.NodeTypeKnowledgeRetriever: - return vo.BlockTypeBotDataset, nil - case entity.NodeTypeSelector: - return vo.BlockTypeCondition, nil - case entity.NodeTypeSubWorkflow: - return vo.BlockTypeBotSubWorkflow, nil - case entity.NodeTypeDatabaseCustomSQL: - return vo.BlockTypeDatabase, nil - case entity.NodeTypeOutputEmitter: - return vo.BlockTypeBotMessage, nil - case entity.NodeTypeTextProcessor: - return vo.BlockTypeBotText, nil - case entity.NodeTypeQuestionAnswer: - return vo.BlockTypeQuestion, nil - case entity.NodeTypeBreak: - return vo.BlockTypeBotBreak, nil - case entity.NodeTypeVariableAssigner: - return vo.BlockTypeBotAssignVariable, nil - case entity.NodeTypeVariableAssignerWithinLoop: - return vo.BlockTypeBotLoopSetVariable, nil - case entity.NodeTypeLoop: - return vo.BlockTypeBotLoop, nil - case entity.NodeTypeIntentDetector: - return vo.BlockTypeBotIntent, nil - case entity.NodeTypeKnowledgeIndexer: - return vo.BlockTypeBotDatasetWrite, nil - case entity.NodeTypeBatch: - return vo.BlockTypeBotBatch, nil - case entity.NodeTypeContinue: - return vo.BlockTypeBotContinue, nil - case entity.NodeTypeInputReceiver: - return vo.BlockTypeBotInput, nil - case entity.NodeTypeDatabaseUpdate: - return vo.BlockTypeDatabaseUpdate, nil - case entity.NodeTypeDatabaseQuery: - return vo.BlockTypeDatabaseSelect, nil - case entity.NodeTypeDatabaseDelete: - return vo.BlockTypeDatabaseDelete, nil - case entity.NodeTypeHTTPRequester: - return vo.BlockTypeBotHttp, nil - case entity.NodeTypeDatabaseInsert: - return vo.BlockTypeDatabaseInsert, nil - case entity.NodeTypeVariableAggregator: - return vo.BlockTypeBotVariableMerge, nil - case entity.NodeTypeJsonSerialization: - return vo.BlockTypeJsonSerialization, nil - case entity.NodeTypeJsonDeserialization: - return vo.BlockTypeJsonDeserialization, nil - case entity.NodeTypeKnowledgeDeleter: - return vo.BlockTypeBotDatasetDelete, nil - - default: - return "", vo.WrapError(errno.ErrSchemaConversionFail, - fmt.Errorf("cannot map entity node type '%s' to a workflow.NodeTemplateType", nodeType)) - } -} diff --git a/backend/types/errno/workflow.go b/backend/types/errno/workflow.go index a2f29ba3..81c643a6 100644 --- a/backend/types/errno/workflow.go +++ b/backend/types/errno/workflow.go @@ -82,7 +82,7 @@ func init() { code.Register( ErrMissingRequiredParam, - "Missing required parameters {param}. Please review the API documentation and ensure all mandatory fields are included in your request.", + "Missing required parameters:’{param}‘. Please review the API documentation and ensure all mandatory fields are included in your request.", code.WithAffectStability(false), )