feat: manually mirror opencoze's code from bytedance

Change-Id: I09a73aadda978ad9511264a756b2ce51f5761adf
This commit is contained in:
fanlv
2025-07-20 17:36:12 +08:00
commit 890153324f
14811 changed files with 1923430 additions and 0 deletions

View File

@@ -0,0 +1,181 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package compose
import (
"context"
"errors"
"fmt"
"github.com/cloudwego/eino/compose"
"github.com/spf13/cast"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector"
)
func (s *NodeSchema) OutputPortCount() (int, bool) {
var hasExceptionPort bool
if s.ExceptionConfigs != nil && s.ExceptionConfigs.ProcessType != nil &&
*s.ExceptionConfigs.ProcessType == vo.ErrorProcessTypeExceptionBranch {
hasExceptionPort = true
}
switch s.Type {
case entity.NodeTypeSelector:
return len(mustGetKey[[]*selector.OneClauseSchema]("Clauses", s.Configs)) + 1, hasExceptionPort
case entity.NodeTypeQuestionAnswer:
if mustGetKey[qa.AnswerType]("AnswerType", s.Configs.(map[string]any)) == qa.AnswerByChoices {
if mustGetKey[qa.ChoiceType]("ChoiceType", s.Configs.(map[string]any)) == qa.FixedChoices {
return len(mustGetKey[[]string]("FixedChoices", s.Configs.(map[string]any))) + 1, hasExceptionPort
} else {
return 2, hasExceptionPort
}
}
return 1, hasExceptionPort
case entity.NodeTypeIntentDetector:
intents := mustGetKey[[]string]("Intents", s.Configs.(map[string]any))
return len(intents) + 1, hasExceptionPort
default:
return 1, hasExceptionPort
}
}
type BranchMapping struct {
Normal []map[string]bool
Exception map[string]bool
}
const (
DefaultBranch = "default"
BranchFmt = "branch_%d"
)
func (s *NodeSchema) GetBranch(bMapping *BranchMapping) (*compose.GraphBranch, error) {
if bMapping == nil {
return nil, errors.New("no branch mapping")
}
endNodes := make(map[string]bool)
for i := range bMapping.Normal {
for k := range bMapping.Normal[i] {
endNodes[k] = true
}
}
if bMapping.Exception != nil {
for k := range bMapping.Exception {
endNodes[k] = true
}
}
switch s.Type {
case entity.NodeTypeSelector:
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
choice := in[selector.SelectKey].(int)
if choice < 0 || choice > len(bMapping.Normal) {
return nil, fmt.Errorf("node %s choice out of range: %d", s.Key, choice)
}
choices := make(map[string]bool, len((bMapping.Normal)[choice]))
for k := range (bMapping.Normal)[choice] {
choices[k] = true
}
return choices, nil
}
return compose.NewGraphMultiBranch(condition, endNodes), nil
case entity.NodeTypeQuestionAnswer:
conf := s.Configs.(map[string]any)
if mustGetKey[qa.AnswerType]("AnswerType", conf) == qa.AnswerByChoices {
choiceType := mustGetKey[qa.ChoiceType]("ChoiceType", conf)
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
optionID, ok := nodes.TakeMapValue(in, compose.FieldPath{qa.OptionIDKey})
if !ok {
return nil, fmt.Errorf("failed to take option id from input map: %v", in)
}
if optionID.(string) == "other" {
return (bMapping.Normal)[len(bMapping.Normal)-1], nil
}
if choiceType == qa.DynamicChoices { // all dynamic choices maps to branch 0
return (bMapping.Normal)[0], nil
}
optionIDInt, ok := qa.AlphabetToInt(optionID.(string))
if !ok {
return nil, fmt.Errorf("failed to convert option id from input map: %v", optionID)
}
if optionIDInt < 0 || optionIDInt >= len(bMapping.Normal) {
return nil, fmt.Errorf("failed to take option id from input map: %v", in)
}
return (bMapping.Normal)[optionIDInt], nil
}
return compose.NewGraphMultiBranch(condition, endNodes), nil
}
return nil, fmt.Errorf("this qa node should not have branches: %s", s.Key)
case entity.NodeTypeIntentDetector:
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
isSuccess, ok := in["isSuccess"]
if ok && isSuccess != nil && !isSuccess.(bool) {
return bMapping.Exception, nil
}
classificationId, ok := nodes.TakeMapValue(in, compose.FieldPath{"classificationId"})
if !ok {
return nil, fmt.Errorf("failed to take classification id from input map: %v", in)
}
// Intent detector the node default branch uses classificationId=0. But currently scene, the implementation uses default as the last element of the array.
// Therefore, when classificationId=0, it needs to be converted into the node corresponding to the last index of the array.
// Other options also need to reduce the index by 1.
id, err := cast.ToInt64E(classificationId)
if err != nil {
return nil, err
}
realID := id - 1
if realID >= int64(len(bMapping.Normal)) {
return nil, fmt.Errorf("invalid classification id from input, classification id: %v", classificationId)
}
if realID < 0 {
realID = int64(len(bMapping.Normal)) - 1
}
return (bMapping.Normal)[realID], nil
}
return compose.NewGraphMultiBranch(condition, endNodes), nil
default:
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
isSuccess, ok := in["isSuccess"]
if ok && isSuccess != nil && !isSuccess.(bool) {
return bMapping.Exception, nil
}
return (bMapping.Normal)[0], nil
}
return compose.NewGraphMultiBranch(condition, endNodes), nil
}
}

View File

@@ -0,0 +1,194 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package compose
import (
"context"
"fmt"
"strconv"
"strings"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector"
)
type selectorCallbackField struct {
Key string `json:"key"`
Type vo.DataType `json:"type"`
Value any `json:"value"`
}
type selectorCondition struct {
Left selectorCallbackField `json:"left"`
Operator vo.OperatorType `json:"operator"`
Right *selectorCallbackField `json:"right,omitempty"`
}
type selectorBranch struct {
Conditions []*selectorCondition `json:"conditions"`
Logic vo.LogicType `json:"logic"`
Name string `json:"name"`
}
func (s *NodeSchema) toSelectorCallbackInput(sc *WorkflowSchema) func(_ context.Context, in map[string]any) (map[string]any, error) {
return func(_ context.Context, in map[string]any) (map[string]any, error) {
config := mustGetKey[[]*selector.OneClauseSchema]("Clauses", s.Configs)
count := len(config)
output := make([]*selectorBranch, count)
for _, source := range s.InputSources {
targetPath := source.Path
if len(targetPath) == 2 {
indexStr := targetPath[0]
index, err := strconv.Atoi(indexStr)
if err != nil {
return nil, err
}
branch := output[index]
if branch == nil {
output[index] = &selectorBranch{
Conditions: []*selectorCondition{
{
Operator: config[index].Single.ToCanvasOperatorType(),
},
},
Logic: selector.ClauseRelationAND.ToVOLogicType(),
}
}
if targetPath[1] == selector.LeftKey {
leftV, ok := nodes.TakeMapValue(in, targetPath)
if !ok {
return nil, fmt.Errorf("failed to take left value of %s", targetPath)
}
if source.Source.Ref.VariableType != nil {
if *source.Source.Ref.VariableType == vo.ParentIntermediate {
parentNodeKey, ok := sc.Hierarchy[s.Key]
if !ok {
return nil, fmt.Errorf("failed to find parent node key of %s", s.Key)
}
parentNode := sc.GetNode(parentNodeKey)
output[index].Conditions[0].Left = selectorCallbackField{
Key: parentNode.Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
Value: leftV,
}
} else {
output[index].Conditions[0].Left = selectorCallbackField{
Key: "",
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
Value: leftV,
}
}
} else {
output[index].Conditions[0].Left = selectorCallbackField{
Key: sc.GetNode(source.Source.Ref.FromNodeKey).Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
Value: leftV,
}
}
} else if targetPath[1] == selector.RightKey {
rightV, ok := nodes.TakeMapValue(in, targetPath)
if !ok {
return nil, fmt.Errorf("failed to take right value of %s", targetPath)
}
output[index].Conditions[0].Right = &selectorCallbackField{
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
Value: rightV,
}
}
} else if len(targetPath) == 3 {
indexStr := targetPath[0]
index, err := strconv.Atoi(indexStr)
if err != nil {
return nil, err
}
multi := config[index].Multi
branch := output[index]
if branch == nil {
output[index] = &selectorBranch{
Conditions: make([]*selectorCondition, len(multi.Clauses)),
Logic: multi.Relation.ToVOLogicType(),
}
}
clauseIndexStr := targetPath[1]
clauseIndex, err := strconv.Atoi(clauseIndexStr)
if err != nil {
return nil, err
}
clause := multi.Clauses[clauseIndex]
if output[index].Conditions[clauseIndex] == nil {
output[index].Conditions[clauseIndex] = &selectorCondition{
Operator: clause.ToCanvasOperatorType(),
}
}
if targetPath[2] == selector.LeftKey {
leftV, ok := nodes.TakeMapValue(in, targetPath)
if !ok {
return nil, fmt.Errorf("failed to take left value of %s", targetPath)
}
if source.Source.Ref.VariableType != nil {
if *source.Source.Ref.VariableType == vo.ParentIntermediate {
parentNodeKey, ok := sc.Hierarchy[s.Key]
if !ok {
return nil, fmt.Errorf("failed to find parent node key of %s", s.Key)
}
parentNode := sc.GetNode(parentNodeKey)
output[index].Conditions[clauseIndex].Left = selectorCallbackField{
Key: parentNode.Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
Value: leftV,
}
} else {
output[index].Conditions[clauseIndex].Left = selectorCallbackField{
Key: "",
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
Value: leftV,
}
}
} else {
output[index].Conditions[clauseIndex].Left = selectorCallbackField{
Key: sc.GetNode(source.Source.Ref.FromNodeKey).Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
Value: leftV,
}
}
} else if targetPath[2] == selector.RightKey {
rightV, ok := nodes.TakeMapValue(in, targetPath)
if !ok {
return nil, fmt.Errorf("failed to take right value of %s", targetPath)
}
output[index].Conditions[clauseIndex].Right = &selectorCallbackField{
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
Value: rightV,
}
}
}
}
return map[string]any{"branches": output}, nil
}
}

View File

@@ -0,0 +1,335 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package compose
import (
"context"
"fmt"
"slices"
"strconv"
einoCompose "github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
workflow2 "github.com/coze-dev/coze-studio/backend/domain/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
func (r *WorkflowRunner) designateOptions(ctx context.Context) (context.Context, []einoCompose.Option, error) {
var (
wb = r.basic
exeCfg = r.config
executeID = r.executeID
workflowSC = r.schema
eventChan = r.eventChan
resumedEvent = r.interruptEvent
sw = r.streamWriter
)
const tokenCallbackKey = "token_callback_key"
if wb.AppID != nil && exeCfg.AppID == nil {
exeCfg.AppID = wb.AppID
}
rootHandler := execute.NewRootWorkflowHandler(
wb,
executeID,
workflowSC.requireCheckPoint,
eventChan,
resumedEvent,
exeCfg,
workflowSC.NodeCount())
opts := []einoCompose.Option{einoCompose.WithCallbacks(rootHandler)}
for key := range workflowSC.GetAllNodes() {
ns := workflowSC.GetAllNodes()[key]
var nodeOpt einoCompose.Option
if ns.Type == entity.NodeTypeExit {
nodeOpt = nodeCallbackOption(key, ns.Name, eventChan, resumedEvent,
ptr.Of(mustGetKey[vo.TerminatePlan]("TerminalPlan", ns.Configs)))
} else if ns.Type != entity.NodeTypeLambda {
nodeOpt = nodeCallbackOption(key, ns.Name, eventChan, resumedEvent, nil)
}
if parent, ok := workflowSC.Hierarchy[key]; !ok { // top level nodes, just add the node handler
opts = append(opts, nodeOpt)
if ns.Type == entity.NodeTypeSubWorkflow {
subOpts, err := r.designateOptionsForSubWorkflow(ctx,
rootHandler.(*execute.WorkflowHandler),
ns,
string(key))
if err != nil {
return ctx, nil, err
}
opts = append(opts, subOpts...)
} else if ns.Type == entity.NodeTypeLLM {
llmNodeOpts, err := llmToolCallbackOptions(ctx, ns, eventChan, sw)
if err != nil {
return ctx, nil, err
}
opts = append(opts, llmNodeOpts...)
}
} else {
parent := workflowSC.GetAllNodes()[parent]
opts = append(opts, WrapOpt(nodeOpt, parent.Key))
if ns.Type == entity.NodeTypeSubWorkflow {
subOpts, err := r.designateOptionsForSubWorkflow(ctx,
rootHandler.(*execute.WorkflowHandler),
ns,
string(key))
if err != nil {
return ctx, nil, err
}
for _, subO := range subOpts {
opts = append(opts, WrapOpt(subO, parent.Key))
}
} else if ns.Type == entity.NodeTypeLLM {
llmNodeOpts, err := llmToolCallbackOptions(ctx, ns, eventChan, sw)
if err != nil {
return ctx, nil, err
}
for _, subO := range llmNodeOpts {
opts = append(opts, WrapOpt(subO, parent.Key))
}
}
}
}
if workflowSC.requireCheckPoint {
opts = append(opts, einoCompose.WithCheckPointID(strconv.FormatInt(executeID, 10)))
}
if !ctxcache.HasKey(ctx, tokenCallbackKey) {
opts = append(opts, einoCompose.WithCallbacks(execute.GetTokenCallbackHandler()))
ctx = ctxcache.Init(ctx)
ctxcache.Store(ctx, tokenCallbackKey, true)
}
return ctx, opts, nil
}
func nodeCallbackOption(key vo.NodeKey, name string, eventChan chan *execute.Event, resumeEvent *entity.InterruptEvent,
terminatePlan *vo.TerminatePlan) einoCompose.Option {
return einoCompose.WithCallbacks(execute.NewNodeHandler(string(key), name, eventChan, resumeEvent, terminatePlan)).DesignateNode(string(key))
}
func WrapOpt(opt einoCompose.Option, parentNodeKey vo.NodeKey) einoCompose.Option {
return einoCompose.WithLambdaOption(nodes.WithOptsForNested(opt)).DesignateNode(string(parentNodeKey))
}
func WrapOptWithIndex(opt einoCompose.Option, parentNodeKey vo.NodeKey, index int) einoCompose.Option {
return einoCompose.WithLambdaOption(nodes.WithOptsForIndexed(index, opt)).DesignateNode(string(parentNodeKey))
}
func (r *WorkflowRunner) designateOptionsForSubWorkflow(ctx context.Context,
parentHandler *execute.WorkflowHandler,
ns *NodeSchema,
pathPrefix ...string) (opts []einoCompose.Option, err error) {
var (
resumeEvent = r.interruptEvent
eventChan = r.eventChan
sw = r.streamWriter
)
subHandler := execute.NewSubWorkflowHandler(
parentHandler,
ns.SubWorkflowBasic,
resumeEvent,
ns.SubWorkflowSchema.NodeCount(),
)
opts = append(opts, WrapOpt(einoCompose.WithCallbacks(subHandler), ns.Key))
workflowSC := ns.SubWorkflowSchema
for key := range workflowSC.GetAllNodes() {
subNS := workflowSC.GetAllNodes()[key]
fullPath := append(slices.Clone(pathPrefix), string(subNS.Key))
var nodeOpt einoCompose.Option
if subNS.Type == entity.NodeTypeExit {
nodeOpt = nodeCallbackOption(key, subNS.Name, eventChan, resumeEvent,
ptr.Of(mustGetKey[vo.TerminatePlan]("TerminalPlan", subNS.Configs)))
} else {
nodeOpt = nodeCallbackOption(key, subNS.Name, eventChan, resumeEvent, nil)
}
if parent, ok := workflowSC.Hierarchy[key]; !ok { // top level nodes, just add the node handler
opts = append(opts, WrapOpt(nodeOpt, ns.Key))
if subNS.Type == entity.NodeTypeSubWorkflow {
subOpts, err := r.designateOptionsForSubWorkflow(ctx,
subHandler.(*execute.WorkflowHandler),
subNS,
fullPath...)
if err != nil {
return nil, err
}
for _, subO := range subOpts {
opts = append(opts, WrapOpt(subO, ns.Key))
}
} else if subNS.Type == entity.NodeTypeLLM {
llmNodeOpts, err := llmToolCallbackOptions(ctx, subNS, eventChan, sw)
if err != nil {
return nil, err
}
for _, subO := range llmNodeOpts {
opts = append(opts, WrapOpt(subO, ns.Key))
}
}
} else {
parent := workflowSC.GetAllNodes()[parent]
opts = append(opts, WrapOpt(WrapOpt(nodeOpt, parent.Key), ns.Key))
if subNS.Type == entity.NodeTypeSubWorkflow {
subOpts, err := r.designateOptionsForSubWorkflow(ctx,
subHandler.(*execute.WorkflowHandler),
subNS,
fullPath...)
if err != nil {
return nil, err
}
for _, subO := range subOpts {
opts = append(opts, WrapOpt(WrapOpt(subO, parent.Key), ns.Key))
}
} else if subNS.Type == entity.NodeTypeLLM {
llmNodeOpts, err := llmToolCallbackOptions(ctx, subNS, eventChan, sw)
if err != nil {
return nil, err
}
for _, subO := range llmNodeOpts {
opts = append(opts, WrapOpt(WrapOpt(subO, parent.Key), ns.Key))
}
}
}
}
return opts, nil
}
func llmToolCallbackOptions(ctx context.Context, ns *NodeSchema, eventChan chan *execute.Event,
sw *schema.StreamWriter[*entity.Message]) (
opts []einoCompose.Option, err error) {
// this is a LLM node.
// check if it has any tools, if no tools, then no callback options needed
// for each tool, extract the entity.FunctionInfo, create the ToolHandler, and add the callback option
if ns.Type != entity.NodeTypeLLM {
panic("impossible: llmToolCallbackOptions is called on a non-LLM node")
}
fcParams := getKeyOrZero[*vo.FCParam]("FCParam", ns.Configs)
if fcParams != nil {
if fcParams.WorkflowFCParam != nil {
// TODO: try to avoid getting the workflow tool all over again
for _, wf := range fcParams.WorkflowFCParam.WorkflowList {
wfIDStr := wf.WorkflowID
wfID, err := strconv.ParseInt(wfIDStr, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid workflow id: %s", wfIDStr)
}
locator := vo.FromDraft
if wf.WorkflowVersion != "" {
locator = vo.FromSpecificVersion
}
wfTool, err := workflow2.GetRepository().WorkflowAsTool(ctx, vo.GetPolicy{
ID: wfID,
QType: locator,
Version: wf.WorkflowVersion,
}, vo.WorkflowToolConfig{})
if err != nil {
return nil, err
}
tInfo, err := wfTool.Info(ctx)
if err != nil {
return nil, err
}
funcInfo := entity.FunctionInfo{
Name: tInfo.Name,
Type: entity.WorkflowTool,
WorkflowName: wfTool.GetWorkflow().Name,
WorkflowTerminatePlan: wfTool.TerminatePlan(),
APIID: wfID,
APIName: wfTool.GetWorkflow().Name,
PluginID: wfID,
PluginName: wfTool.GetWorkflow().Name,
}
toolHandler := execute.NewToolHandler(eventChan, funcInfo)
opt := einoCompose.WithCallbacks(toolHandler)
opt = einoCompose.WithLambdaOption(llm.WithNestedWorkflowOptions(nodes.WithOptsForNested(opt))).DesignateNode(string(ns.Key))
opts = append(opts, opt)
}
}
if fcParams.PluginFCParam != nil {
for _, p := range fcParams.PluginFCParam.PluginList {
toolID, err := strconv.ParseInt(p.ApiId, 10, 64)
if err != nil {
return nil, err
}
pluginID, err := strconv.ParseInt(p.PluginID, 10, 64)
if err != nil {
return nil, err
}
toolInfoResponse, err := plugin.GetPluginService().GetPluginToolsInfo(ctx, &plugin.ToolsInfoRequest{
PluginEntity: plugin.Entity{
PluginID: pluginID,
PluginVersion: ptr.Of(p.PluginVersion),
},
ToolIDs: []int64{toolID},
})
if err != nil {
return nil, err
}
funcInfo := entity.FunctionInfo{
Name: toolInfoResponse.ToolInfoList[toolID].ToolName,
Type: entity.PluginTool,
PluginID: pluginID,
PluginName: toolInfoResponse.PluginName,
APIID: toolID,
APIName: p.ApiName,
}
toolHandler := execute.NewToolHandler(eventChan, funcInfo)
opt := einoCompose.WithCallbacks(toolHandler)
opt = einoCompose.WithLambdaOption(llm.WithNestedWorkflowOptions(nodes.WithOptsForNested(opt))).DesignateNode(string(ns.Key))
opts = append(opts, opt)
}
}
}
if sw != nil {
toolMsgOpt := llm.WithToolWorkflowMessageWriter(sw)
opt := einoCompose.WithLambdaOption(toolMsgOpt).DesignateNode(string(ns.Key))
opts = append(opts, opt)
}
return opts, nil
}

View File

@@ -0,0 +1,314 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package compose
import (
"context"
"fmt"
"maps"
"slices"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
)
// outputValueFiller will fill the output value with nil if the key is not present in the output map.
// if a node emits stream as output, the node needs to handle these absent keys in stream themselves.
func (s *NodeSchema) outputValueFiller() func(ctx context.Context, output map[string]any) (map[string]any, error) {
if len(s.OutputTypes) == 0 {
return func(ctx context.Context, output map[string]any) (map[string]any, error) {
return output, nil
}
}
return func(ctx context.Context, output map[string]any) (map[string]any, error) {
newOutput := make(map[string]any)
for k := range output {
newOutput[k] = output[k]
}
for k, tInfo := range s.OutputTypes {
if err := FillIfNotRequired(tInfo, newOutput, k, FillNil, false); err != nil {
return nil, err
}
}
return newOutput, nil
}
}
// inputValueFiller will fill the input value with default value(zero or nil) if the input value is not present in map.
// if a node accepts stream as input, the node needs to handle these absent keys in stream themselves.
func (s *NodeSchema) inputValueFiller() func(ctx context.Context, input map[string]any) (map[string]any, error) {
if len(s.InputTypes) == 0 {
return func(ctx context.Context, input map[string]any) (map[string]any, error) {
return input, nil
}
}
return func(ctx context.Context, input map[string]any) (map[string]any, error) {
newInput := make(map[string]any)
for k := range input {
newInput[k] = input[k]
}
for k, tInfo := range s.InputTypes {
if err := FillIfNotRequired(tInfo, newInput, k, FillZero, false); err != nil {
return nil, err
}
}
return newInput, nil
}
}
func (s *NodeSchema) streamInputValueFiller() func(ctx context.Context,
input *schema.StreamReader[map[string]any]) *schema.StreamReader[map[string]any] {
fn := func(ctx context.Context, i map[string]any) (map[string]any, error) {
newI := make(map[string]any)
for k := range i {
newI[k] = i[k]
}
for k, tInfo := range s.InputTypes {
if err := replaceNilWithZero(tInfo, newI, k); err != nil {
return nil, err
}
}
return newI, nil
}
return func(ctx context.Context, input *schema.StreamReader[map[string]any]) *schema.StreamReader[map[string]any] {
return schema.StreamReaderWithConvert(input, func(in map[string]any) (map[string]any, error) {
return fn(ctx, in)
})
}
}
type FillStrategy string
const (
FillZero FillStrategy = "zero"
FillNil FillStrategy = "nil"
)
func FillIfNotRequired(tInfo *vo.TypeInfo, container map[string]any, k string, strategy FillStrategy, isInsideObject bool) error {
v, ok := container[k]
if ok {
if len(tInfo.Properties) == 0 {
if v == nil && strategy == FillZero {
if isInsideObject {
return nil
}
v = tInfo.Zero()
container[k] = v
return nil
}
if v != nil && tInfo.Type == vo.DataTypeArray {
val, ok := v.([]any)
if !ok {
valStr, ok := v.(string)
if ok {
err := sonic.UnmarshalString(valStr, &val)
if err != nil {
return err
}
container[k] = val
} else {
return fmt.Errorf("layer field %s is not a []any or string", k)
}
}
elemTInfo := tInfo.ElemTypeInfo
copiedVal := slices.Clone(val)
container[k] = copiedVal
for i := range copiedVal {
if copiedVal[i] == nil {
if strategy == FillZero {
copiedVal[i] = elemTInfo.Zero()
continue
}
}
if len(elemTInfo.Properties) > 0 {
subContainer, ok := copiedVal[i].(map[string]any)
if !ok {
return fmt.Errorf("map item under array %s is not map[string]any", k)
}
newSubContainer := maps.Clone(subContainer)
for subK, subL := range elemTInfo.Properties {
if err := FillIfNotRequired(subL, newSubContainer, subK, strategy, true); err != nil {
return err
}
}
copiedVal[i] = newSubContainer
}
}
}
} else {
if v == nil {
return nil
}
// recursively handle the layered object.
subContainer, ok := v.(map[string]any)
if !ok {
subContainerStr, ok := v.(string)
if ok {
subContainer = make(map[string]any)
err := sonic.UnmarshalString(subContainerStr, &subContainer)
if err != nil {
return err
}
container[k] = subContainer
} else {
return fmt.Errorf("layer field %s is not a map[string]any or string", k)
}
}
newSubContainer := maps.Clone(subContainer)
if newSubContainer == nil {
newSubContainer = make(map[string]any)
}
for subK, subT := range tInfo.Properties {
if err := FillIfNotRequired(subT, newSubContainer, subK, strategy, true); err != nil {
return err
}
}
container[k] = newSubContainer
}
} else {
if tInfo.Required {
return fmt.Errorf("output field %s is required but not present", k)
} else {
var z any
if strategy == FillZero {
if !isInsideObject {
z = tInfo.Zero()
}
}
container[k] = z
// if it's an object, recursively handle the layeredFieldInfo.
if len(tInfo.Properties) > 0 {
z = make(map[string]any)
container[k] = z
subContainer := z.(map[string]any)
for subK, subL := range tInfo.Properties {
if err := FillIfNotRequired(subL, subContainer, subK, strategy, true); err != nil {
return err
}
}
}
}
}
return nil
}
func replaceNilWithZero(tInfo *vo.TypeInfo, container map[string]any, k string) error {
v, ok := container[k]
if !ok {
return nil
}
if len(tInfo.Properties) == 0 {
if v == nil {
v = tInfo.Zero()
container[k] = v
return nil
}
if tInfo.Type == vo.DataTypeArray {
val, ok := v.([]any)
if !ok {
valStr, ok := v.(string)
if ok {
err := sonic.UnmarshalString(valStr, &val)
if err != nil {
return err
}
container[k] = val
} else {
return fmt.Errorf("layer field %s is not a []any or string", k)
}
}
elemTInfo := tInfo.ElemTypeInfo
copiedVal := slices.Clone(val)
container[k] = copiedVal
for i := range copiedVal {
if copiedVal[i] == nil {
copiedVal[i] = elemTInfo.Zero()
continue
}
if len(elemTInfo.Properties) > 0 {
subContainer, ok := copiedVal[i].(map[string]any)
if !ok {
return fmt.Errorf("map item under array %s is not map[string]any", k)
}
newSubContainer := maps.Clone(subContainer)
for subK, subL := range elemTInfo.Properties {
if err := replaceNilWithZero(subL, newSubContainer, subK); err != nil {
return err
}
}
copiedVal[i] = newSubContainer
}
}
}
} else {
if v == nil {
return nil
}
// recursively handle the layered object.
subContainer, ok := v.(map[string]any)
if !ok {
subContainerStr, ok := v.(string)
if ok {
subContainer = make(map[string]any)
err := sonic.UnmarshalString(subContainerStr, &subContainer)
if err != nil {
return err
}
container[k] = subContainer
} else {
return fmt.Errorf("layer field %s is not a map[string]any or string", k)
}
}
newSubContainer := maps.Clone(subContainer)
for subK, subT := range tInfo.Properties {
if err := replaceNilWithZero(subT, newSubContainer, subK); err != nil {
return err
}
}
container[k] = newSubContainer
}
return nil
}

View File

@@ -0,0 +1,300 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package compose
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
)
func TestNodeSchema_OutputValueFiller(t *testing.T) {
type fields struct {
In map[string]any
Outputs map[string]*vo.TypeInfo
}
tests := []struct {
name string
fields fields
want map[string]any
wantErr string
}{
{
name: "string field",
fields: fields{
In: map[string]any{},
Outputs: map[string]*vo.TypeInfo{
"key": {
Type: vo.DataTypeString,
},
},
},
want: map[string]any{
"key": nil,
},
},
{
name: "integer field",
fields: fields{
In: map[string]any{},
Outputs: map[string]*vo.TypeInfo{
"key": {
Type: vo.DataTypeInteger,
},
},
},
want: map[string]any{
"key": nil,
},
},
{
name: "number field",
fields: fields{
In: map[string]any{},
Outputs: map[string]*vo.TypeInfo{
"key": {
Type: vo.DataTypeNumber,
},
},
},
want: map[string]any{
"key": nil,
},
},
{
name: "boolean field",
fields: fields{
In: map[string]any{},
Outputs: map[string]*vo.TypeInfo{
"key": {
Type: vo.DataTypeBoolean,
},
},
},
want: map[string]any{
"key": nil,
},
},
{
name: "time field",
fields: fields{
In: map[string]any{},
Outputs: map[string]*vo.TypeInfo{
"key": {
Type: vo.DataTypeTime,
},
},
},
want: map[string]any{
"key": nil,
},
},
{
name: "object field",
fields: fields{
In: map[string]any{},
Outputs: map[string]*vo.TypeInfo{
"key": {
Type: vo.DataTypeObject,
},
},
},
want: map[string]any{
"key": nil,
},
},
{
name: "array field",
fields: fields{
In: map[string]any{},
Outputs: map[string]*vo.TypeInfo{
"key": {
Type: vo.DataTypeArray,
},
},
},
want: map[string]any{
"key": nil,
},
},
{
name: "file field",
fields: fields{
In: map[string]any{},
Outputs: map[string]*vo.TypeInfo{
"key": {
Type: vo.DataTypeFile,
},
},
},
want: map[string]any{
"key": nil,
},
},
{
name: "required field not present",
fields: fields{
In: map[string]any{},
Outputs: map[string]*vo.TypeInfo{
"key": {
Type: vo.DataTypeString,
Required: true,
},
},
},
wantErr: "is required but not present",
},
{
name: "layered: object.string",
fields: fields{
In: map[string]any{
"key": map[string]any{},
},
Outputs: map[string]*vo.TypeInfo{
"key": {
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"sub_key": {
Type: vo.DataTypeString,
},
},
},
},
},
want: map[string]any{
"key": map[string]any{
"sub_key": nil,
},
},
},
{
name: "layered: object.object",
fields: fields{
In: map[string]any{},
Outputs: map[string]*vo.TypeInfo{
"key": {
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"sub_key": {
Type: vo.DataTypeObject,
},
},
},
},
},
want: map[string]any{
"key": map[string]any{
"sub_key": nil,
},
},
},
{
name: "layered: object.object.array",
fields: fields{
In: map[string]any{},
Outputs: map[string]*vo.TypeInfo{
"key": {
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"sub_key": {
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"sub_key2": {
Type: vo.DataTypeArray,
},
},
},
},
},
},
},
want: map[string]any{
"key": map[string]any{
"sub_key": map[string]any{
"sub_key2": nil,
},
},
},
},
{
name: "key present",
fields: fields{
In: map[string]any{
"key": "value",
},
Outputs: map[string]*vo.TypeInfo{
"key": {
Type: vo.DataTypeString,
},
},
},
want: map[string]any{
"key": "value",
},
},
{
name: "layered key present",
fields: fields{
In: map[string]any{
"key": map[string]any{},
},
Outputs: map[string]*vo.TypeInfo{
"key": {
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"sub_key": {
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"sub_key2": {
Type: vo.DataTypeArray,
},
},
},
},
},
},
},
want: map[string]any{
"key": map[string]any{
"sub_key": map[string]any{
"sub_key2": nil,
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &NodeSchema{
OutputTypes: tt.fields.Outputs,
}
got, err := s.outputValueFiller()(context.Background(), tt.fields.In)
if len(tt.wantErr) > 0 {
assert.Error(t, err)
assert.ErrorContains(t, err, tt.wantErr)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.want, got)
}
})
}
}

View File

@@ -0,0 +1,791 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package compose
import (
"context"
"errors"
"fmt"
"runtime/debug"
"strings"
"time"
"github.com/cloudwego/eino/callbacks"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
"golang.org/x/exp/maps"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
type nodeRunConfig[O any] struct {
nodeKey vo.NodeKey
nodeName string
nodeType entity.NodeType
timeoutMS int64
maxRetry int64
errProcessType vo.ErrorProcessType
dataOnErr func(ctx context.Context) map[string]any
callbackEnabled bool
preProcessors []func(ctx context.Context, input map[string]any) (map[string]any, error)
postProcessors []func(ctx context.Context, input map[string]any) (map[string]any, error)
streamPreProcessors []func(ctx context.Context,
input *schema.StreamReader[map[string]any]) *schema.StreamReader[map[string]any]
callbackInputConverter func(context.Context, map[string]any) (map[string]any, error)
callbackOutputConverter func(context.Context, map[string]any) (*nodes.StructuredCallbackOutput, error)
init []func(context.Context) (context.Context, error)
i compose.Invoke[map[string]any, map[string]any, O]
s compose.Stream[map[string]any, map[string]any, O]
t compose.Transform[map[string]any, map[string]any, O]
}
func newNodeRunConfig[O any](ns *NodeSchema,
i compose.Invoke[map[string]any, map[string]any, O],
s compose.Stream[map[string]any, map[string]any, O],
t compose.Transform[map[string]any, map[string]any, O],
opts *newNodeOptions) *nodeRunConfig[O] {
meta := entity.NodeMetaByNodeType(ns.Type)
var (
timeoutMS = meta.DefaultTimeoutMS
maxRetry int64
errProcessType = vo.ErrorProcessTypeThrow
dataOnErr func(ctx context.Context) map[string]any
)
if ns.ExceptionConfigs != nil {
timeoutMS = ns.ExceptionConfigs.TimeoutMS
maxRetry = ns.ExceptionConfigs.MaxRetry
if ns.ExceptionConfigs.ProcessType != nil {
errProcessType = *ns.ExceptionConfigs.ProcessType
}
if len(ns.ExceptionConfigs.DataOnErr) > 0 {
dataOnErr = func(ctx context.Context) map[string]any {
return parseDefaultOutputOrFallback(ctx, ns.ExceptionConfigs.DataOnErr, ns.OutputTypes)
}
}
}
preProcessors := []func(ctx context.Context, input map[string]any) (map[string]any, error){
preTypeConverter(ns.InputTypes),
keyFinishedMarkerTrimmer(),
}
if meta.PreFillZero {
preProcessors = append(preProcessors, ns.inputValueFiller())
}
var postProcessors []func(ctx context.Context, input map[string]any) (map[string]any, error)
if meta.PostFillNil {
postProcessors = append(postProcessors, ns.outputValueFiller())
}
streamPreProcessors := []func(ctx context.Context,
input *schema.StreamReader[map[string]any]) *schema.StreamReader[map[string]any]{
func(ctx context.Context, input *schema.StreamReader[map[string]any]) *schema.StreamReader[map[string]any] {
f := func(in map[string]any) (map[string]any, error) {
return preTypeConverter(ns.InputTypes)(ctx, in)
}
return schema.StreamReaderWithConvert(input, f)
},
}
if meta.PreFillZero {
streamPreProcessors = append(streamPreProcessors, ns.streamInputValueFiller())
}
opts.init = append(opts.init, func(ctx context.Context) (context.Context, error) {
current, exceeded := execute.IncrAndCheckExecutedNodes(ctx)
if exceeded {
return nil, fmt.Errorf("exceeded max executed node count: %d, current: %d", execute.GetStaticConfig().MaxNodeCountPerExecution, current)
}
return ctx, nil
})
return &nodeRunConfig[O]{
nodeKey: ns.Key,
nodeName: ns.Name,
nodeType: ns.Type,
timeoutMS: timeoutMS,
maxRetry: maxRetry,
errProcessType: errProcessType,
dataOnErr: dataOnErr,
callbackEnabled: meta.CallbackEnabled,
preProcessors: preProcessors,
postProcessors: postProcessors,
streamPreProcessors: streamPreProcessors,
callbackInputConverter: opts.callbackInputConverter,
callbackOutputConverter: opts.callbackOutputConverter,
init: opts.init,
i: i,
s: s,
t: t,
}
}
func newNodeRunConfigWOOpt(ns *NodeSchema,
i compose.InvokeWOOpt[map[string]any, map[string]any],
s compose.StreamWOOpt[map[string]any, map[string]any],
t compose.TransformWOOpts[map[string]any, map[string]any],
opts *newNodeOptions) *nodeRunConfig[any] {
var (
iWO compose.Invoke[map[string]any, map[string]any, any]
sWO compose.Stream[map[string]any, map[string]any, any]
tWO compose.Transform[map[string]any, map[string]any, any]
)
if i != nil {
iWO = func(ctx context.Context, in map[string]any, _ ...any) (out map[string]any, err error) {
return i(ctx, in)
}
}
if s != nil {
sWO = func(ctx context.Context, in map[string]any, _ ...any) (out *schema.StreamReader[map[string]any], err error) {
return s(ctx, in)
}
}
if t != nil {
tWO = func(ctx context.Context, input *schema.StreamReader[map[string]any], opts ...any) (output *schema.StreamReader[map[string]any], err error) {
return t(ctx, input)
}
}
return newNodeRunConfig[any](ns, iWO, sWO, tWO, opts)
}
type newNodeOptions struct {
callbackInputConverter func(context.Context, map[string]any) (map[string]any, error)
callbackOutputConverter func(context.Context, map[string]any) (*nodes.StructuredCallbackOutput, error)
init []func(context.Context) (context.Context, error)
}
type newNodeOption func(*newNodeOptions)
func withCallbackInputConverter(f func(context.Context, map[string]any) (map[string]any, error)) newNodeOption {
return func(opts *newNodeOptions) {
opts.callbackInputConverter = f
}
}
func withCallbackOutputConverter(f func(context.Context, map[string]any) (*nodes.StructuredCallbackOutput, error)) newNodeOption {
return func(opts *newNodeOptions) {
opts.callbackOutputConverter = f
}
}
func withInit(f func(context.Context) (context.Context, error)) newNodeOption {
return func(opts *newNodeOptions) {
opts.init = append(opts.init, f)
}
}
func invokableNode(ns *NodeSchema, i compose.InvokeWOOpt[map[string]any, map[string]any], opts ...newNodeOption) *Node {
options := &newNodeOptions{}
for _, opt := range opts {
opt(options)
}
return newNodeRunConfigWOOpt(ns, i, nil, nil, options).toNode()
}
func invokableNodeWO[O any](ns *NodeSchema, i compose.Invoke[map[string]any, map[string]any, O], opts ...newNodeOption) *Node {
options := &newNodeOptions{}
for _, opt := range opts {
opt(options)
}
return newNodeRunConfig(ns, i, nil, nil, options).toNode()
}
func invokableTransformableNode(ns *NodeSchema, i compose.InvokeWOOpt[map[string]any, map[string]any],
t compose.TransformWOOpts[map[string]any, map[string]any], opts ...newNodeOption) *Node {
options := &newNodeOptions{}
for _, opt := range opts {
opt(options)
}
return newNodeRunConfigWOOpt(ns, i, nil, t, options).toNode()
}
func invokableStreamableNodeWO[O any](ns *NodeSchema, i compose.Invoke[map[string]any, map[string]any, O], s compose.Stream[map[string]any, map[string]any, O], opts ...newNodeOption) *Node {
options := &newNodeOptions{}
for _, opt := range opts {
opt(options)
}
return newNodeRunConfig(ns, i, s, nil, options).toNode()
}
func (nc *nodeRunConfig[O]) invoke() func(ctx context.Context, input map[string]any, opts ...O) (output map[string]any, err error) {
if nc.i == nil {
return nil
}
return func(ctx context.Context, input map[string]any, opts ...O) (output map[string]any, err error) {
ctx, runner := newNodeRunner(ctx, nc)
defer func() {
if panicErr := recover(); panicErr != nil {
err = safego.NewPanicErr(panicErr, debug.Stack())
}
if err == nil {
err = runner.onEnd(ctx, output)
}
if err != nil {
errOutput, hasErrOutput := runner.onError(ctx, err)
if hasErrOutput {
output = errOutput
err = nil
if output, err = runner.postProcess(ctx, output); err != nil {
logs.CtxErrorf(ctx, "postProcess failed after returning error output: %v", err)
}
}
}
}()
for _, i := range runner.init {
if ctx, err = i(ctx); err != nil {
return nil, err
}
}
if input, err = runner.preProcess(ctx, input); err != nil {
return nil, err
}
if ctx, err = runner.onStart(ctx, input); err != nil {
return nil, err
}
if output, err = runner.invoke(ctx, input, opts...); err != nil {
return nil, err
}
return runner.postProcess(ctx, output)
}
}
func (nc *nodeRunConfig[O]) stream() func(ctx context.Context, input map[string]any, opts ...O) (output *schema.StreamReader[map[string]any], err error) {
if nc.s == nil {
return nil
}
return func(ctx context.Context, input map[string]any, opts ...O) (output *schema.StreamReader[map[string]any], err error) {
ctx, runner := newNodeRunner(ctx, nc)
defer func() {
if panicErr := recover(); panicErr != nil {
err = safego.NewPanicErr(panicErr, debug.Stack())
}
if err == nil {
output, err = runner.onEndStream(ctx, output)
}
if err != nil {
errOutput, hasErrOutput := runner.onError(ctx, err)
if hasErrOutput {
output = schema.StreamReaderFromArray([]map[string]any{errOutput})
err = nil
}
}
}()
for _, i := range runner.init {
if ctx, err = i(ctx); err != nil {
return nil, err
}
}
if input, err = runner.preProcess(ctx, input); err != nil {
return nil, err
}
if ctx, err = runner.onStart(ctx, input); err != nil {
return nil, err
}
return runner.stream(ctx, input, opts...)
}
}
func (nc *nodeRunConfig[O]) transform() func(ctx context.Context, input *schema.StreamReader[map[string]any], opts ...O) (output *schema.StreamReader[map[string]any], err error) {
if nc.t == nil {
return nil
}
return func(ctx context.Context, input *schema.StreamReader[map[string]any], opts ...O) (output *schema.StreamReader[map[string]any], err error) {
ctx, runner := newNodeRunner(ctx, nc)
defer func() {
if panicErr := recover(); panicErr != nil {
err = safego.NewPanicErr(panicErr, debug.Stack())
}
if err == nil {
output, err = runner.onEndStream(ctx, output)
}
if err != nil {
errOutput, hasErrOutput := runner.onError(ctx, err)
if hasErrOutput {
output = schema.StreamReaderFromArray([]map[string]any{errOutput})
err = nil
}
}
}()
for _, i := range runner.init {
if ctx, err = i(ctx); err != nil {
return nil, err
}
}
for _, p := range runner.streamPreProcessors {
input = p(ctx, input)
}
if ctx, input, err = runner.onStartStream(ctx, input); err != nil {
return nil, err
}
return runner.transform(ctx, input, opts...)
}
}
func (nc *nodeRunConfig[O]) toNode() *Node {
var opts []compose.LambdaOpt
opts = append(opts, compose.WithLambdaType(string(nc.nodeType)))
if nc.callbackEnabled {
opts = append(opts, compose.WithLambdaCallbackEnable(true))
}
l, err := compose.AnyLambda(nc.invoke(), nc.stream(), nil, nc.transform(), opts...)
if err != nil {
panic(fmt.Sprintf("failed to create lambda for node %s, err: %v", nc.nodeName, err))
}
return &Node{Lambda: l}
}
type nodeRunner[O any] struct {
*nodeRunConfig[O]
interrupted bool
cancelFn context.CancelFunc
}
func newNodeRunner[O any](ctx context.Context, cfg *nodeRunConfig[O]) (context.Context, *nodeRunner[O]) {
runner := &nodeRunner[O]{
nodeRunConfig: cfg,
}
if cfg.timeoutMS > 0 {
ctx, runner.cancelFn = context.WithTimeout(ctx, time.Duration(cfg.timeoutMS)*time.Millisecond)
}
return ctx, runner
}
func (r *nodeRunner[O]) onStart(ctx context.Context, input map[string]any) (context.Context, error) {
if !r.callbackEnabled {
return ctx, nil
}
if r.callbackInputConverter != nil {
convertedInput, err := r.callbackInputConverter(ctx, input)
if err != nil {
ctx = callbacks.OnStart(ctx, input)
return ctx, err
}
ctx = callbacks.OnStart(ctx, convertedInput)
} else {
ctx = callbacks.OnStart(ctx, input)
}
return ctx, nil
}
func (r *nodeRunner[O]) onStartStream(ctx context.Context, input *schema.StreamReader[map[string]any]) (
context.Context, *schema.StreamReader[map[string]any], error) {
if !r.callbackEnabled {
return ctx, input, nil
}
if r.callbackInputConverter != nil {
copied := input.Copy(2)
realConverter := func(ctx context.Context) func(map[string]any) (map[string]any, error) {
return func(in map[string]any) (map[string]any, error) {
return r.callbackInputConverter(ctx, in)
}
}
callbackS := schema.StreamReaderWithConvert(copied[0], realConverter(ctx))
newCtx, unused := callbacks.OnStartWithStreamInput(ctx, callbackS)
unused.Close()
return newCtx, copied[1], nil
}
newCtx, newInput := callbacks.OnStartWithStreamInput(ctx, input)
return newCtx, newInput, nil
}
func (r *nodeRunner[O]) preProcess(ctx context.Context, input map[string]any) (_ map[string]any, err error) {
for _, preProcessor := range r.preProcessors {
if preProcessor == nil {
continue
}
input, err = preProcessor(ctx, input)
if err != nil {
return nil, err
}
}
return input, nil
}
func (r *nodeRunner[O]) postProcess(ctx context.Context, output map[string]any) (_ map[string]any, err error) {
for _, postProcessor := range r.postProcessors {
if postProcessor == nil {
continue
}
output, err = postProcessor(ctx, output)
if err != nil {
return nil, err
}
}
return output, nil
}
func (r *nodeRunner[O]) invoke(ctx context.Context, input map[string]any, opts ...O) (output map[string]any, err error) {
var n int64
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
output, err = r.i(ctx, input, opts...)
if err != nil {
if _, ok := compose.IsInterruptRerunError(err); ok { // interrupt, won't retry
r.interrupted = true
return nil, err
}
logs.CtxErrorf(ctx, "[invoke] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
if r.maxRetry > n {
n++
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil {
exeCtx.CurrentRetryCount++
}
continue
}
return nil, err
}
return output, nil
}
}
func (r *nodeRunner[O]) stream(ctx context.Context, input map[string]any, opts ...O) (output *schema.StreamReader[map[string]any], err error) {
var n int64
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
output, err = r.s(ctx, input, opts...)
if err != nil {
if _, ok := compose.IsInterruptRerunError(err); ok { // interrupt, won't retry
r.interrupted = true
return nil, err
}
logs.CtxErrorf(ctx, "[invoke] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
if r.maxRetry > n {
n++
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil {
exeCtx.CurrentRetryCount++
}
continue
}
return nil, err
}
return output, nil
}
}
func (r *nodeRunner[O]) transform(ctx context.Context, input *schema.StreamReader[map[string]any], opts ...O) (output *schema.StreamReader[map[string]any], err error) {
if r.maxRetry == 0 {
return r.t(ctx, input, opts...)
}
copied := input.Copy(int(r.maxRetry))
var n int64
defer func() {
for i := n + 1; i < r.maxRetry; i++ {
copied[i].Close()
}
}()
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
output, err = r.t(ctx, copied[n], opts...)
if err != nil {
if _, ok := compose.IsInterruptRerunError(err); ok { // interrupt, won't retry
r.interrupted = true
return nil, err
}
logs.CtxErrorf(ctx, "[invoke] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
if r.maxRetry > n {
n++
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil {
exeCtx.CurrentRetryCount++
}
continue
}
return nil, err
}
return output, nil
}
}
func (r *nodeRunner[O]) onEnd(ctx context.Context, output map[string]any) error {
if r.errProcessType == vo.ErrorProcessTypeExceptionBranch || r.errProcessType == vo.ErrorProcessTypeDefault {
output["isSuccess"] = true
}
if !r.callbackEnabled {
return nil
}
if r.callbackOutputConverter != nil {
convertedOutput, err := r.callbackOutputConverter(ctx, output)
if err != nil {
return err
}
_ = callbacks.OnEnd(ctx, convertedOutput)
} else {
_ = callbacks.OnEnd(ctx, output)
}
return nil
}
func (r *nodeRunner[O]) onEndStream(ctx context.Context, output *schema.StreamReader[map[string]any]) (
*schema.StreamReader[map[string]any], error) {
if r.errProcessType == vo.ErrorProcessTypeExceptionBranch || r.errProcessType == vo.ErrorProcessTypeDefault {
flag := schema.StreamReaderFromArray([]map[string]any{{"isSuccess": true}})
output = schema.MergeStreamReaders([]*schema.StreamReader[map[string]any]{flag, output})
}
if !r.callbackEnabled {
return output, nil
}
if r.callbackOutputConverter != nil {
copied := output.Copy(2)
realConverter := func(ctx context.Context) func(map[string]any) (*nodes.StructuredCallbackOutput, error) {
return func(in map[string]any) (*nodes.StructuredCallbackOutput, error) {
return r.callbackOutputConverter(ctx, in)
}
}
callbackS := schema.StreamReaderWithConvert(copied[0], realConverter(ctx))
_, unused := callbacks.OnEndWithStreamOutput(ctx, callbackS)
unused.Close()
return copied[1], nil
}
_, newOutput := callbacks.OnEndWithStreamOutput(ctx, output)
return newOutput, nil
}
func (r *nodeRunner[O]) onError(ctx context.Context, err error) (map[string]any, bool) {
if r.interrupted {
if r.callbackEnabled {
_ = callbacks.OnError(ctx, err)
}
return nil, false
}
var sErr vo.WorkflowError
if !errors.As(err, &sErr) {
if errors.Is(err, context.DeadlineExceeded) {
sErr = vo.NodeTimeoutErr
} else if errors.Is(err, context.Canceled) {
sErr = vo.CancelErr
} else {
sErr = vo.WrapError(errno.ErrWorkflowExecuteFail, err, errorx.KV("cause", vo.UnwrapRootErr(err).Error()))
}
}
code := int(sErr.Code())
msg := sErr.Msg()
switch r.errProcessType {
case vo.ErrorProcessTypeDefault:
d := r.dataOnErr(ctx)
d["errorBody"] = map[string]any{
"errorMessage": msg,
"errorCode": code,
}
d["isSuccess"] = false
if r.callbackEnabled {
sErr = sErr.ChangeErrLevel(vo.LevelWarn)
sOutput := &nodes.StructuredCallbackOutput{
Output: d,
RawOutput: d,
Error: sErr,
}
_ = callbacks.OnEnd(ctx, sOutput)
}
return d, true
case vo.ErrorProcessTypeExceptionBranch:
s := make(map[string]any)
s["errorBody"] = map[string]any{
"errorMessage": msg,
"errorCode": code,
}
s["isSuccess"] = false
if r.callbackEnabled {
sErr = sErr.ChangeErrLevel(vo.LevelWarn)
sOutput := &nodes.StructuredCallbackOutput{
Output: s,
RawOutput: s,
Error: sErr,
}
_ = callbacks.OnEnd(ctx, sOutput)
}
return s, true
default:
if r.callbackEnabled {
_ = callbacks.OnError(ctx, sErr)
}
return nil, false
}
}
func parseDefaultOutput(ctx context.Context, data string, schema_ map[string]*vo.TypeInfo) (map[string]any, error) {
var result map[string]any
err := sonic.UnmarshalString(data, &result)
if err != nil {
return nil, err
}
r, ws, e := nodes.ConvertInputs(ctx, result, schema_)
if e != nil {
return nil, e
}
if ws != nil {
logs.CtxWarnf(ctx, "convert output warnings: %v", *ws)
}
return r, nil
}
func parseDefaultOutputOrFallback(ctx context.Context, data string, schema_ map[string]*vo.TypeInfo) map[string]any {
result, err := parseDefaultOutput(ctx, data, schema_)
if err != nil {
fallback := make(map[string]any, len(schema_))
for k, v := range schema_ {
if v.Type == vo.DataTypeString {
fallback[k] = data
continue
}
fallback[k] = v.Zero()
}
return fallback
}
return result
}
func preTypeConverter(inTypes map[string]*vo.TypeInfo) func(ctx context.Context, in map[string]any) (map[string]any, error) {
return func(ctx context.Context, in map[string]any) (map[string]any, error) {
out, ws, err := nodes.ConvertInputs(ctx, in, inTypes)
if err != nil {
return nil, err
}
if ws != nil {
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
}
return out, err
}
}
func trimKeyFinishedMarker(ctx context.Context, in map[string]any) (map[string]any, bool, error) {
var (
newIn map[string]any
trimmed bool
)
for k, v := range in {
if vStr, ok := v.(string); ok {
if strings.HasSuffix(vStr, nodes.KeyIsFinished) {
if newIn == nil {
newIn = maps.Clone(in)
}
vStr = strings.TrimSuffix(vStr, nodes.KeyIsFinished)
newIn[k] = vStr
trimmed = true
}
} else if vMap, ok := v.(map[string]any); ok {
newMap, subTrimmed, err := trimKeyFinishedMarker(ctx, vMap)
if err != nil {
return nil, false, err
}
if subTrimmed {
if newIn == nil {
newIn = maps.Clone(in)
}
newIn[k] = newMap
trimmed = true
}
}
}
if trimmed {
return newIn, true, nil
}
return in, false, nil
}
func keyFinishedMarkerTrimmer() func(ctx context.Context, in map[string]any) (map[string]any, error) {
return func(ctx context.Context, in map[string]any) (map[string]any, error) {
out, _, err := trimKeyFinishedMarker(ctx, in)
return out, err
}
}

View File

@@ -0,0 +1,580 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package compose
import (
"context"
"fmt"
"runtime/debug"
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/batch"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/code"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/conversation"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/emitter"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/httprequester"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/intentdetector"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/json"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/knowledge"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/plugin"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/receiver"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/subworkflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/textprocessor"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableaggregator"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableassigner"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
"github.com/coze-dev/coze-studio/backend/types/errno"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
)
type NodeSchema struct {
Key vo.NodeKey `json:"key"`
Name string `json:"name"`
Type entity.NodeType `json:"type"`
// Configs are node specific configurations with pre-defined config key and config value.
// Will not participate in request-time field mapping, nor as node's static values.
// In a word, these Configs are INTERNAL to node's implementation, the workflow layer is not aware of them.
Configs any `json:"configs,omitempty"`
InputTypes map[string]*vo.TypeInfo `json:"input_types,omitempty"`
InputSources []*vo.FieldInfo `json:"input_sources,omitempty"`
OutputTypes map[string]*vo.TypeInfo `json:"output_types,omitempty"`
OutputSources []*vo.FieldInfo `json:"output_sources,omitempty"` // only applicable to composite nodes such as Batch or Loop
ExceptionConfigs *ExceptionConfig `json:"exception_configs,omitempty"` // generic configurations applicable to most nodes
StreamConfigs *StreamConfig `json:"stream_configs,omitempty"`
SubWorkflowBasic *entity.WorkflowBasic `json:"sub_workflow_basic,omitempty"`
SubWorkflowSchema *WorkflowSchema `json:"sub_workflow_schema,omitempty"`
Lambda *compose.Lambda // not serializable, used for internal test.
}
type ExceptionConfig struct {
TimeoutMS int64 `json:"timeout_ms,omitempty"` // timeout in milliseconds, 0 means no timeout
MaxRetry int64 `json:"max_retry,omitempty"` // max retry times, 0 means no retry
ProcessType *vo.ErrorProcessType `json:"process_type,omitempty"` // error process type, 0 means throw error
DataOnErr string `json:"data_on_err,omitempty"` // data to return when error, effective when ProcessType==Default occurs
}
type StreamConfig struct {
// whether this node has the ability to produce genuine streaming output.
// not include nodes that only passes stream down as they receives them
CanGeneratesStream bool `json:"can_generates_stream,omitempty"`
// whether this node prioritize streaming input over none-streaming input.
// not include nodes that can accept both and does not have preference.
RequireStreamingInput bool `json:"can_process_stream,omitempty"`
}
type Node struct {
Lambda *compose.Lambda
}
func (s *NodeSchema) New(ctx context.Context, inner compose.Runnable[map[string]any, map[string]any],
sc *WorkflowSchema, deps *dependencyInfo) (_ *Node, err error) {
defer func() {
if panicErr := recover(); panicErr != nil {
err = safego.NewPanicErr(panicErr, debug.Stack())
}
if err != nil {
err = vo.WrapIfNeeded(errno.ErrCreateNodeFail, err, errorx.KV("node_name", s.Name), errorx.KV("cause", err.Error()))
}
}()
if m := entity.NodeMetaByNodeType(s.Type); m != nil && m.InputSourceAware {
if err = s.SetFullSources(sc.GetAllNodes(), deps); err != nil {
return nil, err
}
}
switch s.Type {
case entity.NodeTypeLambda:
if s.Lambda == nil {
return nil, fmt.Errorf("lambda is not defined for NodeTypeLambda")
}
return &Node{Lambda: s.Lambda}, nil
case entity.NodeTypeLLM:
conf, err := s.ToLLMConfig(ctx)
if err != nil {
return nil, err
}
l, err := llm.New(ctx, conf)
if err != nil {
return nil, err
}
return invokableStreamableNodeWO(s, l.Chat, l.ChatStream, withCallbackOutputConverter(l.ToCallbackOutput)), nil
case entity.NodeTypeSelector:
conf := s.ToSelectorConfig()
sl, err := selector.NewSelector(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, sl.Select, withCallbackInputConverter(s.toSelectorCallbackInput(sc)), withCallbackOutputConverter(sl.ToCallbackOutput)), nil
case entity.NodeTypeBatch:
if inner == nil {
return nil, fmt.Errorf("inner workflow must not be nil when creating batch node")
}
conf, err := s.ToBatchConfig(inner)
if err != nil {
return nil, err
}
b, err := batch.NewBatch(ctx, conf)
if err != nil {
return nil, err
}
return invokableNodeWO(s, b.Execute, withCallbackInputConverter(b.ToCallbackInput)), nil
case entity.NodeTypeVariableAggregator:
conf, err := s.ToVariableAggregatorConfig()
if err != nil {
return nil, err
}
va, err := variableaggregator.NewVariableAggregator(ctx, conf)
if err != nil {
return nil, err
}
return invokableTransformableNode(s, va.Invoke, va.Transform,
withCallbackInputConverter(va.ToCallbackInput),
withCallbackOutputConverter(va.ToCallbackOutput),
withInit(va.Init)), nil
case entity.NodeTypeTextProcessor:
conf, err := s.ToTextProcessorConfig()
if err != nil {
return nil, err
}
tp, err := textprocessor.NewTextProcessor(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, tp.Invoke), nil
case entity.NodeTypeHTTPRequester:
conf, err := s.ToHTTPRequesterConfig()
if err != nil {
return nil, err
}
hr, err := httprequester.NewHTTPRequester(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, hr.Invoke, withCallbackInputConverter(hr.ToCallbackInput), withCallbackOutputConverter(hr.ToCallbackOutput)), nil
case entity.NodeTypeContinue:
i := func(ctx context.Context, in map[string]any) (map[string]any, error) {
return map[string]any{}, nil
}
return invokableNode(s, i), nil
case entity.NodeTypeBreak:
b, err := loop.NewBreak(ctx, &nodes.ParentIntermediateStore{})
if err != nil {
return nil, err
}
return invokableNode(s, b.DoBreak), nil
case entity.NodeTypeVariableAssigner:
handler := variable.GetVariableHandler()
conf, err := s.ToVariableAssignerConfig(handler)
if err != nil {
return nil, err
}
va, err := variableassigner.NewVariableAssigner(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, va.Assign), nil
case entity.NodeTypeVariableAssignerWithinLoop:
conf, err := s.ToVariableAssignerInLoopConfig()
if err != nil {
return nil, err
}
va, err := variableassigner.NewVariableAssignerInLoop(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, va.Assign), nil
case entity.NodeTypeLoop:
conf, err := s.ToLoopConfig(inner)
if err != nil {
return nil, err
}
l, err := loop.NewLoop(ctx, conf)
if err != nil {
return nil, err
}
return invokableNodeWO(s, l.Execute, withCallbackInputConverter(l.ToCallbackInput)), nil
case entity.NodeTypeQuestionAnswer:
conf, err := s.ToQAConfig(ctx)
if err != nil {
return nil, err
}
qA, err := qa.NewQuestionAnswer(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, qA.Execute, withCallbackOutputConverter(qA.ToCallbackOutput)), nil
case entity.NodeTypeInputReceiver:
conf, err := s.ToInputReceiverConfig()
if err != nil {
return nil, err
}
inputR, err := receiver.New(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, inputR.Invoke, withCallbackOutputConverter(inputR.ToCallbackOutput)), nil
case entity.NodeTypeOutputEmitter:
conf, err := s.ToOutputEmitterConfig(sc)
if err != nil {
return nil, err
}
e, err := emitter.New(ctx, conf)
if err != nil {
return nil, err
}
return invokableTransformableNode(s, e.Emit, e.EmitStream), nil
case entity.NodeTypeEntry:
conf, err := s.ToEntryConfig(ctx)
if err != nil {
return nil, err
}
e, err := entry.NewEntry(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, e.Invoke), nil
case entity.NodeTypeExit:
terminalPlan := mustGetKey[vo.TerminatePlan]("TerminalPlan", s.Configs)
if terminalPlan == vo.ReturnVariables {
i := func(ctx context.Context, in map[string]any) (map[string]any, error) {
if in == nil {
return map[string]any{}, nil
}
return in, nil
}
return invokableNode(s, i), nil
}
conf, err := s.ToOutputEmitterConfig(sc)
if err != nil {
return nil, err
}
e, err := emitter.New(ctx, conf)
if err != nil {
return nil, err
}
return invokableTransformableNode(s, e.Emit, e.EmitStream), nil
case entity.NodeTypeDatabaseCustomSQL:
conf, err := s.ToDatabaseCustomSQLConfig()
if err != nil {
return nil, err
}
sqlER, err := database.NewCustomSQL(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, sqlER.Execute), nil
case entity.NodeTypeDatabaseQuery:
conf, err := s.ToDatabaseQueryConfig()
if err != nil {
return nil, err
}
query, err := database.NewQuery(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, query.Query, withCallbackInputConverter(query.ToCallbackInput)), nil
case entity.NodeTypeDatabaseInsert:
conf, err := s.ToDatabaseInsertConfig()
if err != nil {
return nil, err
}
insert, err := database.NewInsert(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, insert.Insert, withCallbackInputConverter(insert.ToCallbackInput)), nil
case entity.NodeTypeDatabaseUpdate:
conf, err := s.ToDatabaseUpdateConfig()
if err != nil {
return nil, err
}
update, err := database.NewUpdate(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, update.Update, withCallbackInputConverter(update.ToCallbackInput)), nil
case entity.NodeTypeDatabaseDelete:
conf, err := s.ToDatabaseDeleteConfig()
if err != nil {
return nil, err
}
del, err := database.NewDelete(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, del.Delete, withCallbackInputConverter(del.ToCallbackInput)), nil
case entity.NodeTypeKnowledgeIndexer:
conf, err := s.ToKnowledgeIndexerConfig()
if err != nil {
return nil, err
}
w, err := knowledge.NewKnowledgeIndexer(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, w.Store), nil
case entity.NodeTypeKnowledgeRetriever:
conf, err := s.ToKnowledgeRetrieveConfig()
if err != nil {
return nil, err
}
r, err := knowledge.NewKnowledgeRetrieve(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, r.Retrieve), nil
case entity.NodeTypeKnowledgeDeleter:
conf, err := s.ToKnowledgeDeleterConfig()
if err != nil {
return nil, err
}
r, err := knowledge.NewKnowledgeDeleter(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, r.Delete), nil
case entity.NodeTypeCodeRunner:
conf, err := s.ToCodeRunnerConfig()
if err != nil {
return nil, err
}
r, err := code.NewCodeRunner(ctx, conf)
if err != nil {
return nil, err
}
initFn := func(ctx context.Context) (context.Context, error) {
return ctxcache.Init(ctx), nil
}
return invokableNode(s, r.RunCode, withCallbackOutputConverter(r.ToCallbackOutput), withInit(initFn)), nil
case entity.NodeTypePlugin:
conf, err := s.ToPluginConfig()
if err != nil {
return nil, err
}
r, err := plugin.NewPlugin(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, r.Invoke), nil
case entity.NodeTypeCreateConversation:
conf, err := s.ToCreateConversationConfig()
if err != nil {
return nil, err
}
r, err := conversation.NewCreateConversation(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, r.Create), nil
case entity.NodeTypeMessageList:
conf, err := s.ToMessageListConfig()
if err != nil {
return nil, err
}
r, err := conversation.NewMessageList(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, r.List), nil
case entity.NodeTypeClearMessage:
conf, err := s.ToClearMessageConfig()
if err != nil {
return nil, err
}
r, err := conversation.NewClearMessage(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, r.Clear), nil
case entity.NodeTypeIntentDetector:
conf, err := s.ToIntentDetectorConfig(ctx)
if err != nil {
return nil, err
}
r, err := intentdetector.NewIntentDetector(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, r.Invoke), nil
case entity.NodeTypeSubWorkflow:
conf, err := s.ToSubWorkflowConfig(ctx, sc.requireCheckPoint)
if err != nil {
return nil, err
}
r, err := subworkflow.NewSubWorkflow(ctx, conf)
if err != nil {
return nil, err
}
return invokableStreamableNodeWO(s, r.Invoke, r.Stream), nil
case entity.NodeTypeJsonSerialization:
conf, err := s.ToJsonSerializationConfig()
if err != nil {
return nil, err
}
js, err := json.NewJsonSerializer(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, js.Invoke), nil
case entity.NodeTypeJsonDeserialization:
conf, err := s.ToJsonDeserializationConfig()
if err != nil {
return nil, err
}
jd, err := json.NewJsonDeserializer(ctx, conf)
if err != nil {
return nil, err
}
initFn := func(ctx context.Context) (context.Context, error) {
return ctxcache.Init(ctx), nil
}
return invokableNode(s, jd.Invoke, withCallbackOutputConverter(jd.ToCallbackOutput), withInit(initFn)), nil
default:
panic("not implemented")
}
}
func (s *NodeSchema) IsEnableUserQuery() bool {
if s == nil {
return false
}
if s.Type != entity.NodeTypeEntry {
return false
}
if len(s.OutputSources) == 0 {
return false
}
for _, source := range s.OutputSources {
fieldPath := source.Path
if len(fieldPath) == 1 && (fieldPath[0] == "BOT_USER_INPUT" || fieldPath[0] == "USER_INPUT") {
return true
}
}
return false
}
func (s *NodeSchema) IsEnableChatHistory() bool {
if s == nil {
return false
}
switch s.Type {
case entity.NodeTypeLLM:
llmParam := mustGetKey[*model.LLMParams]("LLMParams", s.Configs)
return llmParam.EnableChatHistory
case entity.NodeTypeIntentDetector:
llmParam := mustGetKey[*model.LLMParams]("LLMParams", s.Configs)
return llmParam.EnableChatHistory
default:
return false
}
}
func (s *NodeSchema) IsRefGlobalVariable() bool {
for _, source := range s.InputSources {
if source.IsRefGlobalVariable() {
return true
}
}
for _, source := range s.OutputSources {
if source.IsRefGlobalVariable() {
return true
}
}
return false
}
func (s *NodeSchema) requireCheckpoint() bool {
if s.Type == entity.NodeTypeQuestionAnswer || s.Type == entity.NodeTypeInputReceiver {
return true
}
if s.Type == entity.NodeTypeLLM {
fcParams := getKeyOrZero[*vo.FCParam]("FCParam", s.Configs)
if fcParams != nil && fcParams.WorkflowFCParam != nil {
return true
}
}
if s.Type == entity.NodeTypeSubWorkflow {
s.SubWorkflowSchema.Init()
if s.SubWorkflowSchema.requireCheckPoint {
return true
}
}
return false
}

View File

@@ -0,0 +1,939 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package compose
import (
"context"
"fmt"
"strconv"
"strings"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
workflow2 "github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/receiver"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableassigner"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
)
type State struct {
Answers map[vo.NodeKey][]string `json:"answers,omitempty"`
Questions map[vo.NodeKey][]*qa.Question `json:"questions,omitempty"`
Inputs map[vo.NodeKey]map[string]any `json:"inputs,omitempty"`
NodeExeContexts map[vo.NodeKey]*execute.Context `json:"-"`
WorkflowExeContext *execute.Context `json:"-"`
InterruptEvents map[vo.NodeKey]*entity.InterruptEvent `json:"interrupt_events,omitempty"`
NestedWorkflowStates map[vo.NodeKey]*nodes.NestedWorkflowState `json:"nested_workflow_states,omitempty"`
ExecutedNodes map[vo.NodeKey]bool `json:"executed_nodes,omitempty"`
SourceInfos map[vo.NodeKey]map[string]*nodes.SourceInfo `json:"source_infos,omitempty"`
GroupChoices map[vo.NodeKey]map[string]int `json:"group_choices,omitempty"`
ToolInterruptEvents map[vo.NodeKey]map[string] /*ToolCallID*/ *entity.ToolInterruptEvent `json:"tool_interrupt_events,omitempty"`
LLMToResumeData map[vo.NodeKey]string `json:"llm_to_resume_data,omitempty"`
AppVariableStore *variableassigner.AppVariables `json:"variable_app_store,omitempty"`
}
func init() {
_ = compose.RegisterSerializableType[*State]("schema_state")
_ = compose.RegisterSerializableType[[]*qa.Question]("qa_question_list")
_ = compose.RegisterSerializableType[qa.Question]("qa_question")
_ = compose.RegisterSerializableType[vo.NodeKey]("node_key")
_ = compose.RegisterSerializableType[*execute.Context]("exe_context")
_ = compose.RegisterSerializableType[execute.RootCtx]("root_ctx")
_ = compose.RegisterSerializableType[*execute.SubWorkflowCtx]("sub_workflow_ctx")
_ = compose.RegisterSerializableType[*execute.NodeCtx]("node_ctx")
_ = compose.RegisterSerializableType[*execute.BatchInfo]("batch_info")
_ = compose.RegisterSerializableType[*execute.TokenCollector]("token_collector")
_ = compose.RegisterSerializableType[entity.NodeType]("node_type")
_ = compose.RegisterSerializableType[*entity.InterruptEvent]("interrupt_event")
_ = compose.RegisterSerializableType[workflow2.EventType]("workflow_event_type")
_ = compose.RegisterSerializableType[*model.TokenUsage]("model_token_usage")
_ = compose.RegisterSerializableType[*nodes.NestedWorkflowState]("composite_state")
_ = compose.RegisterSerializableType[*compose.InterruptInfo]("interrupt_info")
_ = compose.RegisterSerializableType[*nodes.SourceInfo]("source_info")
_ = compose.RegisterSerializableType[nodes.FieldStreamType]("field_stream_type")
_ = compose.RegisterSerializableType[compose.FieldPath]("field_path")
_ = compose.RegisterSerializableType[*entity.WorkflowBasic]("workflow_basic")
_ = compose.RegisterSerializableType[vo.TerminatePlan]("terminate_plan")
_ = compose.RegisterSerializableType[*entity.ToolInterruptEvent]("tool_interrupt_event")
_ = compose.RegisterSerializableType[vo.ExecuteConfig]("execute_config")
_ = compose.RegisterSerializableType[vo.ExecuteMode]("execute_mode")
_ = compose.RegisterSerializableType[vo.TaskType]("task_type")
_ = compose.RegisterSerializableType[vo.SyncPattern]("sync_pattern")
_ = compose.RegisterSerializableType[vo.Locator]("wf_locator")
_ = compose.RegisterSerializableType[vo.BizType]("biz_type")
_ = compose.RegisterSerializableType[*variableassigner.AppVariables]("app_variables")
}
func (s *State) SetAppVariableValue(key string, value any) {
s.AppVariableStore.Set(key, value)
}
func (s *State) GetAppVariableValue(key string) (any, bool) {
return s.AppVariableStore.Get(key)
}
func (s *State) AddQuestion(nodeKey vo.NodeKey, question *qa.Question) {
s.Questions[nodeKey] = append(s.Questions[nodeKey], question)
}
func (s *State) AddAnswer(nodeKey vo.NodeKey, answer string) {
s.Answers[nodeKey] = append(s.Answers[nodeKey], answer)
}
func (s *State) GetQuestionsAndAnswers(nodeKey vo.NodeKey) ([]*qa.Question, []string) {
return s.Questions[nodeKey], s.Answers[nodeKey]
}
func (s *State) GetNodeCtx(key vo.NodeKey) (*execute.Context, bool, error) {
c, ok := s.NodeExeContexts[key]
if ok {
return c, true, nil
}
return nil, false, nil
}
func (s *State) SetNodeCtx(key vo.NodeKey, value *execute.Context) error {
s.NodeExeContexts[key] = value
return nil
}
func (s *State) GetWorkflowCtx() (*execute.Context, bool, error) {
if s.WorkflowExeContext == nil {
return nil, false, nil
}
return s.WorkflowExeContext, true, nil
}
func (s *State) SetWorkflowCtx(value *execute.Context) error {
s.WorkflowExeContext = value
return nil
}
func (s *State) GetInterruptEvent(nodeKey vo.NodeKey) (*entity.InterruptEvent, bool, error) {
if v, ok := s.InterruptEvents[nodeKey]; ok {
return v, true, nil
}
return nil, false, nil
}
func (s *State) SetInterruptEvent(nodeKey vo.NodeKey, value *entity.InterruptEvent) error {
s.InterruptEvents[nodeKey] = value
return nil
}
func (s *State) DeleteInterruptEvent(nodeKey vo.NodeKey) error {
delete(s.InterruptEvents, nodeKey)
return nil
}
func (s *State) GetNestedWorkflowState(key vo.NodeKey) (*nodes.NestedWorkflowState, bool, error) {
if v, ok := s.NestedWorkflowStates[key]; ok {
return v, true, nil
}
return nil, false, nil
}
func (s *State) SaveNestedWorkflowState(key vo.NodeKey, value *nodes.NestedWorkflowState) error {
s.NestedWorkflowStates[key] = value
return nil
}
func (s *State) SaveDynamicChoice(nodeKey vo.NodeKey, groupToChoice map[string]int) {
s.GroupChoices[nodeKey] = groupToChoice
}
func (s *State) GetDynamicChoice(nodeKey vo.NodeKey) map[string]int {
return s.GroupChoices[nodeKey]
}
func (s *State) GetDynamicStreamType(nodeKey vo.NodeKey, group string) (nodes.FieldStreamType, error) {
choices, ok := s.GroupChoices[nodeKey]
if !ok {
return nodes.FieldMaybeStream, fmt.Errorf("choice not found for node %s", nodeKey)
}
choice, ok := choices[group]
if !ok {
return nodes.FieldMaybeStream, fmt.Errorf("choice not found for node %s and group %s", nodeKey, group)
}
if choice == -1 { // this group picks none of the elements
return nodes.FieldNotStream, nil
}
sInfos, ok := s.SourceInfos[nodeKey]
if !ok {
return nodes.FieldMaybeStream, fmt.Errorf("source infos not found for node %s", nodeKey)
}
groupInfo, ok := sInfos[group]
if !ok {
return nodes.FieldMaybeStream, fmt.Errorf("source infos not found for node %s and group %s", nodeKey, group)
}
if groupInfo.SubSources == nil {
return nodes.FieldNotStream, fmt.Errorf("dynamic group %s of node %s does not contain any sub sources", group, nodeKey)
}
subInfo, ok := groupInfo.SubSources[strconv.Itoa(choice)]
if !ok {
return nodes.FieldNotStream, fmt.Errorf("dynamic group %s of node %s does not contain sub source for choice %d", group, nodeKey, choice)
}
if subInfo.FieldType != nodes.FieldMaybeStream {
return subInfo.FieldType, nil
}
if len(subInfo.FromNodeKey) == 0 {
panic("subInfo is maybe stream, but from node key is empty")
}
if len(subInfo.FromPath) > 1 || len(subInfo.FromPath) == 0 {
panic("subInfo is maybe stream, but from path is more than 1 segments or is empty")
}
return s.GetDynamicStreamType(subInfo.FromNodeKey, subInfo.FromPath[0])
}
func (s *State) GetAllDynamicStreamTypes(nodeKey vo.NodeKey) (map[string]nodes.FieldStreamType, error) {
result := make(map[string]nodes.FieldStreamType)
choices, ok := s.GroupChoices[nodeKey]
if !ok {
return result, nil
}
for group := range choices {
t, err := s.GetDynamicStreamType(nodeKey, group)
if err != nil {
return nil, err
}
result[group] = t
}
return result, nil
}
func (s *State) SetToolInterruptEvent(llmNodeKey vo.NodeKey, toolCallID string, ie *entity.ToolInterruptEvent) error {
if _, ok := s.ToolInterruptEvents[llmNodeKey]; !ok {
s.ToolInterruptEvents[llmNodeKey] = make(map[string]*entity.ToolInterruptEvent)
}
s.ToolInterruptEvents[llmNodeKey][toolCallID] = ie
return nil
}
func (s *State) GetToolInterruptEvents(llmNodeKey vo.NodeKey) (map[string]*entity.ToolInterruptEvent, error) {
return s.ToolInterruptEvents[llmNodeKey], nil
}
func (s *State) ResumeToolInterruptEvent(llmNodeKey vo.NodeKey, toolCallID string) (string, error) {
resumeData, ok := s.LLMToResumeData[llmNodeKey]
if !ok {
return "", fmt.Errorf("resume data not found for llm node %s", llmNodeKey)
}
delete(s.ToolInterruptEvents[llmNodeKey], toolCallID)
delete(s.LLMToResumeData, llmNodeKey)
return resumeData, nil
}
func (s *State) NodeExecuted(key vo.NodeKey) bool {
if key == compose.START {
return true
}
_, ok := s.ExecutedNodes[key]
return ok
}
func GenState() compose.GenLocalState[*State] {
return func(ctx context.Context) (state *State) {
var parentState *State
_ = compose.ProcessState(ctx, func(ctx context.Context, s *State) error {
parentState = s
return nil
})
var appVariableStore *variableassigner.AppVariables
if parentState == nil {
appVariableStore = variableassigner.NewAppVariables()
} else {
appVariableStore = parentState.AppVariableStore
}
return &State{
Answers: make(map[vo.NodeKey][]string),
Questions: make(map[vo.NodeKey][]*qa.Question),
Inputs: make(map[vo.NodeKey]map[string]any),
NodeExeContexts: make(map[vo.NodeKey]*execute.Context),
InterruptEvents: make(map[vo.NodeKey]*entity.InterruptEvent),
NestedWorkflowStates: make(map[vo.NodeKey]*nodes.NestedWorkflowState),
ExecutedNodes: make(map[vo.NodeKey]bool),
SourceInfos: make(map[vo.NodeKey]map[string]*nodes.SourceInfo),
GroupChoices: make(map[vo.NodeKey]map[string]int),
ToolInterruptEvents: make(map[vo.NodeKey]map[string]*entity.ToolInterruptEvent),
LLMToResumeData: make(map[vo.NodeKey]string),
AppVariableStore: appVariableStore,
}
}
}
func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt {
var (
handlers []compose.StatePreHandler[map[string]any, *State]
streamHandlers []compose.StreamStatePreHandler[map[string]any, *State]
)
if s.Type == entity.NodeTypeQuestionAnswer {
handlers = append(handlers, func(ctx context.Context, in map[string]any, state *State) (map[string]any, error) {
// even on first execution before any interruption, the input could be empty
// so we need to check if we have stored any questions in state, to decide whether this is the first execution
isFirst := false
if _, ok := state.Questions[s.Key]; !ok {
isFirst = true
}
if isFirst {
state.Inputs[s.Key] = in
return in, nil
}
out := make(map[string]any)
for k, v := range state.Inputs[s.Key] {
out[k] = v
}
out[qa.QuestionsKey] = state.Questions[s.Key]
out[qa.AnswersKey] = state.Answers[s.Key]
return out, nil
})
} else if s.Type == entity.NodeTypeInputReceiver {
// InputReceiver node's only input is set by StateModifier when resuming
handlers = append(handlers, func(ctx context.Context, in map[string]any, state *State) (map[string]any, error) {
if userInput, ok := state.Inputs[s.Key]; ok && len(userInput) > 0 {
return userInput, nil
}
return in, nil
})
} else if s.Type == entity.NodeTypeBatch || s.Type == entity.NodeTypeLoop {
handlers = append(handlers, func(ctx context.Context, in map[string]any, state *State) (map[string]any, error) {
if _, ok := state.Inputs[s.Key]; !ok { // first execution, store input for potential resume later
state.Inputs[s.Key] = in
return in, nil
}
out := make(map[string]any)
for k, v := range state.Inputs[s.Key] {
out[k] = v
}
return out, nil
})
}
if len(handlers) > 0 || !stream {
handlerForVars := s.statePreHandlerForVars()
if handlerForVars != nil {
handlers = append(handlers, handlerForVars)
}
stateHandler := func(ctx context.Context, in map[string]any, state *State) (map[string]any, error) {
var err error
for _, h := range handlers {
in, err = h(ctx, in, state)
if err != nil {
return nil, err
}
}
return in, nil
}
return compose.WithStatePreHandler(stateHandler)
}
if s.Type == entity.NodeTypeVariableAggregator {
streamHandlers = append(streamHandlers, func(ctx context.Context, in *schema.StreamReader[map[string]any], state *State) (*schema.StreamReader[map[string]any], error) {
state.SourceInfos[s.Key] = mustGetKey[map[string]*nodes.SourceInfo]("FullSources", s.Configs)
return in, nil
})
}
handlerForVars := s.streamStatePreHandlerForVars()
if handlerForVars != nil {
streamHandlers = append(streamHandlers, handlerForVars)
}
/*handlerForStreamSource := s.streamStatePreHandlerForStreamSources()
if handlerForStreamSource != nil {
streamHandlers = append(streamHandlers, handlerForStreamSource)
}*/
if len(streamHandlers) > 0 {
streamHandler := func(ctx context.Context, in *schema.StreamReader[map[string]any], state *State) (*schema.StreamReader[map[string]any], error) {
var err error
for _, h := range streamHandlers {
in, err = h(ctx, in, state)
if err != nil {
return nil, err
}
}
return in, nil
}
return compose.WithStreamStatePreHandler(streamHandler)
}
return nil
}
func (s *NodeSchema) statePreHandlerForVars() compose.StatePreHandler[map[string]any, *State] {
// checkout the node's inputs, if it has any variable, use the state's variableHandler to get the variables and set them to the input
var vars []*vo.FieldInfo
for _, input := range s.InputSources {
if input.Source.Ref != nil && input.Source.Ref.VariableType != nil {
vars = append(vars, input)
}
}
if len(vars) == 0 {
return nil
}
varStoreHandler := variable.GetVariableHandler()
intermediateVarStore := &nodes.ParentIntermediateStore{}
return func(ctx context.Context, in map[string]any, state *State) (map[string]any, error) {
opts := make([]variable.OptionFn, 0, 1)
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil {
exeCfg := execute.GetExeCtx(ctx).RootCtx.ExeCfg
opts = append(opts, variable.WithStoreInfo(variable.StoreInfo{
AgentID: exeCfg.AgentID,
AppID: exeCfg.AppID,
ConnectorID: exeCfg.ConnectorID,
ConnectorUID: exeCfg.ConnectorUID,
}))
}
out := make(map[string]any)
for k, v := range in {
out[k] = v
}
for _, input := range vars {
if input == nil {
continue
}
var v any
var err error
switch *input.Source.Ref.VariableType {
case vo.ParentIntermediate:
v, err = intermediateVarStore.Get(ctx, input.Source.Ref.FromPath, opts...)
case vo.GlobalSystem, vo.GlobalUser:
v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...)
case vo.GlobalAPP:
var ok bool
path := strings.Join(input.Source.Ref.FromPath, ".")
if v, ok = state.GetAppVariableValue(path); !ok {
v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...)
if err != nil {
return nil, err
}
state.SetAppVariableValue(path, v)
}
default:
return nil, fmt.Errorf("invalid variable type: %v", *input.Source.Ref.VariableType)
}
if err != nil {
return nil, err
}
nodes.SetMapValue(out, input.Path, v)
}
return out, nil
}
}
func (s *NodeSchema) streamStatePreHandlerForVars() compose.StreamStatePreHandler[map[string]any, *State] {
// checkout the node's inputs, if it has any variables, get the variables and merge them with the input
var vars []*vo.FieldInfo
for _, input := range s.InputSources {
if input.Source.Ref != nil && input.Source.Ref.VariableType != nil {
vars = append(vars, input)
}
}
if len(vars) == 0 {
return nil
}
varStoreHandler := variable.GetVariableHandler()
intermediateVarStore := &nodes.ParentIntermediateStore{}
return func(ctx context.Context, in *schema.StreamReader[map[string]any], state *State) (*schema.StreamReader[map[string]any], error) {
var (
variables = make(map[string]any)
opts = make([]variable.OptionFn, 0, 1)
exeCfg = execute.GetExeCtx(ctx).RootCtx.ExeCfg
)
opts = append(opts, variable.WithStoreInfo(variable.StoreInfo{
AgentID: exeCfg.AgentID,
AppID: exeCfg.AppID,
ConnectorID: exeCfg.ConnectorID,
ConnectorUID: exeCfg.ConnectorUID,
}))
for _, input := range vars {
if input == nil {
continue
}
var v any
var err error
switch *input.Source.Ref.VariableType {
case vo.ParentIntermediate:
v, err = intermediateVarStore.Get(ctx, input.Source.Ref.FromPath, opts...)
case vo.GlobalSystem, vo.GlobalUser:
v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...)
case vo.GlobalAPP:
var ok bool
path := strings.Join(input.Source.Ref.FromPath, ".")
if v, ok = state.GetAppVariableValue(path); !ok {
v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...)
if err != nil {
return nil, err
}
state.SetAppVariableValue(path, v)
}
default:
return nil, fmt.Errorf("invalid variable type: %v", *input.Source.Ref.VariableType)
}
if err != nil {
return nil, err
}
nodes.SetMapValue(variables, input.Path, v)
}
variablesStream := schema.StreamReaderFromArray([]map[string]any{variables})
return schema.MergeStreamReaders([]*schema.StreamReader[map[string]any]{in, variablesStream}), nil
}
}
func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamStatePreHandler[map[string]any, *State] {
// if it does not have source info, do not add this pre handler
if s.Configs == nil {
return nil
}
switch s.Type {
case entity.NodeTypeVariableAggregator, entity.NodeTypeOutputEmitter:
return nil
case entity.NodeTypeExit:
terminatePlan := mustGetKey[vo.TerminatePlan]("TerminalPlan", s.Configs)
if terminatePlan != vo.ReturnVariables {
return nil
}
default:
// all other node can only accept non-stream inputs, relying on Eino's automatically stream concatenation.
}
sourceInfo := getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs)
if len(sourceInfo) == 0 {
return nil
}
// check the node's input sources, if it does not have any streaming sources, no need to add pre handler
// if one input is a stream, then in the pre handler, will trim the KeyIsFinished suffix.
// if one input may be a stream, then in the pre handler, will resolve it first, then handle it.
type resolvedStreamSource struct {
intermediate bool
mustBeStream bool
subStreamSources map[string]resolvedStreamSource
}
var (
anyStream bool
checker func(source *nodes.SourceInfo) bool
)
checker = func(source *nodes.SourceInfo) bool {
if source.FieldType != nodes.FieldNotStream {
return true
}
for _, subSource := range source.SubSources {
if subAnyStream := checker(subSource); subAnyStream {
return true
}
}
return false
}
for _, source := range sourceInfo {
if hasStream := checker(source); hasStream {
anyStream = true
break
}
}
if !anyStream {
return nil
}
return func(ctx context.Context, in *schema.StreamReader[map[string]any], state *State) (*schema.StreamReader[map[string]any], error) {
resolved := map[string]resolvedStreamSource{}
var resolver func(source nodes.SourceInfo) (result *resolvedStreamSource, err error)
resolver = func(source nodes.SourceInfo) (result *resolvedStreamSource, err error) {
if source.IsIntermediate {
result = &resolvedStreamSource{
intermediate: true,
subStreamSources: map[string]resolvedStreamSource{},
}
for key, subSource := range source.SubSources {
subResult, subE := resolver(*subSource)
if subE != nil {
return nil, subE
}
if subResult != nil {
result.subStreamSources[key] = *subResult
}
}
return result, nil
}
streamType := source.FieldType
if streamType == nodes.FieldMaybeStream {
streamType, err = state.GetDynamicStreamType(source.FromNodeKey, source.FromPath[0])
if err != nil {
return nil, err
}
}
if streamType == nodes.FieldNotStream {
return nil, nil
}
result = &resolvedStreamSource{
mustBeStream: true,
}
return result, nil
}
for key, source := range sourceInfo {
result, err := resolver(*source)
if err != nil {
return nil, err
}
if result != nil {
resolved[key] = *result
}
}
var converter func(v any, resolvedSource resolvedStreamSource) any
converter = func(v any, resolvedSource resolvedStreamSource) any {
if resolvedSource.intermediate {
vMap, ok := v.(map[string]any)
if !ok {
panic("intermediate value is not map[string]any")
}
outMap := make(map[string]any, len(vMap))
for k := range vMap {
subResolvedSource, ok := resolvedSource.subStreamSources[k]
if !ok { // not a stream field
outMap[k] = vMap[k]
continue
}
subV := converter(vMap[k], subResolvedSource)
outMap[k] = subV
}
return outMap
}
vStr, ok := v.(string)
if !ok {
panic("stream field is not string")
}
return strings.TrimSuffix(vStr, nodes.KeyIsFinished)
}
streamConverter := func(inChunk map[string]any) (outChunk map[string]any, chunkErr error) {
outChunk = make(map[string]any, len(inChunk))
for k, v := range inChunk {
if resolvedSource, ok := resolved[k]; !ok {
outChunk[k] = v // not a stream field
} else {
vOut := converter(v, resolvedSource)
outChunk[k] = vOut
}
}
return outChunk, nil
}
return schema.StreamReaderWithConvert(in, streamConverter), nil
}
}
func (s *NodeSchema) StatePostHandler(stream bool) compose.GraphAddNodeOpt {
var (
handlers []compose.StatePostHandler[map[string]any, *State]
streamHandlers []compose.StreamStatePostHandler[map[string]any, *State]
)
if stream {
streamHandlers = append(streamHandlers, func(ctx context.Context, out *schema.StreamReader[map[string]any], state *State) (*schema.StreamReader[map[string]any], error) {
state.ExecutedNodes[s.Key] = true
return out, nil
})
forVars := s.streamStatePostHandlerForVars()
if forVars != nil {
streamHandlers = append(streamHandlers, forVars)
}
streamHandler := func(ctx context.Context, in *schema.StreamReader[map[string]any], state *State) (*schema.StreamReader[map[string]any], error) {
var err error
for _, h := range streamHandlers {
in, err = h(ctx, in, state)
if err != nil {
return nil, err
}
}
return in, nil
}
return compose.WithStreamStatePostHandler(streamHandler)
}
handlers = append(handlers, func(ctx context.Context, out map[string]any, state *State) (map[string]any, error) {
state.ExecutedNodes[s.Key] = true
return out, nil
})
forVars := s.statePostHandlerForVars()
if forVars != nil {
handlers = append(handlers, forVars)
}
handler := func(ctx context.Context, in map[string]any, state *State) (map[string]any, error) {
var err error
for _, h := range handlers {
in, err = h(ctx, in, state)
if err != nil {
return nil, err
}
}
return in, nil
}
return compose.WithStatePostHandler(handler)
}
func (s *NodeSchema) statePostHandlerForVars() compose.StatePostHandler[map[string]any, *State] {
// checkout the node's output sources, if it has any variable,
// use the state's variableHandler to get the variables and set them to the output
var vars []*vo.FieldInfo
for _, output := range s.OutputSources {
if output.Source.Ref != nil && output.Source.Ref.VariableType != nil {
// intermediate vars are handled within nodes themselves
if *output.Source.Ref.VariableType == vo.ParentIntermediate {
continue
}
vars = append(vars, output)
}
}
if len(vars) == 0 {
return nil
}
varStoreHandler := variable.GetVariableHandler()
return func(ctx context.Context, in map[string]any, state *State) (map[string]any, error) {
opts := make([]variable.OptionFn, 0, 1)
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil {
exeCfg := execute.GetExeCtx(ctx).RootCtx.ExeCfg
opts = append(opts, variable.WithStoreInfo(variable.StoreInfo{
AgentID: exeCfg.AgentID,
AppID: exeCfg.AppID,
ConnectorID: exeCfg.ConnectorID,
ConnectorUID: exeCfg.ConnectorUID,
}))
}
out := make(map[string]any)
for k, v := range in {
out[k] = v
}
for _, input := range vars {
if input == nil {
continue
}
var v any
var err error
switch *input.Source.Ref.VariableType {
case vo.GlobalSystem, vo.GlobalUser:
v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...)
case vo.GlobalAPP:
var ok bool
path := strings.Join(input.Source.Ref.FromPath, ".")
if v, ok = state.GetAppVariableValue(path); !ok {
v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...)
if err != nil {
return nil, err
}
state.SetAppVariableValue(path, v)
}
default:
return nil, fmt.Errorf("invalid variable type: %v", *input.Source.Ref.VariableType)
}
if err != nil {
return nil, err
}
nodes.SetMapValue(out, input.Path, v)
}
return out, nil
}
}
func (s *NodeSchema) streamStatePostHandlerForVars() compose.StreamStatePostHandler[map[string]any, *State] {
// checkout the node's output sources, if it has any variables, get the variables and merge them with the output
var vars []*vo.FieldInfo
for _, output := range s.OutputSources {
if output.Source.Ref != nil && output.Source.Ref.VariableType != nil {
// intermediate vars are handled within nodes themselves
if *output.Source.Ref.VariableType == vo.ParentIntermediate {
continue
}
vars = append(vars, output)
}
}
if len(vars) == 0 {
return nil
}
varStoreHandler := variable.GetVariableHandler()
return func(ctx context.Context, in *schema.StreamReader[map[string]any], state *State) (*schema.StreamReader[map[string]any], error) {
var (
variables = make(map[string]any)
opts = make([]variable.OptionFn, 0, 1)
)
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil {
exeCfg := execute.GetExeCtx(ctx).RootCtx.ExeCfg
opts = append(opts, variable.WithStoreInfo(variable.StoreInfo{
AgentID: exeCfg.AgentID,
AppID: exeCfg.AppID,
ConnectorID: exeCfg.ConnectorID,
ConnectorUID: exeCfg.ConnectorUID,
}))
}
for _, input := range vars {
if input == nil {
continue
}
var v any
var err error
switch *input.Source.Ref.VariableType {
case vo.GlobalSystem, vo.GlobalUser:
v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...)
case vo.GlobalAPP:
var ok bool
path := strings.Join(input.Source.Ref.FromPath, ".")
if v, ok = state.GetAppVariableValue(path); !ok {
v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...)
if err != nil {
return nil, err
}
state.SetAppVariableValue(path, v)
}
default:
return nil, fmt.Errorf("invalid variable type: %v", *input.Source.Ref.VariableType)
}
if err != nil {
return nil, err
}
nodes.SetMapValue(variables, input.Path, v)
}
variablesStream := schema.StreamReaderFromArray([]map[string]any{variables})
return schema.MergeStreamReaders([]*schema.StreamReader[map[string]any]{in, variablesStream}), nil
}
}
func GenStateModifierByEventType(e entity.InterruptEventType,
nodeKey vo.NodeKey,
resumeData string,
exeCfg vo.ExecuteConfig) (stateModifier compose.StateModifier) {
// TODO: can we unify them all to a map[NodeKey]resumeData?
switch e {
case entity.InterruptEventInput:
stateModifier = func(ctx context.Context, path compose.NodePath, state any) (err error) {
if exeCfg.BizType == vo.BizTypeAgent {
m := make(map[string]any)
sList := strings.Split(resumeData, "\n")
for _, s := range sList {
firstColon := strings.Index(s, ":")
k := s[:firstColon]
v := s[firstColon+1:]
m[k] = v
}
resumeData, err = sonic.MarshalString(m)
if err != nil {
return err
}
}
input := map[string]any{
receiver.ReceivedDataKey: resumeData,
}
state.(*State).Inputs[nodeKey] = input
return nil
}
case entity.InterruptEventQuestion:
stateModifier = func(ctx context.Context, path compose.NodePath, state any) error {
state.(*State).AddAnswer(nodeKey, resumeData)
return nil
}
case entity.InterruptEventLLM:
stateModifier = func(ctx context.Context, path compose.NodePath, state any) error {
state.(*State).LLMToResumeData[nodeKey] = resumeData
return nil
}
default:
panic(fmt.Sprintf("unimplemented interrupt event type: %v", e))
}
return stateModifier
}

View File

@@ -0,0 +1,292 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package compose
import (
"fmt"
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
)
// SetFullSources calculates REAL input sources for a node.
// It may be different from a NodeSchema's InputSources because of the following reasons:
// 1. a inner node under composite node may refer to a field from a node in its parent workflow,
// this is instead routed to and sourced from the inner workflow's start node.
// 2. at the same time, the composite node needs to delegate the input source to the inner workflow.
// 3. also, some node may have implicit input sources not defined in its NodeSchema's InputSources.
func (s *NodeSchema) SetFullSources(allNS map[vo.NodeKey]*NodeSchema, dep *dependencyInfo) error {
fullSource := make(map[string]*nodes.SourceInfo)
var fieldInfos []vo.FieldInfo
for _, s := range dep.staticValues {
fieldInfos = append(fieldInfos, vo.FieldInfo{
Path: s.path,
Source: vo.FieldSource{Val: s.val},
})
}
for _, v := range dep.variableInfos {
fieldInfos = append(fieldInfos, vo.FieldInfo{
Path: v.toPath,
Source: vo.FieldSource{
Ref: &vo.Reference{
VariableType: &v.varType,
FromPath: v.fromPath[1:],
},
},
})
}
for f := range dep.inputsFull {
fieldInfos = append(fieldInfos, vo.FieldInfo{
Path: []string{""},
Source: vo.FieldSource{Ref: &vo.Reference{
FromNodeKey: f,
FromPath: []string{""},
}},
})
}
for f, ms := range dep.inputs {
for _, m := range ms {
fieldInfos = append(fieldInfos, vo.FieldInfo{
Path: m.ToPath(),
Source: vo.FieldSource{Ref: &vo.Reference{
FromNodeKey: f,
FromPath: m.FromPath(),
}},
})
}
}
for f := range dep.inputsNoDirectDependencyFull {
fieldInfos = append(fieldInfos, vo.FieldInfo{
Path: []string{""},
Source: vo.FieldSource{Ref: &vo.Reference{
FromNodeKey: f,
FromPath: []string{""},
}},
})
}
for f, ms := range dep.inputsNoDirectDependency {
for _, m := range ms {
fieldInfos = append(fieldInfos, vo.FieldInfo{
Path: m.ToPath(),
Source: vo.FieldSource{Ref: &vo.Reference{
FromNodeKey: f,
FromPath: m.FromPath(),
}},
})
}
}
for i := range fieldInfos {
fInfo := fieldInfos[i]
path := fInfo.Path
currentSource := fullSource
var (
tInfo *vo.TypeInfo
lastPath string
)
if len(path) > 1 {
tInfo = s.InputTypes[path[0]]
for j := 0; j < len(path)-1; j++ {
if j > 0 {
tInfo = tInfo.Properties[path[j]]
}
if current, ok := currentSource[path[j]]; !ok {
currentSource[path[j]] = &nodes.SourceInfo{
IsIntermediate: true,
FieldType: nodes.FieldNotStream,
TypeInfo: tInfo,
SubSources: make(map[string]*nodes.SourceInfo),
}
} else if !current.IsIntermediate {
return fmt.Errorf("existing sourceInfo for path %s is not intermediate, conflict", path[:j+1])
}
currentSource = currentSource[path[j]].SubSources
}
lastPath = path[len(path)-1]
tInfo = tInfo.Properties[lastPath]
} else {
lastPath = path[0]
tInfo = s.InputTypes[lastPath]
}
// static values or variables
if fInfo.Source.Ref == nil || fInfo.Source.Ref.FromNodeKey == "" {
currentSource[lastPath] = &nodes.SourceInfo{
IsIntermediate: false,
FieldType: nodes.FieldNotStream,
TypeInfo: tInfo,
}
continue
}
fromNodeKey := fInfo.Source.Ref.FromNodeKey
var (
streamType nodes.FieldStreamType
err error
)
if len(fromNodeKey) > 0 {
if fromNodeKey == compose.START {
streamType = nodes.FieldNotStream // TODO: set start node to not stream for now until composite node supports transform
} else {
fromNode, ok := allNS[fromNodeKey]
if !ok {
return fmt.Errorf("node %s not found", fromNodeKey)
}
streamType, err = fromNode.IsStreamingField(fInfo.Source.Ref.FromPath, allNS)
if err != nil {
return err
}
}
}
currentSource[lastPath] = &nodes.SourceInfo{
IsIntermediate: false,
FieldType: streamType,
FromNodeKey: fromNodeKey,
FromPath: fInfo.Source.Ref.FromPath,
TypeInfo: tInfo,
}
}
s.Configs.(map[string]any)["FullSources"] = fullSource
return nil
}
func (s *NodeSchema) IsStreamingField(path compose.FieldPath, allNS map[vo.NodeKey]*NodeSchema) (nodes.FieldStreamType, error) {
if s.Type == entity.NodeTypeExit {
if mustGetKey[nodes.Mode]("Mode", s.Configs) == nodes.Streaming {
if len(path) == 1 && path[0] == "output" {
return nodes.FieldIsStream, nil
}
}
return nodes.FieldNotStream, nil
} else if s.Type == entity.NodeTypeSubWorkflow { // TODO: why not use sub workflow's Mode configuration directly?
subSC := s.SubWorkflowSchema
subExit := subSC.GetNode(entity.ExitNodeKey)
subStreamType, err := subExit.IsStreamingField(path, nil)
if err != nil {
return nodes.FieldNotStream, err
}
return subStreamType, nil
} else if s.Type == entity.NodeTypeVariableAggregator {
if len(path) == 2 { // asking about a specific index within a group
for _, fInfo := range s.InputSources {
if len(fInfo.Path) == len(path) {
equal := true
for i := range fInfo.Path {
if fInfo.Path[i] != path[i] {
equal = false
break
}
}
if equal {
if fInfo.Source.Ref == nil || fInfo.Source.Ref.FromNodeKey == "" {
return nodes.FieldNotStream, nil
}
fromNodeKey := fInfo.Source.Ref.FromNodeKey
fromNode, ok := allNS[fromNodeKey]
if !ok {
return nodes.FieldNotStream, fmt.Errorf("node %s not found", fromNodeKey)
}
return fromNode.IsStreamingField(fInfo.Source.Ref.FromPath, allNS)
}
}
}
} else if len(path) == 1 { // asking about the entire group
var streamCount, notStreamCount int
for _, fInfo := range s.InputSources {
if fInfo.Path[0] == path[0] { // belong to the group
if fInfo.Source.Ref != nil && len(fInfo.Source.Ref.FromNodeKey) > 0 {
fromNode, ok := allNS[fInfo.Source.Ref.FromNodeKey]
if !ok {
return nodes.FieldNotStream, fmt.Errorf("node %s not found", fInfo.Source.Ref.FromNodeKey)
}
subStreamType, err := fromNode.IsStreamingField(fInfo.Source.Ref.FromPath, allNS)
if err != nil {
return nodes.FieldNotStream, err
}
if subStreamType == nodes.FieldMaybeStream {
return nodes.FieldMaybeStream, nil
} else if subStreamType == nodes.FieldIsStream {
streamCount++
} else {
notStreamCount++
}
}
}
}
if streamCount > 0 && notStreamCount == 0 {
return nodes.FieldIsStream, nil
}
if streamCount == 0 && notStreamCount > 0 {
return nodes.FieldNotStream, nil
}
return nodes.FieldMaybeStream, nil
}
}
if s.Type != entity.NodeTypeLLM {
return nodes.FieldNotStream, nil
}
if len(path) != 1 {
return nodes.FieldNotStream, nil
}
outputs := s.OutputTypes
if len(outputs) != 1 && len(outputs) != 2 {
return nodes.FieldNotStream, nil
}
var outputKey string
for key, output := range outputs {
if output.Type != vo.DataTypeString {
return nodes.FieldNotStream, nil
}
if key != "reasoning_content" {
if len(outputKey) > 0 {
return nodes.FieldNotStream, nil
}
outputKey = key
}
}
field := path[0]
if field == "reasoning_content" || field == outputKey {
return nodes.FieldIsStream, nil
}
return nodes.FieldNotStream, nil
}

View File

@@ -0,0 +1,349 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package test
import (
"context"
"fmt"
"testing"
"github.com/cloudwego/eino/compose"
"github.com/stretchr/testify/assert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/batch"
)
func TestBatch(t *testing.T) {
ctx := context.Background()
lambda1 := func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
if in["index"].(int64) > 2 {
return nil, fmt.Errorf("index= %d is too large", in["index"].(int64))
}
out = make(map[string]any)
out["output_1"] = fmt.Sprintf("%s_%v_%d", in["array_1"].(string), in["from_parent_wf"].(bool), in["index"].(int64))
return out, nil
}
lambda2 := func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
return map[string]any{"index": in["index"]}, nil
}
lambda3 := func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
t.Log(in["consumer_1"].(string), in["array_2"].(int64), in["static_source"].(string))
return in, nil
}
lambdaNode1 := &compose2.NodeSchema{
Key: "lambda",
Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(lambda1),
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"index"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: "batch_node_key",
FromPath: compose.FieldPath{"index"},
},
},
},
{
Path: compose.FieldPath{"array_1"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: "batch_node_key",
FromPath: compose.FieldPath{"array_1"},
},
},
},
{
Path: compose.FieldPath{"from_parent_wf"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: "parent_predecessor_1",
FromPath: compose.FieldPath{"success"},
},
},
},
},
}
lambdaNode2 := &compose2.NodeSchema{
Key: "index",
Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(lambda2),
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"index"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: "batch_node_key",
FromPath: compose.FieldPath{"index"},
},
},
},
},
}
lambdaNode3 := &compose2.NodeSchema{
Key: "consumer",
Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(lambda3),
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"consumer_1"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: "lambda",
FromPath: compose.FieldPath{"output_1"},
},
},
},
{
Path: compose.FieldPath{"array_2"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: "batch_node_key",
FromPath: compose.FieldPath{"array_2"},
},
},
},
{
Path: compose.FieldPath{"static_source"},
Source: vo.FieldSource{
Val: "this is a const",
},
},
},
}
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
}
ns := &compose2.NodeSchema{
Key: "batch_node_key",
Type: entity.NodeTypeBatch,
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"array_1"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromPath: compose.FieldPath{"array_1"},
},
},
},
{
Path: compose.FieldPath{"array_2"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromPath: compose.FieldPath{"array_2"},
},
},
},
{
Path: compose.FieldPath{batch.ConcurrentSizeKey},
Source: vo.FieldSource{
Val: int64(2),
},
},
{
Path: compose.FieldPath{batch.MaxBatchSizeKey},
Source: vo.FieldSource{
Val: int64(5),
},
},
},
InputTypes: map[string]*vo.TypeInfo{
"array_1": {
Type: vo.DataTypeArray,
ElemTypeInfo: &vo.TypeInfo{
Type: vo.DataTypeString,
},
},
"array_2": {
Type: vo.DataTypeArray,
ElemTypeInfo: &vo.TypeInfo{
Type: vo.DataTypeInteger,
},
},
},
OutputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"assembled_output_1"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: "lambda",
FromPath: compose.FieldPath{"output_1"},
},
},
},
{
Path: compose.FieldPath{"assembled_output_2"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: "index",
FromPath: compose.FieldPath{"index"},
},
},
},
},
}
exit := &compose2.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"assembled_output_1"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: "batch_node_key",
FromPath: compose.FieldPath{"assembled_output_1"},
},
},
},
{
Path: compose.FieldPath{"assembled_output_2"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: "batch_node_key",
FromPath: compose.FieldPath{"assembled_output_2"},
},
},
},
},
}
parentLambda := func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
return map[string]any{"success": true}, nil
}
parentLambdaNode := &compose2.NodeSchema{
Key: "parent_predecessor_1",
Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(parentLambda),
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
parentLambdaNode,
ns,
exit,
lambdaNode1,
lambdaNode2,
lambdaNode3,
},
Hierarchy: map[vo.NodeKey]vo.NodeKey{
"lambda": "batch_node_key",
"index": "batch_node_key",
"consumer": "batch_node_key",
},
Connections: []*compose2.Connection{
{
FromNode: entity.EntryNodeKey,
ToNode: "parent_predecessor_1",
},
{
FromNode: "parent_predecessor_1",
ToNode: "batch_node_key",
},
{
FromNode: "batch_node_key",
ToNode: "lambda",
},
{
FromNode: "lambda",
ToNode: "index",
},
{
FromNode: "lambda",
ToNode: "consumer",
},
{
FromNode: "index",
ToNode: "batch_node_key",
},
{
FromNode: "consumer",
ToNode: "batch_node_key",
},
{
FromNode: "batch_node_key",
ToNode: entity.ExitNodeKey,
},
},
}
ws.Init()
wf, err := compose2.NewWorkflow(ctx, ws)
assert.NoError(t, err)
out, err := wf.Runner.Invoke(ctx, map[string]any{
"array_1": []any{"a", "b", "c"},
"array_2": []any{int64(1), int64(2), int64(3), int64(4)},
})
assert.NoError(t, err)
assert.Equal(t, map[string]any{
"assembled_output_1": []any{"a_true_0", "b_true_1", "c_true_2"},
"assembled_output_2": []any{int64(0), int64(1), int64(2)},
}, out)
// input array is empty
out, err = wf.Runner.Invoke(ctx, map[string]any{
"array_1": []any{},
"array_2": []any{int64(1)},
})
assert.NoError(t, err)
assert.Equal(t, map[string]any{
"assembled_output_1": []any{},
"assembled_output_2": []any{},
}, out)
// less than concurrency
out, err = wf.Runner.Invoke(ctx, map[string]any{
"array_1": []any{"a"},
"array_2": []any{int64(1), int64(2)},
})
assert.NoError(t, err)
assert.Equal(t, map[string]any{
"assembled_output_1": []any{"a_true_0"},
"assembled_output_2": []any{int64(0)},
}, out)
// err by inner node
_, err = wf.Runner.Invoke(ctx, map[string]any{
"array_1": []any{"a", "b", "c", "d", "e", "f"},
"array_2": []any{int64(1), int64(2), int64(3), int64(4), int64(5), int64(6), int64(7)},
})
assert.ErrorContains(t, err, "is too large")
}

View File

@@ -0,0 +1,679 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package test
import (
"context"
"fmt"
"io"
"os"
"strings"
"testing"
"github.com/bytedance/mockey"
"github.com/cloudwego/eino-ext/components/model/deepseek"
"github.com/cloudwego/eino-ext/components/model/openai"
"github.com/cloudwego/eino/callbacks"
model2 "github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
mockmodel "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model/modelmock"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm"
"github.com/coze-dev/coze-studio/backend/internal/testutil"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
)
func TestLLM(t *testing.T) {
mockey.PatchConvey("test llm", t, func() {
accessKey := os.Getenv("OPENAI_API_KEY")
baseURL := os.Getenv("OPENAI_BASE_URL")
modelName := os.Getenv("OPENAI_MODEL_NAME")
var (
openaiModel, deepSeekModel model2.BaseChatModel
err error
)
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockModelManager := mockmodel.NewMockManager(ctrl)
mockey.Mock(model.GetManager).Return(mockModelManager).Build()
if len(accessKey) > 0 && len(baseURL) > 0 && len(modelName) > 0 {
openaiModel, err = openai.NewChatModel(context.Background(), &openai.ChatModelConfig{
APIKey: accessKey,
ByAzure: true,
BaseURL: baseURL,
Model: modelName,
})
assert.NoError(t, err)
}
deepSeekModelName := os.Getenv("DEEPSEEK_MODEL_NAME")
if len(accessKey) > 0 && len(baseURL) > 0 && len(deepSeekModelName) > 0 {
deepSeekModel, err = deepseek.NewChatModel(context.Background(), &deepseek.ChatModelConfig{
APIKey: accessKey,
BaseURL: baseURL,
Model: deepSeekModelName,
})
assert.NoError(t, err)
}
mockModelManager.EXPECT().GetModel(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, params *model.LLMParams) (model2.BaseChatModel, *modelmgr.Model, error) {
if params.ModelName == modelName {
return openaiModel, nil, nil
} else if params.ModelName == deepSeekModelName {
return deepSeekModel, nil, nil
} else {
return nil, nil, fmt.Errorf("invalid model name: %s", params.ModelName)
}
}).AnyTimes()
ctx := ctxcache.Init(context.Background())
t.Run("plain text output, non-streaming mode", func(t *testing.T) {
if openaiModel == nil {
defer func() {
openaiModel = nil
}()
openaiModel = &testutil.UTChatModel{
InvokeResultProvider: func(_ int, in []*schema.Message) (*schema.Message, error) {
return &schema.Message{
Role: schema.Assistant,
Content: "I don't know",
}, nil
},
}
}
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
}
llmNode := &compose2.NodeSchema{
Key: "llm_node_key",
Type: entity.NodeTypeLLM,
Configs: map[string]any{
"SystemPrompt": "{{sys_prompt}}",
"UserPrompt": "{{query}}",
"OutputFormat": llm.FormatText,
"LLMParams": &model.LLMParams{
ModelName: modelName,
},
},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"sys_prompt"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromPath: compose.FieldPath{"sys_prompt"},
},
},
},
{
Path: compose.FieldPath{"query"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromPath: compose.FieldPath{"query"},
},
},
},
},
InputTypes: map[string]*vo.TypeInfo{
"sys_prompt": {
Type: vo.DataTypeString,
},
"query": {
Type: vo.DataTypeString,
},
},
OutputTypes: map[string]*vo.TypeInfo{
"output": {
Type: vo.DataTypeString,
},
},
}
exit := &compose2.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"output"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: llmNode.Key,
FromPath: compose.FieldPath{"output"},
},
},
},
},
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
llmNode,
exit,
},
Connections: []*compose2.Connection{
{
FromNode: entry.Key,
ToNode: llmNode.Key,
},
{
FromNode: llmNode.Key,
ToNode: exit.Key,
},
},
}
ws.Init()
wf, err := compose2.NewWorkflow(ctx, ws)
assert.NoError(t, err)
out, err := wf.Runner.Invoke(ctx, map[string]any{
"sys_prompt": "you are a helpful assistant",
"query": "what's your name",
})
assert.NoError(t, err)
assert.Greater(t, len(out), 0)
assert.Greater(t, len(out["output"].(string)), 0)
})
t.Run("json output", func(t *testing.T) {
if openaiModel == nil {
defer func() {
openaiModel = nil
}()
openaiModel = &testutil.UTChatModel{
InvokeResultProvider: func(_ int, in []*schema.Message) (*schema.Message, error) {
return &schema.Message{
Role: schema.Assistant,
Content: `{"country_name": "Russia", "area_size": 17075400}`,
}, nil
},
}
}
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
}
llmNode := &compose2.NodeSchema{
Key: "llm_node_key",
Type: entity.NodeTypeLLM,
Configs: map[string]any{
"SystemPrompt": "you are a helpful assistant",
"UserPrompt": "what's the largest country in the world and it's area size in square kilometers?",
"OutputFormat": llm.FormatJSON,
"IgnoreException": true,
"DefaultOutput": map[string]any{
"country_name": "unknown",
"area_size": int64(0),
},
"LLMParams": &model.LLMParams{
ModelName: modelName,
},
},
OutputTypes: map[string]*vo.TypeInfo{
"country_name": {
Type: vo.DataTypeString,
Required: true,
},
"area_size": {
Type: vo.DataTypeInteger,
Required: true,
},
},
}
exit := &compose2.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"country_name"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: llmNode.Key,
FromPath: compose.FieldPath{"country_name"},
},
},
},
{
Path: compose.FieldPath{"area_size"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: llmNode.Key,
FromPath: compose.FieldPath{"area_size"},
},
},
},
},
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
llmNode,
exit,
},
Connections: []*compose2.Connection{
{
FromNode: entry.Key,
ToNode: llmNode.Key,
},
{
FromNode: llmNode.Key,
ToNode: exit.Key,
},
},
}
ws.Init()
wf, err := compose2.NewWorkflow(ctx, ws)
assert.NoError(t, err)
out, err := wf.Runner.Invoke(ctx, map[string]any{})
assert.NoError(t, err)
assert.Equal(t, out["country_name"], "Russia")
assert.Greater(t, out["area_size"], int64(1000))
})
t.Run("markdown output", func(t *testing.T) {
if openaiModel == nil {
defer func() {
openaiModel = nil
}()
openaiModel = &testutil.UTChatModel{
InvokeResultProvider: func(_ int, in []*schema.Message) (*schema.Message, error) {
return &schema.Message{
Role: schema.Assistant,
Content: `#Top 5 Largest Countries in the World ## 1. Russia 2. Canada 3. United States 4. Brazil 5. Japan`,
}, nil
},
}
}
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
}
llmNode := &compose2.NodeSchema{
Key: "llm_node_key",
Type: entity.NodeTypeLLM,
Configs: map[string]any{
"SystemPrompt": "you are a helpful assistant",
"UserPrompt": "list the top 5 largest countries in the world",
"OutputFormat": llm.FormatMarkdown,
"LLMParams": &model.LLMParams{
ModelName: modelName,
},
},
OutputTypes: map[string]*vo.TypeInfo{
"output": {
Type: vo.DataTypeString,
},
},
}
exit := &compose2.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"output"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: llmNode.Key,
FromPath: compose.FieldPath{"output"},
},
},
},
},
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
llmNode,
exit,
},
Connections: []*compose2.Connection{
{
FromNode: entry.Key,
ToNode: llmNode.Key,
},
{
FromNode: llmNode.Key,
ToNode: exit.Key,
},
},
}
ws.Init()
wf, err := compose2.NewWorkflow(ctx, ws)
assert.NoError(t, err)
out, err := wf.Runner.Invoke(ctx, map[string]any{})
assert.NoError(t, err)
assert.Greater(t, len(out["output"].(string)), 0)
})
t.Run("plain text output, streaming mode", func(t *testing.T) {
// start -> fan out to openai LLM and deepseek LLM -> fan in to output emitter -> end
if openaiModel == nil || deepSeekModel == nil {
if openaiModel == nil {
defer func() {
openaiModel = nil
}()
openaiModel = &testutil.UTChatModel{
StreamResultProvider: func(_ int, in []*schema.Message) (*schema.StreamReader[*schema.Message], error) {
sr := schema.StreamReaderFromArray([]*schema.Message{
{
Role: schema.Assistant,
Content: "I ",
},
{
Role: schema.Assistant,
Content: "don't know.",
},
})
return sr, nil
},
}
}
if deepSeekModel == nil {
defer func() {
deepSeekModel = nil
}()
deepSeekModel = &testutil.UTChatModel{
StreamResultProvider: func(_ int, in []*schema.Message) (*schema.StreamReader[*schema.Message], error) {
sr := schema.StreamReaderFromArray([]*schema.Message{
{
Role: schema.Assistant,
Content: "I ",
},
{
Role: schema.Assistant,
Content: "don't know too.",
},
})
return sr, nil
},
}
}
}
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
}
openaiNode := &compose2.NodeSchema{
Key: "openai_llm_node_key",
Type: entity.NodeTypeLLM,
Configs: map[string]any{
"SystemPrompt": "you are a helpful assistant",
"UserPrompt": "plan a 10 day family visit to China.",
"OutputFormat": llm.FormatText,
"LLMParams": &model.LLMParams{
ModelName: modelName,
},
},
OutputTypes: map[string]*vo.TypeInfo{
"output": {
Type: vo.DataTypeString,
},
},
}
deepseekNode := &compose2.NodeSchema{
Key: "deepseek_llm_node_key",
Type: entity.NodeTypeLLM,
Configs: map[string]any{
"SystemPrompt": "you are a helpful assistant",
"UserPrompt": "thoroughly plan a 10 day family visit to China. Use your reasoning ability.",
"OutputFormat": llm.FormatText,
"LLMParams": &model.LLMParams{
ModelName: modelName,
},
},
OutputTypes: map[string]*vo.TypeInfo{
"output": {
Type: vo.DataTypeString,
},
"reasoning_content": {
Type: vo.DataTypeString,
},
},
}
emitterNode := &compose2.NodeSchema{
Key: "emitter_node_key",
Type: entity.NodeTypeOutputEmitter,
Configs: map[string]any{
"Template": "prefix {{inputObj.field1}} {{input2}} {{deepseek_reasoning}} \n\n###\n\n {{openai_output}} \n\n###\n\n {{deepseek_output}} {{inputObj.field2}} suffix",
"Mode": nodes.Streaming,
},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"openai_output"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: openaiNode.Key,
FromPath: compose.FieldPath{"output"},
},
},
},
{
Path: compose.FieldPath{"deepseek_output"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: deepseekNode.Key,
FromPath: compose.FieldPath{"output"},
},
},
},
{
Path: compose.FieldPath{"deepseek_reasoning"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: deepseekNode.Key,
FromPath: compose.FieldPath{"reasoning_content"},
},
},
},
{
Path: compose.FieldPath{"inputObj"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromPath: compose.FieldPath{"inputObj"},
},
},
},
{
Path: compose.FieldPath{"input2"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromPath: compose.FieldPath{"input2"},
},
},
},
},
}
exit := &compose2.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.UseAnswerContent,
},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"openai_output"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: openaiNode.Key,
FromPath: compose.FieldPath{"output"},
},
},
},
{
Path: compose.FieldPath{"deepseek_output"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: deepseekNode.Key,
FromPath: compose.FieldPath{"output"},
},
},
},
{
Path: compose.FieldPath{"deepseek_reasoning"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: deepseekNode.Key,
FromPath: compose.FieldPath{"reasoning_content"},
},
},
},
},
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
openaiNode,
deepseekNode,
emitterNode,
exit,
},
Connections: []*compose2.Connection{
{
FromNode: entry.Key,
ToNode: openaiNode.Key,
},
{
FromNode: openaiNode.Key,
ToNode: emitterNode.Key,
},
{
FromNode: entry.Key,
ToNode: deepseekNode.Key,
},
{
FromNode: deepseekNode.Key,
ToNode: emitterNode.Key,
},
{
FromNode: emitterNode.Key,
ToNode: exit.Key,
},
},
}
ws.Init()
wf, err := compose2.NewWorkflow(ctx, ws)
if err != nil {
t.Fatal(err)
}
var fullOutput string
cbHandler := callbacks.NewHandlerBuilder().OnEndWithStreamOutputFn(
func(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[callbacks.CallbackOutput]) context.Context {
defer output.Close()
for {
chunk, e := output.Recv()
if e != nil {
if e == io.EOF {
break
}
assert.NoError(t, e)
}
s, ok := chunk.(map[string]any)
assert.True(t, ok)
out := s["output"].(string)
if out != nodes.KeyIsFinished {
fmt.Print(s["output"])
fullOutput += s["output"].(string)
}
}
return ctx
}).Build()
outStream, err := wf.Runner.Stream(ctx, map[string]any{
"inputObj": map[string]any{
"field1": "field1",
"field2": 1.1,
},
"input2": 23.5,
}, compose.WithCallbacks(cbHandler).DesignateNode(string(emitterNode.Key)))
assert.NoError(t, err)
assert.True(t, strings.HasPrefix(fullOutput, "prefix field1 23.5"))
assert.True(t, strings.HasSuffix(fullOutput, "1.1 suffix"))
outStream.Close()
})
})
}

View File

@@ -0,0 +1,495 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package test
import (
"context"
"testing"
"github.com/cloudwego/eino/compose"
"github.com/stretchr/testify/assert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableassigner"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
func TestLoop(t *testing.T) {
t.Run("by iteration", func(t *testing.T) {
// start-> loop_node_key[innerNode->continue] -> end
innerNode := &compose2.NodeSchema{
Key: "innerNode",
Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
index := in["index"].(int64)
return map[string]any{"output": index}, nil
}),
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"index"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: "loop_node_key",
FromPath: compose.FieldPath{"index"},
},
},
},
},
}
continueNode := &compose2.NodeSchema{
Key: "continueNode",
Type: entity.NodeTypeContinue,
}
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
}
loopNode := &compose2.NodeSchema{
Key: "loop_node_key",
Type: entity.NodeTypeLoop,
Configs: map[string]any{
"LoopType": loop.ByIteration,
},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{loop.Count},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromPath: compose.FieldPath{"count"},
},
},
},
},
OutputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"output"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: "innerNode",
FromPath: compose.FieldPath{"output"},
},
},
},
},
}
exit := &compose2.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"output"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: "loop_node_key",
FromPath: compose.FieldPath{"output"},
},
},
},
},
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
loopNode,
exit,
innerNode,
continueNode,
},
Hierarchy: map[vo.NodeKey]vo.NodeKey{
"innerNode": "loop_node_key",
"continueNode": "loop_node_key",
},
Connections: []*compose2.Connection{
{
FromNode: "loop_node_key",
ToNode: "innerNode",
},
{
FromNode: "innerNode",
ToNode: "continueNode",
},
{
FromNode: "continueNode",
ToNode: "loop_node_key",
},
{
FromNode: entry.Key,
ToNode: "loop_node_key",
},
{
FromNode: "loop_node_key",
ToNode: exit.Key,
},
},
}
ws.Init()
wf, err := compose2.NewWorkflow(context.Background(), ws)
assert.NoError(t, err)
out, err := wf.Runner.Invoke(context.Background(), map[string]any{
"count": int64(3),
})
assert.NoError(t, err)
assert.Equal(t, map[string]any{
"output": []any{int64(0), int64(1), int64(2)},
}, out)
})
t.Run("infinite", func(t *testing.T) {
// start-> loop_node_key[innerNode->break] -> end
innerNode := &compose2.NodeSchema{
Key: "innerNode",
Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
index := in["index"].(int64)
return map[string]any{"output": index}, nil
}),
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"index"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: "loop_node_key",
FromPath: compose.FieldPath{"index"},
},
},
},
},
}
breakNode := &compose2.NodeSchema{
Key: "breakNode",
Type: entity.NodeTypeBreak,
}
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
}
loopNode := &compose2.NodeSchema{
Key: "loop_node_key",
Type: entity.NodeTypeLoop,
Configs: map[string]any{
"LoopType": loop.Infinite,
},
OutputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"output"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: "innerNode",
FromPath: compose.FieldPath{"output"},
},
},
},
},
}
exit := &compose2.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"output"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: "loop_node_key",
FromPath: compose.FieldPath{"output"},
},
},
},
},
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
loopNode,
exit,
innerNode,
breakNode,
},
Hierarchy: map[vo.NodeKey]vo.NodeKey{
"innerNode": "loop_node_key",
"breakNode": "loop_node_key",
},
Connections: []*compose2.Connection{
{
FromNode: "loop_node_key",
ToNode: "innerNode",
},
{
FromNode: "innerNode",
ToNode: "breakNode",
},
{
FromNode: "breakNode",
ToNode: "loop_node_key",
},
{
FromNode: entry.Key,
ToNode: "loop_node_key",
},
{
FromNode: "loop_node_key",
ToNode: exit.Key,
},
},
}
ws.Init()
wf, err := compose2.NewWorkflow(context.Background(), ws)
assert.NoError(t, err)
out, err := wf.Runner.Invoke(context.Background(), map[string]any{})
assert.NoError(t, err)
assert.Equal(t, map[string]any{
"output": []any{int64(0)},
}, out)
})
t.Run("by array", func(t *testing.T) {
// start-> loop_node_key[innerNode->variable_assign] -> end
innerNode := &compose2.NodeSchema{
Key: "innerNode",
Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
item1 := in["item1"].(string)
item2 := in["item2"].(string)
count := in["count"].(int)
return map[string]any{"total": int(count) + len(item1) + len(item2)}, nil
}),
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"item1"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: "loop_node_key",
FromPath: compose.FieldPath{"items1"},
},
},
},
{
Path: compose.FieldPath{"item2"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: "loop_node_key",
FromPath: compose.FieldPath{"items2"},
},
},
},
{
Path: compose.FieldPath{"count"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromPath: compose.FieldPath{"count"},
VariableType: ptr.Of(vo.ParentIntermediate),
},
},
},
},
}
assigner := &compose2.NodeSchema{
Key: "assigner",
Type: entity.NodeTypeVariableAssignerWithinLoop,
Configs: []*variableassigner.Pair{
{
Left: vo.Reference{
FromPath: compose.FieldPath{"count"},
VariableType: ptr.Of(vo.ParentIntermediate),
},
Right: compose.FieldPath{"total"},
},
},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"total"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: "innerNode",
FromPath: compose.FieldPath{"total"},
},
},
},
},
}
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
}
exit := &compose2.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"output"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: "loop_node_key",
FromPath: compose.FieldPath{"output"},
},
},
},
},
}
loopNode := &compose2.NodeSchema{
Key: "loop_node_key",
Type: entity.NodeTypeLoop,
Configs: map[string]any{
"LoopType": loop.ByArray,
"IntermediateVars": map[string]*vo.TypeInfo{
"count": {
Type: vo.DataTypeInteger,
},
},
},
InputTypes: map[string]*vo.TypeInfo{
"items1": {
Type: vo.DataTypeArray,
ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeString},
},
"items2": {
Type: vo.DataTypeArray,
ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeString},
},
},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"items1"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromPath: compose.FieldPath{"items1"},
},
},
},
{
Path: compose.FieldPath{"items2"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromPath: compose.FieldPath{"items2"},
},
},
},
{
Path: compose.FieldPath{"count"},
Source: vo.FieldSource{
Val: 0,
},
},
},
OutputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"output"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromPath: compose.FieldPath{"count"},
VariableType: ptr.Of(vo.ParentIntermediate),
},
},
},
},
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
loopNode,
exit,
innerNode,
assigner,
},
Hierarchy: map[vo.NodeKey]vo.NodeKey{
"innerNode": "loop_node_key",
"assigner": "loop_node_key",
},
Connections: []*compose2.Connection{
{
FromNode: "loop_node_key",
ToNode: "innerNode",
},
{
FromNode: "innerNode",
ToNode: "assigner",
},
{
FromNode: "assigner",
ToNode: "loop_node_key",
},
{
FromNode: entry.Key,
ToNode: "loop_node_key",
},
{
FromNode: "loop_node_key",
ToNode: exit.Key,
},
},
}
ws.Init()
wf, err := compose2.NewWorkflow(context.Background(), ws)
assert.NoError(t, err)
out, err := wf.Runner.Invoke(context.Background(), map[string]any{
"items1": []any{"a", "b"},
"items2": []any{"a1", "b1", "c1"},
})
assert.NoError(t, err)
assert.Equal(t, map[string]any{
"output": 6,
}, out)
})
}

View File

@@ -0,0 +1,673 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package test
import (
"context"
"errors"
"fmt"
"os"
"strings"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/bytedance/mockey"
"github.com/cloudwego/eino-ext/components/model/openai"
model2 "github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"github.com/coze-dev/coze-studio/backend/domain/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
mockmodel "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model/modelmock"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
repo2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/repo"
"github.com/coze-dev/coze-studio/backend/infra/impl/checkpoint"
mock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/idgen"
storageMock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/storage"
"github.com/coze-dev/coze-studio/backend/internal/testutil"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
func TestQuestionAnswer(t *testing.T) {
mockey.PatchConvey("test qa", t, func() {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockModelManager := mockmodel.NewMockManager(ctrl)
mockey.Mock(model.GetManager).Return(mockModelManager).Build()
accessKey := os.Getenv("OPENAI_API_KEY")
baseURL := os.Getenv("OPENAI_BASE_URL")
modelName := os.Getenv("OPENAI_MODEL_NAME")
var (
chatModel model2.BaseChatModel
err error
)
if len(accessKey) > 0 && len(baseURL) > 0 && len(modelName) > 0 {
chatModel, err = openai.NewChatModel(context.Background(), &openai.ChatModelConfig{
APIKey: accessKey,
ByAzure: true,
BaseURL: baseURL,
Model: modelName,
})
assert.NoError(t, err)
mockModelManager.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(chatModel, nil, nil).AnyTimes()
}
dsn := "root:root@tcp(127.0.0.1:3306)/opencoze?charset=utf8mb4&parseTime=True&loc=Local"
if os.Getenv("CI_JOB_NAME") != "" {
dsn = strings.ReplaceAll(dsn, "127.0.0.1", "mysql")
}
db, err := gorm.Open(mysql.Open(dsn))
assert.NoError(t, err)
s, err := miniredis.Run()
if err != nil {
t.Fatalf("Failed to start miniredis: %v", err)
}
defer s.Close()
redisClient := redis.NewClient(&redis.Options{
Addr: s.Addr(),
})
mockIDGen := mock.NewMockIDGenerator(ctrl)
mockIDGen.EXPECT().GenID(gomock.Any()).Return(time.Now().UnixNano(), nil).AnyTimes()
mockTos := storageMock.NewMockStorage(ctrl)
mockTos.EXPECT().GetObjectUrl(gomock.Any(), gomock.Any(), gomock.Any()).Return("", nil).AnyTimes()
repo := repo2.NewRepository(mockIDGen, db, redisClient, mockTos,
checkpoint.NewRedisStore(redisClient))
mockey.Mock(workflow.GetRepository).Return(repo).Build()
t.Run("answer directly, no structured output", func(t *testing.T) {
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
}}
ns := &compose2.NodeSchema{
Key: "qa_node_key",
Type: entity.NodeTypeQuestionAnswer,
Configs: map[string]any{
"QuestionTpl": "{{input}}",
"AnswerType": qa.AnswerDirectly,
},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"input"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromPath: compose.FieldPath{"query"},
},
},
},
},
}
exit := &compose2.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"answer"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: "qa_node_key",
FromPath: compose.FieldPath{qa.UserResponseKey},
},
},
},
},
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
ns,
exit,
},
Connections: []*compose2.Connection{
{
FromNode: entry.Key,
ToNode: "qa_node_key",
},
{
FromNode: "qa_node_key",
ToNode: exit.Key,
},
},
}
ws.Init()
wf, err := compose2.NewWorkflow(context.Background(), ws)
assert.NoError(t, err)
checkPointID := fmt.Sprintf("%d", time.Now().Nanosecond())
_, err = wf.Runner.Invoke(context.Background(), map[string]any{
"query": "what's your name?",
}, compose.WithCheckPointID(checkPointID))
assert.Error(t, err)
info, existed := compose.ExtractInterruptInfo(err)
assert.True(t, existed)
assert.Equal(t, "what's your name?", info.State.(*compose2.State).Questions[ns.Key][0].Question)
answer := "my name is eino"
stateModifier := func(ctx context.Context, path compose.NodePath, state any) error {
state.(*compose2.State).Answers[ns.Key] = append(state.(*compose2.State).Answers[ns.Key], answer)
return nil
}
out, err := wf.Runner.Invoke(context.Background(), nil, compose.WithCheckPointID(checkPointID), compose.WithStateModifier(stateModifier))
assert.NoError(t, err)
assert.Equal(t, map[string]any{
"answer": answer,
}, out)
})
t.Run("answer with fixed choices", func(t *testing.T) {
if chatModel == nil {
oneChatModel := &testutil.UTChatModel{
InvokeResultProvider: func(_ int, in []*schema.Message) (*schema.Message, error) {
return &schema.Message{
Role: schema.Assistant,
Content: "-1",
}, nil
},
}
mockModelManager.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(oneChatModel, nil, nil).Times(1)
}
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
}
ns := &compose2.NodeSchema{
Key: "qa_node_key",
Type: entity.NodeTypeQuestionAnswer,
Configs: map[string]any{
"QuestionTpl": "{{input}}",
"AnswerType": qa.AnswerByChoices,
"ChoiceType": qa.FixedChoices,
"FixedChoices": []string{"{{choice1}}", "{{choice2}}"},
"LLMParams": &model.LLMParams{},
},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"input"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromPath: compose.FieldPath{"query"},
},
},
},
{
Path: compose.FieldPath{"choice1"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromPath: compose.FieldPath{"choice1"},
},
},
},
{
Path: compose.FieldPath{"choice2"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromPath: compose.FieldPath{"choice2"},
},
},
},
},
}
exit := &compose2.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"option_id"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: "qa_node_key",
FromPath: compose.FieldPath{qa.OptionIDKey},
},
},
},
{
Path: compose.FieldPath{"option_content"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: "qa_node_key",
FromPath: compose.FieldPath{qa.OptionContentKey},
},
},
},
},
}
lambda := &compose2.NodeSchema{
Key: "lambda",
Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
return out, nil
}),
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
ns,
exit,
lambda,
},
Connections: []*compose2.Connection{
{
FromNode: entry.Key,
ToNode: "qa_node_key",
},
{
FromNode: "qa_node_key",
ToNode: exit.Key,
FromPort: ptr.Of("branch_0"),
},
{
FromNode: "qa_node_key",
ToNode: exit.Key,
FromPort: ptr.Of("branch_1"),
},
{
FromNode: "qa_node_key",
ToNode: "lambda",
FromPort: ptr.Of("default"),
},
{
FromNode: "lambda",
ToNode: exit.Key,
},
},
}
ws.Init()
wf, err := compose2.NewWorkflow(context.Background(), ws)
assert.NoError(t, err)
checkPointID := fmt.Sprintf("%d", time.Now().Nanosecond())
_, err = wf.Runner.Invoke(context.Background(), map[string]any{
"query": "what's would you make in Coze?",
"choice1": "make agent",
"choice2": "make workflow",
}, compose.WithCheckPointID(checkPointID))
assert.Error(t, err)
info, existed := compose.ExtractInterruptInfo(err)
assert.True(t, existed)
assert.Equal(t, "what's would you make in Coze?", info.State.(*compose2.State).Questions[ns.Key][0].Question)
assert.Equal(t, "make agent", info.State.(*compose2.State).Questions[ns.Key][0].Choices[0])
assert.Equal(t, "make workflow", info.State.(*compose2.State).Questions[ns.Key][0].Choices[1])
chosenContent := "I would make all kinds of stuff"
stateModifier := func(ctx context.Context, path compose.NodePath, state any) error {
state.(*compose2.State).Answers[ns.Key] = append(state.(*compose2.State).Answers[ns.Key], chosenContent)
return nil
}
out, err := wf.Runner.Invoke(context.Background(), nil, compose.WithCheckPointID(checkPointID), compose.WithStateModifier(stateModifier))
assert.NoError(t, err)
assert.Equal(t, map[string]any{
"option_id": "other",
"option_content": chosenContent,
}, out)
})
t.Run("answer with dynamic choices", func(t *testing.T) {
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
}
ns := &compose2.NodeSchema{
Key: "qa_node_key",
Type: entity.NodeTypeQuestionAnswer,
Configs: map[string]any{
"QuestionTpl": "{{input}}",
"AnswerType": qa.AnswerByChoices,
"ChoiceType": qa.DynamicChoices,
},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"input"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromPath: compose.FieldPath{"query"},
},
},
},
{
Path: compose.FieldPath{qa.DynamicChoicesKey},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromPath: compose.FieldPath{"choices"},
},
},
},
},
}
exit := &compose2.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"option_id"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: "qa_node_key",
FromPath: compose.FieldPath{qa.OptionIDKey},
},
},
},
{
Path: compose.FieldPath{"option_content"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: "qa_node_key",
FromPath: compose.FieldPath{qa.OptionContentKey},
},
},
},
},
}
lambda := &compose2.NodeSchema{
Key: "lambda",
Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
return out, nil
}),
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
ns,
exit,
lambda,
},
Connections: []*compose2.Connection{
{
FromNode: entry.Key,
ToNode: "qa_node_key",
},
{
FromNode: "qa_node_key",
ToNode: exit.Key,
FromPort: ptr.Of("branch_0"),
},
{
FromNode: "lambda",
ToNode: exit.Key,
},
{
FromNode: "qa_node_key",
ToNode: "lambda",
FromPort: ptr.Of("default"),
},
},
}
ws.Init()
wf, err := compose2.NewWorkflow(context.Background(), ws)
assert.NoError(t, err)
checkPointID := fmt.Sprintf("%d", time.Now().Nanosecond())
_, err = wf.Runner.Invoke(context.Background(), map[string]any{
"query": "what's the capital city of China?",
"choices": []any{"beijing", "shanghai"},
}, compose.WithCheckPointID(checkPointID))
assert.Error(t, err)
info, existed := compose.ExtractInterruptInfo(err)
assert.True(t, existed)
assert.Equal(t, "what's the capital city of China?", info.State.(*compose2.State).Questions[ns.Key][0].Question)
assert.Equal(t, "beijing", info.State.(*compose2.State).Questions[ns.Key][0].Choices[0])
assert.Equal(t, "shanghai", info.State.(*compose2.State).Questions[ns.Key][0].Choices[1])
chosenContent := "beijing"
stateModifier := func(ctx context.Context, path compose.NodePath, state any) error {
state.(*compose2.State).Answers[ns.Key] = append(state.(*compose2.State).Answers[ns.Key], chosenContent)
return nil
}
out, err := wf.Runner.Invoke(context.Background(), nil, compose.WithCheckPointID(checkPointID), compose.WithStateModifier(stateModifier))
assert.NoError(t, err)
assert.Equal(t, map[string]any{
"option_id": "A",
"option_content": chosenContent,
}, out)
})
t.Run("answer directly, extract structured output", func(t *testing.T) {
ctx := context.Background()
qaCount := 0
if chatModel == nil {
defer func() {
chatModel = nil
}()
chatModel = &testutil.UTChatModel{
InvokeResultProvider: func(_ int, in []*schema.Message) (*schema.Message, error) {
if qaCount == 1 {
return &schema.Message{
Role: schema.Assistant,
Content: `{"question": "what's your age?"}`,
}, nil
} else if qaCount == 2 {
return &schema.Message{
Role: schema.Assistant,
Content: `{"fields": {"name": "eino", "age": 1}}`,
}, nil
}
return nil, errors.New("not found")
},
}
mockModelManager.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(chatModel, nil, nil).Times(1)
}
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
}
ns := &compose2.NodeSchema{
Key: "qa_node_key",
Type: entity.NodeTypeQuestionAnswer,
Configs: map[string]any{
"QuestionTpl": "{{input}}",
"AnswerType": qa.AnswerDirectly,
"ExtractFromAnswer": true,
"AdditionalSystemPromptTpl": "{{prompt}}",
"MaxAnswerCount": 2,
"LLMParams": &model.LLMParams{},
},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"input"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromPath: compose.FieldPath{"query"},
},
},
},
{
Path: compose.FieldPath{"prompt"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromPath: compose.FieldPath{"prompt"},
},
},
},
},
OutputTypes: map[string]*vo.TypeInfo{
"name": {
Type: vo.DataTypeString,
Required: true,
},
"age": {
Type: vo.DataTypeInteger,
Required: true,
},
},
}
exit := &compose2.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"name"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: "qa_node_key",
FromPath: compose.FieldPath{"name"},
},
},
},
{
Path: compose.FieldPath{"age"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: "qa_node_key",
FromPath: compose.FieldPath{"age"},
},
},
},
{
Path: compose.FieldPath{qa.UserResponseKey},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: "qa_node_key",
FromPath: compose.FieldPath{qa.UserResponseKey},
},
},
},
},
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
ns,
exit,
},
Connections: []*compose2.Connection{
{
FromNode: entry.Key,
ToNode: "qa_node_key",
},
{
FromNode: "qa_node_key",
ToNode: exit.Key,
},
},
}
ws.Init()
wf, err := compose2.NewWorkflow(context.Background(), ws)
assert.NoError(t, err)
checkPointID := fmt.Sprintf("%d", time.Now().Nanosecond())
_, err = wf.Runner.Invoke(ctx, map[string]any{
"query": "what's your name?",
"prompt": "You are a helpful assistant.",
}, compose.WithCheckPointID(checkPointID))
assert.Error(t, err)
info, existed := compose.ExtractInterruptInfo(err)
assert.True(t, existed)
assert.Equal(t, "what's your name?", info.State.(*compose2.State).Questions["qa_node_key"][0].Question)
qaCount++
answer := "my name is eino"
stateModifier := func(ctx context.Context, path compose.NodePath, state any) error {
state.(*compose2.State).Answers[ns.Key] = append(state.(*compose2.State).Answers[ns.Key], answer)
return nil
}
_, err = wf.Runner.Invoke(ctx, map[string]any{}, compose.WithCheckPointID(checkPointID), compose.WithStateModifier(stateModifier))
assert.Error(t, err)
info, existed = compose.ExtractInterruptInfo(err)
assert.True(t, existed)
qaCount++
answer = "my age is 1 years old"
stateModifier = func(ctx context.Context, path compose.NodePath, state any) error {
state.(*compose2.State).Answers[ns.Key] = append(state.(*compose2.State).Answers[ns.Key], answer)
return nil
}
out, err := wf.Runner.Invoke(ctx, map[string]any{}, compose.WithCheckPointID(checkPointID), compose.WithStateModifier(stateModifier))
assert.NoError(t, err)
assert.Equal(t, map[string]any{
qa.UserResponseKey: answer,
"name": "eino",
"age": int64(1),
}, out)
})
})
}

View File

@@ -0,0 +1,634 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package test
import (
"context"
"testing"
"github.com/cloudwego/eino/compose"
"github.com/stretchr/testify/assert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/textprocessor"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableaggregator"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
func TestAddSelector(t *testing.T) {
// start -> selector, selector.condition1 -> lambda1 -> end, selector.condition2 -> [lambda2, lambda3] -> end, selector default -> end
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
}}
exit := &compose2.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: "lambda1",
FromPath: compose.FieldPath{"lambda1"},
},
},
Path: compose.FieldPath{"lambda1"},
},
{
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: "lambda2",
FromPath: compose.FieldPath{"lambda2"},
},
},
Path: compose.FieldPath{"lambda2"},
},
{
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: "lambda3",
FromPath: compose.FieldPath{"lambda3"},
},
},
Path: compose.FieldPath{"lambda3"},
},
},
}
lambda1 := func(ctx context.Context, in map[string]any) (map[string]any, error) {
return map[string]any{
"lambda1": "v1",
}, nil
}
lambdaNode1 := &compose2.NodeSchema{
Key: "lambda1",
Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(lambda1),
}
lambda2 := func(ctx context.Context, in map[string]any) (map[string]any, error) {
return map[string]any{
"lambda2": "v2",
}, nil
}
LambdaNode2 := &compose2.NodeSchema{
Key: "lambda2",
Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(lambda2),
}
lambda3 := func(ctx context.Context, in map[string]any) (map[string]any, error) {
return map[string]any{
"lambda3": "v3",
}, nil
}
lambdaNode3 := &compose2.NodeSchema{
Key: "lambda3",
Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(lambda3),
}
ns := &compose2.NodeSchema{
Key: "selector",
Type: entity.NodeTypeSelector,
Configs: map[string]any{"Clauses": []*selector.OneClauseSchema{
{
Single: ptr.Of(selector.OperatorEqual),
},
{
Multi: &selector.MultiClauseSchema{
Clauses: []*selector.Operator{
ptr.Of(selector.OperatorGreater),
ptr.Of(selector.OperatorIsTrue),
},
Relation: selector.ClauseRelationAND,
},
},
}},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"0", selector.LeftKey},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromPath: compose.FieldPath{"key1"},
},
},
},
{
Path: compose.FieldPath{"0", selector.RightKey},
Source: vo.FieldSource{
Val: "value1",
},
},
{
Path: compose.FieldPath{"1", "0", selector.LeftKey},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromPath: compose.FieldPath{"key2"},
},
},
},
{
Path: compose.FieldPath{"1", "0", selector.RightKey},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromPath: compose.FieldPath{"key3"},
},
},
},
{
Path: compose.FieldPath{"1", "1", selector.LeftKey},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromPath: compose.FieldPath{"key4"},
},
},
},
},
InputTypes: map[string]*vo.TypeInfo{
"0": {
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
selector.LeftKey: {
Type: vo.DataTypeString,
},
selector.RightKey: {
Type: vo.DataTypeString,
},
},
},
"1": {
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"0": {
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
selector.LeftKey: {
Type: vo.DataTypeInteger,
},
selector.RightKey: {
Type: vo.DataTypeInteger,
},
},
},
"1": {
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
selector.LeftKey: {
Type: vo.DataTypeBoolean,
},
},
},
},
},
},
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
ns,
lambdaNode1,
LambdaNode2,
lambdaNode3,
exit,
},
Connections: []*compose2.Connection{
{
FromNode: entry.Key,
ToNode: "selector",
},
{
FromNode: "selector",
ToNode: "lambda1",
FromPort: ptr.Of("branch_0"),
},
{
FromNode: "selector",
ToNode: "lambda2",
FromPort: ptr.Of("branch_1"),
},
{
FromNode: "selector",
ToNode: "lambda3",
FromPort: ptr.Of("branch_1"),
},
{
FromNode: "selector",
ToNode: exit.Key,
FromPort: ptr.Of("default"),
},
{
FromNode: "lambda1",
ToNode: exit.Key,
},
{
FromNode: "lambda2",
ToNode: exit.Key,
},
{
FromNode: "lambda3",
ToNode: exit.Key,
},
},
}
ws.Init()
ctx := context.Background()
wf, err := compose2.NewWorkflow(ctx, ws)
assert.NoError(t, err)
out, err := wf.Runner.Invoke(ctx, map[string]any{
"key1": "value1",
"key2": int64(2),
"key3": int64(3),
"key4": true,
})
assert.NoError(t, err)
assert.Equal(t, map[string]any{
"lambda1": "v1",
}, out)
out, err = wf.Runner.Invoke(ctx, map[string]any{
"key1": "value2",
"key2": int64(3),
"key3": int64(2),
"key4": true,
})
assert.NoError(t, err)
assert.Equal(t, map[string]any{
"lambda2": "v2",
"lambda3": "v3",
}, out)
out, err = wf.Runner.Invoke(ctx, map[string]any{
"key1": "value2",
"key2": int64(2),
"key3": int64(3),
"key4": true,
})
assert.NoError(t, err)
assert.Equal(t, map[string]any{}, out)
}
func TestVariableAggregator(t *testing.T) {
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
}
exit := &compose2.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"Group1"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: "va",
FromPath: compose.FieldPath{"Group1"},
},
},
},
{
Path: compose.FieldPath{"Group2"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: "va",
FromPath: compose.FieldPath{"Group2"},
},
},
},
},
}
ns := &compose2.NodeSchema{
Key: "va",
Type: entity.NodeTypeVariableAggregator,
Configs: map[string]any{
"MergeStrategy": variableaggregator.FirstNotNullValue,
"GroupToLen": map[string]int{
"Group1": 1,
"Group2": 1,
},
"GroupOrder": []string{
"Group1",
"Group2",
},
},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"Group1", "0"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromPath: compose.FieldPath{"Str1"},
},
},
},
{
Path: compose.FieldPath{"Group2", "0"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromPath: compose.FieldPath{"Int1"},
},
},
},
},
InputTypes: map[string]*vo.TypeInfo{
"Group1": {
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"0": {
Type: vo.DataTypeString,
},
},
},
"Group2": {
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"0": {
Type: vo.DataTypeInteger,
},
},
},
},
OutputTypes: map[string]*vo.TypeInfo{
"Group1": {
Type: vo.DataTypeString,
},
"Group2": {
Type: vo.DataTypeInteger,
},
},
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
ns,
exit,
},
Connections: []*compose2.Connection{
{
FromNode: entry.Key,
ToNode: "va",
},
{
FromNode: "va",
ToNode: exit.Key,
},
},
}
ws.Init()
ctx := context.Background()
wf, err := compose2.NewWorkflow(ctx, ws)
assert.NoError(t, err)
out, err := wf.Runner.Invoke(context.Background(), map[string]any{
"Str1": "str_v1",
"Int1": int64(1),
})
assert.NoError(t, err)
assert.Equal(t, map[string]any{
"Group1": "str_v1",
"Group2": int64(1),
}, out)
out, err = wf.Runner.Invoke(context.Background(), map[string]any{
"Str1": "str_v1",
"Int1": nil,
})
assert.NoError(t, err)
assert.Equal(t, map[string]any{
"Group1": "str_v1",
"Group2": nil,
}, out)
}
func TestTextProcessor(t *testing.T) {
t.Run("split", func(t *testing.T) {
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
}
exit := &compose2.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"output"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: "tp",
FromPath: compose.FieldPath{"output"},
},
},
},
},
}
ns := &compose2.NodeSchema{
Key: "tp",
Type: entity.NodeTypeTextProcessor,
Configs: map[string]any{
"Type": textprocessor.SplitText,
"Separators": []string{"|"},
},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"String"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromPath: compose.FieldPath{"Str"},
},
},
},
},
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
ns,
entry,
exit,
},
Connections: []*compose2.Connection{
{
FromNode: entry.Key,
ToNode: "tp",
},
{
FromNode: "tp",
ToNode: exit.Key,
},
},
}
ws.Init()
wf, err := compose2.NewWorkflow(context.Background(), ws)
out, err := wf.Runner.Invoke(context.Background(), map[string]any{
"Str": "a|b|c",
})
assert.NoError(t, err)
assert.Equal(t, map[string]any{
"output": []any{"a", "b", "c"},
}, out)
})
t.Run("concat", func(t *testing.T) {
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
}
exit := &compose2.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"output"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: "tp",
FromPath: compose.FieldPath{"output"},
},
},
},
},
}
ns := &compose2.NodeSchema{
Key: "tp",
Type: entity.NodeTypeTextProcessor,
Configs: map[string]any{
"Type": textprocessor.ConcatText,
"Tpl": "{{String1}}_{{String2.f1}}_{{String3.f2[1]}}",
"ConcatChar": "\t",
},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"String1"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromPath: compose.FieldPath{"Str1"},
},
},
},
{
Path: compose.FieldPath{"String2"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromPath: compose.FieldPath{"Str2"},
},
},
},
{
Path: compose.FieldPath{"String3"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromPath: compose.FieldPath{"Str3"},
},
},
},
},
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
ns,
entry,
exit,
},
Connections: []*compose2.Connection{
{
FromNode: entry.Key,
ToNode: "tp",
},
{
FromNode: "tp",
ToNode: exit.Key,
},
},
}
ws.Init()
ctx := context.Background()
wf, err := compose2.NewWorkflow(ctx, ws)
assert.NoError(t, err)
out, err := wf.Runner.Invoke(context.Background(), map[string]any{
"Str1": true,
"Str2": map[string]any{
"f1": 1.0,
},
"Str3": map[string]any{
"f2": []any{1, "a"},
},
})
assert.NoError(t, err)
assert.Equal(t, map[string]any{
"output": "true_1_a",
}, out)
})
}

View File

@@ -0,0 +1,667 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package compose
import (
"context"
"errors"
"fmt"
"runtime/debug"
"strconv"
"time"
einomodel "github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/components/tool"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
crossmodelmgr "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
workflow3 "github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/workflow"
workflow2 "github.com/coze-dev/coze-studio/backend/domain/workflow"
crosscode "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/code"
crossconversation "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/conversation"
crossdatabase "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
crossknowledge "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
crossplugin "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/batch"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/code"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/conversation"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/emitter"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/httprequester"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/intentdetector"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/json"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/knowledge"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/plugin"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/receiver"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/subworkflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/textprocessor"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableaggregator"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableassigner"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
)
func (s *NodeSchema) ToEntryConfig(_ context.Context) (*entry.Config, error) {
return &entry.Config{
DefaultValues: getKeyOrZero[map[string]any]("DefaultValues", s.Configs),
OutputTypes: s.OutputTypes,
}, nil
}
func (s *NodeSchema) ToLLMConfig(ctx context.Context) (*llm.Config, error) {
llmConf := &llm.Config{
SystemPrompt: getKeyOrZero[string]("SystemPrompt", s.Configs),
UserPrompt: getKeyOrZero[string]("UserPrompt", s.Configs),
OutputFormat: mustGetKey[llm.Format]("OutputFormat", s.Configs),
InputFields: s.InputTypes,
OutputFields: s.OutputTypes,
FullSources: getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs),
}
llmParams := getKeyOrZero[*model.LLMParams]("LLMParams", s.Configs)
if llmParams == nil {
return nil, fmt.Errorf("llm node llmParams is required")
}
var (
err error
chatModel, fallbackM einomodel.BaseChatModel
info, fallbackI *crossmodelmgr.Model
modelWithInfo llm.ModelWithInfo
)
chatModel, info, err = model.GetManager().GetModel(ctx, llmParams)
if err != nil {
return nil, err
}
metaConfigs := s.ExceptionConfigs
if metaConfigs != nil && metaConfigs.MaxRetry > 0 {
backupModelParams := getKeyOrZero[*model.LLMParams]("BackupLLMParams", s.Configs)
if backupModelParams != nil {
fallbackM, fallbackI, err = model.GetManager().GetModel(ctx, backupModelParams)
if err != nil {
return nil, err
}
}
}
if fallbackM == nil {
modelWithInfo = llm.NewModel(chatModel, info)
} else {
modelWithInfo = llm.NewModelWithFallback(chatModel, fallbackM, info, fallbackI)
}
llmConf.ChatModel = modelWithInfo
fcParams := getKeyOrZero[*vo.FCParam]("FCParam", s.Configs)
if fcParams != nil {
if fcParams.WorkflowFCParam != nil {
for _, wf := range fcParams.WorkflowFCParam.WorkflowList {
wfIDStr := wf.WorkflowID
wfID, err := strconv.ParseInt(wfIDStr, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid workflow id: %s", wfIDStr)
}
workflowToolConfig := vo.WorkflowToolConfig{}
if wf.FCSetting != nil {
workflowToolConfig.InputParametersConfig = wf.FCSetting.RequestParameters
workflowToolConfig.OutputParametersConfig = wf.FCSetting.ResponseParameters
}
locator := vo.FromDraft
if wf.WorkflowVersion != "" {
locator = vo.FromSpecificVersion
}
wfTool, err := workflow2.GetRepository().WorkflowAsTool(ctx, vo.GetPolicy{
ID: wfID,
QType: locator,
Version: wf.WorkflowVersion,
}, workflowToolConfig)
if err != nil {
return nil, err
}
llmConf.Tools = append(llmConf.Tools, wfTool)
if wfTool.TerminatePlan() == vo.UseAnswerContent {
if llmConf.ToolsReturnDirectly == nil {
llmConf.ToolsReturnDirectly = make(map[string]bool)
}
toolInfo, err := wfTool.Info(ctx)
if err != nil {
return nil, err
}
llmConf.ToolsReturnDirectly[toolInfo.Name] = true
}
}
}
if fcParams.PluginFCParam != nil {
pluginToolsInvokableReq := make(map[int64]*crossplugin.ToolsInvokableRequest)
for _, p := range fcParams.PluginFCParam.PluginList {
pid, err := strconv.ParseInt(p.PluginID, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid plugin id: %s", p.PluginID)
}
toolID, err := strconv.ParseInt(p.ApiId, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid plugin id: %s", p.PluginID)
}
var (
requestParameters []*workflow3.APIParameter
responseParameters []*workflow3.APIParameter
)
if p.FCSetting != nil {
requestParameters = p.FCSetting.RequestParameters
responseParameters = p.FCSetting.ResponseParameters
}
if req, ok := pluginToolsInvokableReq[pid]; ok {
req.ToolsInvokableInfo[toolID] = &crossplugin.ToolsInvokableInfo{
ToolID: toolID,
RequestAPIParametersConfig: requestParameters,
ResponseAPIParametersConfig: responseParameters,
}
} else {
pluginToolsInfoRequest := &crossplugin.ToolsInvokableRequest{
PluginEntity: crossplugin.Entity{
PluginID: pid,
PluginVersion: ptr.Of(p.PluginVersion),
},
ToolsInvokableInfo: map[int64]*crossplugin.ToolsInvokableInfo{
toolID: {
ToolID: toolID,
RequestAPIParametersConfig: requestParameters,
ResponseAPIParametersConfig: responseParameters,
},
},
IsDraft: p.IsDraft,
}
pluginToolsInvokableReq[pid] = pluginToolsInfoRequest
}
}
inInvokableTools := make([]tool.BaseTool, 0, len(fcParams.PluginFCParam.PluginList))
for _, req := range pluginToolsInvokableReq {
toolMap, err := crossplugin.GetPluginService().GetPluginInvokableTools(ctx, req)
if err != nil {
return nil, err
}
for _, t := range toolMap {
inInvokableTools = append(inInvokableTools, crossplugin.NewInvokableTool(t))
}
}
if len(inInvokableTools) > 0 {
llmConf.Tools = inInvokableTools
}
}
if fcParams.KnowledgeFCParam != nil && len(fcParams.KnowledgeFCParam.KnowledgeList) > 0 {
kwChatModel, err := knowledgeRecallChatModel(ctx)
if err != nil {
return nil, err
}
knowledgeOperator := crossknowledge.GetKnowledgeOperator()
setting := fcParams.KnowledgeFCParam.GlobalSetting
cfg := &llm.KnowledgeRecallConfig{
ChatModel: kwChatModel,
Retriever: knowledgeOperator,
}
searchType, err := totRetrievalSearchType(setting.SearchMode)
if err != nil {
return nil, err
}
cfg.RetrievalStrategy = &llm.RetrievalStrategy{
RetrievalStrategy: &crossknowledge.RetrievalStrategy{
TopK: ptr.Of(setting.TopK),
MinScore: ptr.Of(setting.MinScore),
SearchType: searchType,
EnableNL2SQL: setting.UseNL2SQL,
EnableQueryRewrite: setting.UseRewrite,
EnableRerank: setting.UseRerank,
},
NoReCallReplyMode: llm.NoReCallReplyMode(setting.NoRecallReplyMode),
NoReCallReplyCustomizePrompt: setting.NoRecallReplyCustomizePrompt,
}
knowledgeIDs := make([]int64, 0, len(fcParams.KnowledgeFCParam.KnowledgeList))
for _, kw := range fcParams.KnowledgeFCParam.KnowledgeList {
kid, err := strconv.ParseInt(kw.ID, 10, 64)
if err != nil {
return nil, err
}
knowledgeIDs = append(knowledgeIDs, kid)
}
detailResp, err := knowledgeOperator.ListKnowledgeDetail(ctx, &crossknowledge.ListKnowledgeDetailRequest{
KnowledgeIDs: knowledgeIDs,
})
if err != nil {
return nil, err
}
cfg.SelectedKnowledgeDetails = detailResp.KnowledgeDetails
llmConf.KnowledgeRecallConfig = cfg
}
}
return llmConf, nil
}
func (s *NodeSchema) ToSelectorConfig() *selector.Config {
return &selector.Config{
Clauses: mustGetKey[[]*selector.OneClauseSchema]("Clauses", s.Configs),
}
}
func (s *NodeSchema) SelectorInputConverter(in map[string]any) (out []selector.Operants, err error) {
conf := mustGetKey[[]*selector.OneClauseSchema]("Clauses", s.Configs)
for i, oneConf := range conf {
if oneConf.Single != nil {
left, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), selector.LeftKey})
if !ok {
return nil, fmt.Errorf("failed to take left operant from input map: %v, clause index= %d", in, i)
}
right, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), selector.RightKey})
if ok {
out = append(out, selector.Operants{Left: left, Right: right})
} else {
out = append(out, selector.Operants{Left: left})
}
} else if oneConf.Multi != nil {
multiClause := make([]*selector.Operants, 0)
for j := range oneConf.Multi.Clauses {
left, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), strconv.Itoa(j), selector.LeftKey})
if !ok {
return nil, fmt.Errorf("failed to take left operant from input map: %v, clause index= %d, single clause index= %d", in, i, j)
}
right, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), strconv.Itoa(j), selector.RightKey})
if ok {
multiClause = append(multiClause, &selector.Operants{Left: left, Right: right})
} else {
multiClause = append(multiClause, &selector.Operants{Left: left})
}
}
out = append(out, selector.Operants{Multi: multiClause})
} else {
return nil, fmt.Errorf("invalid clause config, both single and multi are nil: %v", oneConf)
}
}
return out, nil
}
func (s *NodeSchema) ToBatchConfig(inner compose.Runnable[map[string]any, map[string]any]) (*batch.Config, error) {
conf := &batch.Config{
BatchNodeKey: s.Key,
InnerWorkflow: inner,
Outputs: s.OutputSources,
}
for key, tInfo := range s.InputTypes {
if tInfo.Type != vo.DataTypeArray {
continue
}
conf.InputArrays = append(conf.InputArrays, key)
}
return conf, nil
}
func (s *NodeSchema) ToVariableAggregatorConfig() (*variableaggregator.Config, error) {
return &variableaggregator.Config{
MergeStrategy: s.Configs.(map[string]any)["MergeStrategy"].(variableaggregator.MergeStrategy),
GroupLen: s.Configs.(map[string]any)["GroupToLen"].(map[string]int),
FullSources: getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs),
NodeKey: s.Key,
InputSources: s.InputSources,
GroupOrder: mustGetKey[[]string]("GroupOrder", s.Configs),
}, nil
}
func (s *NodeSchema) variableAggregatorInputConverter(in map[string]any) (converted map[string]map[int]any) {
converted = make(map[string]map[int]any)
for k, value := range in {
m, ok := value.(map[string]any)
if !ok {
panic(errors.New("value is not a map[string]any"))
}
converted[k] = make(map[int]any, len(m))
for i, sv := range m {
index, err := strconv.Atoi(i)
if err != nil {
panic(fmt.Errorf(" converting %s to int failed, err=%v", i, err))
}
converted[k][index] = sv
}
}
return converted
}
func (s *NodeSchema) variableAggregatorStreamInputConverter(in *schema.StreamReader[map[string]any]) *schema.StreamReader[map[string]map[int]any] {
converter := func(input map[string]any) (output map[string]map[int]any, err error) {
defer func() {
if r := recover(); r != nil {
err = safego.NewPanicErr(r, debug.Stack())
}
}()
return s.variableAggregatorInputConverter(input), nil
}
return schema.StreamReaderWithConvert(in, converter)
}
func (s *NodeSchema) ToTextProcessorConfig() (*textprocessor.Config, error) {
return &textprocessor.Config{
Type: s.Configs.(map[string]any)["Type"].(textprocessor.Type),
Tpl: getKeyOrZero[string]("Tpl", s.Configs.(map[string]any)),
ConcatChar: getKeyOrZero[string]("ConcatChar", s.Configs.(map[string]any)),
Separators: getKeyOrZero[[]string]("Separators", s.Configs.(map[string]any)),
FullSources: getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs),
}, nil
}
func (s *NodeSchema) ToJsonSerializationConfig() (*json.SerializationConfig, error) {
return &json.SerializationConfig{
InputTypes: s.InputTypes,
}, nil
}
func (s *NodeSchema) ToJsonDeserializationConfig() (*json.DeserializationConfig, error) {
return &json.DeserializationConfig{
OutputFields: s.OutputTypes,
}, nil
}
func (s *NodeSchema) ToHTTPRequesterConfig() (*httprequester.Config, error) {
return &httprequester.Config{
URLConfig: mustGetKey[httprequester.URLConfig]("URLConfig", s.Configs),
AuthConfig: getKeyOrZero[*httprequester.AuthenticationConfig]("AuthConfig", s.Configs),
BodyConfig: mustGetKey[httprequester.BodyConfig]("BodyConfig", s.Configs),
Method: mustGetKey[string]("Method", s.Configs),
Timeout: mustGetKey[time.Duration]("Timeout", s.Configs),
RetryTimes: mustGetKey[uint64]("RetryTimes", s.Configs),
MD5FieldMapping: mustGetKey[httprequester.MD5FieldMapping]("MD5FieldMapping", s.Configs),
}, nil
}
func (s *NodeSchema) ToVariableAssignerConfig(handler *variable.Handler) (*variableassigner.Config, error) {
return &variableassigner.Config{
Pairs: s.Configs.([]*variableassigner.Pair),
Handler: handler,
}, nil
}
func (s *NodeSchema) ToVariableAssignerInLoopConfig() (*variableassigner.Config, error) {
return &variableassigner.Config{
Pairs: s.Configs.([]*variableassigner.Pair),
}, nil
}
func (s *NodeSchema) ToLoopConfig(inner compose.Runnable[map[string]any, map[string]any]) (*loop.Config, error) {
conf := &loop.Config{
LoopNodeKey: s.Key,
LoopType: mustGetKey[loop.Type]("LoopType", s.Configs),
IntermediateVars: getKeyOrZero[map[string]*vo.TypeInfo]("IntermediateVars", s.Configs),
Outputs: s.OutputSources,
Inner: inner,
}
for key, tInfo := range s.InputTypes {
if tInfo.Type != vo.DataTypeArray {
continue
}
if _, ok := conf.IntermediateVars[key]; ok { // exclude arrays in intermediate vars
continue
}
conf.InputArrays = append(conf.InputArrays, key)
}
return conf, nil
}
func (s *NodeSchema) ToQAConfig(ctx context.Context) (*qa.Config, error) {
conf := &qa.Config{
QuestionTpl: mustGetKey[string]("QuestionTpl", s.Configs),
AnswerType: mustGetKey[qa.AnswerType]("AnswerType", s.Configs),
ChoiceType: getKeyOrZero[qa.ChoiceType]("ChoiceType", s.Configs),
FixedChoices: getKeyOrZero[[]string]("FixedChoices", s.Configs),
ExtractFromAnswer: getKeyOrZero[bool]("ExtractFromAnswer", s.Configs),
MaxAnswerCount: getKeyOrZero[int]("MaxAnswerCount", s.Configs),
AdditionalSystemPromptTpl: getKeyOrZero[string]("AdditionalSystemPromptTpl", s.Configs),
OutputFields: s.OutputTypes,
NodeKey: s.Key,
}
llmParams := getKeyOrZero[*model.LLMParams]("LLMParams", s.Configs)
if llmParams != nil {
m, _, err := model.GetManager().GetModel(ctx, llmParams)
if err != nil {
return nil, err
}
conf.Model = m
}
return conf, nil
}
func (s *NodeSchema) ToInputReceiverConfig() (*receiver.Config, error) {
return &receiver.Config{
OutputTypes: s.OutputTypes,
NodeKey: s.Key,
OutputSchema: mustGetKey[string]("OutputSchema", s.Configs),
}, nil
}
func (s *NodeSchema) ToOutputEmitterConfig(sc *WorkflowSchema) (*emitter.Config, error) {
conf := &emitter.Config{
Template: getKeyOrZero[string]("Template", s.Configs),
FullSources: getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs),
}
return conf, nil
}
func (s *NodeSchema) ToDatabaseCustomSQLConfig() (*database.CustomSQLConfig, error) {
return &database.CustomSQLConfig{
DatabaseInfoID: mustGetKey[int64]("DatabaseInfoID", s.Configs),
SQLTemplate: mustGetKey[string]("SQLTemplate", s.Configs),
OutputConfig: s.OutputTypes,
CustomSQLExecutor: crossdatabase.GetDatabaseOperator(),
}, nil
}
func (s *NodeSchema) ToDatabaseQueryConfig() (*database.QueryConfig, error) {
return &database.QueryConfig{
DatabaseInfoID: mustGetKey[int64]("DatabaseInfoID", s.Configs),
QueryFields: getKeyOrZero[[]string]("QueryFields", s.Configs),
OrderClauses: getKeyOrZero[[]*crossdatabase.OrderClause]("OrderClauses", s.Configs),
ClauseGroup: getKeyOrZero[*crossdatabase.ClauseGroup]("ClauseGroup", s.Configs),
OutputConfig: s.OutputTypes,
Limit: mustGetKey[int64]("Limit", s.Configs),
Op: crossdatabase.GetDatabaseOperator(),
}, nil
}
func (s *NodeSchema) ToDatabaseInsertConfig() (*database.InsertConfig, error) {
return &database.InsertConfig{
DatabaseInfoID: mustGetKey[int64]("DatabaseInfoID", s.Configs),
OutputConfig: s.OutputTypes,
Inserter: crossdatabase.GetDatabaseOperator(),
}, nil
}
func (s *NodeSchema) ToDatabaseDeleteConfig() (*database.DeleteConfig, error) {
return &database.DeleteConfig{
DatabaseInfoID: mustGetKey[int64]("DatabaseInfoID", s.Configs),
ClauseGroup: mustGetKey[*crossdatabase.ClauseGroup]("ClauseGroup", s.Configs),
OutputConfig: s.OutputTypes,
Deleter: crossdatabase.GetDatabaseOperator(),
}, nil
}
func (s *NodeSchema) ToDatabaseUpdateConfig() (*database.UpdateConfig, error) {
return &database.UpdateConfig{
DatabaseInfoID: mustGetKey[int64]("DatabaseInfoID", s.Configs),
ClauseGroup: mustGetKey[*crossdatabase.ClauseGroup]("ClauseGroup", s.Configs),
OutputConfig: s.OutputTypes,
Updater: crossdatabase.GetDatabaseOperator(),
}, nil
}
func (s *NodeSchema) ToKnowledgeIndexerConfig() (*knowledge.IndexerConfig, error) {
return &knowledge.IndexerConfig{
KnowledgeID: mustGetKey[int64]("KnowledgeID", s.Configs),
ParsingStrategy: mustGetKey[*crossknowledge.ParsingStrategy]("ParsingStrategy", s.Configs),
ChunkingStrategy: mustGetKey[*crossknowledge.ChunkingStrategy]("ChunkingStrategy", s.Configs),
KnowledgeIndexer: crossknowledge.GetKnowledgeOperator(),
}, nil
}
func (s *NodeSchema) ToKnowledgeRetrieveConfig() (*knowledge.RetrieveConfig, error) {
return &knowledge.RetrieveConfig{
KnowledgeIDs: mustGetKey[[]int64]("KnowledgeIDs", s.Configs),
RetrievalStrategy: mustGetKey[*crossknowledge.RetrievalStrategy]("RetrievalStrategy", s.Configs),
Retriever: crossknowledge.GetKnowledgeOperator(),
}, nil
}
func (s *NodeSchema) ToKnowledgeDeleterConfig() (*knowledge.DeleterConfig, error) {
return &knowledge.DeleterConfig{
KnowledgeID: mustGetKey[int64]("KnowledgeID", s.Configs),
KnowledgeDeleter: crossknowledge.GetKnowledgeOperator(),
}, nil
}
func (s *NodeSchema) ToPluginConfig() (*plugin.Config, error) {
return &plugin.Config{
PluginID: mustGetKey[int64]("PluginID", s.Configs),
ToolID: mustGetKey[int64]("ToolID", s.Configs),
PluginVersion: mustGetKey[string]("PluginVersion", s.Configs),
PluginService: crossplugin.GetPluginService(),
}, nil
}
func (s *NodeSchema) ToCodeRunnerConfig() (*code.Config, error) {
return &code.Config{
Code: mustGetKey[string]("Code", s.Configs),
Language: mustGetKey[crosscode.Language]("Language", s.Configs),
OutputConfig: s.OutputTypes,
Runner: crosscode.GetCodeRunner(),
}, nil
}
func (s *NodeSchema) ToCreateConversationConfig() (*conversation.CreateConversationConfig, error) {
return &conversation.CreateConversationConfig{
Creator: crossconversation.ConversationManagerImpl,
}, nil
}
func (s *NodeSchema) ToClearMessageConfig() (*conversation.ClearMessageConfig, error) {
return &conversation.ClearMessageConfig{
Clearer: crossconversation.ConversationManagerImpl,
}, nil
}
func (s *NodeSchema) ToMessageListConfig() (*conversation.MessageListConfig, error) {
return &conversation.MessageListConfig{
Lister: crossconversation.ConversationManagerImpl,
}, nil
}
func (s *NodeSchema) ToIntentDetectorConfig(ctx context.Context) (*intentdetector.Config, error) {
cfg := &intentdetector.Config{
Intents: mustGetKey[[]string]("Intents", s.Configs),
SystemPrompt: getKeyOrZero[string]("SystemPrompt", s.Configs),
IsFastMode: getKeyOrZero[bool]("IsFastMode", s.Configs),
}
llmParams := mustGetKey[*model.LLMParams]("LLMParams", s.Configs)
m, _, err := model.GetManager().GetModel(ctx, llmParams)
if err != nil {
return nil, err
}
cfg.ChatModel = m
return cfg, nil
}
func (s *NodeSchema) ToSubWorkflowConfig(ctx context.Context, requireCheckpoint bool) (*subworkflow.Config, error) {
var opts []WorkflowOption
opts = append(opts, WithIDAsName(mustGetKey[int64]("WorkflowID", s.Configs)))
if requireCheckpoint {
opts = append(opts, WithParentRequireCheckpoint())
}
if s := execute.GetStaticConfig(); s != nil && s.MaxNodeCountPerWorkflow > 0 {
opts = append(opts, WithMaxNodeCount(s.MaxNodeCountPerWorkflow))
}
wf, err := NewWorkflow(ctx, s.SubWorkflowSchema, opts...)
if err != nil {
return nil, err
}
return &subworkflow.Config{
Runner: wf.Runner,
}, nil
}
func totRetrievalSearchType(s int64) (crossknowledge.SearchType, error) {
switch s {
case 0:
return crossknowledge.SearchTypeSemantic, nil
case 1:
return crossknowledge.SearchTypeHybrid, nil
case 20:
return crossknowledge.SearchTypeFullText, nil
default:
return "", fmt.Errorf("invalid retrieval search type %v", s)
}
}
// knowledgeRecallChatModel the chat model used by the knowledge base recall in the LLM node, not the user-configured model
func knowledgeRecallChatModel(ctx context.Context) (einomodel.BaseChatModel, error) {
defaultChatModelParma := &model.LLMParams{
ModelName: "豆包·1.5·Pro·32k",
ModelType: 1,
Temperature: ptr.Of(0.5),
MaxTokens: 4096,
}
m, _, err := model.GetManager().GetModel(ctx, defaultChatModelParma)
return m, err
}

View File

@@ -0,0 +1,107 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package compose
import (
"fmt"
"reflect"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
)
func getKeyOrZero[T any](key string, cfg any) T {
var zero T
if cfg == nil {
return zero
}
m, ok := cfg.(map[string]any)
if !ok {
panic(fmt.Sprintf("m is not a map[string]any, actual type: %v", reflect.TypeOf(cfg)))
}
if len(m) == 0 {
return zero
}
if v, ok := m[key]; ok {
return v.(T)
}
return zero
}
func mustGetKey[T any](key string, cfg any) T {
if cfg == nil {
panic(fmt.Sprintf("mustGetKey[*any] is nil, key=%s", key))
}
m, ok := cfg.(map[string]any)
if !ok {
panic(fmt.Sprintf("m is not a map[string]any, actual type: %v", reflect.TypeOf(cfg)))
}
if _, ok := m[key]; !ok {
panic(fmt.Sprintf("key %s does not exist in map: %v", key, m))
}
v, ok := m[key].(T)
if !ok {
panic(fmt.Sprintf("key %s is not a %v, actual type: %v", key, reflect.TypeOf(v), reflect.TypeOf(m[key])))
}
return v
}
func (s *NodeSchema) SetConfigKV(key string, value any) {
if s.Configs == nil {
s.Configs = make(map[string]any)
}
s.Configs.(map[string]any)[key] = value
}
func (s *NodeSchema) SetInputType(key string, t *vo.TypeInfo) {
if s.InputTypes == nil {
s.InputTypes = make(map[string]*vo.TypeInfo)
}
s.InputTypes[key] = t
}
func (s *NodeSchema) AddInputSource(info ...*vo.FieldInfo) {
s.InputSources = append(s.InputSources, info...)
}
func (s *NodeSchema) SetOutputType(key string, t *vo.TypeInfo) {
if s.OutputTypes == nil {
s.OutputTypes = make(map[string]*vo.TypeInfo)
}
s.OutputTypes[key] = t
}
func (s *NodeSchema) AddOutputSource(info ...*vo.FieldInfo) {
s.OutputSources = append(s.OutputSources, info...)
}
func (s *NodeSchema) GetSubWorkflowIdentity() (int64, string, bool) {
if s.Type != entity.NodeTypeSubWorkflow {
return 0, "", false
}
return mustGetKey[int64]("WorkflowID", s.Configs), mustGetKey[string]("WorkflowVersion", s.Configs), true
}

View File

@@ -0,0 +1,933 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package compose
import (
"context"
"errors"
"fmt"
"slices"
"strconv"
"strings"
"github.com/cloudwego/eino/compose"
workflow2 "github.com/coze-dev/coze-studio/backend/domain/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
)
type workflow = compose.Workflow[map[string]any, map[string]any]
type Workflow struct { // TODO: too many fields in this struct, cut them down to the absolutely essentials
*workflow
hierarchy map[vo.NodeKey]vo.NodeKey
connections []*Connection
requireCheckpoint bool
entry *compose.WorkflowNode
inner bool
fromNode bool // this workflow is constructed from a single node, without Entry or Exit nodes
streamRun bool
Runner compose.Runnable[map[string]any, map[string]any] // TODO: this will be unexported eventually
input map[string]*vo.TypeInfo
output map[string]*vo.TypeInfo
terminatePlan vo.TerminatePlan
schema *WorkflowSchema
}
type workflowOptions struct {
wfID int64
idAsName bool
parentRequireCheckpoint bool
maxNodeCount int
}
type WorkflowOption func(*workflowOptions)
func WithIDAsName(id int64) WorkflowOption {
return func(opts *workflowOptions) {
opts.wfID = id
opts.idAsName = true
}
}
func WithParentRequireCheckpoint() WorkflowOption {
return func(opts *workflowOptions) {
opts.parentRequireCheckpoint = true
}
}
func WithMaxNodeCount(c int) WorkflowOption {
return func(opts *workflowOptions) {
opts.maxNodeCount = c
}
}
func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption) (*Workflow, error) {
sc.Init()
wf := &Workflow{
workflow: compose.NewWorkflow[map[string]any, map[string]any](compose.WithGenLocalState(GenState())),
hierarchy: sc.Hierarchy,
connections: sc.Connections,
schema: sc,
}
wf.streamRun = sc.requireStreaming
wf.requireCheckpoint = sc.requireCheckPoint
wfOpts := &workflowOptions{}
for _, opt := range opts {
opt(wfOpts)
}
if wfOpts.maxNodeCount > 0 {
if sc.NodeCount() > int32(wfOpts.maxNodeCount) {
return nil, fmt.Errorf("node count %d exceeds the limit: %d", sc.NodeCount(), wfOpts.maxNodeCount)
}
}
if wfOpts.parentRequireCheckpoint {
wf.requireCheckpoint = true
}
wf.input = sc.GetNode(entity.EntryNodeKey).OutputTypes
// even if the terminate plan is use answer content, this still will be 'input types' of exit node
wf.output = sc.GetNode(entity.ExitNodeKey).InputTypes
// add all composite nodes with their inner workflow
compositeNodes := sc.GetCompositeNodes()
processedNodeKey := make(map[vo.NodeKey]struct{})
for i := range compositeNodes {
cNode := compositeNodes[i]
if err := wf.AddCompositeNode(ctx, cNode); err != nil {
return nil, err
}
processedNodeKey[cNode.Parent.Key] = struct{}{}
for _, child := range cNode.Children {
processedNodeKey[child.Key] = struct{}{}
}
}
// add all nodes other than composite nodes and their children
for _, ns := range sc.Nodes {
if _, ok := processedNodeKey[ns.Key]; !ok {
if err := wf.AddNode(ctx, ns); err != nil {
return nil, err
}
}
if ns.Type == entity.NodeTypeExit {
wf.terminatePlan = mustGetKey[vo.TerminatePlan]("TerminalPlan", ns.Configs)
}
}
var compileOpts []compose.GraphCompileOption
if wf.requireCheckpoint {
compileOpts = append(compileOpts, compose.WithCheckPointStore(workflow2.GetRepository()))
}
if wfOpts.idAsName {
compileOpts = append(compileOpts, compose.WithGraphName(strconv.FormatInt(wfOpts.wfID, 10)))
}
fanInConfigs := sc.fanInMergeConfigs()
if len(fanInConfigs) > 0 {
compileOpts = append(compileOpts, compose.WithFanInMergeConfig(fanInConfigs))
}
r, err := wf.Compile(ctx, compileOpts...)
if err != nil {
return nil, err
}
wf.Runner = r
return wf, nil
}
func (w *Workflow) AsyncRun(ctx context.Context, in map[string]any, opts ...compose.Option) {
if w.streamRun {
safego.Go(ctx, func() {
_, _ = w.Runner.Stream(ctx, in, opts...)
})
return
}
safego.Go(ctx, func() {
_, _ = w.Runner.Invoke(ctx, in, opts...)
})
}
func (w *Workflow) SyncRun(ctx context.Context, in map[string]any, opts ...compose.Option) (map[string]any, error) {
return w.Runner.Invoke(ctx, in, opts...)
}
func (w *Workflow) Inputs() map[string]*vo.TypeInfo {
return w.input
}
func (w *Workflow) Outputs() map[string]*vo.TypeInfo {
return w.output
}
func (w *Workflow) StreamRun() bool {
return w.streamRun
}
func (w *Workflow) TerminatePlan() vo.TerminatePlan {
return w.terminatePlan
}
type innerWorkflowInfo struct {
inner compose.Runnable[map[string]any, map[string]any]
carryOvers map[vo.NodeKey][]*compose.FieldMapping
}
func (w *Workflow) AddNode(ctx context.Context, ns *NodeSchema) error {
_, err := w.addNodeInternal(ctx, ns, nil)
return err
}
func (w *Workflow) AddCompositeNode(ctx context.Context, cNode *CompositeNode) error {
inner, err := w.getInnerWorkflow(ctx, cNode)
if err != nil {
return err
}
_, err = w.addNodeInternal(ctx, cNode.Parent, inner)
return err
}
func (w *Workflow) addInnerNode(ctx context.Context, cNode *NodeSchema) (map[vo.NodeKey][]*compose.FieldMapping, error) {
return w.addNodeInternal(ctx, cNode, nil)
}
func (w *Workflow) addNodeInternal(ctx context.Context, ns *NodeSchema, inner *innerWorkflowInfo) (map[vo.NodeKey][]*compose.FieldMapping, error) {
key := ns.Key
var deps *dependencyInfo
deps, err := w.resolveDependencies(key, ns.InputSources)
if err != nil {
return nil, err
}
if inner != nil {
if err = deps.merge(inner.carryOvers); err != nil {
return nil, err
}
}
var innerWorkflow compose.Runnable[map[string]any, map[string]any]
if inner != nil {
innerWorkflow = inner.inner
}
ins, err := ns.New(ctx, innerWorkflow, w.schema, deps)
if err != nil {
return nil, err
}
var opts []compose.GraphAddNodeOpt
opts = append(opts, compose.WithNodeName(string(ns.Key)))
preHandler := ns.StatePreHandler(w.streamRun)
if preHandler != nil {
opts = append(opts, preHandler)
}
postHandler := ns.StatePostHandler(w.streamRun)
if postHandler != nil {
opts = append(opts, postHandler)
}
var wNode *compose.WorkflowNode
if ins.Lambda != nil {
wNode = w.AddLambdaNode(string(key), ins.Lambda, opts...)
} else {
return nil, fmt.Errorf("node instance has no Lambda: %s", key)
}
if err = deps.arrayDrillDown(w.schema.GetAllNodes()); err != nil {
return nil, err
}
for fromNodeKey := range deps.inputsFull {
wNode.AddInput(string(fromNodeKey))
}
for fromNodeKey, fieldMappings := range deps.inputs {
wNode.AddInput(string(fromNodeKey), fieldMappings...)
}
for fromNodeKey := range deps.inputsNoDirectDependencyFull {
wNode.AddInputWithOptions(string(fromNodeKey), nil, compose.WithNoDirectDependency())
}
for fromNodeKey, fieldMappings := range deps.inputsNoDirectDependency {
wNode.AddInputWithOptions(string(fromNodeKey), fieldMappings, compose.WithNoDirectDependency())
}
for i := range deps.dependencies {
wNode.AddDependency(string(deps.dependencies[i]))
}
for i := range deps.staticValues {
wNode.SetStaticValue(deps.staticValues[i].path, deps.staticValues[i].val)
}
if ns.Type == entity.NodeTypeEntry {
if w.entry != nil {
return nil, errors.New("entry node already set")
}
w.entry = wNode
}
outputPortCount, hasExceptionPort := ns.OutputPortCount()
if outputPortCount > 1 || hasExceptionPort {
bMapping, err := w.resolveBranch(key, outputPortCount)
if err != nil {
return nil, err
}
branch, err := ns.GetBranch(bMapping)
if err != nil {
return nil, err
}
_ = w.AddBranch(string(key), branch)
}
return deps.inputsForParent, nil
}
func (w *Workflow) Compile(ctx context.Context, opts ...compose.GraphCompileOption) (compose.Runnable[map[string]any, map[string]any], error) {
if !w.inner && !w.fromNode {
if w.entry == nil {
return nil, fmt.Errorf("entry node is not set")
}
w.entry.AddInput(compose.START)
w.End().AddInput(entity.ExitNodeKey)
}
return w.workflow.Compile(ctx, opts...)
}
func (w *Workflow) getInnerWorkflow(ctx context.Context, cNode *CompositeNode) (*innerWorkflowInfo, error) {
innerNodes := make(map[vo.NodeKey]*NodeSchema)
for _, n := range cNode.Children {
innerNodes[n.Key] = n
}
// trim the connections, only keep the connections that are related to the inner workflow
// ignore the cases when we have nested inner workflows, because we do not support nested composite nodes
innerConnections := make([]*Connection, 0)
for i := range w.schema.Connections {
conn := w.schema.Connections[i]
if _, ok := innerNodes[conn.FromNode]; ok {
innerConnections = append(innerConnections, conn)
} else if _, ok := innerNodes[conn.ToNode]; ok {
innerConnections = append(innerConnections, conn)
}
}
inner := &Workflow{
workflow: compose.NewWorkflow[map[string]any, map[string]any](compose.WithGenLocalState(GenState())),
hierarchy: w.hierarchy, // we keep the entire hierarchy because inner workflow nodes can refer to parent nodes' outputs
connections: innerConnections,
inner: true,
requireCheckpoint: w.requireCheckpoint,
schema: w.schema,
}
carryOvers := make(map[vo.NodeKey][]*compose.FieldMapping)
for key := range innerNodes {
inputsForParent, err := inner.addInnerNode(ctx, innerNodes[key])
if err != nil {
return nil, err
}
for fromNodeKey, fieldMappings := range inputsForParent {
if fromNodeKey == cNode.Parent.Key { // refer to parent itself, no need to carry over
continue
}
if _, ok := carryOvers[fromNodeKey]; !ok {
carryOvers[fromNodeKey] = make([]*compose.FieldMapping, 0)
}
for _, fm := range fieldMappings {
duplicate := false
for _, existing := range carryOvers[fromNodeKey] {
if fm.Equals(existing) {
duplicate = true
break
}
}
if !duplicate {
carryOvers[fromNodeKey] = append(carryOvers[fromNodeKey], fieldMappings...)
}
}
}
}
endDeps, err := inner.resolveDependenciesAsParent(cNode.Parent.Key, cNode.Parent.OutputSources)
if err != nil {
return nil, fmt.Errorf("resolve dependencies of parent node: %s failed: %w", cNode.Parent.Key, err)
}
n := inner.End()
for fromNodeKey := range endDeps.inputsFull {
n.AddInput(string(fromNodeKey))
}
for fromNodeKey, fieldMappings := range endDeps.inputs {
n.AddInput(string(fromNodeKey), fieldMappings...)
}
for fromNodeKey := range endDeps.inputsNoDirectDependencyFull {
n.AddInputWithOptions(string(fromNodeKey), nil, compose.WithNoDirectDependency())
}
for fromNodeKey, fieldMappings := range endDeps.inputsNoDirectDependency {
n.AddInputWithOptions(string(fromNodeKey), fieldMappings, compose.WithNoDirectDependency())
}
for i := range endDeps.dependencies {
n.AddDependency(string(endDeps.dependencies[i]))
}
for i := range endDeps.staticValues {
n.SetStaticValue(endDeps.staticValues[i].path, endDeps.staticValues[i].val)
}
var opts []compose.GraphCompileOption
if inner.requireCheckpoint {
opts = append(opts, compose.WithCheckPointStore(workflow2.GetRepository()))
}
r, err := inner.Compile(ctx, opts...)
if err != nil {
return nil, err
}
return &innerWorkflowInfo{
inner: r,
carryOvers: carryOvers,
}, nil
}
type dependencyInfo struct {
inputs map[vo.NodeKey][]*compose.FieldMapping
inputsFull map[vo.NodeKey]struct{}
dependencies []vo.NodeKey
inputsNoDirectDependency map[vo.NodeKey][]*compose.FieldMapping
inputsNoDirectDependencyFull map[vo.NodeKey]struct{}
staticValues []*staticValue
variableInfos []*variableInfo
inputsForParent map[vo.NodeKey][]*compose.FieldMapping
}
func (d *dependencyInfo) merge(mappings map[vo.NodeKey][]*compose.FieldMapping) error {
for nKey, fms := range mappings {
if _, ok := d.inputsFull[nKey]; ok {
return fmt.Errorf("duplicate input for node: %s", nKey)
}
if _, ok := d.inputsNoDirectDependencyFull[nKey]; ok {
return fmt.Errorf("duplicate input for node: %s", nKey)
}
if currentFMS, ok := d.inputs[nKey]; ok {
for i := range fms {
fm := fms[i]
duplicate := false
for _, currentFM := range currentFMS {
if fm.Equals(currentFM) {
duplicate = true
}
}
if !duplicate {
d.inputs[nKey] = append(d.inputs[nKey], fm)
}
}
} else if currentFMS, ok = d.inputsNoDirectDependency[nKey]; ok {
for i := range fms {
fm := fms[i]
duplicate := false
for _, currentFM := range currentFMS {
if fm.Equals(currentFM) {
duplicate = true
}
}
if !duplicate {
d.inputsNoDirectDependency[nKey] = append(d.inputsNoDirectDependency[nKey], fm)
}
}
} else {
currentDependency := -1
for i, depKey := range d.dependencies {
if depKey == nKey {
currentDependency = i
break
}
}
if currentDependency >= 0 {
d.dependencies = append(d.dependencies[:currentDependency], d.dependencies[currentDependency+1:]...)
d.inputs[nKey] = append(d.inputs[nKey], fms...)
} else {
d.inputsNoDirectDependency[nKey] = append(d.inputsNoDirectDependency[nKey], fms...)
}
}
}
return nil
}
// arrayDrillDown happens when the 'mapping from path' is taking fields from elements within arrays.
// when this happens, we automatically takes the first element from any arrays along the 'from path'.
// For example, if the 'from path' is ['a', 'b', 'c'], and 'b' is an array, we will take value using a.b[0].c.
// As a counter example, if the 'from path' is ['a', 'b', 'c'], and 'b' is not an array, but 'c' is an array,
// we will not try to drill, instead, just take value using a.b.c.
func (d *dependencyInfo) arrayDrillDown(allNS map[vo.NodeKey]*NodeSchema) error {
for nKey, fms := range d.inputs {
if nKey == compose.START { // reference to START node would NEVER need to do array drill down
continue
}
var ot map[string]*vo.TypeInfo
ots, ok := allNS[nKey]
if !ok {
return fmt.Errorf("node not found: %s", nKey)
}
ot = ots.OutputTypes
for i := range fms {
fm := fms[i]
newFM, err := arrayDrillDown(nKey, fm, ot)
if err != nil {
return err
}
fms[i] = newFM
}
}
for nKey, fms := range d.inputsNoDirectDependency {
if nKey == compose.START {
continue
}
var ot map[string]*vo.TypeInfo
ots, ok := allNS[nKey]
if !ok {
return fmt.Errorf("node not found: %s", nKey)
}
ot = ots.OutputTypes
for i := range fms {
fm := fms[i]
newFM, err := arrayDrillDown(nKey, fm, ot)
if err != nil {
return err
}
fms[i] = newFM
}
}
return nil
}
func arrayDrillDown(nKey vo.NodeKey, fm *compose.FieldMapping, types map[string]*vo.TypeInfo) (*compose.FieldMapping, error) {
fromPath := fm.FromPath()
if len(fromPath) <= 1 { // no need to drill down
return fm, nil
}
ct := types
var arraySegIndexes []int
for j := 0; j < len(fromPath)-1; j++ {
p := fromPath[j]
t, ok := ct[p]
if !ok {
return nil, fmt.Errorf("type info not found for path: %s", fm.FromPath()[:j+1])
}
if t.Type == vo.DataTypeArray {
arraySegIndexes = append(arraySegIndexes, j)
if t.ElemTypeInfo.Type == vo.DataTypeObject {
ct = t.ElemTypeInfo.Properties
} else if j != len(fromPath)-1 {
return nil, fmt.Errorf("[arrayDrillDown] already found array of none obj, but still not last segment of path: %v",
fromPath[:j+1])
}
} else if t.Type == vo.DataTypeObject {
ct = t.Properties
} else if j != len(fromPath)-1 {
return nil, fmt.Errorf("[arrayDrillDown] found non-array, non-obj type: %v, but still not last segment of path: %v",
t.Type, fromPath[:j+1])
}
}
if len(arraySegIndexes) == 0 { // no arrays along from path
return fm, nil
}
extractor := func(a any) (any, error) {
for j := range fromPath {
p := fromPath[j]
m, ok := a.(map[string]any)
if !ok {
return nil, fmt.Errorf("[arrayDrillDown] trying to drill down from a non-map type:%T of path %s, "+
"from node key: %v", a, fromPath[:j+1], nKey)
}
a, ok = m[p]
if !ok {
return nil, fmt.Errorf("[arrayDrillDown] field %s not found along from path: %s, "+
"from node key: %v", p, fromPath[:j+1], nKey)
}
if slices.Contains(arraySegIndexes, j) { // this is an array needs drilling down
arr, ok := a.([]any)
if !ok {
return nil, fmt.Errorf("[arrayDrillDown] trying to drill down from a non-array type:%T of path %s, "+
"from node key: %v", a, fromPath[:j+1], nKey)
}
if len(arr) == 0 {
return nil, fmt.Errorf("[arrayDrillDown] trying to drill down from an array of length 0: %s, "+
"from node key: %v", fromPath[:j+1], nKey)
}
a = arr[0]
}
}
return a, nil
}
newFM := compose.ToFieldPath(fm.ToPath(), compose.WithCustomExtractor(extractor))
return newFM, nil
}
type staticValue struct {
val any
path compose.FieldPath
}
type variableInfo struct {
varType vo.GlobalVarType
fromPath compose.FieldPath
toPath compose.FieldPath
}
func (w *Workflow) resolveBranch(n vo.NodeKey, portCount int) (*BranchMapping, error) {
m := make([]map[string]bool, portCount)
var exception map[string]bool
for _, conn := range w.connections {
if conn.FromNode != n {
continue
}
if conn.FromPort == nil {
continue
}
if *conn.FromPort == "default" { // default condition
if m[portCount-1] == nil {
m[portCount-1] = make(map[string]bool)
}
m[portCount-1][string(conn.ToNode)] = true
} else if *conn.FromPort == "branch_error" {
if exception == nil {
exception = make(map[string]bool)
}
exception[string(conn.ToNode)] = true
} else {
if !strings.HasPrefix(*conn.FromPort, "branch_") {
return nil, fmt.Errorf("outgoing connections has invalid port= %s", *conn.FromPort)
}
index := (*conn.FromPort)[7:]
i, err := strconv.Atoi(index)
if err != nil {
return nil, fmt.Errorf("outgoing connections has invalid port index= %s", *conn.FromPort)
}
if i < 0 || i >= portCount {
return nil, fmt.Errorf("outgoing connections has invalid port index range= %d, condition count= %d", i, portCount)
}
if m[i] == nil {
m[i] = make(map[string]bool)
}
m[i][string(conn.ToNode)] = true
}
}
return &BranchMapping{
Normal: m,
Exception: exception,
}, nil
}
func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.FieldInfo) (*dependencyInfo, error) {
var (
inputs = make(map[vo.NodeKey][]*compose.FieldMapping)
inputFull map[vo.NodeKey]struct{}
dependencies []vo.NodeKey
inputsNoDirectDependency = make(map[vo.NodeKey][]*compose.FieldMapping)
inputsNoDirectDependencyFull map[vo.NodeKey]struct{}
staticValues []*staticValue
variableInfos []*variableInfo
// inputsForParent contains all the field mappings from any nodes of the parent workflow
inputsForParent = make(map[vo.NodeKey][]*compose.FieldMapping)
)
connMap := make(map[vo.NodeKey]Connection)
for _, conn := range w.connections {
if conn.ToNode != n {
continue
}
connMap[conn.FromNode] = *conn
}
for _, swp := range sourceWithPaths {
if swp.Source.Val != nil {
staticValues = append(staticValues, &staticValue{
val: swp.Source.Val,
path: swp.Path,
})
} else if swp.Source.Ref != nil {
fromNode := swp.Source.Ref.FromNodeKey
if fromNode == n {
return nil, fmt.Errorf("node %s cannot refer to itself, fromPath: %v, toPath: %v", n,
swp.Source.Ref.FromPath, swp.Path)
}
if swp.Source.Ref.VariableType != nil {
// skip all variables, they are handled in state pre handler
variableInfos = append(variableInfos, &variableInfo{
varType: *swp.Source.Ref.VariableType,
fromPath: swp.Source.Ref.FromPath,
toPath: swp.Path,
})
continue
}
if ok := isInSameWorkflow(w.hierarchy, n, fromNode); ok {
if _, ok := connMap[fromNode]; ok { // direct dependency
if len(swp.Source.Ref.FromPath) == 0 && len(swp.Path) == 0 {
if inputFull == nil {
inputFull = make(map[vo.NodeKey]struct{})
}
inputFull[fromNode] = struct{}{}
} else {
inputs[fromNode] = append(inputs[fromNode], compose.MapFieldPaths(swp.Source.Ref.FromPath, swp.Path))
}
} else { // indirect dependency
if len(swp.Source.Ref.FromPath) == 0 && len(swp.Path) == 0 {
if inputsNoDirectDependencyFull == nil {
inputsNoDirectDependencyFull = make(map[vo.NodeKey]struct{})
}
inputsNoDirectDependencyFull[fromNode] = struct{}{}
} else {
inputsNoDirectDependency[fromNode] = append(inputsNoDirectDependency[fromNode],
compose.MapFieldPaths(swp.Source.Ref.FromPath, swp.Path))
}
}
} else if ok := isBelowOneLevel(w.hierarchy, n, fromNode); ok {
firstNodesInInnerWorkflow := true
for _, conn := range connMap {
if isInSameWorkflow(w.hierarchy, n, conn.FromNode) {
// there is another node 'conn.FromNode' that connects to this node, while also at the same level
firstNodesInInnerWorkflow = false
break
}
}
if firstNodesInInnerWorkflow { // one of the first nodes in sub workflow
inputs[compose.START] = append(inputs[compose.START],
compose.MapFieldPaths(
// the START node of inner workflow will proxy for the fields required from parent workflow
// the field path within START node is prepended by the parent node key
joinFieldPath(append(compose.FieldPath{string(fromNode)}, swp.Source.Ref.FromPath...)),
swp.Path))
} else { // not one of the first nodes in sub workflow, either succeeds other nodes or succeeds branches
inputsNoDirectDependency[compose.START] = append(inputsNoDirectDependency[compose.START],
compose.MapFieldPaths(
// same as above, the START node of inner workflow proxies for the fields from parent workflow
joinFieldPath(append(compose.FieldPath{string(fromNode)}, swp.Source.Ref.FromPath...)),
swp.Path))
}
fieldMapping := compose.MapFieldPaths(swp.Source.Ref.FromPath,
// our parent node will proxy for these field mappings, prepending the 'fromNode' to paths
joinFieldPath(append(compose.FieldPath{string(fromNode)}, swp.Source.Ref.FromPath...)))
added := false
for _, existedFieldMapping := range inputsForParent[fromNode] {
if existedFieldMapping.Equals(fieldMapping) {
added = true
break
}
}
if !added {
inputsForParent[fromNode] = append(inputsForParent[fromNode], fieldMapping)
}
}
} else {
return nil, fmt.Errorf("inputField's Val and Ref are both nil. path= %v", swp.Path)
}
}
for fromNodeKey, conn := range connMap {
if conn.FromPort != nil {
continue
}
if isBelowOneLevel(w.hierarchy, n, fromNodeKey) {
fromNodeKey = compose.START
} else if !isInSameWorkflow(w.hierarchy, n, fromNodeKey) {
continue
}
if _, ok := inputs[fromNodeKey]; !ok {
if _, ok := inputsNoDirectDependency[fromNodeKey]; !ok {
var hasFullInput, hasFullDataInput bool
if inputFull != nil {
if _, ok = inputFull[fromNodeKey]; ok {
hasFullInput = true
}
}
if inputsNoDirectDependencyFull != nil {
if _, ok = inputsNoDirectDependencyFull[fromNodeKey]; ok {
hasFullDataInput = true
}
}
if !hasFullInput && !hasFullDataInput {
dependencies = append(dependencies, fromNodeKey)
}
}
}
}
return &dependencyInfo{
inputs: inputs,
inputsFull: inputFull,
dependencies: dependencies,
inputsNoDirectDependency: inputsNoDirectDependency,
inputsNoDirectDependencyFull: inputsNoDirectDependencyFull,
staticValues: staticValues,
variableInfos: variableInfos,
inputsForParent: inputsForParent,
}, nil
}
const fieldPathSplitter = "#"
func joinFieldPath(f compose.FieldPath) compose.FieldPath {
return []string{strings.Join(f, fieldPathSplitter)}
}
func (w *Workflow) resolveDependenciesAsParent(n vo.NodeKey, sourceWithPaths []*vo.FieldInfo) (*dependencyInfo, error) {
var (
// inputsFull and inputsNoDirectDependencyFull are NEVER used in this case,
// because a composite node MUST use explicit field mappings from inner nodes as its output.
inputs = make(map[vo.NodeKey][]*compose.FieldMapping)
dependencies []vo.NodeKey
inputsNoDirectDependency = make(map[vo.NodeKey][]*compose.FieldMapping)
// although staticValues are not used for current composite nodes,
// they may be used in the future, so we calculate them none the less.
staticValues []*staticValue
// variableInfos are normally handled in state pre handler, but in the case of composite node's output,
// we need to handle them within composite node's state post handler,
variableInfos []*variableInfo
)
connMap := make(map[vo.NodeKey]Connection)
for _, conn := range w.connections {
if conn.ToNode != n {
continue
}
if isInSameWorkflow(w.hierarchy, conn.FromNode, n) {
continue
}
connMap[conn.FromNode] = *conn
}
for _, swp := range sourceWithPaths {
if swp.Source.Ref == nil {
staticValues = append(staticValues, &staticValue{
val: swp.Source.Val,
path: swp.Path,
})
} else if swp.Source.Ref != nil {
if swp.Source.Ref.VariableType != nil {
variableInfos = append(variableInfos, &variableInfo{
varType: *swp.Source.Ref.VariableType,
fromPath: swp.Source.Ref.FromPath,
toPath: swp.Path,
})
continue
}
fromNode := swp.Source.Ref.FromNodeKey
if fromNode == n {
return nil, fmt.Errorf("node %s cannot refer to itself, fromPath= %v, toPath= %v", n,
swp.Source.Ref.FromPath, swp.Path)
}
if ok := isParentOf(w.hierarchy, n, fromNode); ok {
if _, ok := connMap[fromNode]; ok { // direct dependency
inputs[fromNode] = append(inputs[fromNode], compose.MapFieldPaths(swp.Source.Ref.FromPath, append(compose.FieldPath{string(fromNode)}, swp.Source.Ref.FromPath...)))
} else { // indirect dependency
inputsNoDirectDependency[fromNode] = append(inputsNoDirectDependency[fromNode], compose.MapFieldPaths(swp.Source.Ref.FromPath, append(compose.FieldPath{string(fromNode)}, swp.Source.Ref.FromPath...)))
}
}
} else {
return nil, fmt.Errorf("composite node's output field's Val and Ref are both nil. path= %v", swp.Path)
}
}
for fromNodeKey, conn := range connMap {
if conn.FromPort != nil {
continue
}
if _, ok := inputs[fromNodeKey]; !ok {
if _, ok := inputsNoDirectDependency[fromNodeKey]; !ok {
dependencies = append(dependencies, fromNodeKey)
}
}
}
return &dependencyInfo{
inputs: inputs,
dependencies: dependencies,
inputsNoDirectDependency: inputsNoDirectDependency,
staticValues: staticValues,
variableInfos: variableInfos,
}, nil
}

View File

@@ -0,0 +1,84 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package compose
import (
"context"
"github.com/cloudwego/eino/compose"
workflow2 "github.com/coze-dev/coze-studio/backend/domain/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
)
func NewWorkflowFromNode(ctx context.Context, sc *WorkflowSchema, nodeKey vo.NodeKey, opts ...compose.GraphCompileOption) (
*Workflow, error) {
sc.Init()
ns := sc.GetNode(nodeKey)
wf := &Workflow{
workflow: compose.NewWorkflow[map[string]any, map[string]any](compose.WithGenLocalState(GenState())),
hierarchy: sc.Hierarchy,
connections: sc.Connections,
schema: sc,
fromNode: true,
streamRun: false, // single node run can only invoke
requireCheckpoint: sc.requireCheckPoint,
input: ns.InputTypes,
output: ns.OutputTypes,
terminatePlan: vo.ReturnVariables,
}
compositeNodes := sc.GetCompositeNodes()
processedNodeKey := make(map[vo.NodeKey]struct{})
for i := range compositeNodes {
cNode := compositeNodes[i]
if err := wf.AddCompositeNode(ctx, cNode); err != nil {
return nil, err
}
processedNodeKey[cNode.Parent.Key] = struct{}{}
for _, child := range cNode.Children {
processedNodeKey[child.Key] = struct{}{}
}
}
// add all nodes other than composite nodes and their children
for _, ns := range sc.Nodes {
if _, ok := processedNodeKey[ns.Key]; !ok {
if err := wf.AddNode(ctx, ns); err != nil {
return nil, err
}
}
}
wf.End().AddInput(string(nodeKey))
var compileOpts []compose.GraphCompileOption
compileOpts = append(compileOpts, opts...)
if wf.requireCheckpoint {
compileOpts = append(compileOpts, compose.WithCheckPointStore(workflow2.GetRepository()))
}
r, err := wf.Compile(ctx, compileOpts...)
if err != nil {
return nil, err
}
wf.Runner = r
return wf, nil
}

View File

@@ -0,0 +1,292 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package compose
import (
"context"
"fmt"
"runtime/debug"
"strconv"
"strings"
einoCompose "github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
wf "github.com/coze-dev/coze-studio/backend/domain/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
)
type WorkflowRunner struct {
basic *entity.WorkflowBasic
input string
resumeReq *entity.ResumeRequest
schema *WorkflowSchema
streamWriter *schema.StreamWriter[*entity.Message]
config vo.ExecuteConfig
executeID int64
eventChan chan *execute.Event
interruptEvent *entity.InterruptEvent
}
type workflowRunOptions struct {
input string
resumeReq *entity.ResumeRequest
streamWriter *schema.StreamWriter[*entity.Message]
rootTokenCollector *execute.TokenCollector
}
type WorkflowRunnerOption func(*workflowRunOptions)
func WithInput(input string) WorkflowRunnerOption {
return func(opts *workflowRunOptions) {
opts.input = input
}
}
func WithResumeReq(resumeReq *entity.ResumeRequest) WorkflowRunnerOption {
return func(opts *workflowRunOptions) {
opts.resumeReq = resumeReq
}
}
func WithStreamWriter(sw *schema.StreamWriter[*entity.Message]) WorkflowRunnerOption {
return func(opts *workflowRunOptions) {
opts.streamWriter = sw
}
}
func NewWorkflowRunner(b *entity.WorkflowBasic, sc *WorkflowSchema, config vo.ExecuteConfig, opts ...WorkflowRunnerOption) *WorkflowRunner {
options := &workflowRunOptions{}
for _, opt := range opts {
opt(options)
}
return &WorkflowRunner{
basic: b,
input: options.input,
resumeReq: options.resumeReq,
schema: sc,
streamWriter: options.streamWriter,
config: config,
}
}
func (r *WorkflowRunner) Prepare(ctx context.Context) (
context.Context,
int64,
[]einoCompose.Option,
<-chan *execute.Event,
error,
) {
var (
err error
executeID int64
repo = wf.GetRepository()
resumeReq = r.resumeReq
wb = r.basic
sc = r.schema
sw = r.streamWriter
config = r.config
)
if r.resumeReq == nil {
executeID, err = repo.GenID(ctx)
if err != nil {
return ctx, 0, nil, nil, fmt.Errorf("failed to generate workflow execute ID: %w", err)
}
} else {
executeID = resumeReq.ExecuteID
}
eventChan := make(chan *execute.Event)
var (
interruptEvent *entity.InterruptEvent
found bool
)
if resumeReq != nil {
interruptEvent, found, err = repo.GetFirstInterruptEvent(ctx, executeID)
if err != nil {
return ctx, 0, nil, nil, err
}
if !found {
return ctx, 0, nil, nil, fmt.Errorf("interrupt event does not exist, id: %d", resumeReq.EventID)
}
if interruptEvent.ID != resumeReq.EventID {
return ctx, 0, nil, nil, fmt.Errorf("interrupt event id mismatch, expect: %d, actual: %d", resumeReq.EventID, interruptEvent.ID)
}
}
r.executeID = executeID
r.eventChan = eventChan
r.interruptEvent = interruptEvent
ctx, composeOpts, err := r.designateOptions(ctx)
if err != nil {
return ctx, 0, nil, nil, err
}
if interruptEvent != nil {
var stateOpt einoCompose.Option
stateModifier := GenStateModifierByEventType(interruptEvent.EventType,
interruptEvent.NodeKey, resumeReq.ResumeData, r.config)
if len(interruptEvent.NodePath) == 1 {
// this interrupt event is within the top level workflow
stateOpt = einoCompose.WithStateModifier(stateModifier)
} else {
currentI := len(interruptEvent.NodePath) - 2
path := interruptEvent.NodePath[currentI]
if strings.HasPrefix(path, execute.InterruptEventIndexPrefix) {
// this interrupt event is within a composite node
indexStr := path[len(execute.InterruptEventIndexPrefix):]
index, err := strconv.Atoi(indexStr)
if err != nil {
return ctx, 0, nil, nil, fmt.Errorf("failed to parse index: %w", err)
}
currentI--
parentNodeKey := interruptEvent.NodePath[currentI]
stateOpt = einoCompose.WithLambdaOption(
nodes.WithResumeIndex(index, stateModifier)).DesignateNode(parentNodeKey)
} else { // this interrupt event is within a sub workflow
subWorkflowNodeKey := interruptEvent.NodePath[currentI]
stateOpt = einoCompose.WithLambdaOption(
nodes.WithResumeIndex(0, stateModifier)).DesignateNode(subWorkflowNodeKey)
}
for i := currentI - 1; i >= 0; i-- {
path := interruptEvent.NodePath[i]
if strings.HasPrefix(path, execute.InterruptEventIndexPrefix) {
indexStr := path[len(execute.InterruptEventIndexPrefix):]
index, err := strconv.Atoi(indexStr)
if err != nil {
return ctx, 0, nil, nil, fmt.Errorf("failed to parse index: %w", err)
}
i--
parentNodeKey := interruptEvent.NodePath[i]
stateOpt = WrapOptWithIndex(stateOpt, vo.NodeKey(parentNodeKey), index)
} else {
stateOpt = WrapOpt(stateOpt, vo.NodeKey(path))
}
}
}
composeOpts = append(composeOpts, stateOpt)
if interruptEvent.EventType == entity.InterruptEventQuestion {
modifiedData, err := qa.AppendInterruptData(interruptEvent.InterruptData, resumeReq.ResumeData)
if err != nil {
return ctx, 0, nil, nil, fmt.Errorf("failed to append interrupt data: %w", err)
}
interruptEvent.InterruptData = modifiedData
if err = repo.UpdateFirstInterruptEvent(ctx, executeID, interruptEvent); err != nil {
return ctx, 0, nil, nil, fmt.Errorf("failed to update interrupt event: %w", err)
}
} else if interruptEvent.EventType == entity.InterruptEventLLM &&
interruptEvent.ToolInterruptEvent.EventType == entity.InterruptEventQuestion {
modifiedData, err := qa.AppendInterruptData(interruptEvent.ToolInterruptEvent.InterruptData, resumeReq.ResumeData)
if err != nil {
return ctx, 0, nil, nil, fmt.Errorf("failed to append interrupt data for LLM node: %w", err)
}
interruptEvent.ToolInterruptEvent.InterruptData = modifiedData
if err = repo.UpdateFirstInterruptEvent(ctx, executeID, interruptEvent); err != nil {
return ctx, 0, nil, nil, fmt.Errorf("failed to update interrupt event: %w", err)
}
}
success, currentStatus, err := repo.TryLockWorkflowExecution(ctx, executeID, resumeReq.EventID)
if err != nil {
return ctx, 0, nil, nil, fmt.Errorf("try lock workflow execution unexpected err: %w", err)
}
if !success {
return ctx, 0, nil, nil, fmt.Errorf("workflow execution lock failed, current status is %v, executeID: %d", currentStatus, executeID)
}
logs.CtxInfof(ctx, "resuming with eventID: %d, executeID: %d, nodeKey: %s", interruptEvent.ID,
executeID, interruptEvent.NodeKey)
}
if interruptEvent == nil {
var logID string
logID, _ = ctx.Value("log-id").(string)
wfExec := &entity.WorkflowExecution{
ID: executeID,
WorkflowID: wb.ID,
Version: wb.Version,
SpaceID: wb.SpaceID,
ExecuteConfig: config,
Status: entity.WorkflowRunning,
Input: ptr.Of(r.input),
RootExecutionID: executeID,
NodeCount: sc.NodeCount(),
CurrentResumingEventID: ptr.Of(int64(0)),
CommitID: wb.CommitID,
LogID: logID,
}
if err = repo.CreateWorkflowExecution(ctx, wfExec); err != nil {
return ctx, 0, nil, nil, err
}
}
cancelCtx, cancelFn := context.WithCancel(ctx)
var timeoutFn context.CancelFunc
if s := execute.GetStaticConfig(); s != nil {
timeout := ternary.IFElse(config.TaskType == vo.TaskTypeBackground, s.BackgroundRunTimeout, s.ForegroundRunTimeout)
if timeout > 0 {
cancelCtx, timeoutFn = context.WithTimeout(cancelCtx, timeout)
}
}
cancelCtx = execute.InitExecutedNodesCounter(cancelCtx)
lastEventChan := make(chan *execute.Event, 1)
go func() {
defer func() {
if panicErr := recover(); panicErr != nil {
logs.CtxErrorf(ctx, "panic when handling execute event: %v", safego.NewPanicErr(panicErr, debug.Stack()))
}
}()
defer func() {
if sw != nil {
sw.Close()
}
}()
// this goroutine should not use the cancelCtx because it needs to be alive to receive workflow cancel events
lastEventChan <- execute.HandleExecuteEvent(ctx, executeID, eventChan, cancelFn, timeoutFn,
repo, sw, config)
close(lastEventChan)
}()
return cancelCtx, executeID, composeOpts, lastEventChan, nil
}

View File

@@ -0,0 +1,336 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package compose
import (
"fmt"
"maps"
"reflect"
"sync"
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
)
type WorkflowSchema struct {
Nodes []*NodeSchema `json:"nodes"`
Connections []*Connection `json:"connections"`
Hierarchy map[vo.NodeKey]vo.NodeKey `json:"hierarchy,omitempty"` // child node key-> parent node key
GeneratedNodes []vo.NodeKey `json:"generated_nodes,omitempty"` // generated nodes for the nodes in batch mode
nodeMap map[vo.NodeKey]*NodeSchema // won't serialize this
compositeNodes []*CompositeNode // won't serialize this
requireCheckPoint bool // won't serialize this
requireStreaming bool
once sync.Once
}
type Connection struct {
FromNode vo.NodeKey `json:"from_node"`
ToNode vo.NodeKey `json:"to_node"`
FromPort *string `json:"from_port,omitempty"`
}
func (c *Connection) ID() string {
if c.FromPort != nil {
return fmt.Sprintf("%s:%s:%v", c.FromNode, c.ToNode, *c.FromPort)
}
return fmt.Sprintf("%v:%v", c.FromNode, c.ToNode)
}
type CompositeNode struct {
Parent *NodeSchema
Children []*NodeSchema
}
func (w *WorkflowSchema) Init() {
w.once.Do(func() {
w.nodeMap = make(map[vo.NodeKey]*NodeSchema)
for _, node := range w.Nodes {
w.nodeMap[node.Key] = node
}
w.doGetCompositeNodes()
for _, node := range w.Nodes {
if node.requireCheckpoint() {
w.requireCheckPoint = true
break
}
}
w.requireStreaming = w.doRequireStreaming()
})
}
func (w *WorkflowSchema) GetNode(key vo.NodeKey) *NodeSchema {
return w.nodeMap[key]
}
func (w *WorkflowSchema) GetAllNodes() map[vo.NodeKey]*NodeSchema {
return w.nodeMap // TODO: needs to calculate node count separately, considering batch mode nodes
}
func (w *WorkflowSchema) GetCompositeNodes() []*CompositeNode {
if w.compositeNodes == nil {
w.compositeNodes = w.doGetCompositeNodes()
}
return w.compositeNodes
}
func (w *WorkflowSchema) doGetCompositeNodes() (cNodes []*CompositeNode) {
if w.Hierarchy == nil {
return nil
}
// Build parent to children mapping
parentToChildren := make(map[vo.NodeKey][]*NodeSchema)
for childKey, parentKey := range w.Hierarchy {
if parentSchema := w.nodeMap[parentKey]; parentSchema != nil {
if childSchema := w.nodeMap[childKey]; childSchema != nil {
parentToChildren[parentKey] = append(parentToChildren[parentKey], childSchema)
}
}
}
// Create composite nodes
for parentKey, children := range parentToChildren {
if parentSchema := w.nodeMap[parentKey]; parentSchema != nil {
cNodes = append(cNodes, &CompositeNode{
Parent: parentSchema,
Children: children,
})
}
}
return cNodes
}
func isInSameWorkflow(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.NodeKey) bool {
if n == nil {
return true
}
myParents, myParentExists := n[nodeKey]
theirParents, theirParentExists := n[otherNodeKey]
if !myParentExists && !theirParentExists {
return true
}
if !myParentExists || !theirParentExists {
return false
}
return myParents == theirParents
}
func isBelowOneLevel(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.NodeKey) bool {
if n == nil {
return false
}
_, myParentExists := n[nodeKey]
_, theirParentExists := n[otherNodeKey]
return myParentExists && !theirParentExists
}
func isParentOf(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.NodeKey) bool {
if n == nil {
return false
}
theirParent, theirParentExists := n[otherNodeKey]
return theirParentExists && theirParent == nodeKey
}
func (w *WorkflowSchema) IsEqual(other *WorkflowSchema) bool {
otherConnectionsMap := make(map[string]bool, len(other.Connections))
for _, connection := range other.Connections {
otherConnectionsMap[connection.ID()] = true
}
connectionsMap := make(map[string]bool, len(other.Connections))
for _, connection := range w.Connections {
connectionsMap[connection.ID()] = true
}
if !maps.Equal(otherConnectionsMap, connectionsMap) {
return false
}
otherNodeMap := make(map[vo.NodeKey]*NodeSchema, len(other.Nodes))
for _, node := range other.Nodes {
otherNodeMap[node.Key] = node
}
nodeMap := make(map[vo.NodeKey]*NodeSchema, len(w.Nodes))
for _, node := range w.Nodes {
nodeMap[node.Key] = node
}
if !maps.EqualFunc(otherNodeMap, nodeMap, func(node *NodeSchema, other *NodeSchema) bool {
if node.Name != other.Name {
return false
}
if !reflect.DeepEqual(node.Configs, other.Configs) {
return false
}
if !reflect.DeepEqual(node.InputTypes, other.InputTypes) {
return false
}
if !reflect.DeepEqual(node.InputSources, other.InputSources) {
return false
}
if !reflect.DeepEqual(node.OutputTypes, other.OutputTypes) {
return false
}
if !reflect.DeepEqual(node.OutputSources, other.OutputSources) {
return false
}
if !reflect.DeepEqual(node.ExceptionConfigs, other.ExceptionConfigs) {
return false
}
if !reflect.DeepEqual(node.SubWorkflowBasic, other.SubWorkflowBasic) {
return false
}
return true
}) {
return false
}
return true
}
func (w *WorkflowSchema) NodeCount() int32 {
return int32(len(w.Nodes) - len(w.GeneratedNodes))
}
func (w *WorkflowSchema) doRequireStreaming() bool {
producers := make(map[vo.NodeKey]bool)
consumers := make(map[vo.NodeKey]bool)
for _, node := range w.Nodes {
meta := entity.NodeMetaByNodeType(node.Type)
if meta != nil {
sps := meta.ExecutableMeta.StreamingParadigms
if _, ok := sps[entity.Stream]; ok {
if node.StreamConfigs != nil && node.StreamConfigs.CanGeneratesStream {
producers[node.Key] = true
}
}
if sps[entity.Transform] || sps[entity.Collect] {
if node.StreamConfigs != nil && node.StreamConfigs.RequireStreamingInput {
consumers[node.Key] = true
}
}
}
}
if len(producers) == 0 || len(consumers) == 0 {
return false
}
// Build data-flow graph from InputSources
adj := make(map[vo.NodeKey]map[vo.NodeKey]struct{})
for _, node := range w.Nodes {
for _, source := range node.InputSources {
if source.Source.Ref != nil && len(source.Source.Ref.FromNodeKey) > 0 {
if _, ok := adj[source.Source.Ref.FromNodeKey]; !ok {
adj[source.Source.Ref.FromNodeKey] = make(map[vo.NodeKey]struct{})
}
adj[source.Source.Ref.FromNodeKey][node.Key] = struct{}{}
}
}
}
// For each producer, traverse the graph to see if it can reach a consumer
for p := range producers {
q := []vo.NodeKey{p}
visited := make(map[vo.NodeKey]bool)
visited[p] = true
for len(q) > 0 {
curr := q[0]
q = q[1:]
if consumers[curr] {
return true
}
for neighbor := range adj[curr] {
if !visited[neighbor] {
visited[neighbor] = true
q = append(q, neighbor)
}
}
}
}
return false
}
func (w *WorkflowSchema) fanInMergeConfigs() map[string]compose.FanInMergeConfig {
// what we need to do is to see if the workflow requires streaming, if not, then no fan-in merge configs needed
// then we find those nodes that have 'transform' or 'collect' as streaming paradigm,
// and see if each of those nodes has multiple data predecessors, if so, it's a fan-in node.
// then, look up the NodeTypeMeta's ExecutableMeta info and see if it requires fan-in stream merge.
if !w.requireStreaming {
return nil
}
fanInNodes := make(map[vo.NodeKey]bool)
for _, node := range w.Nodes {
meta := entity.NodeMetaByNodeType(node.Type)
if meta != nil {
sps := meta.ExecutableMeta.StreamingParadigms
if sps[entity.Transform] || sps[entity.Collect] {
if node.StreamConfigs != nil && node.StreamConfigs.RequireStreamingInput {
var predecessor *vo.NodeKey
for _, source := range node.InputSources {
if source.Source.Ref != nil && len(source.Source.Ref.FromNodeKey) > 0 {
if predecessor != nil {
fanInNodes[node.Key] = true
break
}
predecessor = &source.Source.Ref.FromNodeKey
}
}
}
}
}
}
fanInConfigs := make(map[string]compose.FanInMergeConfig)
for nodeKey := range fanInNodes {
if m := entity.NodeMetaByNodeType(w.GetNode(nodeKey).Type); m != nil {
if m.StreamSourceEOFAware {
fanInConfigs[string(nodeKey)] = compose.FanInMergeConfig{
StreamMergeWithSourceEOF: true,
}
}
}
}
return fanInConfigs
}

View File

@@ -0,0 +1,333 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package compose
import (
"context"
"fmt"
"strings"
"github.com/cloudwego/eino/components/tool"
einoCompose "github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
wf "github.com/coze-dev/coze-studio/backend/domain/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
)
const answerKey = "output"
type invokableWorkflow struct {
info *schema.ToolInfo
invoke func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (map[string]any, error)
terminatePlan vo.TerminatePlan
wfEntity *entity.Workflow
sc *WorkflowSchema
repo wf.Repository
}
func NewInvokableWorkflow(info *schema.ToolInfo,
invoke func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (map[string]any, error),
terminatePlan vo.TerminatePlan,
wfEntity *entity.Workflow,
sc *WorkflowSchema,
repo wf.Repository,
) wf.ToolFromWorkflow {
return &invokableWorkflow{
info: info,
invoke: invoke,
terminatePlan: terminatePlan,
wfEntity: wfEntity,
sc: sc,
repo: repo,
}
}
func (i *invokableWorkflow) Info(_ context.Context) (*schema.ToolInfo, error) {
return i.info, nil
}
func (i *invokableWorkflow) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) {
rInfo, allIEs := execute.GetResumeRequest(opts...)
var (
previouslyInterrupted bool
callID = einoCompose.GetToolCallID(ctx)
previousExecuteID int64
)
for interruptedCallID := range allIEs {
if callID == interruptedCallID {
previouslyInterrupted = true
previousExecuteID = allIEs[interruptedCallID].ExecuteID
break
}
}
if previouslyInterrupted && rInfo.ExecuteID != previousExecuteID {
logs.Infof("previous interrupted call ID: %s, previous execute ID: %d, current execute ID: %d. Not resuming, interrupt immediately", callID, previousExecuteID, rInfo.ExecuteID)
return "", einoCompose.InterruptAndRerun
}
cfg := execute.GetExecuteConfig(opts...)
var runOpts []WorkflowRunnerOption
if rInfo != nil {
runOpts = append(runOpts, WithResumeReq(rInfo))
} else {
runOpts = append(runOpts, WithInput(argumentsInJSON))
}
if sw := execute.GetIntermediateStreamWriter(opts...); sw != nil {
runOpts = append(runOpts, WithStreamWriter(sw))
}
var (
cancelCtx context.Context
executeID int64
callOpts []einoCompose.Option
in map[string]any
err error
ws *nodes.ConversionWarnings
)
if rInfo == nil {
if err = sonic.UnmarshalString(argumentsInJSON, &in); err != nil {
return "", err
}
var entryNode *NodeSchema
for _, node := range i.sc.Nodes {
if node.Type == entity.NodeTypeEntry {
entryNode = node
break
}
}
if entryNode == nil {
panic("entry node not found in tool workflow")
}
in, ws, err = nodes.ConvertInputs(ctx, in, entryNode.OutputTypes)
if err != nil {
return "", err
} else if ws != nil {
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
}
}
cancelCtx, executeID, callOpts, _, err = NewWorkflowRunner(i.wfEntity.GetBasic(), i.sc, cfg, runOpts...).Prepare(ctx)
if err != nil {
return "", err
}
out, err := i.invoke(cancelCtx, in, callOpts...)
if err != nil {
if _, ok := einoCompose.ExtractInterruptInfo(err); ok {
firstIE, found, err := i.repo.GetFirstInterruptEvent(ctx, executeID)
if err != nil {
return "", err
}
if !found {
return "", fmt.Errorf("interrupt event does not exist, wfExeID: %d", executeID)
}
return "", einoCompose.NewInterruptAndRerunErr(&entity.ToolInterruptEvent{
ToolCallID: einoCompose.GetToolCallID(ctx),
ToolName: i.info.Name,
ExecuteID: executeID,
InterruptEvent: firstIE,
})
}
return "", err
}
if i.terminatePlan == vo.ReturnVariables {
return sonic.MarshalString(out)
}
content, ok := out[answerKey]
if !ok {
return "", fmt.Errorf("no answer found when terminate plan is use answer content. out: %v", out)
}
contentStr, ok := content.(string)
if !ok {
return "", fmt.Errorf("answer content is not string. content: %v", content)
}
if strings.HasSuffix(contentStr, nodes.KeyIsFinished) {
contentStr = strings.TrimSuffix(contentStr, nodes.KeyIsFinished)
}
return contentStr, nil
}
func (i *invokableWorkflow) TerminatePlan() vo.TerminatePlan {
return i.terminatePlan
}
func (i *invokableWorkflow) GetWorkflow() *entity.Workflow {
return i.wfEntity
}
type streamableWorkflow struct {
info *schema.ToolInfo
stream func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (*schema.StreamReader[map[string]any], error)
terminatePlan vo.TerminatePlan
wfEntity *entity.Workflow
sc *WorkflowSchema
repo wf.Repository
}
func NewStreamableWorkflow(info *schema.ToolInfo,
stream func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (*schema.StreamReader[map[string]any], error),
terminatePlan vo.TerminatePlan,
wfEntity *entity.Workflow,
sc *WorkflowSchema,
repo wf.Repository,
) wf.ToolFromWorkflow {
return &streamableWorkflow{
info: info,
stream: stream,
terminatePlan: terminatePlan,
wfEntity: wfEntity,
sc: sc,
repo: repo,
}
}
func (s *streamableWorkflow) Info(_ context.Context) (*schema.ToolInfo, error) {
return s.info, nil
}
func (s *streamableWorkflow) StreamableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (*schema.StreamReader[string], error) {
rInfo, allIEs := execute.GetResumeRequest(opts...)
var (
previouslyInterrupted bool
callID = einoCompose.GetToolCallID(ctx)
previousExecuteID int64
)
for interruptedCallID := range allIEs {
if callID == interruptedCallID {
previouslyInterrupted = true
previousExecuteID = allIEs[interruptedCallID].ExecuteID
break
}
}
if previouslyInterrupted && rInfo.ExecuteID != previousExecuteID {
logs.Infof("previous interrupted call ID: %s, previous execute ID: %d, current execute ID: %d. Not resuming, interrupt immediately", callID, previousExecuteID, rInfo.ExecuteID)
return nil, einoCompose.InterruptAndRerun
}
cfg := execute.GetExecuteConfig(opts...)
var runOpts []WorkflowRunnerOption
if rInfo != nil {
runOpts = append(runOpts, WithResumeReq(rInfo))
} else {
runOpts = append(runOpts, WithInput(argumentsInJSON))
}
if sw := execute.GetIntermediateStreamWriter(opts...); sw != nil {
runOpts = append(runOpts, WithStreamWriter(sw))
}
var (
cancelCtx context.Context
executeID int64
callOpts []einoCompose.Option
in map[string]any
err error
ws *nodes.ConversionWarnings
)
if rInfo == nil {
if err = sonic.UnmarshalString(argumentsInJSON, &in); err != nil {
return nil, err
}
var entryNode *NodeSchema
for _, node := range s.sc.Nodes {
if node.Type == entity.NodeTypeEntry {
entryNode = node
break
}
}
if entryNode == nil {
panic("entry node not found in tool workflow")
}
in, ws, err = nodes.ConvertInputs(ctx, in, entryNode.OutputTypes)
if err != nil {
return nil, err
} else if ws != nil {
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
}
}
cancelCtx, executeID, callOpts, _, err = NewWorkflowRunner(s.wfEntity.GetBasic(), s.sc, cfg, runOpts...).Prepare(ctx)
if err != nil {
return nil, err
}
outStream, err := s.stream(cancelCtx, in, callOpts...)
if err != nil {
if _, ok := einoCompose.ExtractInterruptInfo(err); ok {
firstIE, found, err := s.repo.GetFirstInterruptEvent(ctx, executeID)
if err != nil {
return nil, err
}
if !found {
return nil, fmt.Errorf("interrupt event does not exist, wfExeID: %d", executeID)
}
return nil, einoCompose.NewInterruptAndRerunErr(&entity.ToolInterruptEvent{
ToolCallID: einoCompose.GetToolCallID(ctx),
ToolName: s.info.Name,
ExecuteID: executeID,
InterruptEvent: firstIE,
})
}
return nil, err
}
return schema.StreamReaderWithConvert(outStream, func(in map[string]any) (string, error) {
content, ok := in["output"]
if !ok {
return "", fmt.Errorf("no output found when stream plan is use output content. out: %v", in)
}
contentStr, ok := content.(string)
if !ok {
return "", fmt.Errorf("output content is not string. content: %v", content)
}
if strings.HasSuffix(contentStr, nodes.KeyIsFinished) {
contentStr = strings.TrimSuffix(contentStr, nodes.KeyIsFinished)
}
return contentStr, nil
}), nil
}
func (s *streamableWorkflow) TerminatePlan() vo.TerminatePlan {
return s.terminatePlan
}
func (s *streamableWorkflow) GetWorkflow() *entity.Workflow {
return s.wfEntity
}