fix: workflow tool closes stream writer correctly (#1839)
This commit is contained in:
@@ -1815,212 +1815,212 @@ func TestUpdateWorkflowMeta(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
//func TestSimpleInvokableToolWithReturnVariables(t *testing.T) {
|
||||
// mockey.PatchConvey("simple invokable tool with return variables", t, func() {
|
||||
// r := newWfTestRunner(t)
|
||||
// defer r.closeFn()
|
||||
//
|
||||
// toolID := r.load("function_call/tool_workflow_1.json", withID(7492075279843737651), withPublish("v0.0.1"))
|
||||
//
|
||||
// chatModel := &testutil.UTChatModel{
|
||||
// InvokeResultProvider: func(index int, in []*schema.Message) (*schema.Message, error) {
|
||||
// if index == 0 {
|
||||
// return &schema.Message{
|
||||
// Role: schema.Assistant,
|
||||
// ToolCalls: []schema.ToolCall{
|
||||
// {
|
||||
// ID: "1",
|
||||
// Function: schema.FunctionCall{
|
||||
// Name: "ts_test_wf_test_wf",
|
||||
// Arguments: "{}",
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
// ResponseMeta: &schema.ResponseMeta{
|
||||
// Usage: &schema.TokenUsage{
|
||||
// PromptTokens: 10,
|
||||
// CompletionTokens: 11,
|
||||
// TotalTokens: 21,
|
||||
// },
|
||||
// },
|
||||
// }, nil
|
||||
// } else if index == 1 {
|
||||
// return &schema.Message{
|
||||
// Role: schema.Assistant,
|
||||
// Content: "final_answer",
|
||||
// ResponseMeta: &schema.ResponseMeta{
|
||||
// Usage: &schema.TokenUsage{
|
||||
// PromptTokens: 5,
|
||||
// CompletionTokens: 6,
|
||||
// TotalTokens: 11,
|
||||
// },
|
||||
// },
|
||||
// }, nil
|
||||
// } else {
|
||||
// return nil, fmt.Errorf("unexpected index: %d", index)
|
||||
// }
|
||||
// },
|
||||
// }
|
||||
// r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(chatModel, nil, nil).AnyTimes()
|
||||
//
|
||||
// id := r.load("function_call/llm_with_workflow_as_tool.json")
|
||||
// defer func() {
|
||||
// post[workflow.DeleteWorkflowResponse](r, &workflow.DeleteWorkflowRequest{
|
||||
// WorkflowID: id,
|
||||
// })
|
||||
// }()
|
||||
//
|
||||
// exeID := r.testRun(id, map[string]string{
|
||||
// "input": "this is the user input",
|
||||
// })
|
||||
//
|
||||
// e := r.getProcess(id, exeID)
|
||||
// e.assertSuccess()
|
||||
// assert.Equal(t, map[string]any{
|
||||
// "output": "final_answer",
|
||||
// }, mustUnmarshalToMap(t, e.output))
|
||||
// e.tokenEqual(15, 17)
|
||||
//
|
||||
// mockey.PatchConvey("check behavior if stream run", func() {
|
||||
// chatModel.Reset()
|
||||
//
|
||||
// defer r.runServer()()
|
||||
//
|
||||
// r.publish(id, "v0.0.1", true)
|
||||
//
|
||||
// sseReader := r.openapiStream(id, map[string]any{
|
||||
// "input": "hello",
|
||||
// })
|
||||
// err := sseReader.ForEach(t.Context(), func(e *sse.Event) error {
|
||||
// t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data))
|
||||
// return nil
|
||||
// })
|
||||
// assert.NoError(t, err)
|
||||
//
|
||||
// // check workflow references are correct
|
||||
// refs := post[workflow.GetWorkflowReferencesResponse](r, &workflow.GetWorkflowReferencesRequest{
|
||||
// WorkflowID: toolID,
|
||||
// })
|
||||
// assert.Equal(t, 1, len(refs.Data.WorkflowList))
|
||||
// assert.Equal(t, id, refs.Data.WorkflowList[0].WorkflowID)
|
||||
// })
|
||||
// })
|
||||
//}
|
||||
func TestSimpleInvokableToolWithReturnVariables(t *testing.T) {
|
||||
mockey.PatchConvey("simple invokable tool with return variables", t, func() {
|
||||
r := newWfTestRunner(t)
|
||||
defer r.closeFn()
|
||||
|
||||
//func TestReturnDirectlyStreamableTool(t *testing.T) {
|
||||
// mockey.PatchConvey("return directly streamable tool", t, func() {
|
||||
// r := newWfTestRunner(t)
|
||||
// defer r.closeFn()
|
||||
//
|
||||
// outerModel := &testutil.UTChatModel{
|
||||
// StreamResultProvider: func(index int, in []*schema.Message) (*schema.StreamReader[*schema.Message], error) {
|
||||
// if index == 0 {
|
||||
// return schema.StreamReaderFromArray([]*schema.Message{
|
||||
// {
|
||||
// Role: schema.Assistant,
|
||||
// ToolCalls: []schema.ToolCall{
|
||||
// {
|
||||
// ID: "1",
|
||||
// Function: schema.FunctionCall{
|
||||
// Name: "ts_test_wf_test_wf",
|
||||
// Arguments: `{"input": "input for inner model"}`,
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
// ResponseMeta: &schema.ResponseMeta{
|
||||
// Usage: &schema.TokenUsage{
|
||||
// PromptTokens: 10,
|
||||
// CompletionTokens: 11,
|
||||
// TotalTokens: 21,
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
// }), nil
|
||||
// } else {
|
||||
// return nil, fmt.Errorf("unexpected index: %d", index)
|
||||
// }
|
||||
// },
|
||||
// }
|
||||
//
|
||||
// innerModel := &testutil.UTChatModel{
|
||||
// StreamResultProvider: func(index int, in []*schema.Message) (*schema.StreamReader[*schema.Message], error) {
|
||||
// if index == 0 {
|
||||
// return schema.StreamReaderFromArray([]*schema.Message{
|
||||
// {
|
||||
// Role: schema.Assistant,
|
||||
// Content: "I ",
|
||||
// ResponseMeta: &schema.ResponseMeta{
|
||||
// Usage: &schema.TokenUsage{
|
||||
// PromptTokens: 5,
|
||||
// CompletionTokens: 6,
|
||||
// TotalTokens: 11,
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// Role: schema.Assistant,
|
||||
// Content: "don't know",
|
||||
// ResponseMeta: &schema.ResponseMeta{
|
||||
// Usage: &schema.TokenUsage{
|
||||
// CompletionTokens: 8,
|
||||
// TotalTokens: 8,
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// Role: schema.Assistant,
|
||||
// Content: ".",
|
||||
// ResponseMeta: &schema.ResponseMeta{
|
||||
// Usage: &schema.TokenUsage{
|
||||
// CompletionTokens: 2,
|
||||
// TotalTokens: 2,
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
// }), nil
|
||||
// } else {
|
||||
// return nil, fmt.Errorf("unexpected index: %d", index)
|
||||
// }
|
||||
// },
|
||||
// }
|
||||
//
|
||||
// r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, params *model.LLMParams) (model2.BaseChatModel, *modelmgr.Model, error) {
|
||||
// if params.ModelType == 1706077826 {
|
||||
// innerModel.ModelType = strconv.FormatInt(params.ModelType, 10)
|
||||
// return innerModel, nil, nil
|
||||
// } else {
|
||||
// outerModel.ModelType = strconv.FormatInt(params.ModelType, 10)
|
||||
// return outerModel, nil, nil
|
||||
// }
|
||||
// }).AnyTimes()
|
||||
//
|
||||
// r.load("function_call/tool_workflow_2.json", withID(7492615435881709608), withPublish("v0.0.1"))
|
||||
// id := r.load("function_call/llm_workflow_stream_tool.json")
|
||||
//
|
||||
// exeID := r.testRun(id, map[string]string{
|
||||
// "input": "this is the user input",
|
||||
// })
|
||||
// e := r.getProcess(id, exeID)
|
||||
// e.assertSuccess()
|
||||
// assert.Equal(t, "this is the streaming output I don't know.", e.output)
|
||||
// e.tokenEqual(15, 27)
|
||||
//
|
||||
// mockey.PatchConvey("check behavior if stream run", func() {
|
||||
// outerModel.Reset()
|
||||
// innerModel.Reset()
|
||||
// defer r.runServer()()
|
||||
// r.publish(id, "v0.0.1", true)
|
||||
// sseReader := r.openapiStream(id, map[string]any{
|
||||
// "input": "hello",
|
||||
// })
|
||||
// err := sseReader.ForEach(t.Context(), func(e *sse.Event) error {
|
||||
// t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data))
|
||||
// return nil
|
||||
// })
|
||||
// assert.NoError(t, err)
|
||||
// })
|
||||
// })
|
||||
//}
|
||||
toolID := r.load("function_call/tool_workflow_1.json", withID(7492075279843737651), withPublish("v0.0.1"))
|
||||
|
||||
chatModel := &testutil.UTChatModel{
|
||||
InvokeResultProvider: func(index int, in []*schema.Message) (*schema.Message, error) {
|
||||
if index == 0 {
|
||||
return &schema.Message{
|
||||
Role: schema.Assistant,
|
||||
ToolCalls: []schema.ToolCall{
|
||||
{
|
||||
ID: "1",
|
||||
Function: schema.FunctionCall{
|
||||
Name: "ts_test_wf_test_wf",
|
||||
Arguments: "{}",
|
||||
},
|
||||
},
|
||||
},
|
||||
ResponseMeta: &schema.ResponseMeta{
|
||||
Usage: &schema.TokenUsage{
|
||||
PromptTokens: 10,
|
||||
CompletionTokens: 11,
|
||||
TotalTokens: 21,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
} else if index == 1 {
|
||||
return &schema.Message{
|
||||
Role: schema.Assistant,
|
||||
Content: "final_answer",
|
||||
ResponseMeta: &schema.ResponseMeta{
|
||||
Usage: &schema.TokenUsage{
|
||||
PromptTokens: 5,
|
||||
CompletionTokens: 6,
|
||||
TotalTokens: 11,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
} else {
|
||||
return nil, fmt.Errorf("unexpected index: %d", index)
|
||||
}
|
||||
},
|
||||
}
|
||||
r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(chatModel, nil, nil).AnyTimes()
|
||||
|
||||
id := r.load("function_call/llm_with_workflow_as_tool.json")
|
||||
defer func() {
|
||||
post[workflow.DeleteWorkflowResponse](r, &workflow.DeleteWorkflowRequest{
|
||||
WorkflowID: id,
|
||||
})
|
||||
}()
|
||||
|
||||
exeID := r.testRun(id, map[string]string{
|
||||
"input": "this is the user input",
|
||||
})
|
||||
|
||||
e := r.getProcess(id, exeID)
|
||||
e.assertSuccess()
|
||||
assert.Equal(t, map[string]any{
|
||||
"output": "final_answer",
|
||||
}, mustUnmarshalToMap(t, e.output))
|
||||
e.tokenEqual(15, 17)
|
||||
|
||||
mockey.PatchConvey("check behavior if stream run", func() {
|
||||
chatModel.Reset()
|
||||
|
||||
defer r.runServer()()
|
||||
|
||||
r.publish(id, "v0.0.1", true)
|
||||
|
||||
sseReader := r.openapiStream(id, map[string]any{
|
||||
"input": "hello",
|
||||
})
|
||||
err := sseReader.ForEach(t.Context(), func(e *sse.Event) error {
|
||||
t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data))
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// check workflow references are correct
|
||||
refs := post[workflow.GetWorkflowReferencesResponse](r, &workflow.GetWorkflowReferencesRequest{
|
||||
WorkflowID: toolID,
|
||||
})
|
||||
assert.Equal(t, 1, len(refs.Data.WorkflowList))
|
||||
assert.Equal(t, id, refs.Data.WorkflowList[0].WorkflowID)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestReturnDirectlyStreamableTool(t *testing.T) {
|
||||
mockey.PatchConvey("return directly streamable tool", t, func() {
|
||||
r := newWfTestRunner(t)
|
||||
defer r.closeFn()
|
||||
|
||||
outerModel := &testutil.UTChatModel{
|
||||
StreamResultProvider: func(index int, in []*schema.Message) (*schema.StreamReader[*schema.Message], error) {
|
||||
if index == 0 {
|
||||
return schema.StreamReaderFromArray([]*schema.Message{
|
||||
{
|
||||
Role: schema.Assistant,
|
||||
ToolCalls: []schema.ToolCall{
|
||||
{
|
||||
ID: "1",
|
||||
Function: schema.FunctionCall{
|
||||
Name: "ts_test_wf_test_wf",
|
||||
Arguments: `{"input": "input for inner model"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
ResponseMeta: &schema.ResponseMeta{
|
||||
Usage: &schema.TokenUsage{
|
||||
PromptTokens: 10,
|
||||
CompletionTokens: 11,
|
||||
TotalTokens: 21,
|
||||
},
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
} else {
|
||||
return nil, fmt.Errorf("unexpected index: %d", index)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
innerModel := &testutil.UTChatModel{
|
||||
StreamResultProvider: func(index int, in []*schema.Message) (*schema.StreamReader[*schema.Message], error) {
|
||||
if index == 0 {
|
||||
return schema.StreamReaderFromArray([]*schema.Message{
|
||||
{
|
||||
Role: schema.Assistant,
|
||||
Content: "I ",
|
||||
ResponseMeta: &schema.ResponseMeta{
|
||||
Usage: &schema.TokenUsage{
|
||||
PromptTokens: 5,
|
||||
CompletionTokens: 6,
|
||||
TotalTokens: 11,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: schema.Assistant,
|
||||
Content: "don't know",
|
||||
ResponseMeta: &schema.ResponseMeta{
|
||||
Usage: &schema.TokenUsage{
|
||||
CompletionTokens: 8,
|
||||
TotalTokens: 8,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: schema.Assistant,
|
||||
Content: ".",
|
||||
ResponseMeta: &schema.ResponseMeta{
|
||||
Usage: &schema.TokenUsage{
|
||||
CompletionTokens: 2,
|
||||
TotalTokens: 2,
|
||||
},
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
} else {
|
||||
return nil, fmt.Errorf("unexpected index: %d", index)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, params *model.LLMParams) (model2.BaseChatModel, *modelmgr.Model, error) {
|
||||
if params.ModelType == 1706077826 {
|
||||
innerModel.ModelType = strconv.FormatInt(params.ModelType, 10)
|
||||
return innerModel, nil, nil
|
||||
} else {
|
||||
outerModel.ModelType = strconv.FormatInt(params.ModelType, 10)
|
||||
return outerModel, nil, nil
|
||||
}
|
||||
}).AnyTimes()
|
||||
|
||||
r.load("function_call/tool_workflow_2.json", withID(7492615435881709608), withPublish("v0.0.1"))
|
||||
id := r.load("function_call/llm_workflow_stream_tool.json")
|
||||
|
||||
exeID := r.testRun(id, map[string]string{
|
||||
"input": "this is the user input",
|
||||
})
|
||||
e := r.getProcess(id, exeID)
|
||||
e.assertSuccess()
|
||||
assert.Equal(t, "this is the streaming output I don't know.", e.output)
|
||||
e.tokenEqual(15, 27)
|
||||
|
||||
mockey.PatchConvey("check behavior if stream run", func() {
|
||||
outerModel.Reset()
|
||||
innerModel.Reset()
|
||||
defer r.runServer()()
|
||||
r.publish(id, "v0.0.1", true)
|
||||
sseReader := r.openapiStream(id, map[string]any{
|
||||
"input": "hello",
|
||||
})
|
||||
err := sseReader.ForEach(t.Context(), func(e *sse.Event) error {
|
||||
t.Logf("sse id: %s, type: %s, data: %s", e.ID, e.Type, string(e.Data))
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestSimpleInterruptibleTool(t *testing.T) {
|
||||
mockey.PatchConvey("test simple interruptible tool", t, func() {
|
||||
@@ -2082,233 +2082,231 @@ func TestSimpleInterruptibleTool(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
//func TestStreamableToolWithMultipleInterrupts(t *testing.T) {
|
||||
// mockey.PatchConvey("return directly streamable tool with multiple interrupts", t, func() {
|
||||
// r := newWfTestRunner(t)
|
||||
// defer r.closeFn()
|
||||
//
|
||||
// outerModel := &testutil.UTChatModel{
|
||||
// StreamResultProvider: func(index int, in []*schema.Message) (*schema.StreamReader[*schema.Message], error) {
|
||||
// if index == 0 {
|
||||
// return schema.StreamReaderFromArray([]*schema.Message{
|
||||
// {
|
||||
// Role: schema.Assistant,
|
||||
// ToolCalls: []schema.ToolCall{
|
||||
// {
|
||||
// ID: "1",
|
||||
// Function: schema.FunctionCall{
|
||||
// Name: "ts_test_wf_test_wf",
|
||||
// Arguments: `{"input": "what's your name and age"}`,
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
// ResponseMeta: &schema.ResponseMeta{
|
||||
// Usage: &schema.TokenUsage{
|
||||
// PromptTokens: 6,
|
||||
// CompletionTokens: 7,
|
||||
// TotalTokens: 13,
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
// }), nil
|
||||
// } else if index == 1 {
|
||||
// return schema.StreamReaderFromArray([]*schema.Message{
|
||||
// {
|
||||
// Role: schema.Assistant,
|
||||
// Content: "I now know your ",
|
||||
// ResponseMeta: &schema.ResponseMeta{
|
||||
// Usage: &schema.TokenUsage{
|
||||
// PromptTokens: 5,
|
||||
// CompletionTokens: 8,
|
||||
// TotalTokens: 13,
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// Role: schema.Assistant,
|
||||
// Content: "name is Eino and age is 1.",
|
||||
// ResponseMeta: &schema.ResponseMeta{
|
||||
// Usage: &schema.TokenUsage{
|
||||
// CompletionTokens: 10,
|
||||
// TotalTokens: 17,
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
// }), nil
|
||||
// } else {
|
||||
// return nil, fmt.Errorf("unexpected index: %d", index)
|
||||
// }
|
||||
// },
|
||||
// }
|
||||
//
|
||||
// innerModel := &testutil.UTChatModel{
|
||||
// InvokeResultProvider: func(index int, in []*schema.Message) (*schema.Message, error) {
|
||||
// if index == 0 {
|
||||
// return &schema.Message{
|
||||
// Role: schema.Assistant,
|
||||
// Content: `{"question": "what's your age?"}`,
|
||||
// ResponseMeta: &schema.ResponseMeta{
|
||||
// Usage: &schema.TokenUsage{
|
||||
// PromptTokens: 6,
|
||||
// CompletionTokens: 7,
|
||||
// TotalTokens: 13,
|
||||
// },
|
||||
// },
|
||||
// }, nil
|
||||
// } else if index == 1 {
|
||||
// return &schema.Message{
|
||||
// Role: schema.Assistant,
|
||||
// Content: `{"fields": {"name": "eino", "age": 1}}`,
|
||||
// ResponseMeta: &schema.ResponseMeta{
|
||||
// Usage: &schema.TokenUsage{
|
||||
// PromptTokens: 8,
|
||||
// CompletionTokens: 10,
|
||||
// TotalTokens: 18,
|
||||
// },
|
||||
// },
|
||||
// }, nil
|
||||
// } else {
|
||||
// return nil, fmt.Errorf("unexpected index: %d", index)
|
||||
// }
|
||||
// },
|
||||
// }
|
||||
//
|
||||
// r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, params *model.LLMParams) (model2.BaseChatModel, *modelmgr.Model, error) {
|
||||
// if params.ModelType == 1706077827 {
|
||||
// outerModel.ModelType = strconv.FormatInt(params.ModelType, 10)
|
||||
// return outerModel, nil, nil
|
||||
// } else {
|
||||
// innerModel.ModelType = strconv.FormatInt(params.ModelType, 10)
|
||||
// return innerModel, nil, nil
|
||||
// }
|
||||
// }).AnyTimes()
|
||||
//
|
||||
// r.load("function_call/tool_workflow_3.json", withID(7492615435881709611), withPublish("v0.0.1"))
|
||||
// id := r.load("function_call/llm_workflow_stream_tool_1.json")
|
||||
//
|
||||
// exeID := r.testRun(id, map[string]string{
|
||||
// "input": "this is the user input",
|
||||
// })
|
||||
//
|
||||
// e := r.getProcess(id, exeID)
|
||||
// assert.NotNil(t, e.event)
|
||||
// e.tokenEqual(0, 0)
|
||||
//
|
||||
// r.testResume(id, exeID, e.event.ID, "my name is eino")
|
||||
// e2 := r.getProcess(id, exeID, withPreviousEventID(e.event.ID))
|
||||
// assert.NotNil(t, e2.event)
|
||||
// e2.tokenEqual(0, 0)
|
||||
//
|
||||
// r.testResume(id, exeID, e2.event.ID, "1 year old")
|
||||
// e3 := r.getProcess(id, exeID, withPreviousEventID(e2.event.ID))
|
||||
// e3.assertSuccess()
|
||||
// assert.Equal(t, "the name is eino, age is 1", e3.output)
|
||||
// e3.tokenEqual(20, 24)
|
||||
// })
|
||||
//}
|
||||
func TestStreamableToolWithMultipleInterrupts(t *testing.T) {
|
||||
mockey.PatchConvey("return directly streamable tool with multiple interrupts", t, func() {
|
||||
r := newWfTestRunner(t)
|
||||
defer r.closeFn()
|
||||
|
||||
//func TestNodeWithBatchEnabled(t *testing.T) {
|
||||
// mockey.PatchConvey("test node with batch enabled", t, func() {
|
||||
// r := newWfTestRunner(t)
|
||||
// defer r.closeFn()
|
||||
//
|
||||
// r.load("batch/sub_workflow_as_batch.json", withID(7469707607914217512), withPublish("v0.0.1"))
|
||||
//
|
||||
// chatModel := &testutil.UTChatModel{
|
||||
// InvokeResultProvider: func(index int, in []*schema.Message) (*schema.Message, error) {
|
||||
// if index == 0 {
|
||||
// return &schema.Message{
|
||||
// Role: schema.Assistant,
|
||||
// Content: "answer。for index 0",
|
||||
// ResponseMeta: &schema.ResponseMeta{
|
||||
// Usage: &schema.TokenUsage{
|
||||
// PromptTokens: 5,
|
||||
// CompletionTokens: 6,
|
||||
// TotalTokens: 11,
|
||||
// },
|
||||
// },
|
||||
// }, nil
|
||||
// } else if index == 1 {
|
||||
// return &schema.Message{
|
||||
// Role: schema.Assistant,
|
||||
// Content: "answer,for index 1",
|
||||
// ResponseMeta: &schema.ResponseMeta{
|
||||
// Usage: &schema.TokenUsage{
|
||||
// PromptTokens: 5,
|
||||
// CompletionTokens: 6,
|
||||
// TotalTokens: 11,
|
||||
// },
|
||||
// },
|
||||
// }, nil
|
||||
// } else {
|
||||
// return nil, fmt.Errorf("unexpected index: %d", index)
|
||||
// }
|
||||
// },
|
||||
// }
|
||||
// r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(chatModel, nil, nil).AnyTimes()
|
||||
//
|
||||
// id := r.load("batch/node_batches.json")
|
||||
//
|
||||
// exeID := r.testRun(id, map[string]string{
|
||||
// "input": `["first input", "second input"]`,
|
||||
// })
|
||||
// e := r.getProcess(id, exeID)
|
||||
// e.assertSuccess()
|
||||
// assert.Equal(t, map[string]any{
|
||||
// "output": []any{
|
||||
// map[string]any{
|
||||
// "output": []any{
|
||||
// "answer",
|
||||
// "for index 0",
|
||||
// },
|
||||
// "input": "answer。for index 0",
|
||||
// },
|
||||
// map[string]any{
|
||||
// "output": []any{
|
||||
// "answer",
|
||||
// "for index 1",
|
||||
// },
|
||||
// "input": "answer,for index 1",
|
||||
// },
|
||||
// },
|
||||
// }, mustUnmarshalToMap(t, e.output))
|
||||
// e.tokenEqual(10, 12)
|
||||
//
|
||||
// // verify this workflow has previously succeeded a test run
|
||||
// result := r.getNodeExeHistory(id, "", "100001", ptr.Of(workflow.NodeHistoryScene_TestRunInput))
|
||||
// assert.True(t, len(result.Output) > 0)
|
||||
//
|
||||
// // verify querying this node's result for a particular test run
|
||||
// result = r.getNodeExeHistory(id, exeID, "178876", nil)
|
||||
// assert.True(t, len(result.Output) > 0)
|
||||
//
|
||||
// mockey.PatchConvey("test node debug with batch mode", func() {
|
||||
// exeID = r.nodeDebug(id, "178876", withNDBatch(map[string]string{"item1": `[{"output":"output_1"},{"output":"output_2"}]`}))
|
||||
// e = r.getProcess(id, exeID)
|
||||
// e.assertSuccess()
|
||||
// assert.Equal(t, map[string]any{
|
||||
// "outputList": []any{
|
||||
// map[string]any{
|
||||
// "input": "output_1",
|
||||
// "output": []any{"output_1"},
|
||||
// },
|
||||
// map[string]any{
|
||||
// "input": "output_2",
|
||||
// "output": []any{"output_2"},
|
||||
// },
|
||||
// },
|
||||
// }, mustUnmarshalToMap(t, e.output))
|
||||
//
|
||||
// // verify querying this node's result for this node debug run
|
||||
// result := r.getNodeExeHistory(id, exeID, "178876", nil)
|
||||
// assert.Equal(t, mustUnmarshalToMap(t, e.output), mustUnmarshalToMap(t, result.Output))
|
||||
//
|
||||
// // verify querying this node's has succeeded any node debug run
|
||||
// result = r.getNodeExeHistory(id, "", "178876", ptr.Of(workflow.NodeHistoryScene_TestRunInput))
|
||||
// assert.Equal(t, mustUnmarshalToMap(t, e.output), mustUnmarshalToMap(t, result.Output))
|
||||
// })
|
||||
// })
|
||||
//}
|
||||
outerModel := &testutil.UTChatModel{
|
||||
StreamResultProvider: func(index int, in []*schema.Message) (*schema.StreamReader[*schema.Message], error) {
|
||||
if index == 0 {
|
||||
return schema.StreamReaderFromArray([]*schema.Message{
|
||||
{
|
||||
Role: schema.Assistant,
|
||||
ToolCalls: []schema.ToolCall{
|
||||
{
|
||||
ID: "1",
|
||||
Function: schema.FunctionCall{
|
||||
Name: "ts_test_wf_test_wf",
|
||||
Arguments: `{"input": "what's your name and age"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
ResponseMeta: &schema.ResponseMeta{
|
||||
Usage: &schema.TokenUsage{
|
||||
PromptTokens: 6,
|
||||
CompletionTokens: 7,
|
||||
TotalTokens: 13,
|
||||
},
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
} else if index == 1 {
|
||||
return schema.StreamReaderFromArray([]*schema.Message{
|
||||
{
|
||||
Role: schema.Assistant,
|
||||
Content: "I now know your ",
|
||||
ResponseMeta: &schema.ResponseMeta{
|
||||
Usage: &schema.TokenUsage{
|
||||
PromptTokens: 5,
|
||||
CompletionTokens: 8,
|
||||
TotalTokens: 13,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: schema.Assistant,
|
||||
Content: "name is Eino and age is 1.",
|
||||
ResponseMeta: &schema.ResponseMeta{
|
||||
Usage: &schema.TokenUsage{
|
||||
CompletionTokens: 10,
|
||||
TotalTokens: 17,
|
||||
},
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
} else {
|
||||
return nil, fmt.Errorf("unexpected index: %d", index)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
innerModel := &testutil.UTChatModel{
|
||||
InvokeResultProvider: func(index int, in []*schema.Message) (*schema.Message, error) {
|
||||
if index == 0 {
|
||||
return &schema.Message{
|
||||
Role: schema.Assistant,
|
||||
Content: `{"question": "what's your age?"}`,
|
||||
ResponseMeta: &schema.ResponseMeta{
|
||||
Usage: &schema.TokenUsage{
|
||||
PromptTokens: 6,
|
||||
CompletionTokens: 7,
|
||||
TotalTokens: 13,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
} else if index == 1 {
|
||||
return &schema.Message{
|
||||
Role: schema.Assistant,
|
||||
Content: `{"fields": {"name": "eino", "age": 1}}`,
|
||||
ResponseMeta: &schema.ResponseMeta{
|
||||
Usage: &schema.TokenUsage{
|
||||
PromptTokens: 8,
|
||||
CompletionTokens: 10,
|
||||
TotalTokens: 18,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
} else {
|
||||
return nil, fmt.Errorf("unexpected index: %d", index)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, params *model.LLMParams) (model2.BaseChatModel, *modelmgr.Model, error) {
|
||||
if params.ModelType == 1706077827 {
|
||||
outerModel.ModelType = strconv.FormatInt(params.ModelType, 10)
|
||||
return outerModel, nil, nil
|
||||
} else {
|
||||
innerModel.ModelType = strconv.FormatInt(params.ModelType, 10)
|
||||
return innerModel, nil, nil
|
||||
}
|
||||
}).AnyTimes()
|
||||
|
||||
r.load("function_call/tool_workflow_3.json", withID(7492615435881709611), withPublish("v0.0.1"))
|
||||
id := r.load("function_call/llm_workflow_stream_tool_1.json")
|
||||
|
||||
exeID := r.testRun(id, map[string]string{
|
||||
"input": "this is the user input",
|
||||
})
|
||||
|
||||
e := r.getProcess(id, exeID)
|
||||
assert.NotNil(t, e.event)
|
||||
e.tokenEqual(0, 0)
|
||||
|
||||
r.testResume(id, exeID, e.event.ID, "my name is eino")
|
||||
e2 := r.getProcess(id, exeID, withPreviousEventID(e.event.ID))
|
||||
assert.NotNil(t, e2.event)
|
||||
e2.tokenEqual(0, 0)
|
||||
|
||||
r.testResume(id, exeID, e2.event.ID, "1 year old")
|
||||
e3 := r.getProcess(id, exeID, withPreviousEventID(e2.event.ID))
|
||||
e3.assertSuccess()
|
||||
assert.Equal(t, "the name is eino, age is 1", e3.output)
|
||||
e3.tokenEqual(20, 24)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNodeWithBatchEnabled(t *testing.T) {
|
||||
mockey.PatchConvey("test node with batch enabled", t, func() {
|
||||
r := newWfTestRunner(t)
|
||||
defer r.closeFn()
|
||||
|
||||
r.load("batch/sub_workflow_as_batch.json", withID(7469707607914217512), withPublish("v0.0.1"))
|
||||
|
||||
chatModel := &testutil.UTChatModel{
|
||||
InvokeResultProvider: func(index int, in []*schema.Message) (*schema.Message, error) {
|
||||
if index == 0 {
|
||||
return &schema.Message{
|
||||
Role: schema.Assistant,
|
||||
Content: "answer。for index 0",
|
||||
ResponseMeta: &schema.ResponseMeta{
|
||||
Usage: &schema.TokenUsage{
|
||||
PromptTokens: 5,
|
||||
CompletionTokens: 6,
|
||||
TotalTokens: 11,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
} else if index == 1 {
|
||||
return &schema.Message{
|
||||
Role: schema.Assistant,
|
||||
Content: "answer,for index 1",
|
||||
ResponseMeta: &schema.ResponseMeta{
|
||||
Usage: &schema.TokenUsage{
|
||||
PromptTokens: 5,
|
||||
CompletionTokens: 6,
|
||||
TotalTokens: 11,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
} else {
|
||||
return nil, fmt.Errorf("unexpected index: %d", index)
|
||||
}
|
||||
},
|
||||
}
|
||||
r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(chatModel, nil, nil).AnyTimes()
|
||||
|
||||
id := r.load("batch/node_batches.json")
|
||||
|
||||
exeID := r.testRun(id, map[string]string{
|
||||
"input": `["first input", "second input"]`,
|
||||
})
|
||||
e := r.getProcess(id, exeID)
|
||||
e.assertSuccess()
|
||||
outputMap := mustUnmarshalToMap(t, e.output)
|
||||
assert.Contains(t, outputMap["output"], map[string]any{
|
||||
"output": []any{
|
||||
"answer",
|
||||
"for index 0",
|
||||
},
|
||||
"input": "answer。for index 0",
|
||||
})
|
||||
assert.Contains(t, outputMap["output"], map[string]any{
|
||||
"output": []any{
|
||||
"answer",
|
||||
"for index 1",
|
||||
},
|
||||
"input": "answer,for index 1",
|
||||
})
|
||||
assert.Equal(t, 2, len(outputMap["output"].([]any)))
|
||||
e.tokenEqual(10, 12)
|
||||
|
||||
// verify this workflow has previously succeeded a test run
|
||||
result := r.getNodeExeHistory(id, "", "100001", ptr.Of(workflow.NodeHistoryScene_TestRunInput))
|
||||
assert.True(t, len(result.Output) > 0)
|
||||
|
||||
// verify querying this node's result for a particular test run
|
||||
result = r.getNodeExeHistory(id, exeID, "178876", nil)
|
||||
assert.True(t, len(result.Output) > 0)
|
||||
|
||||
mockey.PatchConvey("test node debug with batch mode", func() {
|
||||
exeID = r.nodeDebug(id, "178876", withNDBatch(map[string]string{"item1": `[{"output":"output_1"},{"output":"output_2"}]`}))
|
||||
e = r.getProcess(id, exeID)
|
||||
e.assertSuccess()
|
||||
assert.Equal(t, map[string]any{
|
||||
"outputList": []any{
|
||||
map[string]any{
|
||||
"input": "output_1",
|
||||
"output": []any{"output_1"},
|
||||
},
|
||||
map[string]any{
|
||||
"input": "output_2",
|
||||
"output": []any{"output_2"},
|
||||
},
|
||||
},
|
||||
}, mustUnmarshalToMap(t, e.output))
|
||||
|
||||
// verify querying this node's result for this node debug run
|
||||
result := r.getNodeExeHistory(id, exeID, "178876", nil)
|
||||
assert.Equal(t, mustUnmarshalToMap(t, e.output), mustUnmarshalToMap(t, result.Output))
|
||||
|
||||
// verify querying this node's has succeeded any node debug run
|
||||
result = r.getNodeExeHistory(id, "", "178876", ptr.Of(workflow.NodeHistoryScene_TestRunInput))
|
||||
assert.Equal(t, mustUnmarshalToMap(t, e.output), mustUnmarshalToMap(t, result.Output))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestStartNodeDefaultValues(t *testing.T) {
|
||||
mockey.PatchConvey("default values", t, func() {
|
||||
|
||||
Reference in New Issue
Block a user