refactor: how to add a node type in workflow (#558)

This commit is contained in:
shentongmartin
2025-08-05 14:02:33 +08:00
committed by GitHub
parent 5dafd81a3f
commit bb6ff0026b
96 changed files with 8305 additions and 8717 deletions

View File

@@ -1,181 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package compose
import (
"context"
"errors"
"fmt"
"github.com/cloudwego/eino/compose"
"github.com/spf13/cast"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector"
)
func (s *NodeSchema) OutputPortCount() (int, bool) {
var hasExceptionPort bool
if s.ExceptionConfigs != nil && s.ExceptionConfigs.ProcessType != nil &&
*s.ExceptionConfigs.ProcessType == vo.ErrorProcessTypeExceptionBranch {
hasExceptionPort = true
}
switch s.Type {
case entity.NodeTypeSelector:
return len(mustGetKey[[]*selector.OneClauseSchema]("Clauses", s.Configs)) + 1, hasExceptionPort
case entity.NodeTypeQuestionAnswer:
if mustGetKey[qa.AnswerType]("AnswerType", s.Configs.(map[string]any)) == qa.AnswerByChoices {
if mustGetKey[qa.ChoiceType]("ChoiceType", s.Configs.(map[string]any)) == qa.FixedChoices {
return len(mustGetKey[[]string]("FixedChoices", s.Configs.(map[string]any))) + 1, hasExceptionPort
} else {
return 2, hasExceptionPort
}
}
return 1, hasExceptionPort
case entity.NodeTypeIntentDetector:
intents := mustGetKey[[]string]("Intents", s.Configs.(map[string]any))
return len(intents) + 1, hasExceptionPort
default:
return 1, hasExceptionPort
}
}
type BranchMapping struct {
Normal []map[string]bool
Exception map[string]bool
}
const (
DefaultBranch = "default"
BranchFmt = "branch_%d"
)
func (s *NodeSchema) GetBranch(bMapping *BranchMapping) (*compose.GraphBranch, error) {
if bMapping == nil {
return nil, errors.New("no branch mapping")
}
endNodes := make(map[string]bool)
for i := range bMapping.Normal {
for k := range bMapping.Normal[i] {
endNodes[k] = true
}
}
if bMapping.Exception != nil {
for k := range bMapping.Exception {
endNodes[k] = true
}
}
switch s.Type {
case entity.NodeTypeSelector:
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
choice := in[selector.SelectKey].(int)
if choice < 0 || choice > len(bMapping.Normal) {
return nil, fmt.Errorf("node %s choice out of range: %d", s.Key, choice)
}
choices := make(map[string]bool, len((bMapping.Normal)[choice]))
for k := range (bMapping.Normal)[choice] {
choices[k] = true
}
return choices, nil
}
return compose.NewGraphMultiBranch(condition, endNodes), nil
case entity.NodeTypeQuestionAnswer:
conf := s.Configs.(map[string]any)
if mustGetKey[qa.AnswerType]("AnswerType", conf) == qa.AnswerByChoices {
choiceType := mustGetKey[qa.ChoiceType]("ChoiceType", conf)
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
optionID, ok := nodes.TakeMapValue(in, compose.FieldPath{qa.OptionIDKey})
if !ok {
return nil, fmt.Errorf("failed to take option id from input map: %v", in)
}
if optionID.(string) == "other" {
return (bMapping.Normal)[len(bMapping.Normal)-1], nil
}
if choiceType == qa.DynamicChoices { // all dynamic choices maps to branch 0
return (bMapping.Normal)[0], nil
}
optionIDInt, ok := qa.AlphabetToInt(optionID.(string))
if !ok {
return nil, fmt.Errorf("failed to convert option id from input map: %v", optionID)
}
if optionIDInt < 0 || optionIDInt >= len(bMapping.Normal) {
return nil, fmt.Errorf("failed to take option id from input map: %v", in)
}
return (bMapping.Normal)[optionIDInt], nil
}
return compose.NewGraphMultiBranch(condition, endNodes), nil
}
return nil, fmt.Errorf("this qa node should not have branches: %s", s.Key)
case entity.NodeTypeIntentDetector:
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
isSuccess, ok := in["isSuccess"]
if ok && isSuccess != nil && !isSuccess.(bool) {
return bMapping.Exception, nil
}
classificationId, ok := nodes.TakeMapValue(in, compose.FieldPath{"classificationId"})
if !ok {
return nil, fmt.Errorf("failed to take classification id from input map: %v", in)
}
// Intent detector the node default branch uses classificationId=0. But currently scene, the implementation uses default as the last element of the array.
// Therefore, when classificationId=0, it needs to be converted into the node corresponding to the last index of the array.
// Other options also need to reduce the index by 1.
id, err := cast.ToInt64E(classificationId)
if err != nil {
return nil, err
}
realID := id - 1
if realID >= int64(len(bMapping.Normal)) {
return nil, fmt.Errorf("invalid classification id from input, classification id: %v", classificationId)
}
if realID < 0 {
realID = int64(len(bMapping.Normal)) - 1
}
return (bMapping.Normal)[realID], nil
}
return compose.NewGraphMultiBranch(condition, endNodes), nil
default:
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
isSuccess, ok := in["isSuccess"]
if ok && isSuccess != nil && !isSuccess.(bool) {
return bMapping.Exception, nil
}
return (bMapping.Normal)[0], nil
}
return compose.NewGraphMultiBranch(condition, endNodes), nil
}
}

View File

@@ -1,194 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package compose
import (
"context"
"fmt"
"strconv"
"strings"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector"
)
type selectorCallbackField struct {
Key string `json:"key"`
Type vo.DataType `json:"type"`
Value any `json:"value"`
}
type selectorCondition struct {
Left selectorCallbackField `json:"left"`
Operator vo.OperatorType `json:"operator"`
Right *selectorCallbackField `json:"right,omitempty"`
}
type selectorBranch struct {
Conditions []*selectorCondition `json:"conditions"`
Logic vo.LogicType `json:"logic"`
Name string `json:"name"`
}
func (s *NodeSchema) toSelectorCallbackInput(sc *WorkflowSchema) func(_ context.Context, in map[string]any) (map[string]any, error) {
return func(_ context.Context, in map[string]any) (map[string]any, error) {
config := mustGetKey[[]*selector.OneClauseSchema]("Clauses", s.Configs)
count := len(config)
output := make([]*selectorBranch, count)
for _, source := range s.InputSources {
targetPath := source.Path
if len(targetPath) == 2 {
indexStr := targetPath[0]
index, err := strconv.Atoi(indexStr)
if err != nil {
return nil, err
}
branch := output[index]
if branch == nil {
output[index] = &selectorBranch{
Conditions: []*selectorCondition{
{
Operator: config[index].Single.ToCanvasOperatorType(),
},
},
Logic: selector.ClauseRelationAND.ToVOLogicType(),
}
}
if targetPath[1] == selector.LeftKey {
leftV, ok := nodes.TakeMapValue(in, targetPath)
if !ok {
return nil, fmt.Errorf("failed to take left value of %s", targetPath)
}
if source.Source.Ref.VariableType != nil {
if *source.Source.Ref.VariableType == vo.ParentIntermediate {
parentNodeKey, ok := sc.Hierarchy[s.Key]
if !ok {
return nil, fmt.Errorf("failed to find parent node key of %s", s.Key)
}
parentNode := sc.GetNode(parentNodeKey)
output[index].Conditions[0].Left = selectorCallbackField{
Key: parentNode.Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
Value: leftV,
}
} else {
output[index].Conditions[0].Left = selectorCallbackField{
Key: "",
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
Value: leftV,
}
}
} else {
output[index].Conditions[0].Left = selectorCallbackField{
Key: sc.GetNode(source.Source.Ref.FromNodeKey).Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
Value: leftV,
}
}
} else if targetPath[1] == selector.RightKey {
rightV, ok := nodes.TakeMapValue(in, targetPath)
if !ok {
return nil, fmt.Errorf("failed to take right value of %s", targetPath)
}
output[index].Conditions[0].Right = &selectorCallbackField{
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
Value: rightV,
}
}
} else if len(targetPath) == 3 {
indexStr := targetPath[0]
index, err := strconv.Atoi(indexStr)
if err != nil {
return nil, err
}
multi := config[index].Multi
branch := output[index]
if branch == nil {
output[index] = &selectorBranch{
Conditions: make([]*selectorCondition, len(multi.Clauses)),
Logic: multi.Relation.ToVOLogicType(),
}
}
clauseIndexStr := targetPath[1]
clauseIndex, err := strconv.Atoi(clauseIndexStr)
if err != nil {
return nil, err
}
clause := multi.Clauses[clauseIndex]
if output[index].Conditions[clauseIndex] == nil {
output[index].Conditions[clauseIndex] = &selectorCondition{
Operator: clause.ToCanvasOperatorType(),
}
}
if targetPath[2] == selector.LeftKey {
leftV, ok := nodes.TakeMapValue(in, targetPath)
if !ok {
return nil, fmt.Errorf("failed to take left value of %s", targetPath)
}
if source.Source.Ref.VariableType != nil {
if *source.Source.Ref.VariableType == vo.ParentIntermediate {
parentNodeKey, ok := sc.Hierarchy[s.Key]
if !ok {
return nil, fmt.Errorf("failed to find parent node key of %s", s.Key)
}
parentNode := sc.GetNode(parentNodeKey)
output[index].Conditions[clauseIndex].Left = selectorCallbackField{
Key: parentNode.Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
Value: leftV,
}
} else {
output[index].Conditions[clauseIndex].Left = selectorCallbackField{
Key: "",
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
Value: leftV,
}
}
} else {
output[index].Conditions[clauseIndex].Left = selectorCallbackField{
Key: sc.GetNode(source.Source.Ref.FromNodeKey).Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
Value: leftV,
}
}
} else if targetPath[2] == selector.RightKey {
rightV, ok := nodes.TakeMapValue(in, targetPath)
if !ok {
return nil, fmt.Errorf("failed to take right value of %s", targetPath)
}
output[index].Conditions[clauseIndex].Right = &selectorCallbackField{
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
Value: rightV,
}
}
}
}
return map[string]any{"branches": output}, nil
}
}

View File

@@ -31,7 +31,9 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
@@ -53,7 +55,7 @@ func (r *WorkflowRunner) designateOptions(ctx context.Context) (context.Context,
rootHandler := execute.NewRootWorkflowHandler(
wb,
executeID,
workflowSC.requireCheckPoint,
workflowSC.RequireCheckpoint(),
eventChan,
resumedEvent,
exeCfg,
@@ -67,7 +69,7 @@ func (r *WorkflowRunner) designateOptions(ctx context.Context) (context.Context,
var nodeOpt einoCompose.Option
if ns.Type == entity.NodeTypeExit {
nodeOpt = nodeCallbackOption(key, ns.Name, eventChan, resumedEvent,
ptr.Of(mustGetKey[vo.TerminatePlan]("TerminalPlan", ns.Configs)))
ptr.Of(ns.Configs.(*exit.Config).TerminatePlan))
} else if ns.Type != entity.NodeTypeLambda {
nodeOpt = nodeCallbackOption(key, ns.Name, eventChan, resumedEvent, nil)
}
@@ -117,7 +119,7 @@ func (r *WorkflowRunner) designateOptions(ctx context.Context) (context.Context,
}
}
if workflowSC.requireCheckPoint {
if workflowSC.RequireCheckpoint() {
opts = append(opts, einoCompose.WithCheckPointID(strconv.FormatInt(executeID, 10)))
}
@@ -139,7 +141,7 @@ func WrapOptWithIndex(opt einoCompose.Option, parentNodeKey vo.NodeKey, index in
func (r *WorkflowRunner) designateOptionsForSubWorkflow(ctx context.Context,
parentHandler *execute.WorkflowHandler,
ns *NodeSchema,
ns *schema2.NodeSchema,
pathPrefix ...string) (opts []einoCompose.Option, err error) {
var (
resumeEvent = r.interruptEvent
@@ -163,7 +165,7 @@ func (r *WorkflowRunner) designateOptionsForSubWorkflow(ctx context.Context,
var nodeOpt einoCompose.Option
if subNS.Type == entity.NodeTypeExit {
nodeOpt = nodeCallbackOption(key, subNS.Name, eventChan, resumeEvent,
ptr.Of(mustGetKey[vo.TerminatePlan]("TerminalPlan", subNS.Configs)))
ptr.Of(subNS.Configs.(*exit.Config).TerminatePlan))
} else {
nodeOpt = nodeCallbackOption(key, subNS.Name, eventChan, resumeEvent, nil)
}
@@ -219,7 +221,7 @@ func (r *WorkflowRunner) designateOptionsForSubWorkflow(ctx context.Context,
return opts, nil
}
func llmToolCallbackOptions(ctx context.Context, ns *NodeSchema, eventChan chan *execute.Event,
func llmToolCallbackOptions(ctx context.Context, ns *schema2.NodeSchema, eventChan chan *execute.Event,
sw *schema.StreamWriter[*entity.Message]) (
opts []einoCompose.Option, err error) {
// this is a LLM node.
@@ -229,7 +231,8 @@ func llmToolCallbackOptions(ctx context.Context, ns *NodeSchema, eventChan chan
panic("impossible: llmToolCallbackOptions is called on a non-LLM node")
}
fcParams := getKeyOrZero[*vo.FCParam]("FCParam", ns.Configs)
cfg := ns.Configs.(*llm.Config)
fcParams := cfg.FCParam
if fcParams != nil {
if fcParams.WorkflowFCParam != nil {
// TODO: try to avoid getting the workflow tool all over again
@@ -272,7 +275,7 @@ func llmToolCallbackOptions(ctx context.Context, ns *NodeSchema, eventChan chan
toolHandler := execute.NewToolHandler(eventChan, funcInfo)
opt := einoCompose.WithCallbacks(toolHandler)
opt = einoCompose.WithLambdaOption(llm.WithNestedWorkflowOptions(nodes.WithOptsForNested(opt))).DesignateNode(string(ns.Key))
opt = einoCompose.WithLambdaOption(nodes.WithOptsForNested(opt)).DesignateNode(string(ns.Key))
opts = append(opts, opt)
}
}
@@ -310,7 +313,7 @@ func llmToolCallbackOptions(ctx context.Context, ns *NodeSchema, eventChan chan
toolHandler := execute.NewToolHandler(eventChan, funcInfo)
opt := einoCompose.WithCallbacks(toolHandler)
opt = einoCompose.WithLambdaOption(llm.WithNestedWorkflowOptions(nodes.WithOptsForNested(opt))).DesignateNode(string(ns.Key))
opt = einoCompose.WithLambdaOption(nodes.WithOptsForNested(opt)).DesignateNode(string(ns.Key))
opts = append(opts, opt)
}
}

View File

@@ -25,12 +25,13 @@ import (
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
)
// outputValueFiller will fill the output value with nil if the key is not present in the output map.
// if a node emits stream as output, the node needs to handle these absent keys in stream themselves.
func (s *NodeSchema) outputValueFiller() func(ctx context.Context, output map[string]any) (map[string]any, error) {
func outputValueFiller(s *schema2.NodeSchema) func(ctx context.Context, output map[string]any) (map[string]any, error) {
if len(s.OutputTypes) == 0 {
return func(ctx context.Context, output map[string]any) (map[string]any, error) {
return output, nil
@@ -55,7 +56,7 @@ func (s *NodeSchema) outputValueFiller() func(ctx context.Context, output map[st
// inputValueFiller will fill the input value with default value(zero or nil) if the input value is not present in map.
// if a node accepts stream as input, the node needs to handle these absent keys in stream themselves.
func (s *NodeSchema) inputValueFiller() func(ctx context.Context, input map[string]any) (map[string]any, error) {
func inputValueFiller(s *schema2.NodeSchema) func(ctx context.Context, input map[string]any) (map[string]any, error) {
if len(s.InputTypes) == 0 {
return func(ctx context.Context, input map[string]any) (map[string]any, error) {
return input, nil
@@ -78,7 +79,7 @@ func (s *NodeSchema) inputValueFiller() func(ctx context.Context, input map[stri
}
}
func (s *NodeSchema) streamInputValueFiller() func(ctx context.Context,
func streamInputValueFiller(s *schema2.NodeSchema) func(ctx context.Context,
input *schema.StreamReader[map[string]any]) *schema.StreamReader[map[string]any] {
fn := func(ctx context.Context, i map[string]any) (map[string]any, error) {
newI := make(map[string]any)

View File

@@ -23,6 +23,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
func TestNodeSchema_OutputValueFiller(t *testing.T) {
@@ -282,11 +283,11 @@ func TestNodeSchema_OutputValueFiller(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &NodeSchema{
s := &schema.NodeSchema{
OutputTypes: tt.fields.Outputs,
}
got, err := s.outputValueFiller()(context.Background(), tt.fields.In)
got, err := outputValueFiller(s)(context.Background(), tt.fields.In)
if len(tt.wantErr) > 0 {
assert.Error(t, err)

View File

@@ -0,0 +1,118 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package compose
import (
"context"
"fmt"
"runtime/debug"
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/subworkflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
type Node struct {
Lambda *compose.Lambda
}
// New instantiates the actual node type from NodeSchema.
func New(ctx context.Context, s *schema.NodeSchema,
inner compose.Runnable[map[string]any, map[string]any], // inner workflow for composite node
sc *schema.WorkflowSchema, // the workflow this NodeSchema is in
deps *dependencyInfo, // the dependency for this node pre-calculated by workflow engine
) (_ *Node, err error) {
defer func() {
if panicErr := recover(); panicErr != nil {
err = safego.NewPanicErr(panicErr, debug.Stack())
}
if err != nil {
err = vo.WrapIfNeeded(errno.ErrCreateNodeFail, err, errorx.KV("node_name", s.Name), errorx.KV("cause", err.Error()))
}
}()
var fullSources map[string]*schema.SourceInfo
if m := entity.NodeMetaByNodeType(s.Type); m != nil && m.InputSourceAware {
if fullSources, err = GetFullSources(s, sc, deps); err != nil {
return nil, err
}
s.FullSources = fullSources
}
// if NodeSchema's Configs implements NodeBuilder, will use it to build the node
nb, ok := s.Configs.(schema.NodeBuilder)
if ok {
opts := []schema.BuildOption{
schema.WithWorkflowSchema(sc),
schema.WithInnerWorkflow(inner),
}
// build the actual InvokableNode, etc.
n, err := nb.Build(ctx, s, opts...)
if err != nil {
return nil, err
}
// wrap InvokableNode, etc. within NodeRunner, converting to eino's Lambda
return toNode(s, n), nil
}
switch s.Type {
case entity.NodeTypeLambda:
if s.Lambda == nil {
return nil, fmt.Errorf("lambda is not defined for NodeTypeLambda")
}
return &Node{Lambda: s.Lambda}, nil
case entity.NodeTypeSubWorkflow:
subWorkflow, err := buildSubWorkflow(ctx, s, sc.RequireCheckpoint())
if err != nil {
return nil, err
}
return toNode(s, subWorkflow), nil
default:
panic(fmt.Sprintf("node schema's Configs does not implement NodeBuilder. type: %v", s.Type))
}
}
func buildSubWorkflow(ctx context.Context, s *schema.NodeSchema, requireCheckpoint bool) (*subworkflow.SubWorkflow, error) {
var opts []WorkflowOption
opts = append(opts, WithIDAsName(s.Configs.(*subworkflow.Config).WorkflowID))
if requireCheckpoint {
opts = append(opts, WithParentRequireCheckpoint())
}
if s := execute.GetStaticConfig(); s != nil && s.MaxNodeCountPerWorkflow > 0 {
opts = append(opts, WithMaxNodeCount(s.MaxNodeCountPerWorkflow))
}
wf, err := NewWorkflow(ctx, s.SubWorkflowSchema, opts...)
if err != nil {
return nil, err
}
return &subworkflow.SubWorkflow{
Runner: wf.Runner,
}, nil
}

View File

@@ -33,6 +33,8 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
@@ -48,7 +50,6 @@ type nodeRunConfig[O any] struct {
maxRetry int64
errProcessType vo.ErrorProcessType
dataOnErr func(ctx context.Context) map[string]any
callbackEnabled bool
preProcessors []func(ctx context.Context, input map[string]any) (map[string]any, error)
postProcessors []func(ctx context.Context, input map[string]any) (map[string]any, error)
streamPreProcessors []func(ctx context.Context,
@@ -58,12 +59,14 @@ type nodeRunConfig[O any] struct {
init []func(context.Context) (context.Context, error)
i compose.Invoke[map[string]any, map[string]any, O]
s compose.Stream[map[string]any, map[string]any, O]
c compose.Collect[map[string]any, map[string]any, O]
t compose.Transform[map[string]any, map[string]any, O]
}
func newNodeRunConfig[O any](ns *NodeSchema,
func newNodeRunConfig[O any](ns *schema2.NodeSchema,
i compose.Invoke[map[string]any, map[string]any, O],
s compose.Stream[map[string]any, map[string]any, O],
c compose.Collect[map[string]any, map[string]any, O],
t compose.Transform[map[string]any, map[string]any, O],
opts *newNodeOptions) *nodeRunConfig[O] {
meta := entity.NodeMetaByNodeType(ns.Type)
@@ -92,12 +95,12 @@ func newNodeRunConfig[O any](ns *NodeSchema,
keyFinishedMarkerTrimmer(),
}
if meta.PreFillZero {
preProcessors = append(preProcessors, ns.inputValueFiller())
preProcessors = append(preProcessors, inputValueFiller(ns))
}
var postProcessors []func(ctx context.Context, input map[string]any) (map[string]any, error)
if meta.PostFillNil {
postProcessors = append(postProcessors, ns.outputValueFiller())
postProcessors = append(postProcessors, outputValueFiller(ns))
}
streamPreProcessors := []func(ctx context.Context,
@@ -110,7 +113,15 @@ func newNodeRunConfig[O any](ns *NodeSchema,
},
}
if meta.PreFillZero {
streamPreProcessors = append(streamPreProcessors, ns.streamInputValueFiller())
streamPreProcessors = append(streamPreProcessors, streamInputValueFiller(ns))
}
if meta.UseCtxCache {
opts.init = append([]func(ctx context.Context) (context.Context, error){
func(ctx context.Context) (context.Context, error) {
return ctxcache.Init(ctx), nil
},
}, opts.init...)
}
opts.init = append(opts.init, func(ctx context.Context) (context.Context, error) {
@@ -129,7 +140,6 @@ func newNodeRunConfig[O any](ns *NodeSchema,
maxRetry: maxRetry,
errProcessType: errProcessType,
dataOnErr: dataOnErr,
callbackEnabled: meta.CallbackEnabled,
preProcessors: preProcessors,
postProcessors: postProcessors,
streamPreProcessors: streamPreProcessors,
@@ -138,18 +148,21 @@ func newNodeRunConfig[O any](ns *NodeSchema,
init: opts.init,
i: i,
s: s,
c: c,
t: t,
}
}
func newNodeRunConfigWOOpt(ns *NodeSchema,
func newNodeRunConfigWOOpt(ns *schema2.NodeSchema,
i compose.InvokeWOOpt[map[string]any, map[string]any],
s compose.StreamWOOpt[map[string]any, map[string]any],
c compose.CollectWOOpt[map[string]any, map[string]any],
t compose.TransformWOOpts[map[string]any, map[string]any],
opts *newNodeOptions) *nodeRunConfig[any] {
var (
iWO compose.Invoke[map[string]any, map[string]any, any]
sWO compose.Stream[map[string]any, map[string]any, any]
cWO compose.Collect[map[string]any, map[string]any, any]
tWO compose.Transform[map[string]any, map[string]any, any]
)
@@ -165,13 +178,19 @@ func newNodeRunConfigWOOpt(ns *NodeSchema,
}
}
if c != nil {
cWO = func(ctx context.Context, in *schema.StreamReader[map[string]any], _ ...any) (out map[string]any, err error) {
return c(ctx, in)
}
}
if t != nil {
tWO = func(ctx context.Context, input *schema.StreamReader[map[string]any], opts ...any) (output *schema.StreamReader[map[string]any], err error) {
return t(ctx, input)
}
}
return newNodeRunConfig[any](ns, iWO, sWO, tWO, opts)
return newNodeRunConfig[any](ns, iWO, sWO, cWO, tWO, opts)
}
type newNodeOptions struct {
@@ -180,57 +199,100 @@ type newNodeOptions struct {
init []func(context.Context) (context.Context, error)
}
type newNodeOption func(*newNodeOptions)
func toNode(ns *schema2.NodeSchema, r any) *Node {
iWOpt, _ := r.(nodes.InvokableNodeWOpt)
sWOpt, _ := r.(nodes.StreamableNodeWOpt)
cWOpt, _ := r.(nodes.CollectableNodeWOpt)
tWOpt, _ := r.(nodes.TransformableNodeWOpt)
iWOOpt, _ := r.(nodes.InvokableNode)
sWOOpt, _ := r.(nodes.StreamableNode)
cWOOpt, _ := r.(nodes.CollectableNode)
tWOOpt, _ := r.(nodes.TransformableNode)
func withCallbackInputConverter(f func(context.Context, map[string]any) (map[string]any, error)) newNodeOption {
return func(opts *newNodeOptions) {
opts.callbackInputConverter = f
var wOpt, wOOpt bool
if iWOpt != nil || sWOpt != nil || cWOpt != nil || tWOpt != nil {
wOpt = true
}
}
func withCallbackOutputConverter(f func(context.Context, map[string]any) (*nodes.StructuredCallbackOutput, error)) newNodeOption {
return func(opts *newNodeOptions) {
opts.callbackOutputConverter = f
if iWOOpt != nil || sWOOpt != nil || cWOOpt != nil || tWOOpt != nil {
wOOpt = true
}
if wOpt && wOOpt {
panic("a node's different streaming methods needs to be consistent: " +
"they should ALL have NodeOption or None should have them")
}
if !wOpt && !wOOpt {
panic("a node should implement at least one interface among: InvokableNodeWOpt, StreamableNodeWOpt, CollectableNodeWOpt, TransformableNodeWOpt, InvokableNode, StreamableNode, CollectableNode, TransformableNode")
}
}
func withInit(f func(context.Context) (context.Context, error)) newNodeOption {
return func(opts *newNodeOptions) {
opts.init = append(opts.init, f)
}
}
func invokableNode(ns *NodeSchema, i compose.InvokeWOOpt[map[string]any, map[string]any], opts ...newNodeOption) *Node {
options := &newNodeOptions{}
for _, opt := range opts {
opt(options)
ci, ok := r.(nodes.CallbackInputConverted)
if ok {
options.callbackInputConverter = ci.ToCallbackInput
}
return newNodeRunConfigWOOpt(ns, i, nil, nil, options).toNode()
}
func invokableNodeWO[O any](ns *NodeSchema, i compose.Invoke[map[string]any, map[string]any, O], opts ...newNodeOption) *Node {
options := &newNodeOptions{}
for _, opt := range opts {
opt(options)
co, ok := r.(nodes.CallbackOutputConverted)
if ok {
options.callbackOutputConverter = co.ToCallbackOutput
}
return newNodeRunConfig(ns, i, nil, nil, options).toNode()
}
func invokableTransformableNode(ns *NodeSchema, i compose.InvokeWOOpt[map[string]any, map[string]any],
t compose.TransformWOOpts[map[string]any, map[string]any], opts ...newNodeOption) *Node {
options := &newNodeOptions{}
for _, opt := range opts {
opt(options)
init, ok := r.(nodes.Initializer)
if ok {
options.init = append(options.init, init.Init)
}
return newNodeRunConfigWOOpt(ns, i, nil, t, options).toNode()
}
func invokableStreamableNodeWO[O any](ns *NodeSchema, i compose.Invoke[map[string]any, map[string]any, O], s compose.Stream[map[string]any, map[string]any, O], opts ...newNodeOption) *Node {
options := &newNodeOptions{}
for _, opt := range opts {
opt(options)
if wOpt {
var (
i compose.Invoke[map[string]any, map[string]any, nodes.NodeOption]
s compose.Stream[map[string]any, map[string]any, nodes.NodeOption]
c compose.Collect[map[string]any, map[string]any, nodes.NodeOption]
t compose.Transform[map[string]any, map[string]any, nodes.NodeOption]
)
if iWOpt != nil {
i = iWOpt.Invoke
}
if sWOpt != nil {
s = sWOpt.Stream
}
if cWOpt != nil {
c = cWOpt.Collect
}
if tWOpt != nil {
t = tWOpt.Transform
}
return newNodeRunConfig(ns, i, s, c, t, options).toNode()
}
return newNodeRunConfig(ns, i, s, nil, options).toNode()
var (
i compose.InvokeWOOpt[map[string]any, map[string]any]
s compose.StreamWOOpt[map[string]any, map[string]any]
c compose.CollectWOOpt[map[string]any, map[string]any]
t compose.TransformWOOpts[map[string]any, map[string]any]
)
if iWOOpt != nil {
i = iWOOpt.Invoke
}
if sWOOpt != nil {
s = sWOOpt.Stream
}
if cWOOpt != nil {
c = cWOOpt.Collect
}
if tWOOpt != nil {
t = tWOOpt.Transform
}
return newNodeRunConfigWOOpt(ns, i, s, c, t, options).toNode()
}
func (nc *nodeRunConfig[O]) invoke() func(ctx context.Context, input map[string]any, opts ...O) (output map[string]any, err error) {
@@ -375,10 +437,8 @@ func (nc *nodeRunConfig[O]) transform() func(ctx context.Context, input *schema.
func (nc *nodeRunConfig[O]) toNode() *Node {
var opts []compose.LambdaOpt
opts = append(opts, compose.WithLambdaType(string(nc.nodeType)))
opts = append(opts, compose.WithLambdaCallbackEnable(true))
if nc.callbackEnabled {
opts = append(opts, compose.WithLambdaCallbackEnable(true))
}
l, err := compose.AnyLambda(nc.invoke(), nc.stream(), nil, nc.transform(), opts...)
if err != nil {
panic(fmt.Sprintf("failed to create lambda for node %s, err: %v", nc.nodeName, err))
@@ -406,9 +466,6 @@ func newNodeRunner[O any](ctx context.Context, cfg *nodeRunConfig[O]) (context.C
}
func (r *nodeRunner[O]) onStart(ctx context.Context, input map[string]any) (context.Context, error) {
if !r.callbackEnabled {
return ctx, nil
}
if r.callbackInputConverter != nil {
convertedInput, err := r.callbackInputConverter(ctx, input)
if err != nil {
@@ -425,10 +482,6 @@ func (r *nodeRunner[O]) onStart(ctx context.Context, input map[string]any) (cont
func (r *nodeRunner[O]) onStartStream(ctx context.Context, input *schema.StreamReader[map[string]any]) (
context.Context, *schema.StreamReader[map[string]any], error) {
if !r.callbackEnabled {
return ctx, input, nil
}
if r.callbackInputConverter != nil {
copied := input.Copy(2)
realConverter := func(ctx context.Context) func(map[string]any) (map[string]any, error) {
@@ -580,14 +633,10 @@ func (r *nodeRunner[O]) transform(ctx context.Context, input *schema.StreamReade
}
func (r *nodeRunner[O]) onEnd(ctx context.Context, output map[string]any) error {
if r.errProcessType == vo.ErrorProcessTypeExceptionBranch || r.errProcessType == vo.ErrorProcessTypeDefault {
if r.errProcessType == vo.ErrorProcessTypeExceptionBranch || r.errProcessType == vo.ErrorProcessTypeReturnDefaultData {
output["isSuccess"] = true
}
if !r.callbackEnabled {
return nil
}
if r.callbackOutputConverter != nil {
convertedOutput, err := r.callbackOutputConverter(ctx, output)
if err != nil {
@@ -603,15 +652,11 @@ func (r *nodeRunner[O]) onEnd(ctx context.Context, output map[string]any) error
func (r *nodeRunner[O]) onEndStream(ctx context.Context, output *schema.StreamReader[map[string]any]) (
*schema.StreamReader[map[string]any], error) {
if r.errProcessType == vo.ErrorProcessTypeExceptionBranch || r.errProcessType == vo.ErrorProcessTypeDefault {
if r.errProcessType == vo.ErrorProcessTypeExceptionBranch || r.errProcessType == vo.ErrorProcessTypeReturnDefaultData {
flag := schema.StreamReaderFromArray([]map[string]any{{"isSuccess": true}})
output = schema.MergeStreamReaders([]*schema.StreamReader[map[string]any]{flag, output})
}
if !r.callbackEnabled {
return output, nil
}
if r.callbackOutputConverter != nil {
copied := output.Copy(2)
realConverter := func(ctx context.Context) func(map[string]any) (*nodes.StructuredCallbackOutput, error) {
@@ -632,9 +677,7 @@ func (r *nodeRunner[O]) onEndStream(ctx context.Context, output *schema.StreamRe
func (r *nodeRunner[O]) onError(ctx context.Context, err error) (map[string]any, bool) {
if r.interrupted {
if r.callbackEnabled {
_ = callbacks.OnError(ctx, err)
}
_ = callbacks.OnError(ctx, err)
return nil, false
}
@@ -653,22 +696,20 @@ func (r *nodeRunner[O]) onError(ctx context.Context, err error) (map[string]any,
msg := sErr.Msg()
switch r.errProcessType {
case vo.ErrorProcessTypeDefault:
case vo.ErrorProcessTypeReturnDefaultData:
d := r.dataOnErr(ctx)
d["errorBody"] = map[string]any{
"errorMessage": msg,
"errorCode": code,
}
d["isSuccess"] = false
if r.callbackEnabled {
sErr = sErr.ChangeErrLevel(vo.LevelWarn)
sOutput := &nodes.StructuredCallbackOutput{
Output: d,
RawOutput: d,
Error: sErr,
}
_ = callbacks.OnEnd(ctx, sOutput)
sErr = sErr.ChangeErrLevel(vo.LevelWarn)
sOutput := &nodes.StructuredCallbackOutput{
Output: d,
RawOutput: d,
Error: sErr,
}
_ = callbacks.OnEnd(ctx, sOutput)
return d, true
case vo.ErrorProcessTypeExceptionBranch:
s := make(map[string]any)
@@ -677,20 +718,16 @@ func (r *nodeRunner[O]) onError(ctx context.Context, err error) (map[string]any,
"errorCode": code,
}
s["isSuccess"] = false
if r.callbackEnabled {
sErr = sErr.ChangeErrLevel(vo.LevelWarn)
sOutput := &nodes.StructuredCallbackOutput{
Output: s,
RawOutput: s,
Error: sErr,
}
_ = callbacks.OnEnd(ctx, sOutput)
sErr = sErr.ChangeErrLevel(vo.LevelWarn)
sOutput := &nodes.StructuredCallbackOutput{
Output: s,
RawOutput: s,
Error: sErr,
}
_ = callbacks.OnEnd(ctx, sOutput)
return s, true
default:
if r.callbackEnabled {
_ = callbacks.OnError(ctx, sErr)
}
_ = callbacks.OnError(ctx, sErr)
return nil, false
}
}

View File

@@ -1,580 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package compose
import (
"context"
"fmt"
"runtime/debug"
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/batch"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/code"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/conversation"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/emitter"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/httprequester"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/intentdetector"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/json"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/knowledge"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/plugin"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/receiver"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/subworkflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/textprocessor"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableaggregator"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableassigner"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
"github.com/coze-dev/coze-studio/backend/types/errno"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
)
type NodeSchema struct {
Key vo.NodeKey `json:"key"`
Name string `json:"name"`
Type entity.NodeType `json:"type"`
// Configs are node specific configurations with pre-defined config key and config value.
// Will not participate in request-time field mapping, nor as node's static values.
// In a word, these Configs are INTERNAL to node's implementation, the workflow layer is not aware of them.
Configs any `json:"configs,omitempty"`
InputTypes map[string]*vo.TypeInfo `json:"input_types,omitempty"`
InputSources []*vo.FieldInfo `json:"input_sources,omitempty"`
OutputTypes map[string]*vo.TypeInfo `json:"output_types,omitempty"`
OutputSources []*vo.FieldInfo `json:"output_sources,omitempty"` // only applicable to composite nodes such as Batch or Loop
ExceptionConfigs *ExceptionConfig `json:"exception_configs,omitempty"` // generic configurations applicable to most nodes
StreamConfigs *StreamConfig `json:"stream_configs,omitempty"`
SubWorkflowBasic *entity.WorkflowBasic `json:"sub_workflow_basic,omitempty"`
SubWorkflowSchema *WorkflowSchema `json:"sub_workflow_schema,omitempty"`
Lambda *compose.Lambda // not serializable, used for internal test.
}
type ExceptionConfig struct {
TimeoutMS int64 `json:"timeout_ms,omitempty"` // timeout in milliseconds, 0 means no timeout
MaxRetry int64 `json:"max_retry,omitempty"` // max retry times, 0 means no retry
ProcessType *vo.ErrorProcessType `json:"process_type,omitempty"` // error process type, 0 means throw error
DataOnErr string `json:"data_on_err,omitempty"` // data to return when error, effective when ProcessType==Default occurs
}
type StreamConfig struct {
// whether this node has the ability to produce genuine streaming output.
// not include nodes that only passes stream down as they receives them
CanGeneratesStream bool `json:"can_generates_stream,omitempty"`
// whether this node prioritize streaming input over none-streaming input.
// not include nodes that can accept both and does not have preference.
RequireStreamingInput bool `json:"can_process_stream,omitempty"`
}
type Node struct {
Lambda *compose.Lambda
}
func (s *NodeSchema) New(ctx context.Context, inner compose.Runnable[map[string]any, map[string]any],
sc *WorkflowSchema, deps *dependencyInfo) (_ *Node, err error) {
defer func() {
if panicErr := recover(); panicErr != nil {
err = safego.NewPanicErr(panicErr, debug.Stack())
}
if err != nil {
err = vo.WrapIfNeeded(errno.ErrCreateNodeFail, err, errorx.KV("node_name", s.Name), errorx.KV("cause", err.Error()))
}
}()
if m := entity.NodeMetaByNodeType(s.Type); m != nil && m.InputSourceAware {
if err = s.SetFullSources(sc.GetAllNodes(), deps); err != nil {
return nil, err
}
}
switch s.Type {
case entity.NodeTypeLambda:
if s.Lambda == nil {
return nil, fmt.Errorf("lambda is not defined for NodeTypeLambda")
}
return &Node{Lambda: s.Lambda}, nil
case entity.NodeTypeLLM:
conf, err := s.ToLLMConfig(ctx)
if err != nil {
return nil, err
}
l, err := llm.New(ctx, conf)
if err != nil {
return nil, err
}
return invokableStreamableNodeWO(s, l.Chat, l.ChatStream, withCallbackOutputConverter(l.ToCallbackOutput)), nil
case entity.NodeTypeSelector:
conf := s.ToSelectorConfig()
sl, err := selector.NewSelector(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, sl.Select, withCallbackInputConverter(s.toSelectorCallbackInput(sc)), withCallbackOutputConverter(sl.ToCallbackOutput)), nil
case entity.NodeTypeBatch:
if inner == nil {
return nil, fmt.Errorf("inner workflow must not be nil when creating batch node")
}
conf, err := s.ToBatchConfig(inner)
if err != nil {
return nil, err
}
b, err := batch.NewBatch(ctx, conf)
if err != nil {
return nil, err
}
return invokableNodeWO(s, b.Execute, withCallbackInputConverter(b.ToCallbackInput)), nil
case entity.NodeTypeVariableAggregator:
conf, err := s.ToVariableAggregatorConfig()
if err != nil {
return nil, err
}
va, err := variableaggregator.NewVariableAggregator(ctx, conf)
if err != nil {
return nil, err
}
return invokableTransformableNode(s, va.Invoke, va.Transform,
withCallbackInputConverter(va.ToCallbackInput),
withCallbackOutputConverter(va.ToCallbackOutput),
withInit(va.Init)), nil
case entity.NodeTypeTextProcessor:
conf, err := s.ToTextProcessorConfig()
if err != nil {
return nil, err
}
tp, err := textprocessor.NewTextProcessor(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, tp.Invoke), nil
case entity.NodeTypeHTTPRequester:
conf, err := s.ToHTTPRequesterConfig()
if err != nil {
return nil, err
}
hr, err := httprequester.NewHTTPRequester(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, hr.Invoke, withCallbackInputConverter(hr.ToCallbackInput), withCallbackOutputConverter(hr.ToCallbackOutput)), nil
case entity.NodeTypeContinue:
i := func(ctx context.Context, in map[string]any) (map[string]any, error) {
return map[string]any{}, nil
}
return invokableNode(s, i), nil
case entity.NodeTypeBreak:
b, err := loop.NewBreak(ctx, &nodes.ParentIntermediateStore{})
if err != nil {
return nil, err
}
return invokableNode(s, b.DoBreak), nil
case entity.NodeTypeVariableAssigner:
handler := variable.GetVariableHandler()
conf, err := s.ToVariableAssignerConfig(handler)
if err != nil {
return nil, err
}
va, err := variableassigner.NewVariableAssigner(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, va.Assign), nil
case entity.NodeTypeVariableAssignerWithinLoop:
conf, err := s.ToVariableAssignerInLoopConfig()
if err != nil {
return nil, err
}
va, err := variableassigner.NewVariableAssignerInLoop(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, va.Assign), nil
case entity.NodeTypeLoop:
conf, err := s.ToLoopConfig(inner)
if err != nil {
return nil, err
}
l, err := loop.NewLoop(ctx, conf)
if err != nil {
return nil, err
}
return invokableNodeWO(s, l.Execute, withCallbackInputConverter(l.ToCallbackInput)), nil
case entity.NodeTypeQuestionAnswer:
conf, err := s.ToQAConfig(ctx)
if err != nil {
return nil, err
}
qA, err := qa.NewQuestionAnswer(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, qA.Execute, withCallbackOutputConverter(qA.ToCallbackOutput)), nil
case entity.NodeTypeInputReceiver:
conf, err := s.ToInputReceiverConfig()
if err != nil {
return nil, err
}
inputR, err := receiver.New(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, inputR.Invoke, withCallbackOutputConverter(inputR.ToCallbackOutput)), nil
case entity.NodeTypeOutputEmitter:
conf, err := s.ToOutputEmitterConfig(sc)
if err != nil {
return nil, err
}
e, err := emitter.New(ctx, conf)
if err != nil {
return nil, err
}
return invokableTransformableNode(s, e.Emit, e.EmitStream), nil
case entity.NodeTypeEntry:
conf, err := s.ToEntryConfig(ctx)
if err != nil {
return nil, err
}
e, err := entry.NewEntry(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, e.Invoke), nil
case entity.NodeTypeExit:
terminalPlan := mustGetKey[vo.TerminatePlan]("TerminalPlan", s.Configs)
if terminalPlan == vo.ReturnVariables {
i := func(ctx context.Context, in map[string]any) (map[string]any, error) {
if in == nil {
return map[string]any{}, nil
}
return in, nil
}
return invokableNode(s, i), nil
}
conf, err := s.ToOutputEmitterConfig(sc)
if err != nil {
return nil, err
}
e, err := emitter.New(ctx, conf)
if err != nil {
return nil, err
}
return invokableTransformableNode(s, e.Emit, e.EmitStream), nil
case entity.NodeTypeDatabaseCustomSQL:
conf, err := s.ToDatabaseCustomSQLConfig()
if err != nil {
return nil, err
}
sqlER, err := database.NewCustomSQL(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, sqlER.Execute), nil
case entity.NodeTypeDatabaseQuery:
conf, err := s.ToDatabaseQueryConfig()
if err != nil {
return nil, err
}
query, err := database.NewQuery(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, query.Query, withCallbackInputConverter(query.ToCallbackInput)), nil
case entity.NodeTypeDatabaseInsert:
conf, err := s.ToDatabaseInsertConfig()
if err != nil {
return nil, err
}
insert, err := database.NewInsert(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, insert.Insert, withCallbackInputConverter(insert.ToCallbackInput)), nil
case entity.NodeTypeDatabaseUpdate:
conf, err := s.ToDatabaseUpdateConfig()
if err != nil {
return nil, err
}
update, err := database.NewUpdate(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, update.Update, withCallbackInputConverter(update.ToCallbackInput)), nil
case entity.NodeTypeDatabaseDelete:
conf, err := s.ToDatabaseDeleteConfig()
if err != nil {
return nil, err
}
del, err := database.NewDelete(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, del.Delete, withCallbackInputConverter(del.ToCallbackInput)), nil
case entity.NodeTypeKnowledgeIndexer:
conf, err := s.ToKnowledgeIndexerConfig()
if err != nil {
return nil, err
}
w, err := knowledge.NewKnowledgeIndexer(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, w.Store), nil
case entity.NodeTypeKnowledgeRetriever:
conf, err := s.ToKnowledgeRetrieveConfig()
if err != nil {
return nil, err
}
r, err := knowledge.NewKnowledgeRetrieve(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, r.Retrieve), nil
case entity.NodeTypeKnowledgeDeleter:
conf, err := s.ToKnowledgeDeleterConfig()
if err != nil {
return nil, err
}
r, err := knowledge.NewKnowledgeDeleter(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, r.Delete), nil
case entity.NodeTypeCodeRunner:
conf, err := s.ToCodeRunnerConfig()
if err != nil {
return nil, err
}
r, err := code.NewCodeRunner(ctx, conf)
if err != nil {
return nil, err
}
initFn := func(ctx context.Context) (context.Context, error) {
return ctxcache.Init(ctx), nil
}
return invokableNode(s, r.RunCode, withCallbackOutputConverter(r.ToCallbackOutput), withInit(initFn)), nil
case entity.NodeTypePlugin:
conf, err := s.ToPluginConfig()
if err != nil {
return nil, err
}
r, err := plugin.NewPlugin(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, r.Invoke), nil
case entity.NodeTypeCreateConversation:
conf, err := s.ToCreateConversationConfig()
if err != nil {
return nil, err
}
r, err := conversation.NewCreateConversation(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, r.Create), nil
case entity.NodeTypeMessageList:
conf, err := s.ToMessageListConfig()
if err != nil {
return nil, err
}
r, err := conversation.NewMessageList(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, r.List), nil
case entity.NodeTypeClearMessage:
conf, err := s.ToClearMessageConfig()
if err != nil {
return nil, err
}
r, err := conversation.NewClearMessage(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, r.Clear), nil
case entity.NodeTypeIntentDetector:
conf, err := s.ToIntentDetectorConfig(ctx)
if err != nil {
return nil, err
}
r, err := intentdetector.NewIntentDetector(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, r.Invoke), nil
case entity.NodeTypeSubWorkflow:
conf, err := s.ToSubWorkflowConfig(ctx, sc.requireCheckPoint)
if err != nil {
return nil, err
}
r, err := subworkflow.NewSubWorkflow(ctx, conf)
if err != nil {
return nil, err
}
return invokableStreamableNodeWO(s, r.Invoke, r.Stream), nil
case entity.NodeTypeJsonSerialization:
conf, err := s.ToJsonSerializationConfig()
if err != nil {
return nil, err
}
js, err := json.NewJsonSerializer(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, js.Invoke), nil
case entity.NodeTypeJsonDeserialization:
conf, err := s.ToJsonDeserializationConfig()
if err != nil {
return nil, err
}
jd, err := json.NewJsonDeserializer(ctx, conf)
if err != nil {
return nil, err
}
initFn := func(ctx context.Context) (context.Context, error) {
return ctxcache.Init(ctx), nil
}
return invokableNode(s, jd.Invoke, withCallbackOutputConverter(jd.ToCallbackOutput), withInit(initFn)), nil
default:
panic("not implemented")
}
}
func (s *NodeSchema) IsEnableUserQuery() bool {
if s == nil {
return false
}
if s.Type != entity.NodeTypeEntry {
return false
}
if len(s.OutputSources) == 0 {
return false
}
for _, source := range s.OutputSources {
fieldPath := source.Path
if len(fieldPath) == 1 && (fieldPath[0] == "BOT_USER_INPUT" || fieldPath[0] == "USER_INPUT") {
return true
}
}
return false
}
func (s *NodeSchema) IsEnableChatHistory() bool {
if s == nil {
return false
}
switch s.Type {
case entity.NodeTypeLLM:
llmParam := mustGetKey[*model.LLMParams]("LLMParams", s.Configs)
return llmParam.EnableChatHistory
case entity.NodeTypeIntentDetector:
llmParam := mustGetKey[*model.LLMParams]("LLMParams", s.Configs)
return llmParam.EnableChatHistory
default:
return false
}
}
func (s *NodeSchema) IsRefGlobalVariable() bool {
for _, source := range s.InputSources {
if source.IsRefGlobalVariable() {
return true
}
}
for _, source := range s.OutputSources {
if source.IsRefGlobalVariable() {
return true
}
}
return false
}
func (s *NodeSchema) requireCheckpoint() bool {
if s.Type == entity.NodeTypeQuestionAnswer || s.Type == entity.NodeTypeInputReceiver {
return true
}
if s.Type == entity.NodeTypeLLM {
fcParams := getKeyOrZero[*vo.FCParam]("FCParam", s.Configs)
if fcParams != nil && fcParams.WorkflowFCParam != nil {
return true
}
}
if s.Type == entity.NodeTypeSubWorkflow {
s.SubWorkflowSchema.Init()
if s.SubWorkflowSchema.requireCheckPoint {
return true
}
}
return false
}

View File

@@ -32,8 +32,10 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/receiver"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
)
@@ -46,9 +48,9 @@ type State struct {
InterruptEvents map[vo.NodeKey]*entity.InterruptEvent `json:"interrupt_events,omitempty"`
NestedWorkflowStates map[vo.NodeKey]*nodes.NestedWorkflowState `json:"nested_workflow_states,omitempty"`
ExecutedNodes map[vo.NodeKey]bool `json:"executed_nodes,omitempty"`
SourceInfos map[vo.NodeKey]map[string]*nodes.SourceInfo `json:"source_infos,omitempty"`
GroupChoices map[vo.NodeKey]map[string]int `json:"group_choices,omitempty"`
ExecutedNodes map[vo.NodeKey]bool `json:"executed_nodes,omitempty"`
SourceInfos map[vo.NodeKey]map[string]*schema2.SourceInfo `json:"source_infos,omitempty"`
GroupChoices map[vo.NodeKey]map[string]int `json:"group_choices,omitempty"`
ToolInterruptEvents map[vo.NodeKey]map[string] /*ToolCallID*/ *entity.ToolInterruptEvent `json:"tool_interrupt_events,omitempty"`
LLMToResumeData map[vo.NodeKey]string `json:"llm_to_resume_data,omitempty"`
@@ -71,8 +73,8 @@ func init() {
_ = compose.RegisterSerializableType[*model.TokenUsage]("model_token_usage")
_ = compose.RegisterSerializableType[*nodes.NestedWorkflowState]("composite_state")
_ = compose.RegisterSerializableType[*compose.InterruptInfo]("interrupt_info")
_ = compose.RegisterSerializableType[*nodes.SourceInfo]("source_info")
_ = compose.RegisterSerializableType[nodes.FieldStreamType]("field_stream_type")
_ = compose.RegisterSerializableType[*schema2.SourceInfo]("source_info")
_ = compose.RegisterSerializableType[schema2.FieldStreamType]("field_stream_type")
_ = compose.RegisterSerializableType[compose.FieldPath]("field_path")
_ = compose.RegisterSerializableType[*entity.WorkflowBasic]("workflow_basic")
_ = compose.RegisterSerializableType[vo.TerminatePlan]("terminate_plan")
@@ -162,41 +164,41 @@ func (s *State) GetDynamicChoice(nodeKey vo.NodeKey) map[string]int {
return s.GroupChoices[nodeKey]
}
func (s *State) GetDynamicStreamType(nodeKey vo.NodeKey, group string) (nodes.FieldStreamType, error) {
func (s *State) GetDynamicStreamType(nodeKey vo.NodeKey, group string) (schema2.FieldStreamType, error) {
choices, ok := s.GroupChoices[nodeKey]
if !ok {
return nodes.FieldMaybeStream, fmt.Errorf("choice not found for node %s", nodeKey)
return schema2.FieldMaybeStream, fmt.Errorf("choice not found for node %s", nodeKey)
}
choice, ok := choices[group]
if !ok {
return nodes.FieldMaybeStream, fmt.Errorf("choice not found for node %s and group %s", nodeKey, group)
return schema2.FieldMaybeStream, fmt.Errorf("choice not found for node %s and group %s", nodeKey, group)
}
if choice == -1 { // this group picks none of the elements
return nodes.FieldNotStream, nil
return schema2.FieldNotStream, nil
}
sInfos, ok := s.SourceInfos[nodeKey]
if !ok {
return nodes.FieldMaybeStream, fmt.Errorf("source infos not found for node %s", nodeKey)
return schema2.FieldMaybeStream, fmt.Errorf("source infos not found for node %s", nodeKey)
}
groupInfo, ok := sInfos[group]
if !ok {
return nodes.FieldMaybeStream, fmt.Errorf("source infos not found for node %s and group %s", nodeKey, group)
return schema2.FieldMaybeStream, fmt.Errorf("source infos not found for node %s and group %s", nodeKey, group)
}
if groupInfo.SubSources == nil {
return nodes.FieldNotStream, fmt.Errorf("dynamic group %s of node %s does not contain any sub sources", group, nodeKey)
return schema2.FieldNotStream, fmt.Errorf("dynamic group %s of node %s does not contain any sub sources", group, nodeKey)
}
subInfo, ok := groupInfo.SubSources[strconv.Itoa(choice)]
if !ok {
return nodes.FieldNotStream, fmt.Errorf("dynamic group %s of node %s does not contain sub source for choice %d", group, nodeKey, choice)
return schema2.FieldNotStream, fmt.Errorf("dynamic group %s of node %s does not contain sub source for choice %d", group, nodeKey, choice)
}
if subInfo.FieldType != nodes.FieldMaybeStream {
if subInfo.FieldType != schema2.FieldMaybeStream {
return subInfo.FieldType, nil
}
@@ -211,8 +213,8 @@ func (s *State) GetDynamicStreamType(nodeKey vo.NodeKey, group string) (nodes.Fi
return s.GetDynamicStreamType(subInfo.FromNodeKey, subInfo.FromPath[0])
}
func (s *State) GetAllDynamicStreamTypes(nodeKey vo.NodeKey) (map[string]nodes.FieldStreamType, error) {
result := make(map[string]nodes.FieldStreamType)
func (s *State) GetAllDynamicStreamTypes(nodeKey vo.NodeKey) (map[string]schema2.FieldStreamType, error) {
result := make(map[string]schema2.FieldStreamType)
choices, ok := s.GroupChoices[nodeKey]
if !ok {
return result, nil
@@ -269,7 +271,7 @@ func GenState() compose.GenLocalState[*State] {
InterruptEvents: make(map[vo.NodeKey]*entity.InterruptEvent),
NestedWorkflowStates: make(map[vo.NodeKey]*nodes.NestedWorkflowState),
ExecutedNodes: make(map[vo.NodeKey]bool),
SourceInfos: make(map[vo.NodeKey]map[string]*nodes.SourceInfo),
SourceInfos: make(map[vo.NodeKey]map[string]*schema2.SourceInfo),
GroupChoices: make(map[vo.NodeKey]map[string]int),
ToolInterruptEvents: make(map[vo.NodeKey]map[string]*entity.ToolInterruptEvent),
LLMToResumeData: make(map[vo.NodeKey]string),
@@ -277,7 +279,7 @@ func GenState() compose.GenLocalState[*State] {
}
}
func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt {
func statePreHandler(s *schema2.NodeSchema, stream bool) compose.GraphAddNodeOpt {
var (
handlers []compose.StatePreHandler[map[string]any, *State]
streamHandlers []compose.StreamStatePreHandler[map[string]any, *State]
@@ -314,7 +316,7 @@ func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt {
}
return in, nil
})
} else if s.Type == entity.NodeTypeBatch || s.Type == entity.NodeTypeLoop {
} else if entity.NodeMetaByNodeType(s.Type).IsComposite {
handlers = append(handlers, func(ctx context.Context, in map[string]any, state *State) (map[string]any, error) {
if _, ok := state.Inputs[s.Key]; !ok { // first execution, store input for potential resume later
state.Inputs[s.Key] = in
@@ -329,7 +331,7 @@ func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt {
}
if len(handlers) > 0 || !stream {
handlerForVars := s.statePreHandlerForVars()
handlerForVars := statePreHandlerForVars(s)
if handlerForVars != nil {
handlers = append(handlers, handlerForVars)
}
@@ -349,12 +351,12 @@ func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt {
if s.Type == entity.NodeTypeVariableAggregator {
streamHandlers = append(streamHandlers, func(ctx context.Context, in *schema.StreamReader[map[string]any], state *State) (*schema.StreamReader[map[string]any], error) {
state.SourceInfos[s.Key] = mustGetKey[map[string]*nodes.SourceInfo]("FullSources", s.Configs)
state.SourceInfos[s.Key] = s.FullSources
return in, nil
})
}
handlerForVars := s.streamStatePreHandlerForVars()
handlerForVars := streamStatePreHandlerForVars(s)
if handlerForVars != nil {
streamHandlers = append(streamHandlers, handlerForVars)
}
@@ -381,7 +383,7 @@ func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt {
return nil
}
func (s *NodeSchema) statePreHandlerForVars() compose.StatePreHandler[map[string]any, *State] {
func statePreHandlerForVars(s *schema2.NodeSchema) compose.StatePreHandler[map[string]any, *State] {
// checkout the node's inputs, if it has any variable, use the state's variableHandler to get the variables and set them to the input
var vars []*vo.FieldInfo
for _, input := range s.InputSources {
@@ -456,7 +458,7 @@ func (s *NodeSchema) statePreHandlerForVars() compose.StatePreHandler[map[string
}
}
func (s *NodeSchema) streamStatePreHandlerForVars() compose.StreamStatePreHandler[map[string]any, *State] {
func streamStatePreHandlerForVars(s *schema2.NodeSchema) compose.StreamStatePreHandler[map[string]any, *State] {
// checkout the node's inputs, if it has any variables, get the variables and merge them with the input
var vars []*vo.FieldInfo
for _, input := range s.InputSources {
@@ -533,7 +535,7 @@ func (s *NodeSchema) streamStatePreHandlerForVars() compose.StreamStatePreHandle
}
}
func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamStatePreHandler[map[string]any, *State] {
func streamStatePreHandlerForStreamSources(s *schema2.NodeSchema) compose.StreamStatePreHandler[map[string]any, *State] {
// if it does not have source info, do not add this pre handler
if s.Configs == nil {
return nil
@@ -543,7 +545,7 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
case entity.NodeTypeVariableAggregator, entity.NodeTypeOutputEmitter:
return nil
case entity.NodeTypeExit:
terminatePlan := mustGetKey[vo.TerminatePlan]("TerminalPlan", s.Configs)
terminatePlan := s.Configs.(*exit.Config).TerminatePlan
if terminatePlan != vo.ReturnVariables {
return nil
}
@@ -551,7 +553,7 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
// all other node can only accept non-stream inputs, relying on Eino's automatically stream concatenation.
}
sourceInfo := getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs)
sourceInfo := s.FullSources
if len(sourceInfo) == 0 {
return nil
}
@@ -566,10 +568,10 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
var (
anyStream bool
checker func(source *nodes.SourceInfo) bool
checker func(source *schema2.SourceInfo) bool
)
checker = func(source *nodes.SourceInfo) bool {
if source.FieldType != nodes.FieldNotStream {
checker = func(source *schema2.SourceInfo) bool {
if source.FieldType != schema2.FieldNotStream {
return true
}
for _, subSource := range source.SubSources {
@@ -594,8 +596,8 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
return func(ctx context.Context, in *schema.StreamReader[map[string]any], state *State) (*schema.StreamReader[map[string]any], error) {
resolved := map[string]resolvedStreamSource{}
var resolver func(source nodes.SourceInfo) (result *resolvedStreamSource, err error)
resolver = func(source nodes.SourceInfo) (result *resolvedStreamSource, err error) {
var resolver func(source schema2.SourceInfo) (result *resolvedStreamSource, err error)
resolver = func(source schema2.SourceInfo) (result *resolvedStreamSource, err error) {
if source.IsIntermediate {
result = &resolvedStreamSource{
intermediate: true,
@@ -615,14 +617,14 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
}
streamType := source.FieldType
if streamType == nodes.FieldMaybeStream {
if streamType == schema2.FieldMaybeStream {
streamType, err = state.GetDynamicStreamType(source.FromNodeKey, source.FromPath[0])
if err != nil {
return nil, err
}
}
if streamType == nodes.FieldNotStream {
if streamType == schema2.FieldNotStream {
return nil, nil
}
@@ -690,7 +692,7 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
}
}
func (s *NodeSchema) StatePostHandler(stream bool) compose.GraphAddNodeOpt {
func statePostHandler(s *schema2.NodeSchema, stream bool) compose.GraphAddNodeOpt {
var (
handlers []compose.StatePostHandler[map[string]any, *State]
streamHandlers []compose.StreamStatePostHandler[map[string]any, *State]
@@ -702,7 +704,7 @@ func (s *NodeSchema) StatePostHandler(stream bool) compose.GraphAddNodeOpt {
return out, nil
})
forVars := s.streamStatePostHandlerForVars()
forVars := streamStatePostHandlerForVars(s)
if forVars != nil {
streamHandlers = append(streamHandlers, forVars)
}
@@ -725,7 +727,7 @@ func (s *NodeSchema) StatePostHandler(stream bool) compose.GraphAddNodeOpt {
return out, nil
})
forVars := s.statePostHandlerForVars()
forVars := statePostHandlerForVars(s)
if forVars != nil {
handlers = append(handlers, forVars)
}
@@ -745,7 +747,7 @@ func (s *NodeSchema) StatePostHandler(stream bool) compose.GraphAddNodeOpt {
return compose.WithStatePostHandler(handler)
}
func (s *NodeSchema) statePostHandlerForVars() compose.StatePostHandler[map[string]any, *State] {
func statePostHandlerForVars(s *schema2.NodeSchema) compose.StatePostHandler[map[string]any, *State] {
// checkout the node's output sources, if it has any variable,
// use the state's variableHandler to get the variables and set them to the output
var vars []*vo.FieldInfo
@@ -823,7 +825,7 @@ func (s *NodeSchema) statePostHandlerForVars() compose.StatePostHandler[map[stri
}
}
func (s *NodeSchema) streamStatePostHandlerForVars() compose.StreamStatePostHandler[map[string]any, *State] {
func streamStatePostHandlerForVars(s *schema2.NodeSchema) compose.StreamStatePostHandler[map[string]any, *State] {
// checkout the node's output sources, if it has any variables, get the variables and merge them with the output
var vars []*vo.FieldInfo
for _, output := range s.OutputSources {

View File

@@ -21,19 +21,20 @@ import (
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
// SetFullSources calculates REAL input sources for a node.
// GetFullSources calculates REAL input sources for a node.
// It may be different from a NodeSchema's InputSources because of the following reasons:
// 1. a inner node under composite node may refer to a field from a node in its parent workflow,
// this is instead routed to and sourced from the inner workflow's start node.
// 2. at the same time, the composite node needs to delegate the input source to the inner workflow.
// 3. also, some node may have implicit input sources not defined in its NodeSchema's InputSources.
func (s *NodeSchema) SetFullSources(allNS map[vo.NodeKey]*NodeSchema, dep *dependencyInfo) error {
fullSource := make(map[string]*nodes.SourceInfo)
func GetFullSources(s *schema.NodeSchema, sc *schema.WorkflowSchema, dep *dependencyInfo) (
map[string]*schema.SourceInfo, error) {
fullSource := make(map[string]*schema.SourceInfo)
var fieldInfos []vo.FieldInfo
for _, s := range dep.staticValues {
fieldInfos = append(fieldInfos, vo.FieldInfo{
@@ -113,14 +114,14 @@ func (s *NodeSchema) SetFullSources(allNS map[vo.NodeKey]*NodeSchema, dep *depen
tInfo = tInfo.Properties[path[j]]
}
if current, ok := currentSource[path[j]]; !ok {
currentSource[path[j]] = &nodes.SourceInfo{
currentSource[path[j]] = &schema.SourceInfo{
IsIntermediate: true,
FieldType: nodes.FieldNotStream,
FieldType: schema.FieldNotStream,
TypeInfo: tInfo,
SubSources: make(map[string]*nodes.SourceInfo),
SubSources: make(map[string]*schema.SourceInfo),
}
} else if !current.IsIntermediate {
return fmt.Errorf("existing sourceInfo for path %s is not intermediate, conflict", path[:j+1])
return nil, fmt.Errorf("existing sourceInfo for path %s is not intermediate, conflict", path[:j+1])
}
currentSource = currentSource[path[j]].SubSources
@@ -135,9 +136,9 @@ func (s *NodeSchema) SetFullSources(allNS map[vo.NodeKey]*NodeSchema, dep *depen
// static values or variables
if fInfo.Source.Ref == nil || fInfo.Source.Ref.FromNodeKey == "" {
currentSource[lastPath] = &nodes.SourceInfo{
currentSource[lastPath] = &schema.SourceInfo{
IsIntermediate: false,
FieldType: nodes.FieldNotStream,
FieldType: schema.FieldNotStream,
TypeInfo: tInfo,
}
continue
@@ -145,25 +146,25 @@ func (s *NodeSchema) SetFullSources(allNS map[vo.NodeKey]*NodeSchema, dep *depen
fromNodeKey := fInfo.Source.Ref.FromNodeKey
var (
streamType nodes.FieldStreamType
streamType schema.FieldStreamType
err error
)
if len(fromNodeKey) > 0 {
if fromNodeKey == compose.START {
streamType = nodes.FieldNotStream // TODO: set start node to not stream for now until composite node supports transform
streamType = schema.FieldNotStream // TODO: set start node to not stream for now until composite node supports transform
} else {
fromNode, ok := allNS[fromNodeKey]
if !ok {
return fmt.Errorf("node %s not found", fromNodeKey)
fromNode := sc.GetNode(fromNodeKey)
if fromNode == nil {
return nil, fmt.Errorf("node %s not found", fromNodeKey)
}
streamType, err = fromNode.IsStreamingField(fInfo.Source.Ref.FromPath, allNS)
streamType, err = nodes.IsStreamingField(fromNode, fInfo.Source.Ref.FromPath, sc)
if err != nil {
return err
return nil, err
}
}
}
currentSource[lastPath] = &nodes.SourceInfo{
currentSource[lastPath] = &schema.SourceInfo{
IsIntermediate: false,
FieldType: streamType,
FromNodeKey: fromNodeKey,
@@ -172,121 +173,5 @@ func (s *NodeSchema) SetFullSources(allNS map[vo.NodeKey]*NodeSchema, dep *depen
}
}
s.Configs.(map[string]any)["FullSources"] = fullSource
return nil
}
func (s *NodeSchema) IsStreamingField(path compose.FieldPath, allNS map[vo.NodeKey]*NodeSchema) (nodes.FieldStreamType, error) {
if s.Type == entity.NodeTypeExit {
if mustGetKey[nodes.Mode]("Mode", s.Configs) == nodes.Streaming {
if len(path) == 1 && path[0] == "output" {
return nodes.FieldIsStream, nil
}
}
return nodes.FieldNotStream, nil
} else if s.Type == entity.NodeTypeSubWorkflow { // TODO: why not use sub workflow's Mode configuration directly?
subSC := s.SubWorkflowSchema
subExit := subSC.GetNode(entity.ExitNodeKey)
subStreamType, err := subExit.IsStreamingField(path, nil)
if err != nil {
return nodes.FieldNotStream, err
}
return subStreamType, nil
} else if s.Type == entity.NodeTypeVariableAggregator {
if len(path) == 2 { // asking about a specific index within a group
for _, fInfo := range s.InputSources {
if len(fInfo.Path) == len(path) {
equal := true
for i := range fInfo.Path {
if fInfo.Path[i] != path[i] {
equal = false
break
}
}
if equal {
if fInfo.Source.Ref == nil || fInfo.Source.Ref.FromNodeKey == "" {
return nodes.FieldNotStream, nil
}
fromNodeKey := fInfo.Source.Ref.FromNodeKey
fromNode, ok := allNS[fromNodeKey]
if !ok {
return nodes.FieldNotStream, fmt.Errorf("node %s not found", fromNodeKey)
}
return fromNode.IsStreamingField(fInfo.Source.Ref.FromPath, allNS)
}
}
}
} else if len(path) == 1 { // asking about the entire group
var streamCount, notStreamCount int
for _, fInfo := range s.InputSources {
if fInfo.Path[0] == path[0] { // belong to the group
if fInfo.Source.Ref != nil && len(fInfo.Source.Ref.FromNodeKey) > 0 {
fromNode, ok := allNS[fInfo.Source.Ref.FromNodeKey]
if !ok {
return nodes.FieldNotStream, fmt.Errorf("node %s not found", fInfo.Source.Ref.FromNodeKey)
}
subStreamType, err := fromNode.IsStreamingField(fInfo.Source.Ref.FromPath, allNS)
if err != nil {
return nodes.FieldNotStream, err
}
if subStreamType == nodes.FieldMaybeStream {
return nodes.FieldMaybeStream, nil
} else if subStreamType == nodes.FieldIsStream {
streamCount++
} else {
notStreamCount++
}
}
}
}
if streamCount > 0 && notStreamCount == 0 {
return nodes.FieldIsStream, nil
}
if streamCount == 0 && notStreamCount > 0 {
return nodes.FieldNotStream, nil
}
return nodes.FieldMaybeStream, nil
}
}
if s.Type != entity.NodeTypeLLM {
return nodes.FieldNotStream, nil
}
if len(path) != 1 {
return nodes.FieldNotStream, nil
}
outputs := s.OutputTypes
if len(outputs) != 1 && len(outputs) != 2 {
return nodes.FieldNotStream, nil
}
var outputKey string
for key, output := range outputs {
if output.Type != vo.DataTypeString {
return nodes.FieldNotStream, nil
}
if key != "reasoning_content" {
if len(outputKey) > 0 {
return nodes.FieldNotStream, nil
}
outputKey = key
}
}
field := path[0]
if field == "reasoning_content" || field == outputKey {
return nodes.FieldIsStream, nil
}
return nodes.FieldNotStream, nil
return fullSource, nil
}

View File

@@ -28,6 +28,9 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/batch"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
func TestBatch(t *testing.T) {
@@ -52,7 +55,7 @@ func TestBatch(t *testing.T) {
return in, nil
}
lambdaNode1 := &compose2.NodeSchema{
lambdaNode1 := &schema.NodeSchema{
Key: "lambda",
Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(lambda1),
@@ -86,7 +89,7 @@ func TestBatch(t *testing.T) {
},
},
}
lambdaNode2 := &compose2.NodeSchema{
lambdaNode2 := &schema.NodeSchema{
Key: "index",
Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(lambda2),
@@ -103,7 +106,7 @@ func TestBatch(t *testing.T) {
},
}
lambdaNode3 := &compose2.NodeSchema{
lambdaNode3 := &schema.NodeSchema{
Key: "consumer",
Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(lambda3),
@@ -135,23 +138,22 @@ func TestBatch(t *testing.T) {
},
}
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
entryN := &schema.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: &entry.Config{},
}
ns := &compose2.NodeSchema{
Key: "batch_node_key",
Type: entity.NodeTypeBatch,
ns := &schema.NodeSchema{
Key: "batch_node_key",
Type: entity.NodeTypeBatch,
Configs: &batch.Config{},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"array_1"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"array_1"},
},
},
@@ -160,7 +162,7 @@ func TestBatch(t *testing.T) {
Path: compose.FieldPath{"array_2"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"array_2"},
},
},
@@ -214,11 +216,11 @@ func TestBatch(t *testing.T) {
},
}
exit := &compose2.NodeSchema{
exitN := &schema.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
Configs: &exit.Config{
TerminatePlan: vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
@@ -246,18 +248,18 @@ func TestBatch(t *testing.T) {
return map[string]any{"success": true}, nil
}
parentLambdaNode := &compose2.NodeSchema{
parentLambdaNode := &schema.NodeSchema{
Key: "parent_predecessor_1",
Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(parentLambda),
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
ws := &schema.WorkflowSchema{
Nodes: []*schema.NodeSchema{
entryN,
parentLambdaNode,
ns,
exit,
exitN,
lambdaNode1,
lambdaNode2,
lambdaNode3,
@@ -267,7 +269,7 @@ func TestBatch(t *testing.T) {
"index": "batch_node_key",
"consumer": "batch_node_key",
},
Connections: []*compose2.Connection{
Connections: []*schema.Connection{
{
FromNode: entity.EntryNodeKey,
ToNode: "parent_predecessor_1",

View File

@@ -40,7 +40,11 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/emitter"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
"github.com/coze-dev/coze-studio/backend/internal/testutil"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
@@ -108,22 +112,20 @@ func TestLLM(t *testing.T) {
}
}
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
entryN := &schema2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: &entry.Config{},
}
llmNode := &compose2.NodeSchema{
llmNode := &schema2.NodeSchema{
Key: "llm_node_key",
Type: entity.NodeTypeLLM,
Configs: map[string]any{
"SystemPrompt": "{{sys_prompt}}",
"UserPrompt": "{{query}}",
"OutputFormat": llm.FormatText,
"LLMParams": &model.LLMParams{
Configs: &llm.Config{
SystemPrompt: "{{sys_prompt}}",
UserPrompt: "{{query}}",
OutputFormat: llm.FormatText,
LLMParams: &model.LLMParams{
ModelName: modelName,
},
},
@@ -132,7 +134,7 @@ func TestLLM(t *testing.T) {
Path: compose.FieldPath{"sys_prompt"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"sys_prompt"},
},
},
@@ -141,7 +143,7 @@ func TestLLM(t *testing.T) {
Path: compose.FieldPath{"query"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"query"},
},
},
@@ -162,11 +164,11 @@ func TestLLM(t *testing.T) {
},
}
exit := &compose2.NodeSchema{
exitN := &schema2.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
Configs: &exit.Config{
TerminatePlan: vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
@@ -181,20 +183,20 @@ func TestLLM(t *testing.T) {
},
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
ws := &schema2.WorkflowSchema{
Nodes: []*schema2.NodeSchema{
entryN,
llmNode,
exit,
exitN,
},
Connections: []*compose2.Connection{
Connections: []*schema2.Connection{
{
FromNode: entry.Key,
FromNode: entryN.Key,
ToNode: llmNode.Key,
},
{
FromNode: llmNode.Key,
ToNode: exit.Key,
ToNode: exitN.Key,
},
},
}
@@ -228,27 +230,20 @@ func TestLLM(t *testing.T) {
}
}
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
entryN := &schema2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: &entry.Config{},
}
llmNode := &compose2.NodeSchema{
llmNode := &schema2.NodeSchema{
Key: "llm_node_key",
Type: entity.NodeTypeLLM,
Configs: map[string]any{
"SystemPrompt": "you are a helpful assistant",
"UserPrompt": "what's the largest country in the world and it's area size in square kilometers?",
"OutputFormat": llm.FormatJSON,
"IgnoreException": true,
"DefaultOutput": map[string]any{
"country_name": "unknown",
"area_size": int64(0),
},
"LLMParams": &model.LLMParams{
Configs: &llm.Config{
SystemPrompt: "you are a helpful assistant",
UserPrompt: "what's the largest country in the world and it's area size in square kilometers?",
OutputFormat: llm.FormatJSON,
LLMParams: &model.LLMParams{
ModelName: modelName,
},
},
@@ -264,11 +259,11 @@ func TestLLM(t *testing.T) {
},
}
exit := &compose2.NodeSchema{
exitN := &schema2.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
Configs: &exit.Config{
TerminatePlan: vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
@@ -292,20 +287,20 @@ func TestLLM(t *testing.T) {
},
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
ws := &schema2.WorkflowSchema{
Nodes: []*schema2.NodeSchema{
entryN,
llmNode,
exit,
exitN,
},
Connections: []*compose2.Connection{
Connections: []*schema2.Connection{
{
FromNode: entry.Key,
FromNode: entryN.Key,
ToNode: llmNode.Key,
},
{
FromNode: llmNode.Key,
ToNode: exit.Key,
ToNode: exitN.Key,
},
},
}
@@ -337,22 +332,20 @@ func TestLLM(t *testing.T) {
}
}
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
entryN := &schema2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: &entry.Config{},
}
llmNode := &compose2.NodeSchema{
llmNode := &schema2.NodeSchema{
Key: "llm_node_key",
Type: entity.NodeTypeLLM,
Configs: map[string]any{
"SystemPrompt": "you are a helpful assistant",
"UserPrompt": "list the top 5 largest countries in the world",
"OutputFormat": llm.FormatMarkdown,
"LLMParams": &model.LLMParams{
Configs: &llm.Config{
SystemPrompt: "you are a helpful assistant",
UserPrompt: "list the top 5 largest countries in the world",
OutputFormat: llm.FormatMarkdown,
LLMParams: &model.LLMParams{
ModelName: modelName,
},
},
@@ -363,11 +356,11 @@ func TestLLM(t *testing.T) {
},
}
exit := &compose2.NodeSchema{
exitN := &schema2.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
Configs: &exit.Config{
TerminatePlan: vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
@@ -382,20 +375,20 @@ func TestLLM(t *testing.T) {
},
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
ws := &schema2.WorkflowSchema{
Nodes: []*schema2.NodeSchema{
entryN,
llmNode,
exit,
exitN,
},
Connections: []*compose2.Connection{
Connections: []*schema2.Connection{
{
FromNode: entry.Key,
FromNode: entryN.Key,
ToNode: llmNode.Key,
},
{
FromNode: llmNode.Key,
ToNode: exit.Key,
ToNode: exitN.Key,
},
},
}
@@ -456,22 +449,20 @@ func TestLLM(t *testing.T) {
}
}
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
entryN := &schema2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: &entry.Config{},
}
openaiNode := &compose2.NodeSchema{
openaiNode := &schema2.NodeSchema{
Key: "openai_llm_node_key",
Type: entity.NodeTypeLLM,
Configs: map[string]any{
"SystemPrompt": "you are a helpful assistant",
"UserPrompt": "plan a 10 day family visit to China.",
"OutputFormat": llm.FormatText,
"LLMParams": &model.LLMParams{
Configs: &llm.Config{
SystemPrompt: "you are a helpful assistant",
UserPrompt: "plan a 10 day family visit to China.",
OutputFormat: llm.FormatText,
LLMParams: &model.LLMParams{
ModelName: modelName,
},
},
@@ -482,14 +473,14 @@ func TestLLM(t *testing.T) {
},
}
deepseekNode := &compose2.NodeSchema{
deepseekNode := &schema2.NodeSchema{
Key: "deepseek_llm_node_key",
Type: entity.NodeTypeLLM,
Configs: map[string]any{
"SystemPrompt": "you are a helpful assistant",
"UserPrompt": "thoroughly plan a 10 day family visit to China. Use your reasoning ability.",
"OutputFormat": llm.FormatText,
"LLMParams": &model.LLMParams{
Configs: &llm.Config{
SystemPrompt: "you are a helpful assistant",
UserPrompt: "thoroughly plan a 10 day family visit to China. Use your reasoning ability.",
OutputFormat: llm.FormatText,
LLMParams: &model.LLMParams{
ModelName: modelName,
},
},
@@ -503,12 +494,11 @@ func TestLLM(t *testing.T) {
},
}
emitterNode := &compose2.NodeSchema{
emitterNode := &schema2.NodeSchema{
Key: "emitter_node_key",
Type: entity.NodeTypeOutputEmitter,
Configs: map[string]any{
"Template": "prefix {{inputObj.field1}} {{input2}} {{deepseek_reasoning}} \n\n###\n\n {{openai_output}} \n\n###\n\n {{deepseek_output}} {{inputObj.field2}} suffix",
"Mode": nodes.Streaming,
Configs: &emitter.Config{
Template: "prefix {{inputObj.field1}} {{input2}} {{deepseek_reasoning}} \n\n###\n\n {{openai_output}} \n\n###\n\n {{deepseek_output}} {{inputObj.field2}} suffix",
},
InputSources: []*vo.FieldInfo{
{
@@ -542,7 +532,7 @@ func TestLLM(t *testing.T) {
Path: compose.FieldPath{"inputObj"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"inputObj"},
},
},
@@ -551,7 +541,7 @@ func TestLLM(t *testing.T) {
Path: compose.FieldPath{"input2"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"input2"},
},
},
@@ -559,11 +549,11 @@ func TestLLM(t *testing.T) {
},
}
exit := &compose2.NodeSchema{
exitN := &schema2.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.UseAnswerContent,
Configs: &exit.Config{
TerminatePlan: vo.UseAnswerContent,
},
InputSources: []*vo.FieldInfo{
{
@@ -596,17 +586,17 @@ func TestLLM(t *testing.T) {
},
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
ws := &schema2.WorkflowSchema{
Nodes: []*schema2.NodeSchema{
entryN,
openaiNode,
deepseekNode,
emitterNode,
exit,
exitN,
},
Connections: []*compose2.Connection{
Connections: []*schema2.Connection{
{
FromNode: entry.Key,
FromNode: entryN.Key,
ToNode: openaiNode.Key,
},
{
@@ -614,7 +604,7 @@ func TestLLM(t *testing.T) {
ToNode: emitterNode.Key,
},
{
FromNode: entry.Key,
FromNode: entryN.Key,
ToNode: deepseekNode.Key,
},
{
@@ -623,7 +613,7 @@ func TestLLM(t *testing.T) {
},
{
FromNode: emitterNode.Key,
ToNode: exit.Key,
ToNode: exitN.Key,
},
},
}

View File

@@ -26,15 +26,20 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop"
_break "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop/break"
_continue "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop/continue"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableassigner"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
func TestLoop(t *testing.T) {
t.Run("by iteration", func(t *testing.T) {
// start-> loop_node_key[innerNode->continue] -> end
innerNode := &compose2.NodeSchema{
innerNode := &schema.NodeSchema{
Key: "innerNode",
Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
@@ -54,31 +59,30 @@ func TestLoop(t *testing.T) {
},
}
continueNode := &compose2.NodeSchema{
Key: "continueNode",
Type: entity.NodeTypeContinue,
continueNode := &schema.NodeSchema{
Key: "continueNode",
Type: entity.NodeTypeContinue,
Configs: &_continue.Config{},
}
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
entryN := &schema.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: &entry.Config{},
}
loopNode := &compose2.NodeSchema{
loopNode := &schema.NodeSchema{
Key: "loop_node_key",
Type: entity.NodeTypeLoop,
Configs: map[string]any{
"LoopType": loop.ByIteration,
Configs: &loop.Config{
LoopType: loop.ByIteration,
},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{loop.Count},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"count"},
},
},
@@ -97,11 +101,11 @@ func TestLoop(t *testing.T) {
},
}
exit := &compose2.NodeSchema{
exitN := &schema.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
Configs: &exit.Config{
TerminatePlan: vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
@@ -116,11 +120,11 @@ func TestLoop(t *testing.T) {
},
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
ws := &schema.WorkflowSchema{
Nodes: []*schema.NodeSchema{
entryN,
loopNode,
exit,
exitN,
innerNode,
continueNode,
},
@@ -128,7 +132,7 @@ func TestLoop(t *testing.T) {
"innerNode": "loop_node_key",
"continueNode": "loop_node_key",
},
Connections: []*compose2.Connection{
Connections: []*schema.Connection{
{
FromNode: "loop_node_key",
ToNode: "innerNode",
@@ -142,12 +146,12 @@ func TestLoop(t *testing.T) {
ToNode: "loop_node_key",
},
{
FromNode: entry.Key,
FromNode: entryN.Key,
ToNode: "loop_node_key",
},
{
FromNode: "loop_node_key",
ToNode: exit.Key,
ToNode: exitN.Key,
},
},
}
@@ -168,7 +172,7 @@ func TestLoop(t *testing.T) {
t.Run("infinite", func(t *testing.T) {
// start-> loop_node_key[innerNode->break] -> end
innerNode := &compose2.NodeSchema{
innerNode := &schema.NodeSchema{
Key: "innerNode",
Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
@@ -188,24 +192,23 @@ func TestLoop(t *testing.T) {
},
}
breakNode := &compose2.NodeSchema{
Key: "breakNode",
Type: entity.NodeTypeBreak,
breakNode := &schema.NodeSchema{
Key: "breakNode",
Type: entity.NodeTypeBreak,
Configs: &_break.Config{},
}
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
entryN := &schema.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: &entry.Config{},
}
loopNode := &compose2.NodeSchema{
loopNode := &schema.NodeSchema{
Key: "loop_node_key",
Type: entity.NodeTypeLoop,
Configs: map[string]any{
"LoopType": loop.Infinite,
Configs: &loop.Config{
LoopType: loop.Infinite,
},
OutputSources: []*vo.FieldInfo{
{
@@ -220,11 +223,11 @@ func TestLoop(t *testing.T) {
},
}
exit := &compose2.NodeSchema{
exitN := &schema.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
Configs: &exit.Config{
TerminatePlan: vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
@@ -239,11 +242,11 @@ func TestLoop(t *testing.T) {
},
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
ws := &schema.WorkflowSchema{
Nodes: []*schema.NodeSchema{
entryN,
loopNode,
exit,
exitN,
innerNode,
breakNode,
},
@@ -251,7 +254,7 @@ func TestLoop(t *testing.T) {
"innerNode": "loop_node_key",
"breakNode": "loop_node_key",
},
Connections: []*compose2.Connection{
Connections: []*schema.Connection{
{
FromNode: "loop_node_key",
ToNode: "innerNode",
@@ -265,12 +268,12 @@ func TestLoop(t *testing.T) {
ToNode: "loop_node_key",
},
{
FromNode: entry.Key,
FromNode: entryN.Key,
ToNode: "loop_node_key",
},
{
FromNode: "loop_node_key",
ToNode: exit.Key,
ToNode: exitN.Key,
},
},
}
@@ -290,14 +293,14 @@ func TestLoop(t *testing.T) {
t.Run("by array", func(t *testing.T) {
// start-> loop_node_key[innerNode->variable_assign] -> end
innerNode := &compose2.NodeSchema{
innerNode := &schema.NodeSchema{
Key: "innerNode",
Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
item1 := in["item1"].(string)
item2 := in["item2"].(string)
count := in["count"].(int)
return map[string]any{"total": int(count) + len(item1) + len(item2)}, nil
return map[string]any{"total": count + len(item1) + len(item2)}, nil
}),
InputSources: []*vo.FieldInfo{
{
@@ -330,16 +333,18 @@ func TestLoop(t *testing.T) {
},
}
assigner := &compose2.NodeSchema{
assigner := &schema.NodeSchema{
Key: "assigner",
Type: entity.NodeTypeVariableAssignerWithinLoop,
Configs: []*variableassigner.Pair{
{
Left: vo.Reference{
FromPath: compose.FieldPath{"count"},
VariableType: ptr.Of(vo.ParentIntermediate),
Configs: &variableassigner.InLoopConfig{
Pairs: []*variableassigner.Pair{
{
Left: vo.Reference{
FromPath: compose.FieldPath{"count"},
VariableType: ptr.Of(vo.ParentIntermediate),
},
Right: compose.FieldPath{"total"},
},
Right: compose.FieldPath{"total"},
},
},
InputSources: []*vo.FieldInfo{
@@ -355,19 +360,17 @@ func TestLoop(t *testing.T) {
},
}
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
entryN := &schema.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: &entry.Config{},
}
exit := &compose2.NodeSchema{
exitN := &schema.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
Configs: &exit.Config{
TerminatePlan: vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
@@ -382,12 +385,13 @@ func TestLoop(t *testing.T) {
},
}
loopNode := &compose2.NodeSchema{
loopNode := &schema.NodeSchema{
Key: "loop_node_key",
Type: entity.NodeTypeLoop,
Configs: map[string]any{
"LoopType": loop.ByArray,
"IntermediateVars": map[string]*vo.TypeInfo{
Configs: &loop.Config{
LoopType: loop.ByArray,
InputArrays: []string{"items1", "items2"},
IntermediateVars: map[string]*vo.TypeInfo{
"count": {
Type: vo.DataTypeInteger,
},
@@ -408,7 +412,7 @@ func TestLoop(t *testing.T) {
Path: compose.FieldPath{"items1"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"items1"},
},
},
@@ -417,7 +421,7 @@ func TestLoop(t *testing.T) {
Path: compose.FieldPath{"items2"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"items2"},
},
},
@@ -442,11 +446,11 @@ func TestLoop(t *testing.T) {
},
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
ws := &schema.WorkflowSchema{
Nodes: []*schema.NodeSchema{
entryN,
loopNode,
exit,
exitN,
innerNode,
assigner,
},
@@ -454,7 +458,7 @@ func TestLoop(t *testing.T) {
"innerNode": "loop_node_key",
"assigner": "loop_node_key",
},
Connections: []*compose2.Connection{
Connections: []*schema.Connection{
{
FromNode: "loop_node_key",
ToNode: "innerNode",
@@ -468,12 +472,12 @@ func TestLoop(t *testing.T) {
ToNode: "loop_node_key",
},
{
FromNode: entry.Key,
FromNode: entryN.Key,
ToNode: "loop_node_key",
},
{
FromNode: "loop_node_key",
ToNode: exit.Key,
ToNode: exitN.Key,
},
},
}

View File

@@ -43,8 +43,11 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
repo2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/repo"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/infra/impl/checkpoint"
mock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/idgen"
storageMock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/storage"
@@ -106,26 +109,25 @@ func TestQuestionAnswer(t *testing.T) {
mockey.Mock(workflow.GetRepository).Return(repo).Build()
t.Run("answer directly, no structured output", func(t *testing.T) {
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
}}
entryN := &schema2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: &entry.Config{},
}
ns := &compose2.NodeSchema{
ns := &schema2.NodeSchema{
Key: "qa_node_key",
Type: entity.NodeTypeQuestionAnswer,
Configs: map[string]any{
"QuestionTpl": "{{input}}",
"AnswerType": qa.AnswerDirectly,
Configs: &qa.Config{
QuestionTpl: "{{input}}",
AnswerType: qa.AnswerDirectly,
},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"input"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"query"},
},
},
@@ -133,11 +135,11 @@ func TestQuestionAnswer(t *testing.T) {
},
}
exit := &compose2.NodeSchema{
exitN := &schema2.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
Configs: &exit.Config{
TerminatePlan: vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
@@ -152,20 +154,20 @@ func TestQuestionAnswer(t *testing.T) {
},
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
ws := &schema2.WorkflowSchema{
Nodes: []*schema2.NodeSchema{
entryN,
ns,
exit,
exitN,
},
Connections: []*compose2.Connection{
Connections: []*schema2.Connection{
{
FromNode: entry.Key,
FromNode: entryN.Key,
ToNode: "qa_node_key",
},
{
FromNode: "qa_node_key",
ToNode: exit.Key,
ToNode: exitN.Key,
},
},
}
@@ -210,30 +212,28 @@ func TestQuestionAnswer(t *testing.T) {
mockModelManager.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(oneChatModel, nil, nil).Times(1)
}
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
entryN := &schema2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: &entry.Config{},
}
ns := &compose2.NodeSchema{
ns := &schema2.NodeSchema{
Key: "qa_node_key",
Type: entity.NodeTypeQuestionAnswer,
Configs: map[string]any{
"QuestionTpl": "{{input}}",
"AnswerType": qa.AnswerByChoices,
"ChoiceType": qa.FixedChoices,
"FixedChoices": []string{"{{choice1}}", "{{choice2}}"},
"LLMParams": &model.LLMParams{},
Configs: &qa.Config{
QuestionTpl: "{{input}}",
AnswerType: qa.AnswerByChoices,
ChoiceType: qa.FixedChoices,
FixedChoices: []string{"{{choice1}}", "{{choice2}}"},
LLMParams: &model.LLMParams{},
},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"input"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"query"},
},
},
@@ -242,7 +242,7 @@ func TestQuestionAnswer(t *testing.T) {
Path: compose.FieldPath{"choice1"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"choice1"},
},
},
@@ -251,7 +251,7 @@ func TestQuestionAnswer(t *testing.T) {
Path: compose.FieldPath{"choice2"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"choice2"},
},
},
@@ -259,11 +259,11 @@ func TestQuestionAnswer(t *testing.T) {
},
}
exit := &compose2.NodeSchema{
exitN := &schema2.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
Configs: &exit.Config{
TerminatePlan: vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
@@ -287,7 +287,7 @@ func TestQuestionAnswer(t *testing.T) {
},
}
lambda := &compose2.NodeSchema{
lambda := &schema2.NodeSchema{
Key: "lambda",
Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
@@ -295,26 +295,26 @@ func TestQuestionAnswer(t *testing.T) {
}),
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
ws := &schema2.WorkflowSchema{
Nodes: []*schema2.NodeSchema{
entryN,
ns,
exit,
exitN,
lambda,
},
Connections: []*compose2.Connection{
Connections: []*schema2.Connection{
{
FromNode: entry.Key,
FromNode: entryN.Key,
ToNode: "qa_node_key",
},
{
FromNode: "qa_node_key",
ToNode: exit.Key,
ToNode: exitN.Key,
FromPort: ptr.Of("branch_0"),
},
{
FromNode: "qa_node_key",
ToNode: exit.Key,
ToNode: exitN.Key,
FromPort: ptr.Of("branch_1"),
},
{
@@ -324,11 +324,15 @@ func TestQuestionAnswer(t *testing.T) {
},
{
FromNode: "lambda",
ToNode: exit.Key,
ToNode: exitN.Key,
},
},
}
branches, err := schema2.BuildBranches(ws.Connections)
assert.NoError(t, err)
ws.Branches = branches
ws.Init()
wf, err := compose2.NewWorkflow(context.Background(), ws)
@@ -362,28 +366,26 @@ func TestQuestionAnswer(t *testing.T) {
})
t.Run("answer with dynamic choices", func(t *testing.T) {
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
entryN := &schema2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: &entry.Config{},
}
ns := &compose2.NodeSchema{
ns := &schema2.NodeSchema{
Key: "qa_node_key",
Type: entity.NodeTypeQuestionAnswer,
Configs: map[string]any{
"QuestionTpl": "{{input}}",
"AnswerType": qa.AnswerByChoices,
"ChoiceType": qa.DynamicChoices,
Configs: &qa.Config{
QuestionTpl: "{{input}}",
AnswerType: qa.AnswerByChoices,
ChoiceType: qa.DynamicChoices,
},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"input"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"query"},
},
},
@@ -392,7 +394,7 @@ func TestQuestionAnswer(t *testing.T) {
Path: compose.FieldPath{qa.DynamicChoicesKey},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"choices"},
},
},
@@ -400,11 +402,11 @@ func TestQuestionAnswer(t *testing.T) {
},
}
exit := &compose2.NodeSchema{
exitN := &schema2.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
Configs: &exit.Config{
TerminatePlan: vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
@@ -428,7 +430,7 @@ func TestQuestionAnswer(t *testing.T) {
},
}
lambda := &compose2.NodeSchema{
lambda := &schema2.NodeSchema{
Key: "lambda",
Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
@@ -436,26 +438,26 @@ func TestQuestionAnswer(t *testing.T) {
}),
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
ws := &schema2.WorkflowSchema{
Nodes: []*schema2.NodeSchema{
entryN,
ns,
exit,
exitN,
lambda,
},
Connections: []*compose2.Connection{
Connections: []*schema2.Connection{
{
FromNode: entry.Key,
FromNode: entryN.Key,
ToNode: "qa_node_key",
},
{
FromNode: "qa_node_key",
ToNode: exit.Key,
ToNode: exitN.Key,
FromPort: ptr.Of("branch_0"),
},
{
FromNode: "lambda",
ToNode: exit.Key,
ToNode: exitN.Key,
},
{
FromNode: "qa_node_key",
@@ -465,6 +467,10 @@ func TestQuestionAnswer(t *testing.T) {
},
}
branches, err := schema2.BuildBranches(ws.Connections)
assert.NoError(t, err)
ws.Branches = branches
ws.Init()
wf, err := compose2.NewWorkflow(context.Background(), ws)
@@ -522,31 +528,29 @@ func TestQuestionAnswer(t *testing.T) {
mockModelManager.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(chatModel, nil, nil).Times(1)
}
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
entryN := &schema2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: &entry.Config{},
}
ns := &compose2.NodeSchema{
ns := &schema2.NodeSchema{
Key: "qa_node_key",
Type: entity.NodeTypeQuestionAnswer,
Configs: map[string]any{
"QuestionTpl": "{{input}}",
"AnswerType": qa.AnswerDirectly,
"ExtractFromAnswer": true,
"AdditionalSystemPromptTpl": "{{prompt}}",
"MaxAnswerCount": 2,
"LLMParams": &model.LLMParams{},
Configs: &qa.Config{
QuestionTpl: "{{input}}",
AnswerType: qa.AnswerDirectly,
ExtractFromAnswer: true,
AdditionalSystemPromptTpl: "{{prompt}}",
MaxAnswerCount: 2,
LLMParams: &model.LLMParams{},
},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"input"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"query"},
},
},
@@ -555,7 +559,7 @@ func TestQuestionAnswer(t *testing.T) {
Path: compose.FieldPath{"prompt"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"prompt"},
},
},
@@ -573,11 +577,11 @@ func TestQuestionAnswer(t *testing.T) {
},
}
exit := &compose2.NodeSchema{
exitN := &schema2.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
Configs: &exit.Config{
TerminatePlan: vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
@@ -610,20 +614,20 @@ func TestQuestionAnswer(t *testing.T) {
},
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
ws := &schema2.WorkflowSchema{
Nodes: []*schema2.NodeSchema{
entryN,
ns,
exit,
exitN,
},
Connections: []*compose2.Connection{
Connections: []*schema2.Connection{
{
FromNode: entry.Key,
FromNode: entryN.Key,
ToNode: "qa_node_key",
},
{
FromNode: "qa_node_key",
ToNode: exit.Key,
ToNode: exitN.Key,
},
},
}

View File

@@ -26,26 +26,28 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/textprocessor"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableaggregator"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
func TestAddSelector(t *testing.T) {
// start -> selector, selector.condition1 -> lambda1 -> end, selector.condition2 -> [lambda2, lambda3] -> end, selector default -> end
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
}}
entryN := &schema.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: &entry.Config{},
}
exit := &compose2.NodeSchema{
exitN := &schema.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
Configs: &exit.Config{
TerminatePlan: vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
@@ -84,7 +86,7 @@ func TestAddSelector(t *testing.T) {
}, nil
}
lambdaNode1 := &compose2.NodeSchema{
lambdaNode1 := &schema.NodeSchema{
Key: "lambda1",
Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(lambda1),
@@ -96,7 +98,7 @@ func TestAddSelector(t *testing.T) {
}, nil
}
LambdaNode2 := &compose2.NodeSchema{
LambdaNode2 := &schema.NodeSchema{
Key: "lambda2",
Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(lambda2),
@@ -108,16 +110,16 @@ func TestAddSelector(t *testing.T) {
}, nil
}
lambdaNode3 := &compose2.NodeSchema{
lambdaNode3 := &schema.NodeSchema{
Key: "lambda3",
Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(lambda3),
}
ns := &compose2.NodeSchema{
ns := &schema.NodeSchema{
Key: "selector",
Type: entity.NodeTypeSelector,
Configs: map[string]any{"Clauses": []*selector.OneClauseSchema{
Configs: &selector.Config{Clauses: []*selector.OneClauseSchema{
{
Single: ptr.Of(selector.OperatorEqual),
},
@@ -136,7 +138,7 @@ func TestAddSelector(t *testing.T) {
Path: compose.FieldPath{"0", selector.LeftKey},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"key1"},
},
},
@@ -151,7 +153,7 @@ func TestAddSelector(t *testing.T) {
Path: compose.FieldPath{"1", "0", selector.LeftKey},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"key2"},
},
},
@@ -160,7 +162,7 @@ func TestAddSelector(t *testing.T) {
Path: compose.FieldPath{"1", "0", selector.RightKey},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"key3"},
},
},
@@ -169,7 +171,7 @@ func TestAddSelector(t *testing.T) {
Path: compose.FieldPath{"1", "1", selector.LeftKey},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"key4"},
},
},
@@ -214,18 +216,18 @@ func TestAddSelector(t *testing.T) {
},
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
ws := &schema.WorkflowSchema{
Nodes: []*schema.NodeSchema{
entryN,
ns,
lambdaNode1,
LambdaNode2,
lambdaNode3,
exit,
exitN,
},
Connections: []*compose2.Connection{
Connections: []*schema.Connection{
{
FromNode: entry.Key,
FromNode: entryN.Key,
ToNode: "selector",
},
{
@@ -245,24 +247,28 @@ func TestAddSelector(t *testing.T) {
},
{
FromNode: "selector",
ToNode: exit.Key,
ToNode: exitN.Key,
FromPort: ptr.Of("default"),
},
{
FromNode: "lambda1",
ToNode: exit.Key,
ToNode: exitN.Key,
},
{
FromNode: "lambda2",
ToNode: exit.Key,
ToNode: exitN.Key,
},
{
FromNode: "lambda3",
ToNode: exit.Key,
ToNode: exitN.Key,
},
},
}
branches, err := schema.BuildBranches(ws.Connections)
assert.NoError(t, err)
ws.Branches = branches
ws.Init()
ctx := context.Background()
@@ -303,19 +309,17 @@ func TestAddSelector(t *testing.T) {
}
func TestVariableAggregator(t *testing.T) {
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
entryN := &schema.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: &entry.Config{},
}
exit := &compose2.NodeSchema{
exitN := &schema.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
Configs: &exit.Config{
TerminatePlan: vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
@@ -339,16 +343,16 @@ func TestVariableAggregator(t *testing.T) {
},
}
ns := &compose2.NodeSchema{
ns := &schema.NodeSchema{
Key: "va",
Type: entity.NodeTypeVariableAggregator,
Configs: map[string]any{
"MergeStrategy": variableaggregator.FirstNotNullValue,
"GroupToLen": map[string]int{
Configs: &variableaggregator.Config{
MergeStrategy: variableaggregator.FirstNotNullValue,
GroupLen: map[string]int{
"Group1": 1,
"Group2": 1,
},
"GroupOrder": []string{
GroupOrder: []string{
"Group1",
"Group2",
},
@@ -358,7 +362,7 @@ func TestVariableAggregator(t *testing.T) {
Path: compose.FieldPath{"Group1", "0"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"Str1"},
},
},
@@ -367,7 +371,7 @@ func TestVariableAggregator(t *testing.T) {
Path: compose.FieldPath{"Group2", "0"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"Int1"},
},
},
@@ -401,20 +405,20 @@ func TestVariableAggregator(t *testing.T) {
},
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
ws := &schema.WorkflowSchema{
Nodes: []*schema.NodeSchema{
entryN,
ns,
exit,
exitN,
},
Connections: []*compose2.Connection{
Connections: []*schema.Connection{
{
FromNode: entry.Key,
FromNode: entryN.Key,
ToNode: "va",
},
{
FromNode: "va",
ToNode: exit.Key,
ToNode: exitN.Key,
},
},
}
@@ -448,19 +452,17 @@ func TestVariableAggregator(t *testing.T) {
func TestTextProcessor(t *testing.T) {
t.Run("split", func(t *testing.T) {
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
entryN := &schema.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: &entry.Config{},
}
exit := &compose2.NodeSchema{
exitN := &schema.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
Configs: &exit.Config{
TerminatePlan: vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
@@ -475,19 +477,19 @@ func TestTextProcessor(t *testing.T) {
},
}
ns := &compose2.NodeSchema{
ns := &schema.NodeSchema{
Key: "tp",
Type: entity.NodeTypeTextProcessor,
Configs: map[string]any{
"Type": textprocessor.SplitText,
"Separators": []string{"|"},
Configs: &textprocessor.Config{
Type: textprocessor.SplitText,
Separators: []string{"|"},
},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"String"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"Str"},
},
},
@@ -495,20 +497,20 @@ func TestTextProcessor(t *testing.T) {
},
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
ws := &schema.WorkflowSchema{
Nodes: []*schema.NodeSchema{
ns,
entry,
exit,
entryN,
exitN,
},
Connections: []*compose2.Connection{
Connections: []*schema.Connection{
{
FromNode: entry.Key,
FromNode: entryN.Key,
ToNode: "tp",
},
{
FromNode: "tp",
ToNode: exit.Key,
ToNode: exitN.Key,
},
},
}
@@ -527,19 +529,17 @@ func TestTextProcessor(t *testing.T) {
})
t.Run("concat", func(t *testing.T) {
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
entryN := &schema.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: &entry.Config{},
}
exit := &compose2.NodeSchema{
exitN := &schema.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
Configs: &exit.Config{
TerminatePlan: vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
@@ -554,20 +554,20 @@ func TestTextProcessor(t *testing.T) {
},
}
ns := &compose2.NodeSchema{
ns := &schema.NodeSchema{
Key: "tp",
Type: entity.NodeTypeTextProcessor,
Configs: map[string]any{
"Type": textprocessor.ConcatText,
"Tpl": "{{String1}}_{{String2.f1}}_{{String3.f2[1]}}",
"ConcatChar": "\t",
Configs: &textprocessor.Config{
Type: textprocessor.ConcatText,
Tpl: "{{String1}}_{{String2.f1}}_{{String3.f2[1]}}",
ConcatChar: "\t",
},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"String1"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"Str1"},
},
},
@@ -576,7 +576,7 @@ func TestTextProcessor(t *testing.T) {
Path: compose.FieldPath{"String2"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"Str2"},
},
},
@@ -585,7 +585,7 @@ func TestTextProcessor(t *testing.T) {
Path: compose.FieldPath{"String3"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"Str3"},
},
},
@@ -593,20 +593,20 @@ func TestTextProcessor(t *testing.T) {
},
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
ws := &schema.WorkflowSchema{
Nodes: []*schema.NodeSchema{
ns,
entry,
exit,
entryN,
exitN,
},
Connections: []*compose2.Connection{
Connections: []*schema.Connection{
{
FromNode: entry.Key,
FromNode: entryN.Key,
ToNode: "tp",
},
{
FromNode: "tp",
ToNode: exit.Key,
ToNode: exitN.Key,
},
},
}

View File

@@ -1,652 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package compose
import (
"context"
"errors"
"fmt"
"runtime/debug"
"strconv"
"time"
einomodel "github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/components/tool"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
workflow3 "github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/workflow"
workflow2 "github.com/coze-dev/coze-studio/backend/domain/workflow"
crosscode "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/code"
crossconversation "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/conversation"
crossdatabase "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
crossknowledge "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
crossplugin "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/batch"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/code"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/conversation"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/emitter"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/httprequester"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/intentdetector"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/json"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/knowledge"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/plugin"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/receiver"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/subworkflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/textprocessor"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableaggregator"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableassigner"
"github.com/coze-dev/coze-studio/backend/infra/contract/coderunner"
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
)
func (s *NodeSchema) ToEntryConfig(_ context.Context) (*entry.Config, error) {
return &entry.Config{
DefaultValues: getKeyOrZero[map[string]any]("DefaultValues", s.Configs),
OutputTypes: s.OutputTypes,
}, nil
}
func (s *NodeSchema) ToLLMConfig(ctx context.Context) (*llm.Config, error) {
llmConf := &llm.Config{
SystemPrompt: getKeyOrZero[string]("SystemPrompt", s.Configs),
UserPrompt: getKeyOrZero[string]("UserPrompt", s.Configs),
OutputFormat: mustGetKey[llm.Format]("OutputFormat", s.Configs),
InputFields: s.InputTypes,
OutputFields: s.OutputTypes,
FullSources: getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs),
}
llmParams := getKeyOrZero[*model.LLMParams]("LLMParams", s.Configs)
if llmParams == nil {
return nil, fmt.Errorf("llm node llmParams is required")
}
var (
err error
chatModel, fallbackM einomodel.BaseChatModel
info, fallbackI *modelmgr.Model
modelWithInfo llm.ModelWithInfo
)
chatModel, info, err = model.GetManager().GetModel(ctx, llmParams)
if err != nil {
return nil, err
}
metaConfigs := s.ExceptionConfigs
if metaConfigs != nil && metaConfigs.MaxRetry > 0 {
backupModelParams := getKeyOrZero[*model.LLMParams]("BackupLLMParams", s.Configs)
if backupModelParams != nil {
fallbackM, fallbackI, err = model.GetManager().GetModel(ctx, backupModelParams)
if err != nil {
return nil, err
}
}
}
if fallbackM == nil {
modelWithInfo = llm.NewModel(chatModel, info)
} else {
modelWithInfo = llm.NewModelWithFallback(chatModel, fallbackM, info, fallbackI)
}
llmConf.ChatModel = modelWithInfo
fcParams := getKeyOrZero[*vo.FCParam]("FCParam", s.Configs)
if fcParams != nil {
if fcParams.WorkflowFCParam != nil {
for _, wf := range fcParams.WorkflowFCParam.WorkflowList {
wfIDStr := wf.WorkflowID
wfID, err := strconv.ParseInt(wfIDStr, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid workflow id: %s", wfIDStr)
}
workflowToolConfig := vo.WorkflowToolConfig{}
if wf.FCSetting != nil {
workflowToolConfig.InputParametersConfig = wf.FCSetting.RequestParameters
workflowToolConfig.OutputParametersConfig = wf.FCSetting.ResponseParameters
}
locator := vo.FromDraft
if wf.WorkflowVersion != "" {
locator = vo.FromSpecificVersion
}
wfTool, err := workflow2.GetRepository().WorkflowAsTool(ctx, vo.GetPolicy{
ID: wfID,
QType: locator,
Version: wf.WorkflowVersion,
}, workflowToolConfig)
if err != nil {
return nil, err
}
llmConf.Tools = append(llmConf.Tools, wfTool)
if wfTool.TerminatePlan() == vo.UseAnswerContent {
if llmConf.ToolsReturnDirectly == nil {
llmConf.ToolsReturnDirectly = make(map[string]bool)
}
toolInfo, err := wfTool.Info(ctx)
if err != nil {
return nil, err
}
llmConf.ToolsReturnDirectly[toolInfo.Name] = true
}
}
}
if fcParams.PluginFCParam != nil {
pluginToolsInvokableReq := make(map[int64]*crossplugin.ToolsInvokableRequest)
for _, p := range fcParams.PluginFCParam.PluginList {
pid, err := strconv.ParseInt(p.PluginID, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid plugin id: %s", p.PluginID)
}
toolID, err := strconv.ParseInt(p.ApiId, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid plugin id: %s", p.PluginID)
}
var (
requestParameters []*workflow3.APIParameter
responseParameters []*workflow3.APIParameter
)
if p.FCSetting != nil {
requestParameters = p.FCSetting.RequestParameters
responseParameters = p.FCSetting.ResponseParameters
}
if req, ok := pluginToolsInvokableReq[pid]; ok {
req.ToolsInvokableInfo[toolID] = &crossplugin.ToolsInvokableInfo{
ToolID: toolID,
RequestAPIParametersConfig: requestParameters,
ResponseAPIParametersConfig: responseParameters,
}
} else {
pluginToolsInfoRequest := &crossplugin.ToolsInvokableRequest{
PluginEntity: crossplugin.Entity{
PluginID: pid,
PluginVersion: ptr.Of(p.PluginVersion),
},
ToolsInvokableInfo: map[int64]*crossplugin.ToolsInvokableInfo{
toolID: {
ToolID: toolID,
RequestAPIParametersConfig: requestParameters,
ResponseAPIParametersConfig: responseParameters,
},
},
IsDraft: p.IsDraft,
}
pluginToolsInvokableReq[pid] = pluginToolsInfoRequest
}
}
inInvokableTools := make([]tool.BaseTool, 0, len(fcParams.PluginFCParam.PluginList))
for _, req := range pluginToolsInvokableReq {
toolMap, err := crossplugin.GetPluginService().GetPluginInvokableTools(ctx, req)
if err != nil {
return nil, err
}
for _, t := range toolMap {
inInvokableTools = append(inInvokableTools, crossplugin.NewInvokableTool(t))
}
}
if len(inInvokableTools) > 0 {
llmConf.Tools = inInvokableTools
}
}
if fcParams.KnowledgeFCParam != nil && len(fcParams.KnowledgeFCParam.KnowledgeList) > 0 {
kwChatModel := workflow2.GetRepository().GetKnowledgeRecallChatModel()
if kwChatModel == nil {
return nil, fmt.Errorf("workflow builtin chat model for knowledge recall not configured")
}
knowledgeOperator := crossknowledge.GetKnowledgeOperator()
setting := fcParams.KnowledgeFCParam.GlobalSetting
cfg := &llm.KnowledgeRecallConfig{
ChatModel: kwChatModel,
Retriever: knowledgeOperator,
}
searchType, err := totRetrievalSearchType(setting.SearchMode)
if err != nil {
return nil, err
}
cfg.RetrievalStrategy = &llm.RetrievalStrategy{
RetrievalStrategy: &crossknowledge.RetrievalStrategy{
TopK: ptr.Of(setting.TopK),
MinScore: ptr.Of(setting.MinScore),
SearchType: searchType,
EnableNL2SQL: setting.UseNL2SQL,
EnableQueryRewrite: setting.UseRewrite,
EnableRerank: setting.UseRerank,
},
NoReCallReplyMode: llm.NoReCallReplyMode(setting.NoRecallReplyMode),
NoReCallReplyCustomizePrompt: setting.NoRecallReplyCustomizePrompt,
}
knowledgeIDs := make([]int64, 0, len(fcParams.KnowledgeFCParam.KnowledgeList))
for _, kw := range fcParams.KnowledgeFCParam.KnowledgeList {
kid, err := strconv.ParseInt(kw.ID, 10, 64)
if err != nil {
return nil, err
}
knowledgeIDs = append(knowledgeIDs, kid)
}
detailResp, err := knowledgeOperator.ListKnowledgeDetail(ctx, &crossknowledge.ListKnowledgeDetailRequest{
KnowledgeIDs: knowledgeIDs,
})
if err != nil {
return nil, err
}
cfg.SelectedKnowledgeDetails = detailResp.KnowledgeDetails
llmConf.KnowledgeRecallConfig = cfg
}
}
return llmConf, nil
}
func (s *NodeSchema) ToSelectorConfig() *selector.Config {
return &selector.Config{
Clauses: mustGetKey[[]*selector.OneClauseSchema]("Clauses", s.Configs),
}
}
func (s *NodeSchema) SelectorInputConverter(in map[string]any) (out []selector.Operants, err error) {
conf := mustGetKey[[]*selector.OneClauseSchema]("Clauses", s.Configs)
for i, oneConf := range conf {
if oneConf.Single != nil {
left, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), selector.LeftKey})
if !ok {
return nil, fmt.Errorf("failed to take left operant from input map: %v, clause index= %d", in, i)
}
right, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), selector.RightKey})
if ok {
out = append(out, selector.Operants{Left: left, Right: right})
} else {
out = append(out, selector.Operants{Left: left})
}
} else if oneConf.Multi != nil {
multiClause := make([]*selector.Operants, 0)
for j := range oneConf.Multi.Clauses {
left, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), strconv.Itoa(j), selector.LeftKey})
if !ok {
return nil, fmt.Errorf("failed to take left operant from input map: %v, clause index= %d, single clause index= %d", in, i, j)
}
right, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), strconv.Itoa(j), selector.RightKey})
if ok {
multiClause = append(multiClause, &selector.Operants{Left: left, Right: right})
} else {
multiClause = append(multiClause, &selector.Operants{Left: left})
}
}
out = append(out, selector.Operants{Multi: multiClause})
} else {
return nil, fmt.Errorf("invalid clause config, both single and multi are nil: %v", oneConf)
}
}
return out, nil
}
func (s *NodeSchema) ToBatchConfig(inner compose.Runnable[map[string]any, map[string]any]) (*batch.Config, error) {
conf := &batch.Config{
BatchNodeKey: s.Key,
InnerWorkflow: inner,
Outputs: s.OutputSources,
}
for key, tInfo := range s.InputTypes {
if tInfo.Type != vo.DataTypeArray {
continue
}
conf.InputArrays = append(conf.InputArrays, key)
}
return conf, nil
}
func (s *NodeSchema) ToVariableAggregatorConfig() (*variableaggregator.Config, error) {
return &variableaggregator.Config{
MergeStrategy: s.Configs.(map[string]any)["MergeStrategy"].(variableaggregator.MergeStrategy),
GroupLen: s.Configs.(map[string]any)["GroupToLen"].(map[string]int),
FullSources: getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs),
NodeKey: s.Key,
InputSources: s.InputSources,
GroupOrder: mustGetKey[[]string]("GroupOrder", s.Configs),
}, nil
}
func (s *NodeSchema) variableAggregatorInputConverter(in map[string]any) (converted map[string]map[int]any) {
converted = make(map[string]map[int]any)
for k, value := range in {
m, ok := value.(map[string]any)
if !ok {
panic(errors.New("value is not a map[string]any"))
}
converted[k] = make(map[int]any, len(m))
for i, sv := range m {
index, err := strconv.Atoi(i)
if err != nil {
panic(fmt.Errorf(" converting %s to int failed, err=%v", i, err))
}
converted[k][index] = sv
}
}
return converted
}
func (s *NodeSchema) variableAggregatorStreamInputConverter(in *schema.StreamReader[map[string]any]) *schema.StreamReader[map[string]map[int]any] {
converter := func(input map[string]any) (output map[string]map[int]any, err error) {
defer func() {
if r := recover(); r != nil {
err = safego.NewPanicErr(r, debug.Stack())
}
}()
return s.variableAggregatorInputConverter(input), nil
}
return schema.StreamReaderWithConvert(in, converter)
}
func (s *NodeSchema) ToTextProcessorConfig() (*textprocessor.Config, error) {
return &textprocessor.Config{
Type: s.Configs.(map[string]any)["Type"].(textprocessor.Type),
Tpl: getKeyOrZero[string]("Tpl", s.Configs.(map[string]any)),
ConcatChar: getKeyOrZero[string]("ConcatChar", s.Configs.(map[string]any)),
Separators: getKeyOrZero[[]string]("Separators", s.Configs.(map[string]any)),
FullSources: getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs),
}, nil
}
func (s *NodeSchema) ToJsonSerializationConfig() (*json.SerializationConfig, error) {
return &json.SerializationConfig{
InputTypes: s.InputTypes,
}, nil
}
func (s *NodeSchema) ToJsonDeserializationConfig() (*json.DeserializationConfig, error) {
return &json.DeserializationConfig{
OutputFields: s.OutputTypes,
}, nil
}
func (s *NodeSchema) ToHTTPRequesterConfig() (*httprequester.Config, error) {
return &httprequester.Config{
URLConfig: mustGetKey[httprequester.URLConfig]("URLConfig", s.Configs),
AuthConfig: getKeyOrZero[*httprequester.AuthenticationConfig]("AuthConfig", s.Configs),
BodyConfig: mustGetKey[httprequester.BodyConfig]("BodyConfig", s.Configs),
Method: mustGetKey[string]("Method", s.Configs),
Timeout: mustGetKey[time.Duration]("Timeout", s.Configs),
RetryTimes: mustGetKey[uint64]("RetryTimes", s.Configs),
MD5FieldMapping: mustGetKey[httprequester.MD5FieldMapping]("MD5FieldMapping", s.Configs),
}, nil
}
func (s *NodeSchema) ToVariableAssignerConfig(handler *variable.Handler) (*variableassigner.Config, error) {
return &variableassigner.Config{
Pairs: s.Configs.([]*variableassigner.Pair),
Handler: handler,
}, nil
}
func (s *NodeSchema) ToVariableAssignerInLoopConfig() (*variableassigner.Config, error) {
return &variableassigner.Config{
Pairs: s.Configs.([]*variableassigner.Pair),
}, nil
}
func (s *NodeSchema) ToLoopConfig(inner compose.Runnable[map[string]any, map[string]any]) (*loop.Config, error) {
conf := &loop.Config{
LoopNodeKey: s.Key,
LoopType: mustGetKey[loop.Type]("LoopType", s.Configs),
IntermediateVars: getKeyOrZero[map[string]*vo.TypeInfo]("IntermediateVars", s.Configs),
Outputs: s.OutputSources,
Inner: inner,
}
for key, tInfo := range s.InputTypes {
if tInfo.Type != vo.DataTypeArray {
continue
}
if _, ok := conf.IntermediateVars[key]; ok { // exclude arrays in intermediate vars
continue
}
conf.InputArrays = append(conf.InputArrays, key)
}
return conf, nil
}
func (s *NodeSchema) ToQAConfig(ctx context.Context) (*qa.Config, error) {
conf := &qa.Config{
QuestionTpl: mustGetKey[string]("QuestionTpl", s.Configs),
AnswerType: mustGetKey[qa.AnswerType]("AnswerType", s.Configs),
ChoiceType: getKeyOrZero[qa.ChoiceType]("ChoiceType", s.Configs),
FixedChoices: getKeyOrZero[[]string]("FixedChoices", s.Configs),
ExtractFromAnswer: getKeyOrZero[bool]("ExtractFromAnswer", s.Configs),
MaxAnswerCount: getKeyOrZero[int]("MaxAnswerCount", s.Configs),
AdditionalSystemPromptTpl: getKeyOrZero[string]("AdditionalSystemPromptTpl", s.Configs),
OutputFields: s.OutputTypes,
NodeKey: s.Key,
}
llmParams := getKeyOrZero[*model.LLMParams]("LLMParams", s.Configs)
if llmParams != nil {
m, _, err := model.GetManager().GetModel(ctx, llmParams)
if err != nil {
return nil, err
}
conf.Model = m
}
return conf, nil
}
func (s *NodeSchema) ToInputReceiverConfig() (*receiver.Config, error) {
return &receiver.Config{
OutputTypes: s.OutputTypes,
NodeKey: s.Key,
OutputSchema: mustGetKey[string]("OutputSchema", s.Configs),
}, nil
}
func (s *NodeSchema) ToOutputEmitterConfig(sc *WorkflowSchema) (*emitter.Config, error) {
conf := &emitter.Config{
Template: getKeyOrZero[string]("Template", s.Configs),
FullSources: getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs),
}
return conf, nil
}
func (s *NodeSchema) ToDatabaseCustomSQLConfig() (*database.CustomSQLConfig, error) {
return &database.CustomSQLConfig{
DatabaseInfoID: mustGetKey[int64]("DatabaseInfoID", s.Configs),
SQLTemplate: mustGetKey[string]("SQLTemplate", s.Configs),
OutputConfig: s.OutputTypes,
CustomSQLExecutor: crossdatabase.GetDatabaseOperator(),
}, nil
}
func (s *NodeSchema) ToDatabaseQueryConfig() (*database.QueryConfig, error) {
return &database.QueryConfig{
DatabaseInfoID: mustGetKey[int64]("DatabaseInfoID", s.Configs),
QueryFields: getKeyOrZero[[]string]("QueryFields", s.Configs),
OrderClauses: getKeyOrZero[[]*crossdatabase.OrderClause]("OrderClauses", s.Configs),
ClauseGroup: getKeyOrZero[*crossdatabase.ClauseGroup]("ClauseGroup", s.Configs),
OutputConfig: s.OutputTypes,
Limit: mustGetKey[int64]("Limit", s.Configs),
Op: crossdatabase.GetDatabaseOperator(),
}, nil
}
func (s *NodeSchema) ToDatabaseInsertConfig() (*database.InsertConfig, error) {
return &database.InsertConfig{
DatabaseInfoID: mustGetKey[int64]("DatabaseInfoID", s.Configs),
OutputConfig: s.OutputTypes,
Inserter: crossdatabase.GetDatabaseOperator(),
}, nil
}
func (s *NodeSchema) ToDatabaseDeleteConfig() (*database.DeleteConfig, error) {
return &database.DeleteConfig{
DatabaseInfoID: mustGetKey[int64]("DatabaseInfoID", s.Configs),
ClauseGroup: mustGetKey[*crossdatabase.ClauseGroup]("ClauseGroup", s.Configs),
OutputConfig: s.OutputTypes,
Deleter: crossdatabase.GetDatabaseOperator(),
}, nil
}
func (s *NodeSchema) ToDatabaseUpdateConfig() (*database.UpdateConfig, error) {
return &database.UpdateConfig{
DatabaseInfoID: mustGetKey[int64]("DatabaseInfoID", s.Configs),
ClauseGroup: mustGetKey[*crossdatabase.ClauseGroup]("ClauseGroup", s.Configs),
OutputConfig: s.OutputTypes,
Updater: crossdatabase.GetDatabaseOperator(),
}, nil
}
func (s *NodeSchema) ToKnowledgeIndexerConfig() (*knowledge.IndexerConfig, error) {
return &knowledge.IndexerConfig{
KnowledgeID: mustGetKey[int64]("KnowledgeID", s.Configs),
ParsingStrategy: mustGetKey[*crossknowledge.ParsingStrategy]("ParsingStrategy", s.Configs),
ChunkingStrategy: mustGetKey[*crossknowledge.ChunkingStrategy]("ChunkingStrategy", s.Configs),
KnowledgeIndexer: crossknowledge.GetKnowledgeOperator(),
}, nil
}
func (s *NodeSchema) ToKnowledgeRetrieveConfig() (*knowledge.RetrieveConfig, error) {
return &knowledge.RetrieveConfig{
KnowledgeIDs: mustGetKey[[]int64]("KnowledgeIDs", s.Configs),
RetrievalStrategy: mustGetKey[*crossknowledge.RetrievalStrategy]("RetrievalStrategy", s.Configs),
Retriever: crossknowledge.GetKnowledgeOperator(),
}, nil
}
func (s *NodeSchema) ToKnowledgeDeleterConfig() (*knowledge.DeleterConfig, error) {
return &knowledge.DeleterConfig{
KnowledgeID: mustGetKey[int64]("KnowledgeID", s.Configs),
KnowledgeDeleter: crossknowledge.GetKnowledgeOperator(),
}, nil
}
func (s *NodeSchema) ToPluginConfig() (*plugin.Config, error) {
return &plugin.Config{
PluginID: mustGetKey[int64]("PluginID", s.Configs),
ToolID: mustGetKey[int64]("ToolID", s.Configs),
PluginVersion: mustGetKey[string]("PluginVersion", s.Configs),
PluginService: crossplugin.GetPluginService(),
}, nil
}
func (s *NodeSchema) ToCodeRunnerConfig() (*code.Config, error) {
return &code.Config{
Code: mustGetKey[string]("Code", s.Configs),
Language: mustGetKey[coderunner.Language]("Language", s.Configs),
OutputConfig: s.OutputTypes,
Runner: crosscode.GetCodeRunner(),
}, nil
}
func (s *NodeSchema) ToCreateConversationConfig() (*conversation.CreateConversationConfig, error) {
return &conversation.CreateConversationConfig{
Creator: crossconversation.ConversationManagerImpl,
}, nil
}
func (s *NodeSchema) ToClearMessageConfig() (*conversation.ClearMessageConfig, error) {
return &conversation.ClearMessageConfig{
Clearer: crossconversation.ConversationManagerImpl,
}, nil
}
func (s *NodeSchema) ToMessageListConfig() (*conversation.MessageListConfig, error) {
return &conversation.MessageListConfig{
Lister: crossconversation.ConversationManagerImpl,
}, nil
}
func (s *NodeSchema) ToIntentDetectorConfig(ctx context.Context) (*intentdetector.Config, error) {
cfg := &intentdetector.Config{
Intents: mustGetKey[[]string]("Intents", s.Configs),
SystemPrompt: getKeyOrZero[string]("SystemPrompt", s.Configs),
IsFastMode: getKeyOrZero[bool]("IsFastMode", s.Configs),
}
llmParams := mustGetKey[*model.LLMParams]("LLMParams", s.Configs)
m, _, err := model.GetManager().GetModel(ctx, llmParams)
if err != nil {
return nil, err
}
cfg.ChatModel = m
return cfg, nil
}
func (s *NodeSchema) ToSubWorkflowConfig(ctx context.Context, requireCheckpoint bool) (*subworkflow.Config, error) {
var opts []WorkflowOption
opts = append(opts, WithIDAsName(mustGetKey[int64]("WorkflowID", s.Configs)))
if requireCheckpoint {
opts = append(opts, WithParentRequireCheckpoint())
}
if s := execute.GetStaticConfig(); s != nil && s.MaxNodeCountPerWorkflow > 0 {
opts = append(opts, WithMaxNodeCount(s.MaxNodeCountPerWorkflow))
}
wf, err := NewWorkflow(ctx, s.SubWorkflowSchema, opts...)
if err != nil {
return nil, err
}
return &subworkflow.Config{
Runner: wf.Runner,
}, nil
}
func totRetrievalSearchType(s int64) (crossknowledge.SearchType, error) {
switch s {
case 0:
return crossknowledge.SearchTypeSemantic, nil
case 1:
return crossknowledge.SearchTypeHybrid, nil
case 20:
return crossknowledge.SearchTypeFullText, nil
default:
return "", fmt.Errorf("invalid retrieval search type %v", s)
}
}

View File

@@ -1,107 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package compose
import (
"fmt"
"reflect"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
)
func getKeyOrZero[T any](key string, cfg any) T {
var zero T
if cfg == nil {
return zero
}
m, ok := cfg.(map[string]any)
if !ok {
panic(fmt.Sprintf("m is not a map[string]any, actual type: %v", reflect.TypeOf(cfg)))
}
if len(m) == 0 {
return zero
}
if v, ok := m[key]; ok {
return v.(T)
}
return zero
}
func mustGetKey[T any](key string, cfg any) T {
if cfg == nil {
panic(fmt.Sprintf("mustGetKey[*any] is nil, key=%s", key))
}
m, ok := cfg.(map[string]any)
if !ok {
panic(fmt.Sprintf("m is not a map[string]any, actual type: %v", reflect.TypeOf(cfg)))
}
if _, ok := m[key]; !ok {
panic(fmt.Sprintf("key %s does not exist in map: %v", key, m))
}
v, ok := m[key].(T)
if !ok {
panic(fmt.Sprintf("key %s is not a %v, actual type: %v", key, reflect.TypeOf(v), reflect.TypeOf(m[key])))
}
return v
}
func (s *NodeSchema) SetConfigKV(key string, value any) {
if s.Configs == nil {
s.Configs = make(map[string]any)
}
s.Configs.(map[string]any)[key] = value
}
func (s *NodeSchema) SetInputType(key string, t *vo.TypeInfo) {
if s.InputTypes == nil {
s.InputTypes = make(map[string]*vo.TypeInfo)
}
s.InputTypes[key] = t
}
func (s *NodeSchema) AddInputSource(info ...*vo.FieldInfo) {
s.InputSources = append(s.InputSources, info...)
}
func (s *NodeSchema) SetOutputType(key string, t *vo.TypeInfo) {
if s.OutputTypes == nil {
s.OutputTypes = make(map[string]*vo.TypeInfo)
}
s.OutputTypes[key] = t
}
func (s *NodeSchema) AddOutputSource(info ...*vo.FieldInfo) {
s.OutputSources = append(s.OutputSources, info...)
}
func (s *NodeSchema) GetSubWorkflowIdentity() (int64, string, bool) {
if s.Type != entity.NodeTypeSubWorkflow {
return 0, "", false
}
return mustGetKey[int64]("WorkflowID", s.Configs), mustGetKey[string]("WorkflowVersion", s.Configs), true
}

View File

@@ -29,6 +29,8 @@ import (
workflow2 "github.com/coze-dev/coze-studio/backend/domain/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
)
@@ -37,7 +39,7 @@ type workflow = compose.Workflow[map[string]any, map[string]any]
type Workflow struct { // TODO: too many fields in this struct, cut them down to the absolutely essentials
*workflow
hierarchy map[vo.NodeKey]vo.NodeKey
connections []*Connection
connections []*schema.Connection
requireCheckpoint bool
entry *compose.WorkflowNode
inner bool
@@ -47,7 +49,7 @@ type Workflow struct { // TODO: too many fields in this struct, cut them down to
input map[string]*vo.TypeInfo
output map[string]*vo.TypeInfo
terminatePlan vo.TerminatePlan
schema *WorkflowSchema
schema *schema.WorkflowSchema
}
type workflowOptions struct {
@@ -78,7 +80,7 @@ func WithMaxNodeCount(c int) WorkflowOption {
}
}
func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption) (*Workflow, error) {
func NewWorkflow(ctx context.Context, sc *schema.WorkflowSchema, opts ...WorkflowOption) (*Workflow, error) {
sc.Init()
wf := &Workflow{
@@ -88,8 +90,8 @@ func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption
schema: sc,
}
wf.streamRun = sc.requireStreaming
wf.requireCheckpoint = sc.requireCheckPoint
wf.streamRun = sc.RequireStreaming()
wf.requireCheckpoint = sc.RequireCheckpoint()
wfOpts := &workflowOptions{}
for _, opt := range opts {
@@ -125,7 +127,6 @@ func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption
processedNodeKey[child.Key] = struct{}{}
}
}
// add all nodes other than composite nodes and their children
for _, ns := range sc.Nodes {
if _, ok := processedNodeKey[ns.Key]; !ok {
@@ -135,7 +136,7 @@ func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption
}
if ns.Type == entity.NodeTypeExit {
wf.terminatePlan = mustGetKey[vo.TerminatePlan]("TerminalPlan", ns.Configs)
wf.terminatePlan = ns.Configs.(*exit.Config).TerminatePlan
}
}
@@ -147,7 +148,7 @@ func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption
compileOpts = append(compileOpts, compose.WithGraphName(strconv.FormatInt(wfOpts.wfID, 10)))
}
fanInConfigs := sc.fanInMergeConfigs()
fanInConfigs := sc.FanInMergeConfigs()
if len(fanInConfigs) > 0 {
compileOpts = append(compileOpts, compose.WithFanInMergeConfig(fanInConfigs))
}
@@ -199,12 +200,12 @@ type innerWorkflowInfo struct {
carryOvers map[vo.NodeKey][]*compose.FieldMapping
}
func (w *Workflow) AddNode(ctx context.Context, ns *NodeSchema) error {
func (w *Workflow) AddNode(ctx context.Context, ns *schema.NodeSchema) error {
_, err := w.addNodeInternal(ctx, ns, nil)
return err
}
func (w *Workflow) AddCompositeNode(ctx context.Context, cNode *CompositeNode) error {
func (w *Workflow) AddCompositeNode(ctx context.Context, cNode *schema.CompositeNode) error {
inner, err := w.getInnerWorkflow(ctx, cNode)
if err != nil {
return err
@@ -213,11 +214,11 @@ func (w *Workflow) AddCompositeNode(ctx context.Context, cNode *CompositeNode) e
return err
}
func (w *Workflow) addInnerNode(ctx context.Context, cNode *NodeSchema) (map[vo.NodeKey][]*compose.FieldMapping, error) {
func (w *Workflow) addInnerNode(ctx context.Context, cNode *schema.NodeSchema) (map[vo.NodeKey][]*compose.FieldMapping, error) {
return w.addNodeInternal(ctx, cNode, nil)
}
func (w *Workflow) addNodeInternal(ctx context.Context, ns *NodeSchema, inner *innerWorkflowInfo) (map[vo.NodeKey][]*compose.FieldMapping, error) {
func (w *Workflow) addNodeInternal(ctx context.Context, ns *schema.NodeSchema, inner *innerWorkflowInfo) (map[vo.NodeKey][]*compose.FieldMapping, error) {
key := ns.Key
var deps *dependencyInfo
@@ -237,7 +238,7 @@ func (w *Workflow) addNodeInternal(ctx context.Context, ns *NodeSchema, inner *i
innerWorkflow = inner.inner
}
ins, err := ns.New(ctx, innerWorkflow, w.schema, deps)
ins, err := New(ctx, ns, innerWorkflow, w.schema, deps)
if err != nil {
return nil, err
}
@@ -245,12 +246,12 @@ func (w *Workflow) addNodeInternal(ctx context.Context, ns *NodeSchema, inner *i
var opts []compose.GraphAddNodeOpt
opts = append(opts, compose.WithNodeName(string(ns.Key)))
preHandler := ns.StatePreHandler(w.streamRun)
preHandler := statePreHandler(ns, w.streamRun)
if preHandler != nil {
opts = append(opts, preHandler)
}
postHandler := ns.StatePostHandler(w.streamRun)
postHandler := statePostHandler(ns, w.streamRun)
if postHandler != nil {
opts = append(opts, postHandler)
}
@@ -297,19 +298,23 @@ func (w *Workflow) addNodeInternal(ctx context.Context, ns *NodeSchema, inner *i
w.entry = wNode
}
outputPortCount, hasExceptionPort := ns.OutputPortCount()
if outputPortCount > 1 || hasExceptionPort {
bMapping, err := w.resolveBranch(key, outputPortCount)
if err != nil {
return nil, err
}
b := w.schema.GetBranch(ns.Key)
if b != nil {
if b.OnlyException() {
_ = w.AddBranch(string(key), b.GetExceptionBranch())
} else {
bb, ok := ns.Configs.(schema.BranchBuilder)
if !ok {
return nil, fmt.Errorf("node schema's Configs should implement BranchBuilder, node type= %v", ns.Type)
}
branch, err := ns.GetBranch(bMapping)
if err != nil {
return nil, err
}
br, err := b.GetFullBranch(ctx, bb)
if err != nil {
return nil, err
}
_ = w.AddBranch(string(key), branch)
_ = w.AddBranch(string(key), br)
}
}
return deps.inputsForParent, nil
@@ -328,15 +333,15 @@ func (w *Workflow) Compile(ctx context.Context, opts ...compose.GraphCompileOpti
return w.workflow.Compile(ctx, opts...)
}
func (w *Workflow) getInnerWorkflow(ctx context.Context, cNode *CompositeNode) (*innerWorkflowInfo, error) {
innerNodes := make(map[vo.NodeKey]*NodeSchema)
func (w *Workflow) getInnerWorkflow(ctx context.Context, cNode *schema.CompositeNode) (*innerWorkflowInfo, error) {
innerNodes := make(map[vo.NodeKey]*schema.NodeSchema)
for _, n := range cNode.Children {
innerNodes[n.Key] = n
}
// trim the connections, only keep the connections that are related to the inner workflow
// ignore the cases when we have nested inner workflows, because we do not support nested composite nodes
innerConnections := make([]*Connection, 0)
innerConnections := make([]*schema.Connection, 0)
for i := range w.schema.Connections {
conn := w.schema.Connections[i]
if _, ok := innerNodes[conn.FromNode]; ok {
@@ -510,7 +515,7 @@ func (d *dependencyInfo) merge(mappings map[vo.NodeKey][]*compose.FieldMapping)
// For example, if the 'from path' is ['a', 'b', 'c'], and 'b' is an array, we will take value using a.b[0].c.
// As a counter example, if the 'from path' is ['a', 'b', 'c'], and 'b' is not an array, but 'c' is an array,
// we will not try to drill, instead, just take value using a.b.c.
func (d *dependencyInfo) arrayDrillDown(allNS map[vo.NodeKey]*NodeSchema) error {
func (d *dependencyInfo) arrayDrillDown(allNS map[vo.NodeKey]*schema.NodeSchema) error {
for nKey, fms := range d.inputs {
if nKey == compose.START { // reference to START node would NEVER need to do array drill down
continue
@@ -638,55 +643,6 @@ type variableInfo struct {
toPath compose.FieldPath
}
func (w *Workflow) resolveBranch(n vo.NodeKey, portCount int) (*BranchMapping, error) {
m := make([]map[string]bool, portCount)
var exception map[string]bool
for _, conn := range w.connections {
if conn.FromNode != n {
continue
}
if conn.FromPort == nil {
continue
}
if *conn.FromPort == "default" { // default condition
if m[portCount-1] == nil {
m[portCount-1] = make(map[string]bool)
}
m[portCount-1][string(conn.ToNode)] = true
} else if *conn.FromPort == "branch_error" {
if exception == nil {
exception = make(map[string]bool)
}
exception[string(conn.ToNode)] = true
} else {
if !strings.HasPrefix(*conn.FromPort, "branch_") {
return nil, fmt.Errorf("outgoing connections has invalid port= %s", *conn.FromPort)
}
index := (*conn.FromPort)[7:]
i, err := strconv.Atoi(index)
if err != nil {
return nil, fmt.Errorf("outgoing connections has invalid port index= %s", *conn.FromPort)
}
if i < 0 || i >= portCount {
return nil, fmt.Errorf("outgoing connections has invalid port index range= %d, condition count= %d", i, portCount)
}
if m[i] == nil {
m[i] = make(map[string]bool)
}
m[i][string(conn.ToNode)] = true
}
}
return &BranchMapping{
Normal: m,
Exception: exception,
}, nil
}
func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.FieldInfo) (*dependencyInfo, error) {
var (
inputs = make(map[vo.NodeKey][]*compose.FieldMapping)
@@ -701,7 +657,7 @@ func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.Field
inputsForParent = make(map[vo.NodeKey][]*compose.FieldMapping)
)
connMap := make(map[vo.NodeKey]Connection)
connMap := make(map[vo.NodeKey]schema.Connection)
for _, conn := range w.connections {
if conn.ToNode != n {
continue
@@ -734,7 +690,7 @@ func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.Field
continue
}
if ok := isInSameWorkflow(w.hierarchy, n, fromNode); ok {
if ok := schema.IsInSameWorkflow(w.hierarchy, n, fromNode); ok {
if _, ok := connMap[fromNode]; ok { // direct dependency
if len(swp.Source.Ref.FromPath) == 0 && len(swp.Path) == 0 {
if inputFull == nil {
@@ -755,10 +711,10 @@ func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.Field
compose.MapFieldPaths(swp.Source.Ref.FromPath, swp.Path))
}
}
} else if ok := isBelowOneLevel(w.hierarchy, n, fromNode); ok {
} else if ok := schema.IsBelowOneLevel(w.hierarchy, n, fromNode); ok {
firstNodesInInnerWorkflow := true
for _, conn := range connMap {
if isInSameWorkflow(w.hierarchy, n, conn.FromNode) {
if schema.IsInSameWorkflow(w.hierarchy, n, conn.FromNode) {
// there is another node 'conn.FromNode' that connects to this node, while also at the same level
firstNodesInInnerWorkflow = false
break
@@ -805,9 +761,9 @@ func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.Field
continue
}
if isBelowOneLevel(w.hierarchy, n, fromNodeKey) {
if schema.IsBelowOneLevel(w.hierarchy, n, fromNodeKey) {
fromNodeKey = compose.START
} else if !isInSameWorkflow(w.hierarchy, n, fromNodeKey) {
} else if !schema.IsInSameWorkflow(w.hierarchy, n, fromNodeKey) {
continue
}
@@ -864,13 +820,13 @@ func (w *Workflow) resolveDependenciesAsParent(n vo.NodeKey, sourceWithPaths []*
variableInfos []*variableInfo
)
connMap := make(map[vo.NodeKey]Connection)
connMap := make(map[vo.NodeKey]schema.Connection)
for _, conn := range w.connections {
if conn.ToNode != n {
continue
}
if isInSameWorkflow(w.hierarchy, conn.FromNode, n) {
if schema.IsInSameWorkflow(w.hierarchy, conn.FromNode, n) {
continue
}
@@ -899,7 +855,7 @@ func (w *Workflow) resolveDependenciesAsParent(n vo.NodeKey, sourceWithPaths []*
swp.Source.Ref.FromPath, swp.Path)
}
if ok := isParentOf(w.hierarchy, n, fromNode); ok {
if ok := schema.IsParentOf(w.hierarchy, n, fromNode); ok {
if _, ok := connMap[fromNode]; ok { // direct dependency
inputs[fromNode] = append(inputs[fromNode], compose.MapFieldPaths(swp.Source.Ref.FromPath, append(compose.FieldPath{string(fromNode)}, swp.Source.Ref.FromPath...)))
} else { // indirect dependency

View File

@@ -23,9 +23,10 @@ import (
workflow2 "github.com/coze-dev/coze-studio/backend/domain/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
func NewWorkflowFromNode(ctx context.Context, sc *WorkflowSchema, nodeKey vo.NodeKey, opts ...compose.GraphCompileOption) (
func NewWorkflowFromNode(ctx context.Context, sc *schema.WorkflowSchema, nodeKey vo.NodeKey, opts ...compose.GraphCompileOption) (
*Workflow, error) {
sc.Init()
ns := sc.GetNode(nodeKey)
@@ -37,7 +38,7 @@ func NewWorkflowFromNode(ctx context.Context, sc *WorkflowSchema, nodeKey vo.Nod
schema: sc,
fromNode: true,
streamRun: false, // single node run can only invoke
requireCheckpoint: sc.requireCheckPoint,
requireCheckpoint: sc.RequireCheckpoint(),
input: ns.InputTypes,
output: ns.OutputTypes,
terminatePlan: vo.ReturnVariables,

View File

@@ -32,6 +32,7 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
@@ -42,7 +43,7 @@ type WorkflowRunner struct {
basic *entity.WorkflowBasic
input string
resumeReq *entity.ResumeRequest
schema *WorkflowSchema
schema *schema2.WorkflowSchema
streamWriter *schema.StreamWriter[*entity.Message]
config vo.ExecuteConfig
@@ -76,7 +77,7 @@ func WithStreamWriter(sw *schema.StreamWriter[*entity.Message]) WorkflowRunnerOp
}
}
func NewWorkflowRunner(b *entity.WorkflowBasic, sc *WorkflowSchema, config vo.ExecuteConfig, opts ...WorkflowRunnerOption) *WorkflowRunner {
func NewWorkflowRunner(b *entity.WorkflowBasic, sc *schema2.WorkflowSchema, config vo.ExecuteConfig, opts ...WorkflowRunnerOption) *WorkflowRunner {
options := &workflowRunOptions{}
for _, opt := range opts {
opt(options)

View File

@@ -1,336 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package compose
import (
"fmt"
"maps"
"reflect"
"sync"
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
)
type WorkflowSchema struct {
Nodes []*NodeSchema `json:"nodes"`
Connections []*Connection `json:"connections"`
Hierarchy map[vo.NodeKey]vo.NodeKey `json:"hierarchy,omitempty"` // child node key-> parent node key
GeneratedNodes []vo.NodeKey `json:"generated_nodes,omitempty"` // generated nodes for the nodes in batch mode
nodeMap map[vo.NodeKey]*NodeSchema // won't serialize this
compositeNodes []*CompositeNode // won't serialize this
requireCheckPoint bool // won't serialize this
requireStreaming bool
once sync.Once
}
type Connection struct {
FromNode vo.NodeKey `json:"from_node"`
ToNode vo.NodeKey `json:"to_node"`
FromPort *string `json:"from_port,omitempty"`
}
func (c *Connection) ID() string {
if c.FromPort != nil {
return fmt.Sprintf("%s:%s:%v", c.FromNode, c.ToNode, *c.FromPort)
}
return fmt.Sprintf("%v:%v", c.FromNode, c.ToNode)
}
type CompositeNode struct {
Parent *NodeSchema
Children []*NodeSchema
}
func (w *WorkflowSchema) Init() {
w.once.Do(func() {
w.nodeMap = make(map[vo.NodeKey]*NodeSchema)
for _, node := range w.Nodes {
w.nodeMap[node.Key] = node
}
w.doGetCompositeNodes()
for _, node := range w.Nodes {
if node.requireCheckpoint() {
w.requireCheckPoint = true
break
}
}
w.requireStreaming = w.doRequireStreaming()
})
}
func (w *WorkflowSchema) GetNode(key vo.NodeKey) *NodeSchema {
return w.nodeMap[key]
}
func (w *WorkflowSchema) GetAllNodes() map[vo.NodeKey]*NodeSchema {
return w.nodeMap // TODO: needs to calculate node count separately, considering batch mode nodes
}
func (w *WorkflowSchema) GetCompositeNodes() []*CompositeNode {
if w.compositeNodes == nil {
w.compositeNodes = w.doGetCompositeNodes()
}
return w.compositeNodes
}
func (w *WorkflowSchema) doGetCompositeNodes() (cNodes []*CompositeNode) {
if w.Hierarchy == nil {
return nil
}
// Build parent to children mapping
parentToChildren := make(map[vo.NodeKey][]*NodeSchema)
for childKey, parentKey := range w.Hierarchy {
if parentSchema := w.nodeMap[parentKey]; parentSchema != nil {
if childSchema := w.nodeMap[childKey]; childSchema != nil {
parentToChildren[parentKey] = append(parentToChildren[parentKey], childSchema)
}
}
}
// Create composite nodes
for parentKey, children := range parentToChildren {
if parentSchema := w.nodeMap[parentKey]; parentSchema != nil {
cNodes = append(cNodes, &CompositeNode{
Parent: parentSchema,
Children: children,
})
}
}
return cNodes
}
func isInSameWorkflow(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.NodeKey) bool {
if n == nil {
return true
}
myParents, myParentExists := n[nodeKey]
theirParents, theirParentExists := n[otherNodeKey]
if !myParentExists && !theirParentExists {
return true
}
if !myParentExists || !theirParentExists {
return false
}
return myParents == theirParents
}
func isBelowOneLevel(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.NodeKey) bool {
if n == nil {
return false
}
_, myParentExists := n[nodeKey]
_, theirParentExists := n[otherNodeKey]
return myParentExists && !theirParentExists
}
func isParentOf(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.NodeKey) bool {
if n == nil {
return false
}
theirParent, theirParentExists := n[otherNodeKey]
return theirParentExists && theirParent == nodeKey
}
func (w *WorkflowSchema) IsEqual(other *WorkflowSchema) bool {
otherConnectionsMap := make(map[string]bool, len(other.Connections))
for _, connection := range other.Connections {
otherConnectionsMap[connection.ID()] = true
}
connectionsMap := make(map[string]bool, len(other.Connections))
for _, connection := range w.Connections {
connectionsMap[connection.ID()] = true
}
if !maps.Equal(otherConnectionsMap, connectionsMap) {
return false
}
otherNodeMap := make(map[vo.NodeKey]*NodeSchema, len(other.Nodes))
for _, node := range other.Nodes {
otherNodeMap[node.Key] = node
}
nodeMap := make(map[vo.NodeKey]*NodeSchema, len(w.Nodes))
for _, node := range w.Nodes {
nodeMap[node.Key] = node
}
if !maps.EqualFunc(otherNodeMap, nodeMap, func(node *NodeSchema, other *NodeSchema) bool {
if node.Name != other.Name {
return false
}
if !reflect.DeepEqual(node.Configs, other.Configs) {
return false
}
if !reflect.DeepEqual(node.InputTypes, other.InputTypes) {
return false
}
if !reflect.DeepEqual(node.InputSources, other.InputSources) {
return false
}
if !reflect.DeepEqual(node.OutputTypes, other.OutputTypes) {
return false
}
if !reflect.DeepEqual(node.OutputSources, other.OutputSources) {
return false
}
if !reflect.DeepEqual(node.ExceptionConfigs, other.ExceptionConfigs) {
return false
}
if !reflect.DeepEqual(node.SubWorkflowBasic, other.SubWorkflowBasic) {
return false
}
return true
}) {
return false
}
return true
}
func (w *WorkflowSchema) NodeCount() int32 {
return int32(len(w.Nodes) - len(w.GeneratedNodes))
}
func (w *WorkflowSchema) doRequireStreaming() bool {
producers := make(map[vo.NodeKey]bool)
consumers := make(map[vo.NodeKey]bool)
for _, node := range w.Nodes {
meta := entity.NodeMetaByNodeType(node.Type)
if meta != nil {
sps := meta.ExecutableMeta.StreamingParadigms
if _, ok := sps[entity.Stream]; ok {
if node.StreamConfigs != nil && node.StreamConfigs.CanGeneratesStream {
producers[node.Key] = true
}
}
if sps[entity.Transform] || sps[entity.Collect] {
if node.StreamConfigs != nil && node.StreamConfigs.RequireStreamingInput {
consumers[node.Key] = true
}
}
}
}
if len(producers) == 0 || len(consumers) == 0 {
return false
}
// Build data-flow graph from InputSources
adj := make(map[vo.NodeKey]map[vo.NodeKey]struct{})
for _, node := range w.Nodes {
for _, source := range node.InputSources {
if source.Source.Ref != nil && len(source.Source.Ref.FromNodeKey) > 0 {
if _, ok := adj[source.Source.Ref.FromNodeKey]; !ok {
adj[source.Source.Ref.FromNodeKey] = make(map[vo.NodeKey]struct{})
}
adj[source.Source.Ref.FromNodeKey][node.Key] = struct{}{}
}
}
}
// For each producer, traverse the graph to see if it can reach a consumer
for p := range producers {
q := []vo.NodeKey{p}
visited := make(map[vo.NodeKey]bool)
visited[p] = true
for len(q) > 0 {
curr := q[0]
q = q[1:]
if consumers[curr] {
return true
}
for neighbor := range adj[curr] {
if !visited[neighbor] {
visited[neighbor] = true
q = append(q, neighbor)
}
}
}
}
return false
}
func (w *WorkflowSchema) fanInMergeConfigs() map[string]compose.FanInMergeConfig {
// what we need to do is to see if the workflow requires streaming, if not, then no fan-in merge configs needed
// then we find those nodes that have 'transform' or 'collect' as streaming paradigm,
// and see if each of those nodes has multiple data predecessors, if so, it's a fan-in node.
// then, look up the NodeTypeMeta's ExecutableMeta info and see if it requires fan-in stream merge.
if !w.requireStreaming {
return nil
}
fanInNodes := make(map[vo.NodeKey]bool)
for _, node := range w.Nodes {
meta := entity.NodeMetaByNodeType(node.Type)
if meta != nil {
sps := meta.ExecutableMeta.StreamingParadigms
if sps[entity.Transform] || sps[entity.Collect] {
if node.StreamConfigs != nil && node.StreamConfigs.RequireStreamingInput {
var predecessor *vo.NodeKey
for _, source := range node.InputSources {
if source.Source.Ref != nil && len(source.Source.Ref.FromNodeKey) > 0 {
if predecessor != nil {
fanInNodes[node.Key] = true
break
}
predecessor = &source.Source.Ref.FromNodeKey
}
}
}
}
}
}
fanInConfigs := make(map[string]compose.FanInMergeConfig)
for nodeKey := range fanInNodes {
if m := entity.NodeMetaByNodeType(w.GetNode(nodeKey).Type); m != nil {
if m.StreamSourceEOFAware {
fanInConfigs[string(nodeKey)] = compose.FanInMergeConfig{
StreamMergeWithSourceEOF: true,
}
}
}
}
return fanInConfigs
}

View File

@@ -30,6 +30,7 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
)
@@ -41,7 +42,7 @@ type invokableWorkflow struct {
invoke func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (map[string]any, error)
terminatePlan vo.TerminatePlan
wfEntity *entity.Workflow
sc *WorkflowSchema
sc *schema2.WorkflowSchema
repo wf.Repository
}
@@ -49,7 +50,7 @@ func NewInvokableWorkflow(info *schema.ToolInfo,
invoke func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (map[string]any, error),
terminatePlan vo.TerminatePlan,
wfEntity *entity.Workflow,
sc *WorkflowSchema,
sc *schema2.WorkflowSchema,
repo wf.Repository,
) wf.ToolFromWorkflow {
return &invokableWorkflow{
@@ -112,7 +113,7 @@ func (i *invokableWorkflow) InvokableRun(ctx context.Context, argumentsInJSON st
return "", err
}
var entryNode *NodeSchema
var entryNode *schema2.NodeSchema
for _, node := range i.sc.Nodes {
if node.Type == entity.NodeTypeEntry {
entryNode = node
@@ -190,7 +191,7 @@ type streamableWorkflow struct {
stream func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (*schema.StreamReader[map[string]any], error)
terminatePlan vo.TerminatePlan
wfEntity *entity.Workflow
sc *WorkflowSchema
sc *schema2.WorkflowSchema
repo wf.Repository
}
@@ -198,7 +199,7 @@ func NewStreamableWorkflow(info *schema.ToolInfo,
stream func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (*schema.StreamReader[map[string]any], error),
terminatePlan vo.TerminatePlan,
wfEntity *entity.Workflow,
sc *WorkflowSchema,
sc *schema2.WorkflowSchema,
repo wf.Repository,
) wf.ToolFromWorkflow {
return &streamableWorkflow{
@@ -261,7 +262,7 @@ func (s *streamableWorkflow) StreamableRun(ctx context.Context, argumentsInJSON
return nil, err
}
var entryNode *NodeSchema
var entryNode *schema2.NodeSchema
for _, node := range s.sc.Nodes {
if node.Type == entity.NodeTypeEntry {
entryNode = node