From 55628009586dbbe04ca646a4580f24d7bb4bf659 Mon Sep 17 00:00:00 2001 From: shentongmartin Date: Wed, 27 Aug 2025 16:29:42 +0800 Subject: [PATCH] fix: workflow tool closes stream writer correctly (#1839) --- .../api/handler/coze/workflow_service_test.go | 860 +++++++++--------- .../crossdomain/contract/workflow/workflow.go | 2 +- backend/crossdomain/impl/workflow/workflow.go | 2 +- .../internal/agentflow/agent_flow_runner.go | 7 +- .../domain/workflow/component_interface.go | 2 +- backend/domain/workflow/entity/message.go | 1 + .../internal/compose/designate_option.go | 23 +- .../workflow/internal/compose/workflow_run.go | 51 +- .../internal/compose/workflow_tool.go | 181 ++-- .../workflow/internal/execute/callback.go | 6 + .../internal/execute/collect_token.go | 17 +- .../workflow/internal/execute/event_handle.go | 31 + .../internal/execute/stream_container.go | 74 ++ .../workflow/internal/execute/tool_option.go | 31 +- .../workflow/internal/nodes/batch/batch.go | 3 - .../domain/workflow/internal/nodes/convert.go | 2 - .../domain/workflow/internal/nodes/llm/llm.go | 46 +- .../domain/workflow/service/as_tool_impl.go | 2 +- .../mock/domain/workflow/interface.go | 21 +- 19 files changed, 742 insertions(+), 620 deletions(-) create mode 100644 backend/domain/workflow/internal/execute/stream_container.go diff --git a/backend/api/handler/coze/workflow_service_test.go b/backend/api/handler/coze/workflow_service_test.go index adb87485..ddeffff3 100644 --- a/backend/api/handler/coze/workflow_service_test.go +++ b/backend/api/handler/coze/workflow_service_test.go @@ -1815,212 +1815,212 @@ func TestUpdateWorkflowMeta(t *testing.T) { }) } -//func TestSimpleInvokableToolWithReturnVariables(t *testing.T) { -// mockey.PatchConvey("simple invokable tool with return variables", t, func() { -// r := newWfTestRunner(t) -// defer r.closeFn() -// -// toolID := r.load("function_call/tool_workflow_1.json", withID(7492075279843737651), withPublish("v0.0.1")) -// -// chatModel := &testutil.UTChatModel{ -// InvokeResultProvider: func(index int, in []*schema.Message) (*schema.Message, error) { -// if index == 0 { -// return &schema.Message{ -// Role: schema.Assistant, -// ToolCalls: []schema.ToolCall{ -// { -// ID: "1", -// Function: schema.FunctionCall{ -// Name: "ts_test_wf_test_wf", -// Arguments: "{}", -// }, -// }, -// }, -// ResponseMeta: &schema.ResponseMeta{ -// Usage: &schema.TokenUsage{ -// PromptTokens: 10, -// CompletionTokens: 11, -// TotalTokens: 21, -// }, -// }, -// }, nil -// } else if index == 1 { -// return &schema.Message{ -// Role: schema.Assistant, -// Content: "final_answer", -// ResponseMeta: &schema.ResponseMeta{ -// Usage: &schema.TokenUsage{ -// PromptTokens: 5, -// CompletionTokens: 6, -// TotalTokens: 11, -// }, -// }, -// }, nil -// } else { -// return nil, fmt.Errorf("unexpected index: %d", index) -// } -// }, -// } -// r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(chatModel, nil, nil).AnyTimes() -// -// id := r.load("function_call/llm_with_workflow_as_tool.json") -// defer func() { -// post[workflow.DeleteWorkflowResponse](r, &workflow.DeleteWorkflowRequest{ -// WorkflowID: id, -// }) -// }() -// -// exeID := r.testRun(id, map[string]string{ -// "input": "this is the user input", -// }) -// -// e := r.getProcess(id, exeID) -// e.assertSuccess() -// assert.Equal(t, map[string]any{ -// "output": "final_answer", -// }, mustUnmarshalToMap(t, e.output)) -// e.tokenEqual(15, 17) -// -// mockey.PatchConvey("check behavior if stream run", func() { -// chatModel.Reset() -// -// defer r.runServer()() -// -// r.publish(id, "v0.0.1", true) -// -// sseReader := r.openapiStream(id, map[string]any{ -// "input": "hello", -// }) -// err := sseReader.ForEach(t.Context(), func(e *sse.Event) error { -// t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data)) -// return nil -// }) -// assert.NoError(t, err) -// -// // check workflow references are correct -// refs := post[workflow.GetWorkflowReferencesResponse](r, &workflow.GetWorkflowReferencesRequest{ -// WorkflowID: toolID, -// }) -// assert.Equal(t, 1, len(refs.Data.WorkflowList)) -// assert.Equal(t, id, refs.Data.WorkflowList[0].WorkflowID) -// }) -// }) -//} +func TestSimpleInvokableToolWithReturnVariables(t *testing.T) { + mockey.PatchConvey("simple invokable tool with return variables", t, func() { + r := newWfTestRunner(t) + defer r.closeFn() -//func TestReturnDirectlyStreamableTool(t *testing.T) { -// mockey.PatchConvey("return directly streamable tool", t, func() { -// r := newWfTestRunner(t) -// defer r.closeFn() -// -// outerModel := &testutil.UTChatModel{ -// StreamResultProvider: func(index int, in []*schema.Message) (*schema.StreamReader[*schema.Message], error) { -// if index == 0 { -// return schema.StreamReaderFromArray([]*schema.Message{ -// { -// Role: schema.Assistant, -// ToolCalls: []schema.ToolCall{ -// { -// ID: "1", -// Function: schema.FunctionCall{ -// Name: "ts_test_wf_test_wf", -// Arguments: `{"input": "input for inner model"}`, -// }, -// }, -// }, -// ResponseMeta: &schema.ResponseMeta{ -// Usage: &schema.TokenUsage{ -// PromptTokens: 10, -// CompletionTokens: 11, -// TotalTokens: 21, -// }, -// }, -// }, -// }), nil -// } else { -// return nil, fmt.Errorf("unexpected index: %d", index) -// } -// }, -// } -// -// innerModel := &testutil.UTChatModel{ -// StreamResultProvider: func(index int, in []*schema.Message) (*schema.StreamReader[*schema.Message], error) { -// if index == 0 { -// return schema.StreamReaderFromArray([]*schema.Message{ -// { -// Role: schema.Assistant, -// Content: "I ", -// ResponseMeta: &schema.ResponseMeta{ -// Usage: &schema.TokenUsage{ -// PromptTokens: 5, -// CompletionTokens: 6, -// TotalTokens: 11, -// }, -// }, -// }, -// { -// Role: schema.Assistant, -// Content: "don't know", -// ResponseMeta: &schema.ResponseMeta{ -// Usage: &schema.TokenUsage{ -// CompletionTokens: 8, -// TotalTokens: 8, -// }, -// }, -// }, -// { -// Role: schema.Assistant, -// Content: ".", -// ResponseMeta: &schema.ResponseMeta{ -// Usage: &schema.TokenUsage{ -// CompletionTokens: 2, -// TotalTokens: 2, -// }, -// }, -// }, -// }), nil -// } else { -// return nil, fmt.Errorf("unexpected index: %d", index) -// } -// }, -// } -// -// r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, params *model.LLMParams) (model2.BaseChatModel, *modelmgr.Model, error) { -// if params.ModelType == 1706077826 { -// innerModel.ModelType = strconv.FormatInt(params.ModelType, 10) -// return innerModel, nil, nil -// } else { -// outerModel.ModelType = strconv.FormatInt(params.ModelType, 10) -// return outerModel, nil, nil -// } -// }).AnyTimes() -// -// r.load("function_call/tool_workflow_2.json", withID(7492615435881709608), withPublish("v0.0.1")) -// id := r.load("function_call/llm_workflow_stream_tool.json") -// -// exeID := r.testRun(id, map[string]string{ -// "input": "this is the user input", -// }) -// e := r.getProcess(id, exeID) -// e.assertSuccess() -// assert.Equal(t, "this is the streaming output I don't know.", e.output) -// e.tokenEqual(15, 27) -// -// mockey.PatchConvey("check behavior if stream run", func() { -// outerModel.Reset() -// innerModel.Reset() -// defer r.runServer()() -// r.publish(id, "v0.0.1", true) -// sseReader := r.openapiStream(id, map[string]any{ -// "input": "hello", -// }) -// err := sseReader.ForEach(t.Context(), func(e *sse.Event) error { -// t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data)) -// return nil -// }) -// assert.NoError(t, err) -// }) -// }) -//} + toolID := r.load("function_call/tool_workflow_1.json", withID(7492075279843737651), withPublish("v0.0.1")) + + chatModel := &testutil.UTChatModel{ + InvokeResultProvider: func(index int, in []*schema.Message) (*schema.Message, error) { + if index == 0 { + return &schema.Message{ + Role: schema.Assistant, + ToolCalls: []schema.ToolCall{ + { + ID: "1", + Function: schema.FunctionCall{ + Name: "ts_test_wf_test_wf", + Arguments: "{}", + }, + }, + }, + ResponseMeta: &schema.ResponseMeta{ + Usage: &schema.TokenUsage{ + PromptTokens: 10, + CompletionTokens: 11, + TotalTokens: 21, + }, + }, + }, nil + } else if index == 1 { + return &schema.Message{ + Role: schema.Assistant, + Content: "final_answer", + ResponseMeta: &schema.ResponseMeta{ + Usage: &schema.TokenUsage{ + PromptTokens: 5, + CompletionTokens: 6, + TotalTokens: 11, + }, + }, + }, nil + } else { + return nil, fmt.Errorf("unexpected index: %d", index) + } + }, + } + r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(chatModel, nil, nil).AnyTimes() + + id := r.load("function_call/llm_with_workflow_as_tool.json") + defer func() { + post[workflow.DeleteWorkflowResponse](r, &workflow.DeleteWorkflowRequest{ + WorkflowID: id, + }) + }() + + exeID := r.testRun(id, map[string]string{ + "input": "this is the user input", + }) + + e := r.getProcess(id, exeID) + e.assertSuccess() + assert.Equal(t, map[string]any{ + "output": "final_answer", + }, mustUnmarshalToMap(t, e.output)) + e.tokenEqual(15, 17) + + mockey.PatchConvey("check behavior if stream run", func() { + chatModel.Reset() + + defer r.runServer()() + + r.publish(id, "v0.0.1", true) + + sseReader := r.openapiStream(id, map[string]any{ + "input": "hello", + }) + err := sseReader.ForEach(t.Context(), func(e *sse.Event) error { + t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data)) + return nil + }) + assert.NoError(t, err) + + // check workflow references are correct + refs := post[workflow.GetWorkflowReferencesResponse](r, &workflow.GetWorkflowReferencesRequest{ + WorkflowID: toolID, + }) + assert.Equal(t, 1, len(refs.Data.WorkflowList)) + assert.Equal(t, id, refs.Data.WorkflowList[0].WorkflowID) + }) + }) +} + +func TestReturnDirectlyStreamableTool(t *testing.T) { + mockey.PatchConvey("return directly streamable tool", t, func() { + r := newWfTestRunner(t) + defer r.closeFn() + + outerModel := &testutil.UTChatModel{ + StreamResultProvider: func(index int, in []*schema.Message) (*schema.StreamReader[*schema.Message], error) { + if index == 0 { + return schema.StreamReaderFromArray([]*schema.Message{ + { + Role: schema.Assistant, + ToolCalls: []schema.ToolCall{ + { + ID: "1", + Function: schema.FunctionCall{ + Name: "ts_test_wf_test_wf", + Arguments: `{"input": "input for inner model"}`, + }, + }, + }, + ResponseMeta: &schema.ResponseMeta{ + Usage: &schema.TokenUsage{ + PromptTokens: 10, + CompletionTokens: 11, + TotalTokens: 21, + }, + }, + }, + }), nil + } else { + return nil, fmt.Errorf("unexpected index: %d", index) + } + }, + } + + innerModel := &testutil.UTChatModel{ + StreamResultProvider: func(index int, in []*schema.Message) (*schema.StreamReader[*schema.Message], error) { + if index == 0 { + return schema.StreamReaderFromArray([]*schema.Message{ + { + Role: schema.Assistant, + Content: "I ", + ResponseMeta: &schema.ResponseMeta{ + Usage: &schema.TokenUsage{ + PromptTokens: 5, + CompletionTokens: 6, + TotalTokens: 11, + }, + }, + }, + { + Role: schema.Assistant, + Content: "don't know", + ResponseMeta: &schema.ResponseMeta{ + Usage: &schema.TokenUsage{ + CompletionTokens: 8, + TotalTokens: 8, + }, + }, + }, + { + Role: schema.Assistant, + Content: ".", + ResponseMeta: &schema.ResponseMeta{ + Usage: &schema.TokenUsage{ + CompletionTokens: 2, + TotalTokens: 2, + }, + }, + }, + }), nil + } else { + return nil, fmt.Errorf("unexpected index: %d", index) + } + }, + } + + r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, params *model.LLMParams) (model2.BaseChatModel, *modelmgr.Model, error) { + if params.ModelType == 1706077826 { + innerModel.ModelType = strconv.FormatInt(params.ModelType, 10) + return innerModel, nil, nil + } else { + outerModel.ModelType = strconv.FormatInt(params.ModelType, 10) + return outerModel, nil, nil + } + }).AnyTimes() + + r.load("function_call/tool_workflow_2.json", withID(7492615435881709608), withPublish("v0.0.1")) + id := r.load("function_call/llm_workflow_stream_tool.json") + + exeID := r.testRun(id, map[string]string{ + "input": "this is the user input", + }) + e := r.getProcess(id, exeID) + e.assertSuccess() + assert.Equal(t, "this is the streaming output I don't know.", e.output) + e.tokenEqual(15, 27) + + mockey.PatchConvey("check behavior if stream run", func() { + outerModel.Reset() + innerModel.Reset() + defer r.runServer()() + r.publish(id, "v0.0.1", true) + sseReader := r.openapiStream(id, map[string]any{ + "input": "hello", + }) + err := sseReader.ForEach(t.Context(), func(e *sse.Event) error { + t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data)) + return nil + }) + assert.NoError(t, err) + }) + }) +} func TestSimpleInterruptibleTool(t *testing.T) { mockey.PatchConvey("test simple interruptible tool", t, func() { @@ -2082,233 +2082,231 @@ func TestSimpleInterruptibleTool(t *testing.T) { }) } -//func TestStreamableToolWithMultipleInterrupts(t *testing.T) { -// mockey.PatchConvey("return directly streamable tool with multiple interrupts", t, func() { -// r := newWfTestRunner(t) -// defer r.closeFn() -// -// outerModel := &testutil.UTChatModel{ -// StreamResultProvider: func(index int, in []*schema.Message) (*schema.StreamReader[*schema.Message], error) { -// if index == 0 { -// return schema.StreamReaderFromArray([]*schema.Message{ -// { -// Role: schema.Assistant, -// ToolCalls: []schema.ToolCall{ -// { -// ID: "1", -// Function: schema.FunctionCall{ -// Name: "ts_test_wf_test_wf", -// Arguments: `{"input": "what's your name and age"}`, -// }, -// }, -// }, -// ResponseMeta: &schema.ResponseMeta{ -// Usage: &schema.TokenUsage{ -// PromptTokens: 6, -// CompletionTokens: 7, -// TotalTokens: 13, -// }, -// }, -// }, -// }), nil -// } else if index == 1 { -// return schema.StreamReaderFromArray([]*schema.Message{ -// { -// Role: schema.Assistant, -// Content: "I now know your ", -// ResponseMeta: &schema.ResponseMeta{ -// Usage: &schema.TokenUsage{ -// PromptTokens: 5, -// CompletionTokens: 8, -// TotalTokens: 13, -// }, -// }, -// }, -// { -// Role: schema.Assistant, -// Content: "name is Eino and age is 1.", -// ResponseMeta: &schema.ResponseMeta{ -// Usage: &schema.TokenUsage{ -// CompletionTokens: 10, -// TotalTokens: 17, -// }, -// }, -// }, -// }), nil -// } else { -// return nil, fmt.Errorf("unexpected index: %d", index) -// } -// }, -// } -// -// innerModel := &testutil.UTChatModel{ -// InvokeResultProvider: func(index int, in []*schema.Message) (*schema.Message, error) { -// if index == 0 { -// return &schema.Message{ -// Role: schema.Assistant, -// Content: `{"question": "what's your age?"}`, -// ResponseMeta: &schema.ResponseMeta{ -// Usage: &schema.TokenUsage{ -// PromptTokens: 6, -// CompletionTokens: 7, -// TotalTokens: 13, -// }, -// }, -// }, nil -// } else if index == 1 { -// return &schema.Message{ -// Role: schema.Assistant, -// Content: `{"fields": {"name": "eino", "age": 1}}`, -// ResponseMeta: &schema.ResponseMeta{ -// Usage: &schema.TokenUsage{ -// PromptTokens: 8, -// CompletionTokens: 10, -// TotalTokens: 18, -// }, -// }, -// }, nil -// } else { -// return nil, fmt.Errorf("unexpected index: %d", index) -// } -// }, -// } -// -// r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, params *model.LLMParams) (model2.BaseChatModel, *modelmgr.Model, error) { -// if params.ModelType == 1706077827 { -// outerModel.ModelType = strconv.FormatInt(params.ModelType, 10) -// return outerModel, nil, nil -// } else { -// innerModel.ModelType = strconv.FormatInt(params.ModelType, 10) -// return innerModel, nil, nil -// } -// }).AnyTimes() -// -// r.load("function_call/tool_workflow_3.json", withID(7492615435881709611), withPublish("v0.0.1")) -// id := r.load("function_call/llm_workflow_stream_tool_1.json") -// -// exeID := r.testRun(id, map[string]string{ -// "input": "this is the user input", -// }) -// -// e := r.getProcess(id, exeID) -// assert.NotNil(t, e.event) -// e.tokenEqual(0, 0) -// -// r.testResume(id, exeID, e.event.ID, "my name is eino") -// e2 := r.getProcess(id, exeID, withPreviousEventID(e.event.ID)) -// assert.NotNil(t, e2.event) -// e2.tokenEqual(0, 0) -// -// r.testResume(id, exeID, e2.event.ID, "1 year old") -// e3 := r.getProcess(id, exeID, withPreviousEventID(e2.event.ID)) -// e3.assertSuccess() -// assert.Equal(t, "the name is eino, age is 1", e3.output) -// e3.tokenEqual(20, 24) -// }) -//} +func TestStreamableToolWithMultipleInterrupts(t *testing.T) { + mockey.PatchConvey("return directly streamable tool with multiple interrupts", t, func() { + r := newWfTestRunner(t) + defer r.closeFn() -//func TestNodeWithBatchEnabled(t *testing.T) { -// mockey.PatchConvey("test node with batch enabled", t, func() { -// r := newWfTestRunner(t) -// defer r.closeFn() -// -// r.load("batch/sub_workflow_as_batch.json", withID(7469707607914217512), withPublish("v0.0.1")) -// -// chatModel := &testutil.UTChatModel{ -// InvokeResultProvider: func(index int, in []*schema.Message) (*schema.Message, error) { -// if index == 0 { -// return &schema.Message{ -// Role: schema.Assistant, -// Content: "answer。for index 0", -// ResponseMeta: &schema.ResponseMeta{ -// Usage: &schema.TokenUsage{ -// PromptTokens: 5, -// CompletionTokens: 6, -// TotalTokens: 11, -// }, -// }, -// }, nil -// } else if index == 1 { -// return &schema.Message{ -// Role: schema.Assistant, -// Content: "answer,for index 1", -// ResponseMeta: &schema.ResponseMeta{ -// Usage: &schema.TokenUsage{ -// PromptTokens: 5, -// CompletionTokens: 6, -// TotalTokens: 11, -// }, -// }, -// }, nil -// } else { -// return nil, fmt.Errorf("unexpected index: %d", index) -// } -// }, -// } -// r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(chatModel, nil, nil).AnyTimes() -// -// id := r.load("batch/node_batches.json") -// -// exeID := r.testRun(id, map[string]string{ -// "input": `["first input", "second input"]`, -// }) -// e := r.getProcess(id, exeID) -// e.assertSuccess() -// assert.Equal(t, map[string]any{ -// "output": []any{ -// map[string]any{ -// "output": []any{ -// "answer", -// "for index 0", -// }, -// "input": "answer。for index 0", -// }, -// map[string]any{ -// "output": []any{ -// "answer", -// "for index 1", -// }, -// "input": "answer,for index 1", -// }, -// }, -// }, mustUnmarshalToMap(t, e.output)) -// e.tokenEqual(10, 12) -// -// // verify this workflow has previously succeeded a test run -// result := r.getNodeExeHistory(id, "", "100001", ptr.Of(workflow.NodeHistoryScene_TestRunInput)) -// assert.True(t, len(result.Output) > 0) -// -// // verify querying this node's result for a particular test run -// result = r.getNodeExeHistory(id, exeID, "178876", nil) -// assert.True(t, len(result.Output) > 0) -// -// mockey.PatchConvey("test node debug with batch mode", func() { -// exeID = r.nodeDebug(id, "178876", withNDBatch(map[string]string{"item1": `[{"output":"output_1"},{"output":"output_2"}]`})) -// e = r.getProcess(id, exeID) -// e.assertSuccess() -// assert.Equal(t, map[string]any{ -// "outputList": []any{ -// map[string]any{ -// "input": "output_1", -// "output": []any{"output_1"}, -// }, -// map[string]any{ -// "input": "output_2", -// "output": []any{"output_2"}, -// }, -// }, -// }, mustUnmarshalToMap(t, e.output)) -// -// // verify querying this node's result for this node debug run -// result := r.getNodeExeHistory(id, exeID, "178876", nil) -// assert.Equal(t, mustUnmarshalToMap(t, e.output), mustUnmarshalToMap(t, result.Output)) -// -// // verify querying this node's has succeeded any node debug run -// result = r.getNodeExeHistory(id, "", "178876", ptr.Of(workflow.NodeHistoryScene_TestRunInput)) -// assert.Equal(t, mustUnmarshalToMap(t, e.output), mustUnmarshalToMap(t, result.Output)) -// }) -// }) -//} + outerModel := &testutil.UTChatModel{ + StreamResultProvider: func(index int, in []*schema.Message) (*schema.StreamReader[*schema.Message], error) { + if index == 0 { + return schema.StreamReaderFromArray([]*schema.Message{ + { + Role: schema.Assistant, + ToolCalls: []schema.ToolCall{ + { + ID: "1", + Function: schema.FunctionCall{ + Name: "ts_test_wf_test_wf", + Arguments: `{"input": "what's your name and age"}`, + }, + }, + }, + ResponseMeta: &schema.ResponseMeta{ + Usage: &schema.TokenUsage{ + PromptTokens: 6, + CompletionTokens: 7, + TotalTokens: 13, + }, + }, + }, + }), nil + } else if index == 1 { + return schema.StreamReaderFromArray([]*schema.Message{ + { + Role: schema.Assistant, + Content: "I now know your ", + ResponseMeta: &schema.ResponseMeta{ + Usage: &schema.TokenUsage{ + PromptTokens: 5, + CompletionTokens: 8, + TotalTokens: 13, + }, + }, + }, + { + Role: schema.Assistant, + Content: "name is Eino and age is 1.", + ResponseMeta: &schema.ResponseMeta{ + Usage: &schema.TokenUsage{ + CompletionTokens: 10, + TotalTokens: 17, + }, + }, + }, + }), nil + } else { + return nil, fmt.Errorf("unexpected index: %d", index) + } + }, + } + + innerModel := &testutil.UTChatModel{ + InvokeResultProvider: func(index int, in []*schema.Message) (*schema.Message, error) { + if index == 0 { + return &schema.Message{ + Role: schema.Assistant, + Content: `{"question": "what's your age?"}`, + ResponseMeta: &schema.ResponseMeta{ + Usage: &schema.TokenUsage{ + PromptTokens: 6, + CompletionTokens: 7, + TotalTokens: 13, + }, + }, + }, nil + } else if index == 1 { + return &schema.Message{ + Role: schema.Assistant, + Content: `{"fields": {"name": "eino", "age": 1}}`, + ResponseMeta: &schema.ResponseMeta{ + Usage: &schema.TokenUsage{ + PromptTokens: 8, + CompletionTokens: 10, + TotalTokens: 18, + }, + }, + }, nil + } else { + return nil, fmt.Errorf("unexpected index: %d", index) + } + }, + } + + r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, params *model.LLMParams) (model2.BaseChatModel, *modelmgr.Model, error) { + if params.ModelType == 1706077827 { + outerModel.ModelType = strconv.FormatInt(params.ModelType, 10) + return outerModel, nil, nil + } else { + innerModel.ModelType = strconv.FormatInt(params.ModelType, 10) + return innerModel, nil, nil + } + }).AnyTimes() + + r.load("function_call/tool_workflow_3.json", withID(7492615435881709611), withPublish("v0.0.1")) + id := r.load("function_call/llm_workflow_stream_tool_1.json") + + exeID := r.testRun(id, map[string]string{ + "input": "this is the user input", + }) + + e := r.getProcess(id, exeID) + assert.NotNil(t, e.event) + e.tokenEqual(0, 0) + + r.testResume(id, exeID, e.event.ID, "my name is eino") + e2 := r.getProcess(id, exeID, withPreviousEventID(e.event.ID)) + assert.NotNil(t, e2.event) + e2.tokenEqual(0, 0) + + r.testResume(id, exeID, e2.event.ID, "1 year old") + e3 := r.getProcess(id, exeID, withPreviousEventID(e2.event.ID)) + e3.assertSuccess() + assert.Equal(t, "the name is eino, age is 1", e3.output) + e3.tokenEqual(20, 24) + }) +} + +func TestNodeWithBatchEnabled(t *testing.T) { + mockey.PatchConvey("test node with batch enabled", t, func() { + r := newWfTestRunner(t) + defer r.closeFn() + + r.load("batch/sub_workflow_as_batch.json", withID(7469707607914217512), withPublish("v0.0.1")) + + chatModel := &testutil.UTChatModel{ + InvokeResultProvider: func(index int, in []*schema.Message) (*schema.Message, error) { + if index == 0 { + return &schema.Message{ + Role: schema.Assistant, + Content: "answer。for index 0", + ResponseMeta: &schema.ResponseMeta{ + Usage: &schema.TokenUsage{ + PromptTokens: 5, + CompletionTokens: 6, + TotalTokens: 11, + }, + }, + }, nil + } else if index == 1 { + return &schema.Message{ + Role: schema.Assistant, + Content: "answer,for index 1", + ResponseMeta: &schema.ResponseMeta{ + Usage: &schema.TokenUsage{ + PromptTokens: 5, + CompletionTokens: 6, + TotalTokens: 11, + }, + }, + }, nil + } else { + return nil, fmt.Errorf("unexpected index: %d", index) + } + }, + } + r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(chatModel, nil, nil).AnyTimes() + + id := r.load("batch/node_batches.json") + + exeID := r.testRun(id, map[string]string{ + "input": `["first input", "second input"]`, + }) + e := r.getProcess(id, exeID) + e.assertSuccess() + outputMap := mustUnmarshalToMap(t, e.output) + assert.Contains(t, outputMap["output"], map[string]any{ + "output": []any{ + "answer", + "for index 0", + }, + "input": "answer。for index 0", + }) + assert.Contains(t, outputMap["output"], map[string]any{ + "output": []any{ + "answer", + "for index 1", + }, + "input": "answer,for index 1", + }) + assert.Equal(t, 2, len(outputMap["output"].([]any))) + e.tokenEqual(10, 12) + + // verify this workflow has previously succeeded a test run + result := r.getNodeExeHistory(id, "", "100001", ptr.Of(workflow.NodeHistoryScene_TestRunInput)) + assert.True(t, len(result.Output) > 0) + + // verify querying this node's result for a particular test run + result = r.getNodeExeHistory(id, exeID, "178876", nil) + assert.True(t, len(result.Output) > 0) + + mockey.PatchConvey("test node debug with batch mode", func() { + exeID = r.nodeDebug(id, "178876", withNDBatch(map[string]string{"item1": `[{"output":"output_1"},{"output":"output_2"}]`})) + e = r.getProcess(id, exeID) + e.assertSuccess() + assert.Equal(t, map[string]any{ + "outputList": []any{ + map[string]any{ + "input": "output_1", + "output": []any{"output_1"}, + }, + map[string]any{ + "input": "output_2", + "output": []any{"output_2"}, + }, + }, + }, mustUnmarshalToMap(t, e.output)) + + // verify querying this node's result for this node debug run + result := r.getNodeExeHistory(id, exeID, "178876", nil) + assert.Equal(t, mustUnmarshalToMap(t, e.output), mustUnmarshalToMap(t, result.Output)) + + // verify querying this node's has succeeded any node debug run + result = r.getNodeExeHistory(id, "", "178876", ptr.Of(workflow.NodeHistoryScene_TestRunInput)) + assert.Equal(t, mustUnmarshalToMap(t, e.output), mustUnmarshalToMap(t, result.Output)) + }) + }) +} func TestStartNodeDefaultValues(t *testing.T) { mockey.PatchConvey("default values", t, func() { diff --git a/backend/crossdomain/contract/workflow/workflow.go b/backend/crossdomain/contract/workflow/workflow.go index e15a0c30..d3020aee 100644 --- a/backend/crossdomain/contract/workflow/workflow.go +++ b/backend/crossdomain/contract/workflow/workflow.go @@ -40,7 +40,7 @@ type Workflow interface { GetWorkflowIDsByAppID(ctx context.Context, appID int64) ([]int64, error) SyncExecuteWorkflow(ctx context.Context, config workflowModel.ExecuteConfig, input map[string]any) (*workflowEntity.WorkflowExecution, vo.TerminatePlan, error) WithExecuteConfig(cfg workflowModel.ExecuteConfig) einoCompose.Option - WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message]) + WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message], func()) } type ExecuteConfig = workflowModel.ExecuteConfig diff --git a/backend/crossdomain/impl/workflow/workflow.go b/backend/crossdomain/impl/workflow/workflow.go index 9a417d61..bcb4b1fc 100644 --- a/backend/crossdomain/impl/workflow/workflow.go +++ b/backend/crossdomain/impl/workflow/workflow.go @@ -66,7 +66,7 @@ func (i *impl) WithExecuteConfig(cfg workflowModel.ExecuteConfig) einoCompose.Op return i.DomainSVC.WithExecuteConfig(cfg) } -func (i *impl) WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message]) { +func (i *impl) WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message], func()) { return i.DomainSVC.WithMessagePipe() } diff --git a/backend/domain/agent/singleagent/internal/agentflow/agent_flow_runner.go b/backend/domain/agent/singleagent/internal/agentflow/agent_flow_runner.go index a496825a..689ad5a3 100644 --- a/backend/domain/agent/singleagent/internal/agentflow/agent_flow_runner.go +++ b/backend/domain/agent/singleagent/internal/agentflow/agent_flow_runner.go @@ -74,6 +74,7 @@ func (r *AgentRunner) StreamExecute(ctx context.Context, req *AgentRequest) ( var composeOpts []compose.Option var pipeMsgOpt compose.Option var workflowMsgSr *schema.StreamReader[*crossworkflow.WorkflowMessage] + var workflowMsgCloser func() if r.containWfTool { cfReq := crossworkflow.ExecuteConfig{ AgentID: &req.Identity.AgentID, @@ -88,7 +89,7 @@ func (r *AgentRunner) StreamExecute(ctx context.Context, req *AgentRequest) ( } wfConfig := crossworkflow.DefaultSVC().WithExecuteConfig(cfReq) composeOpts = append(composeOpts, wfConfig) - pipeMsgOpt, workflowMsgSr = crossworkflow.DefaultSVC().WithMessagePipe() + pipeMsgOpt, workflowMsgSr, workflowMsgCloser = crossworkflow.DefaultSVC().WithMessagePipe() composeOpts = append(composeOpts, pipeMsgOpt) } @@ -120,6 +121,9 @@ func (r *AgentRunner) StreamExecute(ctx context.Context, req *AgentRequest) ( sw.Send(nil, errors.New("internal server error")) } + if workflowMsgCloser != nil { + workflowMsgCloser() + } sw.Close() }() _, _ = r.runner.Stream(ctx, req, composeOpts...) @@ -136,6 +140,7 @@ func (r *AgentRunner) processWfMidAnswerStream(_ context.Context, sw *schema.Str if swT != nil { swT.Close() } + wfStream.Close() }() for { msg, err := wfStream.Recv() diff --git a/backend/domain/workflow/component_interface.go b/backend/domain/workflow/component_interface.go index 8aa75235..e0f0145c 100644 --- a/backend/domain/workflow/component_interface.go +++ b/backend/domain/workflow/component_interface.go @@ -48,7 +48,7 @@ type Executable interface { type AsTool interface { WorkflowAsModelTool(ctx context.Context, policies []*vo.GetPolicy) ([]ToolFromWorkflow, error) - WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message]) + WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message], func()) WithExecuteConfig(cfg workflowModel.ExecuteConfig) compose.Option WithResumeToolWorkflow(resumingEvent *entity.ToolInterruptEvent, resumeData string, allInterruptEvents map[string]*entity.ToolInterruptEvent) compose.Option diff --git a/backend/domain/workflow/entity/message.go b/backend/domain/workflow/entity/message.go index d24669ad..2b6b9080 100644 --- a/backend/domain/workflow/entity/message.go +++ b/backend/domain/workflow/entity/message.go @@ -85,6 +85,7 @@ type ToolResponseInfo struct { FunctionInfo CallID string Response string + Complete bool } type ToolType = workflow.PluginType diff --git a/backend/domain/workflow/internal/compose/designate_option.go b/backend/domain/workflow/internal/compose/designate_option.go index 31d69899..851b1630 100644 --- a/backend/domain/workflow/internal/compose/designate_option.go +++ b/backend/domain/workflow/internal/compose/designate_option.go @@ -23,7 +23,6 @@ import ( "strconv" einoCompose "github.com/cloudwego/eino/compose" - "github.com/cloudwego/eino/schema" "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin" model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow" @@ -47,7 +46,7 @@ func (r *WorkflowRunner) designateOptions(ctx context.Context) (context.Context, workflowSC = r.schema eventChan = r.eventChan resumedEvent = r.interruptEvent - sw = r.streamWriter + sw = r.container ) if wb.AppID != nil && exeCfg.AppID == nil { @@ -148,7 +147,7 @@ func (r *WorkflowRunner) designateOptionsForSubWorkflow(ctx context.Context, var ( resumeEvent = r.interruptEvent eventChan = r.eventChan - sw = r.streamWriter + container = r.container ) subHandler := execute.NewSubWorkflowHandler( parentHandler, @@ -186,7 +185,7 @@ func (r *WorkflowRunner) designateOptionsForSubWorkflow(ctx context.Context, opts = append(opts, WrapOpt(subO, ns.Key)) } } else if subNS.Type == entity.NodeTypeLLM { - llmNodeOpts, err := llmToolCallbackOptions(ctx, subNS, eventChan, sw) + llmNodeOpts, err := llmToolCallbackOptions(ctx, subNS, eventChan, container) if err != nil { return nil, err } @@ -209,7 +208,7 @@ func (r *WorkflowRunner) designateOptionsForSubWorkflow(ctx context.Context, opts = append(opts, WrapOpt(WrapOpt(subO, parent.Key), ns.Key)) } } else if subNS.Type == entity.NodeTypeLLM { - llmNodeOpts, err := llmToolCallbackOptions(ctx, subNS, eventChan, sw) + llmNodeOpts, err := llmToolCallbackOptions(ctx, subNS, eventChan, container) if err != nil { return nil, err } @@ -224,7 +223,7 @@ func (r *WorkflowRunner) designateOptionsForSubWorkflow(ctx context.Context, } func llmToolCallbackOptions(ctx context.Context, ns *schema2.NodeSchema, eventChan chan *execute.Event, - sw *schema.StreamWriter[*entity.Message]) ( + container *execute.StreamContainer) ( opts []einoCompose.Option, err error) { // this is a LLM node. // check if it has any tools, if no tools, then no callback options needed @@ -280,6 +279,12 @@ func llmToolCallbackOptions(ctx context.Context, ns *schema2.NodeSchema, eventCh opt = einoCompose.WithLambdaOption(nodes.WithOptsForNested(opt)).DesignateNode(string(ns.Key)) opts = append(opts, opt) } + + if container != nil { + toolMsgOpt := llm.WithToolWorkflowStreamContainer(container) + opt := einoCompose.WithLambdaOption(toolMsgOpt).DesignateNode(string(ns.Key)) + opts = append(opts, opt) + } } if fcParams.PluginFCParam != nil { for _, p := range fcParams.PluginFCParam.PluginList { @@ -321,11 +326,5 @@ func llmToolCallbackOptions(ctx context.Context, ns *schema2.NodeSchema, eventCh } } - if sw != nil { - toolMsgOpt := llm.WithToolWorkflowMessageWriter(sw) - opt := einoCompose.WithLambdaOption(toolMsgOpt).DesignateNode(string(ns.Key)) - opts = append(opts, opt) - } - return opts, nil } diff --git a/backend/domain/workflow/internal/compose/workflow_run.go b/backend/domain/workflow/internal/compose/workflow_run.go index 919f25a6..082043fc 100644 --- a/backend/domain/workflow/internal/compose/workflow_run.go +++ b/backend/domain/workflow/internal/compose/workflow_run.go @@ -38,15 +38,17 @@ import ( "github.com/coze-dev/coze-studio/backend/pkg/lang/ternary" "github.com/coze-dev/coze-studio/backend/pkg/logs" "github.com/coze-dev/coze-studio/backend/pkg/safego" + "github.com/coze-dev/coze-studio/backend/types/errno" ) type WorkflowRunner struct { - basic *entity.WorkflowBasic - input string - resumeReq *entity.ResumeRequest - schema *schema2.WorkflowSchema - streamWriter *schema.StreamWriter[*entity.Message] - config model.ExecuteConfig + basic *entity.WorkflowBasic + input string + resumeReq *entity.ResumeRequest + schema *schema2.WorkflowSchema + sw *schema.StreamWriter[*entity.Message] + container *execute.StreamContainer + config model.ExecuteConfig executeID int64 eventChan chan *execute.Event @@ -84,13 +86,19 @@ func NewWorkflowRunner(b *entity.WorkflowBasic, sc *schema2.WorkflowSchema, conf opt(options) } + var container *execute.StreamContainer + if options.streamWriter != nil { + container = execute.NewStreamContainer(options.streamWriter) + } + return &WorkflowRunner{ - basic: b, - input: options.input, - resumeReq: options.resumeReq, - schema: sc, - streamWriter: options.streamWriter, - config: config, + basic: b, + input: options.input, + resumeReq: options.resumeReq, + schema: sc, + sw: options.streamWriter, + container: container, + config: config, } } @@ -108,14 +116,16 @@ func (r *WorkflowRunner) Prepare(ctx context.Context) ( resumeReq = r.resumeReq wb = r.basic sc = r.schema - sw = r.streamWriter + sw = r.sw + container = r.container config = r.config ) if r.resumeReq == nil { executeID, err = repo.GenID(ctx) if err != nil { - return ctx, 0, nil, nil, fmt.Errorf("failed to generate workflow execute ID: %w", err) + return ctx, 0, nil, nil, vo.WrapError(errno.ErrIDGenError, + fmt.Errorf("failed to generate workflow execute ID: %w", err)) } } else { executeID = resumeReq.ExecuteID @@ -148,6 +158,15 @@ func (r *WorkflowRunner) Prepare(ctx context.Context) ( r.eventChan = eventChan r.interruptEvent = interruptEvent + if container != nil { + go container.PipeAll() + defer func() { + if err != nil { + container.Done() + } + }() + } + ctx, composeOpts, err := r.designateOptions(ctx) if err != nil { return ctx, 0, nil, nil, err @@ -277,8 +296,8 @@ func (r *WorkflowRunner) Prepare(ctx context.Context) ( } }() defer func() { - if sw != nil { - sw.Close() + if container != nil { + container.Done() } }() diff --git a/backend/domain/workflow/internal/compose/workflow_tool.go b/backend/domain/workflow/internal/compose/workflow_tool.go index ee8b65d7..26cd0c0d 100644 --- a/backend/domain/workflow/internal/compose/workflow_tool.go +++ b/backend/domain/workflow/internal/compose/workflow_tool.go @@ -33,17 +33,22 @@ import ( schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/pkg/logs" "github.com/coze-dev/coze-studio/backend/pkg/sonic" + "github.com/coze-dev/coze-studio/backend/types/errno" ) const answerKey = "output" type invokableWorkflow struct { + workflowTool + invoke func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (map[string]any, error) +} + +type workflowTool struct { info *schema.ToolInfo - invoke func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (map[string]any, error) - terminatePlan vo.TerminatePlan wfEntity *entity.Workflow sc *schema2.WorkflowSchema repo wf.Repository + terminatePlan vo.TerminatePlan } func NewInvokableWorkflow(info *schema.ToolInfo, @@ -54,12 +59,14 @@ func NewInvokableWorkflow(info *schema.ToolInfo, repo wf.Repository, ) wf.ToolFromWorkflow { return &invokableWorkflow{ - info: info, - invoke: invoke, - terminatePlan: terminatePlan, - wfEntity: wfEntity, - sc: sc, - repo: repo, + workflowTool: workflowTool{ + info: info, + wfEntity: wfEntity, + sc: sc, + repo: repo, + terminatePlan: terminatePlan, + }, + invoke: invoke, } } @@ -77,6 +84,52 @@ func resumeOnce(rInfo *entity.ResumeRequest, callID string, allIEs map[string]*e } } +func (wt *workflowTool) prepare(ctx context.Context, rInfo *entity.ResumeRequest, argumentsInJSON string, opts ...tool.Option) ( + cancelCtx context.Context, executeID int64, input map[string]any, callOpts []einoCompose.Option, err error) { + cfg := execute.GetExecuteConfig(opts...) + + var runOpts []WorkflowRunnerOption + if rInfo != nil && !rInfo.Resumed { + runOpts = append(runOpts, WithResumeReq(rInfo)) + } else { + runOpts = append(runOpts, WithInput(argumentsInJSON)) + } + if container := execute.GetParentStreamContainer(opts...); container != nil { + sr, sw := schema.Pipe[*entity.Message](10) + container.AddChild(sr) + runOpts = append(runOpts, WithStreamWriter(sw)) + } + + var ws *nodes.ConversionWarnings + + if (rInfo == nil || rInfo.Resumed) && len(wt.wfEntity.InputParams) > 0 { + if err = sonic.UnmarshalString(argumentsInJSON, &input); err != nil { + err = vo.WrapError(errno.ErrSerializationDeserializationFail, err) + return + } + + var entryNode *schema2.NodeSchema + for _, node := range wt.sc.Nodes { + if node.Type == entity.NodeTypeEntry { + entryNode = node + break + } + } + if entryNode == nil { + panic("entry node not found in tool workflow") + } + input, ws, err = nodes.ConvertInputs(ctx, input, entryNode.OutputTypes) + if err != nil { + return + } else if ws != nil { + logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws) + } + } + + cancelCtx, executeID, callOpts, _, err = NewWorkflowRunner(wt.wfEntity.GetBasic(), wt.sc, cfg, runOpts...).Prepare(ctx) + return +} + func (i *invokableWorkflow) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { rInfo, allIEs := execute.GetResumeRequest(opts...) var ( @@ -97,52 +150,9 @@ func (i *invokableWorkflow) InvokableRun(ctx context.Context, argumentsInJSON st return "", einoCompose.InterruptAndRerun } - cfg := execute.GetExecuteConfig(opts...) defer resumeOnce(rInfo, callID, allIEs) - var runOpts []WorkflowRunnerOption - if rInfo != nil && !rInfo.Resumed { - runOpts = append(runOpts, WithResumeReq(rInfo)) - } else { - runOpts = append(runOpts, WithInput(argumentsInJSON)) - } - if sw := execute.GetIntermediateStreamWriter(opts...); sw != nil { - runOpts = append(runOpts, WithStreamWriter(sw)) - } - - var ( - cancelCtx context.Context - executeID int64 - callOpts []einoCompose.Option - in map[string]any - err error - ws *nodes.ConversionWarnings - ) - - if rInfo == nil && len(i.wfEntity.InputParams) > 0 { - if err = sonic.UnmarshalString(argumentsInJSON, &in); err != nil { - return "", err - } - - var entryNode *schema2.NodeSchema - for _, node := range i.sc.Nodes { - if node.Type == entity.NodeTypeEntry { - entryNode = node - break - } - } - if entryNode == nil { - panic("entry node not found in tool workflow") - } - in, ws, err = nodes.ConvertInputs(ctx, in, entryNode.OutputTypes) - if err != nil { - return "", err - } else if ws != nil { - logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws) - } - } - - cancelCtx, executeID, callOpts, _, err = NewWorkflowRunner(i.wfEntity.GetBasic(), i.sc, cfg, runOpts...).Prepare(ctx) + cancelCtx, executeID, in, callOpts, err := i.prepare(ctx, rInfo, argumentsInJSON, opts...) if err != nil { return "", err } @@ -198,12 +208,8 @@ func (i *invokableWorkflow) GetWorkflow() *entity.Workflow { } type streamableWorkflow struct { - info *schema.ToolInfo - stream func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (*schema.StreamReader[map[string]any], error) - terminatePlan vo.TerminatePlan - wfEntity *entity.Workflow - sc *schema2.WorkflowSchema - repo wf.Repository + workflowTool + stream func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (*schema.StreamReader[map[string]any], error) } func NewStreamableWorkflow(info *schema.ToolInfo, @@ -214,12 +220,14 @@ func NewStreamableWorkflow(info *schema.ToolInfo, repo wf.Repository, ) wf.ToolFromWorkflow { return &streamableWorkflow{ - info: info, - stream: stream, - terminatePlan: terminatePlan, - wfEntity: wfEntity, - sc: sc, - repo: repo, + workflowTool: workflowTool{ + info: info, + wfEntity: wfEntity, + sc: sc, + repo: repo, + terminatePlan: terminatePlan, + }, + stream: stream, } } @@ -247,52 +255,9 @@ func (s *streamableWorkflow) StreamableRun(ctx context.Context, argumentsInJSON return nil, einoCompose.InterruptAndRerun } - cfg := execute.GetExecuteConfig(opts...) defer resumeOnce(rInfo, callID, allIEs) - var runOpts []WorkflowRunnerOption - if rInfo != nil && !rInfo.Resumed { - runOpts = append(runOpts, WithResumeReq(rInfo)) - } else { - runOpts = append(runOpts, WithInput(argumentsInJSON)) - } - if sw := execute.GetIntermediateStreamWriter(opts...); sw != nil { - runOpts = append(runOpts, WithStreamWriter(sw)) - } - - var ( - cancelCtx context.Context - executeID int64 - callOpts []einoCompose.Option - in map[string]any - err error - ws *nodes.ConversionWarnings - ) - - if rInfo == nil && len(s.wfEntity.InputParams) > 0 { - if err = sonic.UnmarshalString(argumentsInJSON, &in); err != nil { - return nil, err - } - - var entryNode *schema2.NodeSchema - for _, node := range s.sc.Nodes { - if node.Type == entity.NodeTypeEntry { - entryNode = node - break - } - } - if entryNode == nil { - panic("entry node not found in tool workflow") - } - in, ws, err = nodes.ConvertInputs(ctx, in, entryNode.OutputTypes) - if err != nil { - return nil, err - } else if ws != nil { - logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws) - } - } - - cancelCtx, executeID, callOpts, _, err = NewWorkflowRunner(s.wfEntity.GetBasic(), s.sc, cfg, runOpts...).Prepare(ctx) + cancelCtx, executeID, in, callOpts, err := s.prepare(ctx, rInfo, argumentsInJSON, opts...) if err != nil { return nil, err } diff --git a/backend/domain/workflow/internal/execute/callback.go b/backend/domain/workflow/internal/execute/callback.go index e017ea5c..a5c3696d 100644 --- a/backend/domain/workflow/internal/execute/callback.go +++ b/backend/domain/workflow/internal/execute/callback.go @@ -370,6 +370,10 @@ func (w *WorkflowHandler) OnError(ctx context.Context, info *callbacks.RunInfo, interruptEvent.EventType, interruptEvent.NodeKey) } + if c.TokenCollector != nil { // wait until all streaming chunks are collected + _ = c.TokenCollector.wait() + } + done := make(chan struct{}) w.ch <- &Event{ @@ -1309,6 +1313,7 @@ func (t *ToolHandler) OnEnd(ctx context.Context, info *callbacks.RunInfo, FunctionInfo: t.info, CallID: compose.GetToolCallID(ctx), Response: output.Response, + Complete: true, }, } @@ -1347,6 +1352,7 @@ func (t *ToolHandler) OnEndWithStreamOutput(ctx context.Context, info *callbacks toolResponse: &entity.ToolResponseInfo{ FunctionInfo: t.info, CallID: callID, + Complete: true, }, } } diff --git a/backend/domain/workflow/internal/execute/collect_token.go b/backend/domain/workflow/internal/execute/collect_token.go index 80023fcd..8fc7aed2 100644 --- a/backend/domain/workflow/internal/execute/collect_token.go +++ b/backend/domain/workflow/internal/execute/collect_token.go @@ -76,6 +76,20 @@ func (t *TokenCollector) add(i int) { return } +func (t *TokenCollector) startStreamCounting() { + t.wg.Add(1) + if t.Parent != nil { + t.Parent.startStreamCounting() + } +} + +func (t *TokenCollector) finishStreamCounting() { + t.wg.Done() + if t.Parent != nil { + t.Parent.finishStreamCounting() + } +} + func getTokenCollector(ctx context.Context) *TokenCollector { c := GetExeCtx(ctx) if c == nil { @@ -92,7 +106,6 @@ func GetTokenCallbackHandler() callbacks.Handler { return ctx } c.add(1) - //c.wg.Add(1) return ctx }, OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *model.CallbackOutput) context.Context { @@ -114,6 +127,7 @@ func GetTokenCallbackHandler() callbacks.Handler { output.Close() return ctx } + c.startStreamCounting() safego.Go(ctx, func() { defer func() { output.Close() @@ -141,6 +155,7 @@ func GetTokenCallbackHandler() callbacks.Handler { if newC.TotalTokens > 0 { c.addTokenUsage(newC) } + c.finishStreamCounting() }) return ctx }, diff --git a/backend/domain/workflow/internal/execute/event_handle.go b/backend/domain/workflow/internal/execute/event_handle.go index 88d125dd..b6cfcc8e 100644 --- a/backend/domain/workflow/internal/execute/event_handle.go +++ b/backend/domain/workflow/internal/execute/event_handle.go @@ -789,6 +789,7 @@ func HandleExecuteEvent(ctx context.Context, logs.CtxInfof(ctx, "[handleExecuteEvent] finish, returned event type: %v, workflow id: %d", event.Type, event.Context.RootWorkflowBasic.ID) cancelTicker.Stop() // Clean up timer + waitUntilToolFinish(ctx) if timeoutFn != nil { timeoutFn() } @@ -880,6 +881,7 @@ func cacheToolStreamingResponse(ctx context.Context, event *Event) { c[event.NodeKey][event.toolResponse.CallID].output = event.toolResponse } c[event.NodeKey][event.toolResponse.CallID].output.Response += event.toolResponse.Response + c[event.NodeKey][event.toolResponse.CallID].output.Complete = event.toolResponse.Complete } func getFCInfos(ctx context.Context, nodeKey vo.NodeKey) map[string]*fcInfo { @@ -887,6 +889,35 @@ func getFCInfos(ctx context.Context, nodeKey vo.NodeKey) map[string]*fcInfo { return c[nodeKey] } +func waitUntilToolFinish(ctx context.Context) { + var cnt int +outer: + for { + if cnt > 1000 { + return + } + + c := ctx.Value(fcCacheKey{}).(map[vo.NodeKey]map[string]*fcInfo) + if len(c) == 0 { + return + } + + for _, m := range c { + for _, info := range m { + if info.output == nil { + cnt++ + continue outer + } + + if !info.output.Complete { + cnt++ + continue outer + } + } + } + } +} + func (f *fcInfo) inputString() string { if f.input == nil { return "" diff --git a/backend/domain/workflow/internal/execute/stream_container.go b/backend/domain/workflow/internal/execute/stream_container.go new file mode 100644 index 00000000..02162331 --- /dev/null +++ b/backend/domain/workflow/internal/execute/stream_container.go @@ -0,0 +1,74 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package execute + +import ( + "errors" + "io" + "sync" + + "github.com/cloudwego/eino/schema" + + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" +) + +type StreamContainer struct { + sw *schema.StreamWriter[*entity.Message] + subStreams chan *schema.StreamReader[*entity.Message] + wg sync.WaitGroup +} + +func NewStreamContainer(sw *schema.StreamWriter[*entity.Message]) *StreamContainer { + return &StreamContainer{ + sw: sw, + subStreams: make(chan *schema.StreamReader[*entity.Message]), + } +} + +func (sc *StreamContainer) AddChild(sr *schema.StreamReader[*entity.Message]) { + sc.wg.Add(1) + sc.subStreams <- sr +} + +func (sc *StreamContainer) PipeAll() { + sc.wg.Add(1) + + for sr := range sc.subStreams { + go func() { + defer sr.Close() + + for { + msg, err := sr.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + sc.wg.Done() + return + } + } + + sc.sw.Send(msg, err) + } + }() + } +} + +func (sc *StreamContainer) Done() { + sc.wg.Done() + sc.wg.Wait() + close(sc.subStreams) + sc.sw.Close() +} diff --git a/backend/domain/workflow/internal/execute/tool_option.go b/backend/domain/workflow/internal/execute/tool_option.go index 5bff5c00..05e4cebb 100644 --- a/backend/domain/workflow/internal/execute/tool_option.go +++ b/backend/domain/workflow/internal/execute/tool_option.go @@ -27,7 +27,7 @@ import ( type workflowToolOption struct { resumeReq *entity.ResumeRequest - sw *schema.StreamWriter[*entity.Message] + streamContainer *StreamContainer exeCfg workflowModel.ExecuteConfig allInterruptEvents map[string]*entity.ToolInterruptEvent parentTokenCollector *TokenCollector @@ -40,9 +40,9 @@ func WithResume(req *entity.ResumeRequest, all map[string]*entity.ToolInterruptE }) } -func WithIntermediateStreamWriter(sw *schema.StreamWriter[*entity.Message]) tool.Option { +func WithParentStreamContainer(sc *StreamContainer) tool.Option { return tool.WrapImplSpecificOptFn(func(opts *workflowToolOption) { - opts.sw = sw + opts.streamContainer = sc }) } @@ -57,9 +57,9 @@ func GetResumeRequest(opts ...tool.Option) (*entity.ResumeRequest, map[string]*e return opt.resumeReq, opt.allInterruptEvents } -func GetIntermediateStreamWriter(opts ...tool.Option) *schema.StreamWriter[*entity.Message] { +func GetParentStreamContainer(opts ...tool.Option) *StreamContainer { opt := tool.GetImplSpecificOptions(&workflowToolOption{}, opts...) - return opt.sw + return opt.streamContainer } func GetExecuteConfig(opts ...tool.Option) workflowModel.ExecuteConfig { @@ -67,11 +67,22 @@ func GetExecuteConfig(opts ...tool.Option) workflowModel.ExecuteConfig { return opt.exeCfg } -// WithMessagePipe returns an Option which is meant to be passed to the tool workflow, as well as a StreamReader to read the messages from the tool workflow. -// This Option will apply to ALL workflow tools to be executed by eino's ToolsNode. The workflow tools will emit messages to this stream. +// WithMessagePipe returns an Option which is meant to be passed to the tool workflow, +// as well as a StreamReader to read the messages from the tool workflow. +// This Option will apply to ALL workflow tools to be executed by eino's ToolsNode. +// The workflow tools will emit messages to this stream. // The caller can receive from the returned StreamReader to get the messages from the tool workflow. -func WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message]) { +func WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message], func()) { sr, sw := schema.Pipe[*entity.Message](10) - opt := compose.WithToolsNodeOption(compose.WithToolOption(WithIntermediateStreamWriter(sw))) - return opt, sr + container := &StreamContainer{ + sw: sw, + subStreams: make(chan *schema.StreamReader[*entity.Message]), + } + + go container.PipeAll() + + opt := compose.WithToolsNodeOption(compose.WithToolOption(WithParentStreamContainer(container))) + return opt, sr, func() { + container.Done() + } } diff --git a/backend/domain/workflow/internal/nodes/batch/batch.go b/backend/domain/workflow/internal/nodes/batch/batch.go index 61c422b4..6f8427d3 100644 --- a/backend/domain/workflow/internal/nodes/batch/batch.go +++ b/backend/domain/workflow/internal/nodes/batch/batch.go @@ -446,9 +446,6 @@ func (b *Batch) Invoke(ctx context.Context, in map[string]any, opts ...nodes.Nod return nil, err } - fmt.Println("save interruptEvent in state within batch: ", iEvent) - fmt.Println("save composite info in state within batch: ", compState) - return nil, compose.InterruptAndRerun } else { err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error { diff --git a/backend/domain/workflow/internal/nodes/convert.go b/backend/domain/workflow/internal/nodes/convert.go index a1915f95..221acae4 100644 --- a/backend/domain/workflow/internal/nodes/convert.go +++ b/backend/domain/workflow/internal/nodes/convert.go @@ -92,7 +92,6 @@ func ConvertInputs(ctx context.Context, in map[string]any, tInfo map[string]*vo. t, ok := tInfo[k] if !ok { // for input fields not explicitly defined, just pass the string value through - logs.CtxWarnf(ctx, "input %s not found in type info", k) if !options.skipUnknownFields { out[k] = in[k] } @@ -323,7 +322,6 @@ func convertToObject(ctx context.Context, in any, path string, t *vo.TypeInfo, o propType, ok := t.Properties[k] if !ok { // for input fields not explicitly defined, just pass the value through - logs.CtxWarnf(ctx, "input %s.%s not found in type info", path, k) if !options.skipUnknownFields { out[k] = v } diff --git a/backend/domain/workflow/internal/nodes/llm/llm.go b/backend/domain/workflow/internal/nodes/llm/llm.go index bf777335..58a6e762 100644 --- a/backend/domain/workflow/internal/nodes/llm/llm.go +++ b/backend/domain/workflow/internal/nodes/llm/llm.go @@ -143,6 +143,7 @@ const ( knowledgeUserPromptTemplateKey = "knowledge_user_prompt_prefix" templateNodeKey = "template" llmNodeKey = "llm" + reactGraphName = "workflow_llm_react_agent" outputConvertNodeKey = "output_convert" ) @@ -620,6 +621,7 @@ func (c *Config) Build(ctx context.Context, ns *schema2.NodeSchema, _ ...schema2 ToolCallingModel: m, ToolsConfig: compose.ToolsNodeConfig{Tools: tools}, ModelNodeName: agentModelName, + GraphName: reactGraphName, } if len(toolsReturnDirectly) > 0 { @@ -635,7 +637,7 @@ func (c *Config) Build(ctx context.Context, ns *schema2.NodeSchema, _ ...schema2 } agentNode, opts := reactAgent.ExportGraph() - opts = append(opts, compose.WithNodeName("workflow_llm_react_agent")) + opts = append(opts, compose.WithNodeName(reactGraphName)) _ = g.AddGraphNode(llmNodeKey, agentNode, opts...) } else { _ = g.AddChatModelNode(llmNodeKey, modelWithInfo) @@ -867,12 +869,12 @@ func jsonParse(ctx context.Context, data string, schema_ map[string]*vo.TypeInfo } type llmOptions struct { - toolWorkflowSW *schema.StreamWriter[*entity.Message] + toolWorkflowContainer *execute.StreamContainer } -func WithToolWorkflowMessageWriter(sw *schema.StreamWriter[*entity.Message]) nodes.NodeOption { +func WithToolWorkflowStreamContainer(container *execute.StreamContainer) nodes.NodeOption { return nodes.WrapImplSpecificOptFn(func(o *llmOptions) { - o.toolWorkflowSW = sw + o.toolWorkflowContainer = container }) } @@ -880,7 +882,8 @@ type llmState = map[string]any const agentModelName = "agent_model" -func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...nodes.NodeOption) (composeOpts []compose.Option, resumingEvent *entity.InterruptEvent, err error) { +func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...nodes.NodeOption) ( + composeOpts []compose.Option, resumingEvent *entity.InterruptEvent, err error) { c := execute.GetExeCtx(ctx) if c != nil { resumingEvent = c.NodeCtx.ResumingEvent @@ -890,7 +893,7 @@ func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...nodes.NodeO if c != nil && c.RootCtx.ResumeEvent != nil { // check if we are not resuming, but previously interrupted. Interrupt immediately. if resumingEvent == nil { - err := compose.ProcessState(ctx, func(ctx context.Context, state ToolInterruptEventStore) error { + err = compose.ProcessState(ctx, func(ctx context.Context, state ToolInterruptEventStore) error { var e error previousToolES, e = state.GetToolInterruptEvents(c.NodeKey) if e != nil { @@ -899,11 +902,12 @@ func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...nodes.NodeO return nil }) if err != nil { - return nil, nil, err + return } if len(previousToolES) > 0 { - return nil, nil, compose.InterruptAndRerun + err = compose.InterruptAndRerun + return } } } @@ -936,7 +940,7 @@ func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...nodes.NodeO return e }) if err != nil { - return nil, nil, err + return } composeOpts = append(composeOpts, compose.WithToolsNodeOption( compose.WithToolOption( @@ -986,27 +990,9 @@ func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...nodes.NodeO } llmOpts := nodes.GetImplSpecificOptions(&llmOptions{}, opts...) - if llmOpts.toolWorkflowSW != nil { - toolMsgOpt, toolMsgSR := execute.WithMessagePipe() - composeOpts = append(composeOpts, toolMsgOpt) - - safego.Go(ctx, func() { - defer toolMsgSR.Close() - for { - msg, err := toolMsgSR.Recv() - if err != nil { - if err == io.EOF { - return - } - logs.CtxErrorf(ctx, "failed to receive message from tool workflow: %v", err) - return - } - - logs.Infof("received message from tool workflow: %+v", msg) - - llmOpts.toolWorkflowSW.Send(msg, nil) - } - }) + if container := llmOpts.toolWorkflowContainer; container != nil { + composeOpts = append(composeOpts, compose.WithToolsNodeOption(compose.WithToolOption( + execute.WithParentStreamContainer(container)))) } resolvedSources, err := nodes.ResolveStreamSources(ctx, l.fullSources) diff --git a/backend/domain/workflow/service/as_tool_impl.go b/backend/domain/workflow/service/as_tool_impl.go index 67b4e39c..00c79bd7 100644 --- a/backend/domain/workflow/service/as_tool_impl.go +++ b/backend/domain/workflow/service/as_tool_impl.go @@ -33,7 +33,7 @@ type asToolImpl struct { repo workflow.Repository } -func (a *asToolImpl) WithMessagePipe() (einoCompose.Option, *schema.StreamReader[*entity.Message]) { +func (a *asToolImpl) WithMessagePipe() (einoCompose.Option, *schema.StreamReader[*entity.Message], func()) { return execute.WithMessagePipe() } diff --git a/backend/internal/mock/domain/workflow/interface.go b/backend/internal/mock/domain/workflow/interface.go index d1cdc950..224a4cd2 100644 --- a/backend/internal/mock/domain/workflow/interface.go +++ b/backend/internal/mock/domain/workflow/interface.go @@ -1,3 +1,19 @@ +/* + * Copyright 2025 coze-dev Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + // Code generated by MockGen. DO NOT EDIT. // Source: interface.go // @@ -484,12 +500,13 @@ func (mr *MockServiceMockRecorder) WithExecuteConfig(cfg any) *gomock.Call { } // WithMessagePipe mocks base method. -func (m *MockService) WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message]) { +func (m *MockService) WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message], func()) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "WithMessagePipe") ret0, _ := ret[0].(compose.Option) ret1, _ := ret[1].(*schema.StreamReader[*entity.Message]) - return ret0, ret1 + ret2, _ := ret[2].(func()) + return ret0, ret1, ret2 } // WithMessagePipe indicates an expected call of WithMessagePipe.