fix: workflow tool closes stream writer correctly (#1839)

This commit is contained in:
shentongmartin 2025-08-27 16:29:42 +08:00 committed by GitHub
parent 263a75b1c0
commit 5562800958
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 742 additions and 620 deletions

View File

@ -1815,212 +1815,212 @@ func TestUpdateWorkflowMeta(t *testing.T) {
}) })
} }
//func TestSimpleInvokableToolWithReturnVariables(t *testing.T) { func TestSimpleInvokableToolWithReturnVariables(t *testing.T) {
// mockey.PatchConvey("simple invokable tool with return variables", t, func() { mockey.PatchConvey("simple invokable tool with return variables", t, func() {
// r := newWfTestRunner(t) r := newWfTestRunner(t)
// defer r.closeFn() defer r.closeFn()
//
// toolID := r.load("function_call/tool_workflow_1.json", withID(7492075279843737651), withPublish("v0.0.1"))
//
// chatModel := &testutil.UTChatModel{
// InvokeResultProvider: func(index int, in []*schema.Message) (*schema.Message, error) {
// if index == 0 {
// return &schema.Message{
// Role: schema.Assistant,
// ToolCalls: []schema.ToolCall{
// {
// ID: "1",
// Function: schema.FunctionCall{
// Name: "ts_test_wf_test_wf",
// Arguments: "{}",
// },
// },
// },
// ResponseMeta: &schema.ResponseMeta{
// Usage: &schema.TokenUsage{
// PromptTokens: 10,
// CompletionTokens: 11,
// TotalTokens: 21,
// },
// },
// }, nil
// } else if index == 1 {
// return &schema.Message{
// Role: schema.Assistant,
// Content: "final_answer",
// ResponseMeta: &schema.ResponseMeta{
// Usage: &schema.TokenUsage{
// PromptTokens: 5,
// CompletionTokens: 6,
// TotalTokens: 11,
// },
// },
// }, nil
// } else {
// return nil, fmt.Errorf("unexpected index: %d", index)
// }
// },
// }
// r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(chatModel, nil, nil).AnyTimes()
//
// id := r.load("function_call/llm_with_workflow_as_tool.json")
// defer func() {
// post[workflow.DeleteWorkflowResponse](r, &workflow.DeleteWorkflowRequest{
// WorkflowID: id,
// })
// }()
//
// exeID := r.testRun(id, map[string]string{
// "input": "this is the user input",
// })
//
// e := r.getProcess(id, exeID)
// e.assertSuccess()
// assert.Equal(t, map[string]any{
// "output": "final_answer",
// }, mustUnmarshalToMap(t, e.output))
// e.tokenEqual(15, 17)
//
// mockey.PatchConvey("check behavior if stream run", func() {
// chatModel.Reset()
//
// defer r.runServer()()
//
// r.publish(id, "v0.0.1", true)
//
// sseReader := r.openapiStream(id, map[string]any{
// "input": "hello",
// })
// err := sseReader.ForEach(t.Context(), func(e *sse.Event) error {
// t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data))
// return nil
// })
// assert.NoError(t, err)
//
// // check workflow references are correct
// refs := post[workflow.GetWorkflowReferencesResponse](r, &workflow.GetWorkflowReferencesRequest{
// WorkflowID: toolID,
// })
// assert.Equal(t, 1, len(refs.Data.WorkflowList))
// assert.Equal(t, id, refs.Data.WorkflowList[0].WorkflowID)
// })
// })
//}
//func TestReturnDirectlyStreamableTool(t *testing.T) { toolID := r.load("function_call/tool_workflow_1.json", withID(7492075279843737651), withPublish("v0.0.1"))
// mockey.PatchConvey("return directly streamable tool", t, func() {
// r := newWfTestRunner(t) chatModel := &testutil.UTChatModel{
// defer r.closeFn() InvokeResultProvider: func(index int, in []*schema.Message) (*schema.Message, error) {
// if index == 0 {
// outerModel := &testutil.UTChatModel{ return &schema.Message{
// StreamResultProvider: func(index int, in []*schema.Message) (*schema.StreamReader[*schema.Message], error) { Role: schema.Assistant,
// if index == 0 { ToolCalls: []schema.ToolCall{
// return schema.StreamReaderFromArray([]*schema.Message{ {
// { ID: "1",
// Role: schema.Assistant, Function: schema.FunctionCall{
// ToolCalls: []schema.ToolCall{ Name: "ts_test_wf_test_wf",
// { Arguments: "{}",
// ID: "1", },
// Function: schema.FunctionCall{ },
// Name: "ts_test_wf_test_wf", },
// Arguments: `{"input": "input for inner model"}`, ResponseMeta: &schema.ResponseMeta{
// }, Usage: &schema.TokenUsage{
// }, PromptTokens: 10,
// }, CompletionTokens: 11,
// ResponseMeta: &schema.ResponseMeta{ TotalTokens: 21,
// Usage: &schema.TokenUsage{ },
// PromptTokens: 10, },
// CompletionTokens: 11, }, nil
// TotalTokens: 21, } else if index == 1 {
// }, return &schema.Message{
// }, Role: schema.Assistant,
// }, Content: "final_answer",
// }), nil ResponseMeta: &schema.ResponseMeta{
// } else { Usage: &schema.TokenUsage{
// return nil, fmt.Errorf("unexpected index: %d", index) PromptTokens: 5,
// } CompletionTokens: 6,
// }, TotalTokens: 11,
// } },
// },
// innerModel := &testutil.UTChatModel{ }, nil
// StreamResultProvider: func(index int, in []*schema.Message) (*schema.StreamReader[*schema.Message], error) { } else {
// if index == 0 { return nil, fmt.Errorf("unexpected index: %d", index)
// return schema.StreamReaderFromArray([]*schema.Message{ }
// { },
// Role: schema.Assistant, }
// Content: "I ", r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(chatModel, nil, nil).AnyTimes()
// ResponseMeta: &schema.ResponseMeta{
// Usage: &schema.TokenUsage{ id := r.load("function_call/llm_with_workflow_as_tool.json")
// PromptTokens: 5, defer func() {
// CompletionTokens: 6, post[workflow.DeleteWorkflowResponse](r, &workflow.DeleteWorkflowRequest{
// TotalTokens: 11, WorkflowID: id,
// }, })
// }, }()
// },
// { exeID := r.testRun(id, map[string]string{
// Role: schema.Assistant, "input": "this is the user input",
// Content: "don't know", })
// ResponseMeta: &schema.ResponseMeta{
// Usage: &schema.TokenUsage{ e := r.getProcess(id, exeID)
// CompletionTokens: 8, e.assertSuccess()
// TotalTokens: 8, assert.Equal(t, map[string]any{
// }, "output": "final_answer",
// }, }, mustUnmarshalToMap(t, e.output))
// }, e.tokenEqual(15, 17)
// {
// Role: schema.Assistant, mockey.PatchConvey("check behavior if stream run", func() {
// Content: ".", chatModel.Reset()
// ResponseMeta: &schema.ResponseMeta{
// Usage: &schema.TokenUsage{ defer r.runServer()()
// CompletionTokens: 2,
// TotalTokens: 2, r.publish(id, "v0.0.1", true)
// },
// }, sseReader := r.openapiStream(id, map[string]any{
// }, "input": "hello",
// }), nil })
// } else { err := sseReader.ForEach(t.Context(), func(e *sse.Event) error {
// return nil, fmt.Errorf("unexpected index: %d", index) t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data))
// } return nil
// }, })
// } assert.NoError(t, err)
//
// r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, params *model.LLMParams) (model2.BaseChatModel, *modelmgr.Model, error) { // check workflow references are correct
// if params.ModelType == 1706077826 { refs := post[workflow.GetWorkflowReferencesResponse](r, &workflow.GetWorkflowReferencesRequest{
// innerModel.ModelType = strconv.FormatInt(params.ModelType, 10) WorkflowID: toolID,
// return innerModel, nil, nil })
// } else { assert.Equal(t, 1, len(refs.Data.WorkflowList))
// outerModel.ModelType = strconv.FormatInt(params.ModelType, 10) assert.Equal(t, id, refs.Data.WorkflowList[0].WorkflowID)
// return outerModel, nil, nil })
// } })
// }).AnyTimes() }
//
// r.load("function_call/tool_workflow_2.json", withID(7492615435881709608), withPublish("v0.0.1")) func TestReturnDirectlyStreamableTool(t *testing.T) {
// id := r.load("function_call/llm_workflow_stream_tool.json") mockey.PatchConvey("return directly streamable tool", t, func() {
// r := newWfTestRunner(t)
// exeID := r.testRun(id, map[string]string{ defer r.closeFn()
// "input": "this is the user input",
// }) outerModel := &testutil.UTChatModel{
// e := r.getProcess(id, exeID) StreamResultProvider: func(index int, in []*schema.Message) (*schema.StreamReader[*schema.Message], error) {
// e.assertSuccess() if index == 0 {
// assert.Equal(t, "this is the streaming output I don't know.", e.output) return schema.StreamReaderFromArray([]*schema.Message{
// e.tokenEqual(15, 27) {
// Role: schema.Assistant,
// mockey.PatchConvey("check behavior if stream run", func() { ToolCalls: []schema.ToolCall{
// outerModel.Reset() {
// innerModel.Reset() ID: "1",
// defer r.runServer()() Function: schema.FunctionCall{
// r.publish(id, "v0.0.1", true) Name: "ts_test_wf_test_wf",
// sseReader := r.openapiStream(id, map[string]any{ Arguments: `{"input": "input for inner model"}`,
// "input": "hello", },
// }) },
// err := sseReader.ForEach(t.Context(), func(e *sse.Event) error { },
// t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data)) ResponseMeta: &schema.ResponseMeta{
// return nil Usage: &schema.TokenUsage{
// }) PromptTokens: 10,
// assert.NoError(t, err) CompletionTokens: 11,
// }) TotalTokens: 21,
// }) },
//} },
},
}), nil
} else {
return nil, fmt.Errorf("unexpected index: %d", index)
}
},
}
innerModel := &testutil.UTChatModel{
StreamResultProvider: func(index int, in []*schema.Message) (*schema.StreamReader[*schema.Message], error) {
if index == 0 {
return schema.StreamReaderFromArray([]*schema.Message{
{
Role: schema.Assistant,
Content: "I ",
ResponseMeta: &schema.ResponseMeta{
Usage: &schema.TokenUsage{
PromptTokens: 5,
CompletionTokens: 6,
TotalTokens: 11,
},
},
},
{
Role: schema.Assistant,
Content: "don't know",
ResponseMeta: &schema.ResponseMeta{
Usage: &schema.TokenUsage{
CompletionTokens: 8,
TotalTokens: 8,
},
},
},
{
Role: schema.Assistant,
Content: ".",
ResponseMeta: &schema.ResponseMeta{
Usage: &schema.TokenUsage{
CompletionTokens: 2,
TotalTokens: 2,
},
},
},
}), nil
} else {
return nil, fmt.Errorf("unexpected index: %d", index)
}
},
}
r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, params *model.LLMParams) (model2.BaseChatModel, *modelmgr.Model, error) {
if params.ModelType == 1706077826 {
innerModel.ModelType = strconv.FormatInt(params.ModelType, 10)
return innerModel, nil, nil
} else {
outerModel.ModelType = strconv.FormatInt(params.ModelType, 10)
return outerModel, nil, nil
}
}).AnyTimes()
r.load("function_call/tool_workflow_2.json", withID(7492615435881709608), withPublish("v0.0.1"))
id := r.load("function_call/llm_workflow_stream_tool.json")
exeID := r.testRun(id, map[string]string{
"input": "this is the user input",
})
e := r.getProcess(id, exeID)
e.assertSuccess()
assert.Equal(t, "this is the streaming output I don't know.", e.output)
e.tokenEqual(15, 27)
mockey.PatchConvey("check behavior if stream run", func() {
outerModel.Reset()
innerModel.Reset()
defer r.runServer()()
r.publish(id, "v0.0.1", true)
sseReader := r.openapiStream(id, map[string]any{
"input": "hello",
})
err := sseReader.ForEach(t.Context(), func(e *sse.Event) error {
t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data))
return nil
})
assert.NoError(t, err)
})
})
}
func TestSimpleInterruptibleTool(t *testing.T) { func TestSimpleInterruptibleTool(t *testing.T) {
mockey.PatchConvey("test simple interruptible tool", t, func() { mockey.PatchConvey("test simple interruptible tool", t, func() {
@ -2082,233 +2082,231 @@ func TestSimpleInterruptibleTool(t *testing.T) {
}) })
} }
//func TestStreamableToolWithMultipleInterrupts(t *testing.T) { func TestStreamableToolWithMultipleInterrupts(t *testing.T) {
// mockey.PatchConvey("return directly streamable tool with multiple interrupts", t, func() { mockey.PatchConvey("return directly streamable tool with multiple interrupts", t, func() {
// r := newWfTestRunner(t) r := newWfTestRunner(t)
// defer r.closeFn() defer r.closeFn()
//
// outerModel := &testutil.UTChatModel{
// StreamResultProvider: func(index int, in []*schema.Message) (*schema.StreamReader[*schema.Message], error) {
// if index == 0 {
// return schema.StreamReaderFromArray([]*schema.Message{
// {
// Role: schema.Assistant,
// ToolCalls: []schema.ToolCall{
// {
// ID: "1",
// Function: schema.FunctionCall{
// Name: "ts_test_wf_test_wf",
// Arguments: `{"input": "what's your name and age"}`,
// },
// },
// },
// ResponseMeta: &schema.ResponseMeta{
// Usage: &schema.TokenUsage{
// PromptTokens: 6,
// CompletionTokens: 7,
// TotalTokens: 13,
// },
// },
// },
// }), nil
// } else if index == 1 {
// return schema.StreamReaderFromArray([]*schema.Message{
// {
// Role: schema.Assistant,
// Content: "I now know your ",
// ResponseMeta: &schema.ResponseMeta{
// Usage: &schema.TokenUsage{
// PromptTokens: 5,
// CompletionTokens: 8,
// TotalTokens: 13,
// },
// },
// },
// {
// Role: schema.Assistant,
// Content: "name is Eino and age is 1.",
// ResponseMeta: &schema.ResponseMeta{
// Usage: &schema.TokenUsage{
// CompletionTokens: 10,
// TotalTokens: 17,
// },
// },
// },
// }), nil
// } else {
// return nil, fmt.Errorf("unexpected index: %d", index)
// }
// },
// }
//
// innerModel := &testutil.UTChatModel{
// InvokeResultProvider: func(index int, in []*schema.Message) (*schema.Message, error) {
// if index == 0 {
// return &schema.Message{
// Role: schema.Assistant,
// Content: `{"question": "what's your age?"}`,
// ResponseMeta: &schema.ResponseMeta{
// Usage: &schema.TokenUsage{
// PromptTokens: 6,
// CompletionTokens: 7,
// TotalTokens: 13,
// },
// },
// }, nil
// } else if index == 1 {
// return &schema.Message{
// Role: schema.Assistant,
// Content: `{"fields": {"name": "eino", "age": 1}}`,
// ResponseMeta: &schema.ResponseMeta{
// Usage: &schema.TokenUsage{
// PromptTokens: 8,
// CompletionTokens: 10,
// TotalTokens: 18,
// },
// },
// }, nil
// } else {
// return nil, fmt.Errorf("unexpected index: %d", index)
// }
// },
// }
//
// r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, params *model.LLMParams) (model2.BaseChatModel, *modelmgr.Model, error) {
// if params.ModelType == 1706077827 {
// outerModel.ModelType = strconv.FormatInt(params.ModelType, 10)
// return outerModel, nil, nil
// } else {
// innerModel.ModelType = strconv.FormatInt(params.ModelType, 10)
// return innerModel, nil, nil
// }
// }).AnyTimes()
//
// r.load("function_call/tool_workflow_3.json", withID(7492615435881709611), withPublish("v0.0.1"))
// id := r.load("function_call/llm_workflow_stream_tool_1.json")
//
// exeID := r.testRun(id, map[string]string{
// "input": "this is the user input",
// })
//
// e := r.getProcess(id, exeID)
// assert.NotNil(t, e.event)
// e.tokenEqual(0, 0)
//
// r.testResume(id, exeID, e.event.ID, "my name is eino")
// e2 := r.getProcess(id, exeID, withPreviousEventID(e.event.ID))
// assert.NotNil(t, e2.event)
// e2.tokenEqual(0, 0)
//
// r.testResume(id, exeID, e2.event.ID, "1 year old")
// e3 := r.getProcess(id, exeID, withPreviousEventID(e2.event.ID))
// e3.assertSuccess()
// assert.Equal(t, "the name is eino, age is 1", e3.output)
// e3.tokenEqual(20, 24)
// })
//}
//func TestNodeWithBatchEnabled(t *testing.T) { outerModel := &testutil.UTChatModel{
// mockey.PatchConvey("test node with batch enabled", t, func() { StreamResultProvider: func(index int, in []*schema.Message) (*schema.StreamReader[*schema.Message], error) {
// r := newWfTestRunner(t) if index == 0 {
// defer r.closeFn() return schema.StreamReaderFromArray([]*schema.Message{
// {
// r.load("batch/sub_workflow_as_batch.json", withID(7469707607914217512), withPublish("v0.0.1")) Role: schema.Assistant,
// ToolCalls: []schema.ToolCall{
// chatModel := &testutil.UTChatModel{ {
// InvokeResultProvider: func(index int, in []*schema.Message) (*schema.Message, error) { ID: "1",
// if index == 0 { Function: schema.FunctionCall{
// return &schema.Message{ Name: "ts_test_wf_test_wf",
// Role: schema.Assistant, Arguments: `{"input": "what's your name and age"}`,
// Content: "answer。for index 0", },
// ResponseMeta: &schema.ResponseMeta{ },
// Usage: &schema.TokenUsage{ },
// PromptTokens: 5, ResponseMeta: &schema.ResponseMeta{
// CompletionTokens: 6, Usage: &schema.TokenUsage{
// TotalTokens: 11, PromptTokens: 6,
// }, CompletionTokens: 7,
// }, TotalTokens: 13,
// }, nil },
// } else if index == 1 { },
// return &schema.Message{ },
// Role: schema.Assistant, }), nil
// Content: "answerfor index 1", } else if index == 1 {
// ResponseMeta: &schema.ResponseMeta{ return schema.StreamReaderFromArray([]*schema.Message{
// Usage: &schema.TokenUsage{ {
// PromptTokens: 5, Role: schema.Assistant,
// CompletionTokens: 6, Content: "I now know your ",
// TotalTokens: 11, ResponseMeta: &schema.ResponseMeta{
// }, Usage: &schema.TokenUsage{
// }, PromptTokens: 5,
// }, nil CompletionTokens: 8,
// } else { TotalTokens: 13,
// return nil, fmt.Errorf("unexpected index: %d", index) },
// } },
// }, },
// } {
// r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(chatModel, nil, nil).AnyTimes() Role: schema.Assistant,
// Content: "name is Eino and age is 1.",
// id := r.load("batch/node_batches.json") ResponseMeta: &schema.ResponseMeta{
// Usage: &schema.TokenUsage{
// exeID := r.testRun(id, map[string]string{ CompletionTokens: 10,
// "input": `["first input", "second input"]`, TotalTokens: 17,
// }) },
// e := r.getProcess(id, exeID) },
// e.assertSuccess() },
// assert.Equal(t, map[string]any{ }), nil
// "output": []any{ } else {
// map[string]any{ return nil, fmt.Errorf("unexpected index: %d", index)
// "output": []any{ }
// "answer", },
// "for index 0", }
// },
// "input": "answer。for index 0", innerModel := &testutil.UTChatModel{
// }, InvokeResultProvider: func(index int, in []*schema.Message) (*schema.Message, error) {
// map[string]any{ if index == 0 {
// "output": []any{ return &schema.Message{
// "answer", Role: schema.Assistant,
// "for index 1", Content: `{"question": "what's your age?"}`,
// }, ResponseMeta: &schema.ResponseMeta{
// "input": "answerfor index 1", Usage: &schema.TokenUsage{
// }, PromptTokens: 6,
// }, CompletionTokens: 7,
// }, mustUnmarshalToMap(t, e.output)) TotalTokens: 13,
// e.tokenEqual(10, 12) },
// },
// // verify this workflow has previously succeeded a test run }, nil
// result := r.getNodeExeHistory(id, "", "100001", ptr.Of(workflow.NodeHistoryScene_TestRunInput)) } else if index == 1 {
// assert.True(t, len(result.Output) > 0) return &schema.Message{
// Role: schema.Assistant,
// // verify querying this node's result for a particular test run Content: `{"fields": {"name": "eino", "age": 1}}`,
// result = r.getNodeExeHistory(id, exeID, "178876", nil) ResponseMeta: &schema.ResponseMeta{
// assert.True(t, len(result.Output) > 0) Usage: &schema.TokenUsage{
// PromptTokens: 8,
// mockey.PatchConvey("test node debug with batch mode", func() { CompletionTokens: 10,
// exeID = r.nodeDebug(id, "178876", withNDBatch(map[string]string{"item1": `[{"output":"output_1"},{"output":"output_2"}]`})) TotalTokens: 18,
// e = r.getProcess(id, exeID) },
// e.assertSuccess() },
// assert.Equal(t, map[string]any{ }, nil
// "outputList": []any{ } else {
// map[string]any{ return nil, fmt.Errorf("unexpected index: %d", index)
// "input": "output_1", }
// "output": []any{"output_1"}, },
// }, }
// map[string]any{
// "input": "output_2", r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, params *model.LLMParams) (model2.BaseChatModel, *modelmgr.Model, error) {
// "output": []any{"output_2"}, if params.ModelType == 1706077827 {
// }, outerModel.ModelType = strconv.FormatInt(params.ModelType, 10)
// }, return outerModel, nil, nil
// }, mustUnmarshalToMap(t, e.output)) } else {
// innerModel.ModelType = strconv.FormatInt(params.ModelType, 10)
// // verify querying this node's result for this node debug run return innerModel, nil, nil
// result := r.getNodeExeHistory(id, exeID, "178876", nil) }
// assert.Equal(t, mustUnmarshalToMap(t, e.output), mustUnmarshalToMap(t, result.Output)) }).AnyTimes()
//
// // verify querying this node's has succeeded any node debug run r.load("function_call/tool_workflow_3.json", withID(7492615435881709611), withPublish("v0.0.1"))
// result = r.getNodeExeHistory(id, "", "178876", ptr.Of(workflow.NodeHistoryScene_TestRunInput)) id := r.load("function_call/llm_workflow_stream_tool_1.json")
// assert.Equal(t, mustUnmarshalToMap(t, e.output), mustUnmarshalToMap(t, result.Output))
// }) exeID := r.testRun(id, map[string]string{
// }) "input": "this is the user input",
//} })
e := r.getProcess(id, exeID)
assert.NotNil(t, e.event)
e.tokenEqual(0, 0)
r.testResume(id, exeID, e.event.ID, "my name is eino")
e2 := r.getProcess(id, exeID, withPreviousEventID(e.event.ID))
assert.NotNil(t, e2.event)
e2.tokenEqual(0, 0)
r.testResume(id, exeID, e2.event.ID, "1 year old")
e3 := r.getProcess(id, exeID, withPreviousEventID(e2.event.ID))
e3.assertSuccess()
assert.Equal(t, "the name is eino, age is 1", e3.output)
e3.tokenEqual(20, 24)
})
}
func TestNodeWithBatchEnabled(t *testing.T) {
mockey.PatchConvey("test node with batch enabled", t, func() {
r := newWfTestRunner(t)
defer r.closeFn()
r.load("batch/sub_workflow_as_batch.json", withID(7469707607914217512), withPublish("v0.0.1"))
chatModel := &testutil.UTChatModel{
InvokeResultProvider: func(index int, in []*schema.Message) (*schema.Message, error) {
if index == 0 {
return &schema.Message{
Role: schema.Assistant,
Content: "answer。for index 0",
ResponseMeta: &schema.ResponseMeta{
Usage: &schema.TokenUsage{
PromptTokens: 5,
CompletionTokens: 6,
TotalTokens: 11,
},
},
}, nil
} else if index == 1 {
return &schema.Message{
Role: schema.Assistant,
Content: "answerfor index 1",
ResponseMeta: &schema.ResponseMeta{
Usage: &schema.TokenUsage{
PromptTokens: 5,
CompletionTokens: 6,
TotalTokens: 11,
},
},
}, nil
} else {
return nil, fmt.Errorf("unexpected index: %d", index)
}
},
}
r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(chatModel, nil, nil).AnyTimes()
id := r.load("batch/node_batches.json")
exeID := r.testRun(id, map[string]string{
"input": `["first input", "second input"]`,
})
e := r.getProcess(id, exeID)
e.assertSuccess()
outputMap := mustUnmarshalToMap(t, e.output)
assert.Contains(t, outputMap["output"], map[string]any{
"output": []any{
"answer",
"for index 0",
},
"input": "answer。for index 0",
})
assert.Contains(t, outputMap["output"], map[string]any{
"output": []any{
"answer",
"for index 1",
},
"input": "answerfor index 1",
})
assert.Equal(t, 2, len(outputMap["output"].([]any)))
e.tokenEqual(10, 12)
// verify this workflow has previously succeeded a test run
result := r.getNodeExeHistory(id, "", "100001", ptr.Of(workflow.NodeHistoryScene_TestRunInput))
assert.True(t, len(result.Output) > 0)
// verify querying this node's result for a particular test run
result = r.getNodeExeHistory(id, exeID, "178876", nil)
assert.True(t, len(result.Output) > 0)
mockey.PatchConvey("test node debug with batch mode", func() {
exeID = r.nodeDebug(id, "178876", withNDBatch(map[string]string{"item1": `[{"output":"output_1"},{"output":"output_2"}]`}))
e = r.getProcess(id, exeID)
e.assertSuccess()
assert.Equal(t, map[string]any{
"outputList": []any{
map[string]any{
"input": "output_1",
"output": []any{"output_1"},
},
map[string]any{
"input": "output_2",
"output": []any{"output_2"},
},
},
}, mustUnmarshalToMap(t, e.output))
// verify querying this node's result for this node debug run
result := r.getNodeExeHistory(id, exeID, "178876", nil)
assert.Equal(t, mustUnmarshalToMap(t, e.output), mustUnmarshalToMap(t, result.Output))
// verify querying this node's has succeeded any node debug run
result = r.getNodeExeHistory(id, "", "178876", ptr.Of(workflow.NodeHistoryScene_TestRunInput))
assert.Equal(t, mustUnmarshalToMap(t, e.output), mustUnmarshalToMap(t, result.Output))
})
})
}
func TestStartNodeDefaultValues(t *testing.T) { func TestStartNodeDefaultValues(t *testing.T) {
mockey.PatchConvey("default values", t, func() { mockey.PatchConvey("default values", t, func() {

View File

@ -40,7 +40,7 @@ type Workflow interface {
GetWorkflowIDsByAppID(ctx context.Context, appID int64) ([]int64, error) GetWorkflowIDsByAppID(ctx context.Context, appID int64) ([]int64, error)
SyncExecuteWorkflow(ctx context.Context, config workflowModel.ExecuteConfig, input map[string]any) (*workflowEntity.WorkflowExecution, vo.TerminatePlan, error) SyncExecuteWorkflow(ctx context.Context, config workflowModel.ExecuteConfig, input map[string]any) (*workflowEntity.WorkflowExecution, vo.TerminatePlan, error)
WithExecuteConfig(cfg workflowModel.ExecuteConfig) einoCompose.Option WithExecuteConfig(cfg workflowModel.ExecuteConfig) einoCompose.Option
WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message]) WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message], func())
} }
type ExecuteConfig = workflowModel.ExecuteConfig type ExecuteConfig = workflowModel.ExecuteConfig

View File

@ -66,7 +66,7 @@ func (i *impl) WithExecuteConfig(cfg workflowModel.ExecuteConfig) einoCompose.Op
return i.DomainSVC.WithExecuteConfig(cfg) return i.DomainSVC.WithExecuteConfig(cfg)
} }
func (i *impl) WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message]) { func (i *impl) WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message], func()) {
return i.DomainSVC.WithMessagePipe() return i.DomainSVC.WithMessagePipe()
} }

View File

@ -74,6 +74,7 @@ func (r *AgentRunner) StreamExecute(ctx context.Context, req *AgentRequest) (
var composeOpts []compose.Option var composeOpts []compose.Option
var pipeMsgOpt compose.Option var pipeMsgOpt compose.Option
var workflowMsgSr *schema.StreamReader[*crossworkflow.WorkflowMessage] var workflowMsgSr *schema.StreamReader[*crossworkflow.WorkflowMessage]
var workflowMsgCloser func()
if r.containWfTool { if r.containWfTool {
cfReq := crossworkflow.ExecuteConfig{ cfReq := crossworkflow.ExecuteConfig{
AgentID: &req.Identity.AgentID, AgentID: &req.Identity.AgentID,
@ -88,7 +89,7 @@ func (r *AgentRunner) StreamExecute(ctx context.Context, req *AgentRequest) (
} }
wfConfig := crossworkflow.DefaultSVC().WithExecuteConfig(cfReq) wfConfig := crossworkflow.DefaultSVC().WithExecuteConfig(cfReq)
composeOpts = append(composeOpts, wfConfig) composeOpts = append(composeOpts, wfConfig)
pipeMsgOpt, workflowMsgSr = crossworkflow.DefaultSVC().WithMessagePipe() pipeMsgOpt, workflowMsgSr, workflowMsgCloser = crossworkflow.DefaultSVC().WithMessagePipe()
composeOpts = append(composeOpts, pipeMsgOpt) composeOpts = append(composeOpts, pipeMsgOpt)
} }
@ -120,6 +121,9 @@ func (r *AgentRunner) StreamExecute(ctx context.Context, req *AgentRequest) (
sw.Send(nil, errors.New("internal server error")) sw.Send(nil, errors.New("internal server error"))
} }
if workflowMsgCloser != nil {
workflowMsgCloser()
}
sw.Close() sw.Close()
}() }()
_, _ = r.runner.Stream(ctx, req, composeOpts...) _, _ = r.runner.Stream(ctx, req, composeOpts...)
@ -136,6 +140,7 @@ func (r *AgentRunner) processWfMidAnswerStream(_ context.Context, sw *schema.Str
if swT != nil { if swT != nil {
swT.Close() swT.Close()
} }
wfStream.Close()
}() }()
for { for {
msg, err := wfStream.Recv() msg, err := wfStream.Recv()

View File

@ -48,7 +48,7 @@ type Executable interface {
type AsTool interface { type AsTool interface {
WorkflowAsModelTool(ctx context.Context, policies []*vo.GetPolicy) ([]ToolFromWorkflow, error) WorkflowAsModelTool(ctx context.Context, policies []*vo.GetPolicy) ([]ToolFromWorkflow, error)
WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message]) WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message], func())
WithExecuteConfig(cfg workflowModel.ExecuteConfig) compose.Option WithExecuteConfig(cfg workflowModel.ExecuteConfig) compose.Option
WithResumeToolWorkflow(resumingEvent *entity.ToolInterruptEvent, resumeData string, WithResumeToolWorkflow(resumingEvent *entity.ToolInterruptEvent, resumeData string,
allInterruptEvents map[string]*entity.ToolInterruptEvent) compose.Option allInterruptEvents map[string]*entity.ToolInterruptEvent) compose.Option

View File

@ -85,6 +85,7 @@ type ToolResponseInfo struct {
FunctionInfo FunctionInfo
CallID string CallID string
Response string Response string
Complete bool
} }
type ToolType = workflow.PluginType type ToolType = workflow.PluginType

View File

@ -23,7 +23,6 @@ import (
"strconv" "strconv"
einoCompose "github.com/cloudwego/eino/compose" einoCompose "github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin" "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow" model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
@ -47,7 +46,7 @@ func (r *WorkflowRunner) designateOptions(ctx context.Context) (context.Context,
workflowSC = r.schema workflowSC = r.schema
eventChan = r.eventChan eventChan = r.eventChan
resumedEvent = r.interruptEvent resumedEvent = r.interruptEvent
sw = r.streamWriter sw = r.container
) )
if wb.AppID != nil && exeCfg.AppID == nil { if wb.AppID != nil && exeCfg.AppID == nil {
@ -148,7 +147,7 @@ func (r *WorkflowRunner) designateOptionsForSubWorkflow(ctx context.Context,
var ( var (
resumeEvent = r.interruptEvent resumeEvent = r.interruptEvent
eventChan = r.eventChan eventChan = r.eventChan
sw = r.streamWriter container = r.container
) )
subHandler := execute.NewSubWorkflowHandler( subHandler := execute.NewSubWorkflowHandler(
parentHandler, parentHandler,
@ -186,7 +185,7 @@ func (r *WorkflowRunner) designateOptionsForSubWorkflow(ctx context.Context,
opts = append(opts, WrapOpt(subO, ns.Key)) opts = append(opts, WrapOpt(subO, ns.Key))
} }
} else if subNS.Type == entity.NodeTypeLLM { } else if subNS.Type == entity.NodeTypeLLM {
llmNodeOpts, err := llmToolCallbackOptions(ctx, subNS, eventChan, sw) llmNodeOpts, err := llmToolCallbackOptions(ctx, subNS, eventChan, container)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -209,7 +208,7 @@ func (r *WorkflowRunner) designateOptionsForSubWorkflow(ctx context.Context,
opts = append(opts, WrapOpt(WrapOpt(subO, parent.Key), ns.Key)) opts = append(opts, WrapOpt(WrapOpt(subO, parent.Key), ns.Key))
} }
} else if subNS.Type == entity.NodeTypeLLM { } else if subNS.Type == entity.NodeTypeLLM {
llmNodeOpts, err := llmToolCallbackOptions(ctx, subNS, eventChan, sw) llmNodeOpts, err := llmToolCallbackOptions(ctx, subNS, eventChan, container)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -224,7 +223,7 @@ func (r *WorkflowRunner) designateOptionsForSubWorkflow(ctx context.Context,
} }
func llmToolCallbackOptions(ctx context.Context, ns *schema2.NodeSchema, eventChan chan *execute.Event, func llmToolCallbackOptions(ctx context.Context, ns *schema2.NodeSchema, eventChan chan *execute.Event,
sw *schema.StreamWriter[*entity.Message]) ( container *execute.StreamContainer) (
opts []einoCompose.Option, err error) { opts []einoCompose.Option, err error) {
// this is a LLM node. // this is a LLM node.
// check if it has any tools, if no tools, then no callback options needed // check if it has any tools, if no tools, then no callback options needed
@ -280,6 +279,12 @@ func llmToolCallbackOptions(ctx context.Context, ns *schema2.NodeSchema, eventCh
opt = einoCompose.WithLambdaOption(nodes.WithOptsForNested(opt)).DesignateNode(string(ns.Key)) opt = einoCompose.WithLambdaOption(nodes.WithOptsForNested(opt)).DesignateNode(string(ns.Key))
opts = append(opts, opt) opts = append(opts, opt)
} }
if container != nil {
toolMsgOpt := llm.WithToolWorkflowStreamContainer(container)
opt := einoCompose.WithLambdaOption(toolMsgOpt).DesignateNode(string(ns.Key))
opts = append(opts, opt)
}
} }
if fcParams.PluginFCParam != nil { if fcParams.PluginFCParam != nil {
for _, p := range fcParams.PluginFCParam.PluginList { for _, p := range fcParams.PluginFCParam.PluginList {
@ -321,11 +326,5 @@ func llmToolCallbackOptions(ctx context.Context, ns *schema2.NodeSchema, eventCh
} }
} }
if sw != nil {
toolMsgOpt := llm.WithToolWorkflowMessageWriter(sw)
opt := einoCompose.WithLambdaOption(toolMsgOpt).DesignateNode(string(ns.Key))
opts = append(opts, opt)
}
return opts, nil return opts, nil
} }

View File

@ -38,6 +38,7 @@ import (
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary" "github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
"github.com/coze-dev/coze-studio/backend/pkg/logs" "github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/safego" "github.com/coze-dev/coze-studio/backend/pkg/safego"
"github.com/coze-dev/coze-studio/backend/types/errno"
) )
type WorkflowRunner struct { type WorkflowRunner struct {
@ -45,7 +46,8 @@ type WorkflowRunner struct {
input string input string
resumeReq *entity.ResumeRequest resumeReq *entity.ResumeRequest
schema *schema2.WorkflowSchema schema *schema2.WorkflowSchema
streamWriter *schema.StreamWriter[*entity.Message] sw *schema.StreamWriter[*entity.Message]
container *execute.StreamContainer
config model.ExecuteConfig config model.ExecuteConfig
executeID int64 executeID int64
@ -84,12 +86,18 @@ func NewWorkflowRunner(b *entity.WorkflowBasic, sc *schema2.WorkflowSchema, conf
opt(options) opt(options)
} }
var container *execute.StreamContainer
if options.streamWriter != nil {
container = execute.NewStreamContainer(options.streamWriter)
}
return &WorkflowRunner{ return &WorkflowRunner{
basic: b, basic: b,
input: options.input, input: options.input,
resumeReq: options.resumeReq, resumeReq: options.resumeReq,
schema: sc, schema: sc,
streamWriter: options.streamWriter, sw: options.streamWriter,
container: container,
config: config, config: config,
} }
} }
@ -108,14 +116,16 @@ func (r *WorkflowRunner) Prepare(ctx context.Context) (
resumeReq = r.resumeReq resumeReq = r.resumeReq
wb = r.basic wb = r.basic
sc = r.schema sc = r.schema
sw = r.streamWriter sw = r.sw
container = r.container
config = r.config config = r.config
) )
if r.resumeReq == nil { if r.resumeReq == nil {
executeID, err = repo.GenID(ctx) executeID, err = repo.GenID(ctx)
if err != nil { if err != nil {
return ctx, 0, nil, nil, fmt.Errorf("failed to generate workflow execute ID: %w", err) return ctx, 0, nil, nil, vo.WrapError(errno.ErrIDGenError,
fmt.Errorf("failed to generate workflow execute ID: %w", err))
} }
} else { } else {
executeID = resumeReq.ExecuteID executeID = resumeReq.ExecuteID
@ -148,6 +158,15 @@ func (r *WorkflowRunner) Prepare(ctx context.Context) (
r.eventChan = eventChan r.eventChan = eventChan
r.interruptEvent = interruptEvent r.interruptEvent = interruptEvent
if container != nil {
go container.PipeAll()
defer func() {
if err != nil {
container.Done()
}
}()
}
ctx, composeOpts, err := r.designateOptions(ctx) ctx, composeOpts, err := r.designateOptions(ctx)
if err != nil { if err != nil {
return ctx, 0, nil, nil, err return ctx, 0, nil, nil, err
@ -277,8 +296,8 @@ func (r *WorkflowRunner) Prepare(ctx context.Context) (
} }
}() }()
defer func() { defer func() {
if sw != nil { if container != nil {
sw.Close() container.Done()
} }
}() }()

View File

@ -33,17 +33,22 @@ import (
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/logs" "github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/sonic" "github.com/coze-dev/coze-studio/backend/pkg/sonic"
"github.com/coze-dev/coze-studio/backend/types/errno"
) )
const answerKey = "output" const answerKey = "output"
type invokableWorkflow struct { type invokableWorkflow struct {
info *schema.ToolInfo workflowTool
invoke func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (map[string]any, error) invoke func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (map[string]any, error)
terminatePlan vo.TerminatePlan }
type workflowTool struct {
info *schema.ToolInfo
wfEntity *entity.Workflow wfEntity *entity.Workflow
sc *schema2.WorkflowSchema sc *schema2.WorkflowSchema
repo wf.Repository repo wf.Repository
terminatePlan vo.TerminatePlan
} }
func NewInvokableWorkflow(info *schema.ToolInfo, func NewInvokableWorkflow(info *schema.ToolInfo,
@ -54,12 +59,14 @@ func NewInvokableWorkflow(info *schema.ToolInfo,
repo wf.Repository, repo wf.Repository,
) wf.ToolFromWorkflow { ) wf.ToolFromWorkflow {
return &invokableWorkflow{ return &invokableWorkflow{
workflowTool: workflowTool{
info: info, info: info,
invoke: invoke,
terminatePlan: terminatePlan,
wfEntity: wfEntity, wfEntity: wfEntity,
sc: sc, sc: sc,
repo: repo, repo: repo,
terminatePlan: terminatePlan,
},
invoke: invoke,
} }
} }
@ -77,6 +84,52 @@ func resumeOnce(rInfo *entity.ResumeRequest, callID string, allIEs map[string]*e
} }
} }
func (wt *workflowTool) prepare(ctx context.Context, rInfo *entity.ResumeRequest, argumentsInJSON string, opts ...tool.Option) (
cancelCtx context.Context, executeID int64, input map[string]any, callOpts []einoCompose.Option, err error) {
cfg := execute.GetExecuteConfig(opts...)
var runOpts []WorkflowRunnerOption
if rInfo != nil && !rInfo.Resumed {
runOpts = append(runOpts, WithResumeReq(rInfo))
} else {
runOpts = append(runOpts, WithInput(argumentsInJSON))
}
if container := execute.GetParentStreamContainer(opts...); container != nil {
sr, sw := schema.Pipe[*entity.Message](10)
container.AddChild(sr)
runOpts = append(runOpts, WithStreamWriter(sw))
}
var ws *nodes.ConversionWarnings
if (rInfo == nil || rInfo.Resumed) && len(wt.wfEntity.InputParams) > 0 {
if err = sonic.UnmarshalString(argumentsInJSON, &input); err != nil {
err = vo.WrapError(errno.ErrSerializationDeserializationFail, err)
return
}
var entryNode *schema2.NodeSchema
for _, node := range wt.sc.Nodes {
if node.Type == entity.NodeTypeEntry {
entryNode = node
break
}
}
if entryNode == nil {
panic("entry node not found in tool workflow")
}
input, ws, err = nodes.ConvertInputs(ctx, input, entryNode.OutputTypes)
if err != nil {
return
} else if ws != nil {
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
}
}
cancelCtx, executeID, callOpts, _, err = NewWorkflowRunner(wt.wfEntity.GetBasic(), wt.sc, cfg, runOpts...).Prepare(ctx)
return
}
func (i *invokableWorkflow) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { func (i *invokableWorkflow) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) {
rInfo, allIEs := execute.GetResumeRequest(opts...) rInfo, allIEs := execute.GetResumeRequest(opts...)
var ( var (
@ -97,52 +150,9 @@ func (i *invokableWorkflow) InvokableRun(ctx context.Context, argumentsInJSON st
return "", einoCompose.InterruptAndRerun return "", einoCompose.InterruptAndRerun
} }
cfg := execute.GetExecuteConfig(opts...)
defer resumeOnce(rInfo, callID, allIEs) defer resumeOnce(rInfo, callID, allIEs)
var runOpts []WorkflowRunnerOption cancelCtx, executeID, in, callOpts, err := i.prepare(ctx, rInfo, argumentsInJSON, opts...)
if rInfo != nil && !rInfo.Resumed {
runOpts = append(runOpts, WithResumeReq(rInfo))
} else {
runOpts = append(runOpts, WithInput(argumentsInJSON))
}
if sw := execute.GetIntermediateStreamWriter(opts...); sw != nil {
runOpts = append(runOpts, WithStreamWriter(sw))
}
var (
cancelCtx context.Context
executeID int64
callOpts []einoCompose.Option
in map[string]any
err error
ws *nodes.ConversionWarnings
)
if rInfo == nil && len(i.wfEntity.InputParams) > 0 {
if err = sonic.UnmarshalString(argumentsInJSON, &in); err != nil {
return "", err
}
var entryNode *schema2.NodeSchema
for _, node := range i.sc.Nodes {
if node.Type == entity.NodeTypeEntry {
entryNode = node
break
}
}
if entryNode == nil {
panic("entry node not found in tool workflow")
}
in, ws, err = nodes.ConvertInputs(ctx, in, entryNode.OutputTypes)
if err != nil {
return "", err
} else if ws != nil {
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
}
}
cancelCtx, executeID, callOpts, _, err = NewWorkflowRunner(i.wfEntity.GetBasic(), i.sc, cfg, runOpts...).Prepare(ctx)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -198,12 +208,8 @@ func (i *invokableWorkflow) GetWorkflow() *entity.Workflow {
} }
type streamableWorkflow struct { type streamableWorkflow struct {
info *schema.ToolInfo workflowTool
stream func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (*schema.StreamReader[map[string]any], error) stream func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (*schema.StreamReader[map[string]any], error)
terminatePlan vo.TerminatePlan
wfEntity *entity.Workflow
sc *schema2.WorkflowSchema
repo wf.Repository
} }
func NewStreamableWorkflow(info *schema.ToolInfo, func NewStreamableWorkflow(info *schema.ToolInfo,
@ -214,12 +220,14 @@ func NewStreamableWorkflow(info *schema.ToolInfo,
repo wf.Repository, repo wf.Repository,
) wf.ToolFromWorkflow { ) wf.ToolFromWorkflow {
return &streamableWorkflow{ return &streamableWorkflow{
workflowTool: workflowTool{
info: info, info: info,
stream: stream,
terminatePlan: terminatePlan,
wfEntity: wfEntity, wfEntity: wfEntity,
sc: sc, sc: sc,
repo: repo, repo: repo,
terminatePlan: terminatePlan,
},
stream: stream,
} }
} }
@ -247,52 +255,9 @@ func (s *streamableWorkflow) StreamableRun(ctx context.Context, argumentsInJSON
return nil, einoCompose.InterruptAndRerun return nil, einoCompose.InterruptAndRerun
} }
cfg := execute.GetExecuteConfig(opts...)
defer resumeOnce(rInfo, callID, allIEs) defer resumeOnce(rInfo, callID, allIEs)
var runOpts []WorkflowRunnerOption cancelCtx, executeID, in, callOpts, err := s.prepare(ctx, rInfo, argumentsInJSON, opts...)
if rInfo != nil && !rInfo.Resumed {
runOpts = append(runOpts, WithResumeReq(rInfo))
} else {
runOpts = append(runOpts, WithInput(argumentsInJSON))
}
if sw := execute.GetIntermediateStreamWriter(opts...); sw != nil {
runOpts = append(runOpts, WithStreamWriter(sw))
}
var (
cancelCtx context.Context
executeID int64
callOpts []einoCompose.Option
in map[string]any
err error
ws *nodes.ConversionWarnings
)
if rInfo == nil && len(s.wfEntity.InputParams) > 0 {
if err = sonic.UnmarshalString(argumentsInJSON, &in); err != nil {
return nil, err
}
var entryNode *schema2.NodeSchema
for _, node := range s.sc.Nodes {
if node.Type == entity.NodeTypeEntry {
entryNode = node
break
}
}
if entryNode == nil {
panic("entry node not found in tool workflow")
}
in, ws, err = nodes.ConvertInputs(ctx, in, entryNode.OutputTypes)
if err != nil {
return nil, err
} else if ws != nil {
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
}
}
cancelCtx, executeID, callOpts, _, err = NewWorkflowRunner(s.wfEntity.GetBasic(), s.sc, cfg, runOpts...).Prepare(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -370,6 +370,10 @@ func (w *WorkflowHandler) OnError(ctx context.Context, info *callbacks.RunInfo,
interruptEvent.EventType, interruptEvent.NodeKey) interruptEvent.EventType, interruptEvent.NodeKey)
} }
if c.TokenCollector != nil { // wait until all streaming chunks are collected
_ = c.TokenCollector.wait()
}
done := make(chan struct{}) done := make(chan struct{})
w.ch <- &Event{ w.ch <- &Event{
@ -1309,6 +1313,7 @@ func (t *ToolHandler) OnEnd(ctx context.Context, info *callbacks.RunInfo,
FunctionInfo: t.info, FunctionInfo: t.info,
CallID: compose.GetToolCallID(ctx), CallID: compose.GetToolCallID(ctx),
Response: output.Response, Response: output.Response,
Complete: true,
}, },
} }
@ -1347,6 +1352,7 @@ func (t *ToolHandler) OnEndWithStreamOutput(ctx context.Context, info *callbacks
toolResponse: &entity.ToolResponseInfo{ toolResponse: &entity.ToolResponseInfo{
FunctionInfo: t.info, FunctionInfo: t.info,
CallID: callID, CallID: callID,
Complete: true,
}, },
} }
} }

View File

@ -76,6 +76,20 @@ func (t *TokenCollector) add(i int) {
return return
} }
func (t *TokenCollector) startStreamCounting() {
t.wg.Add(1)
if t.Parent != nil {
t.Parent.startStreamCounting()
}
}
func (t *TokenCollector) finishStreamCounting() {
t.wg.Done()
if t.Parent != nil {
t.Parent.finishStreamCounting()
}
}
func getTokenCollector(ctx context.Context) *TokenCollector { func getTokenCollector(ctx context.Context) *TokenCollector {
c := GetExeCtx(ctx) c := GetExeCtx(ctx)
if c == nil { if c == nil {
@ -92,7 +106,6 @@ func GetTokenCallbackHandler() callbacks.Handler {
return ctx return ctx
} }
c.add(1) c.add(1)
//c.wg.Add(1)
return ctx return ctx
}, },
OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *model.CallbackOutput) context.Context { OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *model.CallbackOutput) context.Context {
@ -114,6 +127,7 @@ func GetTokenCallbackHandler() callbacks.Handler {
output.Close() output.Close()
return ctx return ctx
} }
c.startStreamCounting()
safego.Go(ctx, func() { safego.Go(ctx, func() {
defer func() { defer func() {
output.Close() output.Close()
@ -141,6 +155,7 @@ func GetTokenCallbackHandler() callbacks.Handler {
if newC.TotalTokens > 0 { if newC.TotalTokens > 0 {
c.addTokenUsage(newC) c.addTokenUsage(newC)
} }
c.finishStreamCounting()
}) })
return ctx return ctx
}, },

View File

@ -789,6 +789,7 @@ func HandleExecuteEvent(ctx context.Context,
logs.CtxInfof(ctx, "[handleExecuteEvent] finish, returned event type: %v, workflow id: %d", logs.CtxInfof(ctx, "[handleExecuteEvent] finish, returned event type: %v, workflow id: %d",
event.Type, event.Context.RootWorkflowBasic.ID) event.Type, event.Context.RootWorkflowBasic.ID)
cancelTicker.Stop() // Clean up timer cancelTicker.Stop() // Clean up timer
waitUntilToolFinish(ctx)
if timeoutFn != nil { if timeoutFn != nil {
timeoutFn() timeoutFn()
} }
@ -880,6 +881,7 @@ func cacheToolStreamingResponse(ctx context.Context, event *Event) {
c[event.NodeKey][event.toolResponse.CallID].output = event.toolResponse c[event.NodeKey][event.toolResponse.CallID].output = event.toolResponse
} }
c[event.NodeKey][event.toolResponse.CallID].output.Response += event.toolResponse.Response c[event.NodeKey][event.toolResponse.CallID].output.Response += event.toolResponse.Response
c[event.NodeKey][event.toolResponse.CallID].output.Complete = event.toolResponse.Complete
} }
func getFCInfos(ctx context.Context, nodeKey vo.NodeKey) map[string]*fcInfo { func getFCInfos(ctx context.Context, nodeKey vo.NodeKey) map[string]*fcInfo {
@ -887,6 +889,35 @@ func getFCInfos(ctx context.Context, nodeKey vo.NodeKey) map[string]*fcInfo {
return c[nodeKey] return c[nodeKey]
} }
func waitUntilToolFinish(ctx context.Context) {
var cnt int
outer:
for {
if cnt > 1000 {
return
}
c := ctx.Value(fcCacheKey{}).(map[vo.NodeKey]map[string]*fcInfo)
if len(c) == 0 {
return
}
for _, m := range c {
for _, info := range m {
if info.output == nil {
cnt++
continue outer
}
if !info.output.Complete {
cnt++
continue outer
}
}
}
}
}
func (f *fcInfo) inputString() string { func (f *fcInfo) inputString() string {
if f.input == nil { if f.input == nil {
return "" return ""

View File

@ -0,0 +1,74 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package execute
import (
"errors"
"io"
"sync"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
)
type StreamContainer struct {
sw *schema.StreamWriter[*entity.Message]
subStreams chan *schema.StreamReader[*entity.Message]
wg sync.WaitGroup
}
func NewStreamContainer(sw *schema.StreamWriter[*entity.Message]) *StreamContainer {
return &StreamContainer{
sw: sw,
subStreams: make(chan *schema.StreamReader[*entity.Message]),
}
}
func (sc *StreamContainer) AddChild(sr *schema.StreamReader[*entity.Message]) {
sc.wg.Add(1)
sc.subStreams <- sr
}
func (sc *StreamContainer) PipeAll() {
sc.wg.Add(1)
for sr := range sc.subStreams {
go func() {
defer sr.Close()
for {
msg, err := sr.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
sc.wg.Done()
return
}
}
sc.sw.Send(msg, err)
}
}()
}
}
func (sc *StreamContainer) Done() {
sc.wg.Done()
sc.wg.Wait()
close(sc.subStreams)
sc.sw.Close()
}

View File

@ -27,7 +27,7 @@ import (
type workflowToolOption struct { type workflowToolOption struct {
resumeReq *entity.ResumeRequest resumeReq *entity.ResumeRequest
sw *schema.StreamWriter[*entity.Message] streamContainer *StreamContainer
exeCfg workflowModel.ExecuteConfig exeCfg workflowModel.ExecuteConfig
allInterruptEvents map[string]*entity.ToolInterruptEvent allInterruptEvents map[string]*entity.ToolInterruptEvent
parentTokenCollector *TokenCollector parentTokenCollector *TokenCollector
@ -40,9 +40,9 @@ func WithResume(req *entity.ResumeRequest, all map[string]*entity.ToolInterruptE
}) })
} }
func WithIntermediateStreamWriter(sw *schema.StreamWriter[*entity.Message]) tool.Option { func WithParentStreamContainer(sc *StreamContainer) tool.Option {
return tool.WrapImplSpecificOptFn(func(opts *workflowToolOption) { return tool.WrapImplSpecificOptFn(func(opts *workflowToolOption) {
opts.sw = sw opts.streamContainer = sc
}) })
} }
@ -57,9 +57,9 @@ func GetResumeRequest(opts ...tool.Option) (*entity.ResumeRequest, map[string]*e
return opt.resumeReq, opt.allInterruptEvents return opt.resumeReq, opt.allInterruptEvents
} }
func GetIntermediateStreamWriter(opts ...tool.Option) *schema.StreamWriter[*entity.Message] { func GetParentStreamContainer(opts ...tool.Option) *StreamContainer {
opt := tool.GetImplSpecificOptions(&workflowToolOption{}, opts...) opt := tool.GetImplSpecificOptions(&workflowToolOption{}, opts...)
return opt.sw return opt.streamContainer
} }
func GetExecuteConfig(opts ...tool.Option) workflowModel.ExecuteConfig { func GetExecuteConfig(opts ...tool.Option) workflowModel.ExecuteConfig {
@ -67,11 +67,22 @@ func GetExecuteConfig(opts ...tool.Option) workflowModel.ExecuteConfig {
return opt.exeCfg return opt.exeCfg
} }
// WithMessagePipe returns an Option which is meant to be passed to the tool workflow, as well as a StreamReader to read the messages from the tool workflow. // WithMessagePipe returns an Option which is meant to be passed to the tool workflow,
// This Option will apply to ALL workflow tools to be executed by eino's ToolsNode. The workflow tools will emit messages to this stream. // as well as a StreamReader to read the messages from the tool workflow.
// This Option will apply to ALL workflow tools to be executed by eino's ToolsNode.
// The workflow tools will emit messages to this stream.
// The caller can receive from the returned StreamReader to get the messages from the tool workflow. // The caller can receive from the returned StreamReader to get the messages from the tool workflow.
func WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message]) { func WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message], func()) {
sr, sw := schema.Pipe[*entity.Message](10) sr, sw := schema.Pipe[*entity.Message](10)
opt := compose.WithToolsNodeOption(compose.WithToolOption(WithIntermediateStreamWriter(sw))) container := &StreamContainer{
return opt, sr sw: sw,
subStreams: make(chan *schema.StreamReader[*entity.Message]),
}
go container.PipeAll()
opt := compose.WithToolsNodeOption(compose.WithToolOption(WithParentStreamContainer(container)))
return opt, sr, func() {
container.Done()
}
} }

View File

@ -446,9 +446,6 @@ func (b *Batch) Invoke(ctx context.Context, in map[string]any, opts ...nodes.Nod
return nil, err return nil, err
} }
fmt.Println("save interruptEvent in state within batch: ", iEvent)
fmt.Println("save composite info in state within batch: ", compState)
return nil, compose.InterruptAndRerun return nil, compose.InterruptAndRerun
} else { } else {
err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error { err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error {

View File

@ -92,7 +92,6 @@ func ConvertInputs(ctx context.Context, in map[string]any, tInfo map[string]*vo.
t, ok := tInfo[k] t, ok := tInfo[k]
if !ok { if !ok {
// for input fields not explicitly defined, just pass the string value through // for input fields not explicitly defined, just pass the string value through
logs.CtxWarnf(ctx, "input %s not found in type info", k)
if !options.skipUnknownFields { if !options.skipUnknownFields {
out[k] = in[k] out[k] = in[k]
} }
@ -323,7 +322,6 @@ func convertToObject(ctx context.Context, in any, path string, t *vo.TypeInfo, o
propType, ok := t.Properties[k] propType, ok := t.Properties[k]
if !ok { if !ok {
// for input fields not explicitly defined, just pass the value through // for input fields not explicitly defined, just pass the value through
logs.CtxWarnf(ctx, "input %s.%s not found in type info", path, k)
if !options.skipUnknownFields { if !options.skipUnknownFields {
out[k] = v out[k] = v
} }

View File

@ -143,6 +143,7 @@ const (
knowledgeUserPromptTemplateKey = "knowledge_user_prompt_prefix" knowledgeUserPromptTemplateKey = "knowledge_user_prompt_prefix"
templateNodeKey = "template" templateNodeKey = "template"
llmNodeKey = "llm" llmNodeKey = "llm"
reactGraphName = "workflow_llm_react_agent"
outputConvertNodeKey = "output_convert" outputConvertNodeKey = "output_convert"
) )
@ -620,6 +621,7 @@ func (c *Config) Build(ctx context.Context, ns *schema2.NodeSchema, _ ...schema2
ToolCallingModel: m, ToolCallingModel: m,
ToolsConfig: compose.ToolsNodeConfig{Tools: tools}, ToolsConfig: compose.ToolsNodeConfig{Tools: tools},
ModelNodeName: agentModelName, ModelNodeName: agentModelName,
GraphName: reactGraphName,
} }
if len(toolsReturnDirectly) > 0 { if len(toolsReturnDirectly) > 0 {
@ -635,7 +637,7 @@ func (c *Config) Build(ctx context.Context, ns *schema2.NodeSchema, _ ...schema2
} }
agentNode, opts := reactAgent.ExportGraph() agentNode, opts := reactAgent.ExportGraph()
opts = append(opts, compose.WithNodeName("workflow_llm_react_agent")) opts = append(opts, compose.WithNodeName(reactGraphName))
_ = g.AddGraphNode(llmNodeKey, agentNode, opts...) _ = g.AddGraphNode(llmNodeKey, agentNode, opts...)
} else { } else {
_ = g.AddChatModelNode(llmNodeKey, modelWithInfo) _ = g.AddChatModelNode(llmNodeKey, modelWithInfo)
@ -867,12 +869,12 @@ func jsonParse(ctx context.Context, data string, schema_ map[string]*vo.TypeInfo
} }
type llmOptions struct { type llmOptions struct {
toolWorkflowSW *schema.StreamWriter[*entity.Message] toolWorkflowContainer *execute.StreamContainer
} }
func WithToolWorkflowMessageWriter(sw *schema.StreamWriter[*entity.Message]) nodes.NodeOption { func WithToolWorkflowStreamContainer(container *execute.StreamContainer) nodes.NodeOption {
return nodes.WrapImplSpecificOptFn(func(o *llmOptions) { return nodes.WrapImplSpecificOptFn(func(o *llmOptions) {
o.toolWorkflowSW = sw o.toolWorkflowContainer = container
}) })
} }
@ -880,7 +882,8 @@ type llmState = map[string]any
const agentModelName = "agent_model" const agentModelName = "agent_model"
func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...nodes.NodeOption) (composeOpts []compose.Option, resumingEvent *entity.InterruptEvent, err error) { func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...nodes.NodeOption) (
composeOpts []compose.Option, resumingEvent *entity.InterruptEvent, err error) {
c := execute.GetExeCtx(ctx) c := execute.GetExeCtx(ctx)
if c != nil { if c != nil {
resumingEvent = c.NodeCtx.ResumingEvent resumingEvent = c.NodeCtx.ResumingEvent
@ -890,7 +893,7 @@ func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...nodes.NodeO
if c != nil && c.RootCtx.ResumeEvent != nil { if c != nil && c.RootCtx.ResumeEvent != nil {
// check if we are not resuming, but previously interrupted. Interrupt immediately. // check if we are not resuming, but previously interrupted. Interrupt immediately.
if resumingEvent == nil { if resumingEvent == nil {
err := compose.ProcessState(ctx, func(ctx context.Context, state ToolInterruptEventStore) error { err = compose.ProcessState(ctx, func(ctx context.Context, state ToolInterruptEventStore) error {
var e error var e error
previousToolES, e = state.GetToolInterruptEvents(c.NodeKey) previousToolES, e = state.GetToolInterruptEvents(c.NodeKey)
if e != nil { if e != nil {
@ -899,11 +902,12 @@ func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...nodes.NodeO
return nil return nil
}) })
if err != nil { if err != nil {
return nil, nil, err return
} }
if len(previousToolES) > 0 { if len(previousToolES) > 0 {
return nil, nil, compose.InterruptAndRerun err = compose.InterruptAndRerun
return
} }
} }
} }
@ -936,7 +940,7 @@ func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...nodes.NodeO
return e return e
}) })
if err != nil { if err != nil {
return nil, nil, err return
} }
composeOpts = append(composeOpts, compose.WithToolsNodeOption( composeOpts = append(composeOpts, compose.WithToolsNodeOption(
compose.WithToolOption( compose.WithToolOption(
@ -986,27 +990,9 @@ func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...nodes.NodeO
} }
llmOpts := nodes.GetImplSpecificOptions(&llmOptions{}, opts...) llmOpts := nodes.GetImplSpecificOptions(&llmOptions{}, opts...)
if llmOpts.toolWorkflowSW != nil { if container := llmOpts.toolWorkflowContainer; container != nil {
toolMsgOpt, toolMsgSR := execute.WithMessagePipe() composeOpts = append(composeOpts, compose.WithToolsNodeOption(compose.WithToolOption(
composeOpts = append(composeOpts, toolMsgOpt) execute.WithParentStreamContainer(container))))
safego.Go(ctx, func() {
defer toolMsgSR.Close()
for {
msg, err := toolMsgSR.Recv()
if err != nil {
if err == io.EOF {
return
}
logs.CtxErrorf(ctx, "failed to receive message from tool workflow: %v", err)
return
}
logs.Infof("received message from tool workflow: %+v", msg)
llmOpts.toolWorkflowSW.Send(msg, nil)
}
})
} }
resolvedSources, err := nodes.ResolveStreamSources(ctx, l.fullSources) resolvedSources, err := nodes.ResolveStreamSources(ctx, l.fullSources)

View File

@ -33,7 +33,7 @@ type asToolImpl struct {
repo workflow.Repository repo workflow.Repository
} }
func (a *asToolImpl) WithMessagePipe() (einoCompose.Option, *schema.StreamReader[*entity.Message]) { func (a *asToolImpl) WithMessagePipe() (einoCompose.Option, *schema.StreamReader[*entity.Message], func()) {
return execute.WithMessagePipe() return execute.WithMessagePipe()
} }

View File

@ -1,3 +1,19 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: interface.go // Source: interface.go
// //
@ -484,12 +500,13 @@ func (mr *MockServiceMockRecorder) WithExecuteConfig(cfg any) *gomock.Call {
} }
// WithMessagePipe mocks base method. // WithMessagePipe mocks base method.
func (m *MockService) WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message]) { func (m *MockService) WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message], func()) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "WithMessagePipe") ret := m.ctrl.Call(m, "WithMessagePipe")
ret0, _ := ret[0].(compose.Option) ret0, _ := ret[0].(compose.Option)
ret1, _ := ret[1].(*schema.StreamReader[*entity.Message]) ret1, _ := ret[1].(*schema.StreamReader[*entity.Message])
return ret0, ret1 ret2, _ := ret[2].(func())
return ret0, ret1, ret2
} }
// WithMessagePipe indicates an expected call of WithMessagePipe. // WithMessagePipe indicates an expected call of WithMessagePipe.