feat: manually mirror opencoze's code from bytedance
Change-Id: I09a73aadda978ad9511264a756b2ce51f5761adf
This commit is contained in:
181
backend/domain/workflow/internal/compose/branch.go
Normal file
181
backend/domain/workflow/internal/compose/branch.go
Normal file
@@ -0,0 +1,181 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package compose
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/spf13/cast"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector"
|
||||
)
|
||||
|
||||
func (s *NodeSchema) OutputPortCount() (int, bool) {
|
||||
var hasExceptionPort bool
|
||||
if s.ExceptionConfigs != nil && s.ExceptionConfigs.ProcessType != nil &&
|
||||
*s.ExceptionConfigs.ProcessType == vo.ErrorProcessTypeExceptionBranch {
|
||||
hasExceptionPort = true
|
||||
}
|
||||
|
||||
switch s.Type {
|
||||
case entity.NodeTypeSelector:
|
||||
return len(mustGetKey[[]*selector.OneClauseSchema]("Clauses", s.Configs)) + 1, hasExceptionPort
|
||||
case entity.NodeTypeQuestionAnswer:
|
||||
if mustGetKey[qa.AnswerType]("AnswerType", s.Configs.(map[string]any)) == qa.AnswerByChoices {
|
||||
if mustGetKey[qa.ChoiceType]("ChoiceType", s.Configs.(map[string]any)) == qa.FixedChoices {
|
||||
return len(mustGetKey[[]string]("FixedChoices", s.Configs.(map[string]any))) + 1, hasExceptionPort
|
||||
} else {
|
||||
return 2, hasExceptionPort
|
||||
}
|
||||
}
|
||||
return 1, hasExceptionPort
|
||||
case entity.NodeTypeIntentDetector:
|
||||
intents := mustGetKey[[]string]("Intents", s.Configs.(map[string]any))
|
||||
return len(intents) + 1, hasExceptionPort
|
||||
default:
|
||||
return 1, hasExceptionPort
|
||||
}
|
||||
}
|
||||
|
||||
type BranchMapping struct {
|
||||
Normal []map[string]bool
|
||||
Exception map[string]bool
|
||||
}
|
||||
|
||||
const (
|
||||
DefaultBranch = "default"
|
||||
BranchFmt = "branch_%d"
|
||||
)
|
||||
|
||||
func (s *NodeSchema) GetBranch(bMapping *BranchMapping) (*compose.GraphBranch, error) {
|
||||
if bMapping == nil {
|
||||
return nil, errors.New("no branch mapping")
|
||||
}
|
||||
|
||||
endNodes := make(map[string]bool)
|
||||
for i := range bMapping.Normal {
|
||||
for k := range bMapping.Normal[i] {
|
||||
endNodes[k] = true
|
||||
}
|
||||
}
|
||||
|
||||
if bMapping.Exception != nil {
|
||||
for k := range bMapping.Exception {
|
||||
endNodes[k] = true
|
||||
}
|
||||
}
|
||||
|
||||
switch s.Type {
|
||||
case entity.NodeTypeSelector:
|
||||
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
|
||||
choice := in[selector.SelectKey].(int)
|
||||
if choice < 0 || choice > len(bMapping.Normal) {
|
||||
return nil, fmt.Errorf("node %s choice out of range: %d", s.Key, choice)
|
||||
}
|
||||
|
||||
choices := make(map[string]bool, len((bMapping.Normal)[choice]))
|
||||
for k := range (bMapping.Normal)[choice] {
|
||||
choices[k] = true
|
||||
}
|
||||
|
||||
return choices, nil
|
||||
}
|
||||
return compose.NewGraphMultiBranch(condition, endNodes), nil
|
||||
case entity.NodeTypeQuestionAnswer:
|
||||
conf := s.Configs.(map[string]any)
|
||||
if mustGetKey[qa.AnswerType]("AnswerType", conf) == qa.AnswerByChoices {
|
||||
choiceType := mustGetKey[qa.ChoiceType]("ChoiceType", conf)
|
||||
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
|
||||
optionID, ok := nodes.TakeMapValue(in, compose.FieldPath{qa.OptionIDKey})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to take option id from input map: %v", in)
|
||||
}
|
||||
|
||||
if optionID.(string) == "other" {
|
||||
return (bMapping.Normal)[len(bMapping.Normal)-1], nil
|
||||
}
|
||||
|
||||
if choiceType == qa.DynamicChoices { // all dynamic choices maps to branch 0
|
||||
return (bMapping.Normal)[0], nil
|
||||
}
|
||||
|
||||
optionIDInt, ok := qa.AlphabetToInt(optionID.(string))
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to convert option id from input map: %v", optionID)
|
||||
}
|
||||
|
||||
if optionIDInt < 0 || optionIDInt >= len(bMapping.Normal) {
|
||||
return nil, fmt.Errorf("failed to take option id from input map: %v", in)
|
||||
}
|
||||
|
||||
return (bMapping.Normal)[optionIDInt], nil
|
||||
}
|
||||
return compose.NewGraphMultiBranch(condition, endNodes), nil
|
||||
}
|
||||
return nil, fmt.Errorf("this qa node should not have branches: %s", s.Key)
|
||||
|
||||
case entity.NodeTypeIntentDetector:
|
||||
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
|
||||
isSuccess, ok := in["isSuccess"]
|
||||
if ok && isSuccess != nil && !isSuccess.(bool) {
|
||||
return bMapping.Exception, nil
|
||||
}
|
||||
|
||||
classificationId, ok := nodes.TakeMapValue(in, compose.FieldPath{"classificationId"})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to take classification id from input map: %v", in)
|
||||
}
|
||||
|
||||
// Intent detector the node default branch uses classificationId=0. But currently scene, the implementation uses default as the last element of the array.
|
||||
// Therefore, when classificationId=0, it needs to be converted into the node corresponding to the last index of the array.
|
||||
// Other options also need to reduce the index by 1.
|
||||
id, err := cast.ToInt64E(classificationId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
realID := id - 1
|
||||
|
||||
if realID >= int64(len(bMapping.Normal)) {
|
||||
return nil, fmt.Errorf("invalid classification id from input, classification id: %v", classificationId)
|
||||
}
|
||||
|
||||
if realID < 0 {
|
||||
realID = int64(len(bMapping.Normal)) - 1
|
||||
}
|
||||
|
||||
return (bMapping.Normal)[realID], nil
|
||||
}
|
||||
return compose.NewGraphMultiBranch(condition, endNodes), nil
|
||||
default:
|
||||
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
|
||||
isSuccess, ok := in["isSuccess"]
|
||||
if ok && isSuccess != nil && !isSuccess.(bool) {
|
||||
return bMapping.Exception, nil
|
||||
}
|
||||
|
||||
return (bMapping.Normal)[0], nil
|
||||
}
|
||||
return compose.NewGraphMultiBranch(condition, endNodes), nil
|
||||
}
|
||||
}
|
||||
194
backend/domain/workflow/internal/compose/callbacks.go
Normal file
194
backend/domain/workflow/internal/compose/callbacks.go
Normal file
@@ -0,0 +1,194 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package compose
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector"
|
||||
)
|
||||
|
||||
type selectorCallbackField struct {
|
||||
Key string `json:"key"`
|
||||
Type vo.DataType `json:"type"`
|
||||
Value any `json:"value"`
|
||||
}
|
||||
|
||||
type selectorCondition struct {
|
||||
Left selectorCallbackField `json:"left"`
|
||||
Operator vo.OperatorType `json:"operator"`
|
||||
Right *selectorCallbackField `json:"right,omitempty"`
|
||||
}
|
||||
|
||||
type selectorBranch struct {
|
||||
Conditions []*selectorCondition `json:"conditions"`
|
||||
Logic vo.LogicType `json:"logic"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func (s *NodeSchema) toSelectorCallbackInput(sc *WorkflowSchema) func(_ context.Context, in map[string]any) (map[string]any, error) {
|
||||
return func(_ context.Context, in map[string]any) (map[string]any, error) {
|
||||
config := mustGetKey[[]*selector.OneClauseSchema]("Clauses", s.Configs)
|
||||
count := len(config)
|
||||
|
||||
output := make([]*selectorBranch, count)
|
||||
|
||||
for _, source := range s.InputSources {
|
||||
targetPath := source.Path
|
||||
if len(targetPath) == 2 {
|
||||
indexStr := targetPath[0]
|
||||
index, err := strconv.Atoi(indexStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
branch := output[index]
|
||||
if branch == nil {
|
||||
output[index] = &selectorBranch{
|
||||
Conditions: []*selectorCondition{
|
||||
{
|
||||
Operator: config[index].Single.ToCanvasOperatorType(),
|
||||
},
|
||||
},
|
||||
Logic: selector.ClauseRelationAND.ToVOLogicType(),
|
||||
}
|
||||
}
|
||||
|
||||
if targetPath[1] == selector.LeftKey {
|
||||
leftV, ok := nodes.TakeMapValue(in, targetPath)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to take left value of %s", targetPath)
|
||||
}
|
||||
if source.Source.Ref.VariableType != nil {
|
||||
if *source.Source.Ref.VariableType == vo.ParentIntermediate {
|
||||
parentNodeKey, ok := sc.Hierarchy[s.Key]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to find parent node key of %s", s.Key)
|
||||
}
|
||||
parentNode := sc.GetNode(parentNodeKey)
|
||||
output[index].Conditions[0].Left = selectorCallbackField{
|
||||
Key: parentNode.Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
|
||||
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
|
||||
Value: leftV,
|
||||
}
|
||||
} else {
|
||||
output[index].Conditions[0].Left = selectorCallbackField{
|
||||
Key: "",
|
||||
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
|
||||
Value: leftV,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
output[index].Conditions[0].Left = selectorCallbackField{
|
||||
Key: sc.GetNode(source.Source.Ref.FromNodeKey).Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
|
||||
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
|
||||
Value: leftV,
|
||||
}
|
||||
}
|
||||
} else if targetPath[1] == selector.RightKey {
|
||||
rightV, ok := nodes.TakeMapValue(in, targetPath)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to take right value of %s", targetPath)
|
||||
}
|
||||
output[index].Conditions[0].Right = &selectorCallbackField{
|
||||
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
|
||||
Value: rightV,
|
||||
}
|
||||
}
|
||||
} else if len(targetPath) == 3 {
|
||||
indexStr := targetPath[0]
|
||||
index, err := strconv.Atoi(indexStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
multi := config[index].Multi
|
||||
|
||||
branch := output[index]
|
||||
if branch == nil {
|
||||
output[index] = &selectorBranch{
|
||||
Conditions: make([]*selectorCondition, len(multi.Clauses)),
|
||||
Logic: multi.Relation.ToVOLogicType(),
|
||||
}
|
||||
}
|
||||
|
||||
clauseIndexStr := targetPath[1]
|
||||
clauseIndex, err := strconv.Atoi(clauseIndexStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
clause := multi.Clauses[clauseIndex]
|
||||
|
||||
if output[index].Conditions[clauseIndex] == nil {
|
||||
output[index].Conditions[clauseIndex] = &selectorCondition{
|
||||
Operator: clause.ToCanvasOperatorType(),
|
||||
}
|
||||
}
|
||||
|
||||
if targetPath[2] == selector.LeftKey {
|
||||
leftV, ok := nodes.TakeMapValue(in, targetPath)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to take left value of %s", targetPath)
|
||||
}
|
||||
if source.Source.Ref.VariableType != nil {
|
||||
if *source.Source.Ref.VariableType == vo.ParentIntermediate {
|
||||
parentNodeKey, ok := sc.Hierarchy[s.Key]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to find parent node key of %s", s.Key)
|
||||
}
|
||||
parentNode := sc.GetNode(parentNodeKey)
|
||||
output[index].Conditions[clauseIndex].Left = selectorCallbackField{
|
||||
Key: parentNode.Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
|
||||
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
|
||||
Value: leftV,
|
||||
}
|
||||
} else {
|
||||
output[index].Conditions[clauseIndex].Left = selectorCallbackField{
|
||||
Key: "",
|
||||
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
|
||||
Value: leftV,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
output[index].Conditions[clauseIndex].Left = selectorCallbackField{
|
||||
Key: sc.GetNode(source.Source.Ref.FromNodeKey).Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
|
||||
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
|
||||
Value: leftV,
|
||||
}
|
||||
}
|
||||
} else if targetPath[2] == selector.RightKey {
|
||||
rightV, ok := nodes.TakeMapValue(in, targetPath)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to take right value of %s", targetPath)
|
||||
}
|
||||
output[index].Conditions[clauseIndex].Right = &selectorCallbackField{
|
||||
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
|
||||
Value: rightV,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]any{"branches": output}, nil
|
||||
}
|
||||
}
|
||||
335
backend/domain/workflow/internal/compose/designate_option.go
Normal file
335
backend/domain/workflow/internal/compose/designate_option.go
Normal file
@@ -0,0 +1,335 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package compose
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strconv"
|
||||
|
||||
einoCompose "github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
workflow2 "github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
func (r *WorkflowRunner) designateOptions(ctx context.Context) (context.Context, []einoCompose.Option, error) {
|
||||
var (
|
||||
wb = r.basic
|
||||
exeCfg = r.config
|
||||
executeID = r.executeID
|
||||
workflowSC = r.schema
|
||||
eventChan = r.eventChan
|
||||
resumedEvent = r.interruptEvent
|
||||
sw = r.streamWriter
|
||||
)
|
||||
|
||||
const tokenCallbackKey = "token_callback_key"
|
||||
|
||||
if wb.AppID != nil && exeCfg.AppID == nil {
|
||||
exeCfg.AppID = wb.AppID
|
||||
}
|
||||
|
||||
rootHandler := execute.NewRootWorkflowHandler(
|
||||
wb,
|
||||
executeID,
|
||||
workflowSC.requireCheckPoint,
|
||||
eventChan,
|
||||
resumedEvent,
|
||||
exeCfg,
|
||||
workflowSC.NodeCount())
|
||||
|
||||
opts := []einoCompose.Option{einoCompose.WithCallbacks(rootHandler)}
|
||||
|
||||
for key := range workflowSC.GetAllNodes() {
|
||||
ns := workflowSC.GetAllNodes()[key]
|
||||
|
||||
var nodeOpt einoCompose.Option
|
||||
if ns.Type == entity.NodeTypeExit {
|
||||
nodeOpt = nodeCallbackOption(key, ns.Name, eventChan, resumedEvent,
|
||||
ptr.Of(mustGetKey[vo.TerminatePlan]("TerminalPlan", ns.Configs)))
|
||||
} else if ns.Type != entity.NodeTypeLambda {
|
||||
nodeOpt = nodeCallbackOption(key, ns.Name, eventChan, resumedEvent, nil)
|
||||
}
|
||||
|
||||
if parent, ok := workflowSC.Hierarchy[key]; !ok { // top level nodes, just add the node handler
|
||||
opts = append(opts, nodeOpt)
|
||||
if ns.Type == entity.NodeTypeSubWorkflow {
|
||||
subOpts, err := r.designateOptionsForSubWorkflow(ctx,
|
||||
rootHandler.(*execute.WorkflowHandler),
|
||||
ns,
|
||||
string(key))
|
||||
if err != nil {
|
||||
return ctx, nil, err
|
||||
}
|
||||
opts = append(opts, subOpts...)
|
||||
} else if ns.Type == entity.NodeTypeLLM {
|
||||
llmNodeOpts, err := llmToolCallbackOptions(ctx, ns, eventChan, sw)
|
||||
if err != nil {
|
||||
return ctx, nil, err
|
||||
}
|
||||
|
||||
opts = append(opts, llmNodeOpts...)
|
||||
}
|
||||
} else {
|
||||
parent := workflowSC.GetAllNodes()[parent]
|
||||
opts = append(opts, WrapOpt(nodeOpt, parent.Key))
|
||||
if ns.Type == entity.NodeTypeSubWorkflow {
|
||||
subOpts, err := r.designateOptionsForSubWorkflow(ctx,
|
||||
rootHandler.(*execute.WorkflowHandler),
|
||||
ns,
|
||||
string(key))
|
||||
if err != nil {
|
||||
return ctx, nil, err
|
||||
}
|
||||
for _, subO := range subOpts {
|
||||
opts = append(opts, WrapOpt(subO, parent.Key))
|
||||
}
|
||||
} else if ns.Type == entity.NodeTypeLLM {
|
||||
llmNodeOpts, err := llmToolCallbackOptions(ctx, ns, eventChan, sw)
|
||||
if err != nil {
|
||||
return ctx, nil, err
|
||||
}
|
||||
for _, subO := range llmNodeOpts {
|
||||
opts = append(opts, WrapOpt(subO, parent.Key))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if workflowSC.requireCheckPoint {
|
||||
opts = append(opts, einoCompose.WithCheckPointID(strconv.FormatInt(executeID, 10)))
|
||||
}
|
||||
|
||||
if !ctxcache.HasKey(ctx, tokenCallbackKey) {
|
||||
opts = append(opts, einoCompose.WithCallbacks(execute.GetTokenCallbackHandler()))
|
||||
ctx = ctxcache.Init(ctx)
|
||||
ctxcache.Store(ctx, tokenCallbackKey, true)
|
||||
}
|
||||
|
||||
return ctx, opts, nil
|
||||
}
|
||||
|
||||
func nodeCallbackOption(key vo.NodeKey, name string, eventChan chan *execute.Event, resumeEvent *entity.InterruptEvent,
|
||||
terminatePlan *vo.TerminatePlan) einoCompose.Option {
|
||||
return einoCompose.WithCallbacks(execute.NewNodeHandler(string(key), name, eventChan, resumeEvent, terminatePlan)).DesignateNode(string(key))
|
||||
}
|
||||
|
||||
func WrapOpt(opt einoCompose.Option, parentNodeKey vo.NodeKey) einoCompose.Option {
|
||||
return einoCompose.WithLambdaOption(nodes.WithOptsForNested(opt)).DesignateNode(string(parentNodeKey))
|
||||
}
|
||||
|
||||
func WrapOptWithIndex(opt einoCompose.Option, parentNodeKey vo.NodeKey, index int) einoCompose.Option {
|
||||
return einoCompose.WithLambdaOption(nodes.WithOptsForIndexed(index, opt)).DesignateNode(string(parentNodeKey))
|
||||
}
|
||||
|
||||
func (r *WorkflowRunner) designateOptionsForSubWorkflow(ctx context.Context,
|
||||
parentHandler *execute.WorkflowHandler,
|
||||
ns *NodeSchema,
|
||||
pathPrefix ...string) (opts []einoCompose.Option, err error) {
|
||||
var (
|
||||
resumeEvent = r.interruptEvent
|
||||
eventChan = r.eventChan
|
||||
sw = r.streamWriter
|
||||
)
|
||||
subHandler := execute.NewSubWorkflowHandler(
|
||||
parentHandler,
|
||||
ns.SubWorkflowBasic,
|
||||
resumeEvent,
|
||||
ns.SubWorkflowSchema.NodeCount(),
|
||||
)
|
||||
|
||||
opts = append(opts, WrapOpt(einoCompose.WithCallbacks(subHandler), ns.Key))
|
||||
|
||||
workflowSC := ns.SubWorkflowSchema
|
||||
for key := range workflowSC.GetAllNodes() {
|
||||
subNS := workflowSC.GetAllNodes()[key]
|
||||
fullPath := append(slices.Clone(pathPrefix), string(subNS.Key))
|
||||
|
||||
var nodeOpt einoCompose.Option
|
||||
if subNS.Type == entity.NodeTypeExit {
|
||||
nodeOpt = nodeCallbackOption(key, subNS.Name, eventChan, resumeEvent,
|
||||
ptr.Of(mustGetKey[vo.TerminatePlan]("TerminalPlan", subNS.Configs)))
|
||||
} else {
|
||||
nodeOpt = nodeCallbackOption(key, subNS.Name, eventChan, resumeEvent, nil)
|
||||
}
|
||||
|
||||
if parent, ok := workflowSC.Hierarchy[key]; !ok { // top level nodes, just add the node handler
|
||||
opts = append(opts, WrapOpt(nodeOpt, ns.Key))
|
||||
if subNS.Type == entity.NodeTypeSubWorkflow {
|
||||
subOpts, err := r.designateOptionsForSubWorkflow(ctx,
|
||||
subHandler.(*execute.WorkflowHandler),
|
||||
subNS,
|
||||
fullPath...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, subO := range subOpts {
|
||||
opts = append(opts, WrapOpt(subO, ns.Key))
|
||||
}
|
||||
} else if subNS.Type == entity.NodeTypeLLM {
|
||||
llmNodeOpts, err := llmToolCallbackOptions(ctx, subNS, eventChan, sw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, subO := range llmNodeOpts {
|
||||
opts = append(opts, WrapOpt(subO, ns.Key))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
parent := workflowSC.GetAllNodes()[parent]
|
||||
opts = append(opts, WrapOpt(WrapOpt(nodeOpt, parent.Key), ns.Key))
|
||||
if subNS.Type == entity.NodeTypeSubWorkflow {
|
||||
subOpts, err := r.designateOptionsForSubWorkflow(ctx,
|
||||
subHandler.(*execute.WorkflowHandler),
|
||||
subNS,
|
||||
fullPath...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, subO := range subOpts {
|
||||
opts = append(opts, WrapOpt(WrapOpt(subO, parent.Key), ns.Key))
|
||||
}
|
||||
} else if subNS.Type == entity.NodeTypeLLM {
|
||||
llmNodeOpts, err := llmToolCallbackOptions(ctx, subNS, eventChan, sw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, subO := range llmNodeOpts {
|
||||
opts = append(opts, WrapOpt(WrapOpt(subO, parent.Key), ns.Key))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
func llmToolCallbackOptions(ctx context.Context, ns *NodeSchema, eventChan chan *execute.Event,
|
||||
sw *schema.StreamWriter[*entity.Message]) (
|
||||
opts []einoCompose.Option, err error) {
|
||||
// this is a LLM node.
|
||||
// check if it has any tools, if no tools, then no callback options needed
|
||||
// for each tool, extract the entity.FunctionInfo, create the ToolHandler, and add the callback option
|
||||
if ns.Type != entity.NodeTypeLLM {
|
||||
panic("impossible: llmToolCallbackOptions is called on a non-LLM node")
|
||||
}
|
||||
|
||||
fcParams := getKeyOrZero[*vo.FCParam]("FCParam", ns.Configs)
|
||||
if fcParams != nil {
|
||||
if fcParams.WorkflowFCParam != nil {
|
||||
// TODO: try to avoid getting the workflow tool all over again
|
||||
for _, wf := range fcParams.WorkflowFCParam.WorkflowList {
|
||||
wfIDStr := wf.WorkflowID
|
||||
wfID, err := strconv.ParseInt(wfIDStr, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid workflow id: %s", wfIDStr)
|
||||
}
|
||||
locator := vo.FromDraft
|
||||
if wf.WorkflowVersion != "" {
|
||||
locator = vo.FromSpecificVersion
|
||||
}
|
||||
|
||||
wfTool, err := workflow2.GetRepository().WorkflowAsTool(ctx, vo.GetPolicy{
|
||||
ID: wfID,
|
||||
QType: locator,
|
||||
Version: wf.WorkflowVersion,
|
||||
}, vo.WorkflowToolConfig{})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tInfo, err := wfTool.Info(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
funcInfo := entity.FunctionInfo{
|
||||
Name: tInfo.Name,
|
||||
Type: entity.WorkflowTool,
|
||||
WorkflowName: wfTool.GetWorkflow().Name,
|
||||
WorkflowTerminatePlan: wfTool.TerminatePlan(),
|
||||
APIID: wfID,
|
||||
APIName: wfTool.GetWorkflow().Name,
|
||||
PluginID: wfID,
|
||||
PluginName: wfTool.GetWorkflow().Name,
|
||||
}
|
||||
|
||||
toolHandler := execute.NewToolHandler(eventChan, funcInfo)
|
||||
opt := einoCompose.WithCallbacks(toolHandler)
|
||||
opt = einoCompose.WithLambdaOption(llm.WithNestedWorkflowOptions(nodes.WithOptsForNested(opt))).DesignateNode(string(ns.Key))
|
||||
opts = append(opts, opt)
|
||||
}
|
||||
}
|
||||
if fcParams.PluginFCParam != nil {
|
||||
for _, p := range fcParams.PluginFCParam.PluginList {
|
||||
toolID, err := strconv.ParseInt(p.ApiId, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pluginID, err := strconv.ParseInt(p.PluginID, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
toolInfoResponse, err := plugin.GetPluginService().GetPluginToolsInfo(ctx, &plugin.ToolsInfoRequest{
|
||||
PluginEntity: plugin.Entity{
|
||||
PluginID: pluginID,
|
||||
PluginVersion: ptr.Of(p.PluginVersion),
|
||||
},
|
||||
ToolIDs: []int64{toolID},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
funcInfo := entity.FunctionInfo{
|
||||
Name: toolInfoResponse.ToolInfoList[toolID].ToolName,
|
||||
Type: entity.PluginTool,
|
||||
PluginID: pluginID,
|
||||
PluginName: toolInfoResponse.PluginName,
|
||||
APIID: toolID,
|
||||
APIName: p.ApiName,
|
||||
}
|
||||
|
||||
toolHandler := execute.NewToolHandler(eventChan, funcInfo)
|
||||
opt := einoCompose.WithCallbacks(toolHandler)
|
||||
opt = einoCompose.WithLambdaOption(llm.WithNestedWorkflowOptions(nodes.WithOptsForNested(opt))).DesignateNode(string(ns.Key))
|
||||
opts = append(opts, opt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if sw != nil {
|
||||
toolMsgOpt := llm.WithToolWorkflowMessageWriter(sw)
|
||||
opt := einoCompose.WithLambdaOption(toolMsgOpt).DesignateNode(string(ns.Key))
|
||||
opts = append(opts, opt)
|
||||
}
|
||||
|
||||
return opts, nil
|
||||
}
|
||||
314
backend/domain/workflow/internal/compose/field_fill.go
Normal file
314
backend/domain/workflow/internal/compose/field_fill.go
Normal file
@@ -0,0 +1,314 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package compose
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"maps"
|
||||
"slices"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
)
|
||||
|
||||
// outputValueFiller will fill the output value with nil if the key is not present in the output map.
|
||||
// if a node emits stream as output, the node needs to handle these absent keys in stream themselves.
|
||||
func (s *NodeSchema) outputValueFiller() func(ctx context.Context, output map[string]any) (map[string]any, error) {
|
||||
if len(s.OutputTypes) == 0 {
|
||||
return func(ctx context.Context, output map[string]any) (map[string]any, error) {
|
||||
return output, nil
|
||||
}
|
||||
}
|
||||
|
||||
return func(ctx context.Context, output map[string]any) (map[string]any, error) {
|
||||
newOutput := make(map[string]any)
|
||||
for k := range output {
|
||||
newOutput[k] = output[k]
|
||||
}
|
||||
|
||||
for k, tInfo := range s.OutputTypes {
|
||||
if err := FillIfNotRequired(tInfo, newOutput, k, FillNil, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return newOutput, nil
|
||||
}
|
||||
}
|
||||
|
||||
// inputValueFiller will fill the input value with default value(zero or nil) if the input value is not present in map.
|
||||
// if a node accepts stream as input, the node needs to handle these absent keys in stream themselves.
|
||||
func (s *NodeSchema) inputValueFiller() func(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
if len(s.InputTypes) == 0 {
|
||||
return func(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
return input, nil
|
||||
}
|
||||
}
|
||||
|
||||
return func(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
newInput := make(map[string]any)
|
||||
for k := range input {
|
||||
newInput[k] = input[k]
|
||||
}
|
||||
|
||||
for k, tInfo := range s.InputTypes {
|
||||
if err := FillIfNotRequired(tInfo, newInput, k, FillZero, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return newInput, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NodeSchema) streamInputValueFiller() func(ctx context.Context,
|
||||
input *schema.StreamReader[map[string]any]) *schema.StreamReader[map[string]any] {
|
||||
fn := func(ctx context.Context, i map[string]any) (map[string]any, error) {
|
||||
newI := make(map[string]any)
|
||||
for k := range i {
|
||||
newI[k] = i[k]
|
||||
}
|
||||
|
||||
for k, tInfo := range s.InputTypes {
|
||||
if err := replaceNilWithZero(tInfo, newI, k); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return newI, nil
|
||||
}
|
||||
|
||||
return func(ctx context.Context, input *schema.StreamReader[map[string]any]) *schema.StreamReader[map[string]any] {
|
||||
return schema.StreamReaderWithConvert(input, func(in map[string]any) (map[string]any, error) {
|
||||
return fn(ctx, in)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type FillStrategy string
|
||||
|
||||
const (
|
||||
FillZero FillStrategy = "zero"
|
||||
FillNil FillStrategy = "nil"
|
||||
)
|
||||
|
||||
func FillIfNotRequired(tInfo *vo.TypeInfo, container map[string]any, k string, strategy FillStrategy, isInsideObject bool) error {
|
||||
v, ok := container[k]
|
||||
if ok {
|
||||
if len(tInfo.Properties) == 0 {
|
||||
if v == nil && strategy == FillZero {
|
||||
if isInsideObject {
|
||||
return nil
|
||||
}
|
||||
v = tInfo.Zero()
|
||||
container[k] = v
|
||||
return nil
|
||||
}
|
||||
|
||||
if v != nil && tInfo.Type == vo.DataTypeArray {
|
||||
val, ok := v.([]any)
|
||||
if !ok {
|
||||
valStr, ok := v.(string)
|
||||
if ok {
|
||||
err := sonic.UnmarshalString(valStr, &val)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
container[k] = val
|
||||
} else {
|
||||
return fmt.Errorf("layer field %s is not a []any or string", k)
|
||||
}
|
||||
}
|
||||
elemTInfo := tInfo.ElemTypeInfo
|
||||
copiedVal := slices.Clone(val)
|
||||
container[k] = copiedVal
|
||||
for i := range copiedVal {
|
||||
if copiedVal[i] == nil {
|
||||
if strategy == FillZero {
|
||||
copiedVal[i] = elemTInfo.Zero()
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if len(elemTInfo.Properties) > 0 {
|
||||
subContainer, ok := copiedVal[i].(map[string]any)
|
||||
if !ok {
|
||||
return fmt.Errorf("map item under array %s is not map[string]any", k)
|
||||
}
|
||||
|
||||
newSubContainer := maps.Clone(subContainer)
|
||||
|
||||
for subK, subL := range elemTInfo.Properties {
|
||||
if err := FillIfNotRequired(subL, newSubContainer, subK, strategy, true); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
copiedVal[i] = newSubContainer
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
// recursively handle the layered object.
|
||||
subContainer, ok := v.(map[string]any)
|
||||
if !ok {
|
||||
subContainerStr, ok := v.(string)
|
||||
if ok {
|
||||
subContainer = make(map[string]any)
|
||||
err := sonic.UnmarshalString(subContainerStr, &subContainer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
container[k] = subContainer
|
||||
} else {
|
||||
return fmt.Errorf("layer field %s is not a map[string]any or string", k)
|
||||
}
|
||||
}
|
||||
|
||||
newSubContainer := maps.Clone(subContainer)
|
||||
if newSubContainer == nil {
|
||||
newSubContainer = make(map[string]any)
|
||||
}
|
||||
for subK, subT := range tInfo.Properties {
|
||||
if err := FillIfNotRequired(subT, newSubContainer, subK, strategy, true); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
container[k] = newSubContainer
|
||||
}
|
||||
} else {
|
||||
if tInfo.Required {
|
||||
return fmt.Errorf("output field %s is required but not present", k)
|
||||
} else {
|
||||
var z any
|
||||
if strategy == FillZero {
|
||||
if !isInsideObject {
|
||||
z = tInfo.Zero()
|
||||
}
|
||||
}
|
||||
|
||||
container[k] = z
|
||||
// if it's an object, recursively handle the layeredFieldInfo.
|
||||
if len(tInfo.Properties) > 0 {
|
||||
z = make(map[string]any)
|
||||
container[k] = z
|
||||
subContainer := z.(map[string]any)
|
||||
for subK, subL := range tInfo.Properties {
|
||||
if err := FillIfNotRequired(subL, subContainer, subK, strategy, true); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func replaceNilWithZero(tInfo *vo.TypeInfo, container map[string]any, k string) error {
|
||||
v, ok := container[k]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(tInfo.Properties) == 0 {
|
||||
if v == nil {
|
||||
v = tInfo.Zero()
|
||||
container[k] = v
|
||||
return nil
|
||||
}
|
||||
|
||||
if tInfo.Type == vo.DataTypeArray {
|
||||
val, ok := v.([]any)
|
||||
if !ok {
|
||||
valStr, ok := v.(string)
|
||||
if ok {
|
||||
err := sonic.UnmarshalString(valStr, &val)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
container[k] = val
|
||||
} else {
|
||||
return fmt.Errorf("layer field %s is not a []any or string", k)
|
||||
}
|
||||
}
|
||||
elemTInfo := tInfo.ElemTypeInfo
|
||||
copiedVal := slices.Clone(val)
|
||||
container[k] = copiedVal
|
||||
for i := range copiedVal {
|
||||
if copiedVal[i] == nil {
|
||||
copiedVal[i] = elemTInfo.Zero()
|
||||
continue
|
||||
}
|
||||
|
||||
if len(elemTInfo.Properties) > 0 {
|
||||
subContainer, ok := copiedVal[i].(map[string]any)
|
||||
if !ok {
|
||||
return fmt.Errorf("map item under array %s is not map[string]any", k)
|
||||
}
|
||||
|
||||
newSubContainer := maps.Clone(subContainer)
|
||||
|
||||
for subK, subL := range elemTInfo.Properties {
|
||||
if err := replaceNilWithZero(subL, newSubContainer, subK); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
copiedVal[i] = newSubContainer
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
// recursively handle the layered object.
|
||||
subContainer, ok := v.(map[string]any)
|
||||
if !ok {
|
||||
subContainerStr, ok := v.(string)
|
||||
if ok {
|
||||
subContainer = make(map[string]any)
|
||||
err := sonic.UnmarshalString(subContainerStr, &subContainer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
container[k] = subContainer
|
||||
} else {
|
||||
return fmt.Errorf("layer field %s is not a map[string]any or string", k)
|
||||
}
|
||||
}
|
||||
|
||||
newSubContainer := maps.Clone(subContainer)
|
||||
for subK, subT := range tInfo.Properties {
|
||||
if err := replaceNilWithZero(subT, newSubContainer, subK); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
container[k] = newSubContainer
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
300
backend/domain/workflow/internal/compose/field_fill_test.go
Normal file
300
backend/domain/workflow/internal/compose/field_fill_test.go
Normal file
@@ -0,0 +1,300 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package compose
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
)
|
||||
|
||||
func TestNodeSchema_OutputValueFiller(t *testing.T) {
|
||||
type fields struct {
|
||||
In map[string]any
|
||||
Outputs map[string]*vo.TypeInfo
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want map[string]any
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "string field",
|
||||
fields: fields{
|
||||
In: map[string]any{},
|
||||
Outputs: map[string]*vo.TypeInfo{
|
||||
"key": {
|
||||
Type: vo.DataTypeString,
|
||||
},
|
||||
},
|
||||
},
|
||||
want: map[string]any{
|
||||
"key": nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "integer field",
|
||||
fields: fields{
|
||||
In: map[string]any{},
|
||||
Outputs: map[string]*vo.TypeInfo{
|
||||
"key": {
|
||||
Type: vo.DataTypeInteger,
|
||||
},
|
||||
},
|
||||
},
|
||||
want: map[string]any{
|
||||
"key": nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "number field",
|
||||
fields: fields{
|
||||
In: map[string]any{},
|
||||
Outputs: map[string]*vo.TypeInfo{
|
||||
"key": {
|
||||
Type: vo.DataTypeNumber,
|
||||
},
|
||||
},
|
||||
},
|
||||
want: map[string]any{
|
||||
"key": nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "boolean field",
|
||||
fields: fields{
|
||||
In: map[string]any{},
|
||||
Outputs: map[string]*vo.TypeInfo{
|
||||
"key": {
|
||||
Type: vo.DataTypeBoolean,
|
||||
},
|
||||
},
|
||||
},
|
||||
want: map[string]any{
|
||||
"key": nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "time field",
|
||||
fields: fields{
|
||||
In: map[string]any{},
|
||||
Outputs: map[string]*vo.TypeInfo{
|
||||
"key": {
|
||||
Type: vo.DataTypeTime,
|
||||
},
|
||||
},
|
||||
},
|
||||
want: map[string]any{
|
||||
"key": nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "object field",
|
||||
fields: fields{
|
||||
In: map[string]any{},
|
||||
Outputs: map[string]*vo.TypeInfo{
|
||||
"key": {
|
||||
Type: vo.DataTypeObject,
|
||||
},
|
||||
},
|
||||
},
|
||||
want: map[string]any{
|
||||
"key": nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "array field",
|
||||
fields: fields{
|
||||
In: map[string]any{},
|
||||
Outputs: map[string]*vo.TypeInfo{
|
||||
"key": {
|
||||
Type: vo.DataTypeArray,
|
||||
},
|
||||
},
|
||||
},
|
||||
want: map[string]any{
|
||||
"key": nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "file field",
|
||||
fields: fields{
|
||||
In: map[string]any{},
|
||||
Outputs: map[string]*vo.TypeInfo{
|
||||
"key": {
|
||||
Type: vo.DataTypeFile,
|
||||
},
|
||||
},
|
||||
},
|
||||
want: map[string]any{
|
||||
"key": nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "required field not present",
|
||||
fields: fields{
|
||||
In: map[string]any{},
|
||||
Outputs: map[string]*vo.TypeInfo{
|
||||
"key": {
|
||||
Type: vo.DataTypeString,
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: "is required but not present",
|
||||
},
|
||||
{
|
||||
name: "layered: object.string",
|
||||
fields: fields{
|
||||
In: map[string]any{
|
||||
"key": map[string]any{},
|
||||
},
|
||||
Outputs: map[string]*vo.TypeInfo{
|
||||
"key": {
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
"sub_key": {
|
||||
Type: vo.DataTypeString,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: map[string]any{
|
||||
"key": map[string]any{
|
||||
"sub_key": nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "layered: object.object",
|
||||
fields: fields{
|
||||
In: map[string]any{},
|
||||
Outputs: map[string]*vo.TypeInfo{
|
||||
"key": {
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
"sub_key": {
|
||||
Type: vo.DataTypeObject,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: map[string]any{
|
||||
"key": map[string]any{
|
||||
"sub_key": nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "layered: object.object.array",
|
||||
fields: fields{
|
||||
In: map[string]any{},
|
||||
Outputs: map[string]*vo.TypeInfo{
|
||||
"key": {
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
"sub_key": {
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
"sub_key2": {
|
||||
Type: vo.DataTypeArray,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: map[string]any{
|
||||
"key": map[string]any{
|
||||
"sub_key": map[string]any{
|
||||
"sub_key2": nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "key present",
|
||||
fields: fields{
|
||||
In: map[string]any{
|
||||
"key": "value",
|
||||
},
|
||||
Outputs: map[string]*vo.TypeInfo{
|
||||
"key": {
|
||||
Type: vo.DataTypeString,
|
||||
},
|
||||
},
|
||||
},
|
||||
want: map[string]any{
|
||||
"key": "value",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "layered key present",
|
||||
fields: fields{
|
||||
In: map[string]any{
|
||||
"key": map[string]any{},
|
||||
},
|
||||
Outputs: map[string]*vo.TypeInfo{
|
||||
"key": {
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
"sub_key": {
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
"sub_key2": {
|
||||
Type: vo.DataTypeArray,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: map[string]any{
|
||||
"key": map[string]any{
|
||||
"sub_key": map[string]any{
|
||||
"sub_key2": nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &NodeSchema{
|
||||
OutputTypes: tt.fields.Outputs,
|
||||
}
|
||||
|
||||
got, err := s.outputValueFiller()(context.Background(), tt.fields.In)
|
||||
|
||||
if len(tt.wantErr) > 0 {
|
||||
assert.Error(t, err)
|
||||
assert.ErrorContains(t, err, tt.wantErr)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
791
backend/domain/workflow/internal/compose/node_runner.go
Normal file
791
backend/domain/workflow/internal/compose/node_runner.go
Normal file
@@ -0,0 +1,791 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package compose
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/callbacks"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
type nodeRunConfig[O any] struct {
|
||||
nodeKey vo.NodeKey
|
||||
nodeName string
|
||||
nodeType entity.NodeType
|
||||
timeoutMS int64
|
||||
maxRetry int64
|
||||
errProcessType vo.ErrorProcessType
|
||||
dataOnErr func(ctx context.Context) map[string]any
|
||||
callbackEnabled bool
|
||||
preProcessors []func(ctx context.Context, input map[string]any) (map[string]any, error)
|
||||
postProcessors []func(ctx context.Context, input map[string]any) (map[string]any, error)
|
||||
streamPreProcessors []func(ctx context.Context,
|
||||
input *schema.StreamReader[map[string]any]) *schema.StreamReader[map[string]any]
|
||||
callbackInputConverter func(context.Context, map[string]any) (map[string]any, error)
|
||||
callbackOutputConverter func(context.Context, map[string]any) (*nodes.StructuredCallbackOutput, error)
|
||||
init []func(context.Context) (context.Context, error)
|
||||
i compose.Invoke[map[string]any, map[string]any, O]
|
||||
s compose.Stream[map[string]any, map[string]any, O]
|
||||
t compose.Transform[map[string]any, map[string]any, O]
|
||||
}
|
||||
|
||||
func newNodeRunConfig[O any](ns *NodeSchema,
|
||||
i compose.Invoke[map[string]any, map[string]any, O],
|
||||
s compose.Stream[map[string]any, map[string]any, O],
|
||||
t compose.Transform[map[string]any, map[string]any, O],
|
||||
opts *newNodeOptions) *nodeRunConfig[O] {
|
||||
meta := entity.NodeMetaByNodeType(ns.Type)
|
||||
|
||||
var (
|
||||
timeoutMS = meta.DefaultTimeoutMS
|
||||
maxRetry int64
|
||||
errProcessType = vo.ErrorProcessTypeThrow
|
||||
dataOnErr func(ctx context.Context) map[string]any
|
||||
)
|
||||
if ns.ExceptionConfigs != nil {
|
||||
timeoutMS = ns.ExceptionConfigs.TimeoutMS
|
||||
maxRetry = ns.ExceptionConfigs.MaxRetry
|
||||
if ns.ExceptionConfigs.ProcessType != nil {
|
||||
errProcessType = *ns.ExceptionConfigs.ProcessType
|
||||
}
|
||||
if len(ns.ExceptionConfigs.DataOnErr) > 0 {
|
||||
dataOnErr = func(ctx context.Context) map[string]any {
|
||||
return parseDefaultOutputOrFallback(ctx, ns.ExceptionConfigs.DataOnErr, ns.OutputTypes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
preProcessors := []func(ctx context.Context, input map[string]any) (map[string]any, error){
|
||||
preTypeConverter(ns.InputTypes),
|
||||
keyFinishedMarkerTrimmer(),
|
||||
}
|
||||
if meta.PreFillZero {
|
||||
preProcessors = append(preProcessors, ns.inputValueFiller())
|
||||
}
|
||||
|
||||
var postProcessors []func(ctx context.Context, input map[string]any) (map[string]any, error)
|
||||
if meta.PostFillNil {
|
||||
postProcessors = append(postProcessors, ns.outputValueFiller())
|
||||
}
|
||||
|
||||
streamPreProcessors := []func(ctx context.Context,
|
||||
input *schema.StreamReader[map[string]any]) *schema.StreamReader[map[string]any]{
|
||||
func(ctx context.Context, input *schema.StreamReader[map[string]any]) *schema.StreamReader[map[string]any] {
|
||||
f := func(in map[string]any) (map[string]any, error) {
|
||||
return preTypeConverter(ns.InputTypes)(ctx, in)
|
||||
}
|
||||
return schema.StreamReaderWithConvert(input, f)
|
||||
},
|
||||
}
|
||||
if meta.PreFillZero {
|
||||
streamPreProcessors = append(streamPreProcessors, ns.streamInputValueFiller())
|
||||
}
|
||||
|
||||
opts.init = append(opts.init, func(ctx context.Context) (context.Context, error) {
|
||||
current, exceeded := execute.IncrAndCheckExecutedNodes(ctx)
|
||||
if exceeded {
|
||||
return nil, fmt.Errorf("exceeded max executed node count: %d, current: %d", execute.GetStaticConfig().MaxNodeCountPerExecution, current)
|
||||
}
|
||||
return ctx, nil
|
||||
})
|
||||
|
||||
return &nodeRunConfig[O]{
|
||||
nodeKey: ns.Key,
|
||||
nodeName: ns.Name,
|
||||
nodeType: ns.Type,
|
||||
timeoutMS: timeoutMS,
|
||||
maxRetry: maxRetry,
|
||||
errProcessType: errProcessType,
|
||||
dataOnErr: dataOnErr,
|
||||
callbackEnabled: meta.CallbackEnabled,
|
||||
preProcessors: preProcessors,
|
||||
postProcessors: postProcessors,
|
||||
streamPreProcessors: streamPreProcessors,
|
||||
callbackInputConverter: opts.callbackInputConverter,
|
||||
callbackOutputConverter: opts.callbackOutputConverter,
|
||||
init: opts.init,
|
||||
i: i,
|
||||
s: s,
|
||||
t: t,
|
||||
}
|
||||
}
|
||||
|
||||
func newNodeRunConfigWOOpt(ns *NodeSchema,
|
||||
i compose.InvokeWOOpt[map[string]any, map[string]any],
|
||||
s compose.StreamWOOpt[map[string]any, map[string]any],
|
||||
t compose.TransformWOOpts[map[string]any, map[string]any],
|
||||
opts *newNodeOptions) *nodeRunConfig[any] {
|
||||
var (
|
||||
iWO compose.Invoke[map[string]any, map[string]any, any]
|
||||
sWO compose.Stream[map[string]any, map[string]any, any]
|
||||
tWO compose.Transform[map[string]any, map[string]any, any]
|
||||
)
|
||||
|
||||
if i != nil {
|
||||
iWO = func(ctx context.Context, in map[string]any, _ ...any) (out map[string]any, err error) {
|
||||
return i(ctx, in)
|
||||
}
|
||||
}
|
||||
|
||||
if s != nil {
|
||||
sWO = func(ctx context.Context, in map[string]any, _ ...any) (out *schema.StreamReader[map[string]any], err error) {
|
||||
return s(ctx, in)
|
||||
}
|
||||
}
|
||||
|
||||
if t != nil {
|
||||
tWO = func(ctx context.Context, input *schema.StreamReader[map[string]any], opts ...any) (output *schema.StreamReader[map[string]any], err error) {
|
||||
return t(ctx, input)
|
||||
}
|
||||
}
|
||||
|
||||
return newNodeRunConfig[any](ns, iWO, sWO, tWO, opts)
|
||||
}
|
||||
|
||||
type newNodeOptions struct {
|
||||
callbackInputConverter func(context.Context, map[string]any) (map[string]any, error)
|
||||
callbackOutputConverter func(context.Context, map[string]any) (*nodes.StructuredCallbackOutput, error)
|
||||
init []func(context.Context) (context.Context, error)
|
||||
}
|
||||
|
||||
type newNodeOption func(*newNodeOptions)
|
||||
|
||||
func withCallbackInputConverter(f func(context.Context, map[string]any) (map[string]any, error)) newNodeOption {
|
||||
return func(opts *newNodeOptions) {
|
||||
opts.callbackInputConverter = f
|
||||
}
|
||||
}
|
||||
func withCallbackOutputConverter(f func(context.Context, map[string]any) (*nodes.StructuredCallbackOutput, error)) newNodeOption {
|
||||
return func(opts *newNodeOptions) {
|
||||
opts.callbackOutputConverter = f
|
||||
}
|
||||
}
|
||||
func withInit(f func(context.Context) (context.Context, error)) newNodeOption {
|
||||
return func(opts *newNodeOptions) {
|
||||
opts.init = append(opts.init, f)
|
||||
}
|
||||
}
|
||||
|
||||
func invokableNode(ns *NodeSchema, i compose.InvokeWOOpt[map[string]any, map[string]any], opts ...newNodeOption) *Node {
|
||||
options := &newNodeOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
return newNodeRunConfigWOOpt(ns, i, nil, nil, options).toNode()
|
||||
}
|
||||
|
||||
func invokableNodeWO[O any](ns *NodeSchema, i compose.Invoke[map[string]any, map[string]any, O], opts ...newNodeOption) *Node {
|
||||
options := &newNodeOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
return newNodeRunConfig(ns, i, nil, nil, options).toNode()
|
||||
}
|
||||
|
||||
func invokableTransformableNode(ns *NodeSchema, i compose.InvokeWOOpt[map[string]any, map[string]any],
|
||||
t compose.TransformWOOpts[map[string]any, map[string]any], opts ...newNodeOption) *Node {
|
||||
options := &newNodeOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
return newNodeRunConfigWOOpt(ns, i, nil, t, options).toNode()
|
||||
}
|
||||
|
||||
func invokableStreamableNodeWO[O any](ns *NodeSchema, i compose.Invoke[map[string]any, map[string]any, O], s compose.Stream[map[string]any, map[string]any, O], opts ...newNodeOption) *Node {
|
||||
options := &newNodeOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
return newNodeRunConfig(ns, i, s, nil, options).toNode()
|
||||
}
|
||||
|
||||
func (nc *nodeRunConfig[O]) invoke() func(ctx context.Context, input map[string]any, opts ...O) (output map[string]any, err error) {
|
||||
if nc.i == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return func(ctx context.Context, input map[string]any, opts ...O) (output map[string]any, err error) {
|
||||
ctx, runner := newNodeRunner(ctx, nc)
|
||||
|
||||
defer func() {
|
||||
if panicErr := recover(); panicErr != nil {
|
||||
err = safego.NewPanicErr(panicErr, debug.Stack())
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
err = runner.onEnd(ctx, output)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
errOutput, hasErrOutput := runner.onError(ctx, err)
|
||||
if hasErrOutput {
|
||||
output = errOutput
|
||||
err = nil
|
||||
if output, err = runner.postProcess(ctx, output); err != nil {
|
||||
logs.CtxErrorf(ctx, "postProcess failed after returning error output: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for _, i := range runner.init {
|
||||
if ctx, err = i(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if input, err = runner.preProcess(ctx, input); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ctx, err = runner.onStart(ctx, input); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if output, err = runner.invoke(ctx, input, opts...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return runner.postProcess(ctx, output)
|
||||
}
|
||||
}
|
||||
|
||||
func (nc *nodeRunConfig[O]) stream() func(ctx context.Context, input map[string]any, opts ...O) (output *schema.StreamReader[map[string]any], err error) {
|
||||
if nc.s == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return func(ctx context.Context, input map[string]any, opts ...O) (output *schema.StreamReader[map[string]any], err error) {
|
||||
ctx, runner := newNodeRunner(ctx, nc)
|
||||
|
||||
defer func() {
|
||||
if panicErr := recover(); panicErr != nil {
|
||||
err = safego.NewPanicErr(panicErr, debug.Stack())
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
output, err = runner.onEndStream(ctx, output)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
errOutput, hasErrOutput := runner.onError(ctx, err)
|
||||
if hasErrOutput {
|
||||
output = schema.StreamReaderFromArray([]map[string]any{errOutput})
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for _, i := range runner.init {
|
||||
if ctx, err = i(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if input, err = runner.preProcess(ctx, input); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ctx, err = runner.onStart(ctx, input); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return runner.stream(ctx, input, opts...)
|
||||
}
|
||||
}
|
||||
|
||||
func (nc *nodeRunConfig[O]) transform() func(ctx context.Context, input *schema.StreamReader[map[string]any], opts ...O) (output *schema.StreamReader[map[string]any], err error) {
|
||||
if nc.t == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return func(ctx context.Context, input *schema.StreamReader[map[string]any], opts ...O) (output *schema.StreamReader[map[string]any], err error) {
|
||||
ctx, runner := newNodeRunner(ctx, nc)
|
||||
|
||||
defer func() {
|
||||
if panicErr := recover(); panicErr != nil {
|
||||
err = safego.NewPanicErr(panicErr, debug.Stack())
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
output, err = runner.onEndStream(ctx, output)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
errOutput, hasErrOutput := runner.onError(ctx, err)
|
||||
if hasErrOutput {
|
||||
output = schema.StreamReaderFromArray([]map[string]any{errOutput})
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for _, i := range runner.init {
|
||||
if ctx, err = i(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
for _, p := range runner.streamPreProcessors {
|
||||
input = p(ctx, input)
|
||||
}
|
||||
|
||||
if ctx, input, err = runner.onStartStream(ctx, input); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return runner.transform(ctx, input, opts...)
|
||||
}
|
||||
}
|
||||
|
||||
func (nc *nodeRunConfig[O]) toNode() *Node {
|
||||
var opts []compose.LambdaOpt
|
||||
opts = append(opts, compose.WithLambdaType(string(nc.nodeType)))
|
||||
|
||||
if nc.callbackEnabled {
|
||||
opts = append(opts, compose.WithLambdaCallbackEnable(true))
|
||||
}
|
||||
l, err := compose.AnyLambda(nc.invoke(), nc.stream(), nil, nc.transform(), opts...)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to create lambda for node %s, err: %v", nc.nodeName, err))
|
||||
}
|
||||
|
||||
return &Node{Lambda: l}
|
||||
}
|
||||
|
||||
type nodeRunner[O any] struct {
|
||||
*nodeRunConfig[O]
|
||||
interrupted bool
|
||||
cancelFn context.CancelFunc
|
||||
}
|
||||
|
||||
func newNodeRunner[O any](ctx context.Context, cfg *nodeRunConfig[O]) (context.Context, *nodeRunner[O]) {
|
||||
runner := &nodeRunner[O]{
|
||||
nodeRunConfig: cfg,
|
||||
}
|
||||
|
||||
if cfg.timeoutMS > 0 {
|
||||
ctx, runner.cancelFn = context.WithTimeout(ctx, time.Duration(cfg.timeoutMS)*time.Millisecond)
|
||||
}
|
||||
|
||||
return ctx, runner
|
||||
}
|
||||
|
||||
func (r *nodeRunner[O]) onStart(ctx context.Context, input map[string]any) (context.Context, error) {
|
||||
if !r.callbackEnabled {
|
||||
return ctx, nil
|
||||
}
|
||||
if r.callbackInputConverter != nil {
|
||||
convertedInput, err := r.callbackInputConverter(ctx, input)
|
||||
if err != nil {
|
||||
ctx = callbacks.OnStart(ctx, input)
|
||||
return ctx, err
|
||||
}
|
||||
ctx = callbacks.OnStart(ctx, convertedInput)
|
||||
} else {
|
||||
ctx = callbacks.OnStart(ctx, input)
|
||||
}
|
||||
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
func (r *nodeRunner[O]) onStartStream(ctx context.Context, input *schema.StreamReader[map[string]any]) (
|
||||
context.Context, *schema.StreamReader[map[string]any], error) {
|
||||
if !r.callbackEnabled {
|
||||
return ctx, input, nil
|
||||
}
|
||||
|
||||
if r.callbackInputConverter != nil {
|
||||
copied := input.Copy(2)
|
||||
realConverter := func(ctx context.Context) func(map[string]any) (map[string]any, error) {
|
||||
return func(in map[string]any) (map[string]any, error) {
|
||||
return r.callbackInputConverter(ctx, in)
|
||||
}
|
||||
}
|
||||
callbackS := schema.StreamReaderWithConvert(copied[0], realConverter(ctx))
|
||||
newCtx, unused := callbacks.OnStartWithStreamInput(ctx, callbackS)
|
||||
unused.Close()
|
||||
return newCtx, copied[1], nil
|
||||
}
|
||||
|
||||
newCtx, newInput := callbacks.OnStartWithStreamInput(ctx, input)
|
||||
return newCtx, newInput, nil
|
||||
}
|
||||
|
||||
func (r *nodeRunner[O]) preProcess(ctx context.Context, input map[string]any) (_ map[string]any, err error) {
|
||||
for _, preProcessor := range r.preProcessors {
|
||||
if preProcessor == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
input, err = preProcessor(ctx, input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return input, nil
|
||||
}
|
||||
|
||||
func (r *nodeRunner[O]) postProcess(ctx context.Context, output map[string]any) (_ map[string]any, err error) {
|
||||
for _, postProcessor := range r.postProcessors {
|
||||
if postProcessor == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
output, err = postProcessor(ctx, output)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return output, nil
|
||||
}
|
||||
|
||||
func (r *nodeRunner[O]) invoke(ctx context.Context, input map[string]any, opts ...O) (output map[string]any, err error) {
|
||||
var n int64
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
output, err = r.i(ctx, input, opts...)
|
||||
if err != nil {
|
||||
if _, ok := compose.IsInterruptRerunError(err); ok { // interrupt, won't retry
|
||||
r.interrupted = true
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logs.CtxErrorf(ctx, "[invoke] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
|
||||
if r.maxRetry > n {
|
||||
n++
|
||||
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil {
|
||||
exeCtx.CurrentRetryCount++
|
||||
}
|
||||
continue
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (r *nodeRunner[O]) stream(ctx context.Context, input map[string]any, opts ...O) (output *schema.StreamReader[map[string]any], err error) {
|
||||
var n int64
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
output, err = r.s(ctx, input, opts...)
|
||||
if err != nil {
|
||||
if _, ok := compose.IsInterruptRerunError(err); ok { // interrupt, won't retry
|
||||
r.interrupted = true
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logs.CtxErrorf(ctx, "[invoke] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
|
||||
if r.maxRetry > n {
|
||||
n++
|
||||
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil {
|
||||
exeCtx.CurrentRetryCount++
|
||||
}
|
||||
continue
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (r *nodeRunner[O]) transform(ctx context.Context, input *schema.StreamReader[map[string]any], opts ...O) (output *schema.StreamReader[map[string]any], err error) {
|
||||
if r.maxRetry == 0 {
|
||||
return r.t(ctx, input, opts...)
|
||||
}
|
||||
|
||||
copied := input.Copy(int(r.maxRetry))
|
||||
|
||||
var n int64
|
||||
defer func() {
|
||||
for i := n + 1; i < r.maxRetry; i++ {
|
||||
copied[i].Close()
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
output, err = r.t(ctx, copied[n], opts...)
|
||||
if err != nil {
|
||||
if _, ok := compose.IsInterruptRerunError(err); ok { // interrupt, won't retry
|
||||
r.interrupted = true
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logs.CtxErrorf(ctx, "[invoke] node %s ID %s failed on %d attempt, err: %v", r.nodeName, r.nodeKey, n, err)
|
||||
if r.maxRetry > n {
|
||||
n++
|
||||
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil && exeCtx.NodeCtx != nil {
|
||||
exeCtx.CurrentRetryCount++
|
||||
}
|
||||
continue
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (r *nodeRunner[O]) onEnd(ctx context.Context, output map[string]any) error {
|
||||
if r.errProcessType == vo.ErrorProcessTypeExceptionBranch || r.errProcessType == vo.ErrorProcessTypeDefault {
|
||||
output["isSuccess"] = true
|
||||
}
|
||||
|
||||
if !r.callbackEnabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
if r.callbackOutputConverter != nil {
|
||||
convertedOutput, err := r.callbackOutputConverter(ctx, output)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_ = callbacks.OnEnd(ctx, convertedOutput)
|
||||
} else {
|
||||
_ = callbacks.OnEnd(ctx, output)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *nodeRunner[O]) onEndStream(ctx context.Context, output *schema.StreamReader[map[string]any]) (
|
||||
*schema.StreamReader[map[string]any], error) {
|
||||
if r.errProcessType == vo.ErrorProcessTypeExceptionBranch || r.errProcessType == vo.ErrorProcessTypeDefault {
|
||||
flag := schema.StreamReaderFromArray([]map[string]any{{"isSuccess": true}})
|
||||
output = schema.MergeStreamReaders([]*schema.StreamReader[map[string]any]{flag, output})
|
||||
}
|
||||
|
||||
if !r.callbackEnabled {
|
||||
return output, nil
|
||||
}
|
||||
|
||||
if r.callbackOutputConverter != nil {
|
||||
copied := output.Copy(2)
|
||||
realConverter := func(ctx context.Context) func(map[string]any) (*nodes.StructuredCallbackOutput, error) {
|
||||
return func(in map[string]any) (*nodes.StructuredCallbackOutput, error) {
|
||||
return r.callbackOutputConverter(ctx, in)
|
||||
}
|
||||
}
|
||||
callbackS := schema.StreamReaderWithConvert(copied[0], realConverter(ctx))
|
||||
_, unused := callbacks.OnEndWithStreamOutput(ctx, callbackS)
|
||||
unused.Close()
|
||||
|
||||
return copied[1], nil
|
||||
}
|
||||
|
||||
_, newOutput := callbacks.OnEndWithStreamOutput(ctx, output)
|
||||
return newOutput, nil
|
||||
}
|
||||
|
||||
func (r *nodeRunner[O]) onError(ctx context.Context, err error) (map[string]any, bool) {
|
||||
if r.interrupted {
|
||||
if r.callbackEnabled {
|
||||
_ = callbacks.OnError(ctx, err)
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
var sErr vo.WorkflowError
|
||||
if !errors.As(err, &sErr) {
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
sErr = vo.NodeTimeoutErr
|
||||
} else if errors.Is(err, context.Canceled) {
|
||||
sErr = vo.CancelErr
|
||||
} else {
|
||||
sErr = vo.WrapError(errno.ErrWorkflowExecuteFail, err, errorx.KV("cause", vo.UnwrapRootErr(err).Error()))
|
||||
}
|
||||
}
|
||||
|
||||
code := int(sErr.Code())
|
||||
msg := sErr.Msg()
|
||||
|
||||
switch r.errProcessType {
|
||||
case vo.ErrorProcessTypeDefault:
|
||||
d := r.dataOnErr(ctx)
|
||||
d["errorBody"] = map[string]any{
|
||||
"errorMessage": msg,
|
||||
"errorCode": code,
|
||||
}
|
||||
d["isSuccess"] = false
|
||||
if r.callbackEnabled {
|
||||
sErr = sErr.ChangeErrLevel(vo.LevelWarn)
|
||||
sOutput := &nodes.StructuredCallbackOutput{
|
||||
Output: d,
|
||||
RawOutput: d,
|
||||
Error: sErr,
|
||||
}
|
||||
_ = callbacks.OnEnd(ctx, sOutput)
|
||||
}
|
||||
return d, true
|
||||
case vo.ErrorProcessTypeExceptionBranch:
|
||||
s := make(map[string]any)
|
||||
s["errorBody"] = map[string]any{
|
||||
"errorMessage": msg,
|
||||
"errorCode": code,
|
||||
}
|
||||
s["isSuccess"] = false
|
||||
if r.callbackEnabled {
|
||||
sErr = sErr.ChangeErrLevel(vo.LevelWarn)
|
||||
sOutput := &nodes.StructuredCallbackOutput{
|
||||
Output: s,
|
||||
RawOutput: s,
|
||||
Error: sErr,
|
||||
}
|
||||
_ = callbacks.OnEnd(ctx, sOutput)
|
||||
}
|
||||
return s, true
|
||||
default:
|
||||
if r.callbackEnabled {
|
||||
_ = callbacks.OnError(ctx, sErr)
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
func parseDefaultOutput(ctx context.Context, data string, schema_ map[string]*vo.TypeInfo) (map[string]any, error) {
|
||||
var result map[string]any
|
||||
|
||||
err := sonic.UnmarshalString(data, &result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r, ws, e := nodes.ConvertInputs(ctx, result, schema_)
|
||||
if e != nil {
|
||||
return nil, e
|
||||
}
|
||||
|
||||
if ws != nil {
|
||||
logs.CtxWarnf(ctx, "convert output warnings: %v", *ws)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func parseDefaultOutputOrFallback(ctx context.Context, data string, schema_ map[string]*vo.TypeInfo) map[string]any {
|
||||
result, err := parseDefaultOutput(ctx, data, schema_)
|
||||
if err != nil {
|
||||
fallback := make(map[string]any, len(schema_))
|
||||
for k, v := range schema_ {
|
||||
if v.Type == vo.DataTypeString {
|
||||
fallback[k] = data
|
||||
continue
|
||||
}
|
||||
fallback[k] = v.Zero()
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func preTypeConverter(inTypes map[string]*vo.TypeInfo) func(ctx context.Context, in map[string]any) (map[string]any, error) {
|
||||
return func(ctx context.Context, in map[string]any) (map[string]any, error) {
|
||||
out, ws, err := nodes.ConvertInputs(ctx, in, inTypes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ws != nil {
|
||||
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
|
||||
}
|
||||
|
||||
return out, err
|
||||
}
|
||||
}
|
||||
|
||||
func trimKeyFinishedMarker(ctx context.Context, in map[string]any) (map[string]any, bool, error) {
|
||||
var (
|
||||
newIn map[string]any
|
||||
trimmed bool
|
||||
)
|
||||
for k, v := range in {
|
||||
if vStr, ok := v.(string); ok {
|
||||
if strings.HasSuffix(vStr, nodes.KeyIsFinished) {
|
||||
if newIn == nil {
|
||||
newIn = maps.Clone(in)
|
||||
}
|
||||
vStr = strings.TrimSuffix(vStr, nodes.KeyIsFinished)
|
||||
newIn[k] = vStr
|
||||
trimmed = true
|
||||
}
|
||||
} else if vMap, ok := v.(map[string]any); ok {
|
||||
newMap, subTrimmed, err := trimKeyFinishedMarker(ctx, vMap)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if subTrimmed {
|
||||
if newIn == nil {
|
||||
newIn = maps.Clone(in)
|
||||
}
|
||||
newIn[k] = newMap
|
||||
trimmed = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if trimmed {
|
||||
return newIn, true, nil
|
||||
}
|
||||
|
||||
return in, false, nil
|
||||
}
|
||||
|
||||
func keyFinishedMarkerTrimmer() func(ctx context.Context, in map[string]any) (map[string]any, error) {
|
||||
return func(ctx context.Context, in map[string]any) (map[string]any, error) {
|
||||
out, _, err := trimKeyFinishedMarker(ctx, in)
|
||||
return out, err
|
||||
}
|
||||
}
|
||||
580
backend/domain/workflow/internal/compose/node_schema.go
Normal file
580
backend/domain/workflow/internal/compose/node_schema.go
Normal file
@@ -0,0 +1,580 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package compose
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime/debug"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/batch"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/code"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/conversation"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/emitter"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/httprequester"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/intentdetector"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/json"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/receiver"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/subworkflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/textprocessor"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableaggregator"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableassigner"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
|
||||
)
|
||||
|
||||
type NodeSchema struct {
|
||||
Key vo.NodeKey `json:"key"`
|
||||
Name string `json:"name"`
|
||||
|
||||
Type entity.NodeType `json:"type"`
|
||||
|
||||
// Configs are node specific configurations with pre-defined config key and config value.
|
||||
// Will not participate in request-time field mapping, nor as node's static values.
|
||||
// In a word, these Configs are INTERNAL to node's implementation, the workflow layer is not aware of them.
|
||||
Configs any `json:"configs,omitempty"`
|
||||
|
||||
InputTypes map[string]*vo.TypeInfo `json:"input_types,omitempty"`
|
||||
InputSources []*vo.FieldInfo `json:"input_sources,omitempty"`
|
||||
|
||||
OutputTypes map[string]*vo.TypeInfo `json:"output_types,omitempty"`
|
||||
OutputSources []*vo.FieldInfo `json:"output_sources,omitempty"` // only applicable to composite nodes such as Batch or Loop
|
||||
|
||||
ExceptionConfigs *ExceptionConfig `json:"exception_configs,omitempty"` // generic configurations applicable to most nodes
|
||||
StreamConfigs *StreamConfig `json:"stream_configs,omitempty"`
|
||||
|
||||
SubWorkflowBasic *entity.WorkflowBasic `json:"sub_workflow_basic,omitempty"`
|
||||
SubWorkflowSchema *WorkflowSchema `json:"sub_workflow_schema,omitempty"`
|
||||
|
||||
Lambda *compose.Lambda // not serializable, used for internal test.
|
||||
}
|
||||
|
||||
type ExceptionConfig struct {
|
||||
TimeoutMS int64 `json:"timeout_ms,omitempty"` // timeout in milliseconds, 0 means no timeout
|
||||
MaxRetry int64 `json:"max_retry,omitempty"` // max retry times, 0 means no retry
|
||||
ProcessType *vo.ErrorProcessType `json:"process_type,omitempty"` // error process type, 0 means throw error
|
||||
DataOnErr string `json:"data_on_err,omitempty"` // data to return when error, effective when ProcessType==Default occurs
|
||||
}
|
||||
|
||||
type StreamConfig struct {
|
||||
// whether this node has the ability to produce genuine streaming output.
|
||||
// not include nodes that only passes stream down as they receives them
|
||||
CanGeneratesStream bool `json:"can_generates_stream,omitempty"`
|
||||
// whether this node prioritize streaming input over none-streaming input.
|
||||
// not include nodes that can accept both and does not have preference.
|
||||
RequireStreamingInput bool `json:"can_process_stream,omitempty"`
|
||||
}
|
||||
|
||||
type Node struct {
|
||||
Lambda *compose.Lambda
|
||||
}
|
||||
|
||||
func (s *NodeSchema) New(ctx context.Context, inner compose.Runnable[map[string]any, map[string]any],
|
||||
sc *WorkflowSchema, deps *dependencyInfo) (_ *Node, err error) {
|
||||
defer func() {
|
||||
if panicErr := recover(); panicErr != nil {
|
||||
err = safego.NewPanicErr(panicErr, debug.Stack())
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
err = vo.WrapIfNeeded(errno.ErrCreateNodeFail, err, errorx.KV("node_name", s.Name), errorx.KV("cause", err.Error()))
|
||||
}
|
||||
}()
|
||||
|
||||
if m := entity.NodeMetaByNodeType(s.Type); m != nil && m.InputSourceAware {
|
||||
if err = s.SetFullSources(sc.GetAllNodes(), deps); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
switch s.Type {
|
||||
case entity.NodeTypeLambda:
|
||||
if s.Lambda == nil {
|
||||
return nil, fmt.Errorf("lambda is not defined for NodeTypeLambda")
|
||||
}
|
||||
|
||||
return &Node{Lambda: s.Lambda}, nil
|
||||
case entity.NodeTypeLLM:
|
||||
conf, err := s.ToLLMConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
l, err := llm.New(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableStreamableNodeWO(s, l.Chat, l.ChatStream, withCallbackOutputConverter(l.ToCallbackOutput)), nil
|
||||
case entity.NodeTypeSelector:
|
||||
conf := s.ToSelectorConfig()
|
||||
|
||||
sl, err := selector.NewSelector(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableNode(s, sl.Select, withCallbackInputConverter(s.toSelectorCallbackInput(sc)), withCallbackOutputConverter(sl.ToCallbackOutput)), nil
|
||||
case entity.NodeTypeBatch:
|
||||
if inner == nil {
|
||||
return nil, fmt.Errorf("inner workflow must not be nil when creating batch node")
|
||||
}
|
||||
|
||||
conf, err := s.ToBatchConfig(inner)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b, err := batch.NewBatch(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableNodeWO(s, b.Execute, withCallbackInputConverter(b.ToCallbackInput)), nil
|
||||
case entity.NodeTypeVariableAggregator:
|
||||
conf, err := s.ToVariableAggregatorConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
va, err := variableaggregator.NewVariableAggregator(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableTransformableNode(s, va.Invoke, va.Transform,
|
||||
withCallbackInputConverter(va.ToCallbackInput),
|
||||
withCallbackOutputConverter(va.ToCallbackOutput),
|
||||
withInit(va.Init)), nil
|
||||
case entity.NodeTypeTextProcessor:
|
||||
conf, err := s.ToTextProcessorConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tp, err := textprocessor.NewTextProcessor(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableNode(s, tp.Invoke), nil
|
||||
case entity.NodeTypeHTTPRequester:
|
||||
conf, err := s.ToHTTPRequesterConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hr, err := httprequester.NewHTTPRequester(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableNode(s, hr.Invoke, withCallbackInputConverter(hr.ToCallbackInput), withCallbackOutputConverter(hr.ToCallbackOutput)), nil
|
||||
case entity.NodeTypeContinue:
|
||||
i := func(ctx context.Context, in map[string]any) (map[string]any, error) {
|
||||
return map[string]any{}, nil
|
||||
}
|
||||
return invokableNode(s, i), nil
|
||||
case entity.NodeTypeBreak:
|
||||
b, err := loop.NewBreak(ctx, &nodes.ParentIntermediateStore{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, b.DoBreak), nil
|
||||
case entity.NodeTypeVariableAssigner:
|
||||
handler := variable.GetVariableHandler()
|
||||
|
||||
conf, err := s.ToVariableAssignerConfig(handler)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
va, err := variableassigner.NewVariableAssigner(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableNode(s, va.Assign), nil
|
||||
case entity.NodeTypeVariableAssignerWithinLoop:
|
||||
conf, err := s.ToVariableAssignerInLoopConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
va, err := variableassigner.NewVariableAssignerInLoop(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableNode(s, va.Assign), nil
|
||||
case entity.NodeTypeLoop:
|
||||
conf, err := s.ToLoopConfig(inner)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l, err := loop.NewLoop(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNodeWO(s, l.Execute, withCallbackInputConverter(l.ToCallbackInput)), nil
|
||||
case entity.NodeTypeQuestionAnswer:
|
||||
conf, err := s.ToQAConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
qA, err := qa.NewQuestionAnswer(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, qA.Execute, withCallbackOutputConverter(qA.ToCallbackOutput)), nil
|
||||
case entity.NodeTypeInputReceiver:
|
||||
conf, err := s.ToInputReceiverConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inputR, err := receiver.New(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, inputR.Invoke, withCallbackOutputConverter(inputR.ToCallbackOutput)), nil
|
||||
case entity.NodeTypeOutputEmitter:
|
||||
conf, err := s.ToOutputEmitterConfig(sc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
e, err := emitter.New(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableTransformableNode(s, e.Emit, e.EmitStream), nil
|
||||
case entity.NodeTypeEntry:
|
||||
conf, err := s.ToEntryConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
e, err := entry.NewEntry(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, e.Invoke), nil
|
||||
case entity.NodeTypeExit:
|
||||
terminalPlan := mustGetKey[vo.TerminatePlan]("TerminalPlan", s.Configs)
|
||||
if terminalPlan == vo.ReturnVariables {
|
||||
i := func(ctx context.Context, in map[string]any) (map[string]any, error) {
|
||||
if in == nil {
|
||||
return map[string]any{}, nil
|
||||
}
|
||||
return in, nil
|
||||
}
|
||||
return invokableNode(s, i), nil
|
||||
}
|
||||
|
||||
conf, err := s.ToOutputEmitterConfig(sc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
e, err := emitter.New(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableTransformableNode(s, e.Emit, e.EmitStream), nil
|
||||
case entity.NodeTypeDatabaseCustomSQL:
|
||||
conf, err := s.ToDatabaseCustomSQLConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sqlER, err := database.NewCustomSQL(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, sqlER.Execute), nil
|
||||
case entity.NodeTypeDatabaseQuery:
|
||||
conf, err := s.ToDatabaseQueryConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
query, err := database.NewQuery(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableNode(s, query.Query, withCallbackInputConverter(query.ToCallbackInput)), nil
|
||||
case entity.NodeTypeDatabaseInsert:
|
||||
conf, err := s.ToDatabaseInsertConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
insert, err := database.NewInsert(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableNode(s, insert.Insert, withCallbackInputConverter(insert.ToCallbackInput)), nil
|
||||
case entity.NodeTypeDatabaseUpdate:
|
||||
conf, err := s.ToDatabaseUpdateConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
update, err := database.NewUpdate(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, update.Update, withCallbackInputConverter(update.ToCallbackInput)), nil
|
||||
case entity.NodeTypeDatabaseDelete:
|
||||
conf, err := s.ToDatabaseDeleteConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
del, err := database.NewDelete(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableNode(s, del.Delete, withCallbackInputConverter(del.ToCallbackInput)), nil
|
||||
case entity.NodeTypeKnowledgeIndexer:
|
||||
conf, err := s.ToKnowledgeIndexerConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
w, err := knowledge.NewKnowledgeIndexer(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, w.Store), nil
|
||||
case entity.NodeTypeKnowledgeRetriever:
|
||||
conf, err := s.ToKnowledgeRetrieveConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r, err := knowledge.NewKnowledgeRetrieve(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, r.Retrieve), nil
|
||||
case entity.NodeTypeKnowledgeDeleter:
|
||||
conf, err := s.ToKnowledgeDeleterConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r, err := knowledge.NewKnowledgeDeleter(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, r.Delete), nil
|
||||
case entity.NodeTypeCodeRunner:
|
||||
conf, err := s.ToCodeRunnerConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r, err := code.NewCodeRunner(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
initFn := func(ctx context.Context) (context.Context, error) {
|
||||
return ctxcache.Init(ctx), nil
|
||||
}
|
||||
return invokableNode(s, r.RunCode, withCallbackOutputConverter(r.ToCallbackOutput), withInit(initFn)), nil
|
||||
case entity.NodeTypePlugin:
|
||||
conf, err := s.ToPluginConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r, err := plugin.NewPlugin(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, r.Invoke), nil
|
||||
case entity.NodeTypeCreateConversation:
|
||||
conf, err := s.ToCreateConversationConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r, err := conversation.NewCreateConversation(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, r.Create), nil
|
||||
case entity.NodeTypeMessageList:
|
||||
conf, err := s.ToMessageListConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r, err := conversation.NewMessageList(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, r.List), nil
|
||||
case entity.NodeTypeClearMessage:
|
||||
conf, err := s.ToClearMessageConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r, err := conversation.NewClearMessage(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, r.Clear), nil
|
||||
case entity.NodeTypeIntentDetector:
|
||||
conf, err := s.ToIntentDetectorConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r, err := intentdetector.NewIntentDetector(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableNode(s, r.Invoke), nil
|
||||
case entity.NodeTypeSubWorkflow:
|
||||
conf, err := s.ToSubWorkflowConfig(ctx, sc.requireCheckPoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r, err := subworkflow.NewSubWorkflow(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableStreamableNodeWO(s, r.Invoke, r.Stream), nil
|
||||
case entity.NodeTypeJsonSerialization:
|
||||
conf, err := s.ToJsonSerializationConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
js, err := json.NewJsonSerializer(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, js.Invoke), nil
|
||||
case entity.NodeTypeJsonDeserialization:
|
||||
conf, err := s.ToJsonDeserializationConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jd, err := json.NewJsonDeserializer(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
initFn := func(ctx context.Context) (context.Context, error) {
|
||||
return ctxcache.Init(ctx), nil
|
||||
}
|
||||
return invokableNode(s, jd.Invoke, withCallbackOutputConverter(jd.ToCallbackOutput), withInit(initFn)), nil
|
||||
default:
|
||||
panic("not implemented")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NodeSchema) IsEnableUserQuery() bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
if s.Type != entity.NodeTypeEntry {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(s.OutputSources) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, source := range s.OutputSources {
|
||||
fieldPath := source.Path
|
||||
if len(fieldPath) == 1 && (fieldPath[0] == "BOT_USER_INPUT" || fieldPath[0] == "USER_INPUT") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
|
||||
}
|
||||
|
||||
func (s *NodeSchema) IsEnableChatHistory() bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
switch s.Type {
|
||||
|
||||
case entity.NodeTypeLLM:
|
||||
llmParam := mustGetKey[*model.LLMParams]("LLMParams", s.Configs)
|
||||
return llmParam.EnableChatHistory
|
||||
case entity.NodeTypeIntentDetector:
|
||||
llmParam := mustGetKey[*model.LLMParams]("LLMParams", s.Configs)
|
||||
return llmParam.EnableChatHistory
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (s *NodeSchema) IsRefGlobalVariable() bool {
|
||||
for _, source := range s.InputSources {
|
||||
if source.IsRefGlobalVariable() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
for _, source := range s.OutputSources {
|
||||
if source.IsRefGlobalVariable() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *NodeSchema) requireCheckpoint() bool {
|
||||
if s.Type == entity.NodeTypeQuestionAnswer || s.Type == entity.NodeTypeInputReceiver {
|
||||
return true
|
||||
}
|
||||
|
||||
if s.Type == entity.NodeTypeLLM {
|
||||
fcParams := getKeyOrZero[*vo.FCParam]("FCParam", s.Configs)
|
||||
if fcParams != nil && fcParams.WorkflowFCParam != nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if s.Type == entity.NodeTypeSubWorkflow {
|
||||
s.SubWorkflowSchema.Init()
|
||||
if s.SubWorkflowSchema.requireCheckPoint {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
939
backend/domain/workflow/internal/compose/state.go
Normal file
939
backend/domain/workflow/internal/compose/state.go
Normal file
@@ -0,0 +1,939 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package compose
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
workflow2 "github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/receiver"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableassigner"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
)
|
||||
|
||||
type State struct {
|
||||
Answers map[vo.NodeKey][]string `json:"answers,omitempty"`
|
||||
Questions map[vo.NodeKey][]*qa.Question `json:"questions,omitempty"`
|
||||
Inputs map[vo.NodeKey]map[string]any `json:"inputs,omitempty"`
|
||||
NodeExeContexts map[vo.NodeKey]*execute.Context `json:"-"`
|
||||
WorkflowExeContext *execute.Context `json:"-"`
|
||||
InterruptEvents map[vo.NodeKey]*entity.InterruptEvent `json:"interrupt_events,omitempty"`
|
||||
NestedWorkflowStates map[vo.NodeKey]*nodes.NestedWorkflowState `json:"nested_workflow_states,omitempty"`
|
||||
|
||||
ExecutedNodes map[vo.NodeKey]bool `json:"executed_nodes,omitempty"`
|
||||
SourceInfos map[vo.NodeKey]map[string]*nodes.SourceInfo `json:"source_infos,omitempty"`
|
||||
GroupChoices map[vo.NodeKey]map[string]int `json:"group_choices,omitempty"`
|
||||
|
||||
ToolInterruptEvents map[vo.NodeKey]map[string] /*ToolCallID*/ *entity.ToolInterruptEvent `json:"tool_interrupt_events,omitempty"`
|
||||
LLMToResumeData map[vo.NodeKey]string `json:"llm_to_resume_data,omitempty"`
|
||||
AppVariableStore *variableassigner.AppVariables `json:"variable_app_store,omitempty"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
_ = compose.RegisterSerializableType[*State]("schema_state")
|
||||
_ = compose.RegisterSerializableType[[]*qa.Question]("qa_question_list")
|
||||
_ = compose.RegisterSerializableType[qa.Question]("qa_question")
|
||||
_ = compose.RegisterSerializableType[vo.NodeKey]("node_key")
|
||||
_ = compose.RegisterSerializableType[*execute.Context]("exe_context")
|
||||
_ = compose.RegisterSerializableType[execute.RootCtx]("root_ctx")
|
||||
_ = compose.RegisterSerializableType[*execute.SubWorkflowCtx]("sub_workflow_ctx")
|
||||
_ = compose.RegisterSerializableType[*execute.NodeCtx]("node_ctx")
|
||||
_ = compose.RegisterSerializableType[*execute.BatchInfo]("batch_info")
|
||||
_ = compose.RegisterSerializableType[*execute.TokenCollector]("token_collector")
|
||||
_ = compose.RegisterSerializableType[entity.NodeType]("node_type")
|
||||
_ = compose.RegisterSerializableType[*entity.InterruptEvent]("interrupt_event")
|
||||
_ = compose.RegisterSerializableType[workflow2.EventType]("workflow_event_type")
|
||||
_ = compose.RegisterSerializableType[*model.TokenUsage]("model_token_usage")
|
||||
_ = compose.RegisterSerializableType[*nodes.NestedWorkflowState]("composite_state")
|
||||
_ = compose.RegisterSerializableType[*compose.InterruptInfo]("interrupt_info")
|
||||
_ = compose.RegisterSerializableType[*nodes.SourceInfo]("source_info")
|
||||
_ = compose.RegisterSerializableType[nodes.FieldStreamType]("field_stream_type")
|
||||
_ = compose.RegisterSerializableType[compose.FieldPath]("field_path")
|
||||
_ = compose.RegisterSerializableType[*entity.WorkflowBasic]("workflow_basic")
|
||||
_ = compose.RegisterSerializableType[vo.TerminatePlan]("terminate_plan")
|
||||
_ = compose.RegisterSerializableType[*entity.ToolInterruptEvent]("tool_interrupt_event")
|
||||
_ = compose.RegisterSerializableType[vo.ExecuteConfig]("execute_config")
|
||||
_ = compose.RegisterSerializableType[vo.ExecuteMode]("execute_mode")
|
||||
_ = compose.RegisterSerializableType[vo.TaskType]("task_type")
|
||||
_ = compose.RegisterSerializableType[vo.SyncPattern]("sync_pattern")
|
||||
_ = compose.RegisterSerializableType[vo.Locator]("wf_locator")
|
||||
_ = compose.RegisterSerializableType[vo.BizType]("biz_type")
|
||||
_ = compose.RegisterSerializableType[*variableassigner.AppVariables]("app_variables")
|
||||
}
|
||||
|
||||
func (s *State) SetAppVariableValue(key string, value any) {
|
||||
s.AppVariableStore.Set(key, value)
|
||||
}
|
||||
|
||||
func (s *State) GetAppVariableValue(key string) (any, bool) {
|
||||
return s.AppVariableStore.Get(key)
|
||||
}
|
||||
|
||||
func (s *State) AddQuestion(nodeKey vo.NodeKey, question *qa.Question) {
|
||||
s.Questions[nodeKey] = append(s.Questions[nodeKey], question)
|
||||
}
|
||||
|
||||
func (s *State) AddAnswer(nodeKey vo.NodeKey, answer string) {
|
||||
s.Answers[nodeKey] = append(s.Answers[nodeKey], answer)
|
||||
}
|
||||
|
||||
func (s *State) GetQuestionsAndAnswers(nodeKey vo.NodeKey) ([]*qa.Question, []string) {
|
||||
return s.Questions[nodeKey], s.Answers[nodeKey]
|
||||
}
|
||||
|
||||
func (s *State) GetNodeCtx(key vo.NodeKey) (*execute.Context, bool, error) {
|
||||
c, ok := s.NodeExeContexts[key]
|
||||
if ok {
|
||||
return c, true, nil
|
||||
}
|
||||
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
func (s *State) SetNodeCtx(key vo.NodeKey, value *execute.Context) error {
|
||||
s.NodeExeContexts[key] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *State) GetWorkflowCtx() (*execute.Context, bool, error) {
|
||||
if s.WorkflowExeContext == nil {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
return s.WorkflowExeContext, true, nil
|
||||
}
|
||||
|
||||
func (s *State) SetWorkflowCtx(value *execute.Context) error {
|
||||
s.WorkflowExeContext = value
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *State) GetInterruptEvent(nodeKey vo.NodeKey) (*entity.InterruptEvent, bool, error) {
|
||||
if v, ok := s.InterruptEvents[nodeKey]; ok {
|
||||
return v, true, nil
|
||||
}
|
||||
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
func (s *State) SetInterruptEvent(nodeKey vo.NodeKey, value *entity.InterruptEvent) error {
|
||||
s.InterruptEvents[nodeKey] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *State) DeleteInterruptEvent(nodeKey vo.NodeKey) error {
|
||||
delete(s.InterruptEvents, nodeKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *State) GetNestedWorkflowState(key vo.NodeKey) (*nodes.NestedWorkflowState, bool, error) {
|
||||
if v, ok := s.NestedWorkflowStates[key]; ok {
|
||||
return v, true, nil
|
||||
}
|
||||
return nil, false, nil
|
||||
}
|
||||
func (s *State) SaveNestedWorkflowState(key vo.NodeKey, value *nodes.NestedWorkflowState) error {
|
||||
s.NestedWorkflowStates[key] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *State) SaveDynamicChoice(nodeKey vo.NodeKey, groupToChoice map[string]int) {
|
||||
s.GroupChoices[nodeKey] = groupToChoice
|
||||
}
|
||||
|
||||
func (s *State) GetDynamicChoice(nodeKey vo.NodeKey) map[string]int {
|
||||
return s.GroupChoices[nodeKey]
|
||||
}
|
||||
|
||||
func (s *State) GetDynamicStreamType(nodeKey vo.NodeKey, group string) (nodes.FieldStreamType, error) {
|
||||
choices, ok := s.GroupChoices[nodeKey]
|
||||
if !ok {
|
||||
return nodes.FieldMaybeStream, fmt.Errorf("choice not found for node %s", nodeKey)
|
||||
}
|
||||
|
||||
choice, ok := choices[group]
|
||||
if !ok {
|
||||
return nodes.FieldMaybeStream, fmt.Errorf("choice not found for node %s and group %s", nodeKey, group)
|
||||
}
|
||||
|
||||
if choice == -1 { // this group picks none of the elements
|
||||
return nodes.FieldNotStream, nil
|
||||
}
|
||||
|
||||
sInfos, ok := s.SourceInfos[nodeKey]
|
||||
if !ok {
|
||||
return nodes.FieldMaybeStream, fmt.Errorf("source infos not found for node %s", nodeKey)
|
||||
}
|
||||
|
||||
groupInfo, ok := sInfos[group]
|
||||
if !ok {
|
||||
return nodes.FieldMaybeStream, fmt.Errorf("source infos not found for node %s and group %s", nodeKey, group)
|
||||
}
|
||||
|
||||
if groupInfo.SubSources == nil {
|
||||
return nodes.FieldNotStream, fmt.Errorf("dynamic group %s of node %s does not contain any sub sources", group, nodeKey)
|
||||
}
|
||||
|
||||
subInfo, ok := groupInfo.SubSources[strconv.Itoa(choice)]
|
||||
if !ok {
|
||||
return nodes.FieldNotStream, fmt.Errorf("dynamic group %s of node %s does not contain sub source for choice %d", group, nodeKey, choice)
|
||||
}
|
||||
|
||||
if subInfo.FieldType != nodes.FieldMaybeStream {
|
||||
return subInfo.FieldType, nil
|
||||
}
|
||||
|
||||
if len(subInfo.FromNodeKey) == 0 {
|
||||
panic("subInfo is maybe stream, but from node key is empty")
|
||||
}
|
||||
|
||||
if len(subInfo.FromPath) > 1 || len(subInfo.FromPath) == 0 {
|
||||
panic("subInfo is maybe stream, but from path is more than 1 segments or is empty")
|
||||
}
|
||||
|
||||
return s.GetDynamicStreamType(subInfo.FromNodeKey, subInfo.FromPath[0])
|
||||
}
|
||||
|
||||
func (s *State) GetAllDynamicStreamTypes(nodeKey vo.NodeKey) (map[string]nodes.FieldStreamType, error) {
|
||||
result := make(map[string]nodes.FieldStreamType)
|
||||
choices, ok := s.GroupChoices[nodeKey]
|
||||
if !ok {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
for group := range choices {
|
||||
t, err := s.GetDynamicStreamType(nodeKey, group)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[group] = t
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *State) SetToolInterruptEvent(llmNodeKey vo.NodeKey, toolCallID string, ie *entity.ToolInterruptEvent) error {
|
||||
if _, ok := s.ToolInterruptEvents[llmNodeKey]; !ok {
|
||||
s.ToolInterruptEvents[llmNodeKey] = make(map[string]*entity.ToolInterruptEvent)
|
||||
}
|
||||
s.ToolInterruptEvents[llmNodeKey][toolCallID] = ie
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *State) GetToolInterruptEvents(llmNodeKey vo.NodeKey) (map[string]*entity.ToolInterruptEvent, error) {
|
||||
return s.ToolInterruptEvents[llmNodeKey], nil
|
||||
}
|
||||
|
||||
func (s *State) ResumeToolInterruptEvent(llmNodeKey vo.NodeKey, toolCallID string) (string, error) {
|
||||
resumeData, ok := s.LLMToResumeData[llmNodeKey]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("resume data not found for llm node %s", llmNodeKey)
|
||||
}
|
||||
delete(s.ToolInterruptEvents[llmNodeKey], toolCallID)
|
||||
delete(s.LLMToResumeData, llmNodeKey)
|
||||
return resumeData, nil
|
||||
}
|
||||
|
||||
func (s *State) NodeExecuted(key vo.NodeKey) bool {
|
||||
if key == compose.START {
|
||||
return true
|
||||
}
|
||||
_, ok := s.ExecutedNodes[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
func GenState() compose.GenLocalState[*State] {
|
||||
return func(ctx context.Context) (state *State) {
|
||||
var parentState *State
|
||||
_ = compose.ProcessState(ctx, func(ctx context.Context, s *State) error {
|
||||
parentState = s
|
||||
return nil
|
||||
})
|
||||
|
||||
var appVariableStore *variableassigner.AppVariables
|
||||
if parentState == nil {
|
||||
appVariableStore = variableassigner.NewAppVariables()
|
||||
} else {
|
||||
appVariableStore = parentState.AppVariableStore
|
||||
}
|
||||
|
||||
return &State{
|
||||
Answers: make(map[vo.NodeKey][]string),
|
||||
Questions: make(map[vo.NodeKey][]*qa.Question),
|
||||
Inputs: make(map[vo.NodeKey]map[string]any),
|
||||
NodeExeContexts: make(map[vo.NodeKey]*execute.Context),
|
||||
InterruptEvents: make(map[vo.NodeKey]*entity.InterruptEvent),
|
||||
NestedWorkflowStates: make(map[vo.NodeKey]*nodes.NestedWorkflowState),
|
||||
ExecutedNodes: make(map[vo.NodeKey]bool),
|
||||
SourceInfos: make(map[vo.NodeKey]map[string]*nodes.SourceInfo),
|
||||
GroupChoices: make(map[vo.NodeKey]map[string]int),
|
||||
ToolInterruptEvents: make(map[vo.NodeKey]map[string]*entity.ToolInterruptEvent),
|
||||
LLMToResumeData: make(map[vo.NodeKey]string),
|
||||
AppVariableStore: appVariableStore,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt {
|
||||
var (
|
||||
handlers []compose.StatePreHandler[map[string]any, *State]
|
||||
streamHandlers []compose.StreamStatePreHandler[map[string]any, *State]
|
||||
)
|
||||
|
||||
if s.Type == entity.NodeTypeQuestionAnswer {
|
||||
handlers = append(handlers, func(ctx context.Context, in map[string]any, state *State) (map[string]any, error) {
|
||||
// even on first execution before any interruption, the input could be empty
|
||||
// so we need to check if we have stored any questions in state, to decide whether this is the first execution
|
||||
isFirst := false
|
||||
if _, ok := state.Questions[s.Key]; !ok {
|
||||
isFirst = true
|
||||
}
|
||||
|
||||
if isFirst {
|
||||
state.Inputs[s.Key] = in
|
||||
return in, nil
|
||||
}
|
||||
|
||||
out := make(map[string]any)
|
||||
for k, v := range state.Inputs[s.Key] {
|
||||
out[k] = v
|
||||
}
|
||||
|
||||
out[qa.QuestionsKey] = state.Questions[s.Key]
|
||||
out[qa.AnswersKey] = state.Answers[s.Key]
|
||||
return out, nil
|
||||
})
|
||||
} else if s.Type == entity.NodeTypeInputReceiver {
|
||||
// InputReceiver node's only input is set by StateModifier when resuming
|
||||
handlers = append(handlers, func(ctx context.Context, in map[string]any, state *State) (map[string]any, error) {
|
||||
if userInput, ok := state.Inputs[s.Key]; ok && len(userInput) > 0 {
|
||||
return userInput, nil
|
||||
}
|
||||
return in, nil
|
||||
})
|
||||
} else if s.Type == entity.NodeTypeBatch || s.Type == entity.NodeTypeLoop {
|
||||
handlers = append(handlers, func(ctx context.Context, in map[string]any, state *State) (map[string]any, error) {
|
||||
if _, ok := state.Inputs[s.Key]; !ok { // first execution, store input for potential resume later
|
||||
state.Inputs[s.Key] = in
|
||||
return in, nil
|
||||
}
|
||||
out := make(map[string]any)
|
||||
for k, v := range state.Inputs[s.Key] {
|
||||
out[k] = v
|
||||
}
|
||||
return out, nil
|
||||
})
|
||||
}
|
||||
|
||||
if len(handlers) > 0 || !stream {
|
||||
handlerForVars := s.statePreHandlerForVars()
|
||||
if handlerForVars != nil {
|
||||
handlers = append(handlers, handlerForVars)
|
||||
}
|
||||
stateHandler := func(ctx context.Context, in map[string]any, state *State) (map[string]any, error) {
|
||||
var err error
|
||||
for _, h := range handlers {
|
||||
in, err = h(ctx, in, state)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return in, nil
|
||||
}
|
||||
return compose.WithStatePreHandler(stateHandler)
|
||||
}
|
||||
|
||||
if s.Type == entity.NodeTypeVariableAggregator {
|
||||
streamHandlers = append(streamHandlers, func(ctx context.Context, in *schema.StreamReader[map[string]any], state *State) (*schema.StreamReader[map[string]any], error) {
|
||||
state.SourceInfos[s.Key] = mustGetKey[map[string]*nodes.SourceInfo]("FullSources", s.Configs)
|
||||
return in, nil
|
||||
})
|
||||
}
|
||||
|
||||
handlerForVars := s.streamStatePreHandlerForVars()
|
||||
if handlerForVars != nil {
|
||||
streamHandlers = append(streamHandlers, handlerForVars)
|
||||
}
|
||||
|
||||
/*handlerForStreamSource := s.streamStatePreHandlerForStreamSources()
|
||||
if handlerForStreamSource != nil {
|
||||
streamHandlers = append(streamHandlers, handlerForStreamSource)
|
||||
}*/
|
||||
|
||||
if len(streamHandlers) > 0 {
|
||||
streamHandler := func(ctx context.Context, in *schema.StreamReader[map[string]any], state *State) (*schema.StreamReader[map[string]any], error) {
|
||||
var err error
|
||||
for _, h := range streamHandlers {
|
||||
in, err = h(ctx, in, state)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return in, nil
|
||||
}
|
||||
return compose.WithStreamStatePreHandler(streamHandler)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) statePreHandlerForVars() compose.StatePreHandler[map[string]any, *State] {
|
||||
// checkout the node's inputs, if it has any variable, use the state's variableHandler to get the variables and set them to the input
|
||||
var vars []*vo.FieldInfo
|
||||
for _, input := range s.InputSources {
|
||||
if input.Source.Ref != nil && input.Source.Ref.VariableType != nil {
|
||||
vars = append(vars, input)
|
||||
}
|
||||
}
|
||||
|
||||
if len(vars) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
varStoreHandler := variable.GetVariableHandler()
|
||||
intermediateVarStore := &nodes.ParentIntermediateStore{}
|
||||
|
||||
return func(ctx context.Context, in map[string]any, state *State) (map[string]any, error) {
|
||||
|
||||
opts := make([]variable.OptionFn, 0, 1)
|
||||
|
||||
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil {
|
||||
exeCfg := execute.GetExeCtx(ctx).RootCtx.ExeCfg
|
||||
opts = append(opts, variable.WithStoreInfo(variable.StoreInfo{
|
||||
AgentID: exeCfg.AgentID,
|
||||
AppID: exeCfg.AppID,
|
||||
ConnectorID: exeCfg.ConnectorID,
|
||||
ConnectorUID: exeCfg.ConnectorUID,
|
||||
}))
|
||||
}
|
||||
out := make(map[string]any)
|
||||
for k, v := range in {
|
||||
out[k] = v
|
||||
}
|
||||
for _, input := range vars {
|
||||
if input == nil {
|
||||
continue
|
||||
}
|
||||
var v any
|
||||
var err error
|
||||
switch *input.Source.Ref.VariableType {
|
||||
case vo.ParentIntermediate:
|
||||
v, err = intermediateVarStore.Get(ctx, input.Source.Ref.FromPath, opts...)
|
||||
case vo.GlobalSystem, vo.GlobalUser:
|
||||
v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...)
|
||||
case vo.GlobalAPP:
|
||||
var ok bool
|
||||
path := strings.Join(input.Source.Ref.FromPath, ".")
|
||||
if v, ok = state.GetAppVariableValue(path); !ok {
|
||||
v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
state.SetAppVariableValue(path, v)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid variable type: %v", *input.Source.Ref.VariableType)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nodes.SetMapValue(out, input.Path, v)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NodeSchema) streamStatePreHandlerForVars() compose.StreamStatePreHandler[map[string]any, *State] {
|
||||
// checkout the node's inputs, if it has any variables, get the variables and merge them with the input
|
||||
var vars []*vo.FieldInfo
|
||||
for _, input := range s.InputSources {
|
||||
if input.Source.Ref != nil && input.Source.Ref.VariableType != nil {
|
||||
vars = append(vars, input)
|
||||
}
|
||||
}
|
||||
|
||||
if len(vars) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
varStoreHandler := variable.GetVariableHandler()
|
||||
intermediateVarStore := &nodes.ParentIntermediateStore{}
|
||||
|
||||
return func(ctx context.Context, in *schema.StreamReader[map[string]any], state *State) (*schema.StreamReader[map[string]any], error) {
|
||||
var (
|
||||
variables = make(map[string]any)
|
||||
opts = make([]variable.OptionFn, 0, 1)
|
||||
exeCfg = execute.GetExeCtx(ctx).RootCtx.ExeCfg
|
||||
)
|
||||
|
||||
opts = append(opts, variable.WithStoreInfo(variable.StoreInfo{
|
||||
AgentID: exeCfg.AgentID,
|
||||
AppID: exeCfg.AppID,
|
||||
ConnectorID: exeCfg.ConnectorID,
|
||||
ConnectorUID: exeCfg.ConnectorUID,
|
||||
}))
|
||||
|
||||
for _, input := range vars {
|
||||
if input == nil {
|
||||
continue
|
||||
}
|
||||
var v any
|
||||
var err error
|
||||
switch *input.Source.Ref.VariableType {
|
||||
case vo.ParentIntermediate:
|
||||
v, err = intermediateVarStore.Get(ctx, input.Source.Ref.FromPath, opts...)
|
||||
case vo.GlobalSystem, vo.GlobalUser:
|
||||
v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...)
|
||||
case vo.GlobalAPP:
|
||||
var ok bool
|
||||
path := strings.Join(input.Source.Ref.FromPath, ".")
|
||||
if v, ok = state.GetAppVariableValue(path); !ok {
|
||||
v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
state.SetAppVariableValue(path, v)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid variable type: %v", *input.Source.Ref.VariableType)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nodes.SetMapValue(variables, input.Path, v)
|
||||
}
|
||||
|
||||
variablesStream := schema.StreamReaderFromArray([]map[string]any{variables})
|
||||
|
||||
return schema.MergeStreamReaders([]*schema.StreamReader[map[string]any]{in, variablesStream}), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamStatePreHandler[map[string]any, *State] {
|
||||
// if it does not have source info, do not add this pre handler
|
||||
if s.Configs == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch s.Type {
|
||||
case entity.NodeTypeVariableAggregator, entity.NodeTypeOutputEmitter:
|
||||
return nil
|
||||
case entity.NodeTypeExit:
|
||||
terminatePlan := mustGetKey[vo.TerminatePlan]("TerminalPlan", s.Configs)
|
||||
if terminatePlan != vo.ReturnVariables {
|
||||
return nil
|
||||
}
|
||||
default:
|
||||
// all other node can only accept non-stream inputs, relying on Eino's automatically stream concatenation.
|
||||
}
|
||||
|
||||
sourceInfo := getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs)
|
||||
if len(sourceInfo) == 0 {
|
||||
return nil
|
||||
}
|
||||
// check the node's input sources, if it does not have any streaming sources, no need to add pre handler
|
||||
// if one input is a stream, then in the pre handler, will trim the KeyIsFinished suffix.
|
||||
// if one input may be a stream, then in the pre handler, will resolve it first, then handle it.
|
||||
type resolvedStreamSource struct {
|
||||
intermediate bool
|
||||
mustBeStream bool
|
||||
subStreamSources map[string]resolvedStreamSource
|
||||
}
|
||||
|
||||
var (
|
||||
anyStream bool
|
||||
checker func(source *nodes.SourceInfo) bool
|
||||
)
|
||||
checker = func(source *nodes.SourceInfo) bool {
|
||||
if source.FieldType != nodes.FieldNotStream {
|
||||
return true
|
||||
}
|
||||
for _, subSource := range source.SubSources {
|
||||
if subAnyStream := checker(subSource); subAnyStream {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
for _, source := range sourceInfo {
|
||||
if hasStream := checker(source); hasStream {
|
||||
anyStream = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !anyStream {
|
||||
return nil
|
||||
}
|
||||
|
||||
return func(ctx context.Context, in *schema.StreamReader[map[string]any], state *State) (*schema.StreamReader[map[string]any], error) {
|
||||
resolved := map[string]resolvedStreamSource{}
|
||||
|
||||
var resolver func(source nodes.SourceInfo) (result *resolvedStreamSource, err error)
|
||||
resolver = func(source nodes.SourceInfo) (result *resolvedStreamSource, err error) {
|
||||
if source.IsIntermediate {
|
||||
result = &resolvedStreamSource{
|
||||
intermediate: true,
|
||||
subStreamSources: map[string]resolvedStreamSource{},
|
||||
}
|
||||
for key, subSource := range source.SubSources {
|
||||
subResult, subE := resolver(*subSource)
|
||||
if subE != nil {
|
||||
return nil, subE
|
||||
}
|
||||
if subResult != nil {
|
||||
result.subStreamSources[key] = *subResult
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
streamType := source.FieldType
|
||||
if streamType == nodes.FieldMaybeStream {
|
||||
streamType, err = state.GetDynamicStreamType(source.FromNodeKey, source.FromPath[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if streamType == nodes.FieldNotStream {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
result = &resolvedStreamSource{
|
||||
mustBeStream: true,
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
for key, source := range sourceInfo {
|
||||
result, err := resolver(*source)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if result != nil {
|
||||
resolved[key] = *result
|
||||
}
|
||||
}
|
||||
|
||||
var converter func(v any, resolvedSource resolvedStreamSource) any
|
||||
converter = func(v any, resolvedSource resolvedStreamSource) any {
|
||||
if resolvedSource.intermediate {
|
||||
vMap, ok := v.(map[string]any)
|
||||
if !ok {
|
||||
panic("intermediate value is not map[string]any")
|
||||
}
|
||||
outMap := make(map[string]any, len(vMap))
|
||||
for k := range vMap {
|
||||
subResolvedSource, ok := resolvedSource.subStreamSources[k]
|
||||
if !ok { // not a stream field
|
||||
outMap[k] = vMap[k]
|
||||
continue
|
||||
}
|
||||
|
||||
subV := converter(vMap[k], subResolvedSource)
|
||||
outMap[k] = subV
|
||||
}
|
||||
|
||||
return outMap
|
||||
}
|
||||
|
||||
vStr, ok := v.(string)
|
||||
if !ok {
|
||||
panic("stream field is not string")
|
||||
}
|
||||
|
||||
return strings.TrimSuffix(vStr, nodes.KeyIsFinished)
|
||||
}
|
||||
|
||||
streamConverter := func(inChunk map[string]any) (outChunk map[string]any, chunkErr error) {
|
||||
outChunk = make(map[string]any, len(inChunk))
|
||||
for k, v := range inChunk {
|
||||
if resolvedSource, ok := resolved[k]; !ok {
|
||||
outChunk[k] = v // not a stream field
|
||||
} else {
|
||||
vOut := converter(v, resolvedSource)
|
||||
outChunk[k] = vOut
|
||||
}
|
||||
}
|
||||
|
||||
return outChunk, nil
|
||||
}
|
||||
|
||||
return schema.StreamReaderWithConvert(in, streamConverter), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NodeSchema) StatePostHandler(stream bool) compose.GraphAddNodeOpt {
|
||||
var (
|
||||
handlers []compose.StatePostHandler[map[string]any, *State]
|
||||
streamHandlers []compose.StreamStatePostHandler[map[string]any, *State]
|
||||
)
|
||||
|
||||
if stream {
|
||||
streamHandlers = append(streamHandlers, func(ctx context.Context, out *schema.StreamReader[map[string]any], state *State) (*schema.StreamReader[map[string]any], error) {
|
||||
state.ExecutedNodes[s.Key] = true
|
||||
return out, nil
|
||||
})
|
||||
|
||||
forVars := s.streamStatePostHandlerForVars()
|
||||
if forVars != nil {
|
||||
streamHandlers = append(streamHandlers, forVars)
|
||||
}
|
||||
|
||||
streamHandler := func(ctx context.Context, in *schema.StreamReader[map[string]any], state *State) (*schema.StreamReader[map[string]any], error) {
|
||||
var err error
|
||||
for _, h := range streamHandlers {
|
||||
in, err = h(ctx, in, state)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return in, nil
|
||||
}
|
||||
return compose.WithStreamStatePostHandler(streamHandler)
|
||||
}
|
||||
|
||||
handlers = append(handlers, func(ctx context.Context, out map[string]any, state *State) (map[string]any, error) {
|
||||
state.ExecutedNodes[s.Key] = true
|
||||
return out, nil
|
||||
})
|
||||
|
||||
forVars := s.statePostHandlerForVars()
|
||||
if forVars != nil {
|
||||
handlers = append(handlers, forVars)
|
||||
}
|
||||
|
||||
handler := func(ctx context.Context, in map[string]any, state *State) (map[string]any, error) {
|
||||
var err error
|
||||
for _, h := range handlers {
|
||||
in, err = h(ctx, in, state)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return in, nil
|
||||
}
|
||||
|
||||
return compose.WithStatePostHandler(handler)
|
||||
}
|
||||
|
||||
func (s *NodeSchema) statePostHandlerForVars() compose.StatePostHandler[map[string]any, *State] {
|
||||
// checkout the node's output sources, if it has any variable,
|
||||
// use the state's variableHandler to get the variables and set them to the output
|
||||
var vars []*vo.FieldInfo
|
||||
for _, output := range s.OutputSources {
|
||||
if output.Source.Ref != nil && output.Source.Ref.VariableType != nil {
|
||||
// intermediate vars are handled within nodes themselves
|
||||
if *output.Source.Ref.VariableType == vo.ParentIntermediate {
|
||||
continue
|
||||
}
|
||||
vars = append(vars, output)
|
||||
}
|
||||
}
|
||||
|
||||
if len(vars) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
varStoreHandler := variable.GetVariableHandler()
|
||||
|
||||
return func(ctx context.Context, in map[string]any, state *State) (map[string]any, error) {
|
||||
opts := make([]variable.OptionFn, 0, 1)
|
||||
|
||||
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil {
|
||||
exeCfg := execute.GetExeCtx(ctx).RootCtx.ExeCfg
|
||||
opts = append(opts, variable.WithStoreInfo(variable.StoreInfo{
|
||||
AgentID: exeCfg.AgentID,
|
||||
AppID: exeCfg.AppID,
|
||||
ConnectorID: exeCfg.ConnectorID,
|
||||
ConnectorUID: exeCfg.ConnectorUID,
|
||||
}))
|
||||
}
|
||||
out := make(map[string]any)
|
||||
for k, v := range in {
|
||||
out[k] = v
|
||||
}
|
||||
for _, input := range vars {
|
||||
if input == nil {
|
||||
continue
|
||||
}
|
||||
var v any
|
||||
var err error
|
||||
switch *input.Source.Ref.VariableType {
|
||||
case vo.GlobalSystem, vo.GlobalUser:
|
||||
v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...)
|
||||
case vo.GlobalAPP:
|
||||
var ok bool
|
||||
path := strings.Join(input.Source.Ref.FromPath, ".")
|
||||
if v, ok = state.GetAppVariableValue(path); !ok {
|
||||
v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
state.SetAppVariableValue(path, v)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid variable type: %v", *input.Source.Ref.VariableType)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nodes.SetMapValue(out, input.Path, v)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NodeSchema) streamStatePostHandlerForVars() compose.StreamStatePostHandler[map[string]any, *State] {
|
||||
// checkout the node's output sources, if it has any variables, get the variables and merge them with the output
|
||||
var vars []*vo.FieldInfo
|
||||
for _, output := range s.OutputSources {
|
||||
if output.Source.Ref != nil && output.Source.Ref.VariableType != nil {
|
||||
// intermediate vars are handled within nodes themselves
|
||||
if *output.Source.Ref.VariableType == vo.ParentIntermediate {
|
||||
continue
|
||||
}
|
||||
vars = append(vars, output)
|
||||
}
|
||||
}
|
||||
|
||||
if len(vars) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
varStoreHandler := variable.GetVariableHandler()
|
||||
return func(ctx context.Context, in *schema.StreamReader[map[string]any], state *State) (*schema.StreamReader[map[string]any], error) {
|
||||
var (
|
||||
variables = make(map[string]any)
|
||||
opts = make([]variable.OptionFn, 0, 1)
|
||||
)
|
||||
|
||||
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil {
|
||||
exeCfg := execute.GetExeCtx(ctx).RootCtx.ExeCfg
|
||||
opts = append(opts, variable.WithStoreInfo(variable.StoreInfo{
|
||||
AgentID: exeCfg.AgentID,
|
||||
AppID: exeCfg.AppID,
|
||||
ConnectorID: exeCfg.ConnectorID,
|
||||
ConnectorUID: exeCfg.ConnectorUID,
|
||||
}))
|
||||
}
|
||||
|
||||
for _, input := range vars {
|
||||
if input == nil {
|
||||
continue
|
||||
}
|
||||
var v any
|
||||
var err error
|
||||
switch *input.Source.Ref.VariableType {
|
||||
case vo.GlobalSystem, vo.GlobalUser:
|
||||
v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...)
|
||||
case vo.GlobalAPP:
|
||||
var ok bool
|
||||
path := strings.Join(input.Source.Ref.FromPath, ".")
|
||||
if v, ok = state.GetAppVariableValue(path); !ok {
|
||||
v, err = varStoreHandler.Get(ctx, *input.Source.Ref.VariableType, input.Source.Ref.FromPath, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
state.SetAppVariableValue(path, v)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid variable type: %v", *input.Source.Ref.VariableType)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nodes.SetMapValue(variables, input.Path, v)
|
||||
}
|
||||
|
||||
variablesStream := schema.StreamReaderFromArray([]map[string]any{variables})
|
||||
|
||||
return schema.MergeStreamReaders([]*schema.StreamReader[map[string]any]{in, variablesStream}), nil
|
||||
}
|
||||
}
|
||||
|
||||
func GenStateModifierByEventType(e entity.InterruptEventType,
|
||||
nodeKey vo.NodeKey,
|
||||
resumeData string,
|
||||
exeCfg vo.ExecuteConfig) (stateModifier compose.StateModifier) {
|
||||
// TODO: can we unify them all to a map[NodeKey]resumeData?
|
||||
switch e {
|
||||
case entity.InterruptEventInput:
|
||||
stateModifier = func(ctx context.Context, path compose.NodePath, state any) (err error) {
|
||||
if exeCfg.BizType == vo.BizTypeAgent {
|
||||
m := make(map[string]any)
|
||||
sList := strings.Split(resumeData, "\n")
|
||||
for _, s := range sList {
|
||||
firstColon := strings.Index(s, ":")
|
||||
k := s[:firstColon]
|
||||
v := s[firstColon+1:]
|
||||
m[k] = v
|
||||
}
|
||||
resumeData, err = sonic.MarshalString(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
input := map[string]any{
|
||||
receiver.ReceivedDataKey: resumeData,
|
||||
}
|
||||
state.(*State).Inputs[nodeKey] = input
|
||||
return nil
|
||||
}
|
||||
case entity.InterruptEventQuestion:
|
||||
stateModifier = func(ctx context.Context, path compose.NodePath, state any) error {
|
||||
state.(*State).AddAnswer(nodeKey, resumeData)
|
||||
return nil
|
||||
}
|
||||
case entity.InterruptEventLLM:
|
||||
stateModifier = func(ctx context.Context, path compose.NodePath, state any) error {
|
||||
state.(*State).LLMToResumeData[nodeKey] = resumeData
|
||||
return nil
|
||||
}
|
||||
default:
|
||||
panic(fmt.Sprintf("unimplemented interrupt event type: %v", e))
|
||||
}
|
||||
|
||||
return stateModifier
|
||||
}
|
||||
292
backend/domain/workflow/internal/compose/stream.go
Normal file
292
backend/domain/workflow/internal/compose/stream.go
Normal file
@@ -0,0 +1,292 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package compose
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
)
|
||||
|
||||
// SetFullSources calculates REAL input sources for a node.
|
||||
// It may be different from a NodeSchema's InputSources because of the following reasons:
|
||||
// 1. a inner node under composite node may refer to a field from a node in its parent workflow,
|
||||
// this is instead routed to and sourced from the inner workflow's start node.
|
||||
// 2. at the same time, the composite node needs to delegate the input source to the inner workflow.
|
||||
// 3. also, some node may have implicit input sources not defined in its NodeSchema's InputSources.
|
||||
func (s *NodeSchema) SetFullSources(allNS map[vo.NodeKey]*NodeSchema, dep *dependencyInfo) error {
|
||||
fullSource := make(map[string]*nodes.SourceInfo)
|
||||
var fieldInfos []vo.FieldInfo
|
||||
for _, s := range dep.staticValues {
|
||||
fieldInfos = append(fieldInfos, vo.FieldInfo{
|
||||
Path: s.path,
|
||||
Source: vo.FieldSource{Val: s.val},
|
||||
})
|
||||
}
|
||||
|
||||
for _, v := range dep.variableInfos {
|
||||
fieldInfos = append(fieldInfos, vo.FieldInfo{
|
||||
Path: v.toPath,
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
VariableType: &v.varType,
|
||||
FromPath: v.fromPath[1:],
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
for f := range dep.inputsFull {
|
||||
fieldInfos = append(fieldInfos, vo.FieldInfo{
|
||||
Path: []string{""},
|
||||
Source: vo.FieldSource{Ref: &vo.Reference{
|
||||
FromNodeKey: f,
|
||||
FromPath: []string{""},
|
||||
}},
|
||||
})
|
||||
}
|
||||
|
||||
for f, ms := range dep.inputs {
|
||||
for _, m := range ms {
|
||||
fieldInfos = append(fieldInfos, vo.FieldInfo{
|
||||
Path: m.ToPath(),
|
||||
Source: vo.FieldSource{Ref: &vo.Reference{
|
||||
FromNodeKey: f,
|
||||
FromPath: m.FromPath(),
|
||||
}},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for f := range dep.inputsNoDirectDependencyFull {
|
||||
fieldInfos = append(fieldInfos, vo.FieldInfo{
|
||||
Path: []string{""},
|
||||
Source: vo.FieldSource{Ref: &vo.Reference{
|
||||
FromNodeKey: f,
|
||||
FromPath: []string{""},
|
||||
}},
|
||||
})
|
||||
}
|
||||
|
||||
for f, ms := range dep.inputsNoDirectDependency {
|
||||
for _, m := range ms {
|
||||
fieldInfos = append(fieldInfos, vo.FieldInfo{
|
||||
Path: m.ToPath(),
|
||||
Source: vo.FieldSource{Ref: &vo.Reference{
|
||||
FromNodeKey: f,
|
||||
FromPath: m.FromPath(),
|
||||
}},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for i := range fieldInfos {
|
||||
fInfo := fieldInfos[i]
|
||||
path := fInfo.Path
|
||||
currentSource := fullSource
|
||||
var (
|
||||
tInfo *vo.TypeInfo
|
||||
lastPath string
|
||||
)
|
||||
if len(path) > 1 {
|
||||
tInfo = s.InputTypes[path[0]]
|
||||
for j := 0; j < len(path)-1; j++ {
|
||||
if j > 0 {
|
||||
tInfo = tInfo.Properties[path[j]]
|
||||
}
|
||||
if current, ok := currentSource[path[j]]; !ok {
|
||||
currentSource[path[j]] = &nodes.SourceInfo{
|
||||
IsIntermediate: true,
|
||||
FieldType: nodes.FieldNotStream,
|
||||
TypeInfo: tInfo,
|
||||
SubSources: make(map[string]*nodes.SourceInfo),
|
||||
}
|
||||
} else if !current.IsIntermediate {
|
||||
return fmt.Errorf("existing sourceInfo for path %s is not intermediate, conflict", path[:j+1])
|
||||
}
|
||||
|
||||
currentSource = currentSource[path[j]].SubSources
|
||||
}
|
||||
|
||||
lastPath = path[len(path)-1]
|
||||
tInfo = tInfo.Properties[lastPath]
|
||||
} else {
|
||||
lastPath = path[0]
|
||||
tInfo = s.InputTypes[lastPath]
|
||||
}
|
||||
|
||||
// static values or variables
|
||||
if fInfo.Source.Ref == nil || fInfo.Source.Ref.FromNodeKey == "" {
|
||||
currentSource[lastPath] = &nodes.SourceInfo{
|
||||
IsIntermediate: false,
|
||||
FieldType: nodes.FieldNotStream,
|
||||
TypeInfo: tInfo,
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
fromNodeKey := fInfo.Source.Ref.FromNodeKey
|
||||
var (
|
||||
streamType nodes.FieldStreamType
|
||||
err error
|
||||
)
|
||||
if len(fromNodeKey) > 0 {
|
||||
if fromNodeKey == compose.START {
|
||||
streamType = nodes.FieldNotStream // TODO: set start node to not stream for now until composite node supports transform
|
||||
} else {
|
||||
fromNode, ok := allNS[fromNodeKey]
|
||||
if !ok {
|
||||
return fmt.Errorf("node %s not found", fromNodeKey)
|
||||
}
|
||||
streamType, err = fromNode.IsStreamingField(fInfo.Source.Ref.FromPath, allNS)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
currentSource[lastPath] = &nodes.SourceInfo{
|
||||
IsIntermediate: false,
|
||||
FieldType: streamType,
|
||||
FromNodeKey: fromNodeKey,
|
||||
FromPath: fInfo.Source.Ref.FromPath,
|
||||
TypeInfo: tInfo,
|
||||
}
|
||||
}
|
||||
|
||||
s.Configs.(map[string]any)["FullSources"] = fullSource
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) IsStreamingField(path compose.FieldPath, allNS map[vo.NodeKey]*NodeSchema) (nodes.FieldStreamType, error) {
|
||||
if s.Type == entity.NodeTypeExit {
|
||||
if mustGetKey[nodes.Mode]("Mode", s.Configs) == nodes.Streaming {
|
||||
if len(path) == 1 && path[0] == "output" {
|
||||
return nodes.FieldIsStream, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nodes.FieldNotStream, nil
|
||||
} else if s.Type == entity.NodeTypeSubWorkflow { // TODO: why not use sub workflow's Mode configuration directly?
|
||||
subSC := s.SubWorkflowSchema
|
||||
subExit := subSC.GetNode(entity.ExitNodeKey)
|
||||
subStreamType, err := subExit.IsStreamingField(path, nil)
|
||||
if err != nil {
|
||||
return nodes.FieldNotStream, err
|
||||
}
|
||||
|
||||
return subStreamType, nil
|
||||
} else if s.Type == entity.NodeTypeVariableAggregator {
|
||||
if len(path) == 2 { // asking about a specific index within a group
|
||||
for _, fInfo := range s.InputSources {
|
||||
if len(fInfo.Path) == len(path) {
|
||||
equal := true
|
||||
for i := range fInfo.Path {
|
||||
if fInfo.Path[i] != path[i] {
|
||||
equal = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if equal {
|
||||
if fInfo.Source.Ref == nil || fInfo.Source.Ref.FromNodeKey == "" {
|
||||
return nodes.FieldNotStream, nil
|
||||
}
|
||||
fromNodeKey := fInfo.Source.Ref.FromNodeKey
|
||||
fromNode, ok := allNS[fromNodeKey]
|
||||
if !ok {
|
||||
return nodes.FieldNotStream, fmt.Errorf("node %s not found", fromNodeKey)
|
||||
}
|
||||
return fromNode.IsStreamingField(fInfo.Source.Ref.FromPath, allNS)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if len(path) == 1 { // asking about the entire group
|
||||
var streamCount, notStreamCount int
|
||||
for _, fInfo := range s.InputSources {
|
||||
if fInfo.Path[0] == path[0] { // belong to the group
|
||||
if fInfo.Source.Ref != nil && len(fInfo.Source.Ref.FromNodeKey) > 0 {
|
||||
fromNode, ok := allNS[fInfo.Source.Ref.FromNodeKey]
|
||||
if !ok {
|
||||
return nodes.FieldNotStream, fmt.Errorf("node %s not found", fInfo.Source.Ref.FromNodeKey)
|
||||
}
|
||||
subStreamType, err := fromNode.IsStreamingField(fInfo.Source.Ref.FromPath, allNS)
|
||||
if err != nil {
|
||||
return nodes.FieldNotStream, err
|
||||
}
|
||||
|
||||
if subStreamType == nodes.FieldMaybeStream {
|
||||
return nodes.FieldMaybeStream, nil
|
||||
} else if subStreamType == nodes.FieldIsStream {
|
||||
streamCount++
|
||||
} else {
|
||||
notStreamCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if streamCount > 0 && notStreamCount == 0 {
|
||||
return nodes.FieldIsStream, nil
|
||||
}
|
||||
|
||||
if streamCount == 0 && notStreamCount > 0 {
|
||||
return nodes.FieldNotStream, nil
|
||||
}
|
||||
|
||||
return nodes.FieldMaybeStream, nil
|
||||
}
|
||||
}
|
||||
|
||||
if s.Type != entity.NodeTypeLLM {
|
||||
return nodes.FieldNotStream, nil
|
||||
}
|
||||
|
||||
if len(path) != 1 {
|
||||
return nodes.FieldNotStream, nil
|
||||
}
|
||||
|
||||
outputs := s.OutputTypes
|
||||
if len(outputs) != 1 && len(outputs) != 2 {
|
||||
return nodes.FieldNotStream, nil
|
||||
}
|
||||
|
||||
var outputKey string
|
||||
for key, output := range outputs {
|
||||
if output.Type != vo.DataTypeString {
|
||||
return nodes.FieldNotStream, nil
|
||||
}
|
||||
|
||||
if key != "reasoning_content" {
|
||||
if len(outputKey) > 0 {
|
||||
return nodes.FieldNotStream, nil
|
||||
}
|
||||
outputKey = key
|
||||
}
|
||||
}
|
||||
|
||||
field := path[0]
|
||||
if field == "reasoning_content" || field == outputKey {
|
||||
return nodes.FieldIsStream, nil
|
||||
}
|
||||
|
||||
return nodes.FieldNotStream, nil
|
||||
}
|
||||
349
backend/domain/workflow/internal/compose/test/batch_test.go
Normal file
349
backend/domain/workflow/internal/compose/test/batch_test.go
Normal file
@@ -0,0 +1,349 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/batch"
|
||||
)
|
||||
|
||||
func TestBatch(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
lambda1 := func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
|
||||
if in["index"].(int64) > 2 {
|
||||
return nil, fmt.Errorf("index= %d is too large", in["index"].(int64))
|
||||
}
|
||||
|
||||
out = make(map[string]any)
|
||||
out["output_1"] = fmt.Sprintf("%s_%v_%d", in["array_1"].(string), in["from_parent_wf"].(bool), in["index"].(int64))
|
||||
return out, nil
|
||||
}
|
||||
|
||||
lambda2 := func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
|
||||
return map[string]any{"index": in["index"]}, nil
|
||||
}
|
||||
|
||||
lambda3 := func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
|
||||
t.Log(in["consumer_1"].(string), in["array_2"].(int64), in["static_source"].(string))
|
||||
return in, nil
|
||||
}
|
||||
|
||||
lambdaNode1 := &compose2.NodeSchema{
|
||||
Key: "lambda",
|
||||
Type: entity.NodeTypeLambda,
|
||||
Lambda: compose.InvokableLambda(lambda1),
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"index"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: "batch_node_key",
|
||||
FromPath: compose.FieldPath{"index"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: compose.FieldPath{"array_1"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: "batch_node_key",
|
||||
FromPath: compose.FieldPath{"array_1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: compose.FieldPath{"from_parent_wf"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: "parent_predecessor_1",
|
||||
FromPath: compose.FieldPath{"success"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
lambdaNode2 := &compose2.NodeSchema{
|
||||
Key: "index",
|
||||
Type: entity.NodeTypeLambda,
|
||||
Lambda: compose.InvokableLambda(lambda2),
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"index"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: "batch_node_key",
|
||||
FromPath: compose.FieldPath{"index"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
lambdaNode3 := &compose2.NodeSchema{
|
||||
Key: "consumer",
|
||||
Type: entity.NodeTypeLambda,
|
||||
Lambda: compose.InvokableLambda(lambda3),
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"consumer_1"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: "lambda",
|
||||
FromPath: compose.FieldPath{"output_1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: compose.FieldPath{"array_2"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: "batch_node_key",
|
||||
FromPath: compose.FieldPath{"array_2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: compose.FieldPath{"static_source"},
|
||||
Source: vo.FieldSource{
|
||||
Val: "this is a const",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
entry := &compose2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
}
|
||||
|
||||
ns := &compose2.NodeSchema{
|
||||
Key: "batch_node_key",
|
||||
Type: entity.NodeTypeBatch,
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"array_1"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromPath: compose.FieldPath{"array_1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: compose.FieldPath{"array_2"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromPath: compose.FieldPath{"array_2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: compose.FieldPath{batch.ConcurrentSizeKey},
|
||||
Source: vo.FieldSource{
|
||||
Val: int64(2),
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: compose.FieldPath{batch.MaxBatchSizeKey},
|
||||
Source: vo.FieldSource{
|
||||
Val: int64(5),
|
||||
},
|
||||
},
|
||||
},
|
||||
InputTypes: map[string]*vo.TypeInfo{
|
||||
"array_1": {
|
||||
Type: vo.DataTypeArray,
|
||||
ElemTypeInfo: &vo.TypeInfo{
|
||||
Type: vo.DataTypeString,
|
||||
},
|
||||
},
|
||||
"array_2": {
|
||||
Type: vo.DataTypeArray,
|
||||
ElemTypeInfo: &vo.TypeInfo{
|
||||
Type: vo.DataTypeInteger,
|
||||
},
|
||||
},
|
||||
},
|
||||
OutputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"assembled_output_1"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: "lambda",
|
||||
FromPath: compose.FieldPath{"output_1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: compose.FieldPath{"assembled_output_2"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: "index",
|
||||
FromPath: compose.FieldPath{"index"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"assembled_output_1"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: "batch_node_key",
|
||||
FromPath: compose.FieldPath{"assembled_output_1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: compose.FieldPath{"assembled_output_2"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: "batch_node_key",
|
||||
FromPath: compose.FieldPath{"assembled_output_2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
parentLambda := func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
|
||||
return map[string]any{"success": true}, nil
|
||||
}
|
||||
|
||||
parentLambdaNode := &compose2.NodeSchema{
|
||||
Key: "parent_predecessor_1",
|
||||
Type: entity.NodeTypeLambda,
|
||||
Lambda: compose.InvokableLambda(parentLambda),
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
parentLambdaNode,
|
||||
ns,
|
||||
exit,
|
||||
lambdaNode1,
|
||||
lambdaNode2,
|
||||
lambdaNode3,
|
||||
},
|
||||
Hierarchy: map[vo.NodeKey]vo.NodeKey{
|
||||
"lambda": "batch_node_key",
|
||||
"index": "batch_node_key",
|
||||
"consumer": "batch_node_key",
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
{
|
||||
FromNode: entity.EntryNodeKey,
|
||||
ToNode: "parent_predecessor_1",
|
||||
},
|
||||
{
|
||||
FromNode: "parent_predecessor_1",
|
||||
ToNode: "batch_node_key",
|
||||
},
|
||||
{
|
||||
FromNode: "batch_node_key",
|
||||
ToNode: "lambda",
|
||||
},
|
||||
{
|
||||
FromNode: "lambda",
|
||||
ToNode: "index",
|
||||
},
|
||||
{
|
||||
FromNode: "lambda",
|
||||
ToNode: "consumer",
|
||||
},
|
||||
{
|
||||
FromNode: "index",
|
||||
ToNode: "batch_node_key",
|
||||
},
|
||||
{
|
||||
FromNode: "consumer",
|
||||
ToNode: "batch_node_key",
|
||||
},
|
||||
{
|
||||
FromNode: "batch_node_key",
|
||||
ToNode: entity.ExitNodeKey,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ws.Init()
|
||||
|
||||
wf, err := compose2.NewWorkflow(ctx, ws)
|
||||
assert.NoError(t, err)
|
||||
|
||||
out, err := wf.Runner.Invoke(ctx, map[string]any{
|
||||
"array_1": []any{"a", "b", "c"},
|
||||
"array_2": []any{int64(1), int64(2), int64(3), int64(4)},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, map[string]any{
|
||||
"assembled_output_1": []any{"a_true_0", "b_true_1", "c_true_2"},
|
||||
"assembled_output_2": []any{int64(0), int64(1), int64(2)},
|
||||
}, out)
|
||||
|
||||
// input array is empty
|
||||
out, err = wf.Runner.Invoke(ctx, map[string]any{
|
||||
"array_1": []any{},
|
||||
"array_2": []any{int64(1)},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, map[string]any{
|
||||
"assembled_output_1": []any{},
|
||||
"assembled_output_2": []any{},
|
||||
}, out)
|
||||
|
||||
// less than concurrency
|
||||
out, err = wf.Runner.Invoke(ctx, map[string]any{
|
||||
"array_1": []any{"a"},
|
||||
"array_2": []any{int64(1), int64(2)},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, map[string]any{
|
||||
"assembled_output_1": []any{"a_true_0"},
|
||||
"assembled_output_2": []any{int64(0)},
|
||||
}, out)
|
||||
|
||||
// err by inner node
|
||||
_, err = wf.Runner.Invoke(ctx, map[string]any{
|
||||
"array_1": []any{"a", "b", "c", "d", "e", "f"},
|
||||
"array_2": []any{int64(1), int64(2), int64(3), int64(4), int64(5), int64(6), int64(7)},
|
||||
})
|
||||
assert.ErrorContains(t, err, "is too large")
|
||||
}
|
||||
679
backend/domain/workflow/internal/compose/test/llm_test.go
Normal file
679
backend/domain/workflow/internal/compose/test/llm_test.go
Normal file
@@ -0,0 +1,679 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/bytedance/mockey"
|
||||
"github.com/cloudwego/eino-ext/components/model/deepseek"
|
||||
"github.com/cloudwego/eino-ext/components/model/openai"
|
||||
"github.com/cloudwego/eino/callbacks"
|
||||
model2 "github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
|
||||
mockmodel "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model/modelmock"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm"
|
||||
"github.com/coze-dev/coze-studio/backend/internal/testutil"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
)
|
||||
|
||||
func TestLLM(t *testing.T) {
|
||||
mockey.PatchConvey("test llm", t, func() {
|
||||
accessKey := os.Getenv("OPENAI_API_KEY")
|
||||
baseURL := os.Getenv("OPENAI_BASE_URL")
|
||||
modelName := os.Getenv("OPENAI_MODEL_NAME")
|
||||
var (
|
||||
openaiModel, deepSeekModel model2.BaseChatModel
|
||||
err error
|
||||
)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockModelManager := mockmodel.NewMockManager(ctrl)
|
||||
mockey.Mock(model.GetManager).Return(mockModelManager).Build()
|
||||
|
||||
if len(accessKey) > 0 && len(baseURL) > 0 && len(modelName) > 0 {
|
||||
openaiModel, err = openai.NewChatModel(context.Background(), &openai.ChatModelConfig{
|
||||
APIKey: accessKey,
|
||||
ByAzure: true,
|
||||
BaseURL: baseURL,
|
||||
Model: modelName,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
deepSeekModelName := os.Getenv("DEEPSEEK_MODEL_NAME")
|
||||
if len(accessKey) > 0 && len(baseURL) > 0 && len(deepSeekModelName) > 0 {
|
||||
deepSeekModel, err = deepseek.NewChatModel(context.Background(), &deepseek.ChatModelConfig{
|
||||
APIKey: accessKey,
|
||||
BaseURL: baseURL,
|
||||
Model: deepSeekModelName,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
mockModelManager.EXPECT().GetModel(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, params *model.LLMParams) (model2.BaseChatModel, *modelmgr.Model, error) {
|
||||
if params.ModelName == modelName {
|
||||
return openaiModel, nil, nil
|
||||
} else if params.ModelName == deepSeekModelName {
|
||||
return deepSeekModel, nil, nil
|
||||
} else {
|
||||
return nil, nil, fmt.Errorf("invalid model name: %s", params.ModelName)
|
||||
}
|
||||
}).AnyTimes()
|
||||
|
||||
ctx := ctxcache.Init(context.Background())
|
||||
|
||||
t.Run("plain text output, non-streaming mode", func(t *testing.T) {
|
||||
if openaiModel == nil {
|
||||
defer func() {
|
||||
openaiModel = nil
|
||||
}()
|
||||
openaiModel = &testutil.UTChatModel{
|
||||
InvokeResultProvider: func(_ int, in []*schema.Message) (*schema.Message, error) {
|
||||
return &schema.Message{
|
||||
Role: schema.Assistant,
|
||||
Content: "I don't know",
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
entry := &compose2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
}
|
||||
|
||||
llmNode := &compose2.NodeSchema{
|
||||
Key: "llm_node_key",
|
||||
Type: entity.NodeTypeLLM,
|
||||
Configs: map[string]any{
|
||||
"SystemPrompt": "{{sys_prompt}}",
|
||||
"UserPrompt": "{{query}}",
|
||||
"OutputFormat": llm.FormatText,
|
||||
"LLMParams": &model.LLMParams{
|
||||
ModelName: modelName,
|
||||
},
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"sys_prompt"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromPath: compose.FieldPath{"sys_prompt"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: compose.FieldPath{"query"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromPath: compose.FieldPath{"query"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
InputTypes: map[string]*vo.TypeInfo{
|
||||
"sys_prompt": {
|
||||
Type: vo.DataTypeString,
|
||||
},
|
||||
"query": {
|
||||
Type: vo.DataTypeString,
|
||||
},
|
||||
},
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
"output": {
|
||||
Type: vo.DataTypeString,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"output"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: llmNode.Key,
|
||||
FromPath: compose.FieldPath{"output"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
llmNode,
|
||||
exit,
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
ToNode: llmNode.Key,
|
||||
},
|
||||
{
|
||||
FromNode: llmNode.Key,
|
||||
ToNode: exit.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ws.Init()
|
||||
|
||||
wf, err := compose2.NewWorkflow(ctx, ws)
|
||||
assert.NoError(t, err)
|
||||
|
||||
out, err := wf.Runner.Invoke(ctx, map[string]any{
|
||||
"sys_prompt": "you are a helpful assistant",
|
||||
"query": "what's your name",
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Greater(t, len(out), 0)
|
||||
assert.Greater(t, len(out["output"].(string)), 0)
|
||||
})
|
||||
|
||||
t.Run("json output", func(t *testing.T) {
|
||||
if openaiModel == nil {
|
||||
defer func() {
|
||||
openaiModel = nil
|
||||
}()
|
||||
openaiModel = &testutil.UTChatModel{
|
||||
InvokeResultProvider: func(_ int, in []*schema.Message) (*schema.Message, error) {
|
||||
return &schema.Message{
|
||||
Role: schema.Assistant,
|
||||
Content: `{"country_name": "Russia", "area_size": 17075400}`,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
entry := &compose2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
}
|
||||
|
||||
llmNode := &compose2.NodeSchema{
|
||||
Key: "llm_node_key",
|
||||
Type: entity.NodeTypeLLM,
|
||||
Configs: map[string]any{
|
||||
"SystemPrompt": "you are a helpful assistant",
|
||||
"UserPrompt": "what's the largest country in the world and it's area size in square kilometers?",
|
||||
"OutputFormat": llm.FormatJSON,
|
||||
"IgnoreException": true,
|
||||
"DefaultOutput": map[string]any{
|
||||
"country_name": "unknown",
|
||||
"area_size": int64(0),
|
||||
},
|
||||
"LLMParams": &model.LLMParams{
|
||||
ModelName: modelName,
|
||||
},
|
||||
},
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
"country_name": {
|
||||
Type: vo.DataTypeString,
|
||||
Required: true,
|
||||
},
|
||||
"area_size": {
|
||||
Type: vo.DataTypeInteger,
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"country_name"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: llmNode.Key,
|
||||
FromPath: compose.FieldPath{"country_name"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: compose.FieldPath{"area_size"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: llmNode.Key,
|
||||
FromPath: compose.FieldPath{"area_size"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
llmNode,
|
||||
exit,
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
ToNode: llmNode.Key,
|
||||
},
|
||||
{
|
||||
FromNode: llmNode.Key,
|
||||
ToNode: exit.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ws.Init()
|
||||
|
||||
wf, err := compose2.NewWorkflow(ctx, ws)
|
||||
assert.NoError(t, err)
|
||||
|
||||
out, err := wf.Runner.Invoke(ctx, map[string]any{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, out["country_name"], "Russia")
|
||||
assert.Greater(t, out["area_size"], int64(1000))
|
||||
})
|
||||
|
||||
t.Run("markdown output", func(t *testing.T) {
|
||||
if openaiModel == nil {
|
||||
defer func() {
|
||||
openaiModel = nil
|
||||
}()
|
||||
openaiModel = &testutil.UTChatModel{
|
||||
InvokeResultProvider: func(_ int, in []*schema.Message) (*schema.Message, error) {
|
||||
return &schema.Message{
|
||||
Role: schema.Assistant,
|
||||
Content: `#Top 5 Largest Countries in the World ## 1. Russia 2. Canada 3. United States 4. Brazil 5. Japan`,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
entry := &compose2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
}
|
||||
|
||||
llmNode := &compose2.NodeSchema{
|
||||
Key: "llm_node_key",
|
||||
Type: entity.NodeTypeLLM,
|
||||
Configs: map[string]any{
|
||||
"SystemPrompt": "you are a helpful assistant",
|
||||
"UserPrompt": "list the top 5 largest countries in the world",
|
||||
"OutputFormat": llm.FormatMarkdown,
|
||||
"LLMParams": &model.LLMParams{
|
||||
ModelName: modelName,
|
||||
},
|
||||
},
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
"output": {
|
||||
Type: vo.DataTypeString,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"output"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: llmNode.Key,
|
||||
FromPath: compose.FieldPath{"output"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
llmNode,
|
||||
exit,
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
ToNode: llmNode.Key,
|
||||
},
|
||||
{
|
||||
FromNode: llmNode.Key,
|
||||
ToNode: exit.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ws.Init()
|
||||
|
||||
wf, err := compose2.NewWorkflow(ctx, ws)
|
||||
assert.NoError(t, err)
|
||||
|
||||
out, err := wf.Runner.Invoke(ctx, map[string]any{})
|
||||
assert.NoError(t, err)
|
||||
assert.Greater(t, len(out["output"].(string)), 0)
|
||||
})
|
||||
|
||||
t.Run("plain text output, streaming mode", func(t *testing.T) {
|
||||
// start -> fan out to openai LLM and deepseek LLM -> fan in to output emitter -> end
|
||||
if openaiModel == nil || deepSeekModel == nil {
|
||||
if openaiModel == nil {
|
||||
defer func() {
|
||||
openaiModel = nil
|
||||
}()
|
||||
openaiModel = &testutil.UTChatModel{
|
||||
StreamResultProvider: func(_ int, in []*schema.Message) (*schema.StreamReader[*schema.Message], error) {
|
||||
sr := schema.StreamReaderFromArray([]*schema.Message{
|
||||
{
|
||||
Role: schema.Assistant,
|
||||
Content: "I ",
|
||||
},
|
||||
{
|
||||
Role: schema.Assistant,
|
||||
Content: "don't know.",
|
||||
},
|
||||
})
|
||||
return sr, nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if deepSeekModel == nil {
|
||||
defer func() {
|
||||
deepSeekModel = nil
|
||||
}()
|
||||
deepSeekModel = &testutil.UTChatModel{
|
||||
StreamResultProvider: func(_ int, in []*schema.Message) (*schema.StreamReader[*schema.Message], error) {
|
||||
sr := schema.StreamReaderFromArray([]*schema.Message{
|
||||
{
|
||||
Role: schema.Assistant,
|
||||
Content: "I ",
|
||||
},
|
||||
{
|
||||
Role: schema.Assistant,
|
||||
Content: "don't know too.",
|
||||
},
|
||||
})
|
||||
return sr, nil
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
entry := &compose2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
}
|
||||
|
||||
openaiNode := &compose2.NodeSchema{
|
||||
Key: "openai_llm_node_key",
|
||||
Type: entity.NodeTypeLLM,
|
||||
Configs: map[string]any{
|
||||
"SystemPrompt": "you are a helpful assistant",
|
||||
"UserPrompt": "plan a 10 day family visit to China.",
|
||||
"OutputFormat": llm.FormatText,
|
||||
"LLMParams": &model.LLMParams{
|
||||
ModelName: modelName,
|
||||
},
|
||||
},
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
"output": {
|
||||
Type: vo.DataTypeString,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
deepseekNode := &compose2.NodeSchema{
|
||||
Key: "deepseek_llm_node_key",
|
||||
Type: entity.NodeTypeLLM,
|
||||
Configs: map[string]any{
|
||||
"SystemPrompt": "you are a helpful assistant",
|
||||
"UserPrompt": "thoroughly plan a 10 day family visit to China. Use your reasoning ability.",
|
||||
"OutputFormat": llm.FormatText,
|
||||
"LLMParams": &model.LLMParams{
|
||||
ModelName: modelName,
|
||||
},
|
||||
},
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
"output": {
|
||||
Type: vo.DataTypeString,
|
||||
},
|
||||
"reasoning_content": {
|
||||
Type: vo.DataTypeString,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
emitterNode := &compose2.NodeSchema{
|
||||
Key: "emitter_node_key",
|
||||
Type: entity.NodeTypeOutputEmitter,
|
||||
Configs: map[string]any{
|
||||
"Template": "prefix {{inputObj.field1}} {{input2}} {{deepseek_reasoning}} \n\n###\n\n {{openai_output}} \n\n###\n\n {{deepseek_output}} {{inputObj.field2}} suffix",
|
||||
"Mode": nodes.Streaming,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"openai_output"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: openaiNode.Key,
|
||||
FromPath: compose.FieldPath{"output"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: compose.FieldPath{"deepseek_output"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: deepseekNode.Key,
|
||||
FromPath: compose.FieldPath{"output"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: compose.FieldPath{"deepseek_reasoning"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: deepseekNode.Key,
|
||||
FromPath: compose.FieldPath{"reasoning_content"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: compose.FieldPath{"inputObj"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromPath: compose.FieldPath{"inputObj"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: compose.FieldPath{"input2"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromPath: compose.FieldPath{"input2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.UseAnswerContent,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"openai_output"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: openaiNode.Key,
|
||||
FromPath: compose.FieldPath{"output"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: compose.FieldPath{"deepseek_output"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: deepseekNode.Key,
|
||||
FromPath: compose.FieldPath{"output"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: compose.FieldPath{"deepseek_reasoning"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: deepseekNode.Key,
|
||||
FromPath: compose.FieldPath{"reasoning_content"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
openaiNode,
|
||||
deepseekNode,
|
||||
emitterNode,
|
||||
exit,
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
ToNode: openaiNode.Key,
|
||||
},
|
||||
{
|
||||
FromNode: openaiNode.Key,
|
||||
ToNode: emitterNode.Key,
|
||||
},
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
ToNode: deepseekNode.Key,
|
||||
},
|
||||
{
|
||||
FromNode: deepseekNode.Key,
|
||||
ToNode: emitterNode.Key,
|
||||
},
|
||||
{
|
||||
FromNode: emitterNode.Key,
|
||||
ToNode: exit.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ws.Init()
|
||||
|
||||
wf, err := compose2.NewWorkflow(ctx, ws)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var fullOutput string
|
||||
|
||||
cbHandler := callbacks.NewHandlerBuilder().OnEndWithStreamOutputFn(
|
||||
func(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[callbacks.CallbackOutput]) context.Context {
|
||||
defer output.Close()
|
||||
|
||||
for {
|
||||
chunk, e := output.Recv()
|
||||
if e != nil {
|
||||
if e == io.EOF {
|
||||
break
|
||||
}
|
||||
assert.NoError(t, e)
|
||||
}
|
||||
|
||||
s, ok := chunk.(map[string]any)
|
||||
assert.True(t, ok)
|
||||
|
||||
out := s["output"].(string)
|
||||
if out != nodes.KeyIsFinished {
|
||||
fmt.Print(s["output"])
|
||||
fullOutput += s["output"].(string)
|
||||
}
|
||||
}
|
||||
|
||||
return ctx
|
||||
}).Build()
|
||||
|
||||
outStream, err := wf.Runner.Stream(ctx, map[string]any{
|
||||
"inputObj": map[string]any{
|
||||
"field1": "field1",
|
||||
"field2": 1.1,
|
||||
},
|
||||
"input2": 23.5,
|
||||
}, compose.WithCallbacks(cbHandler).DesignateNode(string(emitterNode.Key)))
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, strings.HasPrefix(fullOutput, "prefix field1 23.5"))
|
||||
assert.True(t, strings.HasSuffix(fullOutput, "1.1 suffix"))
|
||||
outStream.Close()
|
||||
})
|
||||
})
|
||||
}
|
||||
495
backend/domain/workflow/internal/compose/test/loop_test.go
Normal file
495
backend/domain/workflow/internal/compose/test/loop_test.go
Normal file
@@ -0,0 +1,495 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableassigner"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
func TestLoop(t *testing.T) {
|
||||
t.Run("by iteration", func(t *testing.T) {
|
||||
// start-> loop_node_key[innerNode->continue] -> end
|
||||
innerNode := &compose2.NodeSchema{
|
||||
Key: "innerNode",
|
||||
Type: entity.NodeTypeLambda,
|
||||
Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
|
||||
index := in["index"].(int64)
|
||||
return map[string]any{"output": index}, nil
|
||||
}),
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"index"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: "loop_node_key",
|
||||
FromPath: compose.FieldPath{"index"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
continueNode := &compose2.NodeSchema{
|
||||
Key: "continueNode",
|
||||
Type: entity.NodeTypeContinue,
|
||||
}
|
||||
|
||||
entry := &compose2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
}
|
||||
|
||||
loopNode := &compose2.NodeSchema{
|
||||
Key: "loop_node_key",
|
||||
Type: entity.NodeTypeLoop,
|
||||
Configs: map[string]any{
|
||||
"LoopType": loop.ByIteration,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{loop.Count},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromPath: compose.FieldPath{"count"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
OutputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"output"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: "innerNode",
|
||||
FromPath: compose.FieldPath{"output"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"output"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: "loop_node_key",
|
||||
FromPath: compose.FieldPath{"output"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
loopNode,
|
||||
exit,
|
||||
innerNode,
|
||||
continueNode,
|
||||
},
|
||||
Hierarchy: map[vo.NodeKey]vo.NodeKey{
|
||||
"innerNode": "loop_node_key",
|
||||
"continueNode": "loop_node_key",
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
{
|
||||
FromNode: "loop_node_key",
|
||||
ToNode: "innerNode",
|
||||
},
|
||||
{
|
||||
FromNode: "innerNode",
|
||||
ToNode: "continueNode",
|
||||
},
|
||||
{
|
||||
FromNode: "continueNode",
|
||||
ToNode: "loop_node_key",
|
||||
},
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
ToNode: "loop_node_key",
|
||||
},
|
||||
{
|
||||
FromNode: "loop_node_key",
|
||||
ToNode: exit.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ws.Init()
|
||||
|
||||
wf, err := compose2.NewWorkflow(context.Background(), ws)
|
||||
assert.NoError(t, err)
|
||||
|
||||
out, err := wf.Runner.Invoke(context.Background(), map[string]any{
|
||||
"count": int64(3),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, map[string]any{
|
||||
"output": []any{int64(0), int64(1), int64(2)},
|
||||
}, out)
|
||||
})
|
||||
|
||||
t.Run("infinite", func(t *testing.T) {
|
||||
// start-> loop_node_key[innerNode->break] -> end
|
||||
innerNode := &compose2.NodeSchema{
|
||||
Key: "innerNode",
|
||||
Type: entity.NodeTypeLambda,
|
||||
Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
|
||||
index := in["index"].(int64)
|
||||
return map[string]any{"output": index}, nil
|
||||
}),
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"index"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: "loop_node_key",
|
||||
FromPath: compose.FieldPath{"index"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
breakNode := &compose2.NodeSchema{
|
||||
Key: "breakNode",
|
||||
Type: entity.NodeTypeBreak,
|
||||
}
|
||||
|
||||
entry := &compose2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
}
|
||||
|
||||
loopNode := &compose2.NodeSchema{
|
||||
Key: "loop_node_key",
|
||||
Type: entity.NodeTypeLoop,
|
||||
Configs: map[string]any{
|
||||
"LoopType": loop.Infinite,
|
||||
},
|
||||
OutputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"output"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: "innerNode",
|
||||
FromPath: compose.FieldPath{"output"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"output"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: "loop_node_key",
|
||||
FromPath: compose.FieldPath{"output"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
loopNode,
|
||||
exit,
|
||||
innerNode,
|
||||
breakNode,
|
||||
},
|
||||
Hierarchy: map[vo.NodeKey]vo.NodeKey{
|
||||
"innerNode": "loop_node_key",
|
||||
"breakNode": "loop_node_key",
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
{
|
||||
FromNode: "loop_node_key",
|
||||
ToNode: "innerNode",
|
||||
},
|
||||
{
|
||||
FromNode: "innerNode",
|
||||
ToNode: "breakNode",
|
||||
},
|
||||
{
|
||||
FromNode: "breakNode",
|
||||
ToNode: "loop_node_key",
|
||||
},
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
ToNode: "loop_node_key",
|
||||
},
|
||||
{
|
||||
FromNode: "loop_node_key",
|
||||
ToNode: exit.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ws.Init()
|
||||
|
||||
wf, err := compose2.NewWorkflow(context.Background(), ws)
|
||||
assert.NoError(t, err)
|
||||
|
||||
out, err := wf.Runner.Invoke(context.Background(), map[string]any{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, map[string]any{
|
||||
"output": []any{int64(0)},
|
||||
}, out)
|
||||
})
|
||||
|
||||
t.Run("by array", func(t *testing.T) {
|
||||
// start-> loop_node_key[innerNode->variable_assign] -> end
|
||||
|
||||
innerNode := &compose2.NodeSchema{
|
||||
Key: "innerNode",
|
||||
Type: entity.NodeTypeLambda,
|
||||
Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
|
||||
item1 := in["item1"].(string)
|
||||
item2 := in["item2"].(string)
|
||||
count := in["count"].(int)
|
||||
return map[string]any{"total": int(count) + len(item1) + len(item2)}, nil
|
||||
}),
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"item1"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: "loop_node_key",
|
||||
FromPath: compose.FieldPath{"items1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: compose.FieldPath{"item2"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: "loop_node_key",
|
||||
FromPath: compose.FieldPath{"items2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: compose.FieldPath{"count"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromPath: compose.FieldPath{"count"},
|
||||
VariableType: ptr.Of(vo.ParentIntermediate),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
assigner := &compose2.NodeSchema{
|
||||
Key: "assigner",
|
||||
Type: entity.NodeTypeVariableAssignerWithinLoop,
|
||||
Configs: []*variableassigner.Pair{
|
||||
{
|
||||
Left: vo.Reference{
|
||||
FromPath: compose.FieldPath{"count"},
|
||||
VariableType: ptr.Of(vo.ParentIntermediate),
|
||||
},
|
||||
Right: compose.FieldPath{"total"},
|
||||
},
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"total"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: "innerNode",
|
||||
FromPath: compose.FieldPath{"total"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
entry := &compose2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"output"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: "loop_node_key",
|
||||
FromPath: compose.FieldPath{"output"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
loopNode := &compose2.NodeSchema{
|
||||
Key: "loop_node_key",
|
||||
Type: entity.NodeTypeLoop,
|
||||
Configs: map[string]any{
|
||||
"LoopType": loop.ByArray,
|
||||
"IntermediateVars": map[string]*vo.TypeInfo{
|
||||
"count": {
|
||||
Type: vo.DataTypeInteger,
|
||||
},
|
||||
},
|
||||
},
|
||||
InputTypes: map[string]*vo.TypeInfo{
|
||||
"items1": {
|
||||
Type: vo.DataTypeArray,
|
||||
ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
},
|
||||
"items2": {
|
||||
Type: vo.DataTypeArray,
|
||||
ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
},
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"items1"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromPath: compose.FieldPath{"items1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: compose.FieldPath{"items2"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromPath: compose.FieldPath{"items2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: compose.FieldPath{"count"},
|
||||
Source: vo.FieldSource{
|
||||
Val: 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
OutputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"output"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromPath: compose.FieldPath{"count"},
|
||||
VariableType: ptr.Of(vo.ParentIntermediate),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
loopNode,
|
||||
exit,
|
||||
innerNode,
|
||||
assigner,
|
||||
},
|
||||
Hierarchy: map[vo.NodeKey]vo.NodeKey{
|
||||
"innerNode": "loop_node_key",
|
||||
"assigner": "loop_node_key",
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
{
|
||||
FromNode: "loop_node_key",
|
||||
ToNode: "innerNode",
|
||||
},
|
||||
{
|
||||
FromNode: "innerNode",
|
||||
ToNode: "assigner",
|
||||
},
|
||||
{
|
||||
FromNode: "assigner",
|
||||
ToNode: "loop_node_key",
|
||||
},
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
ToNode: "loop_node_key",
|
||||
},
|
||||
{
|
||||
FromNode: "loop_node_key",
|
||||
ToNode: exit.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ws.Init()
|
||||
|
||||
wf, err := compose2.NewWorkflow(context.Background(), ws)
|
||||
assert.NoError(t, err)
|
||||
|
||||
out, err := wf.Runner.Invoke(context.Background(), map[string]any{
|
||||
"items1": []any{"a", "b"},
|
||||
"items2": []any{"a1", "b1", "c1"},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, map[string]any{
|
||||
"output": 6,
|
||||
}, out)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,673 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/bytedance/mockey"
|
||||
"github.com/cloudwego/eino-ext/components/model/openai"
|
||||
model2 "github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/mock/gomock"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
|
||||
mockmodel "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model/modelmock"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
|
||||
repo2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/repo"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/checkpoint"
|
||||
mock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/idgen"
|
||||
storageMock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/storage"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/internal/testutil"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
func TestQuestionAnswer(t *testing.T) {
|
||||
mockey.PatchConvey("test qa", t, func() {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockModelManager := mockmodel.NewMockManager(ctrl)
|
||||
mockey.Mock(model.GetManager).Return(mockModelManager).Build()
|
||||
|
||||
accessKey := os.Getenv("OPENAI_API_KEY")
|
||||
baseURL := os.Getenv("OPENAI_BASE_URL")
|
||||
modelName := os.Getenv("OPENAI_MODEL_NAME")
|
||||
var (
|
||||
chatModel model2.BaseChatModel
|
||||
err error
|
||||
)
|
||||
|
||||
if len(accessKey) > 0 && len(baseURL) > 0 && len(modelName) > 0 {
|
||||
chatModel, err = openai.NewChatModel(context.Background(), &openai.ChatModelConfig{
|
||||
APIKey: accessKey,
|
||||
ByAzure: true,
|
||||
BaseURL: baseURL,
|
||||
Model: modelName,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
mockModelManager.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(chatModel, nil, nil).AnyTimes()
|
||||
}
|
||||
|
||||
dsn := "root:root@tcp(127.0.0.1:3306)/opencoze?charset=utf8mb4&parseTime=True&loc=Local"
|
||||
if os.Getenv("CI_JOB_NAME") != "" {
|
||||
dsn = strings.ReplaceAll(dsn, "127.0.0.1", "mysql")
|
||||
}
|
||||
db, err := gorm.Open(mysql.Open(dsn))
|
||||
assert.NoError(t, err)
|
||||
|
||||
s, err := miniredis.Run()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start miniredis: %v", err)
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
redisClient := redis.NewClient(&redis.Options{
|
||||
Addr: s.Addr(),
|
||||
})
|
||||
|
||||
mockIDGen := mock.NewMockIDGenerator(ctrl)
|
||||
mockIDGen.EXPECT().GenID(gomock.Any()).Return(time.Now().UnixNano(), nil).AnyTimes()
|
||||
mockTos := storageMock.NewMockStorage(ctrl)
|
||||
mockTos.EXPECT().GetObjectUrl(gomock.Any(), gomock.Any(), gomock.Any()).Return("", nil).AnyTimes()
|
||||
repo := repo2.NewRepository(mockIDGen, db, redisClient, mockTos,
|
||||
checkpoint.NewRedisStore(redisClient))
|
||||
mockey.Mock(workflow.GetRepository).Return(repo).Build()
|
||||
|
||||
t.Run("answer directly, no structured output", func(t *testing.T) {
|
||||
entry := &compose2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
}}
|
||||
|
||||
ns := &compose2.NodeSchema{
|
||||
Key: "qa_node_key",
|
||||
Type: entity.NodeTypeQuestionAnswer,
|
||||
Configs: map[string]any{
|
||||
"QuestionTpl": "{{input}}",
|
||||
"AnswerType": qa.AnswerDirectly,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"input"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromPath: compose.FieldPath{"query"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"answer"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: "qa_node_key",
|
||||
FromPath: compose.FieldPath{qa.UserResponseKey},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
ns,
|
||||
exit,
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
ToNode: "qa_node_key",
|
||||
},
|
||||
{
|
||||
FromNode: "qa_node_key",
|
||||
ToNode: exit.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ws.Init()
|
||||
|
||||
wf, err := compose2.NewWorkflow(context.Background(), ws)
|
||||
assert.NoError(t, err)
|
||||
|
||||
checkPointID := fmt.Sprintf("%d", time.Now().Nanosecond())
|
||||
_, err = wf.Runner.Invoke(context.Background(), map[string]any{
|
||||
"query": "what's your name?",
|
||||
}, compose.WithCheckPointID(checkPointID))
|
||||
assert.Error(t, err)
|
||||
|
||||
info, existed := compose.ExtractInterruptInfo(err)
|
||||
assert.True(t, existed)
|
||||
assert.Equal(t, "what's your name?", info.State.(*compose2.State).Questions[ns.Key][0].Question)
|
||||
|
||||
answer := "my name is eino"
|
||||
stateModifier := func(ctx context.Context, path compose.NodePath, state any) error {
|
||||
state.(*compose2.State).Answers[ns.Key] = append(state.(*compose2.State).Answers[ns.Key], answer)
|
||||
return nil
|
||||
}
|
||||
out, err := wf.Runner.Invoke(context.Background(), nil, compose.WithCheckPointID(checkPointID), compose.WithStateModifier(stateModifier))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, map[string]any{
|
||||
"answer": answer,
|
||||
}, out)
|
||||
})
|
||||
|
||||
t.Run("answer with fixed choices", func(t *testing.T) {
|
||||
if chatModel == nil {
|
||||
oneChatModel := &testutil.UTChatModel{
|
||||
InvokeResultProvider: func(_ int, in []*schema.Message) (*schema.Message, error) {
|
||||
return &schema.Message{
|
||||
Role: schema.Assistant,
|
||||
Content: "-1",
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
mockModelManager.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(oneChatModel, nil, nil).Times(1)
|
||||
}
|
||||
|
||||
entry := &compose2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
}
|
||||
|
||||
ns := &compose2.NodeSchema{
|
||||
Key: "qa_node_key",
|
||||
Type: entity.NodeTypeQuestionAnswer,
|
||||
Configs: map[string]any{
|
||||
"QuestionTpl": "{{input}}",
|
||||
"AnswerType": qa.AnswerByChoices,
|
||||
"ChoiceType": qa.FixedChoices,
|
||||
"FixedChoices": []string{"{{choice1}}", "{{choice2}}"},
|
||||
"LLMParams": &model.LLMParams{},
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"input"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromPath: compose.FieldPath{"query"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: compose.FieldPath{"choice1"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromPath: compose.FieldPath{"choice1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: compose.FieldPath{"choice2"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromPath: compose.FieldPath{"choice2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"option_id"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: "qa_node_key",
|
||||
FromPath: compose.FieldPath{qa.OptionIDKey},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: compose.FieldPath{"option_content"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: "qa_node_key",
|
||||
FromPath: compose.FieldPath{qa.OptionContentKey},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
lambda := &compose2.NodeSchema{
|
||||
Key: "lambda",
|
||||
Type: entity.NodeTypeLambda,
|
||||
Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
|
||||
return out, nil
|
||||
}),
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
ns,
|
||||
exit,
|
||||
lambda,
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
ToNode: "qa_node_key",
|
||||
},
|
||||
{
|
||||
FromNode: "qa_node_key",
|
||||
ToNode: exit.Key,
|
||||
FromPort: ptr.Of("branch_0"),
|
||||
},
|
||||
{
|
||||
FromNode: "qa_node_key",
|
||||
ToNode: exit.Key,
|
||||
FromPort: ptr.Of("branch_1"),
|
||||
},
|
||||
{
|
||||
FromNode: "qa_node_key",
|
||||
ToNode: "lambda",
|
||||
FromPort: ptr.Of("default"),
|
||||
},
|
||||
{
|
||||
FromNode: "lambda",
|
||||
ToNode: exit.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ws.Init()
|
||||
|
||||
wf, err := compose2.NewWorkflow(context.Background(), ws)
|
||||
assert.NoError(t, err)
|
||||
|
||||
checkPointID := fmt.Sprintf("%d", time.Now().Nanosecond())
|
||||
_, err = wf.Runner.Invoke(context.Background(), map[string]any{
|
||||
"query": "what's would you make in Coze?",
|
||||
"choice1": "make agent",
|
||||
"choice2": "make workflow",
|
||||
}, compose.WithCheckPointID(checkPointID))
|
||||
assert.Error(t, err)
|
||||
|
||||
info, existed := compose.ExtractInterruptInfo(err)
|
||||
assert.True(t, existed)
|
||||
assert.Equal(t, "what's would you make in Coze?", info.State.(*compose2.State).Questions[ns.Key][0].Question)
|
||||
assert.Equal(t, "make agent", info.State.(*compose2.State).Questions[ns.Key][0].Choices[0])
|
||||
assert.Equal(t, "make workflow", info.State.(*compose2.State).Questions[ns.Key][0].Choices[1])
|
||||
|
||||
chosenContent := "I would make all kinds of stuff"
|
||||
stateModifier := func(ctx context.Context, path compose.NodePath, state any) error {
|
||||
state.(*compose2.State).Answers[ns.Key] = append(state.(*compose2.State).Answers[ns.Key], chosenContent)
|
||||
return nil
|
||||
}
|
||||
out, err := wf.Runner.Invoke(context.Background(), nil, compose.WithCheckPointID(checkPointID), compose.WithStateModifier(stateModifier))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, map[string]any{
|
||||
"option_id": "other",
|
||||
"option_content": chosenContent,
|
||||
}, out)
|
||||
})
|
||||
|
||||
t.Run("answer with dynamic choices", func(t *testing.T) {
|
||||
entry := &compose2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
}
|
||||
|
||||
ns := &compose2.NodeSchema{
|
||||
Key: "qa_node_key",
|
||||
Type: entity.NodeTypeQuestionAnswer,
|
||||
Configs: map[string]any{
|
||||
"QuestionTpl": "{{input}}",
|
||||
"AnswerType": qa.AnswerByChoices,
|
||||
"ChoiceType": qa.DynamicChoices,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"input"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromPath: compose.FieldPath{"query"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: compose.FieldPath{qa.DynamicChoicesKey},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromPath: compose.FieldPath{"choices"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"option_id"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: "qa_node_key",
|
||||
FromPath: compose.FieldPath{qa.OptionIDKey},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: compose.FieldPath{"option_content"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: "qa_node_key",
|
||||
FromPath: compose.FieldPath{qa.OptionContentKey},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
lambda := &compose2.NodeSchema{
|
||||
Key: "lambda",
|
||||
Type: entity.NodeTypeLambda,
|
||||
Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
|
||||
return out, nil
|
||||
}),
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
ns,
|
||||
exit,
|
||||
lambda,
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
ToNode: "qa_node_key",
|
||||
},
|
||||
{
|
||||
FromNode: "qa_node_key",
|
||||
ToNode: exit.Key,
|
||||
FromPort: ptr.Of("branch_0"),
|
||||
},
|
||||
{
|
||||
FromNode: "lambda",
|
||||
ToNode: exit.Key,
|
||||
},
|
||||
{
|
||||
FromNode: "qa_node_key",
|
||||
ToNode: "lambda",
|
||||
FromPort: ptr.Of("default"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ws.Init()
|
||||
|
||||
wf, err := compose2.NewWorkflow(context.Background(), ws)
|
||||
assert.NoError(t, err)
|
||||
|
||||
checkPointID := fmt.Sprintf("%d", time.Now().Nanosecond())
|
||||
_, err = wf.Runner.Invoke(context.Background(), map[string]any{
|
||||
"query": "what's the capital city of China?",
|
||||
"choices": []any{"beijing", "shanghai"},
|
||||
}, compose.WithCheckPointID(checkPointID))
|
||||
assert.Error(t, err)
|
||||
|
||||
info, existed := compose.ExtractInterruptInfo(err)
|
||||
assert.True(t, existed)
|
||||
assert.Equal(t, "what's the capital city of China?", info.State.(*compose2.State).Questions[ns.Key][0].Question)
|
||||
assert.Equal(t, "beijing", info.State.(*compose2.State).Questions[ns.Key][0].Choices[0])
|
||||
assert.Equal(t, "shanghai", info.State.(*compose2.State).Questions[ns.Key][0].Choices[1])
|
||||
|
||||
chosenContent := "beijing"
|
||||
stateModifier := func(ctx context.Context, path compose.NodePath, state any) error {
|
||||
state.(*compose2.State).Answers[ns.Key] = append(state.(*compose2.State).Answers[ns.Key], chosenContent)
|
||||
return nil
|
||||
}
|
||||
out, err := wf.Runner.Invoke(context.Background(), nil, compose.WithCheckPointID(checkPointID), compose.WithStateModifier(stateModifier))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, map[string]any{
|
||||
"option_id": "A",
|
||||
"option_content": chosenContent,
|
||||
}, out)
|
||||
})
|
||||
|
||||
t.Run("answer directly, extract structured output", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
qaCount := 0
|
||||
if chatModel == nil {
|
||||
defer func() {
|
||||
chatModel = nil
|
||||
}()
|
||||
chatModel = &testutil.UTChatModel{
|
||||
InvokeResultProvider: func(_ int, in []*schema.Message) (*schema.Message, error) {
|
||||
if qaCount == 1 {
|
||||
return &schema.Message{
|
||||
Role: schema.Assistant,
|
||||
Content: `{"question": "what's your age?"}`,
|
||||
}, nil
|
||||
} else if qaCount == 2 {
|
||||
return &schema.Message{
|
||||
Role: schema.Assistant,
|
||||
Content: `{"fields": {"name": "eino", "age": 1}}`,
|
||||
}, nil
|
||||
}
|
||||
return nil, errors.New("not found")
|
||||
},
|
||||
}
|
||||
mockModelManager.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(chatModel, nil, nil).Times(1)
|
||||
}
|
||||
|
||||
entry := &compose2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
}
|
||||
|
||||
ns := &compose2.NodeSchema{
|
||||
Key: "qa_node_key",
|
||||
Type: entity.NodeTypeQuestionAnswer,
|
||||
Configs: map[string]any{
|
||||
"QuestionTpl": "{{input}}",
|
||||
"AnswerType": qa.AnswerDirectly,
|
||||
"ExtractFromAnswer": true,
|
||||
"AdditionalSystemPromptTpl": "{{prompt}}",
|
||||
"MaxAnswerCount": 2,
|
||||
"LLMParams": &model.LLMParams{},
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"input"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromPath: compose.FieldPath{"query"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: compose.FieldPath{"prompt"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromPath: compose.FieldPath{"prompt"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
"name": {
|
||||
Type: vo.DataTypeString,
|
||||
Required: true,
|
||||
},
|
||||
"age": {
|
||||
Type: vo.DataTypeInteger,
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"name"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: "qa_node_key",
|
||||
FromPath: compose.FieldPath{"name"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: compose.FieldPath{"age"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: "qa_node_key",
|
||||
FromPath: compose.FieldPath{"age"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: compose.FieldPath{qa.UserResponseKey},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: "qa_node_key",
|
||||
FromPath: compose.FieldPath{qa.UserResponseKey},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
ns,
|
||||
exit,
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
ToNode: "qa_node_key",
|
||||
},
|
||||
{
|
||||
FromNode: "qa_node_key",
|
||||
ToNode: exit.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ws.Init()
|
||||
|
||||
wf, err := compose2.NewWorkflow(context.Background(), ws)
|
||||
assert.NoError(t, err)
|
||||
|
||||
checkPointID := fmt.Sprintf("%d", time.Now().Nanosecond())
|
||||
_, err = wf.Runner.Invoke(ctx, map[string]any{
|
||||
"query": "what's your name?",
|
||||
"prompt": "You are a helpful assistant.",
|
||||
}, compose.WithCheckPointID(checkPointID))
|
||||
assert.Error(t, err)
|
||||
|
||||
info, existed := compose.ExtractInterruptInfo(err)
|
||||
assert.True(t, existed)
|
||||
assert.Equal(t, "what's your name?", info.State.(*compose2.State).Questions["qa_node_key"][0].Question)
|
||||
|
||||
qaCount++
|
||||
answer := "my name is eino"
|
||||
stateModifier := func(ctx context.Context, path compose.NodePath, state any) error {
|
||||
state.(*compose2.State).Answers[ns.Key] = append(state.(*compose2.State).Answers[ns.Key], answer)
|
||||
return nil
|
||||
}
|
||||
_, err = wf.Runner.Invoke(ctx, map[string]any{}, compose.WithCheckPointID(checkPointID), compose.WithStateModifier(stateModifier))
|
||||
assert.Error(t, err)
|
||||
info, existed = compose.ExtractInterruptInfo(err)
|
||||
assert.True(t, existed)
|
||||
|
||||
qaCount++
|
||||
answer = "my age is 1 years old"
|
||||
stateModifier = func(ctx context.Context, path compose.NodePath, state any) error {
|
||||
state.(*compose2.State).Answers[ns.Key] = append(state.(*compose2.State).Answers[ns.Key], answer)
|
||||
return nil
|
||||
}
|
||||
out, err := wf.Runner.Invoke(ctx, map[string]any{}, compose.WithCheckPointID(checkPointID), compose.WithStateModifier(stateModifier))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, map[string]any{
|
||||
qa.UserResponseKey: answer,
|
||||
"name": "eino",
|
||||
"age": int64(1),
|
||||
}, out)
|
||||
})
|
||||
})
|
||||
}
|
||||
634
backend/domain/workflow/internal/compose/test/workflow_test.go
Normal file
634
backend/domain/workflow/internal/compose/test/workflow_test.go
Normal file
@@ -0,0 +1,634 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/textprocessor"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableaggregator"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
func TestAddSelector(t *testing.T) {
|
||||
// start -> selector, selector.condition1 -> lambda1 -> end, selector.condition2 -> [lambda2, lambda3] -> end, selector default -> end
|
||||
entry := &compose2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
}}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: "lambda1",
|
||||
FromPath: compose.FieldPath{"lambda1"},
|
||||
},
|
||||
},
|
||||
Path: compose.FieldPath{"lambda1"},
|
||||
},
|
||||
{
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: "lambda2",
|
||||
FromPath: compose.FieldPath{"lambda2"},
|
||||
},
|
||||
},
|
||||
Path: compose.FieldPath{"lambda2"},
|
||||
},
|
||||
{
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: "lambda3",
|
||||
FromPath: compose.FieldPath{"lambda3"},
|
||||
},
|
||||
},
|
||||
Path: compose.FieldPath{"lambda3"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
lambda1 := func(ctx context.Context, in map[string]any) (map[string]any, error) {
|
||||
return map[string]any{
|
||||
"lambda1": "v1",
|
||||
}, nil
|
||||
}
|
||||
|
||||
lambdaNode1 := &compose2.NodeSchema{
|
||||
Key: "lambda1",
|
||||
Type: entity.NodeTypeLambda,
|
||||
Lambda: compose.InvokableLambda(lambda1),
|
||||
}
|
||||
|
||||
lambda2 := func(ctx context.Context, in map[string]any) (map[string]any, error) {
|
||||
return map[string]any{
|
||||
"lambda2": "v2",
|
||||
}, nil
|
||||
}
|
||||
|
||||
LambdaNode2 := &compose2.NodeSchema{
|
||||
Key: "lambda2",
|
||||
Type: entity.NodeTypeLambda,
|
||||
Lambda: compose.InvokableLambda(lambda2),
|
||||
}
|
||||
|
||||
lambda3 := func(ctx context.Context, in map[string]any) (map[string]any, error) {
|
||||
return map[string]any{
|
||||
"lambda3": "v3",
|
||||
}, nil
|
||||
}
|
||||
|
||||
lambdaNode3 := &compose2.NodeSchema{
|
||||
Key: "lambda3",
|
||||
Type: entity.NodeTypeLambda,
|
||||
Lambda: compose.InvokableLambda(lambda3),
|
||||
}
|
||||
|
||||
ns := &compose2.NodeSchema{
|
||||
Key: "selector",
|
||||
Type: entity.NodeTypeSelector,
|
||||
Configs: map[string]any{"Clauses": []*selector.OneClauseSchema{
|
||||
{
|
||||
Single: ptr.Of(selector.OperatorEqual),
|
||||
},
|
||||
{
|
||||
Multi: &selector.MultiClauseSchema{
|
||||
Clauses: []*selector.Operator{
|
||||
ptr.Of(selector.OperatorGreater),
|
||||
ptr.Of(selector.OperatorIsTrue),
|
||||
},
|
||||
Relation: selector.ClauseRelationAND,
|
||||
},
|
||||
},
|
||||
}},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"0", selector.LeftKey},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromPath: compose.FieldPath{"key1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: compose.FieldPath{"0", selector.RightKey},
|
||||
Source: vo.FieldSource{
|
||||
Val: "value1",
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: compose.FieldPath{"1", "0", selector.LeftKey},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromPath: compose.FieldPath{"key2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: compose.FieldPath{"1", "0", selector.RightKey},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromPath: compose.FieldPath{"key3"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: compose.FieldPath{"1", "1", selector.LeftKey},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromPath: compose.FieldPath{"key4"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
InputTypes: map[string]*vo.TypeInfo{
|
||||
"0": {
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
selector.LeftKey: {
|
||||
Type: vo.DataTypeString,
|
||||
},
|
||||
selector.RightKey: {
|
||||
Type: vo.DataTypeString,
|
||||
},
|
||||
},
|
||||
},
|
||||
"1": {
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
"0": {
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
selector.LeftKey: {
|
||||
Type: vo.DataTypeInteger,
|
||||
},
|
||||
selector.RightKey: {
|
||||
Type: vo.DataTypeInteger,
|
||||
},
|
||||
},
|
||||
},
|
||||
"1": {
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
selector.LeftKey: {
|
||||
Type: vo.DataTypeBoolean,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
ns,
|
||||
lambdaNode1,
|
||||
LambdaNode2,
|
||||
lambdaNode3,
|
||||
exit,
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
ToNode: "selector",
|
||||
},
|
||||
{
|
||||
FromNode: "selector",
|
||||
ToNode: "lambda1",
|
||||
FromPort: ptr.Of("branch_0"),
|
||||
},
|
||||
{
|
||||
FromNode: "selector",
|
||||
ToNode: "lambda2",
|
||||
FromPort: ptr.Of("branch_1"),
|
||||
},
|
||||
{
|
||||
FromNode: "selector",
|
||||
ToNode: "lambda3",
|
||||
FromPort: ptr.Of("branch_1"),
|
||||
},
|
||||
{
|
||||
FromNode: "selector",
|
||||
ToNode: exit.Key,
|
||||
FromPort: ptr.Of("default"),
|
||||
},
|
||||
{
|
||||
FromNode: "lambda1",
|
||||
ToNode: exit.Key,
|
||||
},
|
||||
{
|
||||
FromNode: "lambda2",
|
||||
ToNode: exit.Key,
|
||||
},
|
||||
{
|
||||
FromNode: "lambda3",
|
||||
ToNode: exit.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ws.Init()
|
||||
|
||||
ctx := context.Background()
|
||||
wf, err := compose2.NewWorkflow(ctx, ws)
|
||||
assert.NoError(t, err)
|
||||
|
||||
out, err := wf.Runner.Invoke(ctx, map[string]any{
|
||||
"key1": "value1",
|
||||
"key2": int64(2),
|
||||
"key3": int64(3),
|
||||
"key4": true,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, map[string]any{
|
||||
"lambda1": "v1",
|
||||
}, out)
|
||||
|
||||
out, err = wf.Runner.Invoke(ctx, map[string]any{
|
||||
"key1": "value2",
|
||||
"key2": int64(3),
|
||||
"key3": int64(2),
|
||||
"key4": true,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, map[string]any{
|
||||
"lambda2": "v2",
|
||||
"lambda3": "v3",
|
||||
}, out)
|
||||
|
||||
out, err = wf.Runner.Invoke(ctx, map[string]any{
|
||||
"key1": "value2",
|
||||
"key2": int64(2),
|
||||
"key3": int64(3),
|
||||
"key4": true,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, map[string]any{}, out)
|
||||
}
|
||||
|
||||
func TestVariableAggregator(t *testing.T) {
|
||||
entry := &compose2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"Group1"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: "va",
|
||||
FromPath: compose.FieldPath{"Group1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: compose.FieldPath{"Group2"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: "va",
|
||||
FromPath: compose.FieldPath{"Group2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ns := &compose2.NodeSchema{
|
||||
Key: "va",
|
||||
Type: entity.NodeTypeVariableAggregator,
|
||||
Configs: map[string]any{
|
||||
"MergeStrategy": variableaggregator.FirstNotNullValue,
|
||||
"GroupToLen": map[string]int{
|
||||
"Group1": 1,
|
||||
"Group2": 1,
|
||||
},
|
||||
"GroupOrder": []string{
|
||||
"Group1",
|
||||
"Group2",
|
||||
},
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"Group1", "0"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromPath: compose.FieldPath{"Str1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: compose.FieldPath{"Group2", "0"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromPath: compose.FieldPath{"Int1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
InputTypes: map[string]*vo.TypeInfo{
|
||||
"Group1": {
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
"0": {
|
||||
Type: vo.DataTypeString,
|
||||
},
|
||||
},
|
||||
},
|
||||
"Group2": {
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
"0": {
|
||||
Type: vo.DataTypeInteger,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
"Group1": {
|
||||
Type: vo.DataTypeString,
|
||||
},
|
||||
"Group2": {
|
||||
Type: vo.DataTypeInteger,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
ns,
|
||||
exit,
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
ToNode: "va",
|
||||
},
|
||||
{
|
||||
FromNode: "va",
|
||||
ToNode: exit.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ws.Init()
|
||||
|
||||
ctx := context.Background()
|
||||
wf, err := compose2.NewWorkflow(ctx, ws)
|
||||
assert.NoError(t, err)
|
||||
|
||||
out, err := wf.Runner.Invoke(context.Background(), map[string]any{
|
||||
"Str1": "str_v1",
|
||||
"Int1": int64(1),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, map[string]any{
|
||||
"Group1": "str_v1",
|
||||
"Group2": int64(1),
|
||||
}, out)
|
||||
|
||||
out, err = wf.Runner.Invoke(context.Background(), map[string]any{
|
||||
"Str1": "str_v1",
|
||||
"Int1": nil,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, map[string]any{
|
||||
"Group1": "str_v1",
|
||||
"Group2": nil,
|
||||
}, out)
|
||||
}
|
||||
|
||||
func TestTextProcessor(t *testing.T) {
|
||||
t.Run("split", func(t *testing.T) {
|
||||
entry := &compose2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"output"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: "tp",
|
||||
FromPath: compose.FieldPath{"output"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ns := &compose2.NodeSchema{
|
||||
Key: "tp",
|
||||
Type: entity.NodeTypeTextProcessor,
|
||||
Configs: map[string]any{
|
||||
"Type": textprocessor.SplitText,
|
||||
"Separators": []string{"|"},
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"String"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromPath: compose.FieldPath{"Str"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
ns,
|
||||
entry,
|
||||
exit,
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
ToNode: "tp",
|
||||
},
|
||||
{
|
||||
FromNode: "tp",
|
||||
ToNode: exit.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ws.Init()
|
||||
|
||||
wf, err := compose2.NewWorkflow(context.Background(), ws)
|
||||
|
||||
out, err := wf.Runner.Invoke(context.Background(), map[string]any{
|
||||
"Str": "a|b|c",
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, map[string]any{
|
||||
"output": []any{"a", "b", "c"},
|
||||
}, out)
|
||||
})
|
||||
|
||||
t.Run("concat", func(t *testing.T) {
|
||||
entry := &compose2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"output"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: "tp",
|
||||
FromPath: compose.FieldPath{"output"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ns := &compose2.NodeSchema{
|
||||
Key: "tp",
|
||||
Type: entity.NodeTypeTextProcessor,
|
||||
Configs: map[string]any{
|
||||
"Type": textprocessor.ConcatText,
|
||||
"Tpl": "{{String1}}_{{String2.f1}}_{{String3.f2[1]}}",
|
||||
"ConcatChar": "\t",
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"String1"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromPath: compose.FieldPath{"Str1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: compose.FieldPath{"String2"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromPath: compose.FieldPath{"Str2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: compose.FieldPath{"String3"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromPath: compose.FieldPath{"Str3"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
ns,
|
||||
entry,
|
||||
exit,
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
ToNode: "tp",
|
||||
},
|
||||
{
|
||||
FromNode: "tp",
|
||||
ToNode: exit.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ws.Init()
|
||||
|
||||
ctx := context.Background()
|
||||
wf, err := compose2.NewWorkflow(ctx, ws)
|
||||
assert.NoError(t, err)
|
||||
|
||||
out, err := wf.Runner.Invoke(context.Background(), map[string]any{
|
||||
"Str1": true,
|
||||
"Str2": map[string]any{
|
||||
"f1": 1.0,
|
||||
},
|
||||
"Str3": map[string]any{
|
||||
"f2": []any{1, "a"},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, map[string]any{
|
||||
"output": "true_1_a",
|
||||
}, out)
|
||||
})
|
||||
}
|
||||
667
backend/domain/workflow/internal/compose/to_node.go
Normal file
667
backend/domain/workflow/internal/compose/to_node.go
Normal file
@@ -0,0 +1,667 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package compose
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
einomodel "github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
crossmodelmgr "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
|
||||
workflow3 "github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/workflow"
|
||||
workflow2 "github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
crosscode "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/code"
|
||||
crossconversation "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/conversation"
|
||||
crossdatabase "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
crossknowledge "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
|
||||
crossplugin "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/batch"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/code"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/conversation"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/emitter"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/httprequester"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/intentdetector"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/json"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/receiver"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/subworkflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/textprocessor"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableaggregator"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableassigner"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
)
|
||||
|
||||
func (s *NodeSchema) ToEntryConfig(_ context.Context) (*entry.Config, error) {
|
||||
return &entry.Config{
|
||||
DefaultValues: getKeyOrZero[map[string]any]("DefaultValues", s.Configs),
|
||||
OutputTypes: s.OutputTypes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToLLMConfig(ctx context.Context) (*llm.Config, error) {
|
||||
llmConf := &llm.Config{
|
||||
SystemPrompt: getKeyOrZero[string]("SystemPrompt", s.Configs),
|
||||
UserPrompt: getKeyOrZero[string]("UserPrompt", s.Configs),
|
||||
OutputFormat: mustGetKey[llm.Format]("OutputFormat", s.Configs),
|
||||
InputFields: s.InputTypes,
|
||||
OutputFields: s.OutputTypes,
|
||||
FullSources: getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs),
|
||||
}
|
||||
|
||||
llmParams := getKeyOrZero[*model.LLMParams]("LLMParams", s.Configs)
|
||||
|
||||
if llmParams == nil {
|
||||
return nil, fmt.Errorf("llm node llmParams is required")
|
||||
}
|
||||
var (
|
||||
err error
|
||||
chatModel, fallbackM einomodel.BaseChatModel
|
||||
info, fallbackI *crossmodelmgr.Model
|
||||
modelWithInfo llm.ModelWithInfo
|
||||
)
|
||||
|
||||
chatModel, info, err = model.GetManager().GetModel(ctx, llmParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
metaConfigs := s.ExceptionConfigs
|
||||
if metaConfigs != nil && metaConfigs.MaxRetry > 0 {
|
||||
backupModelParams := getKeyOrZero[*model.LLMParams]("BackupLLMParams", s.Configs)
|
||||
if backupModelParams != nil {
|
||||
fallbackM, fallbackI, err = model.GetManager().GetModel(ctx, backupModelParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if fallbackM == nil {
|
||||
modelWithInfo = llm.NewModel(chatModel, info)
|
||||
} else {
|
||||
modelWithInfo = llm.NewModelWithFallback(chatModel, fallbackM, info, fallbackI)
|
||||
}
|
||||
llmConf.ChatModel = modelWithInfo
|
||||
|
||||
fcParams := getKeyOrZero[*vo.FCParam]("FCParam", s.Configs)
|
||||
if fcParams != nil {
|
||||
if fcParams.WorkflowFCParam != nil {
|
||||
for _, wf := range fcParams.WorkflowFCParam.WorkflowList {
|
||||
wfIDStr := wf.WorkflowID
|
||||
wfID, err := strconv.ParseInt(wfIDStr, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid workflow id: %s", wfIDStr)
|
||||
}
|
||||
|
||||
workflowToolConfig := vo.WorkflowToolConfig{}
|
||||
if wf.FCSetting != nil {
|
||||
workflowToolConfig.InputParametersConfig = wf.FCSetting.RequestParameters
|
||||
workflowToolConfig.OutputParametersConfig = wf.FCSetting.ResponseParameters
|
||||
}
|
||||
|
||||
locator := vo.FromDraft
|
||||
if wf.WorkflowVersion != "" {
|
||||
locator = vo.FromSpecificVersion
|
||||
}
|
||||
|
||||
wfTool, err := workflow2.GetRepository().WorkflowAsTool(ctx, vo.GetPolicy{
|
||||
ID: wfID,
|
||||
QType: locator,
|
||||
Version: wf.WorkflowVersion,
|
||||
}, workflowToolConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
llmConf.Tools = append(llmConf.Tools, wfTool)
|
||||
if wfTool.TerminatePlan() == vo.UseAnswerContent {
|
||||
if llmConf.ToolsReturnDirectly == nil {
|
||||
llmConf.ToolsReturnDirectly = make(map[string]bool)
|
||||
}
|
||||
toolInfo, err := wfTool.Info(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
llmConf.ToolsReturnDirectly[toolInfo.Name] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if fcParams.PluginFCParam != nil {
|
||||
pluginToolsInvokableReq := make(map[int64]*crossplugin.ToolsInvokableRequest)
|
||||
for _, p := range fcParams.PluginFCParam.PluginList {
|
||||
pid, err := strconv.ParseInt(p.PluginID, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid plugin id: %s", p.PluginID)
|
||||
}
|
||||
toolID, err := strconv.ParseInt(p.ApiId, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid plugin id: %s", p.PluginID)
|
||||
}
|
||||
|
||||
var (
|
||||
requestParameters []*workflow3.APIParameter
|
||||
responseParameters []*workflow3.APIParameter
|
||||
)
|
||||
if p.FCSetting != nil {
|
||||
requestParameters = p.FCSetting.RequestParameters
|
||||
responseParameters = p.FCSetting.ResponseParameters
|
||||
}
|
||||
|
||||
if req, ok := pluginToolsInvokableReq[pid]; ok {
|
||||
req.ToolsInvokableInfo[toolID] = &crossplugin.ToolsInvokableInfo{
|
||||
ToolID: toolID,
|
||||
RequestAPIParametersConfig: requestParameters,
|
||||
ResponseAPIParametersConfig: responseParameters,
|
||||
}
|
||||
} else {
|
||||
pluginToolsInfoRequest := &crossplugin.ToolsInvokableRequest{
|
||||
PluginEntity: crossplugin.Entity{
|
||||
PluginID: pid,
|
||||
PluginVersion: ptr.Of(p.PluginVersion),
|
||||
},
|
||||
ToolsInvokableInfo: map[int64]*crossplugin.ToolsInvokableInfo{
|
||||
toolID: {
|
||||
ToolID: toolID,
|
||||
RequestAPIParametersConfig: requestParameters,
|
||||
ResponseAPIParametersConfig: responseParameters,
|
||||
},
|
||||
},
|
||||
IsDraft: p.IsDraft,
|
||||
}
|
||||
pluginToolsInvokableReq[pid] = pluginToolsInfoRequest
|
||||
}
|
||||
}
|
||||
inInvokableTools := make([]tool.BaseTool, 0, len(fcParams.PluginFCParam.PluginList))
|
||||
for _, req := range pluginToolsInvokableReq {
|
||||
toolMap, err := crossplugin.GetPluginService().GetPluginInvokableTools(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, t := range toolMap {
|
||||
inInvokableTools = append(inInvokableTools, crossplugin.NewInvokableTool(t))
|
||||
}
|
||||
}
|
||||
if len(inInvokableTools) > 0 {
|
||||
llmConf.Tools = inInvokableTools
|
||||
}
|
||||
}
|
||||
|
||||
if fcParams.KnowledgeFCParam != nil && len(fcParams.KnowledgeFCParam.KnowledgeList) > 0 {
|
||||
kwChatModel, err := knowledgeRecallChatModel(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
knowledgeOperator := crossknowledge.GetKnowledgeOperator()
|
||||
setting := fcParams.KnowledgeFCParam.GlobalSetting
|
||||
cfg := &llm.KnowledgeRecallConfig{
|
||||
ChatModel: kwChatModel,
|
||||
Retriever: knowledgeOperator,
|
||||
}
|
||||
searchType, err := totRetrievalSearchType(setting.SearchMode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cfg.RetrievalStrategy = &llm.RetrievalStrategy{
|
||||
RetrievalStrategy: &crossknowledge.RetrievalStrategy{
|
||||
TopK: ptr.Of(setting.TopK),
|
||||
MinScore: ptr.Of(setting.MinScore),
|
||||
SearchType: searchType,
|
||||
EnableNL2SQL: setting.UseNL2SQL,
|
||||
EnableQueryRewrite: setting.UseRewrite,
|
||||
EnableRerank: setting.UseRerank,
|
||||
},
|
||||
NoReCallReplyMode: llm.NoReCallReplyMode(setting.NoRecallReplyMode),
|
||||
NoReCallReplyCustomizePrompt: setting.NoRecallReplyCustomizePrompt,
|
||||
}
|
||||
|
||||
knowledgeIDs := make([]int64, 0, len(fcParams.KnowledgeFCParam.KnowledgeList))
|
||||
for _, kw := range fcParams.KnowledgeFCParam.KnowledgeList {
|
||||
kid, err := strconv.ParseInt(kw.ID, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
knowledgeIDs = append(knowledgeIDs, kid)
|
||||
}
|
||||
|
||||
detailResp, err := knowledgeOperator.ListKnowledgeDetail(ctx, &crossknowledge.ListKnowledgeDetailRequest{
|
||||
KnowledgeIDs: knowledgeIDs,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cfg.SelectedKnowledgeDetails = detailResp.KnowledgeDetails
|
||||
llmConf.KnowledgeRecallConfig = cfg
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return llmConf, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToSelectorConfig() *selector.Config {
|
||||
return &selector.Config{
|
||||
Clauses: mustGetKey[[]*selector.OneClauseSchema]("Clauses", s.Configs),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NodeSchema) SelectorInputConverter(in map[string]any) (out []selector.Operants, err error) {
|
||||
conf := mustGetKey[[]*selector.OneClauseSchema]("Clauses", s.Configs)
|
||||
|
||||
for i, oneConf := range conf {
|
||||
if oneConf.Single != nil {
|
||||
left, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), selector.LeftKey})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to take left operant from input map: %v, clause index= %d", in, i)
|
||||
}
|
||||
|
||||
right, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), selector.RightKey})
|
||||
if ok {
|
||||
out = append(out, selector.Operants{Left: left, Right: right})
|
||||
} else {
|
||||
out = append(out, selector.Operants{Left: left})
|
||||
}
|
||||
} else if oneConf.Multi != nil {
|
||||
multiClause := make([]*selector.Operants, 0)
|
||||
for j := range oneConf.Multi.Clauses {
|
||||
left, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), strconv.Itoa(j), selector.LeftKey})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to take left operant from input map: %v, clause index= %d, single clause index= %d", in, i, j)
|
||||
}
|
||||
right, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), strconv.Itoa(j), selector.RightKey})
|
||||
if ok {
|
||||
multiClause = append(multiClause, &selector.Operants{Left: left, Right: right})
|
||||
} else {
|
||||
multiClause = append(multiClause, &selector.Operants{Left: left})
|
||||
}
|
||||
}
|
||||
out = append(out, selector.Operants{Multi: multiClause})
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid clause config, both single and multi are nil: %v", oneConf)
|
||||
}
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToBatchConfig(inner compose.Runnable[map[string]any, map[string]any]) (*batch.Config, error) {
|
||||
conf := &batch.Config{
|
||||
BatchNodeKey: s.Key,
|
||||
InnerWorkflow: inner,
|
||||
Outputs: s.OutputSources,
|
||||
}
|
||||
|
||||
for key, tInfo := range s.InputTypes {
|
||||
if tInfo.Type != vo.DataTypeArray {
|
||||
continue
|
||||
}
|
||||
|
||||
conf.InputArrays = append(conf.InputArrays, key)
|
||||
}
|
||||
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToVariableAggregatorConfig() (*variableaggregator.Config, error) {
|
||||
return &variableaggregator.Config{
|
||||
MergeStrategy: s.Configs.(map[string]any)["MergeStrategy"].(variableaggregator.MergeStrategy),
|
||||
GroupLen: s.Configs.(map[string]any)["GroupToLen"].(map[string]int),
|
||||
FullSources: getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs),
|
||||
NodeKey: s.Key,
|
||||
InputSources: s.InputSources,
|
||||
GroupOrder: mustGetKey[[]string]("GroupOrder", s.Configs),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) variableAggregatorInputConverter(in map[string]any) (converted map[string]map[int]any) {
|
||||
converted = make(map[string]map[int]any)
|
||||
|
||||
for k, value := range in {
|
||||
m, ok := value.(map[string]any)
|
||||
if !ok {
|
||||
panic(errors.New("value is not a map[string]any"))
|
||||
}
|
||||
converted[k] = make(map[int]any, len(m))
|
||||
for i, sv := range m {
|
||||
index, err := strconv.Atoi(i)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf(" converting %s to int failed, err=%v", i, err))
|
||||
}
|
||||
converted[k][index] = sv
|
||||
}
|
||||
}
|
||||
|
||||
return converted
|
||||
}
|
||||
|
||||
func (s *NodeSchema) variableAggregatorStreamInputConverter(in *schema.StreamReader[map[string]any]) *schema.StreamReader[map[string]map[int]any] {
|
||||
converter := func(input map[string]any) (output map[string]map[int]any, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = safego.NewPanicErr(r, debug.Stack())
|
||||
}
|
||||
}()
|
||||
return s.variableAggregatorInputConverter(input), nil
|
||||
}
|
||||
return schema.StreamReaderWithConvert(in, converter)
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToTextProcessorConfig() (*textprocessor.Config, error) {
|
||||
return &textprocessor.Config{
|
||||
Type: s.Configs.(map[string]any)["Type"].(textprocessor.Type),
|
||||
Tpl: getKeyOrZero[string]("Tpl", s.Configs.(map[string]any)),
|
||||
ConcatChar: getKeyOrZero[string]("ConcatChar", s.Configs.(map[string]any)),
|
||||
Separators: getKeyOrZero[[]string]("Separators", s.Configs.(map[string]any)),
|
||||
FullSources: getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToJsonSerializationConfig() (*json.SerializationConfig, error) {
|
||||
return &json.SerializationConfig{
|
||||
InputTypes: s.InputTypes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToJsonDeserializationConfig() (*json.DeserializationConfig, error) {
|
||||
return &json.DeserializationConfig{
|
||||
OutputFields: s.OutputTypes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToHTTPRequesterConfig() (*httprequester.Config, error) {
|
||||
return &httprequester.Config{
|
||||
URLConfig: mustGetKey[httprequester.URLConfig]("URLConfig", s.Configs),
|
||||
AuthConfig: getKeyOrZero[*httprequester.AuthenticationConfig]("AuthConfig", s.Configs),
|
||||
BodyConfig: mustGetKey[httprequester.BodyConfig]("BodyConfig", s.Configs),
|
||||
Method: mustGetKey[string]("Method", s.Configs),
|
||||
Timeout: mustGetKey[time.Duration]("Timeout", s.Configs),
|
||||
RetryTimes: mustGetKey[uint64]("RetryTimes", s.Configs),
|
||||
MD5FieldMapping: mustGetKey[httprequester.MD5FieldMapping]("MD5FieldMapping", s.Configs),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToVariableAssignerConfig(handler *variable.Handler) (*variableassigner.Config, error) {
|
||||
return &variableassigner.Config{
|
||||
Pairs: s.Configs.([]*variableassigner.Pair),
|
||||
Handler: handler,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToVariableAssignerInLoopConfig() (*variableassigner.Config, error) {
|
||||
return &variableassigner.Config{
|
||||
Pairs: s.Configs.([]*variableassigner.Pair),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToLoopConfig(inner compose.Runnable[map[string]any, map[string]any]) (*loop.Config, error) {
|
||||
conf := &loop.Config{
|
||||
LoopNodeKey: s.Key,
|
||||
LoopType: mustGetKey[loop.Type]("LoopType", s.Configs),
|
||||
IntermediateVars: getKeyOrZero[map[string]*vo.TypeInfo]("IntermediateVars", s.Configs),
|
||||
Outputs: s.OutputSources,
|
||||
Inner: inner,
|
||||
}
|
||||
|
||||
for key, tInfo := range s.InputTypes {
|
||||
if tInfo.Type != vo.DataTypeArray {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := conf.IntermediateVars[key]; ok { // exclude arrays in intermediate vars
|
||||
continue
|
||||
}
|
||||
|
||||
conf.InputArrays = append(conf.InputArrays, key)
|
||||
}
|
||||
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToQAConfig(ctx context.Context) (*qa.Config, error) {
|
||||
conf := &qa.Config{
|
||||
QuestionTpl: mustGetKey[string]("QuestionTpl", s.Configs),
|
||||
AnswerType: mustGetKey[qa.AnswerType]("AnswerType", s.Configs),
|
||||
ChoiceType: getKeyOrZero[qa.ChoiceType]("ChoiceType", s.Configs),
|
||||
FixedChoices: getKeyOrZero[[]string]("FixedChoices", s.Configs),
|
||||
ExtractFromAnswer: getKeyOrZero[bool]("ExtractFromAnswer", s.Configs),
|
||||
MaxAnswerCount: getKeyOrZero[int]("MaxAnswerCount", s.Configs),
|
||||
AdditionalSystemPromptTpl: getKeyOrZero[string]("AdditionalSystemPromptTpl", s.Configs),
|
||||
OutputFields: s.OutputTypes,
|
||||
NodeKey: s.Key,
|
||||
}
|
||||
|
||||
llmParams := getKeyOrZero[*model.LLMParams]("LLMParams", s.Configs)
|
||||
if llmParams != nil {
|
||||
m, _, err := model.GetManager().GetModel(ctx, llmParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conf.Model = m
|
||||
}
|
||||
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToInputReceiverConfig() (*receiver.Config, error) {
|
||||
return &receiver.Config{
|
||||
OutputTypes: s.OutputTypes,
|
||||
NodeKey: s.Key,
|
||||
OutputSchema: mustGetKey[string]("OutputSchema", s.Configs),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToOutputEmitterConfig(sc *WorkflowSchema) (*emitter.Config, error) {
|
||||
conf := &emitter.Config{
|
||||
Template: getKeyOrZero[string]("Template", s.Configs),
|
||||
FullSources: getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs),
|
||||
}
|
||||
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToDatabaseCustomSQLConfig() (*database.CustomSQLConfig, error) {
|
||||
return &database.CustomSQLConfig{
|
||||
DatabaseInfoID: mustGetKey[int64]("DatabaseInfoID", s.Configs),
|
||||
SQLTemplate: mustGetKey[string]("SQLTemplate", s.Configs),
|
||||
OutputConfig: s.OutputTypes,
|
||||
CustomSQLExecutor: crossdatabase.GetDatabaseOperator(),
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToDatabaseQueryConfig() (*database.QueryConfig, error) {
|
||||
return &database.QueryConfig{
|
||||
DatabaseInfoID: mustGetKey[int64]("DatabaseInfoID", s.Configs),
|
||||
QueryFields: getKeyOrZero[[]string]("QueryFields", s.Configs),
|
||||
OrderClauses: getKeyOrZero[[]*crossdatabase.OrderClause]("OrderClauses", s.Configs),
|
||||
ClauseGroup: getKeyOrZero[*crossdatabase.ClauseGroup]("ClauseGroup", s.Configs),
|
||||
OutputConfig: s.OutputTypes,
|
||||
Limit: mustGetKey[int64]("Limit", s.Configs),
|
||||
Op: crossdatabase.GetDatabaseOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToDatabaseInsertConfig() (*database.InsertConfig, error) {
|
||||
|
||||
return &database.InsertConfig{
|
||||
DatabaseInfoID: mustGetKey[int64]("DatabaseInfoID", s.Configs),
|
||||
OutputConfig: s.OutputTypes,
|
||||
Inserter: crossdatabase.GetDatabaseOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToDatabaseDeleteConfig() (*database.DeleteConfig, error) {
|
||||
return &database.DeleteConfig{
|
||||
DatabaseInfoID: mustGetKey[int64]("DatabaseInfoID", s.Configs),
|
||||
ClauseGroup: mustGetKey[*crossdatabase.ClauseGroup]("ClauseGroup", s.Configs),
|
||||
OutputConfig: s.OutputTypes,
|
||||
Deleter: crossdatabase.GetDatabaseOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToDatabaseUpdateConfig() (*database.UpdateConfig, error) {
|
||||
|
||||
return &database.UpdateConfig{
|
||||
DatabaseInfoID: mustGetKey[int64]("DatabaseInfoID", s.Configs),
|
||||
ClauseGroup: mustGetKey[*crossdatabase.ClauseGroup]("ClauseGroup", s.Configs),
|
||||
OutputConfig: s.OutputTypes,
|
||||
Updater: crossdatabase.GetDatabaseOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToKnowledgeIndexerConfig() (*knowledge.IndexerConfig, error) {
|
||||
return &knowledge.IndexerConfig{
|
||||
KnowledgeID: mustGetKey[int64]("KnowledgeID", s.Configs),
|
||||
ParsingStrategy: mustGetKey[*crossknowledge.ParsingStrategy]("ParsingStrategy", s.Configs),
|
||||
ChunkingStrategy: mustGetKey[*crossknowledge.ChunkingStrategy]("ChunkingStrategy", s.Configs),
|
||||
KnowledgeIndexer: crossknowledge.GetKnowledgeOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToKnowledgeRetrieveConfig() (*knowledge.RetrieveConfig, error) {
|
||||
return &knowledge.RetrieveConfig{
|
||||
KnowledgeIDs: mustGetKey[[]int64]("KnowledgeIDs", s.Configs),
|
||||
RetrievalStrategy: mustGetKey[*crossknowledge.RetrievalStrategy]("RetrievalStrategy", s.Configs),
|
||||
Retriever: crossknowledge.GetKnowledgeOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToKnowledgeDeleterConfig() (*knowledge.DeleterConfig, error) {
|
||||
return &knowledge.DeleterConfig{
|
||||
KnowledgeID: mustGetKey[int64]("KnowledgeID", s.Configs),
|
||||
KnowledgeDeleter: crossknowledge.GetKnowledgeOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToPluginConfig() (*plugin.Config, error) {
|
||||
return &plugin.Config{
|
||||
PluginID: mustGetKey[int64]("PluginID", s.Configs),
|
||||
ToolID: mustGetKey[int64]("ToolID", s.Configs),
|
||||
PluginVersion: mustGetKey[string]("PluginVersion", s.Configs),
|
||||
PluginService: crossplugin.GetPluginService(),
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToCodeRunnerConfig() (*code.Config, error) {
|
||||
return &code.Config{
|
||||
Code: mustGetKey[string]("Code", s.Configs),
|
||||
Language: mustGetKey[crosscode.Language]("Language", s.Configs),
|
||||
OutputConfig: s.OutputTypes,
|
||||
Runner: crosscode.GetCodeRunner(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToCreateConversationConfig() (*conversation.CreateConversationConfig, error) {
|
||||
return &conversation.CreateConversationConfig{
|
||||
Creator: crossconversation.ConversationManagerImpl,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToClearMessageConfig() (*conversation.ClearMessageConfig, error) {
|
||||
return &conversation.ClearMessageConfig{
|
||||
Clearer: crossconversation.ConversationManagerImpl,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToMessageListConfig() (*conversation.MessageListConfig, error) {
|
||||
return &conversation.MessageListConfig{
|
||||
Lister: crossconversation.ConversationManagerImpl,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToIntentDetectorConfig(ctx context.Context) (*intentdetector.Config, error) {
|
||||
cfg := &intentdetector.Config{
|
||||
Intents: mustGetKey[[]string]("Intents", s.Configs),
|
||||
SystemPrompt: getKeyOrZero[string]("SystemPrompt", s.Configs),
|
||||
IsFastMode: getKeyOrZero[bool]("IsFastMode", s.Configs),
|
||||
}
|
||||
|
||||
llmParams := mustGetKey[*model.LLMParams]("LLMParams", s.Configs)
|
||||
m, _, err := model.GetManager().GetModel(ctx, llmParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cfg.ChatModel = m
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToSubWorkflowConfig(ctx context.Context, requireCheckpoint bool) (*subworkflow.Config, error) {
|
||||
var opts []WorkflowOption
|
||||
opts = append(opts, WithIDAsName(mustGetKey[int64]("WorkflowID", s.Configs)))
|
||||
if requireCheckpoint {
|
||||
opts = append(opts, WithParentRequireCheckpoint())
|
||||
}
|
||||
if s := execute.GetStaticConfig(); s != nil && s.MaxNodeCountPerWorkflow > 0 {
|
||||
opts = append(opts, WithMaxNodeCount(s.MaxNodeCountPerWorkflow))
|
||||
}
|
||||
wf, err := NewWorkflow(ctx, s.SubWorkflowSchema, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &subworkflow.Config{
|
||||
Runner: wf.Runner,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func totRetrievalSearchType(s int64) (crossknowledge.SearchType, error) {
|
||||
switch s {
|
||||
case 0:
|
||||
return crossknowledge.SearchTypeSemantic, nil
|
||||
case 1:
|
||||
return crossknowledge.SearchTypeHybrid, nil
|
||||
case 20:
|
||||
return crossknowledge.SearchTypeFullText, nil
|
||||
default:
|
||||
return "", fmt.Errorf("invalid retrieval search type %v", s)
|
||||
}
|
||||
}
|
||||
|
||||
// knowledgeRecallChatModel the chat model used by the knowledge base recall in the LLM node, not the user-configured model
|
||||
func knowledgeRecallChatModel(ctx context.Context) (einomodel.BaseChatModel, error) {
|
||||
defaultChatModelParma := &model.LLMParams{
|
||||
ModelName: "豆包·1.5·Pro·32k",
|
||||
ModelType: 1,
|
||||
Temperature: ptr.Of(0.5),
|
||||
MaxTokens: 4096,
|
||||
}
|
||||
m, _, err := model.GetManager().GetModel(ctx, defaultChatModelParma)
|
||||
return m, err
|
||||
}
|
||||
107
backend/domain/workflow/internal/compose/utils.go
Normal file
107
backend/domain/workflow/internal/compose/utils.go
Normal file
@@ -0,0 +1,107 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package compose
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
)
|
||||
|
||||
func getKeyOrZero[T any](key string, cfg any) T {
|
||||
var zero T
|
||||
if cfg == nil {
|
||||
return zero
|
||||
}
|
||||
|
||||
m, ok := cfg.(map[string]any)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("m is not a map[string]any, actual type: %v", reflect.TypeOf(cfg)))
|
||||
}
|
||||
|
||||
if len(m) == 0 {
|
||||
return zero
|
||||
}
|
||||
|
||||
if v, ok := m[key]; ok {
|
||||
return v.(T)
|
||||
}
|
||||
|
||||
return zero
|
||||
}
|
||||
|
||||
func mustGetKey[T any](key string, cfg any) T {
|
||||
if cfg == nil {
|
||||
panic(fmt.Sprintf("mustGetKey[*any] is nil, key=%s", key))
|
||||
}
|
||||
|
||||
m, ok := cfg.(map[string]any)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("m is not a map[string]any, actual type: %v", reflect.TypeOf(cfg)))
|
||||
}
|
||||
|
||||
if _, ok := m[key]; !ok {
|
||||
panic(fmt.Sprintf("key %s does not exist in map: %v", key, m))
|
||||
}
|
||||
|
||||
v, ok := m[key].(T)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("key %s is not a %v, actual type: %v", key, reflect.TypeOf(v), reflect.TypeOf(m[key])))
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
func (s *NodeSchema) SetConfigKV(key string, value any) {
|
||||
if s.Configs == nil {
|
||||
s.Configs = make(map[string]any)
|
||||
}
|
||||
|
||||
s.Configs.(map[string]any)[key] = value
|
||||
}
|
||||
|
||||
func (s *NodeSchema) SetInputType(key string, t *vo.TypeInfo) {
|
||||
if s.InputTypes == nil {
|
||||
s.InputTypes = make(map[string]*vo.TypeInfo)
|
||||
}
|
||||
s.InputTypes[key] = t
|
||||
}
|
||||
|
||||
func (s *NodeSchema) AddInputSource(info ...*vo.FieldInfo) {
|
||||
s.InputSources = append(s.InputSources, info...)
|
||||
}
|
||||
|
||||
func (s *NodeSchema) SetOutputType(key string, t *vo.TypeInfo) {
|
||||
if s.OutputTypes == nil {
|
||||
s.OutputTypes = make(map[string]*vo.TypeInfo)
|
||||
}
|
||||
s.OutputTypes[key] = t
|
||||
}
|
||||
|
||||
func (s *NodeSchema) AddOutputSource(info ...*vo.FieldInfo) {
|
||||
s.OutputSources = append(s.OutputSources, info...)
|
||||
}
|
||||
|
||||
func (s *NodeSchema) GetSubWorkflowIdentity() (int64, string, bool) {
|
||||
if s.Type != entity.NodeTypeSubWorkflow {
|
||||
return 0, "", false
|
||||
}
|
||||
|
||||
return mustGetKey[int64]("WorkflowID", s.Configs), mustGetKey[string]("WorkflowVersion", s.Configs), true
|
||||
}
|
||||
933
backend/domain/workflow/internal/compose/workflow.go
Normal file
933
backend/domain/workflow/internal/compose/workflow.go
Normal file
@@ -0,0 +1,933 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package compose
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
workflow2 "github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
)
|
||||
|
||||
type workflow = compose.Workflow[map[string]any, map[string]any]
|
||||
|
||||
type Workflow struct { // TODO: too many fields in this struct, cut them down to the absolutely essentials
|
||||
*workflow
|
||||
hierarchy map[vo.NodeKey]vo.NodeKey
|
||||
connections []*Connection
|
||||
requireCheckpoint bool
|
||||
entry *compose.WorkflowNode
|
||||
inner bool
|
||||
fromNode bool // this workflow is constructed from a single node, without Entry or Exit nodes
|
||||
streamRun bool
|
||||
Runner compose.Runnable[map[string]any, map[string]any] // TODO: this will be unexported eventually
|
||||
input map[string]*vo.TypeInfo
|
||||
output map[string]*vo.TypeInfo
|
||||
terminatePlan vo.TerminatePlan
|
||||
schema *WorkflowSchema
|
||||
}
|
||||
|
||||
type workflowOptions struct {
|
||||
wfID int64
|
||||
idAsName bool
|
||||
parentRequireCheckpoint bool
|
||||
maxNodeCount int
|
||||
}
|
||||
|
||||
type WorkflowOption func(*workflowOptions)
|
||||
|
||||
func WithIDAsName(id int64) WorkflowOption {
|
||||
return func(opts *workflowOptions) {
|
||||
opts.wfID = id
|
||||
opts.idAsName = true
|
||||
}
|
||||
}
|
||||
|
||||
func WithParentRequireCheckpoint() WorkflowOption {
|
||||
return func(opts *workflowOptions) {
|
||||
opts.parentRequireCheckpoint = true
|
||||
}
|
||||
}
|
||||
|
||||
func WithMaxNodeCount(c int) WorkflowOption {
|
||||
return func(opts *workflowOptions) {
|
||||
opts.maxNodeCount = c
|
||||
}
|
||||
}
|
||||
|
||||
func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption) (*Workflow, error) {
|
||||
sc.Init()
|
||||
|
||||
wf := &Workflow{
|
||||
workflow: compose.NewWorkflow[map[string]any, map[string]any](compose.WithGenLocalState(GenState())),
|
||||
hierarchy: sc.Hierarchy,
|
||||
connections: sc.Connections,
|
||||
schema: sc,
|
||||
}
|
||||
|
||||
wf.streamRun = sc.requireStreaming
|
||||
wf.requireCheckpoint = sc.requireCheckPoint
|
||||
|
||||
wfOpts := &workflowOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(wfOpts)
|
||||
}
|
||||
|
||||
if wfOpts.maxNodeCount > 0 {
|
||||
if sc.NodeCount() > int32(wfOpts.maxNodeCount) {
|
||||
return nil, fmt.Errorf("node count %d exceeds the limit: %d", sc.NodeCount(), wfOpts.maxNodeCount)
|
||||
}
|
||||
}
|
||||
|
||||
if wfOpts.parentRequireCheckpoint {
|
||||
wf.requireCheckpoint = true
|
||||
}
|
||||
|
||||
wf.input = sc.GetNode(entity.EntryNodeKey).OutputTypes
|
||||
|
||||
// even if the terminate plan is use answer content, this still will be 'input types' of exit node
|
||||
wf.output = sc.GetNode(entity.ExitNodeKey).InputTypes
|
||||
|
||||
// add all composite nodes with their inner workflow
|
||||
compositeNodes := sc.GetCompositeNodes()
|
||||
processedNodeKey := make(map[vo.NodeKey]struct{})
|
||||
for i := range compositeNodes {
|
||||
cNode := compositeNodes[i]
|
||||
if err := wf.AddCompositeNode(ctx, cNode); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
processedNodeKey[cNode.Parent.Key] = struct{}{}
|
||||
for _, child := range cNode.Children {
|
||||
processedNodeKey[child.Key] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// add all nodes other than composite nodes and their children
|
||||
for _, ns := range sc.Nodes {
|
||||
if _, ok := processedNodeKey[ns.Key]; !ok {
|
||||
if err := wf.AddNode(ctx, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if ns.Type == entity.NodeTypeExit {
|
||||
wf.terminatePlan = mustGetKey[vo.TerminatePlan]("TerminalPlan", ns.Configs)
|
||||
}
|
||||
}
|
||||
|
||||
var compileOpts []compose.GraphCompileOption
|
||||
if wf.requireCheckpoint {
|
||||
compileOpts = append(compileOpts, compose.WithCheckPointStore(workflow2.GetRepository()))
|
||||
}
|
||||
if wfOpts.idAsName {
|
||||
compileOpts = append(compileOpts, compose.WithGraphName(strconv.FormatInt(wfOpts.wfID, 10)))
|
||||
}
|
||||
|
||||
fanInConfigs := sc.fanInMergeConfigs()
|
||||
if len(fanInConfigs) > 0 {
|
||||
compileOpts = append(compileOpts, compose.WithFanInMergeConfig(fanInConfigs))
|
||||
}
|
||||
|
||||
r, err := wf.Compile(ctx, compileOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
wf.Runner = r
|
||||
|
||||
return wf, nil
|
||||
}
|
||||
|
||||
func (w *Workflow) AsyncRun(ctx context.Context, in map[string]any, opts ...compose.Option) {
|
||||
if w.streamRun {
|
||||
safego.Go(ctx, func() {
|
||||
_, _ = w.Runner.Stream(ctx, in, opts...)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
safego.Go(ctx, func() {
|
||||
_, _ = w.Runner.Invoke(ctx, in, opts...)
|
||||
})
|
||||
}
|
||||
|
||||
func (w *Workflow) SyncRun(ctx context.Context, in map[string]any, opts ...compose.Option) (map[string]any, error) {
|
||||
return w.Runner.Invoke(ctx, in, opts...)
|
||||
}
|
||||
|
||||
func (w *Workflow) Inputs() map[string]*vo.TypeInfo {
|
||||
return w.input
|
||||
}
|
||||
|
||||
func (w *Workflow) Outputs() map[string]*vo.TypeInfo {
|
||||
return w.output
|
||||
}
|
||||
|
||||
func (w *Workflow) StreamRun() bool {
|
||||
return w.streamRun
|
||||
}
|
||||
|
||||
func (w *Workflow) TerminatePlan() vo.TerminatePlan {
|
||||
return w.terminatePlan
|
||||
}
|
||||
|
||||
type innerWorkflowInfo struct {
|
||||
inner compose.Runnable[map[string]any, map[string]any]
|
||||
carryOvers map[vo.NodeKey][]*compose.FieldMapping
|
||||
}
|
||||
|
||||
func (w *Workflow) AddNode(ctx context.Context, ns *NodeSchema) error {
|
||||
_, err := w.addNodeInternal(ctx, ns, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *Workflow) AddCompositeNode(ctx context.Context, cNode *CompositeNode) error {
|
||||
inner, err := w.getInnerWorkflow(ctx, cNode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = w.addNodeInternal(ctx, cNode.Parent, inner)
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *Workflow) addInnerNode(ctx context.Context, cNode *NodeSchema) (map[vo.NodeKey][]*compose.FieldMapping, error) {
|
||||
return w.addNodeInternal(ctx, cNode, nil)
|
||||
}
|
||||
|
||||
func (w *Workflow) addNodeInternal(ctx context.Context, ns *NodeSchema, inner *innerWorkflowInfo) (map[vo.NodeKey][]*compose.FieldMapping, error) {
|
||||
key := ns.Key
|
||||
var deps *dependencyInfo
|
||||
|
||||
deps, err := w.resolveDependencies(key, ns.InputSources)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if inner != nil {
|
||||
if err = deps.merge(inner.carryOvers); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var innerWorkflow compose.Runnable[map[string]any, map[string]any]
|
||||
if inner != nil {
|
||||
innerWorkflow = inner.inner
|
||||
}
|
||||
|
||||
ins, err := ns.New(ctx, innerWorkflow, w.schema, deps)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var opts []compose.GraphAddNodeOpt
|
||||
opts = append(opts, compose.WithNodeName(string(ns.Key)))
|
||||
|
||||
preHandler := ns.StatePreHandler(w.streamRun)
|
||||
if preHandler != nil {
|
||||
opts = append(opts, preHandler)
|
||||
}
|
||||
|
||||
postHandler := ns.StatePostHandler(w.streamRun)
|
||||
if postHandler != nil {
|
||||
opts = append(opts, postHandler)
|
||||
}
|
||||
|
||||
var wNode *compose.WorkflowNode
|
||||
if ins.Lambda != nil {
|
||||
wNode = w.AddLambdaNode(string(key), ins.Lambda, opts...)
|
||||
} else {
|
||||
return nil, fmt.Errorf("node instance has no Lambda: %s", key)
|
||||
}
|
||||
|
||||
if err = deps.arrayDrillDown(w.schema.GetAllNodes()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for fromNodeKey := range deps.inputsFull {
|
||||
wNode.AddInput(string(fromNodeKey))
|
||||
}
|
||||
|
||||
for fromNodeKey, fieldMappings := range deps.inputs {
|
||||
wNode.AddInput(string(fromNodeKey), fieldMappings...)
|
||||
}
|
||||
|
||||
for fromNodeKey := range deps.inputsNoDirectDependencyFull {
|
||||
wNode.AddInputWithOptions(string(fromNodeKey), nil, compose.WithNoDirectDependency())
|
||||
}
|
||||
|
||||
for fromNodeKey, fieldMappings := range deps.inputsNoDirectDependency {
|
||||
wNode.AddInputWithOptions(string(fromNodeKey), fieldMappings, compose.WithNoDirectDependency())
|
||||
}
|
||||
|
||||
for i := range deps.dependencies {
|
||||
wNode.AddDependency(string(deps.dependencies[i]))
|
||||
}
|
||||
|
||||
for i := range deps.staticValues {
|
||||
wNode.SetStaticValue(deps.staticValues[i].path, deps.staticValues[i].val)
|
||||
}
|
||||
|
||||
if ns.Type == entity.NodeTypeEntry {
|
||||
if w.entry != nil {
|
||||
return nil, errors.New("entry node already set")
|
||||
}
|
||||
w.entry = wNode
|
||||
}
|
||||
|
||||
outputPortCount, hasExceptionPort := ns.OutputPortCount()
|
||||
if outputPortCount > 1 || hasExceptionPort {
|
||||
bMapping, err := w.resolveBranch(key, outputPortCount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
branch, err := ns.GetBranch(bMapping)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_ = w.AddBranch(string(key), branch)
|
||||
}
|
||||
|
||||
return deps.inputsForParent, nil
|
||||
}
|
||||
|
||||
func (w *Workflow) Compile(ctx context.Context, opts ...compose.GraphCompileOption) (compose.Runnable[map[string]any, map[string]any], error) {
|
||||
if !w.inner && !w.fromNode {
|
||||
if w.entry == nil {
|
||||
return nil, fmt.Errorf("entry node is not set")
|
||||
}
|
||||
|
||||
w.entry.AddInput(compose.START)
|
||||
w.End().AddInput(entity.ExitNodeKey)
|
||||
}
|
||||
|
||||
return w.workflow.Compile(ctx, opts...)
|
||||
}
|
||||
|
||||
func (w *Workflow) getInnerWorkflow(ctx context.Context, cNode *CompositeNode) (*innerWorkflowInfo, error) {
|
||||
innerNodes := make(map[vo.NodeKey]*NodeSchema)
|
||||
for _, n := range cNode.Children {
|
||||
innerNodes[n.Key] = n
|
||||
}
|
||||
|
||||
// trim the connections, only keep the connections that are related to the inner workflow
|
||||
// ignore the cases when we have nested inner workflows, because we do not support nested composite nodes
|
||||
innerConnections := make([]*Connection, 0)
|
||||
for i := range w.schema.Connections {
|
||||
conn := w.schema.Connections[i]
|
||||
if _, ok := innerNodes[conn.FromNode]; ok {
|
||||
innerConnections = append(innerConnections, conn)
|
||||
} else if _, ok := innerNodes[conn.ToNode]; ok {
|
||||
innerConnections = append(innerConnections, conn)
|
||||
}
|
||||
}
|
||||
|
||||
inner := &Workflow{
|
||||
workflow: compose.NewWorkflow[map[string]any, map[string]any](compose.WithGenLocalState(GenState())),
|
||||
hierarchy: w.hierarchy, // we keep the entire hierarchy because inner workflow nodes can refer to parent nodes' outputs
|
||||
connections: innerConnections,
|
||||
inner: true,
|
||||
requireCheckpoint: w.requireCheckpoint,
|
||||
schema: w.schema,
|
||||
}
|
||||
|
||||
carryOvers := make(map[vo.NodeKey][]*compose.FieldMapping)
|
||||
|
||||
for key := range innerNodes {
|
||||
inputsForParent, err := inner.addInnerNode(ctx, innerNodes[key])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for fromNodeKey, fieldMappings := range inputsForParent {
|
||||
if fromNodeKey == cNode.Parent.Key { // refer to parent itself, no need to carry over
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := carryOvers[fromNodeKey]; !ok {
|
||||
carryOvers[fromNodeKey] = make([]*compose.FieldMapping, 0)
|
||||
}
|
||||
|
||||
for _, fm := range fieldMappings {
|
||||
duplicate := false
|
||||
for _, existing := range carryOvers[fromNodeKey] {
|
||||
if fm.Equals(existing) {
|
||||
duplicate = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !duplicate {
|
||||
carryOvers[fromNodeKey] = append(carryOvers[fromNodeKey], fieldMappings...)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
endDeps, err := inner.resolveDependenciesAsParent(cNode.Parent.Key, cNode.Parent.OutputSources)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolve dependencies of parent node: %s failed: %w", cNode.Parent.Key, err)
|
||||
}
|
||||
|
||||
n := inner.End()
|
||||
|
||||
for fromNodeKey := range endDeps.inputsFull {
|
||||
n.AddInput(string(fromNodeKey))
|
||||
}
|
||||
|
||||
for fromNodeKey, fieldMappings := range endDeps.inputs {
|
||||
n.AddInput(string(fromNodeKey), fieldMappings...)
|
||||
}
|
||||
|
||||
for fromNodeKey := range endDeps.inputsNoDirectDependencyFull {
|
||||
n.AddInputWithOptions(string(fromNodeKey), nil, compose.WithNoDirectDependency())
|
||||
}
|
||||
|
||||
for fromNodeKey, fieldMappings := range endDeps.inputsNoDirectDependency {
|
||||
n.AddInputWithOptions(string(fromNodeKey), fieldMappings, compose.WithNoDirectDependency())
|
||||
}
|
||||
|
||||
for i := range endDeps.dependencies {
|
||||
n.AddDependency(string(endDeps.dependencies[i]))
|
||||
}
|
||||
|
||||
for i := range endDeps.staticValues {
|
||||
n.SetStaticValue(endDeps.staticValues[i].path, endDeps.staticValues[i].val)
|
||||
}
|
||||
|
||||
var opts []compose.GraphCompileOption
|
||||
if inner.requireCheckpoint {
|
||||
opts = append(opts, compose.WithCheckPointStore(workflow2.GetRepository()))
|
||||
}
|
||||
|
||||
r, err := inner.Compile(ctx, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &innerWorkflowInfo{
|
||||
inner: r,
|
||||
carryOvers: carryOvers,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type dependencyInfo struct {
|
||||
inputs map[vo.NodeKey][]*compose.FieldMapping
|
||||
inputsFull map[vo.NodeKey]struct{}
|
||||
dependencies []vo.NodeKey
|
||||
inputsNoDirectDependency map[vo.NodeKey][]*compose.FieldMapping
|
||||
inputsNoDirectDependencyFull map[vo.NodeKey]struct{}
|
||||
staticValues []*staticValue
|
||||
variableInfos []*variableInfo
|
||||
inputsForParent map[vo.NodeKey][]*compose.FieldMapping
|
||||
}
|
||||
|
||||
func (d *dependencyInfo) merge(mappings map[vo.NodeKey][]*compose.FieldMapping) error {
|
||||
for nKey, fms := range mappings {
|
||||
if _, ok := d.inputsFull[nKey]; ok {
|
||||
return fmt.Errorf("duplicate input for node: %s", nKey)
|
||||
}
|
||||
|
||||
if _, ok := d.inputsNoDirectDependencyFull[nKey]; ok {
|
||||
return fmt.Errorf("duplicate input for node: %s", nKey)
|
||||
}
|
||||
|
||||
if currentFMS, ok := d.inputs[nKey]; ok {
|
||||
for i := range fms {
|
||||
fm := fms[i]
|
||||
duplicate := false
|
||||
for _, currentFM := range currentFMS {
|
||||
if fm.Equals(currentFM) {
|
||||
duplicate = true
|
||||
}
|
||||
}
|
||||
|
||||
if !duplicate {
|
||||
d.inputs[nKey] = append(d.inputs[nKey], fm)
|
||||
}
|
||||
}
|
||||
} else if currentFMS, ok = d.inputsNoDirectDependency[nKey]; ok {
|
||||
for i := range fms {
|
||||
fm := fms[i]
|
||||
duplicate := false
|
||||
for _, currentFM := range currentFMS {
|
||||
if fm.Equals(currentFM) {
|
||||
duplicate = true
|
||||
}
|
||||
}
|
||||
|
||||
if !duplicate {
|
||||
d.inputsNoDirectDependency[nKey] = append(d.inputsNoDirectDependency[nKey], fm)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
currentDependency := -1
|
||||
for i, depKey := range d.dependencies {
|
||||
if depKey == nKey {
|
||||
currentDependency = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if currentDependency >= 0 {
|
||||
d.dependencies = append(d.dependencies[:currentDependency], d.dependencies[currentDependency+1:]...)
|
||||
d.inputs[nKey] = append(d.inputs[nKey], fms...)
|
||||
} else {
|
||||
d.inputsNoDirectDependency[nKey] = append(d.inputsNoDirectDependency[nKey], fms...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// arrayDrillDown happens when the 'mapping from path' is taking fields from elements within arrays.
|
||||
// when this happens, we automatically takes the first element from any arrays along the 'from path'.
|
||||
// For example, if the 'from path' is ['a', 'b', 'c'], and 'b' is an array, we will take value using a.b[0].c.
|
||||
// As a counter example, if the 'from path' is ['a', 'b', 'c'], and 'b' is not an array, but 'c' is an array,
|
||||
// we will not try to drill, instead, just take value using a.b.c.
|
||||
func (d *dependencyInfo) arrayDrillDown(allNS map[vo.NodeKey]*NodeSchema) error {
|
||||
for nKey, fms := range d.inputs {
|
||||
if nKey == compose.START { // reference to START node would NEVER need to do array drill down
|
||||
continue
|
||||
}
|
||||
|
||||
var ot map[string]*vo.TypeInfo
|
||||
ots, ok := allNS[nKey]
|
||||
if !ok {
|
||||
return fmt.Errorf("node not found: %s", nKey)
|
||||
}
|
||||
ot = ots.OutputTypes
|
||||
for i := range fms {
|
||||
fm := fms[i]
|
||||
newFM, err := arrayDrillDown(nKey, fm, ot)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fms[i] = newFM
|
||||
}
|
||||
}
|
||||
|
||||
for nKey, fms := range d.inputsNoDirectDependency {
|
||||
if nKey == compose.START {
|
||||
continue
|
||||
}
|
||||
|
||||
var ot map[string]*vo.TypeInfo
|
||||
ots, ok := allNS[nKey]
|
||||
if !ok {
|
||||
return fmt.Errorf("node not found: %s", nKey)
|
||||
}
|
||||
ot = ots.OutputTypes
|
||||
for i := range fms {
|
||||
fm := fms[i]
|
||||
newFM, err := arrayDrillDown(nKey, fm, ot)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fms[i] = newFM
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func arrayDrillDown(nKey vo.NodeKey, fm *compose.FieldMapping, types map[string]*vo.TypeInfo) (*compose.FieldMapping, error) {
|
||||
fromPath := fm.FromPath()
|
||||
if len(fromPath) <= 1 { // no need to drill down
|
||||
return fm, nil
|
||||
}
|
||||
|
||||
ct := types
|
||||
var arraySegIndexes []int
|
||||
for j := 0; j < len(fromPath)-1; j++ {
|
||||
p := fromPath[j]
|
||||
t, ok := ct[p]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("type info not found for path: %s", fm.FromPath()[:j+1])
|
||||
}
|
||||
|
||||
if t.Type == vo.DataTypeArray {
|
||||
arraySegIndexes = append(arraySegIndexes, j)
|
||||
if t.ElemTypeInfo.Type == vo.DataTypeObject {
|
||||
ct = t.ElemTypeInfo.Properties
|
||||
} else if j != len(fromPath)-1 {
|
||||
return nil, fmt.Errorf("[arrayDrillDown] already found array of none obj, but still not last segment of path: %v",
|
||||
fromPath[:j+1])
|
||||
}
|
||||
} else if t.Type == vo.DataTypeObject {
|
||||
ct = t.Properties
|
||||
} else if j != len(fromPath)-1 {
|
||||
return nil, fmt.Errorf("[arrayDrillDown] found non-array, non-obj type: %v, but still not last segment of path: %v",
|
||||
t.Type, fromPath[:j+1])
|
||||
}
|
||||
}
|
||||
|
||||
if len(arraySegIndexes) == 0 { // no arrays along from path
|
||||
return fm, nil
|
||||
}
|
||||
|
||||
extractor := func(a any) (any, error) {
|
||||
for j := range fromPath {
|
||||
p := fromPath[j]
|
||||
m, ok := a.(map[string]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("[arrayDrillDown] trying to drill down from a non-map type:%T of path %s, "+
|
||||
"from node key: %v", a, fromPath[:j+1], nKey)
|
||||
}
|
||||
a, ok = m[p]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("[arrayDrillDown] field %s not found along from path: %s, "+
|
||||
"from node key: %v", p, fromPath[:j+1], nKey)
|
||||
}
|
||||
if slices.Contains(arraySegIndexes, j) { // this is an array needs drilling down
|
||||
arr, ok := a.([]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("[arrayDrillDown] trying to drill down from a non-array type:%T of path %s, "+
|
||||
"from node key: %v", a, fromPath[:j+1], nKey)
|
||||
}
|
||||
|
||||
if len(arr) == 0 {
|
||||
return nil, fmt.Errorf("[arrayDrillDown] trying to drill down from an array of length 0: %s, "+
|
||||
"from node key: %v", fromPath[:j+1], nKey)
|
||||
}
|
||||
|
||||
a = arr[0]
|
||||
}
|
||||
}
|
||||
|
||||
return a, nil
|
||||
}
|
||||
|
||||
newFM := compose.ToFieldPath(fm.ToPath(), compose.WithCustomExtractor(extractor))
|
||||
return newFM, nil
|
||||
}
|
||||
|
||||
type staticValue struct {
|
||||
val any
|
||||
path compose.FieldPath
|
||||
}
|
||||
|
||||
type variableInfo struct {
|
||||
varType vo.GlobalVarType
|
||||
fromPath compose.FieldPath
|
||||
toPath compose.FieldPath
|
||||
}
|
||||
|
||||
func (w *Workflow) resolveBranch(n vo.NodeKey, portCount int) (*BranchMapping, error) {
|
||||
m := make([]map[string]bool, portCount)
|
||||
var exception map[string]bool
|
||||
|
||||
for _, conn := range w.connections {
|
||||
if conn.FromNode != n {
|
||||
continue
|
||||
}
|
||||
|
||||
if conn.FromPort == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if *conn.FromPort == "default" { // default condition
|
||||
if m[portCount-1] == nil {
|
||||
m[portCount-1] = make(map[string]bool)
|
||||
}
|
||||
m[portCount-1][string(conn.ToNode)] = true
|
||||
} else if *conn.FromPort == "branch_error" {
|
||||
if exception == nil {
|
||||
exception = make(map[string]bool)
|
||||
}
|
||||
exception[string(conn.ToNode)] = true
|
||||
} else {
|
||||
if !strings.HasPrefix(*conn.FromPort, "branch_") {
|
||||
return nil, fmt.Errorf("outgoing connections has invalid port= %s", *conn.FromPort)
|
||||
}
|
||||
|
||||
index := (*conn.FromPort)[7:]
|
||||
i, err := strconv.Atoi(index)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("outgoing connections has invalid port index= %s", *conn.FromPort)
|
||||
}
|
||||
if i < 0 || i >= portCount {
|
||||
return nil, fmt.Errorf("outgoing connections has invalid port index range= %d, condition count= %d", i, portCount)
|
||||
}
|
||||
if m[i] == nil {
|
||||
m[i] = make(map[string]bool)
|
||||
}
|
||||
m[i][string(conn.ToNode)] = true
|
||||
}
|
||||
}
|
||||
|
||||
return &BranchMapping{
|
||||
Normal: m,
|
||||
Exception: exception,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.FieldInfo) (*dependencyInfo, error) {
|
||||
var (
|
||||
inputs = make(map[vo.NodeKey][]*compose.FieldMapping)
|
||||
inputFull map[vo.NodeKey]struct{}
|
||||
dependencies []vo.NodeKey
|
||||
inputsNoDirectDependency = make(map[vo.NodeKey][]*compose.FieldMapping)
|
||||
inputsNoDirectDependencyFull map[vo.NodeKey]struct{}
|
||||
staticValues []*staticValue
|
||||
variableInfos []*variableInfo
|
||||
|
||||
// inputsForParent contains all the field mappings from any nodes of the parent workflow
|
||||
inputsForParent = make(map[vo.NodeKey][]*compose.FieldMapping)
|
||||
)
|
||||
|
||||
connMap := make(map[vo.NodeKey]Connection)
|
||||
for _, conn := range w.connections {
|
||||
if conn.ToNode != n {
|
||||
continue
|
||||
}
|
||||
|
||||
connMap[conn.FromNode] = *conn
|
||||
}
|
||||
|
||||
for _, swp := range sourceWithPaths {
|
||||
if swp.Source.Val != nil {
|
||||
staticValues = append(staticValues, &staticValue{
|
||||
val: swp.Source.Val,
|
||||
path: swp.Path,
|
||||
})
|
||||
} else if swp.Source.Ref != nil {
|
||||
fromNode := swp.Source.Ref.FromNodeKey
|
||||
|
||||
if fromNode == n {
|
||||
return nil, fmt.Errorf("node %s cannot refer to itself, fromPath: %v, toPath: %v", n,
|
||||
swp.Source.Ref.FromPath, swp.Path)
|
||||
}
|
||||
|
||||
if swp.Source.Ref.VariableType != nil {
|
||||
// skip all variables, they are handled in state pre handler
|
||||
variableInfos = append(variableInfos, &variableInfo{
|
||||
varType: *swp.Source.Ref.VariableType,
|
||||
fromPath: swp.Source.Ref.FromPath,
|
||||
toPath: swp.Path,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
if ok := isInSameWorkflow(w.hierarchy, n, fromNode); ok {
|
||||
if _, ok := connMap[fromNode]; ok { // direct dependency
|
||||
if len(swp.Source.Ref.FromPath) == 0 && len(swp.Path) == 0 {
|
||||
if inputFull == nil {
|
||||
inputFull = make(map[vo.NodeKey]struct{})
|
||||
}
|
||||
inputFull[fromNode] = struct{}{}
|
||||
} else {
|
||||
inputs[fromNode] = append(inputs[fromNode], compose.MapFieldPaths(swp.Source.Ref.FromPath, swp.Path))
|
||||
}
|
||||
} else { // indirect dependency
|
||||
if len(swp.Source.Ref.FromPath) == 0 && len(swp.Path) == 0 {
|
||||
if inputsNoDirectDependencyFull == nil {
|
||||
inputsNoDirectDependencyFull = make(map[vo.NodeKey]struct{})
|
||||
}
|
||||
inputsNoDirectDependencyFull[fromNode] = struct{}{}
|
||||
} else {
|
||||
inputsNoDirectDependency[fromNode] = append(inputsNoDirectDependency[fromNode],
|
||||
compose.MapFieldPaths(swp.Source.Ref.FromPath, swp.Path))
|
||||
}
|
||||
}
|
||||
} else if ok := isBelowOneLevel(w.hierarchy, n, fromNode); ok {
|
||||
firstNodesInInnerWorkflow := true
|
||||
for _, conn := range connMap {
|
||||
if isInSameWorkflow(w.hierarchy, n, conn.FromNode) {
|
||||
// there is another node 'conn.FromNode' that connects to this node, while also at the same level
|
||||
firstNodesInInnerWorkflow = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if firstNodesInInnerWorkflow { // one of the first nodes in sub workflow
|
||||
inputs[compose.START] = append(inputs[compose.START],
|
||||
compose.MapFieldPaths(
|
||||
// the START node of inner workflow will proxy for the fields required from parent workflow
|
||||
// the field path within START node is prepended by the parent node key
|
||||
joinFieldPath(append(compose.FieldPath{string(fromNode)}, swp.Source.Ref.FromPath...)),
|
||||
swp.Path))
|
||||
} else { // not one of the first nodes in sub workflow, either succeeds other nodes or succeeds branches
|
||||
inputsNoDirectDependency[compose.START] = append(inputsNoDirectDependency[compose.START],
|
||||
compose.MapFieldPaths(
|
||||
// same as above, the START node of inner workflow proxies for the fields from parent workflow
|
||||
joinFieldPath(append(compose.FieldPath{string(fromNode)}, swp.Source.Ref.FromPath...)),
|
||||
swp.Path))
|
||||
}
|
||||
|
||||
fieldMapping := compose.MapFieldPaths(swp.Source.Ref.FromPath,
|
||||
// our parent node will proxy for these field mappings, prepending the 'fromNode' to paths
|
||||
joinFieldPath(append(compose.FieldPath{string(fromNode)}, swp.Source.Ref.FromPath...)))
|
||||
added := false
|
||||
for _, existedFieldMapping := range inputsForParent[fromNode] {
|
||||
if existedFieldMapping.Equals(fieldMapping) {
|
||||
added = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !added {
|
||||
inputsForParent[fromNode] = append(inputsForParent[fromNode], fieldMapping)
|
||||
}
|
||||
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("inputField's Val and Ref are both nil. path= %v", swp.Path)
|
||||
}
|
||||
}
|
||||
|
||||
for fromNodeKey, conn := range connMap {
|
||||
if conn.FromPort != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if isBelowOneLevel(w.hierarchy, n, fromNodeKey) {
|
||||
fromNodeKey = compose.START
|
||||
} else if !isInSameWorkflow(w.hierarchy, n, fromNodeKey) {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := inputs[fromNodeKey]; !ok {
|
||||
if _, ok := inputsNoDirectDependency[fromNodeKey]; !ok {
|
||||
var hasFullInput, hasFullDataInput bool
|
||||
if inputFull != nil {
|
||||
if _, ok = inputFull[fromNodeKey]; ok {
|
||||
hasFullInput = true
|
||||
}
|
||||
}
|
||||
if inputsNoDirectDependencyFull != nil {
|
||||
if _, ok = inputsNoDirectDependencyFull[fromNodeKey]; ok {
|
||||
hasFullDataInput = true
|
||||
}
|
||||
}
|
||||
if !hasFullInput && !hasFullDataInput {
|
||||
dependencies = append(dependencies, fromNodeKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &dependencyInfo{
|
||||
inputs: inputs,
|
||||
inputsFull: inputFull,
|
||||
dependencies: dependencies,
|
||||
inputsNoDirectDependency: inputsNoDirectDependency,
|
||||
inputsNoDirectDependencyFull: inputsNoDirectDependencyFull,
|
||||
staticValues: staticValues,
|
||||
variableInfos: variableInfos,
|
||||
inputsForParent: inputsForParent,
|
||||
}, nil
|
||||
}
|
||||
|
||||
const fieldPathSplitter = "#"
|
||||
|
||||
func joinFieldPath(f compose.FieldPath) compose.FieldPath {
|
||||
return []string{strings.Join(f, fieldPathSplitter)}
|
||||
}
|
||||
|
||||
func (w *Workflow) resolveDependenciesAsParent(n vo.NodeKey, sourceWithPaths []*vo.FieldInfo) (*dependencyInfo, error) {
|
||||
var (
|
||||
// inputsFull and inputsNoDirectDependencyFull are NEVER used in this case,
|
||||
// because a composite node MUST use explicit field mappings from inner nodes as its output.
|
||||
inputs = make(map[vo.NodeKey][]*compose.FieldMapping)
|
||||
dependencies []vo.NodeKey
|
||||
inputsNoDirectDependency = make(map[vo.NodeKey][]*compose.FieldMapping)
|
||||
// although staticValues are not used for current composite nodes,
|
||||
// they may be used in the future, so we calculate them none the less.
|
||||
staticValues []*staticValue
|
||||
// variableInfos are normally handled in state pre handler, but in the case of composite node's output,
|
||||
// we need to handle them within composite node's state post handler,
|
||||
variableInfos []*variableInfo
|
||||
)
|
||||
|
||||
connMap := make(map[vo.NodeKey]Connection)
|
||||
for _, conn := range w.connections {
|
||||
if conn.ToNode != n {
|
||||
continue
|
||||
}
|
||||
|
||||
if isInSameWorkflow(w.hierarchy, conn.FromNode, n) {
|
||||
continue
|
||||
}
|
||||
|
||||
connMap[conn.FromNode] = *conn
|
||||
}
|
||||
|
||||
for _, swp := range sourceWithPaths {
|
||||
if swp.Source.Ref == nil {
|
||||
staticValues = append(staticValues, &staticValue{
|
||||
val: swp.Source.Val,
|
||||
path: swp.Path,
|
||||
})
|
||||
} else if swp.Source.Ref != nil {
|
||||
if swp.Source.Ref.VariableType != nil {
|
||||
variableInfos = append(variableInfos, &variableInfo{
|
||||
varType: *swp.Source.Ref.VariableType,
|
||||
fromPath: swp.Source.Ref.FromPath,
|
||||
toPath: swp.Path,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
fromNode := swp.Source.Ref.FromNodeKey
|
||||
if fromNode == n {
|
||||
return nil, fmt.Errorf("node %s cannot refer to itself, fromPath= %v, toPath= %v", n,
|
||||
swp.Source.Ref.FromPath, swp.Path)
|
||||
}
|
||||
|
||||
if ok := isParentOf(w.hierarchy, n, fromNode); ok {
|
||||
if _, ok := connMap[fromNode]; ok { // direct dependency
|
||||
inputs[fromNode] = append(inputs[fromNode], compose.MapFieldPaths(swp.Source.Ref.FromPath, append(compose.FieldPath{string(fromNode)}, swp.Source.Ref.FromPath...)))
|
||||
} else { // indirect dependency
|
||||
inputsNoDirectDependency[fromNode] = append(inputsNoDirectDependency[fromNode], compose.MapFieldPaths(swp.Source.Ref.FromPath, append(compose.FieldPath{string(fromNode)}, swp.Source.Ref.FromPath...)))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("composite node's output field's Val and Ref are both nil. path= %v", swp.Path)
|
||||
}
|
||||
}
|
||||
|
||||
for fromNodeKey, conn := range connMap {
|
||||
if conn.FromPort != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := inputs[fromNodeKey]; !ok {
|
||||
if _, ok := inputsNoDirectDependency[fromNodeKey]; !ok {
|
||||
dependencies = append(dependencies, fromNodeKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &dependencyInfo{
|
||||
inputs: inputs,
|
||||
dependencies: dependencies,
|
||||
inputsNoDirectDependency: inputsNoDirectDependency,
|
||||
staticValues: staticValues,
|
||||
variableInfos: variableInfos,
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package compose
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
workflow2 "github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
)
|
||||
|
||||
func NewWorkflowFromNode(ctx context.Context, sc *WorkflowSchema, nodeKey vo.NodeKey, opts ...compose.GraphCompileOption) (
|
||||
*Workflow, error) {
|
||||
sc.Init()
|
||||
ns := sc.GetNode(nodeKey)
|
||||
|
||||
wf := &Workflow{
|
||||
workflow: compose.NewWorkflow[map[string]any, map[string]any](compose.WithGenLocalState(GenState())),
|
||||
hierarchy: sc.Hierarchy,
|
||||
connections: sc.Connections,
|
||||
schema: sc,
|
||||
fromNode: true,
|
||||
streamRun: false, // single node run can only invoke
|
||||
requireCheckpoint: sc.requireCheckPoint,
|
||||
input: ns.InputTypes,
|
||||
output: ns.OutputTypes,
|
||||
terminatePlan: vo.ReturnVariables,
|
||||
}
|
||||
|
||||
compositeNodes := sc.GetCompositeNodes()
|
||||
processedNodeKey := make(map[vo.NodeKey]struct{})
|
||||
for i := range compositeNodes {
|
||||
cNode := compositeNodes[i]
|
||||
if err := wf.AddCompositeNode(ctx, cNode); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
processedNodeKey[cNode.Parent.Key] = struct{}{}
|
||||
for _, child := range cNode.Children {
|
||||
processedNodeKey[child.Key] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// add all nodes other than composite nodes and their children
|
||||
for _, ns := range sc.Nodes {
|
||||
if _, ok := processedNodeKey[ns.Key]; !ok {
|
||||
if err := wf.AddNode(ctx, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
wf.End().AddInput(string(nodeKey))
|
||||
|
||||
var compileOpts []compose.GraphCompileOption
|
||||
compileOpts = append(compileOpts, opts...)
|
||||
if wf.requireCheckpoint {
|
||||
compileOpts = append(compileOpts, compose.WithCheckPointStore(workflow2.GetRepository()))
|
||||
}
|
||||
|
||||
r, err := wf.Compile(ctx, compileOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
wf.Runner = r
|
||||
|
||||
return wf, nil
|
||||
}
|
||||
292
backend/domain/workflow/internal/compose/workflow_run.go
Normal file
292
backend/domain/workflow/internal/compose/workflow_run.go
Normal file
@@ -0,0 +1,292 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package compose
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
einoCompose "github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
wf "github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
)
|
||||
|
||||
type WorkflowRunner struct {
|
||||
basic *entity.WorkflowBasic
|
||||
input string
|
||||
resumeReq *entity.ResumeRequest
|
||||
schema *WorkflowSchema
|
||||
streamWriter *schema.StreamWriter[*entity.Message]
|
||||
config vo.ExecuteConfig
|
||||
|
||||
executeID int64
|
||||
eventChan chan *execute.Event
|
||||
interruptEvent *entity.InterruptEvent
|
||||
}
|
||||
|
||||
type workflowRunOptions struct {
|
||||
input string
|
||||
resumeReq *entity.ResumeRequest
|
||||
streamWriter *schema.StreamWriter[*entity.Message]
|
||||
rootTokenCollector *execute.TokenCollector
|
||||
}
|
||||
|
||||
type WorkflowRunnerOption func(*workflowRunOptions)
|
||||
|
||||
func WithInput(input string) WorkflowRunnerOption {
|
||||
return func(opts *workflowRunOptions) {
|
||||
opts.input = input
|
||||
}
|
||||
}
|
||||
func WithResumeReq(resumeReq *entity.ResumeRequest) WorkflowRunnerOption {
|
||||
return func(opts *workflowRunOptions) {
|
||||
opts.resumeReq = resumeReq
|
||||
}
|
||||
}
|
||||
func WithStreamWriter(sw *schema.StreamWriter[*entity.Message]) WorkflowRunnerOption {
|
||||
return func(opts *workflowRunOptions) {
|
||||
opts.streamWriter = sw
|
||||
}
|
||||
}
|
||||
|
||||
func NewWorkflowRunner(b *entity.WorkflowBasic, sc *WorkflowSchema, config vo.ExecuteConfig, opts ...WorkflowRunnerOption) *WorkflowRunner {
|
||||
options := &workflowRunOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
return &WorkflowRunner{
|
||||
basic: b,
|
||||
input: options.input,
|
||||
resumeReq: options.resumeReq,
|
||||
schema: sc,
|
||||
streamWriter: options.streamWriter,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *WorkflowRunner) Prepare(ctx context.Context) (
|
||||
context.Context,
|
||||
int64,
|
||||
[]einoCompose.Option,
|
||||
<-chan *execute.Event,
|
||||
error,
|
||||
) {
|
||||
var (
|
||||
err error
|
||||
executeID int64
|
||||
repo = wf.GetRepository()
|
||||
resumeReq = r.resumeReq
|
||||
wb = r.basic
|
||||
sc = r.schema
|
||||
sw = r.streamWriter
|
||||
config = r.config
|
||||
)
|
||||
|
||||
if r.resumeReq == nil {
|
||||
executeID, err = repo.GenID(ctx)
|
||||
if err != nil {
|
||||
return ctx, 0, nil, nil, fmt.Errorf("failed to generate workflow execute ID: %w", err)
|
||||
}
|
||||
} else {
|
||||
executeID = resumeReq.ExecuteID
|
||||
}
|
||||
|
||||
eventChan := make(chan *execute.Event)
|
||||
|
||||
var (
|
||||
interruptEvent *entity.InterruptEvent
|
||||
found bool
|
||||
)
|
||||
|
||||
if resumeReq != nil {
|
||||
interruptEvent, found, err = repo.GetFirstInterruptEvent(ctx, executeID)
|
||||
if err != nil {
|
||||
return ctx, 0, nil, nil, err
|
||||
}
|
||||
|
||||
if !found {
|
||||
return ctx, 0, nil, nil, fmt.Errorf("interrupt event does not exist, id: %d", resumeReq.EventID)
|
||||
}
|
||||
|
||||
if interruptEvent.ID != resumeReq.EventID {
|
||||
return ctx, 0, nil, nil, fmt.Errorf("interrupt event id mismatch, expect: %d, actual: %d", resumeReq.EventID, interruptEvent.ID)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
r.executeID = executeID
|
||||
r.eventChan = eventChan
|
||||
r.interruptEvent = interruptEvent
|
||||
|
||||
ctx, composeOpts, err := r.designateOptions(ctx)
|
||||
if err != nil {
|
||||
return ctx, 0, nil, nil, err
|
||||
}
|
||||
|
||||
if interruptEvent != nil {
|
||||
var stateOpt einoCompose.Option
|
||||
stateModifier := GenStateModifierByEventType(interruptEvent.EventType,
|
||||
interruptEvent.NodeKey, resumeReq.ResumeData, r.config)
|
||||
|
||||
if len(interruptEvent.NodePath) == 1 {
|
||||
// this interrupt event is within the top level workflow
|
||||
stateOpt = einoCompose.WithStateModifier(stateModifier)
|
||||
} else {
|
||||
currentI := len(interruptEvent.NodePath) - 2
|
||||
path := interruptEvent.NodePath[currentI]
|
||||
if strings.HasPrefix(path, execute.InterruptEventIndexPrefix) {
|
||||
// this interrupt event is within a composite node
|
||||
indexStr := path[len(execute.InterruptEventIndexPrefix):]
|
||||
index, err := strconv.Atoi(indexStr)
|
||||
if err != nil {
|
||||
return ctx, 0, nil, nil, fmt.Errorf("failed to parse index: %w", err)
|
||||
}
|
||||
|
||||
currentI--
|
||||
parentNodeKey := interruptEvent.NodePath[currentI]
|
||||
stateOpt = einoCompose.WithLambdaOption(
|
||||
nodes.WithResumeIndex(index, stateModifier)).DesignateNode(parentNodeKey)
|
||||
} else { // this interrupt event is within a sub workflow
|
||||
subWorkflowNodeKey := interruptEvent.NodePath[currentI]
|
||||
stateOpt = einoCompose.WithLambdaOption(
|
||||
nodes.WithResumeIndex(0, stateModifier)).DesignateNode(subWorkflowNodeKey)
|
||||
}
|
||||
|
||||
for i := currentI - 1; i >= 0; i-- {
|
||||
path := interruptEvent.NodePath[i]
|
||||
if strings.HasPrefix(path, execute.InterruptEventIndexPrefix) {
|
||||
indexStr := path[len(execute.InterruptEventIndexPrefix):]
|
||||
index, err := strconv.Atoi(indexStr)
|
||||
if err != nil {
|
||||
return ctx, 0, nil, nil, fmt.Errorf("failed to parse index: %w", err)
|
||||
}
|
||||
|
||||
i--
|
||||
parentNodeKey := interruptEvent.NodePath[i]
|
||||
stateOpt = WrapOptWithIndex(stateOpt, vo.NodeKey(parentNodeKey), index)
|
||||
} else {
|
||||
stateOpt = WrapOpt(stateOpt, vo.NodeKey(path))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
composeOpts = append(composeOpts, stateOpt)
|
||||
|
||||
if interruptEvent.EventType == entity.InterruptEventQuestion {
|
||||
modifiedData, err := qa.AppendInterruptData(interruptEvent.InterruptData, resumeReq.ResumeData)
|
||||
if err != nil {
|
||||
return ctx, 0, nil, nil, fmt.Errorf("failed to append interrupt data: %w", err)
|
||||
}
|
||||
interruptEvent.InterruptData = modifiedData
|
||||
if err = repo.UpdateFirstInterruptEvent(ctx, executeID, interruptEvent); err != nil {
|
||||
return ctx, 0, nil, nil, fmt.Errorf("failed to update interrupt event: %w", err)
|
||||
}
|
||||
} else if interruptEvent.EventType == entity.InterruptEventLLM &&
|
||||
interruptEvent.ToolInterruptEvent.EventType == entity.InterruptEventQuestion {
|
||||
modifiedData, err := qa.AppendInterruptData(interruptEvent.ToolInterruptEvent.InterruptData, resumeReq.ResumeData)
|
||||
if err != nil {
|
||||
return ctx, 0, nil, nil, fmt.Errorf("failed to append interrupt data for LLM node: %w", err)
|
||||
}
|
||||
interruptEvent.ToolInterruptEvent.InterruptData = modifiedData
|
||||
if err = repo.UpdateFirstInterruptEvent(ctx, executeID, interruptEvent); err != nil {
|
||||
return ctx, 0, nil, nil, fmt.Errorf("failed to update interrupt event: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
success, currentStatus, err := repo.TryLockWorkflowExecution(ctx, executeID, resumeReq.EventID)
|
||||
if err != nil {
|
||||
return ctx, 0, nil, nil, fmt.Errorf("try lock workflow execution unexpected err: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
return ctx, 0, nil, nil, fmt.Errorf("workflow execution lock failed, current status is %v, executeID: %d", currentStatus, executeID)
|
||||
}
|
||||
|
||||
logs.CtxInfof(ctx, "resuming with eventID: %d, executeID: %d, nodeKey: %s", interruptEvent.ID,
|
||||
executeID, interruptEvent.NodeKey)
|
||||
}
|
||||
|
||||
if interruptEvent == nil {
|
||||
var logID string
|
||||
logID, _ = ctx.Value("log-id").(string)
|
||||
|
||||
wfExec := &entity.WorkflowExecution{
|
||||
ID: executeID,
|
||||
WorkflowID: wb.ID,
|
||||
Version: wb.Version,
|
||||
SpaceID: wb.SpaceID,
|
||||
ExecuteConfig: config,
|
||||
Status: entity.WorkflowRunning,
|
||||
Input: ptr.Of(r.input),
|
||||
RootExecutionID: executeID,
|
||||
NodeCount: sc.NodeCount(),
|
||||
CurrentResumingEventID: ptr.Of(int64(0)),
|
||||
CommitID: wb.CommitID,
|
||||
LogID: logID,
|
||||
}
|
||||
|
||||
if err = repo.CreateWorkflowExecution(ctx, wfExec); err != nil {
|
||||
return ctx, 0, nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
cancelCtx, cancelFn := context.WithCancel(ctx)
|
||||
var timeoutFn context.CancelFunc
|
||||
if s := execute.GetStaticConfig(); s != nil {
|
||||
timeout := ternary.IFElse(config.TaskType == vo.TaskTypeBackground, s.BackgroundRunTimeout, s.ForegroundRunTimeout)
|
||||
if timeout > 0 {
|
||||
cancelCtx, timeoutFn = context.WithTimeout(cancelCtx, timeout)
|
||||
}
|
||||
}
|
||||
|
||||
cancelCtx = execute.InitExecutedNodesCounter(cancelCtx)
|
||||
|
||||
lastEventChan := make(chan *execute.Event, 1)
|
||||
go func() {
|
||||
defer func() {
|
||||
if panicErr := recover(); panicErr != nil {
|
||||
logs.CtxErrorf(ctx, "panic when handling execute event: %v", safego.NewPanicErr(panicErr, debug.Stack()))
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
if sw != nil {
|
||||
sw.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
// this goroutine should not use the cancelCtx because it needs to be alive to receive workflow cancel events
|
||||
lastEventChan <- execute.HandleExecuteEvent(ctx, executeID, eventChan, cancelFn, timeoutFn,
|
||||
repo, sw, config)
|
||||
close(lastEventChan)
|
||||
}()
|
||||
|
||||
return cancelCtx, executeID, composeOpts, lastEventChan, nil
|
||||
}
|
||||
336
backend/domain/workflow/internal/compose/workflow_schema.go
Normal file
336
backend/domain/workflow/internal/compose/workflow_schema.go
Normal file
@@ -0,0 +1,336 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package compose
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"maps"
|
||||
"reflect"
|
||||
"sync"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
)
|
||||
|
||||
type WorkflowSchema struct {
|
||||
Nodes []*NodeSchema `json:"nodes"`
|
||||
Connections []*Connection `json:"connections"`
|
||||
Hierarchy map[vo.NodeKey]vo.NodeKey `json:"hierarchy,omitempty"` // child node key-> parent node key
|
||||
|
||||
GeneratedNodes []vo.NodeKey `json:"generated_nodes,omitempty"` // generated nodes for the nodes in batch mode
|
||||
|
||||
nodeMap map[vo.NodeKey]*NodeSchema // won't serialize this
|
||||
compositeNodes []*CompositeNode // won't serialize this
|
||||
requireCheckPoint bool // won't serialize this
|
||||
requireStreaming bool
|
||||
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
type Connection struct {
|
||||
FromNode vo.NodeKey `json:"from_node"`
|
||||
ToNode vo.NodeKey `json:"to_node"`
|
||||
FromPort *string `json:"from_port,omitempty"`
|
||||
}
|
||||
|
||||
func (c *Connection) ID() string {
|
||||
if c.FromPort != nil {
|
||||
return fmt.Sprintf("%s:%s:%v", c.FromNode, c.ToNode, *c.FromPort)
|
||||
}
|
||||
return fmt.Sprintf("%v:%v", c.FromNode, c.ToNode)
|
||||
}
|
||||
|
||||
type CompositeNode struct {
|
||||
Parent *NodeSchema
|
||||
Children []*NodeSchema
|
||||
}
|
||||
|
||||
func (w *WorkflowSchema) Init() {
|
||||
w.once.Do(func() {
|
||||
w.nodeMap = make(map[vo.NodeKey]*NodeSchema)
|
||||
for _, node := range w.Nodes {
|
||||
w.nodeMap[node.Key] = node
|
||||
}
|
||||
|
||||
w.doGetCompositeNodes()
|
||||
|
||||
for _, node := range w.Nodes {
|
||||
if node.requireCheckpoint() {
|
||||
w.requireCheckPoint = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
w.requireStreaming = w.doRequireStreaming()
|
||||
})
|
||||
}
|
||||
|
||||
func (w *WorkflowSchema) GetNode(key vo.NodeKey) *NodeSchema {
|
||||
return w.nodeMap[key]
|
||||
}
|
||||
|
||||
func (w *WorkflowSchema) GetAllNodes() map[vo.NodeKey]*NodeSchema {
|
||||
return w.nodeMap // TODO: needs to calculate node count separately, considering batch mode nodes
|
||||
}
|
||||
|
||||
func (w *WorkflowSchema) GetCompositeNodes() []*CompositeNode {
|
||||
if w.compositeNodes == nil {
|
||||
w.compositeNodes = w.doGetCompositeNodes()
|
||||
}
|
||||
|
||||
return w.compositeNodes
|
||||
}
|
||||
|
||||
func (w *WorkflowSchema) doGetCompositeNodes() (cNodes []*CompositeNode) {
|
||||
if w.Hierarchy == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build parent to children mapping
|
||||
parentToChildren := make(map[vo.NodeKey][]*NodeSchema)
|
||||
for childKey, parentKey := range w.Hierarchy {
|
||||
if parentSchema := w.nodeMap[parentKey]; parentSchema != nil {
|
||||
if childSchema := w.nodeMap[childKey]; childSchema != nil {
|
||||
parentToChildren[parentKey] = append(parentToChildren[parentKey], childSchema)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create composite nodes
|
||||
for parentKey, children := range parentToChildren {
|
||||
if parentSchema := w.nodeMap[parentKey]; parentSchema != nil {
|
||||
cNodes = append(cNodes, &CompositeNode{
|
||||
Parent: parentSchema,
|
||||
Children: children,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return cNodes
|
||||
}
|
||||
|
||||
func isInSameWorkflow(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.NodeKey) bool {
|
||||
if n == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
myParents, myParentExists := n[nodeKey]
|
||||
theirParents, theirParentExists := n[otherNodeKey]
|
||||
|
||||
if !myParentExists && !theirParentExists {
|
||||
return true
|
||||
}
|
||||
|
||||
if !myParentExists || !theirParentExists {
|
||||
return false
|
||||
}
|
||||
|
||||
return myParents == theirParents
|
||||
}
|
||||
|
||||
func isBelowOneLevel(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.NodeKey) bool {
|
||||
if n == nil {
|
||||
return false
|
||||
}
|
||||
_, myParentExists := n[nodeKey]
|
||||
_, theirParentExists := n[otherNodeKey]
|
||||
|
||||
return myParentExists && !theirParentExists
|
||||
}
|
||||
|
||||
func isParentOf(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.NodeKey) bool {
|
||||
if n == nil {
|
||||
return false
|
||||
}
|
||||
theirParent, theirParentExists := n[otherNodeKey]
|
||||
|
||||
return theirParentExists && theirParent == nodeKey
|
||||
}
|
||||
|
||||
func (w *WorkflowSchema) IsEqual(other *WorkflowSchema) bool {
|
||||
otherConnectionsMap := make(map[string]bool, len(other.Connections))
|
||||
for _, connection := range other.Connections {
|
||||
otherConnectionsMap[connection.ID()] = true
|
||||
}
|
||||
connectionsMap := make(map[string]bool, len(other.Connections))
|
||||
for _, connection := range w.Connections {
|
||||
connectionsMap[connection.ID()] = true
|
||||
}
|
||||
if !maps.Equal(otherConnectionsMap, connectionsMap) {
|
||||
return false
|
||||
}
|
||||
otherNodeMap := make(map[vo.NodeKey]*NodeSchema, len(other.Nodes))
|
||||
for _, node := range other.Nodes {
|
||||
otherNodeMap[node.Key] = node
|
||||
}
|
||||
nodeMap := make(map[vo.NodeKey]*NodeSchema, len(w.Nodes))
|
||||
|
||||
for _, node := range w.Nodes {
|
||||
nodeMap[node.Key] = node
|
||||
}
|
||||
|
||||
if !maps.EqualFunc(otherNodeMap, nodeMap, func(node *NodeSchema, other *NodeSchema) bool {
|
||||
if node.Name != other.Name {
|
||||
return false
|
||||
}
|
||||
if !reflect.DeepEqual(node.Configs, other.Configs) {
|
||||
return false
|
||||
}
|
||||
if !reflect.DeepEqual(node.InputTypes, other.InputTypes) {
|
||||
return false
|
||||
}
|
||||
if !reflect.DeepEqual(node.InputSources, other.InputSources) {
|
||||
return false
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(node.OutputTypes, other.OutputTypes) {
|
||||
return false
|
||||
}
|
||||
if !reflect.DeepEqual(node.OutputSources, other.OutputSources) {
|
||||
return false
|
||||
}
|
||||
if !reflect.DeepEqual(node.ExceptionConfigs, other.ExceptionConfigs) {
|
||||
return false
|
||||
}
|
||||
if !reflect.DeepEqual(node.SubWorkflowBasic, other.SubWorkflowBasic) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
|
||||
}) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
|
||||
}
|
||||
|
||||
func (w *WorkflowSchema) NodeCount() int32 {
|
||||
return int32(len(w.Nodes) - len(w.GeneratedNodes))
|
||||
}
|
||||
|
||||
func (w *WorkflowSchema) doRequireStreaming() bool {
|
||||
producers := make(map[vo.NodeKey]bool)
|
||||
consumers := make(map[vo.NodeKey]bool)
|
||||
|
||||
for _, node := range w.Nodes {
|
||||
meta := entity.NodeMetaByNodeType(node.Type)
|
||||
if meta != nil {
|
||||
sps := meta.ExecutableMeta.StreamingParadigms
|
||||
if _, ok := sps[entity.Stream]; ok {
|
||||
if node.StreamConfigs != nil && node.StreamConfigs.CanGeneratesStream {
|
||||
producers[node.Key] = true
|
||||
}
|
||||
}
|
||||
|
||||
if sps[entity.Transform] || sps[entity.Collect] {
|
||||
if node.StreamConfigs != nil && node.StreamConfigs.RequireStreamingInput {
|
||||
consumers[node.Key] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(producers) == 0 || len(consumers) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Build data-flow graph from InputSources
|
||||
adj := make(map[vo.NodeKey]map[vo.NodeKey]struct{})
|
||||
for _, node := range w.Nodes {
|
||||
for _, source := range node.InputSources {
|
||||
if source.Source.Ref != nil && len(source.Source.Ref.FromNodeKey) > 0 {
|
||||
if _, ok := adj[source.Source.Ref.FromNodeKey]; !ok {
|
||||
adj[source.Source.Ref.FromNodeKey] = make(map[vo.NodeKey]struct{})
|
||||
}
|
||||
adj[source.Source.Ref.FromNodeKey][node.Key] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For each producer, traverse the graph to see if it can reach a consumer
|
||||
for p := range producers {
|
||||
q := []vo.NodeKey{p}
|
||||
visited := make(map[vo.NodeKey]bool)
|
||||
visited[p] = true
|
||||
|
||||
for len(q) > 0 {
|
||||
curr := q[0]
|
||||
q = q[1:]
|
||||
|
||||
if consumers[curr] {
|
||||
return true
|
||||
}
|
||||
|
||||
for neighbor := range adj[curr] {
|
||||
if !visited[neighbor] {
|
||||
visited[neighbor] = true
|
||||
q = append(q, neighbor)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (w *WorkflowSchema) fanInMergeConfigs() map[string]compose.FanInMergeConfig {
|
||||
// what we need to do is to see if the workflow requires streaming, if not, then no fan-in merge configs needed
|
||||
// then we find those nodes that have 'transform' or 'collect' as streaming paradigm,
|
||||
// and see if each of those nodes has multiple data predecessors, if so, it's a fan-in node.
|
||||
// then, look up the NodeTypeMeta's ExecutableMeta info and see if it requires fan-in stream merge.
|
||||
if !w.requireStreaming {
|
||||
return nil
|
||||
}
|
||||
|
||||
fanInNodes := make(map[vo.NodeKey]bool)
|
||||
for _, node := range w.Nodes {
|
||||
meta := entity.NodeMetaByNodeType(node.Type)
|
||||
if meta != nil {
|
||||
sps := meta.ExecutableMeta.StreamingParadigms
|
||||
if sps[entity.Transform] || sps[entity.Collect] {
|
||||
if node.StreamConfigs != nil && node.StreamConfigs.RequireStreamingInput {
|
||||
var predecessor *vo.NodeKey
|
||||
for _, source := range node.InputSources {
|
||||
if source.Source.Ref != nil && len(source.Source.Ref.FromNodeKey) > 0 {
|
||||
if predecessor != nil {
|
||||
fanInNodes[node.Key] = true
|
||||
break
|
||||
}
|
||||
predecessor = &source.Source.Ref.FromNodeKey
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fanInConfigs := make(map[string]compose.FanInMergeConfig)
|
||||
for nodeKey := range fanInNodes {
|
||||
if m := entity.NodeMetaByNodeType(w.GetNode(nodeKey).Type); m != nil {
|
||||
if m.StreamSourceEOFAware {
|
||||
fanInConfigs[string(nodeKey)] = compose.FanInMergeConfig{
|
||||
StreamMergeWithSourceEOF: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return fanInConfigs
|
||||
}
|
||||
333
backend/domain/workflow/internal/compose/workflow_tool.go
Normal file
333
backend/domain/workflow/internal/compose/workflow_tool.go
Normal file
@@ -0,0 +1,333 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package compose
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
einoCompose "github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
wf "github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
)
|
||||
|
||||
const answerKey = "output"
|
||||
|
||||
type invokableWorkflow struct {
|
||||
info *schema.ToolInfo
|
||||
invoke func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (map[string]any, error)
|
||||
terminatePlan vo.TerminatePlan
|
||||
wfEntity *entity.Workflow
|
||||
sc *WorkflowSchema
|
||||
repo wf.Repository
|
||||
}
|
||||
|
||||
func NewInvokableWorkflow(info *schema.ToolInfo,
|
||||
invoke func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (map[string]any, error),
|
||||
terminatePlan vo.TerminatePlan,
|
||||
wfEntity *entity.Workflow,
|
||||
sc *WorkflowSchema,
|
||||
repo wf.Repository,
|
||||
) wf.ToolFromWorkflow {
|
||||
return &invokableWorkflow{
|
||||
info: info,
|
||||
invoke: invoke,
|
||||
terminatePlan: terminatePlan,
|
||||
wfEntity: wfEntity,
|
||||
sc: sc,
|
||||
repo: repo,
|
||||
}
|
||||
}
|
||||
|
||||
func (i *invokableWorkflow) Info(_ context.Context) (*schema.ToolInfo, error) {
|
||||
return i.info, nil
|
||||
}
|
||||
|
||||
func (i *invokableWorkflow) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) {
|
||||
rInfo, allIEs := execute.GetResumeRequest(opts...)
|
||||
var (
|
||||
previouslyInterrupted bool
|
||||
callID = einoCompose.GetToolCallID(ctx)
|
||||
previousExecuteID int64
|
||||
)
|
||||
for interruptedCallID := range allIEs {
|
||||
if callID == interruptedCallID {
|
||||
previouslyInterrupted = true
|
||||
previousExecuteID = allIEs[interruptedCallID].ExecuteID
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if previouslyInterrupted && rInfo.ExecuteID != previousExecuteID {
|
||||
logs.Infof("previous interrupted call ID: %s, previous execute ID: %d, current execute ID: %d. Not resuming, interrupt immediately", callID, previousExecuteID, rInfo.ExecuteID)
|
||||
return "", einoCompose.InterruptAndRerun
|
||||
}
|
||||
|
||||
cfg := execute.GetExecuteConfig(opts...)
|
||||
|
||||
var runOpts []WorkflowRunnerOption
|
||||
if rInfo != nil {
|
||||
runOpts = append(runOpts, WithResumeReq(rInfo))
|
||||
} else {
|
||||
runOpts = append(runOpts, WithInput(argumentsInJSON))
|
||||
}
|
||||
if sw := execute.GetIntermediateStreamWriter(opts...); sw != nil {
|
||||
runOpts = append(runOpts, WithStreamWriter(sw))
|
||||
}
|
||||
|
||||
var (
|
||||
cancelCtx context.Context
|
||||
executeID int64
|
||||
callOpts []einoCompose.Option
|
||||
in map[string]any
|
||||
err error
|
||||
ws *nodes.ConversionWarnings
|
||||
)
|
||||
|
||||
if rInfo == nil {
|
||||
if err = sonic.UnmarshalString(argumentsInJSON, &in); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var entryNode *NodeSchema
|
||||
for _, node := range i.sc.Nodes {
|
||||
if node.Type == entity.NodeTypeEntry {
|
||||
entryNode = node
|
||||
break
|
||||
}
|
||||
}
|
||||
if entryNode == nil {
|
||||
panic("entry node not found in tool workflow")
|
||||
}
|
||||
in, ws, err = nodes.ConvertInputs(ctx, in, entryNode.OutputTypes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
} else if ws != nil {
|
||||
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
|
||||
}
|
||||
}
|
||||
|
||||
cancelCtx, executeID, callOpts, _, err = NewWorkflowRunner(i.wfEntity.GetBasic(), i.sc, cfg, runOpts...).Prepare(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
out, err := i.invoke(cancelCtx, in, callOpts...)
|
||||
if err != nil {
|
||||
if _, ok := einoCompose.ExtractInterruptInfo(err); ok {
|
||||
firstIE, found, err := i.repo.GetFirstInterruptEvent(ctx, executeID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !found {
|
||||
return "", fmt.Errorf("interrupt event does not exist, wfExeID: %d", executeID)
|
||||
}
|
||||
|
||||
return "", einoCompose.NewInterruptAndRerunErr(&entity.ToolInterruptEvent{
|
||||
ToolCallID: einoCompose.GetToolCallID(ctx),
|
||||
ToolName: i.info.Name,
|
||||
ExecuteID: executeID,
|
||||
InterruptEvent: firstIE,
|
||||
})
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
if i.terminatePlan == vo.ReturnVariables {
|
||||
return sonic.MarshalString(out)
|
||||
}
|
||||
|
||||
content, ok := out[answerKey]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("no answer found when terminate plan is use answer content. out: %v", out)
|
||||
}
|
||||
|
||||
contentStr, ok := content.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("answer content is not string. content: %v", content)
|
||||
}
|
||||
|
||||
if strings.HasSuffix(contentStr, nodes.KeyIsFinished) {
|
||||
contentStr = strings.TrimSuffix(contentStr, nodes.KeyIsFinished)
|
||||
}
|
||||
|
||||
return contentStr, nil
|
||||
}
|
||||
|
||||
func (i *invokableWorkflow) TerminatePlan() vo.TerminatePlan {
|
||||
return i.terminatePlan
|
||||
}
|
||||
|
||||
func (i *invokableWorkflow) GetWorkflow() *entity.Workflow {
|
||||
return i.wfEntity
|
||||
}
|
||||
|
||||
type streamableWorkflow struct {
|
||||
info *schema.ToolInfo
|
||||
stream func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (*schema.StreamReader[map[string]any], error)
|
||||
terminatePlan vo.TerminatePlan
|
||||
wfEntity *entity.Workflow
|
||||
sc *WorkflowSchema
|
||||
repo wf.Repository
|
||||
}
|
||||
|
||||
func NewStreamableWorkflow(info *schema.ToolInfo,
|
||||
stream func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (*schema.StreamReader[map[string]any], error),
|
||||
terminatePlan vo.TerminatePlan,
|
||||
wfEntity *entity.Workflow,
|
||||
sc *WorkflowSchema,
|
||||
repo wf.Repository,
|
||||
) wf.ToolFromWorkflow {
|
||||
return &streamableWorkflow{
|
||||
info: info,
|
||||
stream: stream,
|
||||
terminatePlan: terminatePlan,
|
||||
wfEntity: wfEntity,
|
||||
sc: sc,
|
||||
repo: repo,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *streamableWorkflow) Info(_ context.Context) (*schema.ToolInfo, error) {
|
||||
return s.info, nil
|
||||
}
|
||||
|
||||
func (s *streamableWorkflow) StreamableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (*schema.StreamReader[string], error) {
|
||||
rInfo, allIEs := execute.GetResumeRequest(opts...)
|
||||
var (
|
||||
previouslyInterrupted bool
|
||||
callID = einoCompose.GetToolCallID(ctx)
|
||||
previousExecuteID int64
|
||||
)
|
||||
for interruptedCallID := range allIEs {
|
||||
if callID == interruptedCallID {
|
||||
previouslyInterrupted = true
|
||||
previousExecuteID = allIEs[interruptedCallID].ExecuteID
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if previouslyInterrupted && rInfo.ExecuteID != previousExecuteID {
|
||||
logs.Infof("previous interrupted call ID: %s, previous execute ID: %d, current execute ID: %d. Not resuming, interrupt immediately", callID, previousExecuteID, rInfo.ExecuteID)
|
||||
return nil, einoCompose.InterruptAndRerun
|
||||
}
|
||||
|
||||
cfg := execute.GetExecuteConfig(opts...)
|
||||
|
||||
var runOpts []WorkflowRunnerOption
|
||||
if rInfo != nil {
|
||||
runOpts = append(runOpts, WithResumeReq(rInfo))
|
||||
} else {
|
||||
runOpts = append(runOpts, WithInput(argumentsInJSON))
|
||||
}
|
||||
if sw := execute.GetIntermediateStreamWriter(opts...); sw != nil {
|
||||
runOpts = append(runOpts, WithStreamWriter(sw))
|
||||
}
|
||||
|
||||
var (
|
||||
cancelCtx context.Context
|
||||
executeID int64
|
||||
callOpts []einoCompose.Option
|
||||
in map[string]any
|
||||
err error
|
||||
ws *nodes.ConversionWarnings
|
||||
)
|
||||
|
||||
if rInfo == nil {
|
||||
if err = sonic.UnmarshalString(argumentsInJSON, &in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var entryNode *NodeSchema
|
||||
for _, node := range s.sc.Nodes {
|
||||
if node.Type == entity.NodeTypeEntry {
|
||||
entryNode = node
|
||||
break
|
||||
}
|
||||
}
|
||||
if entryNode == nil {
|
||||
panic("entry node not found in tool workflow")
|
||||
}
|
||||
in, ws, err = nodes.ConvertInputs(ctx, in, entryNode.OutputTypes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if ws != nil {
|
||||
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
|
||||
}
|
||||
}
|
||||
|
||||
cancelCtx, executeID, callOpts, _, err = NewWorkflowRunner(s.wfEntity.GetBasic(), s.sc, cfg, runOpts...).Prepare(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outStream, err := s.stream(cancelCtx, in, callOpts...)
|
||||
if err != nil {
|
||||
if _, ok := einoCompose.ExtractInterruptInfo(err); ok {
|
||||
firstIE, found, err := s.repo.GetFirstInterruptEvent(ctx, executeID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !found {
|
||||
return nil, fmt.Errorf("interrupt event does not exist, wfExeID: %d", executeID)
|
||||
}
|
||||
|
||||
return nil, einoCompose.NewInterruptAndRerunErr(&entity.ToolInterruptEvent{
|
||||
ToolCallID: einoCompose.GetToolCallID(ctx),
|
||||
ToolName: s.info.Name,
|
||||
ExecuteID: executeID,
|
||||
InterruptEvent: firstIE,
|
||||
})
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return schema.StreamReaderWithConvert(outStream, func(in map[string]any) (string, error) {
|
||||
content, ok := in["output"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("no output found when stream plan is use output content. out: %v", in)
|
||||
}
|
||||
|
||||
contentStr, ok := content.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("output content is not string. content: %v", content)
|
||||
}
|
||||
|
||||
if strings.HasSuffix(contentStr, nodes.KeyIsFinished) {
|
||||
contentStr = strings.TrimSuffix(contentStr, nodes.KeyIsFinished)
|
||||
}
|
||||
|
||||
return contentStr, nil
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (s *streamableWorkflow) TerminatePlan() vo.TerminatePlan {
|
||||
return s.terminatePlan
|
||||
}
|
||||
|
||||
func (s *streamableWorkflow) GetWorkflow() *entity.Workflow {
|
||||
return s.wfEntity
|
||||
}
|
||||
Reference in New Issue
Block a user