519 lines
14 KiB
Go
519 lines
14 KiB
Go
/*
|
|
* Copyright 2025 coze-dev Authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package loop
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"math"
|
|
"reflect"
|
|
|
|
"github.com/cloudwego/eino/compose"
|
|
|
|
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
|
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
|
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
|
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
|
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
|
_break "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop/break"
|
|
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
|
)
|
|
|
|
type Loop struct {
|
|
outputs map[string]*vo.FieldSource
|
|
outputVars map[string]string
|
|
inner compose.Runnable[map[string]any, map[string]any]
|
|
nodeKey vo.NodeKey
|
|
|
|
loopType Type
|
|
inputArrays []string
|
|
intermediateVars map[string]*vo.TypeInfo
|
|
}
|
|
|
|
type Config struct {
|
|
LoopType Type
|
|
InputArrays []string
|
|
IntermediateVars map[string]*vo.TypeInfo
|
|
}
|
|
|
|
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
|
if n.Parent() != nil {
|
|
return nil, fmt.Errorf("loop node cannot have parent: %s", n.Parent().ID)
|
|
}
|
|
|
|
ns := &schema.NodeSchema{
|
|
Key: vo.NodeKey(n.ID),
|
|
Type: entity.NodeTypeLoop,
|
|
Name: n.Data.Meta.Title,
|
|
Configs: c,
|
|
}
|
|
|
|
loopType, err := toLoopType(n.Data.Inputs.LoopType)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
c.LoopType = loopType
|
|
|
|
intermediateVars := make(map[string]*vo.TypeInfo)
|
|
for _, param := range n.Data.Inputs.VariableParameters {
|
|
tInfo, err := convert.CanvasBlockInputToTypeInfo(param.Input)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
intermediateVars[param.Name] = tInfo
|
|
|
|
ns.SetInputType(param.Name, tInfo)
|
|
sources, err := convert.CanvasBlockInputToFieldInfo(param.Input, compose.FieldPath{param.Name}, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ns.AddInputSource(sources...)
|
|
}
|
|
c.IntermediateVars = intermediateVars
|
|
|
|
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := convert.SetOutputsForNodeSchema(n, ns); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, fieldInfo := range ns.OutputSources {
|
|
if fieldInfo.Source.Ref != nil {
|
|
if len(fieldInfo.Source.Ref.FromPath) == 1 {
|
|
if _, ok := intermediateVars[fieldInfo.Source.Ref.FromPath[0]]; ok {
|
|
fieldInfo.Source.Ref.VariableType = ptr.Of(vo.ParentIntermediate)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
loopCount := n.Data.Inputs.LoopCount
|
|
if loopCount != nil {
|
|
typeInfo, err := convert.CanvasBlockInputToTypeInfo(loopCount)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ns.SetInputType(Count, typeInfo)
|
|
|
|
sources, err := convert.CanvasBlockInputToFieldInfo(loopCount, compose.FieldPath{Count}, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ns.AddInputSource(sources...)
|
|
}
|
|
|
|
for key, tInfo := range ns.InputTypes {
|
|
if tInfo.Type != vo.DataTypeArray {
|
|
continue
|
|
}
|
|
|
|
if _, ok := intermediateVars[key]; ok { // exclude arrays in intermediate vars
|
|
continue
|
|
}
|
|
|
|
c.InputArrays = append(c.InputArrays, key)
|
|
}
|
|
|
|
return ns, nil
|
|
}
|
|
|
|
func toLoopType(l vo.LoopType) (Type, error) {
|
|
switch l {
|
|
case vo.LoopTypeArray:
|
|
return ByArray, nil
|
|
case vo.LoopTypeCount:
|
|
return ByIteration, nil
|
|
case vo.LoopTypeInfinite:
|
|
return Infinite, nil
|
|
default:
|
|
return "", fmt.Errorf("unsupported loop type: %s", l)
|
|
}
|
|
}
|
|
|
|
func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, opts ...schema.BuildOption) (any, error) {
|
|
if c.LoopType == ByArray {
|
|
if len(c.InputArrays) == 0 {
|
|
return nil, errors.New("input arrays is empty when loop type is ByArray")
|
|
}
|
|
}
|
|
|
|
options := schema.GetBuildOptions(opts...)
|
|
if options.Inner == nil {
|
|
return nil, errors.New("inner workflow is required for Loop Node")
|
|
}
|
|
|
|
loop := &Loop{
|
|
outputs: make(map[string]*vo.FieldSource),
|
|
outputVars: make(map[string]string),
|
|
inputArrays: c.InputArrays,
|
|
nodeKey: ns.Key,
|
|
intermediateVars: c.IntermediateVars,
|
|
inner: options.Inner,
|
|
loopType: c.LoopType,
|
|
}
|
|
|
|
for _, info := range ns.OutputSources {
|
|
if len(info.Path) != 1 {
|
|
return nil, fmt.Errorf("invalid output path: %s", info.Path)
|
|
}
|
|
|
|
k := info.Path[0]
|
|
fromPath := info.Source.Ref.FromPath
|
|
|
|
if info.Source.Ref != nil && info.Source.Ref.VariableType != nil &&
|
|
*info.Source.Ref.VariableType == vo.ParentIntermediate {
|
|
if len(fromPath) > 1 {
|
|
return nil, fmt.Errorf("loop output refers to intermediate variable, but path length > 1: %v", fromPath)
|
|
}
|
|
|
|
if _, ok := c.IntermediateVars[fromPath[0]]; !ok {
|
|
return nil, fmt.Errorf("loop output refers to intermediate variable, but not found in intermediate vars: %v", fromPath)
|
|
}
|
|
|
|
loop.outputVars[k] = fromPath[0]
|
|
|
|
continue
|
|
}
|
|
|
|
loop.outputs[k] = &info.Source
|
|
}
|
|
|
|
return loop, nil
|
|
}
|
|
|
|
type Type string
|
|
|
|
const (
|
|
ByArray Type = "by_array"
|
|
ByIteration Type = "by_iteration"
|
|
Infinite Type = "infinite"
|
|
)
|
|
|
|
const (
|
|
Count = "loopCount"
|
|
)
|
|
|
|
func (l *Loop) Invoke(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (
|
|
out map[string]any, err error) {
|
|
maxIter, err := l.getMaxIter(in)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
arrays := make(map[string][]any, len(l.inputArrays))
|
|
for _, arrayKey := range l.inputArrays {
|
|
a, ok := nodes.TakeMapValue(in, compose.FieldPath{arrayKey})
|
|
if !ok {
|
|
return nil, fmt.Errorf("incoming array not present in input: %s", arrayKey)
|
|
}
|
|
arrays[arrayKey] = a.([]any)
|
|
}
|
|
|
|
options := nodes.GetCommonOptions(&nodes.NodeOptions{}, opts...)
|
|
|
|
var (
|
|
existingCState *nodes.NestedWorkflowState
|
|
intermediateVars map[string]*any
|
|
output map[string]any
|
|
hasBreak = any(false)
|
|
)
|
|
err = compose.ProcessState(ctx, func(ctx context.Context, getter nodes.NestedWorkflowAware) error {
|
|
var e error
|
|
existingCState, _, e = getter.GetNestedWorkflowState(l.nodeKey)
|
|
if e != nil {
|
|
return e
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if existingCState != nil {
|
|
output = existingCState.FullOutput
|
|
intermediateVars = make(map[string]*any, len(existingCState.IntermediateVars))
|
|
for k := range existingCState.IntermediateVars {
|
|
intermediateVars[k] = ptr.Of(existingCState.IntermediateVars[k])
|
|
}
|
|
intermediateVars[_break.BreakKey] = &hasBreak
|
|
} else {
|
|
output = make(map[string]any, len(l.outputs))
|
|
for k := range l.outputs {
|
|
output[k] = make([]any, 0)
|
|
}
|
|
|
|
intermediateVars = make(map[string]*any, len(l.intermediateVars))
|
|
for varKey := range l.intermediateVars {
|
|
v, ok := nodes.TakeMapValue(in, compose.FieldPath{varKey})
|
|
if !ok {
|
|
return nil, fmt.Errorf("incoming intermediate variable not present in input: %s", varKey)
|
|
}
|
|
|
|
intermediateVars[varKey] = &v
|
|
}
|
|
intermediateVars[_break.BreakKey] = &hasBreak
|
|
}
|
|
|
|
ctx = nodes.InitIntermediateVars(ctx, intermediateVars, l.intermediateVars)
|
|
|
|
getIthInput := func(i int) (map[string]any, map[string]any, error) {
|
|
input := make(map[string]any)
|
|
|
|
for k, v := range in { // carry over other values
|
|
if k == Count {
|
|
continue
|
|
}
|
|
|
|
if _, ok := arrays[k]; ok {
|
|
continue
|
|
}
|
|
|
|
if _, ok := intermediateVars[k]; ok {
|
|
continue
|
|
}
|
|
|
|
input[k] = v
|
|
}
|
|
|
|
input[string(l.nodeKey)+"#index"] = int64(i)
|
|
|
|
items := make(map[string]any)
|
|
for arrayKey := range arrays {
|
|
ele := arrays[arrayKey][i]
|
|
items[arrayKey] = ele
|
|
currentKey := string(l.nodeKey) + "#" + arrayKey
|
|
|
|
// Recursively expand map[string]any elements
|
|
var expand func(prefix string, val interface{})
|
|
expand = func(prefix string, val interface{}) {
|
|
input[prefix] = val
|
|
if nestedMap, ok := val.(map[string]any); ok {
|
|
for k, v := range nestedMap {
|
|
expand(prefix+"#"+k, v)
|
|
}
|
|
}
|
|
}
|
|
expand(currentKey, ele)
|
|
}
|
|
|
|
return input, items, nil
|
|
}
|
|
|
|
setIthOutput := func(i int, taskOutput map[string]any) {
|
|
for arrayKey := range l.outputs {
|
|
source := l.outputs[arrayKey]
|
|
fromValue, ok := nodes.TakeMapValue(taskOutput, append(compose.FieldPath{string(source.Ref.FromNodeKey)}, source.Ref.FromPath...))
|
|
if ok {
|
|
output[arrayKey] = append(output[arrayKey].([]any), fromValue)
|
|
}
|
|
}
|
|
}
|
|
|
|
var (
|
|
index2Done = map[int]bool{}
|
|
index2InterruptInfo = map[int]*compose.InterruptInfo{}
|
|
resumed = map[int]bool{}
|
|
)
|
|
|
|
for i := 0; i < maxIter; i++ {
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err() // canceled by Eino workflow engine
|
|
default:
|
|
}
|
|
|
|
if existingCState != nil {
|
|
if existingCState.Index2Done[i] == true {
|
|
continue
|
|
}
|
|
|
|
if existingCState.Index2InterruptInfo[i] != nil {
|
|
if len(options.GetResumeIndexes()) > 0 {
|
|
if _, ok := options.GetResumeIndexes()[i]; !ok {
|
|
// previously interrupted, but not resumed this time, should not happen
|
|
panic("impossible")
|
|
}
|
|
}
|
|
}
|
|
|
|
resumed[i] = true
|
|
}
|
|
|
|
input, items, err := getIthInput(i)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
subCtx, checkpointID := execute.InheritExeCtxWithBatchInfo(ctx, i, items)
|
|
|
|
ithOpts := options.GetOptsForNested()
|
|
ithOpts = append(ithOpts, options.GetOptsForIndexed(i)...)
|
|
|
|
if checkpointID != "" {
|
|
ithOpts = append(ithOpts, compose.WithCheckPointID(checkpointID))
|
|
}
|
|
|
|
if len(options.GetResumeIndexes()) > 0 {
|
|
stateModifier, ok := options.GetResumeIndexes()[i]
|
|
if ok {
|
|
fmt.Println("has state modifier for ith run: ", i, ", checkpointID: ", checkpointID)
|
|
ithOpts = append(ithOpts, compose.WithStateModifier(stateModifier))
|
|
}
|
|
}
|
|
|
|
taskOutput, err := l.inner.Invoke(subCtx, input, ithOpts...)
|
|
if err != nil {
|
|
info, ok := compose.ExtractInterruptInfo(err)
|
|
if !ok {
|
|
return nil, err
|
|
}
|
|
|
|
index2InterruptInfo[i] = info
|
|
break
|
|
}
|
|
|
|
setIthOutput(i, taskOutput)
|
|
|
|
index2Done[i] = true
|
|
|
|
if hasBreak.(bool) {
|
|
break
|
|
}
|
|
}
|
|
|
|
// delete the interruptions that have been resumed
|
|
for index := range resumed {
|
|
delete(existingCState.Index2InterruptInfo, index)
|
|
}
|
|
|
|
compState := existingCState
|
|
if compState == nil {
|
|
compState = &nodes.NestedWorkflowState{
|
|
Index2Done: index2Done,
|
|
Index2InterruptInfo: index2InterruptInfo,
|
|
FullOutput: output,
|
|
IntermediateVars: convertIntermediateVars(intermediateVars),
|
|
}
|
|
} else {
|
|
for i := range index2Done {
|
|
compState.Index2Done[i] = index2Done[i]
|
|
}
|
|
for i := range index2InterruptInfo {
|
|
compState.Index2InterruptInfo[i] = index2InterruptInfo[i]
|
|
}
|
|
compState.FullOutput = output
|
|
compState.IntermediateVars = convertIntermediateVars(intermediateVars)
|
|
}
|
|
|
|
if len(index2InterruptInfo) > 0 { // this invocation of batch.Execute has new interruptions
|
|
iEvent := &entity.InterruptEvent{
|
|
NodeKey: l.nodeKey,
|
|
NodeType: entity.NodeTypeLoop,
|
|
NestedInterruptInfo: index2InterruptInfo, // only emit the newly generated interruptInfo
|
|
}
|
|
|
|
err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error {
|
|
if e := setter.SaveNestedWorkflowState(l.nodeKey, compState); e != nil {
|
|
return e
|
|
}
|
|
|
|
return setter.SetInterruptEvent(l.nodeKey, iEvent)
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return nil, compose.InterruptAndRerun
|
|
} else {
|
|
err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error {
|
|
return setter.SaveNestedWorkflowState(l.nodeKey, compState)
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
fmt.Println("save composite info in state within loop: ", compState)
|
|
}
|
|
|
|
if existingCState != nil && len(existingCState.Index2InterruptInfo) > 0 {
|
|
panic(fmt.Sprintf("no interrupt thrown this round, but has historical interrupt events: %v", existingCState.Index2InterruptInfo))
|
|
}
|
|
|
|
for outputVarKey, intermediateVarKey := range l.outputVars {
|
|
output[outputVarKey] = *(intermediateVars[intermediateVarKey])
|
|
}
|
|
|
|
return output, nil
|
|
}
|
|
|
|
func (l *Loop) getMaxIter(in map[string]any) (int, error) {
|
|
maxIter := math.MaxInt
|
|
|
|
switch l.loopType {
|
|
case ByArray:
|
|
for _, arrayKey := range l.inputArrays {
|
|
a, ok := nodes.TakeMapValue(in, compose.FieldPath{arrayKey})
|
|
if !ok {
|
|
return 0, fmt.Errorf("incoming array not present in input: %s", arrayKey)
|
|
}
|
|
|
|
if reflect.TypeOf(a).Kind() != reflect.Slice {
|
|
return 0, fmt.Errorf("incoming array not a slice: %s. Actual type: %v", arrayKey, reflect.TypeOf(a))
|
|
}
|
|
|
|
oneLen := reflect.ValueOf(a).Len()
|
|
if oneLen < maxIter {
|
|
maxIter = oneLen
|
|
}
|
|
}
|
|
case ByIteration:
|
|
iter, ok := nodes.TakeMapValue(in, compose.FieldPath{Count})
|
|
if !ok {
|
|
return 0, errors.New("incoming LoopCount not present in input when loop type is ByIteration")
|
|
}
|
|
|
|
maxIter = int(iter.(int64))
|
|
case Infinite:
|
|
default:
|
|
return 0, fmt.Errorf("loop type not supported: %v", l.loopType)
|
|
}
|
|
|
|
return maxIter, nil
|
|
}
|
|
|
|
func convertIntermediateVars(vars map[string]*any) map[string]any {
|
|
ret := make(map[string]any, len(vars))
|
|
for k, v := range vars {
|
|
ret[k] = *v
|
|
}
|
|
return ret
|
|
}
|
|
|
|
func (l *Loop) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) {
|
|
trimmed := make(map[string]any, len(l.inputArrays))
|
|
for _, arrayKey := range l.inputArrays {
|
|
if v, ok := in[arrayKey]; ok {
|
|
trimmed[arrayKey] = v
|
|
}
|
|
}
|
|
return trimmed, nil
|
|
}
|