refactor: how to add a node type in workflow (#558)
This commit is contained in:
@@ -1,181 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package compose
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/spf13/cast"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector"
|
||||
)
|
||||
|
||||
func (s *NodeSchema) OutputPortCount() (int, bool) {
|
||||
var hasExceptionPort bool
|
||||
if s.ExceptionConfigs != nil && s.ExceptionConfigs.ProcessType != nil &&
|
||||
*s.ExceptionConfigs.ProcessType == vo.ErrorProcessTypeExceptionBranch {
|
||||
hasExceptionPort = true
|
||||
}
|
||||
|
||||
switch s.Type {
|
||||
case entity.NodeTypeSelector:
|
||||
return len(mustGetKey[[]*selector.OneClauseSchema]("Clauses", s.Configs)) + 1, hasExceptionPort
|
||||
case entity.NodeTypeQuestionAnswer:
|
||||
if mustGetKey[qa.AnswerType]("AnswerType", s.Configs.(map[string]any)) == qa.AnswerByChoices {
|
||||
if mustGetKey[qa.ChoiceType]("ChoiceType", s.Configs.(map[string]any)) == qa.FixedChoices {
|
||||
return len(mustGetKey[[]string]("FixedChoices", s.Configs.(map[string]any))) + 1, hasExceptionPort
|
||||
} else {
|
||||
return 2, hasExceptionPort
|
||||
}
|
||||
}
|
||||
return 1, hasExceptionPort
|
||||
case entity.NodeTypeIntentDetector:
|
||||
intents := mustGetKey[[]string]("Intents", s.Configs.(map[string]any))
|
||||
return len(intents) + 1, hasExceptionPort
|
||||
default:
|
||||
return 1, hasExceptionPort
|
||||
}
|
||||
}
|
||||
|
||||
type BranchMapping struct {
|
||||
Normal []map[string]bool
|
||||
Exception map[string]bool
|
||||
}
|
||||
|
||||
const (
|
||||
DefaultBranch = "default"
|
||||
BranchFmt = "branch_%d"
|
||||
)
|
||||
|
||||
func (s *NodeSchema) GetBranch(bMapping *BranchMapping) (*compose.GraphBranch, error) {
|
||||
if bMapping == nil {
|
||||
return nil, errors.New("no branch mapping")
|
||||
}
|
||||
|
||||
endNodes := make(map[string]bool)
|
||||
for i := range bMapping.Normal {
|
||||
for k := range bMapping.Normal[i] {
|
||||
endNodes[k] = true
|
||||
}
|
||||
}
|
||||
|
||||
if bMapping.Exception != nil {
|
||||
for k := range bMapping.Exception {
|
||||
endNodes[k] = true
|
||||
}
|
||||
}
|
||||
|
||||
switch s.Type {
|
||||
case entity.NodeTypeSelector:
|
||||
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
|
||||
choice := in[selector.SelectKey].(int)
|
||||
if choice < 0 || choice > len(bMapping.Normal) {
|
||||
return nil, fmt.Errorf("node %s choice out of range: %d", s.Key, choice)
|
||||
}
|
||||
|
||||
choices := make(map[string]bool, len((bMapping.Normal)[choice]))
|
||||
for k := range (bMapping.Normal)[choice] {
|
||||
choices[k] = true
|
||||
}
|
||||
|
||||
return choices, nil
|
||||
}
|
||||
return compose.NewGraphMultiBranch(condition, endNodes), nil
|
||||
case entity.NodeTypeQuestionAnswer:
|
||||
conf := s.Configs.(map[string]any)
|
||||
if mustGetKey[qa.AnswerType]("AnswerType", conf) == qa.AnswerByChoices {
|
||||
choiceType := mustGetKey[qa.ChoiceType]("ChoiceType", conf)
|
||||
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
|
||||
optionID, ok := nodes.TakeMapValue(in, compose.FieldPath{qa.OptionIDKey})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to take option id from input map: %v", in)
|
||||
}
|
||||
|
||||
if optionID.(string) == "other" {
|
||||
return (bMapping.Normal)[len(bMapping.Normal)-1], nil
|
||||
}
|
||||
|
||||
if choiceType == qa.DynamicChoices { // all dynamic choices maps to branch 0
|
||||
return (bMapping.Normal)[0], nil
|
||||
}
|
||||
|
||||
optionIDInt, ok := qa.AlphabetToInt(optionID.(string))
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to convert option id from input map: %v", optionID)
|
||||
}
|
||||
|
||||
if optionIDInt < 0 || optionIDInt >= len(bMapping.Normal) {
|
||||
return nil, fmt.Errorf("failed to take option id from input map: %v", in)
|
||||
}
|
||||
|
||||
return (bMapping.Normal)[optionIDInt], nil
|
||||
}
|
||||
return compose.NewGraphMultiBranch(condition, endNodes), nil
|
||||
}
|
||||
return nil, fmt.Errorf("this qa node should not have branches: %s", s.Key)
|
||||
|
||||
case entity.NodeTypeIntentDetector:
|
||||
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
|
||||
isSuccess, ok := in["isSuccess"]
|
||||
if ok && isSuccess != nil && !isSuccess.(bool) {
|
||||
return bMapping.Exception, nil
|
||||
}
|
||||
|
||||
classificationId, ok := nodes.TakeMapValue(in, compose.FieldPath{"classificationId"})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to take classification id from input map: %v", in)
|
||||
}
|
||||
|
||||
// Intent detector the node default branch uses classificationId=0. But currently scene, the implementation uses default as the last element of the array.
|
||||
// Therefore, when classificationId=0, it needs to be converted into the node corresponding to the last index of the array.
|
||||
// Other options also need to reduce the index by 1.
|
||||
id, err := cast.ToInt64E(classificationId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
realID := id - 1
|
||||
|
||||
if realID >= int64(len(bMapping.Normal)) {
|
||||
return nil, fmt.Errorf("invalid classification id from input, classification id: %v", classificationId)
|
||||
}
|
||||
|
||||
if realID < 0 {
|
||||
realID = int64(len(bMapping.Normal)) - 1
|
||||
}
|
||||
|
||||
return (bMapping.Normal)[realID], nil
|
||||
}
|
||||
return compose.NewGraphMultiBranch(condition, endNodes), nil
|
||||
default:
|
||||
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
|
||||
isSuccess, ok := in["isSuccess"]
|
||||
if ok && isSuccess != nil && !isSuccess.(bool) {
|
||||
return bMapping.Exception, nil
|
||||
}
|
||||
|
||||
return (bMapping.Normal)[0], nil
|
||||
}
|
||||
return compose.NewGraphMultiBranch(condition, endNodes), nil
|
||||
}
|
||||
}
|
||||
@@ -1,194 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package compose
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector"
|
||||
)
|
||||
|
||||
type selectorCallbackField struct {
|
||||
Key string `json:"key"`
|
||||
Type vo.DataType `json:"type"`
|
||||
Value any `json:"value"`
|
||||
}
|
||||
|
||||
type selectorCondition struct {
|
||||
Left selectorCallbackField `json:"left"`
|
||||
Operator vo.OperatorType `json:"operator"`
|
||||
Right *selectorCallbackField `json:"right,omitempty"`
|
||||
}
|
||||
|
||||
type selectorBranch struct {
|
||||
Conditions []*selectorCondition `json:"conditions"`
|
||||
Logic vo.LogicType `json:"logic"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func (s *NodeSchema) toSelectorCallbackInput(sc *WorkflowSchema) func(_ context.Context, in map[string]any) (map[string]any, error) {
|
||||
return func(_ context.Context, in map[string]any) (map[string]any, error) {
|
||||
config := mustGetKey[[]*selector.OneClauseSchema]("Clauses", s.Configs)
|
||||
count := len(config)
|
||||
|
||||
output := make([]*selectorBranch, count)
|
||||
|
||||
for _, source := range s.InputSources {
|
||||
targetPath := source.Path
|
||||
if len(targetPath) == 2 {
|
||||
indexStr := targetPath[0]
|
||||
index, err := strconv.Atoi(indexStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
branch := output[index]
|
||||
if branch == nil {
|
||||
output[index] = &selectorBranch{
|
||||
Conditions: []*selectorCondition{
|
||||
{
|
||||
Operator: config[index].Single.ToCanvasOperatorType(),
|
||||
},
|
||||
},
|
||||
Logic: selector.ClauseRelationAND.ToVOLogicType(),
|
||||
}
|
||||
}
|
||||
|
||||
if targetPath[1] == selector.LeftKey {
|
||||
leftV, ok := nodes.TakeMapValue(in, targetPath)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to take left value of %s", targetPath)
|
||||
}
|
||||
if source.Source.Ref.VariableType != nil {
|
||||
if *source.Source.Ref.VariableType == vo.ParentIntermediate {
|
||||
parentNodeKey, ok := sc.Hierarchy[s.Key]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to find parent node key of %s", s.Key)
|
||||
}
|
||||
parentNode := sc.GetNode(parentNodeKey)
|
||||
output[index].Conditions[0].Left = selectorCallbackField{
|
||||
Key: parentNode.Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
|
||||
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
|
||||
Value: leftV,
|
||||
}
|
||||
} else {
|
||||
output[index].Conditions[0].Left = selectorCallbackField{
|
||||
Key: "",
|
||||
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
|
||||
Value: leftV,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
output[index].Conditions[0].Left = selectorCallbackField{
|
||||
Key: sc.GetNode(source.Source.Ref.FromNodeKey).Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
|
||||
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
|
||||
Value: leftV,
|
||||
}
|
||||
}
|
||||
} else if targetPath[1] == selector.RightKey {
|
||||
rightV, ok := nodes.TakeMapValue(in, targetPath)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to take right value of %s", targetPath)
|
||||
}
|
||||
output[index].Conditions[0].Right = &selectorCallbackField{
|
||||
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
|
||||
Value: rightV,
|
||||
}
|
||||
}
|
||||
} else if len(targetPath) == 3 {
|
||||
indexStr := targetPath[0]
|
||||
index, err := strconv.Atoi(indexStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
multi := config[index].Multi
|
||||
|
||||
branch := output[index]
|
||||
if branch == nil {
|
||||
output[index] = &selectorBranch{
|
||||
Conditions: make([]*selectorCondition, len(multi.Clauses)),
|
||||
Logic: multi.Relation.ToVOLogicType(),
|
||||
}
|
||||
}
|
||||
|
||||
clauseIndexStr := targetPath[1]
|
||||
clauseIndex, err := strconv.Atoi(clauseIndexStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
clause := multi.Clauses[clauseIndex]
|
||||
|
||||
if output[index].Conditions[clauseIndex] == nil {
|
||||
output[index].Conditions[clauseIndex] = &selectorCondition{
|
||||
Operator: clause.ToCanvasOperatorType(),
|
||||
}
|
||||
}
|
||||
|
||||
if targetPath[2] == selector.LeftKey {
|
||||
leftV, ok := nodes.TakeMapValue(in, targetPath)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to take left value of %s", targetPath)
|
||||
}
|
||||
if source.Source.Ref.VariableType != nil {
|
||||
if *source.Source.Ref.VariableType == vo.ParentIntermediate {
|
||||
parentNodeKey, ok := sc.Hierarchy[s.Key]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to find parent node key of %s", s.Key)
|
||||
}
|
||||
parentNode := sc.GetNode(parentNodeKey)
|
||||
output[index].Conditions[clauseIndex].Left = selectorCallbackField{
|
||||
Key: parentNode.Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
|
||||
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
|
||||
Value: leftV,
|
||||
}
|
||||
} else {
|
||||
output[index].Conditions[clauseIndex].Left = selectorCallbackField{
|
||||
Key: "",
|
||||
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
|
||||
Value: leftV,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
output[index].Conditions[clauseIndex].Left = selectorCallbackField{
|
||||
Key: sc.GetNode(source.Source.Ref.FromNodeKey).Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
|
||||
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
|
||||
Value: leftV,
|
||||
}
|
||||
}
|
||||
} else if targetPath[2] == selector.RightKey {
|
||||
rightV, ok := nodes.TakeMapValue(in, targetPath)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to take right value of %s", targetPath)
|
||||
}
|
||||
output[index].Conditions[clauseIndex].Right = &selectorCallbackField{
|
||||
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
|
||||
Value: rightV,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]any{"branches": output}, nil
|
||||
}
|
||||
}
|
||||
@@ -31,7 +31,9 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm"
|
||||
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
@@ -53,7 +55,7 @@ func (r *WorkflowRunner) designateOptions(ctx context.Context) (context.Context,
|
||||
rootHandler := execute.NewRootWorkflowHandler(
|
||||
wb,
|
||||
executeID,
|
||||
workflowSC.requireCheckPoint,
|
||||
workflowSC.RequireCheckpoint(),
|
||||
eventChan,
|
||||
resumedEvent,
|
||||
exeCfg,
|
||||
@@ -67,7 +69,7 @@ func (r *WorkflowRunner) designateOptions(ctx context.Context) (context.Context,
|
||||
var nodeOpt einoCompose.Option
|
||||
if ns.Type == entity.NodeTypeExit {
|
||||
nodeOpt = nodeCallbackOption(key, ns.Name, eventChan, resumedEvent,
|
||||
ptr.Of(mustGetKey[vo.TerminatePlan]("TerminalPlan", ns.Configs)))
|
||||
ptr.Of(ns.Configs.(*exit.Config).TerminatePlan))
|
||||
} else if ns.Type != entity.NodeTypeLambda {
|
||||
nodeOpt = nodeCallbackOption(key, ns.Name, eventChan, resumedEvent, nil)
|
||||
}
|
||||
@@ -117,7 +119,7 @@ func (r *WorkflowRunner) designateOptions(ctx context.Context) (context.Context,
|
||||
}
|
||||
}
|
||||
|
||||
if workflowSC.requireCheckPoint {
|
||||
if workflowSC.RequireCheckpoint() {
|
||||
opts = append(opts, einoCompose.WithCheckPointID(strconv.FormatInt(executeID, 10)))
|
||||
}
|
||||
|
||||
@@ -139,7 +141,7 @@ func WrapOptWithIndex(opt einoCompose.Option, parentNodeKey vo.NodeKey, index in
|
||||
|
||||
func (r *WorkflowRunner) designateOptionsForSubWorkflow(ctx context.Context,
|
||||
parentHandler *execute.WorkflowHandler,
|
||||
ns *NodeSchema,
|
||||
ns *schema2.NodeSchema,
|
||||
pathPrefix ...string) (opts []einoCompose.Option, err error) {
|
||||
var (
|
||||
resumeEvent = r.interruptEvent
|
||||
@@ -163,7 +165,7 @@ func (r *WorkflowRunner) designateOptionsForSubWorkflow(ctx context.Context,
|
||||
var nodeOpt einoCompose.Option
|
||||
if subNS.Type == entity.NodeTypeExit {
|
||||
nodeOpt = nodeCallbackOption(key, subNS.Name, eventChan, resumeEvent,
|
||||
ptr.Of(mustGetKey[vo.TerminatePlan]("TerminalPlan", subNS.Configs)))
|
||||
ptr.Of(subNS.Configs.(*exit.Config).TerminatePlan))
|
||||
} else {
|
||||
nodeOpt = nodeCallbackOption(key, subNS.Name, eventChan, resumeEvent, nil)
|
||||
}
|
||||
@@ -219,7 +221,7 @@ func (r *WorkflowRunner) designateOptionsForSubWorkflow(ctx context.Context,
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
func llmToolCallbackOptions(ctx context.Context, ns *NodeSchema, eventChan chan *execute.Event,
|
||||
func llmToolCallbackOptions(ctx context.Context, ns *schema2.NodeSchema, eventChan chan *execute.Event,
|
||||
sw *schema.StreamWriter[*entity.Message]) (
|
||||
opts []einoCompose.Option, err error) {
|
||||
// this is a LLM node.
|
||||
@@ -229,7 +231,8 @@ func llmToolCallbackOptions(ctx context.Context, ns *NodeSchema, eventChan chan
|
||||
panic("impossible: llmToolCallbackOptions is called on a non-LLM node")
|
||||
}
|
||||
|
||||
fcParams := getKeyOrZero[*vo.FCParam]("FCParam", ns.Configs)
|
||||
cfg := ns.Configs.(*llm.Config)
|
||||
fcParams := cfg.FCParam
|
||||
if fcParams != nil {
|
||||
if fcParams.WorkflowFCParam != nil {
|
||||
// TODO: try to avoid getting the workflow tool all over again
|
||||
@@ -272,7 +275,7 @@ func llmToolCallbackOptions(ctx context.Context, ns *NodeSchema, eventChan chan
|
||||
|
||||
toolHandler := execute.NewToolHandler(eventChan, funcInfo)
|
||||
opt := einoCompose.WithCallbacks(toolHandler)
|
||||
opt = einoCompose.WithLambdaOption(llm.WithNestedWorkflowOptions(nodes.WithOptsForNested(opt))).DesignateNode(string(ns.Key))
|
||||
opt = einoCompose.WithLambdaOption(nodes.WithOptsForNested(opt)).DesignateNode(string(ns.Key))
|
||||
opts = append(opts, opt)
|
||||
}
|
||||
}
|
||||
@@ -310,7 +313,7 @@ func llmToolCallbackOptions(ctx context.Context, ns *NodeSchema, eventChan chan
|
||||
|
||||
toolHandler := execute.NewToolHandler(eventChan, funcInfo)
|
||||
opt := einoCompose.WithCallbacks(toolHandler)
|
||||
opt = einoCompose.WithLambdaOption(llm.WithNestedWorkflowOptions(nodes.WithOptsForNested(opt))).DesignateNode(string(ns.Key))
|
||||
opt = einoCompose.WithLambdaOption(nodes.WithOptsForNested(opt)).DesignateNode(string(ns.Key))
|
||||
opts = append(opts, opt)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,12 +25,13 @@ import (
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
)
|
||||
|
||||
// outputValueFiller will fill the output value with nil if the key is not present in the output map.
|
||||
// if a node emits stream as output, the node needs to handle these absent keys in stream themselves.
|
||||
func (s *NodeSchema) outputValueFiller() func(ctx context.Context, output map[string]any) (map[string]any, error) {
|
||||
func outputValueFiller(s *schema2.NodeSchema) func(ctx context.Context, output map[string]any) (map[string]any, error) {
|
||||
if len(s.OutputTypes) == 0 {
|
||||
return func(ctx context.Context, output map[string]any) (map[string]any, error) {
|
||||
return output, nil
|
||||
@@ -55,7 +56,7 @@ func (s *NodeSchema) outputValueFiller() func(ctx context.Context, output map[st
|
||||
|
||||
// inputValueFiller will fill the input value with default value(zero or nil) if the input value is not present in map.
|
||||
// if a node accepts stream as input, the node needs to handle these absent keys in stream themselves.
|
||||
func (s *NodeSchema) inputValueFiller() func(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
func inputValueFiller(s *schema2.NodeSchema) func(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
if len(s.InputTypes) == 0 {
|
||||
return func(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
return input, nil
|
||||
@@ -78,7 +79,7 @@ func (s *NodeSchema) inputValueFiller() func(ctx context.Context, input map[stri
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NodeSchema) streamInputValueFiller() func(ctx context.Context,
|
||||
func streamInputValueFiller(s *schema2.NodeSchema) func(ctx context.Context,
|
||||
input *schema.StreamReader[map[string]any]) *schema.StreamReader[map[string]any] {
|
||||
fn := func(ctx context.Context, i map[string]any) (map[string]any, error) {
|
||||
newI := make(map[string]any)
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
func TestNodeSchema_OutputValueFiller(t *testing.T) {
|
||||
@@ -282,11 +283,11 @@ func TestNodeSchema_OutputValueFiller(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &NodeSchema{
|
||||
s := &schema.NodeSchema{
|
||||
OutputTypes: tt.fields.Outputs,
|
||||
}
|
||||
|
||||
got, err := s.outputValueFiller()(context.Background(), tt.fields.In)
|
||||
got, err := outputValueFiller(s)(context.Background(), tt.fields.In)
|
||||
|
||||
if len(tt.wantErr) > 0 {
|
||||
assert.Error(t, err)
|
||||
|
||||
118
backend/domain/workflow/internal/compose/node_builder.go
Normal file
118
backend/domain/workflow/internal/compose/node_builder.go
Normal file
@@ -0,0 +1,118 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package compose
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime/debug"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/subworkflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
type Node struct {
|
||||
Lambda *compose.Lambda
|
||||
}
|
||||
|
||||
// New instantiates the actual node type from NodeSchema.
|
||||
func New(ctx context.Context, s *schema.NodeSchema,
|
||||
inner compose.Runnable[map[string]any, map[string]any], // inner workflow for composite node
|
||||
sc *schema.WorkflowSchema, // the workflow this NodeSchema is in
|
||||
deps *dependencyInfo, // the dependency for this node pre-calculated by workflow engine
|
||||
) (_ *Node, err error) {
|
||||
defer func() {
|
||||
if panicErr := recover(); panicErr != nil {
|
||||
err = safego.NewPanicErr(panicErr, debug.Stack())
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
err = vo.WrapIfNeeded(errno.ErrCreateNodeFail, err, errorx.KV("node_name", s.Name), errorx.KV("cause", err.Error()))
|
||||
}
|
||||
}()
|
||||
|
||||
var fullSources map[string]*schema.SourceInfo
|
||||
if m := entity.NodeMetaByNodeType(s.Type); m != nil && m.InputSourceAware {
|
||||
if fullSources, err = GetFullSources(s, sc, deps); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.FullSources = fullSources
|
||||
}
|
||||
|
||||
// if NodeSchema's Configs implements NodeBuilder, will use it to build the node
|
||||
nb, ok := s.Configs.(schema.NodeBuilder)
|
||||
if ok {
|
||||
opts := []schema.BuildOption{
|
||||
schema.WithWorkflowSchema(sc),
|
||||
schema.WithInnerWorkflow(inner),
|
||||
}
|
||||
|
||||
// build the actual InvokableNode, etc.
|
||||
n, err := nb.Build(ctx, s, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// wrap InvokableNode, etc. within NodeRunner, converting to eino's Lambda
|
||||
return toNode(s, n), nil
|
||||
}
|
||||
|
||||
switch s.Type {
|
||||
case entity.NodeTypeLambda:
|
||||
if s.Lambda == nil {
|
||||
return nil, fmt.Errorf("lambda is not defined for NodeTypeLambda")
|
||||
}
|
||||
|
||||
return &Node{Lambda: s.Lambda}, nil
|
||||
case entity.NodeTypeSubWorkflow:
|
||||
subWorkflow, err := buildSubWorkflow(ctx, s, sc.RequireCheckpoint())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toNode(s, subWorkflow), nil
|
||||
default:
|
||||
panic(fmt.Sprintf("node schema's Configs does not implement NodeBuilder. type: %v", s.Type))
|
||||
}
|
||||
}
|
||||
|
||||
func buildSubWorkflow(ctx context.Context, s *schema.NodeSchema, requireCheckpoint bool) (*subworkflow.SubWorkflow, error) {
|
||||
var opts []WorkflowOption
|
||||
opts = append(opts, WithIDAsName(s.Configs.(*subworkflow.Config).WorkflowID))
|
||||
if requireCheckpoint {
|
||||
opts = append(opts, WithParentRequireCheckpoint())
|
||||
}
|
||||
if s := execute.GetStaticConfig(); s != nil && s.MaxNodeCountPerWorkflow > 0 {
|
||||
opts = append(opts, WithMaxNodeCount(s.MaxNodeCountPerWorkflow))
|
||||
}
|
||||
wf, err := NewWorkflow(ctx, s.SubWorkflowSchema, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &subworkflow.SubWorkflow{
|
||||
Runner: wf.Runner,
|
||||
}, nil
|
||||
}
|
||||
@@ -33,6 +33,8 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
@@ -48,7 +50,6 @@ type nodeRunConfig[O any] struct {
|
||||
maxRetry int64
|
||||
errProcessType vo.ErrorProcessType
|
||||
dataOnErr func(ctx context.Context) map[string]any
|
||||
callbackEnabled bool
|
||||
preProcessors []func(ctx context.Context, input map[string]any) (map[string]any, error)
|
||||
postProcessors []func(ctx context.Context, input map[string]any) (map[string]any, error)
|
||||
streamPreProcessors []func(ctx context.Context,
|
||||
@@ -58,12 +59,14 @@ type nodeRunConfig[O any] struct {
|
||||
init []func(context.Context) (context.Context, error)
|
||||
i compose.Invoke[map[string]any, map[string]any, O]
|
||||
s compose.Stream[map[string]any, map[string]any, O]
|
||||
c compose.Collect[map[string]any, map[string]any, O]
|
||||
t compose.Transform[map[string]any, map[string]any, O]
|
||||
}
|
||||
|
||||
func newNodeRunConfig[O any](ns *NodeSchema,
|
||||
func newNodeRunConfig[O any](ns *schema2.NodeSchema,
|
||||
i compose.Invoke[map[string]any, map[string]any, O],
|
||||
s compose.Stream[map[string]any, map[string]any, O],
|
||||
c compose.Collect[map[string]any, map[string]any, O],
|
||||
t compose.Transform[map[string]any, map[string]any, O],
|
||||
opts *newNodeOptions) *nodeRunConfig[O] {
|
||||
meta := entity.NodeMetaByNodeType(ns.Type)
|
||||
@@ -92,12 +95,12 @@ func newNodeRunConfig[O any](ns *NodeSchema,
|
||||
keyFinishedMarkerTrimmer(),
|
||||
}
|
||||
if meta.PreFillZero {
|
||||
preProcessors = append(preProcessors, ns.inputValueFiller())
|
||||
preProcessors = append(preProcessors, inputValueFiller(ns))
|
||||
}
|
||||
|
||||
var postProcessors []func(ctx context.Context, input map[string]any) (map[string]any, error)
|
||||
if meta.PostFillNil {
|
||||
postProcessors = append(postProcessors, ns.outputValueFiller())
|
||||
postProcessors = append(postProcessors, outputValueFiller(ns))
|
||||
}
|
||||
|
||||
streamPreProcessors := []func(ctx context.Context,
|
||||
@@ -110,7 +113,15 @@ func newNodeRunConfig[O any](ns *NodeSchema,
|
||||
},
|
||||
}
|
||||
if meta.PreFillZero {
|
||||
streamPreProcessors = append(streamPreProcessors, ns.streamInputValueFiller())
|
||||
streamPreProcessors = append(streamPreProcessors, streamInputValueFiller(ns))
|
||||
}
|
||||
|
||||
if meta.UseCtxCache {
|
||||
opts.init = append([]func(ctx context.Context) (context.Context, error){
|
||||
func(ctx context.Context) (context.Context, error) {
|
||||
return ctxcache.Init(ctx), nil
|
||||
},
|
||||
}, opts.init...)
|
||||
}
|
||||
|
||||
opts.init = append(opts.init, func(ctx context.Context) (context.Context, error) {
|
||||
@@ -129,7 +140,6 @@ func newNodeRunConfig[O any](ns *NodeSchema,
|
||||
maxRetry: maxRetry,
|
||||
errProcessType: errProcessType,
|
||||
dataOnErr: dataOnErr,
|
||||
callbackEnabled: meta.CallbackEnabled,
|
||||
preProcessors: preProcessors,
|
||||
postProcessors: postProcessors,
|
||||
streamPreProcessors: streamPreProcessors,
|
||||
@@ -138,18 +148,21 @@ func newNodeRunConfig[O any](ns *NodeSchema,
|
||||
init: opts.init,
|
||||
i: i,
|
||||
s: s,
|
||||
c: c,
|
||||
t: t,
|
||||
}
|
||||
}
|
||||
|
||||
func newNodeRunConfigWOOpt(ns *NodeSchema,
|
||||
func newNodeRunConfigWOOpt(ns *schema2.NodeSchema,
|
||||
i compose.InvokeWOOpt[map[string]any, map[string]any],
|
||||
s compose.StreamWOOpt[map[string]any, map[string]any],
|
||||
c compose.CollectWOOpt[map[string]any, map[string]any],
|
||||
t compose.TransformWOOpts[map[string]any, map[string]any],
|
||||
opts *newNodeOptions) *nodeRunConfig[any] {
|
||||
var (
|
||||
iWO compose.Invoke[map[string]any, map[string]any, any]
|
||||
sWO compose.Stream[map[string]any, map[string]any, any]
|
||||
cWO compose.Collect[map[string]any, map[string]any, any]
|
||||
tWO compose.Transform[map[string]any, map[string]any, any]
|
||||
)
|
||||
|
||||
@@ -165,13 +178,19 @@ func newNodeRunConfigWOOpt(ns *NodeSchema,
|
||||
}
|
||||
}
|
||||
|
||||
if c != nil {
|
||||
cWO = func(ctx context.Context, in *schema.StreamReader[map[string]any], _ ...any) (out map[string]any, err error) {
|
||||
return c(ctx, in)
|
||||
}
|
||||
}
|
||||
|
||||
if t != nil {
|
||||
tWO = func(ctx context.Context, input *schema.StreamReader[map[string]any], opts ...any) (output *schema.StreamReader[map[string]any], err error) {
|
||||
return t(ctx, input)
|
||||
}
|
||||
}
|
||||
|
||||
return newNodeRunConfig[any](ns, iWO, sWO, tWO, opts)
|
||||
return newNodeRunConfig[any](ns, iWO, sWO, cWO, tWO, opts)
|
||||
}
|
||||
|
||||
type newNodeOptions struct {
|
||||
@@ -180,57 +199,100 @@ type newNodeOptions struct {
|
||||
init []func(context.Context) (context.Context, error)
|
||||
}
|
||||
|
||||
type newNodeOption func(*newNodeOptions)
|
||||
func toNode(ns *schema2.NodeSchema, r any) *Node {
|
||||
iWOpt, _ := r.(nodes.InvokableNodeWOpt)
|
||||
sWOpt, _ := r.(nodes.StreamableNodeWOpt)
|
||||
cWOpt, _ := r.(nodes.CollectableNodeWOpt)
|
||||
tWOpt, _ := r.(nodes.TransformableNodeWOpt)
|
||||
iWOOpt, _ := r.(nodes.InvokableNode)
|
||||
sWOOpt, _ := r.(nodes.StreamableNode)
|
||||
cWOOpt, _ := r.(nodes.CollectableNode)
|
||||
tWOOpt, _ := r.(nodes.TransformableNode)
|
||||
|
||||
func withCallbackInputConverter(f func(context.Context, map[string]any) (map[string]any, error)) newNodeOption {
|
||||
return func(opts *newNodeOptions) {
|
||||
opts.callbackInputConverter = f
|
||||
var wOpt, wOOpt bool
|
||||
if iWOpt != nil || sWOpt != nil || cWOpt != nil || tWOpt != nil {
|
||||
wOpt = true
|
||||
}
|
||||
}
|
||||
func withCallbackOutputConverter(f func(context.Context, map[string]any) (*nodes.StructuredCallbackOutput, error)) newNodeOption {
|
||||
return func(opts *newNodeOptions) {
|
||||
opts.callbackOutputConverter = f
|
||||
if iWOOpt != nil || sWOOpt != nil || cWOOpt != nil || tWOOpt != nil {
|
||||
wOOpt = true
|
||||
}
|
||||
|
||||
if wOpt && wOOpt {
|
||||
panic("a node's different streaming methods needs to be consistent: " +
|
||||
"they should ALL have NodeOption or None should have them")
|
||||
}
|
||||
|
||||
if !wOpt && !wOOpt {
|
||||
panic("a node should implement at least one interface among: InvokableNodeWOpt, StreamableNodeWOpt, CollectableNodeWOpt, TransformableNodeWOpt, InvokableNode, StreamableNode, CollectableNode, TransformableNode")
|
||||
}
|
||||
}
|
||||
func withInit(f func(context.Context) (context.Context, error)) newNodeOption {
|
||||
return func(opts *newNodeOptions) {
|
||||
opts.init = append(opts.init, f)
|
||||
}
|
||||
}
|
||||
|
||||
func invokableNode(ns *NodeSchema, i compose.InvokeWOOpt[map[string]any, map[string]any], opts ...newNodeOption) *Node {
|
||||
options := &newNodeOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
ci, ok := r.(nodes.CallbackInputConverted)
|
||||
if ok {
|
||||
options.callbackInputConverter = ci.ToCallbackInput
|
||||
}
|
||||
|
||||
return newNodeRunConfigWOOpt(ns, i, nil, nil, options).toNode()
|
||||
}
|
||||
|
||||
func invokableNodeWO[O any](ns *NodeSchema, i compose.Invoke[map[string]any, map[string]any, O], opts ...newNodeOption) *Node {
|
||||
options := &newNodeOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
co, ok := r.(nodes.CallbackOutputConverted)
|
||||
if ok {
|
||||
options.callbackOutputConverter = co.ToCallbackOutput
|
||||
}
|
||||
|
||||
return newNodeRunConfig(ns, i, nil, nil, options).toNode()
|
||||
}
|
||||
|
||||
func invokableTransformableNode(ns *NodeSchema, i compose.InvokeWOOpt[map[string]any, map[string]any],
|
||||
t compose.TransformWOOpts[map[string]any, map[string]any], opts ...newNodeOption) *Node {
|
||||
options := &newNodeOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
init, ok := r.(nodes.Initializer)
|
||||
if ok {
|
||||
options.init = append(options.init, init.Init)
|
||||
}
|
||||
return newNodeRunConfigWOOpt(ns, i, nil, t, options).toNode()
|
||||
}
|
||||
|
||||
func invokableStreamableNodeWO[O any](ns *NodeSchema, i compose.Invoke[map[string]any, map[string]any, O], s compose.Stream[map[string]any, map[string]any, O], opts ...newNodeOption) *Node {
|
||||
options := &newNodeOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
if wOpt {
|
||||
var (
|
||||
i compose.Invoke[map[string]any, map[string]any, nodes.NodeOption]
|
||||
s compose.Stream[map[string]any, map[string]any, nodes.NodeOption]
|
||||
c compose.Collect[map[string]any, map[string]any, nodes.NodeOption]
|
||||
t compose.Transform[map[string]any, map[string]any, nodes.NodeOption]
|
||||
)
|
||||
|
||||
if iWOpt != nil {
|
||||
i = iWOpt.Invoke
|
||||
}
|
||||
|
||||
if sWOpt != nil {
|
||||
s = sWOpt.Stream
|
||||
}
|
||||
|
||||
if cWOpt != nil {
|
||||
c = cWOpt.Collect
|
||||
}
|
||||
|
||||
if tWOpt != nil {
|
||||
t = tWOpt.Transform
|
||||
}
|
||||
|
||||
return newNodeRunConfig(ns, i, s, c, t, options).toNode()
|
||||
}
|
||||
return newNodeRunConfig(ns, i, s, nil, options).toNode()
|
||||
|
||||
var (
|
||||
i compose.InvokeWOOpt[map[string]any, map[string]any]
|
||||
s compose.StreamWOOpt[map[string]any, map[string]any]
|
||||
c compose.CollectWOOpt[map[string]any, map[string]any]
|
||||
t compose.TransformWOOpts[map[string]any, map[string]any]
|
||||
)
|
||||
|
||||
if iWOOpt != nil {
|
||||
i = iWOOpt.Invoke
|
||||
}
|
||||
|
||||
if sWOOpt != nil {
|
||||
s = sWOOpt.Stream
|
||||
}
|
||||
|
||||
if cWOOpt != nil {
|
||||
c = cWOOpt.Collect
|
||||
}
|
||||
|
||||
if tWOOpt != nil {
|
||||
t = tWOOpt.Transform
|
||||
}
|
||||
|
||||
return newNodeRunConfigWOOpt(ns, i, s, c, t, options).toNode()
|
||||
}
|
||||
|
||||
func (nc *nodeRunConfig[O]) invoke() func(ctx context.Context, input map[string]any, opts ...O) (output map[string]any, err error) {
|
||||
@@ -375,10 +437,8 @@ func (nc *nodeRunConfig[O]) transform() func(ctx context.Context, input *schema.
|
||||
func (nc *nodeRunConfig[O]) toNode() *Node {
|
||||
var opts []compose.LambdaOpt
|
||||
opts = append(opts, compose.WithLambdaType(string(nc.nodeType)))
|
||||
opts = append(opts, compose.WithLambdaCallbackEnable(true))
|
||||
|
||||
if nc.callbackEnabled {
|
||||
opts = append(opts, compose.WithLambdaCallbackEnable(true))
|
||||
}
|
||||
l, err := compose.AnyLambda(nc.invoke(), nc.stream(), nil, nc.transform(), opts...)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to create lambda for node %s, err: %v", nc.nodeName, err))
|
||||
@@ -406,9 +466,6 @@ func newNodeRunner[O any](ctx context.Context, cfg *nodeRunConfig[O]) (context.C
|
||||
}
|
||||
|
||||
func (r *nodeRunner[O]) onStart(ctx context.Context, input map[string]any) (context.Context, error) {
|
||||
if !r.callbackEnabled {
|
||||
return ctx, nil
|
||||
}
|
||||
if r.callbackInputConverter != nil {
|
||||
convertedInput, err := r.callbackInputConverter(ctx, input)
|
||||
if err != nil {
|
||||
@@ -425,10 +482,6 @@ func (r *nodeRunner[O]) onStart(ctx context.Context, input map[string]any) (cont
|
||||
|
||||
func (r *nodeRunner[O]) onStartStream(ctx context.Context, input *schema.StreamReader[map[string]any]) (
|
||||
context.Context, *schema.StreamReader[map[string]any], error) {
|
||||
if !r.callbackEnabled {
|
||||
return ctx, input, nil
|
||||
}
|
||||
|
||||
if r.callbackInputConverter != nil {
|
||||
copied := input.Copy(2)
|
||||
realConverter := func(ctx context.Context) func(map[string]any) (map[string]any, error) {
|
||||
@@ -580,14 +633,10 @@ func (r *nodeRunner[O]) transform(ctx context.Context, input *schema.StreamReade
|
||||
}
|
||||
|
||||
func (r *nodeRunner[O]) onEnd(ctx context.Context, output map[string]any) error {
|
||||
if r.errProcessType == vo.ErrorProcessTypeExceptionBranch || r.errProcessType == vo.ErrorProcessTypeDefault {
|
||||
if r.errProcessType == vo.ErrorProcessTypeExceptionBranch || r.errProcessType == vo.ErrorProcessTypeReturnDefaultData {
|
||||
output["isSuccess"] = true
|
||||
}
|
||||
|
||||
if !r.callbackEnabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
if r.callbackOutputConverter != nil {
|
||||
convertedOutput, err := r.callbackOutputConverter(ctx, output)
|
||||
if err != nil {
|
||||
@@ -603,15 +652,11 @@ func (r *nodeRunner[O]) onEnd(ctx context.Context, output map[string]any) error
|
||||
|
||||
func (r *nodeRunner[O]) onEndStream(ctx context.Context, output *schema.StreamReader[map[string]any]) (
|
||||
*schema.StreamReader[map[string]any], error) {
|
||||
if r.errProcessType == vo.ErrorProcessTypeExceptionBranch || r.errProcessType == vo.ErrorProcessTypeDefault {
|
||||
if r.errProcessType == vo.ErrorProcessTypeExceptionBranch || r.errProcessType == vo.ErrorProcessTypeReturnDefaultData {
|
||||
flag := schema.StreamReaderFromArray([]map[string]any{{"isSuccess": true}})
|
||||
output = schema.MergeStreamReaders([]*schema.StreamReader[map[string]any]{flag, output})
|
||||
}
|
||||
|
||||
if !r.callbackEnabled {
|
||||
return output, nil
|
||||
}
|
||||
|
||||
if r.callbackOutputConverter != nil {
|
||||
copied := output.Copy(2)
|
||||
realConverter := func(ctx context.Context) func(map[string]any) (*nodes.StructuredCallbackOutput, error) {
|
||||
@@ -632,9 +677,7 @@ func (r *nodeRunner[O]) onEndStream(ctx context.Context, output *schema.StreamRe
|
||||
|
||||
func (r *nodeRunner[O]) onError(ctx context.Context, err error) (map[string]any, bool) {
|
||||
if r.interrupted {
|
||||
if r.callbackEnabled {
|
||||
_ = callbacks.OnError(ctx, err)
|
||||
}
|
||||
_ = callbacks.OnError(ctx, err)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
@@ -653,22 +696,20 @@ func (r *nodeRunner[O]) onError(ctx context.Context, err error) (map[string]any,
|
||||
msg := sErr.Msg()
|
||||
|
||||
switch r.errProcessType {
|
||||
case vo.ErrorProcessTypeDefault:
|
||||
case vo.ErrorProcessTypeReturnDefaultData:
|
||||
d := r.dataOnErr(ctx)
|
||||
d["errorBody"] = map[string]any{
|
||||
"errorMessage": msg,
|
||||
"errorCode": code,
|
||||
}
|
||||
d["isSuccess"] = false
|
||||
if r.callbackEnabled {
|
||||
sErr = sErr.ChangeErrLevel(vo.LevelWarn)
|
||||
sOutput := &nodes.StructuredCallbackOutput{
|
||||
Output: d,
|
||||
RawOutput: d,
|
||||
Error: sErr,
|
||||
}
|
||||
_ = callbacks.OnEnd(ctx, sOutput)
|
||||
sErr = sErr.ChangeErrLevel(vo.LevelWarn)
|
||||
sOutput := &nodes.StructuredCallbackOutput{
|
||||
Output: d,
|
||||
RawOutput: d,
|
||||
Error: sErr,
|
||||
}
|
||||
_ = callbacks.OnEnd(ctx, sOutput)
|
||||
return d, true
|
||||
case vo.ErrorProcessTypeExceptionBranch:
|
||||
s := make(map[string]any)
|
||||
@@ -677,20 +718,16 @@ func (r *nodeRunner[O]) onError(ctx context.Context, err error) (map[string]any,
|
||||
"errorCode": code,
|
||||
}
|
||||
s["isSuccess"] = false
|
||||
if r.callbackEnabled {
|
||||
sErr = sErr.ChangeErrLevel(vo.LevelWarn)
|
||||
sOutput := &nodes.StructuredCallbackOutput{
|
||||
Output: s,
|
||||
RawOutput: s,
|
||||
Error: sErr,
|
||||
}
|
||||
_ = callbacks.OnEnd(ctx, sOutput)
|
||||
sErr = sErr.ChangeErrLevel(vo.LevelWarn)
|
||||
sOutput := &nodes.StructuredCallbackOutput{
|
||||
Output: s,
|
||||
RawOutput: s,
|
||||
Error: sErr,
|
||||
}
|
||||
_ = callbacks.OnEnd(ctx, sOutput)
|
||||
return s, true
|
||||
default:
|
||||
if r.callbackEnabled {
|
||||
_ = callbacks.OnError(ctx, sErr)
|
||||
}
|
||||
_ = callbacks.OnError(ctx, sErr)
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,580 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package compose
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime/debug"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/batch"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/code"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/conversation"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/emitter"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/httprequester"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/intentdetector"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/json"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/receiver"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/subworkflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/textprocessor"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableaggregator"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableassigner"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
|
||||
)
|
||||
|
||||
type NodeSchema struct {
|
||||
Key vo.NodeKey `json:"key"`
|
||||
Name string `json:"name"`
|
||||
|
||||
Type entity.NodeType `json:"type"`
|
||||
|
||||
// Configs are node specific configurations with pre-defined config key and config value.
|
||||
// Will not participate in request-time field mapping, nor as node's static values.
|
||||
// In a word, these Configs are INTERNAL to node's implementation, the workflow layer is not aware of them.
|
||||
Configs any `json:"configs,omitempty"`
|
||||
|
||||
InputTypes map[string]*vo.TypeInfo `json:"input_types,omitempty"`
|
||||
InputSources []*vo.FieldInfo `json:"input_sources,omitempty"`
|
||||
|
||||
OutputTypes map[string]*vo.TypeInfo `json:"output_types,omitempty"`
|
||||
OutputSources []*vo.FieldInfo `json:"output_sources,omitempty"` // only applicable to composite nodes such as Batch or Loop
|
||||
|
||||
ExceptionConfigs *ExceptionConfig `json:"exception_configs,omitempty"` // generic configurations applicable to most nodes
|
||||
StreamConfigs *StreamConfig `json:"stream_configs,omitempty"`
|
||||
|
||||
SubWorkflowBasic *entity.WorkflowBasic `json:"sub_workflow_basic,omitempty"`
|
||||
SubWorkflowSchema *WorkflowSchema `json:"sub_workflow_schema,omitempty"`
|
||||
|
||||
Lambda *compose.Lambda // not serializable, used for internal test.
|
||||
}
|
||||
|
||||
type ExceptionConfig struct {
|
||||
TimeoutMS int64 `json:"timeout_ms,omitempty"` // timeout in milliseconds, 0 means no timeout
|
||||
MaxRetry int64 `json:"max_retry,omitempty"` // max retry times, 0 means no retry
|
||||
ProcessType *vo.ErrorProcessType `json:"process_type,omitempty"` // error process type, 0 means throw error
|
||||
DataOnErr string `json:"data_on_err,omitempty"` // data to return when error, effective when ProcessType==Default occurs
|
||||
}
|
||||
|
||||
type StreamConfig struct {
|
||||
// whether this node has the ability to produce genuine streaming output.
|
||||
// not include nodes that only passes stream down as they receives them
|
||||
CanGeneratesStream bool `json:"can_generates_stream,omitempty"`
|
||||
// whether this node prioritize streaming input over none-streaming input.
|
||||
// not include nodes that can accept both and does not have preference.
|
||||
RequireStreamingInput bool `json:"can_process_stream,omitempty"`
|
||||
}
|
||||
|
||||
type Node struct {
|
||||
Lambda *compose.Lambda
|
||||
}
|
||||
|
||||
func (s *NodeSchema) New(ctx context.Context, inner compose.Runnable[map[string]any, map[string]any],
|
||||
sc *WorkflowSchema, deps *dependencyInfo) (_ *Node, err error) {
|
||||
defer func() {
|
||||
if panicErr := recover(); panicErr != nil {
|
||||
err = safego.NewPanicErr(panicErr, debug.Stack())
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
err = vo.WrapIfNeeded(errno.ErrCreateNodeFail, err, errorx.KV("node_name", s.Name), errorx.KV("cause", err.Error()))
|
||||
}
|
||||
}()
|
||||
|
||||
if m := entity.NodeMetaByNodeType(s.Type); m != nil && m.InputSourceAware {
|
||||
if err = s.SetFullSources(sc.GetAllNodes(), deps); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
switch s.Type {
|
||||
case entity.NodeTypeLambda:
|
||||
if s.Lambda == nil {
|
||||
return nil, fmt.Errorf("lambda is not defined for NodeTypeLambda")
|
||||
}
|
||||
|
||||
return &Node{Lambda: s.Lambda}, nil
|
||||
case entity.NodeTypeLLM:
|
||||
conf, err := s.ToLLMConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
l, err := llm.New(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableStreamableNodeWO(s, l.Chat, l.ChatStream, withCallbackOutputConverter(l.ToCallbackOutput)), nil
|
||||
case entity.NodeTypeSelector:
|
||||
conf := s.ToSelectorConfig()
|
||||
|
||||
sl, err := selector.NewSelector(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableNode(s, sl.Select, withCallbackInputConverter(s.toSelectorCallbackInput(sc)), withCallbackOutputConverter(sl.ToCallbackOutput)), nil
|
||||
case entity.NodeTypeBatch:
|
||||
if inner == nil {
|
||||
return nil, fmt.Errorf("inner workflow must not be nil when creating batch node")
|
||||
}
|
||||
|
||||
conf, err := s.ToBatchConfig(inner)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b, err := batch.NewBatch(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableNodeWO(s, b.Execute, withCallbackInputConverter(b.ToCallbackInput)), nil
|
||||
case entity.NodeTypeVariableAggregator:
|
||||
conf, err := s.ToVariableAggregatorConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
va, err := variableaggregator.NewVariableAggregator(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableTransformableNode(s, va.Invoke, va.Transform,
|
||||
withCallbackInputConverter(va.ToCallbackInput),
|
||||
withCallbackOutputConverter(va.ToCallbackOutput),
|
||||
withInit(va.Init)), nil
|
||||
case entity.NodeTypeTextProcessor:
|
||||
conf, err := s.ToTextProcessorConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tp, err := textprocessor.NewTextProcessor(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableNode(s, tp.Invoke), nil
|
||||
case entity.NodeTypeHTTPRequester:
|
||||
conf, err := s.ToHTTPRequesterConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hr, err := httprequester.NewHTTPRequester(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableNode(s, hr.Invoke, withCallbackInputConverter(hr.ToCallbackInput), withCallbackOutputConverter(hr.ToCallbackOutput)), nil
|
||||
case entity.NodeTypeContinue:
|
||||
i := func(ctx context.Context, in map[string]any) (map[string]any, error) {
|
||||
return map[string]any{}, nil
|
||||
}
|
||||
return invokableNode(s, i), nil
|
||||
case entity.NodeTypeBreak:
|
||||
b, err := loop.NewBreak(ctx, &nodes.ParentIntermediateStore{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, b.DoBreak), nil
|
||||
case entity.NodeTypeVariableAssigner:
|
||||
handler := variable.GetVariableHandler()
|
||||
|
||||
conf, err := s.ToVariableAssignerConfig(handler)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
va, err := variableassigner.NewVariableAssigner(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableNode(s, va.Assign), nil
|
||||
case entity.NodeTypeVariableAssignerWithinLoop:
|
||||
conf, err := s.ToVariableAssignerInLoopConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
va, err := variableassigner.NewVariableAssignerInLoop(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableNode(s, va.Assign), nil
|
||||
case entity.NodeTypeLoop:
|
||||
conf, err := s.ToLoopConfig(inner)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l, err := loop.NewLoop(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNodeWO(s, l.Execute, withCallbackInputConverter(l.ToCallbackInput)), nil
|
||||
case entity.NodeTypeQuestionAnswer:
|
||||
conf, err := s.ToQAConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
qA, err := qa.NewQuestionAnswer(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, qA.Execute, withCallbackOutputConverter(qA.ToCallbackOutput)), nil
|
||||
case entity.NodeTypeInputReceiver:
|
||||
conf, err := s.ToInputReceiverConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inputR, err := receiver.New(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, inputR.Invoke, withCallbackOutputConverter(inputR.ToCallbackOutput)), nil
|
||||
case entity.NodeTypeOutputEmitter:
|
||||
conf, err := s.ToOutputEmitterConfig(sc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
e, err := emitter.New(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableTransformableNode(s, e.Emit, e.EmitStream), nil
|
||||
case entity.NodeTypeEntry:
|
||||
conf, err := s.ToEntryConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
e, err := entry.NewEntry(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, e.Invoke), nil
|
||||
case entity.NodeTypeExit:
|
||||
terminalPlan := mustGetKey[vo.TerminatePlan]("TerminalPlan", s.Configs)
|
||||
if terminalPlan == vo.ReturnVariables {
|
||||
i := func(ctx context.Context, in map[string]any) (map[string]any, error) {
|
||||
if in == nil {
|
||||
return map[string]any{}, nil
|
||||
}
|
||||
return in, nil
|
||||
}
|
||||
return invokableNode(s, i), nil
|
||||
}
|
||||
|
||||
conf, err := s.ToOutputEmitterConfig(sc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
e, err := emitter.New(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableTransformableNode(s, e.Emit, e.EmitStream), nil
|
||||
case entity.NodeTypeDatabaseCustomSQL:
|
||||
conf, err := s.ToDatabaseCustomSQLConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sqlER, err := database.NewCustomSQL(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, sqlER.Execute), nil
|
||||
case entity.NodeTypeDatabaseQuery:
|
||||
conf, err := s.ToDatabaseQueryConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
query, err := database.NewQuery(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableNode(s, query.Query, withCallbackInputConverter(query.ToCallbackInput)), nil
|
||||
case entity.NodeTypeDatabaseInsert:
|
||||
conf, err := s.ToDatabaseInsertConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
insert, err := database.NewInsert(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableNode(s, insert.Insert, withCallbackInputConverter(insert.ToCallbackInput)), nil
|
||||
case entity.NodeTypeDatabaseUpdate:
|
||||
conf, err := s.ToDatabaseUpdateConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
update, err := database.NewUpdate(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, update.Update, withCallbackInputConverter(update.ToCallbackInput)), nil
|
||||
case entity.NodeTypeDatabaseDelete:
|
||||
conf, err := s.ToDatabaseDeleteConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
del, err := database.NewDelete(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableNode(s, del.Delete, withCallbackInputConverter(del.ToCallbackInput)), nil
|
||||
case entity.NodeTypeKnowledgeIndexer:
|
||||
conf, err := s.ToKnowledgeIndexerConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
w, err := knowledge.NewKnowledgeIndexer(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, w.Store), nil
|
||||
case entity.NodeTypeKnowledgeRetriever:
|
||||
conf, err := s.ToKnowledgeRetrieveConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r, err := knowledge.NewKnowledgeRetrieve(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, r.Retrieve), nil
|
||||
case entity.NodeTypeKnowledgeDeleter:
|
||||
conf, err := s.ToKnowledgeDeleterConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r, err := knowledge.NewKnowledgeDeleter(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, r.Delete), nil
|
||||
case entity.NodeTypeCodeRunner:
|
||||
conf, err := s.ToCodeRunnerConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r, err := code.NewCodeRunner(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
initFn := func(ctx context.Context) (context.Context, error) {
|
||||
return ctxcache.Init(ctx), nil
|
||||
}
|
||||
return invokableNode(s, r.RunCode, withCallbackOutputConverter(r.ToCallbackOutput), withInit(initFn)), nil
|
||||
case entity.NodeTypePlugin:
|
||||
conf, err := s.ToPluginConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r, err := plugin.NewPlugin(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, r.Invoke), nil
|
||||
case entity.NodeTypeCreateConversation:
|
||||
conf, err := s.ToCreateConversationConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r, err := conversation.NewCreateConversation(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, r.Create), nil
|
||||
case entity.NodeTypeMessageList:
|
||||
conf, err := s.ToMessageListConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r, err := conversation.NewMessageList(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, r.List), nil
|
||||
case entity.NodeTypeClearMessage:
|
||||
conf, err := s.ToClearMessageConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r, err := conversation.NewClearMessage(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, r.Clear), nil
|
||||
case entity.NodeTypeIntentDetector:
|
||||
conf, err := s.ToIntentDetectorConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r, err := intentdetector.NewIntentDetector(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableNode(s, r.Invoke), nil
|
||||
case entity.NodeTypeSubWorkflow:
|
||||
conf, err := s.ToSubWorkflowConfig(ctx, sc.requireCheckPoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r, err := subworkflow.NewSubWorkflow(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableStreamableNodeWO(s, r.Invoke, r.Stream), nil
|
||||
case entity.NodeTypeJsonSerialization:
|
||||
conf, err := s.ToJsonSerializationConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
js, err := json.NewJsonSerializer(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, js.Invoke), nil
|
||||
case entity.NodeTypeJsonDeserialization:
|
||||
conf, err := s.ToJsonDeserializationConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jd, err := json.NewJsonDeserializer(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
initFn := func(ctx context.Context) (context.Context, error) {
|
||||
return ctxcache.Init(ctx), nil
|
||||
}
|
||||
return invokableNode(s, jd.Invoke, withCallbackOutputConverter(jd.ToCallbackOutput), withInit(initFn)), nil
|
||||
default:
|
||||
panic("not implemented")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NodeSchema) IsEnableUserQuery() bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
if s.Type != entity.NodeTypeEntry {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(s.OutputSources) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, source := range s.OutputSources {
|
||||
fieldPath := source.Path
|
||||
if len(fieldPath) == 1 && (fieldPath[0] == "BOT_USER_INPUT" || fieldPath[0] == "USER_INPUT") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
|
||||
}
|
||||
|
||||
func (s *NodeSchema) IsEnableChatHistory() bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
switch s.Type {
|
||||
|
||||
case entity.NodeTypeLLM:
|
||||
llmParam := mustGetKey[*model.LLMParams]("LLMParams", s.Configs)
|
||||
return llmParam.EnableChatHistory
|
||||
case entity.NodeTypeIntentDetector:
|
||||
llmParam := mustGetKey[*model.LLMParams]("LLMParams", s.Configs)
|
||||
return llmParam.EnableChatHistory
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (s *NodeSchema) IsRefGlobalVariable() bool {
|
||||
for _, source := range s.InputSources {
|
||||
if source.IsRefGlobalVariable() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
for _, source := range s.OutputSources {
|
||||
if source.IsRefGlobalVariable() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *NodeSchema) requireCheckpoint() bool {
|
||||
if s.Type == entity.NodeTypeQuestionAnswer || s.Type == entity.NodeTypeInputReceiver {
|
||||
return true
|
||||
}
|
||||
|
||||
if s.Type == entity.NodeTypeLLM {
|
||||
fcParams := getKeyOrZero[*vo.FCParam]("FCParam", s.Configs)
|
||||
if fcParams != nil && fcParams.WorkflowFCParam != nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if s.Type == entity.NodeTypeSubWorkflow {
|
||||
s.SubWorkflowSchema.Init()
|
||||
if s.SubWorkflowSchema.requireCheckPoint {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
@@ -32,8 +32,10 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/receiver"
|
||||
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
)
|
||||
|
||||
@@ -46,9 +48,9 @@ type State struct {
|
||||
InterruptEvents map[vo.NodeKey]*entity.InterruptEvent `json:"interrupt_events,omitempty"`
|
||||
NestedWorkflowStates map[vo.NodeKey]*nodes.NestedWorkflowState `json:"nested_workflow_states,omitempty"`
|
||||
|
||||
ExecutedNodes map[vo.NodeKey]bool `json:"executed_nodes,omitempty"`
|
||||
SourceInfos map[vo.NodeKey]map[string]*nodes.SourceInfo `json:"source_infos,omitempty"`
|
||||
GroupChoices map[vo.NodeKey]map[string]int `json:"group_choices,omitempty"`
|
||||
ExecutedNodes map[vo.NodeKey]bool `json:"executed_nodes,omitempty"`
|
||||
SourceInfos map[vo.NodeKey]map[string]*schema2.SourceInfo `json:"source_infos,omitempty"`
|
||||
GroupChoices map[vo.NodeKey]map[string]int `json:"group_choices,omitempty"`
|
||||
|
||||
ToolInterruptEvents map[vo.NodeKey]map[string] /*ToolCallID*/ *entity.ToolInterruptEvent `json:"tool_interrupt_events,omitempty"`
|
||||
LLMToResumeData map[vo.NodeKey]string `json:"llm_to_resume_data,omitempty"`
|
||||
@@ -71,8 +73,8 @@ func init() {
|
||||
_ = compose.RegisterSerializableType[*model.TokenUsage]("model_token_usage")
|
||||
_ = compose.RegisterSerializableType[*nodes.NestedWorkflowState]("composite_state")
|
||||
_ = compose.RegisterSerializableType[*compose.InterruptInfo]("interrupt_info")
|
||||
_ = compose.RegisterSerializableType[*nodes.SourceInfo]("source_info")
|
||||
_ = compose.RegisterSerializableType[nodes.FieldStreamType]("field_stream_type")
|
||||
_ = compose.RegisterSerializableType[*schema2.SourceInfo]("source_info")
|
||||
_ = compose.RegisterSerializableType[schema2.FieldStreamType]("field_stream_type")
|
||||
_ = compose.RegisterSerializableType[compose.FieldPath]("field_path")
|
||||
_ = compose.RegisterSerializableType[*entity.WorkflowBasic]("workflow_basic")
|
||||
_ = compose.RegisterSerializableType[vo.TerminatePlan]("terminate_plan")
|
||||
@@ -162,41 +164,41 @@ func (s *State) GetDynamicChoice(nodeKey vo.NodeKey) map[string]int {
|
||||
return s.GroupChoices[nodeKey]
|
||||
}
|
||||
|
||||
func (s *State) GetDynamicStreamType(nodeKey vo.NodeKey, group string) (nodes.FieldStreamType, error) {
|
||||
func (s *State) GetDynamicStreamType(nodeKey vo.NodeKey, group string) (schema2.FieldStreamType, error) {
|
||||
choices, ok := s.GroupChoices[nodeKey]
|
||||
if !ok {
|
||||
return nodes.FieldMaybeStream, fmt.Errorf("choice not found for node %s", nodeKey)
|
||||
return schema2.FieldMaybeStream, fmt.Errorf("choice not found for node %s", nodeKey)
|
||||
}
|
||||
|
||||
choice, ok := choices[group]
|
||||
if !ok {
|
||||
return nodes.FieldMaybeStream, fmt.Errorf("choice not found for node %s and group %s", nodeKey, group)
|
||||
return schema2.FieldMaybeStream, fmt.Errorf("choice not found for node %s and group %s", nodeKey, group)
|
||||
}
|
||||
|
||||
if choice == -1 { // this group picks none of the elements
|
||||
return nodes.FieldNotStream, nil
|
||||
return schema2.FieldNotStream, nil
|
||||
}
|
||||
|
||||
sInfos, ok := s.SourceInfos[nodeKey]
|
||||
if !ok {
|
||||
return nodes.FieldMaybeStream, fmt.Errorf("source infos not found for node %s", nodeKey)
|
||||
return schema2.FieldMaybeStream, fmt.Errorf("source infos not found for node %s", nodeKey)
|
||||
}
|
||||
|
||||
groupInfo, ok := sInfos[group]
|
||||
if !ok {
|
||||
return nodes.FieldMaybeStream, fmt.Errorf("source infos not found for node %s and group %s", nodeKey, group)
|
||||
return schema2.FieldMaybeStream, fmt.Errorf("source infos not found for node %s and group %s", nodeKey, group)
|
||||
}
|
||||
|
||||
if groupInfo.SubSources == nil {
|
||||
return nodes.FieldNotStream, fmt.Errorf("dynamic group %s of node %s does not contain any sub sources", group, nodeKey)
|
||||
return schema2.FieldNotStream, fmt.Errorf("dynamic group %s of node %s does not contain any sub sources", group, nodeKey)
|
||||
}
|
||||
|
||||
subInfo, ok := groupInfo.SubSources[strconv.Itoa(choice)]
|
||||
if !ok {
|
||||
return nodes.FieldNotStream, fmt.Errorf("dynamic group %s of node %s does not contain sub source for choice %d", group, nodeKey, choice)
|
||||
return schema2.FieldNotStream, fmt.Errorf("dynamic group %s of node %s does not contain sub source for choice %d", group, nodeKey, choice)
|
||||
}
|
||||
|
||||
if subInfo.FieldType != nodes.FieldMaybeStream {
|
||||
if subInfo.FieldType != schema2.FieldMaybeStream {
|
||||
return subInfo.FieldType, nil
|
||||
}
|
||||
|
||||
@@ -211,8 +213,8 @@ func (s *State) GetDynamicStreamType(nodeKey vo.NodeKey, group string) (nodes.Fi
|
||||
return s.GetDynamicStreamType(subInfo.FromNodeKey, subInfo.FromPath[0])
|
||||
}
|
||||
|
||||
func (s *State) GetAllDynamicStreamTypes(nodeKey vo.NodeKey) (map[string]nodes.FieldStreamType, error) {
|
||||
result := make(map[string]nodes.FieldStreamType)
|
||||
func (s *State) GetAllDynamicStreamTypes(nodeKey vo.NodeKey) (map[string]schema2.FieldStreamType, error) {
|
||||
result := make(map[string]schema2.FieldStreamType)
|
||||
choices, ok := s.GroupChoices[nodeKey]
|
||||
if !ok {
|
||||
return result, nil
|
||||
@@ -269,7 +271,7 @@ func GenState() compose.GenLocalState[*State] {
|
||||
InterruptEvents: make(map[vo.NodeKey]*entity.InterruptEvent),
|
||||
NestedWorkflowStates: make(map[vo.NodeKey]*nodes.NestedWorkflowState),
|
||||
ExecutedNodes: make(map[vo.NodeKey]bool),
|
||||
SourceInfos: make(map[vo.NodeKey]map[string]*nodes.SourceInfo),
|
||||
SourceInfos: make(map[vo.NodeKey]map[string]*schema2.SourceInfo),
|
||||
GroupChoices: make(map[vo.NodeKey]map[string]int),
|
||||
ToolInterruptEvents: make(map[vo.NodeKey]map[string]*entity.ToolInterruptEvent),
|
||||
LLMToResumeData: make(map[vo.NodeKey]string),
|
||||
@@ -277,7 +279,7 @@ func GenState() compose.GenLocalState[*State] {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt {
|
||||
func statePreHandler(s *schema2.NodeSchema, stream bool) compose.GraphAddNodeOpt {
|
||||
var (
|
||||
handlers []compose.StatePreHandler[map[string]any, *State]
|
||||
streamHandlers []compose.StreamStatePreHandler[map[string]any, *State]
|
||||
@@ -314,7 +316,7 @@ func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt {
|
||||
}
|
||||
return in, nil
|
||||
})
|
||||
} else if s.Type == entity.NodeTypeBatch || s.Type == entity.NodeTypeLoop {
|
||||
} else if entity.NodeMetaByNodeType(s.Type).IsComposite {
|
||||
handlers = append(handlers, func(ctx context.Context, in map[string]any, state *State) (map[string]any, error) {
|
||||
if _, ok := state.Inputs[s.Key]; !ok { // first execution, store input for potential resume later
|
||||
state.Inputs[s.Key] = in
|
||||
@@ -329,7 +331,7 @@ func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt {
|
||||
}
|
||||
|
||||
if len(handlers) > 0 || !stream {
|
||||
handlerForVars := s.statePreHandlerForVars()
|
||||
handlerForVars := statePreHandlerForVars(s)
|
||||
if handlerForVars != nil {
|
||||
handlers = append(handlers, handlerForVars)
|
||||
}
|
||||
@@ -349,12 +351,12 @@ func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt {
|
||||
|
||||
if s.Type == entity.NodeTypeVariableAggregator {
|
||||
streamHandlers = append(streamHandlers, func(ctx context.Context, in *schema.StreamReader[map[string]any], state *State) (*schema.StreamReader[map[string]any], error) {
|
||||
state.SourceInfos[s.Key] = mustGetKey[map[string]*nodes.SourceInfo]("FullSources", s.Configs)
|
||||
state.SourceInfos[s.Key] = s.FullSources
|
||||
return in, nil
|
||||
})
|
||||
}
|
||||
|
||||
handlerForVars := s.streamStatePreHandlerForVars()
|
||||
handlerForVars := streamStatePreHandlerForVars(s)
|
||||
if handlerForVars != nil {
|
||||
streamHandlers = append(streamHandlers, handlerForVars)
|
||||
}
|
||||
@@ -381,7 +383,7 @@ func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) statePreHandlerForVars() compose.StatePreHandler[map[string]any, *State] {
|
||||
func statePreHandlerForVars(s *schema2.NodeSchema) compose.StatePreHandler[map[string]any, *State] {
|
||||
// checkout the node's inputs, if it has any variable, use the state's variableHandler to get the variables and set them to the input
|
||||
var vars []*vo.FieldInfo
|
||||
for _, input := range s.InputSources {
|
||||
@@ -456,7 +458,7 @@ func (s *NodeSchema) statePreHandlerForVars() compose.StatePreHandler[map[string
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NodeSchema) streamStatePreHandlerForVars() compose.StreamStatePreHandler[map[string]any, *State] {
|
||||
func streamStatePreHandlerForVars(s *schema2.NodeSchema) compose.StreamStatePreHandler[map[string]any, *State] {
|
||||
// checkout the node's inputs, if it has any variables, get the variables and merge them with the input
|
||||
var vars []*vo.FieldInfo
|
||||
for _, input := range s.InputSources {
|
||||
@@ -533,7 +535,7 @@ func (s *NodeSchema) streamStatePreHandlerForVars() compose.StreamStatePreHandle
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamStatePreHandler[map[string]any, *State] {
|
||||
func streamStatePreHandlerForStreamSources(s *schema2.NodeSchema) compose.StreamStatePreHandler[map[string]any, *State] {
|
||||
// if it does not have source info, do not add this pre handler
|
||||
if s.Configs == nil {
|
||||
return nil
|
||||
@@ -543,7 +545,7 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
|
||||
case entity.NodeTypeVariableAggregator, entity.NodeTypeOutputEmitter:
|
||||
return nil
|
||||
case entity.NodeTypeExit:
|
||||
terminatePlan := mustGetKey[vo.TerminatePlan]("TerminalPlan", s.Configs)
|
||||
terminatePlan := s.Configs.(*exit.Config).TerminatePlan
|
||||
if terminatePlan != vo.ReturnVariables {
|
||||
return nil
|
||||
}
|
||||
@@ -551,7 +553,7 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
|
||||
// all other node can only accept non-stream inputs, relying on Eino's automatically stream concatenation.
|
||||
}
|
||||
|
||||
sourceInfo := getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs)
|
||||
sourceInfo := s.FullSources
|
||||
if len(sourceInfo) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -566,10 +568,10 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
|
||||
|
||||
var (
|
||||
anyStream bool
|
||||
checker func(source *nodes.SourceInfo) bool
|
||||
checker func(source *schema2.SourceInfo) bool
|
||||
)
|
||||
checker = func(source *nodes.SourceInfo) bool {
|
||||
if source.FieldType != nodes.FieldNotStream {
|
||||
checker = func(source *schema2.SourceInfo) bool {
|
||||
if source.FieldType != schema2.FieldNotStream {
|
||||
return true
|
||||
}
|
||||
for _, subSource := range source.SubSources {
|
||||
@@ -594,8 +596,8 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
|
||||
return func(ctx context.Context, in *schema.StreamReader[map[string]any], state *State) (*schema.StreamReader[map[string]any], error) {
|
||||
resolved := map[string]resolvedStreamSource{}
|
||||
|
||||
var resolver func(source nodes.SourceInfo) (result *resolvedStreamSource, err error)
|
||||
resolver = func(source nodes.SourceInfo) (result *resolvedStreamSource, err error) {
|
||||
var resolver func(source schema2.SourceInfo) (result *resolvedStreamSource, err error)
|
||||
resolver = func(source schema2.SourceInfo) (result *resolvedStreamSource, err error) {
|
||||
if source.IsIntermediate {
|
||||
result = &resolvedStreamSource{
|
||||
intermediate: true,
|
||||
@@ -615,14 +617,14 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
|
||||
}
|
||||
|
||||
streamType := source.FieldType
|
||||
if streamType == nodes.FieldMaybeStream {
|
||||
if streamType == schema2.FieldMaybeStream {
|
||||
streamType, err = state.GetDynamicStreamType(source.FromNodeKey, source.FromPath[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if streamType == nodes.FieldNotStream {
|
||||
if streamType == schema2.FieldNotStream {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -690,7 +692,7 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NodeSchema) StatePostHandler(stream bool) compose.GraphAddNodeOpt {
|
||||
func statePostHandler(s *schema2.NodeSchema, stream bool) compose.GraphAddNodeOpt {
|
||||
var (
|
||||
handlers []compose.StatePostHandler[map[string]any, *State]
|
||||
streamHandlers []compose.StreamStatePostHandler[map[string]any, *State]
|
||||
@@ -702,7 +704,7 @@ func (s *NodeSchema) StatePostHandler(stream bool) compose.GraphAddNodeOpt {
|
||||
return out, nil
|
||||
})
|
||||
|
||||
forVars := s.streamStatePostHandlerForVars()
|
||||
forVars := streamStatePostHandlerForVars(s)
|
||||
if forVars != nil {
|
||||
streamHandlers = append(streamHandlers, forVars)
|
||||
}
|
||||
@@ -725,7 +727,7 @@ func (s *NodeSchema) StatePostHandler(stream bool) compose.GraphAddNodeOpt {
|
||||
return out, nil
|
||||
})
|
||||
|
||||
forVars := s.statePostHandlerForVars()
|
||||
forVars := statePostHandlerForVars(s)
|
||||
if forVars != nil {
|
||||
handlers = append(handlers, forVars)
|
||||
}
|
||||
@@ -745,7 +747,7 @@ func (s *NodeSchema) StatePostHandler(stream bool) compose.GraphAddNodeOpt {
|
||||
return compose.WithStatePostHandler(handler)
|
||||
}
|
||||
|
||||
func (s *NodeSchema) statePostHandlerForVars() compose.StatePostHandler[map[string]any, *State] {
|
||||
func statePostHandlerForVars(s *schema2.NodeSchema) compose.StatePostHandler[map[string]any, *State] {
|
||||
// checkout the node's output sources, if it has any variable,
|
||||
// use the state's variableHandler to get the variables and set them to the output
|
||||
var vars []*vo.FieldInfo
|
||||
@@ -823,7 +825,7 @@ func (s *NodeSchema) statePostHandlerForVars() compose.StatePostHandler[map[stri
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NodeSchema) streamStatePostHandlerForVars() compose.StreamStatePostHandler[map[string]any, *State] {
|
||||
func streamStatePostHandlerForVars(s *schema2.NodeSchema) compose.StreamStatePostHandler[map[string]any, *State] {
|
||||
// checkout the node's output sources, if it has any variables, get the variables and merge them with the output
|
||||
var vars []*vo.FieldInfo
|
||||
for _, output := range s.OutputSources {
|
||||
|
||||
@@ -21,19 +21,20 @@ import (
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
// SetFullSources calculates REAL input sources for a node.
|
||||
// GetFullSources calculates REAL input sources for a node.
|
||||
// It may be different from a NodeSchema's InputSources because of the following reasons:
|
||||
// 1. a inner node under composite node may refer to a field from a node in its parent workflow,
|
||||
// this is instead routed to and sourced from the inner workflow's start node.
|
||||
// 2. at the same time, the composite node needs to delegate the input source to the inner workflow.
|
||||
// 3. also, some node may have implicit input sources not defined in its NodeSchema's InputSources.
|
||||
func (s *NodeSchema) SetFullSources(allNS map[vo.NodeKey]*NodeSchema, dep *dependencyInfo) error {
|
||||
fullSource := make(map[string]*nodes.SourceInfo)
|
||||
func GetFullSources(s *schema.NodeSchema, sc *schema.WorkflowSchema, dep *dependencyInfo) (
|
||||
map[string]*schema.SourceInfo, error) {
|
||||
fullSource := make(map[string]*schema.SourceInfo)
|
||||
var fieldInfos []vo.FieldInfo
|
||||
for _, s := range dep.staticValues {
|
||||
fieldInfos = append(fieldInfos, vo.FieldInfo{
|
||||
@@ -113,14 +114,14 @@ func (s *NodeSchema) SetFullSources(allNS map[vo.NodeKey]*NodeSchema, dep *depen
|
||||
tInfo = tInfo.Properties[path[j]]
|
||||
}
|
||||
if current, ok := currentSource[path[j]]; !ok {
|
||||
currentSource[path[j]] = &nodes.SourceInfo{
|
||||
currentSource[path[j]] = &schema.SourceInfo{
|
||||
IsIntermediate: true,
|
||||
FieldType: nodes.FieldNotStream,
|
||||
FieldType: schema.FieldNotStream,
|
||||
TypeInfo: tInfo,
|
||||
SubSources: make(map[string]*nodes.SourceInfo),
|
||||
SubSources: make(map[string]*schema.SourceInfo),
|
||||
}
|
||||
} else if !current.IsIntermediate {
|
||||
return fmt.Errorf("existing sourceInfo for path %s is not intermediate, conflict", path[:j+1])
|
||||
return nil, fmt.Errorf("existing sourceInfo for path %s is not intermediate, conflict", path[:j+1])
|
||||
}
|
||||
|
||||
currentSource = currentSource[path[j]].SubSources
|
||||
@@ -135,9 +136,9 @@ func (s *NodeSchema) SetFullSources(allNS map[vo.NodeKey]*NodeSchema, dep *depen
|
||||
|
||||
// static values or variables
|
||||
if fInfo.Source.Ref == nil || fInfo.Source.Ref.FromNodeKey == "" {
|
||||
currentSource[lastPath] = &nodes.SourceInfo{
|
||||
currentSource[lastPath] = &schema.SourceInfo{
|
||||
IsIntermediate: false,
|
||||
FieldType: nodes.FieldNotStream,
|
||||
FieldType: schema.FieldNotStream,
|
||||
TypeInfo: tInfo,
|
||||
}
|
||||
continue
|
||||
@@ -145,25 +146,25 @@ func (s *NodeSchema) SetFullSources(allNS map[vo.NodeKey]*NodeSchema, dep *depen
|
||||
|
||||
fromNodeKey := fInfo.Source.Ref.FromNodeKey
|
||||
var (
|
||||
streamType nodes.FieldStreamType
|
||||
streamType schema.FieldStreamType
|
||||
err error
|
||||
)
|
||||
if len(fromNodeKey) > 0 {
|
||||
if fromNodeKey == compose.START {
|
||||
streamType = nodes.FieldNotStream // TODO: set start node to not stream for now until composite node supports transform
|
||||
streamType = schema.FieldNotStream // TODO: set start node to not stream for now until composite node supports transform
|
||||
} else {
|
||||
fromNode, ok := allNS[fromNodeKey]
|
||||
if !ok {
|
||||
return fmt.Errorf("node %s not found", fromNodeKey)
|
||||
fromNode := sc.GetNode(fromNodeKey)
|
||||
if fromNode == nil {
|
||||
return nil, fmt.Errorf("node %s not found", fromNodeKey)
|
||||
}
|
||||
streamType, err = fromNode.IsStreamingField(fInfo.Source.Ref.FromPath, allNS)
|
||||
streamType, err = nodes.IsStreamingField(fromNode, fInfo.Source.Ref.FromPath, sc)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
currentSource[lastPath] = &nodes.SourceInfo{
|
||||
currentSource[lastPath] = &schema.SourceInfo{
|
||||
IsIntermediate: false,
|
||||
FieldType: streamType,
|
||||
FromNodeKey: fromNodeKey,
|
||||
@@ -172,121 +173,5 @@ func (s *NodeSchema) SetFullSources(allNS map[vo.NodeKey]*NodeSchema, dep *depen
|
||||
}
|
||||
}
|
||||
|
||||
s.Configs.(map[string]any)["FullSources"] = fullSource
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) IsStreamingField(path compose.FieldPath, allNS map[vo.NodeKey]*NodeSchema) (nodes.FieldStreamType, error) {
|
||||
if s.Type == entity.NodeTypeExit {
|
||||
if mustGetKey[nodes.Mode]("Mode", s.Configs) == nodes.Streaming {
|
||||
if len(path) == 1 && path[0] == "output" {
|
||||
return nodes.FieldIsStream, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nodes.FieldNotStream, nil
|
||||
} else if s.Type == entity.NodeTypeSubWorkflow { // TODO: why not use sub workflow's Mode configuration directly?
|
||||
subSC := s.SubWorkflowSchema
|
||||
subExit := subSC.GetNode(entity.ExitNodeKey)
|
||||
subStreamType, err := subExit.IsStreamingField(path, nil)
|
||||
if err != nil {
|
||||
return nodes.FieldNotStream, err
|
||||
}
|
||||
|
||||
return subStreamType, nil
|
||||
} else if s.Type == entity.NodeTypeVariableAggregator {
|
||||
if len(path) == 2 { // asking about a specific index within a group
|
||||
for _, fInfo := range s.InputSources {
|
||||
if len(fInfo.Path) == len(path) {
|
||||
equal := true
|
||||
for i := range fInfo.Path {
|
||||
if fInfo.Path[i] != path[i] {
|
||||
equal = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if equal {
|
||||
if fInfo.Source.Ref == nil || fInfo.Source.Ref.FromNodeKey == "" {
|
||||
return nodes.FieldNotStream, nil
|
||||
}
|
||||
fromNodeKey := fInfo.Source.Ref.FromNodeKey
|
||||
fromNode, ok := allNS[fromNodeKey]
|
||||
if !ok {
|
||||
return nodes.FieldNotStream, fmt.Errorf("node %s not found", fromNodeKey)
|
||||
}
|
||||
return fromNode.IsStreamingField(fInfo.Source.Ref.FromPath, allNS)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if len(path) == 1 { // asking about the entire group
|
||||
var streamCount, notStreamCount int
|
||||
for _, fInfo := range s.InputSources {
|
||||
if fInfo.Path[0] == path[0] { // belong to the group
|
||||
if fInfo.Source.Ref != nil && len(fInfo.Source.Ref.FromNodeKey) > 0 {
|
||||
fromNode, ok := allNS[fInfo.Source.Ref.FromNodeKey]
|
||||
if !ok {
|
||||
return nodes.FieldNotStream, fmt.Errorf("node %s not found", fInfo.Source.Ref.FromNodeKey)
|
||||
}
|
||||
subStreamType, err := fromNode.IsStreamingField(fInfo.Source.Ref.FromPath, allNS)
|
||||
if err != nil {
|
||||
return nodes.FieldNotStream, err
|
||||
}
|
||||
|
||||
if subStreamType == nodes.FieldMaybeStream {
|
||||
return nodes.FieldMaybeStream, nil
|
||||
} else if subStreamType == nodes.FieldIsStream {
|
||||
streamCount++
|
||||
} else {
|
||||
notStreamCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if streamCount > 0 && notStreamCount == 0 {
|
||||
return nodes.FieldIsStream, nil
|
||||
}
|
||||
|
||||
if streamCount == 0 && notStreamCount > 0 {
|
||||
return nodes.FieldNotStream, nil
|
||||
}
|
||||
|
||||
return nodes.FieldMaybeStream, nil
|
||||
}
|
||||
}
|
||||
|
||||
if s.Type != entity.NodeTypeLLM {
|
||||
return nodes.FieldNotStream, nil
|
||||
}
|
||||
|
||||
if len(path) != 1 {
|
||||
return nodes.FieldNotStream, nil
|
||||
}
|
||||
|
||||
outputs := s.OutputTypes
|
||||
if len(outputs) != 1 && len(outputs) != 2 {
|
||||
return nodes.FieldNotStream, nil
|
||||
}
|
||||
|
||||
var outputKey string
|
||||
for key, output := range outputs {
|
||||
if output.Type != vo.DataTypeString {
|
||||
return nodes.FieldNotStream, nil
|
||||
}
|
||||
|
||||
if key != "reasoning_content" {
|
||||
if len(outputKey) > 0 {
|
||||
return nodes.FieldNotStream, nil
|
||||
}
|
||||
outputKey = key
|
||||
}
|
||||
}
|
||||
|
||||
field := path[0]
|
||||
if field == "reasoning_content" || field == outputKey {
|
||||
return nodes.FieldIsStream, nil
|
||||
}
|
||||
|
||||
return nodes.FieldNotStream, nil
|
||||
return fullSource, nil
|
||||
}
|
||||
|
||||
@@ -28,6 +28,9 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/batch"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
func TestBatch(t *testing.T) {
|
||||
@@ -52,7 +55,7 @@ func TestBatch(t *testing.T) {
|
||||
return in, nil
|
||||
}
|
||||
|
||||
lambdaNode1 := &compose2.NodeSchema{
|
||||
lambdaNode1 := &schema.NodeSchema{
|
||||
Key: "lambda",
|
||||
Type: entity.NodeTypeLambda,
|
||||
Lambda: compose.InvokableLambda(lambda1),
|
||||
@@ -86,7 +89,7 @@ func TestBatch(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
lambdaNode2 := &compose2.NodeSchema{
|
||||
lambdaNode2 := &schema.NodeSchema{
|
||||
Key: "index",
|
||||
Type: entity.NodeTypeLambda,
|
||||
Lambda: compose.InvokableLambda(lambda2),
|
||||
@@ -103,7 +106,7 @@ func TestBatch(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
lambdaNode3 := &compose2.NodeSchema{
|
||||
lambdaNode3 := &schema.NodeSchema{
|
||||
Key: "consumer",
|
||||
Type: entity.NodeTypeLambda,
|
||||
Lambda: compose.InvokableLambda(lambda3),
|
||||
@@ -135,23 +138,22 @@ func TestBatch(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
entry := &compose2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
entryN := &schema.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: &entry.Config{},
|
||||
}
|
||||
|
||||
ns := &compose2.NodeSchema{
|
||||
Key: "batch_node_key",
|
||||
Type: entity.NodeTypeBatch,
|
||||
ns := &schema.NodeSchema{
|
||||
Key: "batch_node_key",
|
||||
Type: entity.NodeTypeBatch,
|
||||
Configs: &batch.Config{},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"array_1"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"array_1"},
|
||||
},
|
||||
},
|
||||
@@ -160,7 +162,7 @@ func TestBatch(t *testing.T) {
|
||||
Path: compose.FieldPath{"array_2"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"array_2"},
|
||||
},
|
||||
},
|
||||
@@ -214,11 +216,11 @@ func TestBatch(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
exitN := &schema.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
Configs: &exit.Config{
|
||||
TerminatePlan: vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
@@ -246,18 +248,18 @@ func TestBatch(t *testing.T) {
|
||||
return map[string]any{"success": true}, nil
|
||||
}
|
||||
|
||||
parentLambdaNode := &compose2.NodeSchema{
|
||||
parentLambdaNode := &schema.NodeSchema{
|
||||
Key: "parent_predecessor_1",
|
||||
Type: entity.NodeTypeLambda,
|
||||
Lambda: compose.InvokableLambda(parentLambda),
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
ws := &schema.WorkflowSchema{
|
||||
Nodes: []*schema.NodeSchema{
|
||||
entryN,
|
||||
parentLambdaNode,
|
||||
ns,
|
||||
exit,
|
||||
exitN,
|
||||
lambdaNode1,
|
||||
lambdaNode2,
|
||||
lambdaNode3,
|
||||
@@ -267,7 +269,7 @@ func TestBatch(t *testing.T) {
|
||||
"index": "batch_node_key",
|
||||
"consumer": "batch_node_key",
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
Connections: []*schema.Connection{
|
||||
{
|
||||
FromNode: entity.EntryNodeKey,
|
||||
ToNode: "parent_predecessor_1",
|
||||
|
||||
@@ -40,7 +40,11 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/emitter"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm"
|
||||
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/internal/testutil"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
@@ -108,22 +112,20 @@ func TestLLM(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
entry := &compose2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
entryN := &schema2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: &entry.Config{},
|
||||
}
|
||||
|
||||
llmNode := &compose2.NodeSchema{
|
||||
llmNode := &schema2.NodeSchema{
|
||||
Key: "llm_node_key",
|
||||
Type: entity.NodeTypeLLM,
|
||||
Configs: map[string]any{
|
||||
"SystemPrompt": "{{sys_prompt}}",
|
||||
"UserPrompt": "{{query}}",
|
||||
"OutputFormat": llm.FormatText,
|
||||
"LLMParams": &model.LLMParams{
|
||||
Configs: &llm.Config{
|
||||
SystemPrompt: "{{sys_prompt}}",
|
||||
UserPrompt: "{{query}}",
|
||||
OutputFormat: llm.FormatText,
|
||||
LLMParams: &model.LLMParams{
|
||||
ModelName: modelName,
|
||||
},
|
||||
},
|
||||
@@ -132,7 +134,7 @@ func TestLLM(t *testing.T) {
|
||||
Path: compose.FieldPath{"sys_prompt"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"sys_prompt"},
|
||||
},
|
||||
},
|
||||
@@ -141,7 +143,7 @@ func TestLLM(t *testing.T) {
|
||||
Path: compose.FieldPath{"query"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"query"},
|
||||
},
|
||||
},
|
||||
@@ -162,11 +164,11 @@ func TestLLM(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
exitN := &schema2.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
Configs: &exit.Config{
|
||||
TerminatePlan: vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
@@ -181,20 +183,20 @@ func TestLLM(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
ws := &schema2.WorkflowSchema{
|
||||
Nodes: []*schema2.NodeSchema{
|
||||
entryN,
|
||||
llmNode,
|
||||
exit,
|
||||
exitN,
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
Connections: []*schema2.Connection{
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
FromNode: entryN.Key,
|
||||
ToNode: llmNode.Key,
|
||||
},
|
||||
{
|
||||
FromNode: llmNode.Key,
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -228,27 +230,20 @@ func TestLLM(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
entry := &compose2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
entryN := &schema2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: &entry.Config{},
|
||||
}
|
||||
|
||||
llmNode := &compose2.NodeSchema{
|
||||
llmNode := &schema2.NodeSchema{
|
||||
Key: "llm_node_key",
|
||||
Type: entity.NodeTypeLLM,
|
||||
Configs: map[string]any{
|
||||
"SystemPrompt": "you are a helpful assistant",
|
||||
"UserPrompt": "what's the largest country in the world and it's area size in square kilometers?",
|
||||
"OutputFormat": llm.FormatJSON,
|
||||
"IgnoreException": true,
|
||||
"DefaultOutput": map[string]any{
|
||||
"country_name": "unknown",
|
||||
"area_size": int64(0),
|
||||
},
|
||||
"LLMParams": &model.LLMParams{
|
||||
Configs: &llm.Config{
|
||||
SystemPrompt: "you are a helpful assistant",
|
||||
UserPrompt: "what's the largest country in the world and it's area size in square kilometers?",
|
||||
OutputFormat: llm.FormatJSON,
|
||||
LLMParams: &model.LLMParams{
|
||||
ModelName: modelName,
|
||||
},
|
||||
},
|
||||
@@ -264,11 +259,11 @@ func TestLLM(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
exitN := &schema2.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
Configs: &exit.Config{
|
||||
TerminatePlan: vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
@@ -292,20 +287,20 @@ func TestLLM(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
ws := &schema2.WorkflowSchema{
|
||||
Nodes: []*schema2.NodeSchema{
|
||||
entryN,
|
||||
llmNode,
|
||||
exit,
|
||||
exitN,
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
Connections: []*schema2.Connection{
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
FromNode: entryN.Key,
|
||||
ToNode: llmNode.Key,
|
||||
},
|
||||
{
|
||||
FromNode: llmNode.Key,
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -337,22 +332,20 @@ func TestLLM(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
entry := &compose2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
entryN := &schema2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: &entry.Config{},
|
||||
}
|
||||
|
||||
llmNode := &compose2.NodeSchema{
|
||||
llmNode := &schema2.NodeSchema{
|
||||
Key: "llm_node_key",
|
||||
Type: entity.NodeTypeLLM,
|
||||
Configs: map[string]any{
|
||||
"SystemPrompt": "you are a helpful assistant",
|
||||
"UserPrompt": "list the top 5 largest countries in the world",
|
||||
"OutputFormat": llm.FormatMarkdown,
|
||||
"LLMParams": &model.LLMParams{
|
||||
Configs: &llm.Config{
|
||||
SystemPrompt: "you are a helpful assistant",
|
||||
UserPrompt: "list the top 5 largest countries in the world",
|
||||
OutputFormat: llm.FormatMarkdown,
|
||||
LLMParams: &model.LLMParams{
|
||||
ModelName: modelName,
|
||||
},
|
||||
},
|
||||
@@ -363,11 +356,11 @@ func TestLLM(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
exitN := &schema2.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
Configs: &exit.Config{
|
||||
TerminatePlan: vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
@@ -382,20 +375,20 @@ func TestLLM(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
ws := &schema2.WorkflowSchema{
|
||||
Nodes: []*schema2.NodeSchema{
|
||||
entryN,
|
||||
llmNode,
|
||||
exit,
|
||||
exitN,
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
Connections: []*schema2.Connection{
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
FromNode: entryN.Key,
|
||||
ToNode: llmNode.Key,
|
||||
},
|
||||
{
|
||||
FromNode: llmNode.Key,
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -456,22 +449,20 @@ func TestLLM(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
entry := &compose2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
entryN := &schema2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: &entry.Config{},
|
||||
}
|
||||
|
||||
openaiNode := &compose2.NodeSchema{
|
||||
openaiNode := &schema2.NodeSchema{
|
||||
Key: "openai_llm_node_key",
|
||||
Type: entity.NodeTypeLLM,
|
||||
Configs: map[string]any{
|
||||
"SystemPrompt": "you are a helpful assistant",
|
||||
"UserPrompt": "plan a 10 day family visit to China.",
|
||||
"OutputFormat": llm.FormatText,
|
||||
"LLMParams": &model.LLMParams{
|
||||
Configs: &llm.Config{
|
||||
SystemPrompt: "you are a helpful assistant",
|
||||
UserPrompt: "plan a 10 day family visit to China.",
|
||||
OutputFormat: llm.FormatText,
|
||||
LLMParams: &model.LLMParams{
|
||||
ModelName: modelName,
|
||||
},
|
||||
},
|
||||
@@ -482,14 +473,14 @@ func TestLLM(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
deepseekNode := &compose2.NodeSchema{
|
||||
deepseekNode := &schema2.NodeSchema{
|
||||
Key: "deepseek_llm_node_key",
|
||||
Type: entity.NodeTypeLLM,
|
||||
Configs: map[string]any{
|
||||
"SystemPrompt": "you are a helpful assistant",
|
||||
"UserPrompt": "thoroughly plan a 10 day family visit to China. Use your reasoning ability.",
|
||||
"OutputFormat": llm.FormatText,
|
||||
"LLMParams": &model.LLMParams{
|
||||
Configs: &llm.Config{
|
||||
SystemPrompt: "you are a helpful assistant",
|
||||
UserPrompt: "thoroughly plan a 10 day family visit to China. Use your reasoning ability.",
|
||||
OutputFormat: llm.FormatText,
|
||||
LLMParams: &model.LLMParams{
|
||||
ModelName: modelName,
|
||||
},
|
||||
},
|
||||
@@ -503,12 +494,11 @@ func TestLLM(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
emitterNode := &compose2.NodeSchema{
|
||||
emitterNode := &schema2.NodeSchema{
|
||||
Key: "emitter_node_key",
|
||||
Type: entity.NodeTypeOutputEmitter,
|
||||
Configs: map[string]any{
|
||||
"Template": "prefix {{inputObj.field1}} {{input2}} {{deepseek_reasoning}} \n\n###\n\n {{openai_output}} \n\n###\n\n {{deepseek_output}} {{inputObj.field2}} suffix",
|
||||
"Mode": nodes.Streaming,
|
||||
Configs: &emitter.Config{
|
||||
Template: "prefix {{inputObj.field1}} {{input2}} {{deepseek_reasoning}} \n\n###\n\n {{openai_output}} \n\n###\n\n {{deepseek_output}} {{inputObj.field2}} suffix",
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
@@ -542,7 +532,7 @@ func TestLLM(t *testing.T) {
|
||||
Path: compose.FieldPath{"inputObj"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"inputObj"},
|
||||
},
|
||||
},
|
||||
@@ -551,7 +541,7 @@ func TestLLM(t *testing.T) {
|
||||
Path: compose.FieldPath{"input2"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"input2"},
|
||||
},
|
||||
},
|
||||
@@ -559,11 +549,11 @@ func TestLLM(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
exitN := &schema2.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.UseAnswerContent,
|
||||
Configs: &exit.Config{
|
||||
TerminatePlan: vo.UseAnswerContent,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
@@ -596,17 +586,17 @@ func TestLLM(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
ws := &schema2.WorkflowSchema{
|
||||
Nodes: []*schema2.NodeSchema{
|
||||
entryN,
|
||||
openaiNode,
|
||||
deepseekNode,
|
||||
emitterNode,
|
||||
exit,
|
||||
exitN,
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
Connections: []*schema2.Connection{
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
FromNode: entryN.Key,
|
||||
ToNode: openaiNode.Key,
|
||||
},
|
||||
{
|
||||
@@ -614,7 +604,7 @@ func TestLLM(t *testing.T) {
|
||||
ToNode: emitterNode.Key,
|
||||
},
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
FromNode: entryN.Key,
|
||||
ToNode: deepseekNode.Key,
|
||||
},
|
||||
{
|
||||
@@ -623,7 +613,7 @@ func TestLLM(t *testing.T) {
|
||||
},
|
||||
{
|
||||
FromNode: emitterNode.Key,
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -26,15 +26,20 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop"
|
||||
_break "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop/break"
|
||||
_continue "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop/continue"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableassigner"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
func TestLoop(t *testing.T) {
|
||||
t.Run("by iteration", func(t *testing.T) {
|
||||
// start-> loop_node_key[innerNode->continue] -> end
|
||||
innerNode := &compose2.NodeSchema{
|
||||
innerNode := &schema.NodeSchema{
|
||||
Key: "innerNode",
|
||||
Type: entity.NodeTypeLambda,
|
||||
Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
|
||||
@@ -54,31 +59,30 @@ func TestLoop(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
continueNode := &compose2.NodeSchema{
|
||||
Key: "continueNode",
|
||||
Type: entity.NodeTypeContinue,
|
||||
continueNode := &schema.NodeSchema{
|
||||
Key: "continueNode",
|
||||
Type: entity.NodeTypeContinue,
|
||||
Configs: &_continue.Config{},
|
||||
}
|
||||
|
||||
entry := &compose2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
entryN := &schema.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: &entry.Config{},
|
||||
}
|
||||
|
||||
loopNode := &compose2.NodeSchema{
|
||||
loopNode := &schema.NodeSchema{
|
||||
Key: "loop_node_key",
|
||||
Type: entity.NodeTypeLoop,
|
||||
Configs: map[string]any{
|
||||
"LoopType": loop.ByIteration,
|
||||
Configs: &loop.Config{
|
||||
LoopType: loop.ByIteration,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{loop.Count},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"count"},
|
||||
},
|
||||
},
|
||||
@@ -97,11 +101,11 @@ func TestLoop(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
exitN := &schema.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
Configs: &exit.Config{
|
||||
TerminatePlan: vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
@@ -116,11 +120,11 @@ func TestLoop(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
ws := &schema.WorkflowSchema{
|
||||
Nodes: []*schema.NodeSchema{
|
||||
entryN,
|
||||
loopNode,
|
||||
exit,
|
||||
exitN,
|
||||
innerNode,
|
||||
continueNode,
|
||||
},
|
||||
@@ -128,7 +132,7 @@ func TestLoop(t *testing.T) {
|
||||
"innerNode": "loop_node_key",
|
||||
"continueNode": "loop_node_key",
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
Connections: []*schema.Connection{
|
||||
{
|
||||
FromNode: "loop_node_key",
|
||||
ToNode: "innerNode",
|
||||
@@ -142,12 +146,12 @@ func TestLoop(t *testing.T) {
|
||||
ToNode: "loop_node_key",
|
||||
},
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
FromNode: entryN.Key,
|
||||
ToNode: "loop_node_key",
|
||||
},
|
||||
{
|
||||
FromNode: "loop_node_key",
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -168,7 +172,7 @@ func TestLoop(t *testing.T) {
|
||||
|
||||
t.Run("infinite", func(t *testing.T) {
|
||||
// start-> loop_node_key[innerNode->break] -> end
|
||||
innerNode := &compose2.NodeSchema{
|
||||
innerNode := &schema.NodeSchema{
|
||||
Key: "innerNode",
|
||||
Type: entity.NodeTypeLambda,
|
||||
Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
|
||||
@@ -188,24 +192,23 @@ func TestLoop(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
breakNode := &compose2.NodeSchema{
|
||||
Key: "breakNode",
|
||||
Type: entity.NodeTypeBreak,
|
||||
breakNode := &schema.NodeSchema{
|
||||
Key: "breakNode",
|
||||
Type: entity.NodeTypeBreak,
|
||||
Configs: &_break.Config{},
|
||||
}
|
||||
|
||||
entry := &compose2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
entryN := &schema.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: &entry.Config{},
|
||||
}
|
||||
|
||||
loopNode := &compose2.NodeSchema{
|
||||
loopNode := &schema.NodeSchema{
|
||||
Key: "loop_node_key",
|
||||
Type: entity.NodeTypeLoop,
|
||||
Configs: map[string]any{
|
||||
"LoopType": loop.Infinite,
|
||||
Configs: &loop.Config{
|
||||
LoopType: loop.Infinite,
|
||||
},
|
||||
OutputSources: []*vo.FieldInfo{
|
||||
{
|
||||
@@ -220,11 +223,11 @@ func TestLoop(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
exitN := &schema.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
Configs: &exit.Config{
|
||||
TerminatePlan: vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
@@ -239,11 +242,11 @@ func TestLoop(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
ws := &schema.WorkflowSchema{
|
||||
Nodes: []*schema.NodeSchema{
|
||||
entryN,
|
||||
loopNode,
|
||||
exit,
|
||||
exitN,
|
||||
innerNode,
|
||||
breakNode,
|
||||
},
|
||||
@@ -251,7 +254,7 @@ func TestLoop(t *testing.T) {
|
||||
"innerNode": "loop_node_key",
|
||||
"breakNode": "loop_node_key",
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
Connections: []*schema.Connection{
|
||||
{
|
||||
FromNode: "loop_node_key",
|
||||
ToNode: "innerNode",
|
||||
@@ -265,12 +268,12 @@ func TestLoop(t *testing.T) {
|
||||
ToNode: "loop_node_key",
|
||||
},
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
FromNode: entryN.Key,
|
||||
ToNode: "loop_node_key",
|
||||
},
|
||||
{
|
||||
FromNode: "loop_node_key",
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -290,14 +293,14 @@ func TestLoop(t *testing.T) {
|
||||
t.Run("by array", func(t *testing.T) {
|
||||
// start-> loop_node_key[innerNode->variable_assign] -> end
|
||||
|
||||
innerNode := &compose2.NodeSchema{
|
||||
innerNode := &schema.NodeSchema{
|
||||
Key: "innerNode",
|
||||
Type: entity.NodeTypeLambda,
|
||||
Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
|
||||
item1 := in["item1"].(string)
|
||||
item2 := in["item2"].(string)
|
||||
count := in["count"].(int)
|
||||
return map[string]any{"total": int(count) + len(item1) + len(item2)}, nil
|
||||
return map[string]any{"total": count + len(item1) + len(item2)}, nil
|
||||
}),
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
@@ -330,16 +333,18 @@ func TestLoop(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
assigner := &compose2.NodeSchema{
|
||||
assigner := &schema.NodeSchema{
|
||||
Key: "assigner",
|
||||
Type: entity.NodeTypeVariableAssignerWithinLoop,
|
||||
Configs: []*variableassigner.Pair{
|
||||
{
|
||||
Left: vo.Reference{
|
||||
FromPath: compose.FieldPath{"count"},
|
||||
VariableType: ptr.Of(vo.ParentIntermediate),
|
||||
Configs: &variableassigner.InLoopConfig{
|
||||
Pairs: []*variableassigner.Pair{
|
||||
{
|
||||
Left: vo.Reference{
|
||||
FromPath: compose.FieldPath{"count"},
|
||||
VariableType: ptr.Of(vo.ParentIntermediate),
|
||||
},
|
||||
Right: compose.FieldPath{"total"},
|
||||
},
|
||||
Right: compose.FieldPath{"total"},
|
||||
},
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
@@ -355,19 +360,17 @@ func TestLoop(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
entry := &compose2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
entryN := &schema.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: &entry.Config{},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
exitN := &schema.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
Configs: &exit.Config{
|
||||
TerminatePlan: vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
@@ -382,12 +385,13 @@ func TestLoop(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
loopNode := &compose2.NodeSchema{
|
||||
loopNode := &schema.NodeSchema{
|
||||
Key: "loop_node_key",
|
||||
Type: entity.NodeTypeLoop,
|
||||
Configs: map[string]any{
|
||||
"LoopType": loop.ByArray,
|
||||
"IntermediateVars": map[string]*vo.TypeInfo{
|
||||
Configs: &loop.Config{
|
||||
LoopType: loop.ByArray,
|
||||
InputArrays: []string{"items1", "items2"},
|
||||
IntermediateVars: map[string]*vo.TypeInfo{
|
||||
"count": {
|
||||
Type: vo.DataTypeInteger,
|
||||
},
|
||||
@@ -408,7 +412,7 @@ func TestLoop(t *testing.T) {
|
||||
Path: compose.FieldPath{"items1"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"items1"},
|
||||
},
|
||||
},
|
||||
@@ -417,7 +421,7 @@ func TestLoop(t *testing.T) {
|
||||
Path: compose.FieldPath{"items2"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"items2"},
|
||||
},
|
||||
},
|
||||
@@ -442,11 +446,11 @@ func TestLoop(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
ws := &schema.WorkflowSchema{
|
||||
Nodes: []*schema.NodeSchema{
|
||||
entryN,
|
||||
loopNode,
|
||||
exit,
|
||||
exitN,
|
||||
innerNode,
|
||||
assigner,
|
||||
},
|
||||
@@ -454,7 +458,7 @@ func TestLoop(t *testing.T) {
|
||||
"innerNode": "loop_node_key",
|
||||
"assigner": "loop_node_key",
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
Connections: []*schema.Connection{
|
||||
{
|
||||
FromNode: "loop_node_key",
|
||||
ToNode: "innerNode",
|
||||
@@ -468,12 +472,12 @@ func TestLoop(t *testing.T) {
|
||||
ToNode: "loop_node_key",
|
||||
},
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
FromNode: entryN.Key,
|
||||
ToNode: "loop_node_key",
|
||||
},
|
||||
{
|
||||
FromNode: "loop_node_key",
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -43,8 +43,11 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
|
||||
repo2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/repo"
|
||||
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/checkpoint"
|
||||
mock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/idgen"
|
||||
storageMock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/storage"
|
||||
@@ -106,26 +109,25 @@ func TestQuestionAnswer(t *testing.T) {
|
||||
mockey.Mock(workflow.GetRepository).Return(repo).Build()
|
||||
|
||||
t.Run("answer directly, no structured output", func(t *testing.T) {
|
||||
entry := &compose2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
}}
|
||||
entryN := &schema2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: &entry.Config{},
|
||||
}
|
||||
|
||||
ns := &compose2.NodeSchema{
|
||||
ns := &schema2.NodeSchema{
|
||||
Key: "qa_node_key",
|
||||
Type: entity.NodeTypeQuestionAnswer,
|
||||
Configs: map[string]any{
|
||||
"QuestionTpl": "{{input}}",
|
||||
"AnswerType": qa.AnswerDirectly,
|
||||
Configs: &qa.Config{
|
||||
QuestionTpl: "{{input}}",
|
||||
AnswerType: qa.AnswerDirectly,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"input"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"query"},
|
||||
},
|
||||
},
|
||||
@@ -133,11 +135,11 @@ func TestQuestionAnswer(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
exitN := &schema2.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
Configs: &exit.Config{
|
||||
TerminatePlan: vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
@@ -152,20 +154,20 @@ func TestQuestionAnswer(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
ws := &schema2.WorkflowSchema{
|
||||
Nodes: []*schema2.NodeSchema{
|
||||
entryN,
|
||||
ns,
|
||||
exit,
|
||||
exitN,
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
Connections: []*schema2.Connection{
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
FromNode: entryN.Key,
|
||||
ToNode: "qa_node_key",
|
||||
},
|
||||
{
|
||||
FromNode: "qa_node_key",
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -210,30 +212,28 @@ func TestQuestionAnswer(t *testing.T) {
|
||||
mockModelManager.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(oneChatModel, nil, nil).Times(1)
|
||||
}
|
||||
|
||||
entry := &compose2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
entryN := &schema2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: &entry.Config{},
|
||||
}
|
||||
|
||||
ns := &compose2.NodeSchema{
|
||||
ns := &schema2.NodeSchema{
|
||||
Key: "qa_node_key",
|
||||
Type: entity.NodeTypeQuestionAnswer,
|
||||
Configs: map[string]any{
|
||||
"QuestionTpl": "{{input}}",
|
||||
"AnswerType": qa.AnswerByChoices,
|
||||
"ChoiceType": qa.FixedChoices,
|
||||
"FixedChoices": []string{"{{choice1}}", "{{choice2}}"},
|
||||
"LLMParams": &model.LLMParams{},
|
||||
Configs: &qa.Config{
|
||||
QuestionTpl: "{{input}}",
|
||||
AnswerType: qa.AnswerByChoices,
|
||||
ChoiceType: qa.FixedChoices,
|
||||
FixedChoices: []string{"{{choice1}}", "{{choice2}}"},
|
||||
LLMParams: &model.LLMParams{},
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"input"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"query"},
|
||||
},
|
||||
},
|
||||
@@ -242,7 +242,7 @@ func TestQuestionAnswer(t *testing.T) {
|
||||
Path: compose.FieldPath{"choice1"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"choice1"},
|
||||
},
|
||||
},
|
||||
@@ -251,7 +251,7 @@ func TestQuestionAnswer(t *testing.T) {
|
||||
Path: compose.FieldPath{"choice2"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"choice2"},
|
||||
},
|
||||
},
|
||||
@@ -259,11 +259,11 @@ func TestQuestionAnswer(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
exitN := &schema2.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
Configs: &exit.Config{
|
||||
TerminatePlan: vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
@@ -287,7 +287,7 @@ func TestQuestionAnswer(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
lambda := &compose2.NodeSchema{
|
||||
lambda := &schema2.NodeSchema{
|
||||
Key: "lambda",
|
||||
Type: entity.NodeTypeLambda,
|
||||
Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
|
||||
@@ -295,26 +295,26 @@ func TestQuestionAnswer(t *testing.T) {
|
||||
}),
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
ws := &schema2.WorkflowSchema{
|
||||
Nodes: []*schema2.NodeSchema{
|
||||
entryN,
|
||||
ns,
|
||||
exit,
|
||||
exitN,
|
||||
lambda,
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
Connections: []*schema2.Connection{
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
FromNode: entryN.Key,
|
||||
ToNode: "qa_node_key",
|
||||
},
|
||||
{
|
||||
FromNode: "qa_node_key",
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
FromPort: ptr.Of("branch_0"),
|
||||
},
|
||||
{
|
||||
FromNode: "qa_node_key",
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
FromPort: ptr.Of("branch_1"),
|
||||
},
|
||||
{
|
||||
@@ -324,11 +324,15 @@ func TestQuestionAnswer(t *testing.T) {
|
||||
},
|
||||
{
|
||||
FromNode: "lambda",
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
branches, err := schema2.BuildBranches(ws.Connections)
|
||||
assert.NoError(t, err)
|
||||
ws.Branches = branches
|
||||
|
||||
ws.Init()
|
||||
|
||||
wf, err := compose2.NewWorkflow(context.Background(), ws)
|
||||
@@ -362,28 +366,26 @@ func TestQuestionAnswer(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("answer with dynamic choices", func(t *testing.T) {
|
||||
entry := &compose2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
entryN := &schema2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: &entry.Config{},
|
||||
}
|
||||
|
||||
ns := &compose2.NodeSchema{
|
||||
ns := &schema2.NodeSchema{
|
||||
Key: "qa_node_key",
|
||||
Type: entity.NodeTypeQuestionAnswer,
|
||||
Configs: map[string]any{
|
||||
"QuestionTpl": "{{input}}",
|
||||
"AnswerType": qa.AnswerByChoices,
|
||||
"ChoiceType": qa.DynamicChoices,
|
||||
Configs: &qa.Config{
|
||||
QuestionTpl: "{{input}}",
|
||||
AnswerType: qa.AnswerByChoices,
|
||||
ChoiceType: qa.DynamicChoices,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"input"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"query"},
|
||||
},
|
||||
},
|
||||
@@ -392,7 +394,7 @@ func TestQuestionAnswer(t *testing.T) {
|
||||
Path: compose.FieldPath{qa.DynamicChoicesKey},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"choices"},
|
||||
},
|
||||
},
|
||||
@@ -400,11 +402,11 @@ func TestQuestionAnswer(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
exitN := &schema2.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
Configs: &exit.Config{
|
||||
TerminatePlan: vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
@@ -428,7 +430,7 @@ func TestQuestionAnswer(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
lambda := &compose2.NodeSchema{
|
||||
lambda := &schema2.NodeSchema{
|
||||
Key: "lambda",
|
||||
Type: entity.NodeTypeLambda,
|
||||
Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
|
||||
@@ -436,26 +438,26 @@ func TestQuestionAnswer(t *testing.T) {
|
||||
}),
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
ws := &schema2.WorkflowSchema{
|
||||
Nodes: []*schema2.NodeSchema{
|
||||
entryN,
|
||||
ns,
|
||||
exit,
|
||||
exitN,
|
||||
lambda,
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
Connections: []*schema2.Connection{
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
FromNode: entryN.Key,
|
||||
ToNode: "qa_node_key",
|
||||
},
|
||||
{
|
||||
FromNode: "qa_node_key",
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
FromPort: ptr.Of("branch_0"),
|
||||
},
|
||||
{
|
||||
FromNode: "lambda",
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
},
|
||||
{
|
||||
FromNode: "qa_node_key",
|
||||
@@ -465,6 +467,10 @@ func TestQuestionAnswer(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
branches, err := schema2.BuildBranches(ws.Connections)
|
||||
assert.NoError(t, err)
|
||||
ws.Branches = branches
|
||||
|
||||
ws.Init()
|
||||
|
||||
wf, err := compose2.NewWorkflow(context.Background(), ws)
|
||||
@@ -522,31 +528,29 @@ func TestQuestionAnswer(t *testing.T) {
|
||||
mockModelManager.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(chatModel, nil, nil).Times(1)
|
||||
}
|
||||
|
||||
entry := &compose2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
entryN := &schema2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: &entry.Config{},
|
||||
}
|
||||
|
||||
ns := &compose2.NodeSchema{
|
||||
ns := &schema2.NodeSchema{
|
||||
Key: "qa_node_key",
|
||||
Type: entity.NodeTypeQuestionAnswer,
|
||||
Configs: map[string]any{
|
||||
"QuestionTpl": "{{input}}",
|
||||
"AnswerType": qa.AnswerDirectly,
|
||||
"ExtractFromAnswer": true,
|
||||
"AdditionalSystemPromptTpl": "{{prompt}}",
|
||||
"MaxAnswerCount": 2,
|
||||
"LLMParams": &model.LLMParams{},
|
||||
Configs: &qa.Config{
|
||||
QuestionTpl: "{{input}}",
|
||||
AnswerType: qa.AnswerDirectly,
|
||||
ExtractFromAnswer: true,
|
||||
AdditionalSystemPromptTpl: "{{prompt}}",
|
||||
MaxAnswerCount: 2,
|
||||
LLMParams: &model.LLMParams{},
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"input"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"query"},
|
||||
},
|
||||
},
|
||||
@@ -555,7 +559,7 @@ func TestQuestionAnswer(t *testing.T) {
|
||||
Path: compose.FieldPath{"prompt"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"prompt"},
|
||||
},
|
||||
},
|
||||
@@ -573,11 +577,11 @@ func TestQuestionAnswer(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
exitN := &schema2.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
Configs: &exit.Config{
|
||||
TerminatePlan: vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
@@ -610,20 +614,20 @@ func TestQuestionAnswer(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
ws := &schema2.WorkflowSchema{
|
||||
Nodes: []*schema2.NodeSchema{
|
||||
entryN,
|
||||
ns,
|
||||
exit,
|
||||
exitN,
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
Connections: []*schema2.Connection{
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
FromNode: entryN.Key,
|
||||
ToNode: "qa_node_key",
|
||||
},
|
||||
{
|
||||
FromNode: "qa_node_key",
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -26,26 +26,28 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/textprocessor"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableaggregator"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
func TestAddSelector(t *testing.T) {
|
||||
// start -> selector, selector.condition1 -> lambda1 -> end, selector.condition2 -> [lambda2, lambda3] -> end, selector default -> end
|
||||
entry := &compose2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
}}
|
||||
entryN := &schema.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: &entry.Config{},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
exitN := &schema.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
Configs: &exit.Config{
|
||||
TerminatePlan: vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
@@ -84,7 +86,7 @@ func TestAddSelector(t *testing.T) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
lambdaNode1 := &compose2.NodeSchema{
|
||||
lambdaNode1 := &schema.NodeSchema{
|
||||
Key: "lambda1",
|
||||
Type: entity.NodeTypeLambda,
|
||||
Lambda: compose.InvokableLambda(lambda1),
|
||||
@@ -96,7 +98,7 @@ func TestAddSelector(t *testing.T) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
LambdaNode2 := &compose2.NodeSchema{
|
||||
LambdaNode2 := &schema.NodeSchema{
|
||||
Key: "lambda2",
|
||||
Type: entity.NodeTypeLambda,
|
||||
Lambda: compose.InvokableLambda(lambda2),
|
||||
@@ -108,16 +110,16 @@ func TestAddSelector(t *testing.T) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
lambdaNode3 := &compose2.NodeSchema{
|
||||
lambdaNode3 := &schema.NodeSchema{
|
||||
Key: "lambda3",
|
||||
Type: entity.NodeTypeLambda,
|
||||
Lambda: compose.InvokableLambda(lambda3),
|
||||
}
|
||||
|
||||
ns := &compose2.NodeSchema{
|
||||
ns := &schema.NodeSchema{
|
||||
Key: "selector",
|
||||
Type: entity.NodeTypeSelector,
|
||||
Configs: map[string]any{"Clauses": []*selector.OneClauseSchema{
|
||||
Configs: &selector.Config{Clauses: []*selector.OneClauseSchema{
|
||||
{
|
||||
Single: ptr.Of(selector.OperatorEqual),
|
||||
},
|
||||
@@ -136,7 +138,7 @@ func TestAddSelector(t *testing.T) {
|
||||
Path: compose.FieldPath{"0", selector.LeftKey},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"key1"},
|
||||
},
|
||||
},
|
||||
@@ -151,7 +153,7 @@ func TestAddSelector(t *testing.T) {
|
||||
Path: compose.FieldPath{"1", "0", selector.LeftKey},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"key2"},
|
||||
},
|
||||
},
|
||||
@@ -160,7 +162,7 @@ func TestAddSelector(t *testing.T) {
|
||||
Path: compose.FieldPath{"1", "0", selector.RightKey},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"key3"},
|
||||
},
|
||||
},
|
||||
@@ -169,7 +171,7 @@ func TestAddSelector(t *testing.T) {
|
||||
Path: compose.FieldPath{"1", "1", selector.LeftKey},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"key4"},
|
||||
},
|
||||
},
|
||||
@@ -214,18 +216,18 @@ func TestAddSelector(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
ws := &schema.WorkflowSchema{
|
||||
Nodes: []*schema.NodeSchema{
|
||||
entryN,
|
||||
ns,
|
||||
lambdaNode1,
|
||||
LambdaNode2,
|
||||
lambdaNode3,
|
||||
exit,
|
||||
exitN,
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
Connections: []*schema.Connection{
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
FromNode: entryN.Key,
|
||||
ToNode: "selector",
|
||||
},
|
||||
{
|
||||
@@ -245,24 +247,28 @@ func TestAddSelector(t *testing.T) {
|
||||
},
|
||||
{
|
||||
FromNode: "selector",
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
FromPort: ptr.Of("default"),
|
||||
},
|
||||
{
|
||||
FromNode: "lambda1",
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
},
|
||||
{
|
||||
FromNode: "lambda2",
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
},
|
||||
{
|
||||
FromNode: "lambda3",
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
branches, err := schema.BuildBranches(ws.Connections)
|
||||
assert.NoError(t, err)
|
||||
ws.Branches = branches
|
||||
|
||||
ws.Init()
|
||||
|
||||
ctx := context.Background()
|
||||
@@ -303,19 +309,17 @@ func TestAddSelector(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestVariableAggregator(t *testing.T) {
|
||||
entry := &compose2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
entryN := &schema.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: &entry.Config{},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
exitN := &schema.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
Configs: &exit.Config{
|
||||
TerminatePlan: vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
@@ -339,16 +343,16 @@ func TestVariableAggregator(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
ns := &compose2.NodeSchema{
|
||||
ns := &schema.NodeSchema{
|
||||
Key: "va",
|
||||
Type: entity.NodeTypeVariableAggregator,
|
||||
Configs: map[string]any{
|
||||
"MergeStrategy": variableaggregator.FirstNotNullValue,
|
||||
"GroupToLen": map[string]int{
|
||||
Configs: &variableaggregator.Config{
|
||||
MergeStrategy: variableaggregator.FirstNotNullValue,
|
||||
GroupLen: map[string]int{
|
||||
"Group1": 1,
|
||||
"Group2": 1,
|
||||
},
|
||||
"GroupOrder": []string{
|
||||
GroupOrder: []string{
|
||||
"Group1",
|
||||
"Group2",
|
||||
},
|
||||
@@ -358,7 +362,7 @@ func TestVariableAggregator(t *testing.T) {
|
||||
Path: compose.FieldPath{"Group1", "0"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"Str1"},
|
||||
},
|
||||
},
|
||||
@@ -367,7 +371,7 @@ func TestVariableAggregator(t *testing.T) {
|
||||
Path: compose.FieldPath{"Group2", "0"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"Int1"},
|
||||
},
|
||||
},
|
||||
@@ -401,20 +405,20 @@ func TestVariableAggregator(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
ws := &schema.WorkflowSchema{
|
||||
Nodes: []*schema.NodeSchema{
|
||||
entryN,
|
||||
ns,
|
||||
exit,
|
||||
exitN,
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
Connections: []*schema.Connection{
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
FromNode: entryN.Key,
|
||||
ToNode: "va",
|
||||
},
|
||||
{
|
||||
FromNode: "va",
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -448,19 +452,17 @@ func TestVariableAggregator(t *testing.T) {
|
||||
|
||||
func TestTextProcessor(t *testing.T) {
|
||||
t.Run("split", func(t *testing.T) {
|
||||
entry := &compose2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
entryN := &schema.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: &entry.Config{},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
exitN := &schema.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
Configs: &exit.Config{
|
||||
TerminatePlan: vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
@@ -475,19 +477,19 @@ func TestTextProcessor(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
ns := &compose2.NodeSchema{
|
||||
ns := &schema.NodeSchema{
|
||||
Key: "tp",
|
||||
Type: entity.NodeTypeTextProcessor,
|
||||
Configs: map[string]any{
|
||||
"Type": textprocessor.SplitText,
|
||||
"Separators": []string{"|"},
|
||||
Configs: &textprocessor.Config{
|
||||
Type: textprocessor.SplitText,
|
||||
Separators: []string{"|"},
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"String"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"Str"},
|
||||
},
|
||||
},
|
||||
@@ -495,20 +497,20 @@ func TestTextProcessor(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
ws := &schema.WorkflowSchema{
|
||||
Nodes: []*schema.NodeSchema{
|
||||
ns,
|
||||
entry,
|
||||
exit,
|
||||
entryN,
|
||||
exitN,
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
Connections: []*schema.Connection{
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
FromNode: entryN.Key,
|
||||
ToNode: "tp",
|
||||
},
|
||||
{
|
||||
FromNode: "tp",
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -527,19 +529,17 @@ func TestTextProcessor(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("concat", func(t *testing.T) {
|
||||
entry := &compose2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
entryN := &schema.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: &entry.Config{},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
exitN := &schema.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
Configs: &exit.Config{
|
||||
TerminatePlan: vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
@@ -554,20 +554,20 @@ func TestTextProcessor(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
ns := &compose2.NodeSchema{
|
||||
ns := &schema.NodeSchema{
|
||||
Key: "tp",
|
||||
Type: entity.NodeTypeTextProcessor,
|
||||
Configs: map[string]any{
|
||||
"Type": textprocessor.ConcatText,
|
||||
"Tpl": "{{String1}}_{{String2.f1}}_{{String3.f2[1]}}",
|
||||
"ConcatChar": "\t",
|
||||
Configs: &textprocessor.Config{
|
||||
Type: textprocessor.ConcatText,
|
||||
Tpl: "{{String1}}_{{String2.f1}}_{{String3.f2[1]}}",
|
||||
ConcatChar: "\t",
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"String1"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"Str1"},
|
||||
},
|
||||
},
|
||||
@@ -576,7 +576,7 @@ func TestTextProcessor(t *testing.T) {
|
||||
Path: compose.FieldPath{"String2"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"Str2"},
|
||||
},
|
||||
},
|
||||
@@ -585,7 +585,7 @@ func TestTextProcessor(t *testing.T) {
|
||||
Path: compose.FieldPath{"String3"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"Str3"},
|
||||
},
|
||||
},
|
||||
@@ -593,20 +593,20 @@ func TestTextProcessor(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
ws := &schema.WorkflowSchema{
|
||||
Nodes: []*schema.NodeSchema{
|
||||
ns,
|
||||
entry,
|
||||
exit,
|
||||
entryN,
|
||||
exitN,
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
Connections: []*schema.Connection{
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
FromNode: entryN.Key,
|
||||
ToNode: "tp",
|
||||
},
|
||||
{
|
||||
FromNode: "tp",
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,652 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package compose
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
einomodel "github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
workflow3 "github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/workflow"
|
||||
workflow2 "github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
crosscode "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/code"
|
||||
crossconversation "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/conversation"
|
||||
crossdatabase "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
crossknowledge "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
|
||||
crossplugin "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/batch"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/code"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/conversation"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/emitter"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/httprequester"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/intentdetector"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/json"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/receiver"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/subworkflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/textprocessor"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableaggregator"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableassigner"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/coderunner"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
)
|
||||
|
||||
func (s *NodeSchema) ToEntryConfig(_ context.Context) (*entry.Config, error) {
|
||||
return &entry.Config{
|
||||
DefaultValues: getKeyOrZero[map[string]any]("DefaultValues", s.Configs),
|
||||
OutputTypes: s.OutputTypes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToLLMConfig(ctx context.Context) (*llm.Config, error) {
|
||||
llmConf := &llm.Config{
|
||||
SystemPrompt: getKeyOrZero[string]("SystemPrompt", s.Configs),
|
||||
UserPrompt: getKeyOrZero[string]("UserPrompt", s.Configs),
|
||||
OutputFormat: mustGetKey[llm.Format]("OutputFormat", s.Configs),
|
||||
InputFields: s.InputTypes,
|
||||
OutputFields: s.OutputTypes,
|
||||
FullSources: getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs),
|
||||
}
|
||||
|
||||
llmParams := getKeyOrZero[*model.LLMParams]("LLMParams", s.Configs)
|
||||
|
||||
if llmParams == nil {
|
||||
return nil, fmt.Errorf("llm node llmParams is required")
|
||||
}
|
||||
var (
|
||||
err error
|
||||
chatModel, fallbackM einomodel.BaseChatModel
|
||||
info, fallbackI *modelmgr.Model
|
||||
modelWithInfo llm.ModelWithInfo
|
||||
)
|
||||
|
||||
chatModel, info, err = model.GetManager().GetModel(ctx, llmParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
metaConfigs := s.ExceptionConfigs
|
||||
if metaConfigs != nil && metaConfigs.MaxRetry > 0 {
|
||||
backupModelParams := getKeyOrZero[*model.LLMParams]("BackupLLMParams", s.Configs)
|
||||
if backupModelParams != nil {
|
||||
fallbackM, fallbackI, err = model.GetManager().GetModel(ctx, backupModelParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if fallbackM == nil {
|
||||
modelWithInfo = llm.NewModel(chatModel, info)
|
||||
} else {
|
||||
modelWithInfo = llm.NewModelWithFallback(chatModel, fallbackM, info, fallbackI)
|
||||
}
|
||||
llmConf.ChatModel = modelWithInfo
|
||||
|
||||
fcParams := getKeyOrZero[*vo.FCParam]("FCParam", s.Configs)
|
||||
if fcParams != nil {
|
||||
if fcParams.WorkflowFCParam != nil {
|
||||
for _, wf := range fcParams.WorkflowFCParam.WorkflowList {
|
||||
wfIDStr := wf.WorkflowID
|
||||
wfID, err := strconv.ParseInt(wfIDStr, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid workflow id: %s", wfIDStr)
|
||||
}
|
||||
|
||||
workflowToolConfig := vo.WorkflowToolConfig{}
|
||||
if wf.FCSetting != nil {
|
||||
workflowToolConfig.InputParametersConfig = wf.FCSetting.RequestParameters
|
||||
workflowToolConfig.OutputParametersConfig = wf.FCSetting.ResponseParameters
|
||||
}
|
||||
|
||||
locator := vo.FromDraft
|
||||
if wf.WorkflowVersion != "" {
|
||||
locator = vo.FromSpecificVersion
|
||||
}
|
||||
|
||||
wfTool, err := workflow2.GetRepository().WorkflowAsTool(ctx, vo.GetPolicy{
|
||||
ID: wfID,
|
||||
QType: locator,
|
||||
Version: wf.WorkflowVersion,
|
||||
}, workflowToolConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
llmConf.Tools = append(llmConf.Tools, wfTool)
|
||||
if wfTool.TerminatePlan() == vo.UseAnswerContent {
|
||||
if llmConf.ToolsReturnDirectly == nil {
|
||||
llmConf.ToolsReturnDirectly = make(map[string]bool)
|
||||
}
|
||||
toolInfo, err := wfTool.Info(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
llmConf.ToolsReturnDirectly[toolInfo.Name] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if fcParams.PluginFCParam != nil {
|
||||
pluginToolsInvokableReq := make(map[int64]*crossplugin.ToolsInvokableRequest)
|
||||
for _, p := range fcParams.PluginFCParam.PluginList {
|
||||
pid, err := strconv.ParseInt(p.PluginID, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid plugin id: %s", p.PluginID)
|
||||
}
|
||||
toolID, err := strconv.ParseInt(p.ApiId, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid plugin id: %s", p.PluginID)
|
||||
}
|
||||
|
||||
var (
|
||||
requestParameters []*workflow3.APIParameter
|
||||
responseParameters []*workflow3.APIParameter
|
||||
)
|
||||
if p.FCSetting != nil {
|
||||
requestParameters = p.FCSetting.RequestParameters
|
||||
responseParameters = p.FCSetting.ResponseParameters
|
||||
}
|
||||
|
||||
if req, ok := pluginToolsInvokableReq[pid]; ok {
|
||||
req.ToolsInvokableInfo[toolID] = &crossplugin.ToolsInvokableInfo{
|
||||
ToolID: toolID,
|
||||
RequestAPIParametersConfig: requestParameters,
|
||||
ResponseAPIParametersConfig: responseParameters,
|
||||
}
|
||||
} else {
|
||||
pluginToolsInfoRequest := &crossplugin.ToolsInvokableRequest{
|
||||
PluginEntity: crossplugin.Entity{
|
||||
PluginID: pid,
|
||||
PluginVersion: ptr.Of(p.PluginVersion),
|
||||
},
|
||||
ToolsInvokableInfo: map[int64]*crossplugin.ToolsInvokableInfo{
|
||||
toolID: {
|
||||
ToolID: toolID,
|
||||
RequestAPIParametersConfig: requestParameters,
|
||||
ResponseAPIParametersConfig: responseParameters,
|
||||
},
|
||||
},
|
||||
IsDraft: p.IsDraft,
|
||||
}
|
||||
pluginToolsInvokableReq[pid] = pluginToolsInfoRequest
|
||||
}
|
||||
}
|
||||
inInvokableTools := make([]tool.BaseTool, 0, len(fcParams.PluginFCParam.PluginList))
|
||||
for _, req := range pluginToolsInvokableReq {
|
||||
toolMap, err := crossplugin.GetPluginService().GetPluginInvokableTools(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, t := range toolMap {
|
||||
inInvokableTools = append(inInvokableTools, crossplugin.NewInvokableTool(t))
|
||||
}
|
||||
}
|
||||
if len(inInvokableTools) > 0 {
|
||||
llmConf.Tools = inInvokableTools
|
||||
}
|
||||
}
|
||||
|
||||
if fcParams.KnowledgeFCParam != nil && len(fcParams.KnowledgeFCParam.KnowledgeList) > 0 {
|
||||
kwChatModel := workflow2.GetRepository().GetKnowledgeRecallChatModel()
|
||||
if kwChatModel == nil {
|
||||
return nil, fmt.Errorf("workflow builtin chat model for knowledge recall not configured")
|
||||
}
|
||||
knowledgeOperator := crossknowledge.GetKnowledgeOperator()
|
||||
setting := fcParams.KnowledgeFCParam.GlobalSetting
|
||||
cfg := &llm.KnowledgeRecallConfig{
|
||||
ChatModel: kwChatModel,
|
||||
Retriever: knowledgeOperator,
|
||||
}
|
||||
searchType, err := totRetrievalSearchType(setting.SearchMode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cfg.RetrievalStrategy = &llm.RetrievalStrategy{
|
||||
RetrievalStrategy: &crossknowledge.RetrievalStrategy{
|
||||
TopK: ptr.Of(setting.TopK),
|
||||
MinScore: ptr.Of(setting.MinScore),
|
||||
SearchType: searchType,
|
||||
EnableNL2SQL: setting.UseNL2SQL,
|
||||
EnableQueryRewrite: setting.UseRewrite,
|
||||
EnableRerank: setting.UseRerank,
|
||||
},
|
||||
NoReCallReplyMode: llm.NoReCallReplyMode(setting.NoRecallReplyMode),
|
||||
NoReCallReplyCustomizePrompt: setting.NoRecallReplyCustomizePrompt,
|
||||
}
|
||||
|
||||
knowledgeIDs := make([]int64, 0, len(fcParams.KnowledgeFCParam.KnowledgeList))
|
||||
for _, kw := range fcParams.KnowledgeFCParam.KnowledgeList {
|
||||
kid, err := strconv.ParseInt(kw.ID, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
knowledgeIDs = append(knowledgeIDs, kid)
|
||||
}
|
||||
|
||||
detailResp, err := knowledgeOperator.ListKnowledgeDetail(ctx, &crossknowledge.ListKnowledgeDetailRequest{
|
||||
KnowledgeIDs: knowledgeIDs,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cfg.SelectedKnowledgeDetails = detailResp.KnowledgeDetails
|
||||
llmConf.KnowledgeRecallConfig = cfg
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return llmConf, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToSelectorConfig() *selector.Config {
|
||||
return &selector.Config{
|
||||
Clauses: mustGetKey[[]*selector.OneClauseSchema]("Clauses", s.Configs),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NodeSchema) SelectorInputConverter(in map[string]any) (out []selector.Operants, err error) {
|
||||
conf := mustGetKey[[]*selector.OneClauseSchema]("Clauses", s.Configs)
|
||||
|
||||
for i, oneConf := range conf {
|
||||
if oneConf.Single != nil {
|
||||
left, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), selector.LeftKey})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to take left operant from input map: %v, clause index= %d", in, i)
|
||||
}
|
||||
|
||||
right, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), selector.RightKey})
|
||||
if ok {
|
||||
out = append(out, selector.Operants{Left: left, Right: right})
|
||||
} else {
|
||||
out = append(out, selector.Operants{Left: left})
|
||||
}
|
||||
} else if oneConf.Multi != nil {
|
||||
multiClause := make([]*selector.Operants, 0)
|
||||
for j := range oneConf.Multi.Clauses {
|
||||
left, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), strconv.Itoa(j), selector.LeftKey})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to take left operant from input map: %v, clause index= %d, single clause index= %d", in, i, j)
|
||||
}
|
||||
right, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), strconv.Itoa(j), selector.RightKey})
|
||||
if ok {
|
||||
multiClause = append(multiClause, &selector.Operants{Left: left, Right: right})
|
||||
} else {
|
||||
multiClause = append(multiClause, &selector.Operants{Left: left})
|
||||
}
|
||||
}
|
||||
out = append(out, selector.Operants{Multi: multiClause})
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid clause config, both single and multi are nil: %v", oneConf)
|
||||
}
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToBatchConfig(inner compose.Runnable[map[string]any, map[string]any]) (*batch.Config, error) {
|
||||
conf := &batch.Config{
|
||||
BatchNodeKey: s.Key,
|
||||
InnerWorkflow: inner,
|
||||
Outputs: s.OutputSources,
|
||||
}
|
||||
|
||||
for key, tInfo := range s.InputTypes {
|
||||
if tInfo.Type != vo.DataTypeArray {
|
||||
continue
|
||||
}
|
||||
|
||||
conf.InputArrays = append(conf.InputArrays, key)
|
||||
}
|
||||
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToVariableAggregatorConfig() (*variableaggregator.Config, error) {
|
||||
return &variableaggregator.Config{
|
||||
MergeStrategy: s.Configs.(map[string]any)["MergeStrategy"].(variableaggregator.MergeStrategy),
|
||||
GroupLen: s.Configs.(map[string]any)["GroupToLen"].(map[string]int),
|
||||
FullSources: getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs),
|
||||
NodeKey: s.Key,
|
||||
InputSources: s.InputSources,
|
||||
GroupOrder: mustGetKey[[]string]("GroupOrder", s.Configs),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) variableAggregatorInputConverter(in map[string]any) (converted map[string]map[int]any) {
|
||||
converted = make(map[string]map[int]any)
|
||||
|
||||
for k, value := range in {
|
||||
m, ok := value.(map[string]any)
|
||||
if !ok {
|
||||
panic(errors.New("value is not a map[string]any"))
|
||||
}
|
||||
converted[k] = make(map[int]any, len(m))
|
||||
for i, sv := range m {
|
||||
index, err := strconv.Atoi(i)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf(" converting %s to int failed, err=%v", i, err))
|
||||
}
|
||||
converted[k][index] = sv
|
||||
}
|
||||
}
|
||||
|
||||
return converted
|
||||
}
|
||||
|
||||
func (s *NodeSchema) variableAggregatorStreamInputConverter(in *schema.StreamReader[map[string]any]) *schema.StreamReader[map[string]map[int]any] {
|
||||
converter := func(input map[string]any) (output map[string]map[int]any, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = safego.NewPanicErr(r, debug.Stack())
|
||||
}
|
||||
}()
|
||||
return s.variableAggregatorInputConverter(input), nil
|
||||
}
|
||||
return schema.StreamReaderWithConvert(in, converter)
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToTextProcessorConfig() (*textprocessor.Config, error) {
|
||||
return &textprocessor.Config{
|
||||
Type: s.Configs.(map[string]any)["Type"].(textprocessor.Type),
|
||||
Tpl: getKeyOrZero[string]("Tpl", s.Configs.(map[string]any)),
|
||||
ConcatChar: getKeyOrZero[string]("ConcatChar", s.Configs.(map[string]any)),
|
||||
Separators: getKeyOrZero[[]string]("Separators", s.Configs.(map[string]any)),
|
||||
FullSources: getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToJsonSerializationConfig() (*json.SerializationConfig, error) {
|
||||
return &json.SerializationConfig{
|
||||
InputTypes: s.InputTypes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToJsonDeserializationConfig() (*json.DeserializationConfig, error) {
|
||||
return &json.DeserializationConfig{
|
||||
OutputFields: s.OutputTypes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToHTTPRequesterConfig() (*httprequester.Config, error) {
|
||||
return &httprequester.Config{
|
||||
URLConfig: mustGetKey[httprequester.URLConfig]("URLConfig", s.Configs),
|
||||
AuthConfig: getKeyOrZero[*httprequester.AuthenticationConfig]("AuthConfig", s.Configs),
|
||||
BodyConfig: mustGetKey[httprequester.BodyConfig]("BodyConfig", s.Configs),
|
||||
Method: mustGetKey[string]("Method", s.Configs),
|
||||
Timeout: mustGetKey[time.Duration]("Timeout", s.Configs),
|
||||
RetryTimes: mustGetKey[uint64]("RetryTimes", s.Configs),
|
||||
MD5FieldMapping: mustGetKey[httprequester.MD5FieldMapping]("MD5FieldMapping", s.Configs),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToVariableAssignerConfig(handler *variable.Handler) (*variableassigner.Config, error) {
|
||||
return &variableassigner.Config{
|
||||
Pairs: s.Configs.([]*variableassigner.Pair),
|
||||
Handler: handler,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToVariableAssignerInLoopConfig() (*variableassigner.Config, error) {
|
||||
return &variableassigner.Config{
|
||||
Pairs: s.Configs.([]*variableassigner.Pair),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToLoopConfig(inner compose.Runnable[map[string]any, map[string]any]) (*loop.Config, error) {
|
||||
conf := &loop.Config{
|
||||
LoopNodeKey: s.Key,
|
||||
LoopType: mustGetKey[loop.Type]("LoopType", s.Configs),
|
||||
IntermediateVars: getKeyOrZero[map[string]*vo.TypeInfo]("IntermediateVars", s.Configs),
|
||||
Outputs: s.OutputSources,
|
||||
Inner: inner,
|
||||
}
|
||||
|
||||
for key, tInfo := range s.InputTypes {
|
||||
if tInfo.Type != vo.DataTypeArray {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := conf.IntermediateVars[key]; ok { // exclude arrays in intermediate vars
|
||||
continue
|
||||
}
|
||||
|
||||
conf.InputArrays = append(conf.InputArrays, key)
|
||||
}
|
||||
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToQAConfig(ctx context.Context) (*qa.Config, error) {
|
||||
conf := &qa.Config{
|
||||
QuestionTpl: mustGetKey[string]("QuestionTpl", s.Configs),
|
||||
AnswerType: mustGetKey[qa.AnswerType]("AnswerType", s.Configs),
|
||||
ChoiceType: getKeyOrZero[qa.ChoiceType]("ChoiceType", s.Configs),
|
||||
FixedChoices: getKeyOrZero[[]string]("FixedChoices", s.Configs),
|
||||
ExtractFromAnswer: getKeyOrZero[bool]("ExtractFromAnswer", s.Configs),
|
||||
MaxAnswerCount: getKeyOrZero[int]("MaxAnswerCount", s.Configs),
|
||||
AdditionalSystemPromptTpl: getKeyOrZero[string]("AdditionalSystemPromptTpl", s.Configs),
|
||||
OutputFields: s.OutputTypes,
|
||||
NodeKey: s.Key,
|
||||
}
|
||||
|
||||
llmParams := getKeyOrZero[*model.LLMParams]("LLMParams", s.Configs)
|
||||
if llmParams != nil {
|
||||
m, _, err := model.GetManager().GetModel(ctx, llmParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conf.Model = m
|
||||
}
|
||||
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToInputReceiverConfig() (*receiver.Config, error) {
|
||||
return &receiver.Config{
|
||||
OutputTypes: s.OutputTypes,
|
||||
NodeKey: s.Key,
|
||||
OutputSchema: mustGetKey[string]("OutputSchema", s.Configs),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToOutputEmitterConfig(sc *WorkflowSchema) (*emitter.Config, error) {
|
||||
conf := &emitter.Config{
|
||||
Template: getKeyOrZero[string]("Template", s.Configs),
|
||||
FullSources: getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs),
|
||||
}
|
||||
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToDatabaseCustomSQLConfig() (*database.CustomSQLConfig, error) {
|
||||
return &database.CustomSQLConfig{
|
||||
DatabaseInfoID: mustGetKey[int64]("DatabaseInfoID", s.Configs),
|
||||
SQLTemplate: mustGetKey[string]("SQLTemplate", s.Configs),
|
||||
OutputConfig: s.OutputTypes,
|
||||
CustomSQLExecutor: crossdatabase.GetDatabaseOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToDatabaseQueryConfig() (*database.QueryConfig, error) {
|
||||
return &database.QueryConfig{
|
||||
DatabaseInfoID: mustGetKey[int64]("DatabaseInfoID", s.Configs),
|
||||
QueryFields: getKeyOrZero[[]string]("QueryFields", s.Configs),
|
||||
OrderClauses: getKeyOrZero[[]*crossdatabase.OrderClause]("OrderClauses", s.Configs),
|
||||
ClauseGroup: getKeyOrZero[*crossdatabase.ClauseGroup]("ClauseGroup", s.Configs),
|
||||
OutputConfig: s.OutputTypes,
|
||||
Limit: mustGetKey[int64]("Limit", s.Configs),
|
||||
Op: crossdatabase.GetDatabaseOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToDatabaseInsertConfig() (*database.InsertConfig, error) {
|
||||
return &database.InsertConfig{
|
||||
DatabaseInfoID: mustGetKey[int64]("DatabaseInfoID", s.Configs),
|
||||
OutputConfig: s.OutputTypes,
|
||||
Inserter: crossdatabase.GetDatabaseOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToDatabaseDeleteConfig() (*database.DeleteConfig, error) {
|
||||
return &database.DeleteConfig{
|
||||
DatabaseInfoID: mustGetKey[int64]("DatabaseInfoID", s.Configs),
|
||||
ClauseGroup: mustGetKey[*crossdatabase.ClauseGroup]("ClauseGroup", s.Configs),
|
||||
OutputConfig: s.OutputTypes,
|
||||
Deleter: crossdatabase.GetDatabaseOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToDatabaseUpdateConfig() (*database.UpdateConfig, error) {
|
||||
return &database.UpdateConfig{
|
||||
DatabaseInfoID: mustGetKey[int64]("DatabaseInfoID", s.Configs),
|
||||
ClauseGroup: mustGetKey[*crossdatabase.ClauseGroup]("ClauseGroup", s.Configs),
|
||||
OutputConfig: s.OutputTypes,
|
||||
Updater: crossdatabase.GetDatabaseOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToKnowledgeIndexerConfig() (*knowledge.IndexerConfig, error) {
|
||||
return &knowledge.IndexerConfig{
|
||||
KnowledgeID: mustGetKey[int64]("KnowledgeID", s.Configs),
|
||||
ParsingStrategy: mustGetKey[*crossknowledge.ParsingStrategy]("ParsingStrategy", s.Configs),
|
||||
ChunkingStrategy: mustGetKey[*crossknowledge.ChunkingStrategy]("ChunkingStrategy", s.Configs),
|
||||
KnowledgeIndexer: crossknowledge.GetKnowledgeOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToKnowledgeRetrieveConfig() (*knowledge.RetrieveConfig, error) {
|
||||
return &knowledge.RetrieveConfig{
|
||||
KnowledgeIDs: mustGetKey[[]int64]("KnowledgeIDs", s.Configs),
|
||||
RetrievalStrategy: mustGetKey[*crossknowledge.RetrievalStrategy]("RetrievalStrategy", s.Configs),
|
||||
Retriever: crossknowledge.GetKnowledgeOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToKnowledgeDeleterConfig() (*knowledge.DeleterConfig, error) {
|
||||
return &knowledge.DeleterConfig{
|
||||
KnowledgeID: mustGetKey[int64]("KnowledgeID", s.Configs),
|
||||
KnowledgeDeleter: crossknowledge.GetKnowledgeOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToPluginConfig() (*plugin.Config, error) {
|
||||
return &plugin.Config{
|
||||
PluginID: mustGetKey[int64]("PluginID", s.Configs),
|
||||
ToolID: mustGetKey[int64]("ToolID", s.Configs),
|
||||
PluginVersion: mustGetKey[string]("PluginVersion", s.Configs),
|
||||
PluginService: crossplugin.GetPluginService(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToCodeRunnerConfig() (*code.Config, error) {
|
||||
return &code.Config{
|
||||
Code: mustGetKey[string]("Code", s.Configs),
|
||||
Language: mustGetKey[coderunner.Language]("Language", s.Configs),
|
||||
OutputConfig: s.OutputTypes,
|
||||
Runner: crosscode.GetCodeRunner(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToCreateConversationConfig() (*conversation.CreateConversationConfig, error) {
|
||||
return &conversation.CreateConversationConfig{
|
||||
Creator: crossconversation.ConversationManagerImpl,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToClearMessageConfig() (*conversation.ClearMessageConfig, error) {
|
||||
return &conversation.ClearMessageConfig{
|
||||
Clearer: crossconversation.ConversationManagerImpl,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToMessageListConfig() (*conversation.MessageListConfig, error) {
|
||||
return &conversation.MessageListConfig{
|
||||
Lister: crossconversation.ConversationManagerImpl,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToIntentDetectorConfig(ctx context.Context) (*intentdetector.Config, error) {
|
||||
cfg := &intentdetector.Config{
|
||||
Intents: mustGetKey[[]string]("Intents", s.Configs),
|
||||
SystemPrompt: getKeyOrZero[string]("SystemPrompt", s.Configs),
|
||||
IsFastMode: getKeyOrZero[bool]("IsFastMode", s.Configs),
|
||||
}
|
||||
|
||||
llmParams := mustGetKey[*model.LLMParams]("LLMParams", s.Configs)
|
||||
m, _, err := model.GetManager().GetModel(ctx, llmParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cfg.ChatModel = m
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToSubWorkflowConfig(ctx context.Context, requireCheckpoint bool) (*subworkflow.Config, error) {
|
||||
var opts []WorkflowOption
|
||||
opts = append(opts, WithIDAsName(mustGetKey[int64]("WorkflowID", s.Configs)))
|
||||
if requireCheckpoint {
|
||||
opts = append(opts, WithParentRequireCheckpoint())
|
||||
}
|
||||
if s := execute.GetStaticConfig(); s != nil && s.MaxNodeCountPerWorkflow > 0 {
|
||||
opts = append(opts, WithMaxNodeCount(s.MaxNodeCountPerWorkflow))
|
||||
}
|
||||
wf, err := NewWorkflow(ctx, s.SubWorkflowSchema, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &subworkflow.Config{
|
||||
Runner: wf.Runner,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func totRetrievalSearchType(s int64) (crossknowledge.SearchType, error) {
|
||||
switch s {
|
||||
case 0:
|
||||
return crossknowledge.SearchTypeSemantic, nil
|
||||
case 1:
|
||||
return crossknowledge.SearchTypeHybrid, nil
|
||||
case 20:
|
||||
return crossknowledge.SearchTypeFullText, nil
|
||||
default:
|
||||
return "", fmt.Errorf("invalid retrieval search type %v", s)
|
||||
}
|
||||
}
|
||||
@@ -1,107 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package compose
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
)
|
||||
|
||||
func getKeyOrZero[T any](key string, cfg any) T {
|
||||
var zero T
|
||||
if cfg == nil {
|
||||
return zero
|
||||
}
|
||||
|
||||
m, ok := cfg.(map[string]any)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("m is not a map[string]any, actual type: %v", reflect.TypeOf(cfg)))
|
||||
}
|
||||
|
||||
if len(m) == 0 {
|
||||
return zero
|
||||
}
|
||||
|
||||
if v, ok := m[key]; ok {
|
||||
return v.(T)
|
||||
}
|
||||
|
||||
return zero
|
||||
}
|
||||
|
||||
func mustGetKey[T any](key string, cfg any) T {
|
||||
if cfg == nil {
|
||||
panic(fmt.Sprintf("mustGetKey[*any] is nil, key=%s", key))
|
||||
}
|
||||
|
||||
m, ok := cfg.(map[string]any)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("m is not a map[string]any, actual type: %v", reflect.TypeOf(cfg)))
|
||||
}
|
||||
|
||||
if _, ok := m[key]; !ok {
|
||||
panic(fmt.Sprintf("key %s does not exist in map: %v", key, m))
|
||||
}
|
||||
|
||||
v, ok := m[key].(T)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("key %s is not a %v, actual type: %v", key, reflect.TypeOf(v), reflect.TypeOf(m[key])))
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
func (s *NodeSchema) SetConfigKV(key string, value any) {
|
||||
if s.Configs == nil {
|
||||
s.Configs = make(map[string]any)
|
||||
}
|
||||
|
||||
s.Configs.(map[string]any)[key] = value
|
||||
}
|
||||
|
||||
func (s *NodeSchema) SetInputType(key string, t *vo.TypeInfo) {
|
||||
if s.InputTypes == nil {
|
||||
s.InputTypes = make(map[string]*vo.TypeInfo)
|
||||
}
|
||||
s.InputTypes[key] = t
|
||||
}
|
||||
|
||||
func (s *NodeSchema) AddInputSource(info ...*vo.FieldInfo) {
|
||||
s.InputSources = append(s.InputSources, info...)
|
||||
}
|
||||
|
||||
func (s *NodeSchema) SetOutputType(key string, t *vo.TypeInfo) {
|
||||
if s.OutputTypes == nil {
|
||||
s.OutputTypes = make(map[string]*vo.TypeInfo)
|
||||
}
|
||||
s.OutputTypes[key] = t
|
||||
}
|
||||
|
||||
func (s *NodeSchema) AddOutputSource(info ...*vo.FieldInfo) {
|
||||
s.OutputSources = append(s.OutputSources, info...)
|
||||
}
|
||||
|
||||
func (s *NodeSchema) GetSubWorkflowIdentity() (int64, string, bool) {
|
||||
if s.Type != entity.NodeTypeSubWorkflow {
|
||||
return 0, "", false
|
||||
}
|
||||
|
||||
return mustGetKey[int64]("WorkflowID", s.Configs), mustGetKey[string]("WorkflowVersion", s.Configs), true
|
||||
}
|
||||
@@ -29,6 +29,8 @@ import (
|
||||
workflow2 "github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
)
|
||||
|
||||
@@ -37,7 +39,7 @@ type workflow = compose.Workflow[map[string]any, map[string]any]
|
||||
type Workflow struct { // TODO: too many fields in this struct, cut them down to the absolutely essentials
|
||||
*workflow
|
||||
hierarchy map[vo.NodeKey]vo.NodeKey
|
||||
connections []*Connection
|
||||
connections []*schema.Connection
|
||||
requireCheckpoint bool
|
||||
entry *compose.WorkflowNode
|
||||
inner bool
|
||||
@@ -47,7 +49,7 @@ type Workflow struct { // TODO: too many fields in this struct, cut them down to
|
||||
input map[string]*vo.TypeInfo
|
||||
output map[string]*vo.TypeInfo
|
||||
terminatePlan vo.TerminatePlan
|
||||
schema *WorkflowSchema
|
||||
schema *schema.WorkflowSchema
|
||||
}
|
||||
|
||||
type workflowOptions struct {
|
||||
@@ -78,7 +80,7 @@ func WithMaxNodeCount(c int) WorkflowOption {
|
||||
}
|
||||
}
|
||||
|
||||
func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption) (*Workflow, error) {
|
||||
func NewWorkflow(ctx context.Context, sc *schema.WorkflowSchema, opts ...WorkflowOption) (*Workflow, error) {
|
||||
sc.Init()
|
||||
|
||||
wf := &Workflow{
|
||||
@@ -88,8 +90,8 @@ func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption
|
||||
schema: sc,
|
||||
}
|
||||
|
||||
wf.streamRun = sc.requireStreaming
|
||||
wf.requireCheckpoint = sc.requireCheckPoint
|
||||
wf.streamRun = sc.RequireStreaming()
|
||||
wf.requireCheckpoint = sc.RequireCheckpoint()
|
||||
|
||||
wfOpts := &workflowOptions{}
|
||||
for _, opt := range opts {
|
||||
@@ -125,7 +127,6 @@ func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption
|
||||
processedNodeKey[child.Key] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// add all nodes other than composite nodes and their children
|
||||
for _, ns := range sc.Nodes {
|
||||
if _, ok := processedNodeKey[ns.Key]; !ok {
|
||||
@@ -135,7 +136,7 @@ func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption
|
||||
}
|
||||
|
||||
if ns.Type == entity.NodeTypeExit {
|
||||
wf.terminatePlan = mustGetKey[vo.TerminatePlan]("TerminalPlan", ns.Configs)
|
||||
wf.terminatePlan = ns.Configs.(*exit.Config).TerminatePlan
|
||||
}
|
||||
}
|
||||
|
||||
@@ -147,7 +148,7 @@ func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption
|
||||
compileOpts = append(compileOpts, compose.WithGraphName(strconv.FormatInt(wfOpts.wfID, 10)))
|
||||
}
|
||||
|
||||
fanInConfigs := sc.fanInMergeConfigs()
|
||||
fanInConfigs := sc.FanInMergeConfigs()
|
||||
if len(fanInConfigs) > 0 {
|
||||
compileOpts = append(compileOpts, compose.WithFanInMergeConfig(fanInConfigs))
|
||||
}
|
||||
@@ -199,12 +200,12 @@ type innerWorkflowInfo struct {
|
||||
carryOvers map[vo.NodeKey][]*compose.FieldMapping
|
||||
}
|
||||
|
||||
func (w *Workflow) AddNode(ctx context.Context, ns *NodeSchema) error {
|
||||
func (w *Workflow) AddNode(ctx context.Context, ns *schema.NodeSchema) error {
|
||||
_, err := w.addNodeInternal(ctx, ns, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *Workflow) AddCompositeNode(ctx context.Context, cNode *CompositeNode) error {
|
||||
func (w *Workflow) AddCompositeNode(ctx context.Context, cNode *schema.CompositeNode) error {
|
||||
inner, err := w.getInnerWorkflow(ctx, cNode)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -213,11 +214,11 @@ func (w *Workflow) AddCompositeNode(ctx context.Context, cNode *CompositeNode) e
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *Workflow) addInnerNode(ctx context.Context, cNode *NodeSchema) (map[vo.NodeKey][]*compose.FieldMapping, error) {
|
||||
func (w *Workflow) addInnerNode(ctx context.Context, cNode *schema.NodeSchema) (map[vo.NodeKey][]*compose.FieldMapping, error) {
|
||||
return w.addNodeInternal(ctx, cNode, nil)
|
||||
}
|
||||
|
||||
func (w *Workflow) addNodeInternal(ctx context.Context, ns *NodeSchema, inner *innerWorkflowInfo) (map[vo.NodeKey][]*compose.FieldMapping, error) {
|
||||
func (w *Workflow) addNodeInternal(ctx context.Context, ns *schema.NodeSchema, inner *innerWorkflowInfo) (map[vo.NodeKey][]*compose.FieldMapping, error) {
|
||||
key := ns.Key
|
||||
var deps *dependencyInfo
|
||||
|
||||
@@ -237,7 +238,7 @@ func (w *Workflow) addNodeInternal(ctx context.Context, ns *NodeSchema, inner *i
|
||||
innerWorkflow = inner.inner
|
||||
}
|
||||
|
||||
ins, err := ns.New(ctx, innerWorkflow, w.schema, deps)
|
||||
ins, err := New(ctx, ns, innerWorkflow, w.schema, deps)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -245,12 +246,12 @@ func (w *Workflow) addNodeInternal(ctx context.Context, ns *NodeSchema, inner *i
|
||||
var opts []compose.GraphAddNodeOpt
|
||||
opts = append(opts, compose.WithNodeName(string(ns.Key)))
|
||||
|
||||
preHandler := ns.StatePreHandler(w.streamRun)
|
||||
preHandler := statePreHandler(ns, w.streamRun)
|
||||
if preHandler != nil {
|
||||
opts = append(opts, preHandler)
|
||||
}
|
||||
|
||||
postHandler := ns.StatePostHandler(w.streamRun)
|
||||
postHandler := statePostHandler(ns, w.streamRun)
|
||||
if postHandler != nil {
|
||||
opts = append(opts, postHandler)
|
||||
}
|
||||
@@ -297,19 +298,23 @@ func (w *Workflow) addNodeInternal(ctx context.Context, ns *NodeSchema, inner *i
|
||||
w.entry = wNode
|
||||
}
|
||||
|
||||
outputPortCount, hasExceptionPort := ns.OutputPortCount()
|
||||
if outputPortCount > 1 || hasExceptionPort {
|
||||
bMapping, err := w.resolveBranch(key, outputPortCount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b := w.schema.GetBranch(ns.Key)
|
||||
if b != nil {
|
||||
if b.OnlyException() {
|
||||
_ = w.AddBranch(string(key), b.GetExceptionBranch())
|
||||
} else {
|
||||
bb, ok := ns.Configs.(schema.BranchBuilder)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("node schema's Configs should implement BranchBuilder, node type= %v", ns.Type)
|
||||
}
|
||||
|
||||
branch, err := ns.GetBranch(bMapping)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
br, err := b.GetFullBranch(ctx, bb)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_ = w.AddBranch(string(key), branch)
|
||||
_ = w.AddBranch(string(key), br)
|
||||
}
|
||||
}
|
||||
|
||||
return deps.inputsForParent, nil
|
||||
@@ -328,15 +333,15 @@ func (w *Workflow) Compile(ctx context.Context, opts ...compose.GraphCompileOpti
|
||||
return w.workflow.Compile(ctx, opts...)
|
||||
}
|
||||
|
||||
func (w *Workflow) getInnerWorkflow(ctx context.Context, cNode *CompositeNode) (*innerWorkflowInfo, error) {
|
||||
innerNodes := make(map[vo.NodeKey]*NodeSchema)
|
||||
func (w *Workflow) getInnerWorkflow(ctx context.Context, cNode *schema.CompositeNode) (*innerWorkflowInfo, error) {
|
||||
innerNodes := make(map[vo.NodeKey]*schema.NodeSchema)
|
||||
for _, n := range cNode.Children {
|
||||
innerNodes[n.Key] = n
|
||||
}
|
||||
|
||||
// trim the connections, only keep the connections that are related to the inner workflow
|
||||
// ignore the cases when we have nested inner workflows, because we do not support nested composite nodes
|
||||
innerConnections := make([]*Connection, 0)
|
||||
innerConnections := make([]*schema.Connection, 0)
|
||||
for i := range w.schema.Connections {
|
||||
conn := w.schema.Connections[i]
|
||||
if _, ok := innerNodes[conn.FromNode]; ok {
|
||||
@@ -510,7 +515,7 @@ func (d *dependencyInfo) merge(mappings map[vo.NodeKey][]*compose.FieldMapping)
|
||||
// For example, if the 'from path' is ['a', 'b', 'c'], and 'b' is an array, we will take value using a.b[0].c.
|
||||
// As a counter example, if the 'from path' is ['a', 'b', 'c'], and 'b' is not an array, but 'c' is an array,
|
||||
// we will not try to drill, instead, just take value using a.b.c.
|
||||
func (d *dependencyInfo) arrayDrillDown(allNS map[vo.NodeKey]*NodeSchema) error {
|
||||
func (d *dependencyInfo) arrayDrillDown(allNS map[vo.NodeKey]*schema.NodeSchema) error {
|
||||
for nKey, fms := range d.inputs {
|
||||
if nKey == compose.START { // reference to START node would NEVER need to do array drill down
|
||||
continue
|
||||
@@ -638,55 +643,6 @@ type variableInfo struct {
|
||||
toPath compose.FieldPath
|
||||
}
|
||||
|
||||
func (w *Workflow) resolveBranch(n vo.NodeKey, portCount int) (*BranchMapping, error) {
|
||||
m := make([]map[string]bool, portCount)
|
||||
var exception map[string]bool
|
||||
|
||||
for _, conn := range w.connections {
|
||||
if conn.FromNode != n {
|
||||
continue
|
||||
}
|
||||
|
||||
if conn.FromPort == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if *conn.FromPort == "default" { // default condition
|
||||
if m[portCount-1] == nil {
|
||||
m[portCount-1] = make(map[string]bool)
|
||||
}
|
||||
m[portCount-1][string(conn.ToNode)] = true
|
||||
} else if *conn.FromPort == "branch_error" {
|
||||
if exception == nil {
|
||||
exception = make(map[string]bool)
|
||||
}
|
||||
exception[string(conn.ToNode)] = true
|
||||
} else {
|
||||
if !strings.HasPrefix(*conn.FromPort, "branch_") {
|
||||
return nil, fmt.Errorf("outgoing connections has invalid port= %s", *conn.FromPort)
|
||||
}
|
||||
|
||||
index := (*conn.FromPort)[7:]
|
||||
i, err := strconv.Atoi(index)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("outgoing connections has invalid port index= %s", *conn.FromPort)
|
||||
}
|
||||
if i < 0 || i >= portCount {
|
||||
return nil, fmt.Errorf("outgoing connections has invalid port index range= %d, condition count= %d", i, portCount)
|
||||
}
|
||||
if m[i] == nil {
|
||||
m[i] = make(map[string]bool)
|
||||
}
|
||||
m[i][string(conn.ToNode)] = true
|
||||
}
|
||||
}
|
||||
|
||||
return &BranchMapping{
|
||||
Normal: m,
|
||||
Exception: exception,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.FieldInfo) (*dependencyInfo, error) {
|
||||
var (
|
||||
inputs = make(map[vo.NodeKey][]*compose.FieldMapping)
|
||||
@@ -701,7 +657,7 @@ func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.Field
|
||||
inputsForParent = make(map[vo.NodeKey][]*compose.FieldMapping)
|
||||
)
|
||||
|
||||
connMap := make(map[vo.NodeKey]Connection)
|
||||
connMap := make(map[vo.NodeKey]schema.Connection)
|
||||
for _, conn := range w.connections {
|
||||
if conn.ToNode != n {
|
||||
continue
|
||||
@@ -734,7 +690,7 @@ func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.Field
|
||||
continue
|
||||
}
|
||||
|
||||
if ok := isInSameWorkflow(w.hierarchy, n, fromNode); ok {
|
||||
if ok := schema.IsInSameWorkflow(w.hierarchy, n, fromNode); ok {
|
||||
if _, ok := connMap[fromNode]; ok { // direct dependency
|
||||
if len(swp.Source.Ref.FromPath) == 0 && len(swp.Path) == 0 {
|
||||
if inputFull == nil {
|
||||
@@ -755,10 +711,10 @@ func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.Field
|
||||
compose.MapFieldPaths(swp.Source.Ref.FromPath, swp.Path))
|
||||
}
|
||||
}
|
||||
} else if ok := isBelowOneLevel(w.hierarchy, n, fromNode); ok {
|
||||
} else if ok := schema.IsBelowOneLevel(w.hierarchy, n, fromNode); ok {
|
||||
firstNodesInInnerWorkflow := true
|
||||
for _, conn := range connMap {
|
||||
if isInSameWorkflow(w.hierarchy, n, conn.FromNode) {
|
||||
if schema.IsInSameWorkflow(w.hierarchy, n, conn.FromNode) {
|
||||
// there is another node 'conn.FromNode' that connects to this node, while also at the same level
|
||||
firstNodesInInnerWorkflow = false
|
||||
break
|
||||
@@ -805,9 +761,9 @@ func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.Field
|
||||
continue
|
||||
}
|
||||
|
||||
if isBelowOneLevel(w.hierarchy, n, fromNodeKey) {
|
||||
if schema.IsBelowOneLevel(w.hierarchy, n, fromNodeKey) {
|
||||
fromNodeKey = compose.START
|
||||
} else if !isInSameWorkflow(w.hierarchy, n, fromNodeKey) {
|
||||
} else if !schema.IsInSameWorkflow(w.hierarchy, n, fromNodeKey) {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -864,13 +820,13 @@ func (w *Workflow) resolveDependenciesAsParent(n vo.NodeKey, sourceWithPaths []*
|
||||
variableInfos []*variableInfo
|
||||
)
|
||||
|
||||
connMap := make(map[vo.NodeKey]Connection)
|
||||
connMap := make(map[vo.NodeKey]schema.Connection)
|
||||
for _, conn := range w.connections {
|
||||
if conn.ToNode != n {
|
||||
continue
|
||||
}
|
||||
|
||||
if isInSameWorkflow(w.hierarchy, conn.FromNode, n) {
|
||||
if schema.IsInSameWorkflow(w.hierarchy, conn.FromNode, n) {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -899,7 +855,7 @@ func (w *Workflow) resolveDependenciesAsParent(n vo.NodeKey, sourceWithPaths []*
|
||||
swp.Source.Ref.FromPath, swp.Path)
|
||||
}
|
||||
|
||||
if ok := isParentOf(w.hierarchy, n, fromNode); ok {
|
||||
if ok := schema.IsParentOf(w.hierarchy, n, fromNode); ok {
|
||||
if _, ok := connMap[fromNode]; ok { // direct dependency
|
||||
inputs[fromNode] = append(inputs[fromNode], compose.MapFieldPaths(swp.Source.Ref.FromPath, append(compose.FieldPath{string(fromNode)}, swp.Source.Ref.FromPath...)))
|
||||
} else { // indirect dependency
|
||||
|
||||
@@ -23,9 +23,10 @@ import (
|
||||
|
||||
workflow2 "github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
func NewWorkflowFromNode(ctx context.Context, sc *WorkflowSchema, nodeKey vo.NodeKey, opts ...compose.GraphCompileOption) (
|
||||
func NewWorkflowFromNode(ctx context.Context, sc *schema.WorkflowSchema, nodeKey vo.NodeKey, opts ...compose.GraphCompileOption) (
|
||||
*Workflow, error) {
|
||||
sc.Init()
|
||||
ns := sc.GetNode(nodeKey)
|
||||
@@ -37,7 +38,7 @@ func NewWorkflowFromNode(ctx context.Context, sc *WorkflowSchema, nodeKey vo.Nod
|
||||
schema: sc,
|
||||
fromNode: true,
|
||||
streamRun: false, // single node run can only invoke
|
||||
requireCheckpoint: sc.requireCheckPoint,
|
||||
requireCheckpoint: sc.RequireCheckpoint(),
|
||||
input: ns.InputTypes,
|
||||
output: ns.OutputTypes,
|
||||
terminatePlan: vo.ReturnVariables,
|
||||
|
||||
@@ -32,6 +32,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
|
||||
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
@@ -42,7 +43,7 @@ type WorkflowRunner struct {
|
||||
basic *entity.WorkflowBasic
|
||||
input string
|
||||
resumeReq *entity.ResumeRequest
|
||||
schema *WorkflowSchema
|
||||
schema *schema2.WorkflowSchema
|
||||
streamWriter *schema.StreamWriter[*entity.Message]
|
||||
config vo.ExecuteConfig
|
||||
|
||||
@@ -76,7 +77,7 @@ func WithStreamWriter(sw *schema.StreamWriter[*entity.Message]) WorkflowRunnerOp
|
||||
}
|
||||
}
|
||||
|
||||
func NewWorkflowRunner(b *entity.WorkflowBasic, sc *WorkflowSchema, config vo.ExecuteConfig, opts ...WorkflowRunnerOption) *WorkflowRunner {
|
||||
func NewWorkflowRunner(b *entity.WorkflowBasic, sc *schema2.WorkflowSchema, config vo.ExecuteConfig, opts ...WorkflowRunnerOption) *WorkflowRunner {
|
||||
options := &workflowRunOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
|
||||
@@ -1,336 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package compose
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"maps"
|
||||
"reflect"
|
||||
"sync"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
)
|
||||
|
||||
type WorkflowSchema struct {
|
||||
Nodes []*NodeSchema `json:"nodes"`
|
||||
Connections []*Connection `json:"connections"`
|
||||
Hierarchy map[vo.NodeKey]vo.NodeKey `json:"hierarchy,omitempty"` // child node key-> parent node key
|
||||
|
||||
GeneratedNodes []vo.NodeKey `json:"generated_nodes,omitempty"` // generated nodes for the nodes in batch mode
|
||||
|
||||
nodeMap map[vo.NodeKey]*NodeSchema // won't serialize this
|
||||
compositeNodes []*CompositeNode // won't serialize this
|
||||
requireCheckPoint bool // won't serialize this
|
||||
requireStreaming bool
|
||||
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
type Connection struct {
|
||||
FromNode vo.NodeKey `json:"from_node"`
|
||||
ToNode vo.NodeKey `json:"to_node"`
|
||||
FromPort *string `json:"from_port,omitempty"`
|
||||
}
|
||||
|
||||
func (c *Connection) ID() string {
|
||||
if c.FromPort != nil {
|
||||
return fmt.Sprintf("%s:%s:%v", c.FromNode, c.ToNode, *c.FromPort)
|
||||
}
|
||||
return fmt.Sprintf("%v:%v", c.FromNode, c.ToNode)
|
||||
}
|
||||
|
||||
type CompositeNode struct {
|
||||
Parent *NodeSchema
|
||||
Children []*NodeSchema
|
||||
}
|
||||
|
||||
func (w *WorkflowSchema) Init() {
|
||||
w.once.Do(func() {
|
||||
w.nodeMap = make(map[vo.NodeKey]*NodeSchema)
|
||||
for _, node := range w.Nodes {
|
||||
w.nodeMap[node.Key] = node
|
||||
}
|
||||
|
||||
w.doGetCompositeNodes()
|
||||
|
||||
for _, node := range w.Nodes {
|
||||
if node.requireCheckpoint() {
|
||||
w.requireCheckPoint = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
w.requireStreaming = w.doRequireStreaming()
|
||||
})
|
||||
}
|
||||
|
||||
func (w *WorkflowSchema) GetNode(key vo.NodeKey) *NodeSchema {
|
||||
return w.nodeMap[key]
|
||||
}
|
||||
|
||||
func (w *WorkflowSchema) GetAllNodes() map[vo.NodeKey]*NodeSchema {
|
||||
return w.nodeMap // TODO: needs to calculate node count separately, considering batch mode nodes
|
||||
}
|
||||
|
||||
func (w *WorkflowSchema) GetCompositeNodes() []*CompositeNode {
|
||||
if w.compositeNodes == nil {
|
||||
w.compositeNodes = w.doGetCompositeNodes()
|
||||
}
|
||||
|
||||
return w.compositeNodes
|
||||
}
|
||||
|
||||
func (w *WorkflowSchema) doGetCompositeNodes() (cNodes []*CompositeNode) {
|
||||
if w.Hierarchy == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build parent to children mapping
|
||||
parentToChildren := make(map[vo.NodeKey][]*NodeSchema)
|
||||
for childKey, parentKey := range w.Hierarchy {
|
||||
if parentSchema := w.nodeMap[parentKey]; parentSchema != nil {
|
||||
if childSchema := w.nodeMap[childKey]; childSchema != nil {
|
||||
parentToChildren[parentKey] = append(parentToChildren[parentKey], childSchema)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create composite nodes
|
||||
for parentKey, children := range parentToChildren {
|
||||
if parentSchema := w.nodeMap[parentKey]; parentSchema != nil {
|
||||
cNodes = append(cNodes, &CompositeNode{
|
||||
Parent: parentSchema,
|
||||
Children: children,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return cNodes
|
||||
}
|
||||
|
||||
func isInSameWorkflow(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.NodeKey) bool {
|
||||
if n == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
myParents, myParentExists := n[nodeKey]
|
||||
theirParents, theirParentExists := n[otherNodeKey]
|
||||
|
||||
if !myParentExists && !theirParentExists {
|
||||
return true
|
||||
}
|
||||
|
||||
if !myParentExists || !theirParentExists {
|
||||
return false
|
||||
}
|
||||
|
||||
return myParents == theirParents
|
||||
}
|
||||
|
||||
func isBelowOneLevel(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.NodeKey) bool {
|
||||
if n == nil {
|
||||
return false
|
||||
}
|
||||
_, myParentExists := n[nodeKey]
|
||||
_, theirParentExists := n[otherNodeKey]
|
||||
|
||||
return myParentExists && !theirParentExists
|
||||
}
|
||||
|
||||
func isParentOf(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.NodeKey) bool {
|
||||
if n == nil {
|
||||
return false
|
||||
}
|
||||
theirParent, theirParentExists := n[otherNodeKey]
|
||||
|
||||
return theirParentExists && theirParent == nodeKey
|
||||
}
|
||||
|
||||
func (w *WorkflowSchema) IsEqual(other *WorkflowSchema) bool {
|
||||
otherConnectionsMap := make(map[string]bool, len(other.Connections))
|
||||
for _, connection := range other.Connections {
|
||||
otherConnectionsMap[connection.ID()] = true
|
||||
}
|
||||
connectionsMap := make(map[string]bool, len(other.Connections))
|
||||
for _, connection := range w.Connections {
|
||||
connectionsMap[connection.ID()] = true
|
||||
}
|
||||
if !maps.Equal(otherConnectionsMap, connectionsMap) {
|
||||
return false
|
||||
}
|
||||
otherNodeMap := make(map[vo.NodeKey]*NodeSchema, len(other.Nodes))
|
||||
for _, node := range other.Nodes {
|
||||
otherNodeMap[node.Key] = node
|
||||
}
|
||||
nodeMap := make(map[vo.NodeKey]*NodeSchema, len(w.Nodes))
|
||||
|
||||
for _, node := range w.Nodes {
|
||||
nodeMap[node.Key] = node
|
||||
}
|
||||
|
||||
if !maps.EqualFunc(otherNodeMap, nodeMap, func(node *NodeSchema, other *NodeSchema) bool {
|
||||
if node.Name != other.Name {
|
||||
return false
|
||||
}
|
||||
if !reflect.DeepEqual(node.Configs, other.Configs) {
|
||||
return false
|
||||
}
|
||||
if !reflect.DeepEqual(node.InputTypes, other.InputTypes) {
|
||||
return false
|
||||
}
|
||||
if !reflect.DeepEqual(node.InputSources, other.InputSources) {
|
||||
return false
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(node.OutputTypes, other.OutputTypes) {
|
||||
return false
|
||||
}
|
||||
if !reflect.DeepEqual(node.OutputSources, other.OutputSources) {
|
||||
return false
|
||||
}
|
||||
if !reflect.DeepEqual(node.ExceptionConfigs, other.ExceptionConfigs) {
|
||||
return false
|
||||
}
|
||||
if !reflect.DeepEqual(node.SubWorkflowBasic, other.SubWorkflowBasic) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
|
||||
}) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
|
||||
}
|
||||
|
||||
func (w *WorkflowSchema) NodeCount() int32 {
|
||||
return int32(len(w.Nodes) - len(w.GeneratedNodes))
|
||||
}
|
||||
|
||||
func (w *WorkflowSchema) doRequireStreaming() bool {
|
||||
producers := make(map[vo.NodeKey]bool)
|
||||
consumers := make(map[vo.NodeKey]bool)
|
||||
|
||||
for _, node := range w.Nodes {
|
||||
meta := entity.NodeMetaByNodeType(node.Type)
|
||||
if meta != nil {
|
||||
sps := meta.ExecutableMeta.StreamingParadigms
|
||||
if _, ok := sps[entity.Stream]; ok {
|
||||
if node.StreamConfigs != nil && node.StreamConfigs.CanGeneratesStream {
|
||||
producers[node.Key] = true
|
||||
}
|
||||
}
|
||||
|
||||
if sps[entity.Transform] || sps[entity.Collect] {
|
||||
if node.StreamConfigs != nil && node.StreamConfigs.RequireStreamingInput {
|
||||
consumers[node.Key] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(producers) == 0 || len(consumers) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Build data-flow graph from InputSources
|
||||
adj := make(map[vo.NodeKey]map[vo.NodeKey]struct{})
|
||||
for _, node := range w.Nodes {
|
||||
for _, source := range node.InputSources {
|
||||
if source.Source.Ref != nil && len(source.Source.Ref.FromNodeKey) > 0 {
|
||||
if _, ok := adj[source.Source.Ref.FromNodeKey]; !ok {
|
||||
adj[source.Source.Ref.FromNodeKey] = make(map[vo.NodeKey]struct{})
|
||||
}
|
||||
adj[source.Source.Ref.FromNodeKey][node.Key] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For each producer, traverse the graph to see if it can reach a consumer
|
||||
for p := range producers {
|
||||
q := []vo.NodeKey{p}
|
||||
visited := make(map[vo.NodeKey]bool)
|
||||
visited[p] = true
|
||||
|
||||
for len(q) > 0 {
|
||||
curr := q[0]
|
||||
q = q[1:]
|
||||
|
||||
if consumers[curr] {
|
||||
return true
|
||||
}
|
||||
|
||||
for neighbor := range adj[curr] {
|
||||
if !visited[neighbor] {
|
||||
visited[neighbor] = true
|
||||
q = append(q, neighbor)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (w *WorkflowSchema) fanInMergeConfigs() map[string]compose.FanInMergeConfig {
|
||||
// what we need to do is to see if the workflow requires streaming, if not, then no fan-in merge configs needed
|
||||
// then we find those nodes that have 'transform' or 'collect' as streaming paradigm,
|
||||
// and see if each of those nodes has multiple data predecessors, if so, it's a fan-in node.
|
||||
// then, look up the NodeTypeMeta's ExecutableMeta info and see if it requires fan-in stream merge.
|
||||
if !w.requireStreaming {
|
||||
return nil
|
||||
}
|
||||
|
||||
fanInNodes := make(map[vo.NodeKey]bool)
|
||||
for _, node := range w.Nodes {
|
||||
meta := entity.NodeMetaByNodeType(node.Type)
|
||||
if meta != nil {
|
||||
sps := meta.ExecutableMeta.StreamingParadigms
|
||||
if sps[entity.Transform] || sps[entity.Collect] {
|
||||
if node.StreamConfigs != nil && node.StreamConfigs.RequireStreamingInput {
|
||||
var predecessor *vo.NodeKey
|
||||
for _, source := range node.InputSources {
|
||||
if source.Source.Ref != nil && len(source.Source.Ref.FromNodeKey) > 0 {
|
||||
if predecessor != nil {
|
||||
fanInNodes[node.Key] = true
|
||||
break
|
||||
}
|
||||
predecessor = &source.Source.Ref.FromNodeKey
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fanInConfigs := make(map[string]compose.FanInMergeConfig)
|
||||
for nodeKey := range fanInNodes {
|
||||
if m := entity.NodeMetaByNodeType(w.GetNode(nodeKey).Type); m != nil {
|
||||
if m.StreamSourceEOFAware {
|
||||
fanInConfigs[string(nodeKey)] = compose.FanInMergeConfig{
|
||||
StreamMergeWithSourceEOF: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return fanInConfigs
|
||||
}
|
||||
@@ -30,6 +30,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
)
|
||||
@@ -41,7 +42,7 @@ type invokableWorkflow struct {
|
||||
invoke func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (map[string]any, error)
|
||||
terminatePlan vo.TerminatePlan
|
||||
wfEntity *entity.Workflow
|
||||
sc *WorkflowSchema
|
||||
sc *schema2.WorkflowSchema
|
||||
repo wf.Repository
|
||||
}
|
||||
|
||||
@@ -49,7 +50,7 @@ func NewInvokableWorkflow(info *schema.ToolInfo,
|
||||
invoke func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (map[string]any, error),
|
||||
terminatePlan vo.TerminatePlan,
|
||||
wfEntity *entity.Workflow,
|
||||
sc *WorkflowSchema,
|
||||
sc *schema2.WorkflowSchema,
|
||||
repo wf.Repository,
|
||||
) wf.ToolFromWorkflow {
|
||||
return &invokableWorkflow{
|
||||
@@ -112,7 +113,7 @@ func (i *invokableWorkflow) InvokableRun(ctx context.Context, argumentsInJSON st
|
||||
return "", err
|
||||
}
|
||||
|
||||
var entryNode *NodeSchema
|
||||
var entryNode *schema2.NodeSchema
|
||||
for _, node := range i.sc.Nodes {
|
||||
if node.Type == entity.NodeTypeEntry {
|
||||
entryNode = node
|
||||
@@ -190,7 +191,7 @@ type streamableWorkflow struct {
|
||||
stream func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (*schema.StreamReader[map[string]any], error)
|
||||
terminatePlan vo.TerminatePlan
|
||||
wfEntity *entity.Workflow
|
||||
sc *WorkflowSchema
|
||||
sc *schema2.WorkflowSchema
|
||||
repo wf.Repository
|
||||
}
|
||||
|
||||
@@ -198,7 +199,7 @@ func NewStreamableWorkflow(info *schema.ToolInfo,
|
||||
stream func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (*schema.StreamReader[map[string]any], error),
|
||||
terminatePlan vo.TerminatePlan,
|
||||
wfEntity *entity.Workflow,
|
||||
sc *WorkflowSchema,
|
||||
sc *schema2.WorkflowSchema,
|
||||
repo wf.Repository,
|
||||
) wf.ToolFromWorkflow {
|
||||
return &streamableWorkflow{
|
||||
@@ -261,7 +262,7 @@ func (s *streamableWorkflow) StreamableRun(ctx context.Context, argumentsInJSON
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var entryNode *NodeSchema
|
||||
var entryNode *schema2.NodeSchema
|
||||
for _, node := range s.sc.Nodes {
|
||||
if node.Type == entity.NodeTypeEntry {
|
||||
entryNode = node
|
||||
|
||||
Reference in New Issue
Block a user