519 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			519 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Go
		
	
	
	
/*
 | 
						|
 * Copyright 2025 coze-dev Authors
 | 
						|
 *
 | 
						|
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
 * you may not use this file except in compliance with the License.
 | 
						|
 * You may obtain a copy of the License at
 | 
						|
 *
 | 
						|
 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
 *
 | 
						|
 * Unless required by applicable law or agreed to in writing, software
 | 
						|
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
 * See the License for the specific language governing permissions and
 | 
						|
 * limitations under the License.
 | 
						|
 */
 | 
						|
 | 
						|
package loop
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"errors"
 | 
						|
	"fmt"
 | 
						|
	"math"
 | 
						|
	"reflect"
 | 
						|
 | 
						|
	"github.com/cloudwego/eino/compose"
 | 
						|
 | 
						|
	"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
 | 
						|
	_break "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop/break"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
 | 
						|
)
 | 
						|
 | 
						|
type Loop struct {
 | 
						|
	outputs    map[string]*vo.FieldSource
 | 
						|
	outputVars map[string]string
 | 
						|
	inner      compose.Runnable[map[string]any, map[string]any]
 | 
						|
	nodeKey    vo.NodeKey
 | 
						|
 | 
						|
	loopType         Type
 | 
						|
	inputArrays      []string
 | 
						|
	intermediateVars map[string]*vo.TypeInfo
 | 
						|
}
 | 
						|
 | 
						|
type Config struct {
 | 
						|
	LoopType         Type
 | 
						|
	InputArrays      []string
 | 
						|
	IntermediateVars map[string]*vo.TypeInfo
 | 
						|
}
 | 
						|
 | 
						|
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
 | 
						|
	if n.Parent() != nil {
 | 
						|
		return nil, fmt.Errorf("loop node cannot have parent: %s", n.Parent().ID)
 | 
						|
	}
 | 
						|
 | 
						|
	ns := &schema.NodeSchema{
 | 
						|
		Key:     vo.NodeKey(n.ID),
 | 
						|
		Type:    entity.NodeTypeLoop,
 | 
						|
		Name:    n.Data.Meta.Title,
 | 
						|
		Configs: c,
 | 
						|
	}
 | 
						|
 | 
						|
	loopType, err := toLoopType(n.Data.Inputs.LoopType)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	c.LoopType = loopType
 | 
						|
 | 
						|
	intermediateVars := make(map[string]*vo.TypeInfo)
 | 
						|
	for _, param := range n.Data.Inputs.VariableParameters {
 | 
						|
		tInfo, err := convert.CanvasBlockInputToTypeInfo(param.Input)
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
		intermediateVars[param.Name] = tInfo
 | 
						|
 | 
						|
		ns.SetInputType(param.Name, tInfo)
 | 
						|
		sources, err := convert.CanvasBlockInputToFieldInfo(param.Input, compose.FieldPath{param.Name}, nil)
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
		ns.AddInputSource(sources...)
 | 
						|
	}
 | 
						|
	c.IntermediateVars = intermediateVars
 | 
						|
 | 
						|
	if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	if err := convert.SetOutputsForNodeSchema(n, ns); err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	for _, fieldInfo := range ns.OutputSources {
 | 
						|
		if fieldInfo.Source.Ref != nil {
 | 
						|
			if len(fieldInfo.Source.Ref.FromPath) == 1 {
 | 
						|
				if _, ok := intermediateVars[fieldInfo.Source.Ref.FromPath[0]]; ok {
 | 
						|
					fieldInfo.Source.Ref.VariableType = ptr.Of(vo.ParentIntermediate)
 | 
						|
				}
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	loopCount := n.Data.Inputs.LoopCount
 | 
						|
	if loopCount != nil {
 | 
						|
		typeInfo, err := convert.CanvasBlockInputToTypeInfo(loopCount)
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
		ns.SetInputType(Count, typeInfo)
 | 
						|
 | 
						|
		sources, err := convert.CanvasBlockInputToFieldInfo(loopCount, compose.FieldPath{Count}, nil)
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
		ns.AddInputSource(sources...)
 | 
						|
	}
 | 
						|
 | 
						|
	for key, tInfo := range ns.InputTypes {
 | 
						|
		if tInfo.Type != vo.DataTypeArray {
 | 
						|
			continue
 | 
						|
		}
 | 
						|
 | 
						|
		if _, ok := intermediateVars[key]; ok { // exclude arrays in intermediate vars
 | 
						|
			continue
 | 
						|
		}
 | 
						|
 | 
						|
		c.InputArrays = append(c.InputArrays, key)
 | 
						|
	}
 | 
						|
 | 
						|
	return ns, nil
 | 
						|
}
 | 
						|
 | 
						|
func toLoopType(l vo.LoopType) (Type, error) {
 | 
						|
	switch l {
 | 
						|
	case vo.LoopTypeArray:
 | 
						|
		return ByArray, nil
 | 
						|
	case vo.LoopTypeCount:
 | 
						|
		return ByIteration, nil
 | 
						|
	case vo.LoopTypeInfinite:
 | 
						|
		return Infinite, nil
 | 
						|
	default:
 | 
						|
		return "", fmt.Errorf("unsupported loop type: %s", l)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, opts ...schema.BuildOption) (any, error) {
 | 
						|
	if c.LoopType == ByArray {
 | 
						|
		if len(c.InputArrays) == 0 {
 | 
						|
			return nil, errors.New("input arrays is empty when loop type is ByArray")
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	options := schema.GetBuildOptions(opts...)
 | 
						|
	if options.Inner == nil {
 | 
						|
		return nil, errors.New("inner workflow is required for Loop Node")
 | 
						|
	}
 | 
						|
 | 
						|
	loop := &Loop{
 | 
						|
		outputs:          make(map[string]*vo.FieldSource),
 | 
						|
		outputVars:       make(map[string]string),
 | 
						|
		inputArrays:      c.InputArrays,
 | 
						|
		nodeKey:          ns.Key,
 | 
						|
		intermediateVars: c.IntermediateVars,
 | 
						|
		inner:            options.Inner,
 | 
						|
		loopType:         c.LoopType,
 | 
						|
	}
 | 
						|
 | 
						|
	for _, info := range ns.OutputSources {
 | 
						|
		if len(info.Path) != 1 {
 | 
						|
			return nil, fmt.Errorf("invalid output path: %s", info.Path)
 | 
						|
		}
 | 
						|
 | 
						|
		k := info.Path[0]
 | 
						|
		fromPath := info.Source.Ref.FromPath
 | 
						|
 | 
						|
		if info.Source.Ref != nil && info.Source.Ref.VariableType != nil &&
 | 
						|
			*info.Source.Ref.VariableType == vo.ParentIntermediate {
 | 
						|
			if len(fromPath) > 1 {
 | 
						|
				return nil, fmt.Errorf("loop output refers to intermediate variable, but path length > 1: %v", fromPath)
 | 
						|
			}
 | 
						|
 | 
						|
			if _, ok := c.IntermediateVars[fromPath[0]]; !ok {
 | 
						|
				return nil, fmt.Errorf("loop output refers to intermediate variable, but not found in intermediate vars: %v", fromPath)
 | 
						|
			}
 | 
						|
 | 
						|
			loop.outputVars[k] = fromPath[0]
 | 
						|
 | 
						|
			continue
 | 
						|
		}
 | 
						|
 | 
						|
		loop.outputs[k] = &info.Source
 | 
						|
	}
 | 
						|
 | 
						|
	return loop, nil
 | 
						|
}
 | 
						|
 | 
						|
type Type string
 | 
						|
 | 
						|
const (
 | 
						|
	ByArray     Type = "by_array"
 | 
						|
	ByIteration Type = "by_iteration"
 | 
						|
	Infinite    Type = "infinite"
 | 
						|
)
 | 
						|
 | 
						|
const (
 | 
						|
	Count = "loopCount"
 | 
						|
)
 | 
						|
 | 
						|
func (l *Loop) Invoke(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (
 | 
						|
	out map[string]any, err error) {
 | 
						|
	maxIter, err := l.getMaxIter(in)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	arrays := make(map[string][]any, len(l.inputArrays))
 | 
						|
	for _, arrayKey := range l.inputArrays {
 | 
						|
		a, ok := nodes.TakeMapValue(in, compose.FieldPath{arrayKey})
 | 
						|
		if !ok {
 | 
						|
			return nil, fmt.Errorf("incoming array not present in input: %s", arrayKey)
 | 
						|
		}
 | 
						|
		arrays[arrayKey] = a.([]any)
 | 
						|
	}
 | 
						|
 | 
						|
	options := nodes.GetCommonOptions(&nodes.NodeOptions{}, opts...)
 | 
						|
 | 
						|
	var (
 | 
						|
		existingCState   *nodes.NestedWorkflowState
 | 
						|
		intermediateVars map[string]*any
 | 
						|
		output           map[string]any
 | 
						|
		hasBreak         = any(false)
 | 
						|
	)
 | 
						|
	err = compose.ProcessState(ctx, func(ctx context.Context, getter nodes.NestedWorkflowAware) error {
 | 
						|
		var e error
 | 
						|
		existingCState, _, e = getter.GetNestedWorkflowState(l.nodeKey)
 | 
						|
		if e != nil {
 | 
						|
			return e
 | 
						|
		}
 | 
						|
		return nil
 | 
						|
	})
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	if existingCState != nil {
 | 
						|
		output = existingCState.FullOutput
 | 
						|
		intermediateVars = make(map[string]*any, len(existingCState.IntermediateVars))
 | 
						|
		for k := range existingCState.IntermediateVars {
 | 
						|
			intermediateVars[k] = ptr.Of(existingCState.IntermediateVars[k])
 | 
						|
		}
 | 
						|
		intermediateVars[_break.BreakKey] = &hasBreak
 | 
						|
	} else {
 | 
						|
		output = make(map[string]any, len(l.outputs))
 | 
						|
		for k := range l.outputs {
 | 
						|
			output[k] = make([]any, 0)
 | 
						|
		}
 | 
						|
 | 
						|
		intermediateVars = make(map[string]*any, len(l.intermediateVars))
 | 
						|
		for varKey := range l.intermediateVars {
 | 
						|
			v, ok := nodes.TakeMapValue(in, compose.FieldPath{varKey})
 | 
						|
			if !ok {
 | 
						|
				return nil, fmt.Errorf("incoming intermediate variable not present in input: %s", varKey)
 | 
						|
			}
 | 
						|
 | 
						|
			intermediateVars[varKey] = &v
 | 
						|
		}
 | 
						|
		intermediateVars[_break.BreakKey] = &hasBreak
 | 
						|
	}
 | 
						|
 | 
						|
	ctx = nodes.InitIntermediateVars(ctx, intermediateVars, l.intermediateVars)
 | 
						|
 | 
						|
	getIthInput := func(i int) (map[string]any, map[string]any, error) {
 | 
						|
		input := make(map[string]any)
 | 
						|
 | 
						|
		for k, v := range in { // carry over other values
 | 
						|
			if k == Count {
 | 
						|
				continue
 | 
						|
			}
 | 
						|
 | 
						|
			if _, ok := arrays[k]; ok {
 | 
						|
				continue
 | 
						|
			}
 | 
						|
 | 
						|
			if _, ok := intermediateVars[k]; ok {
 | 
						|
				continue
 | 
						|
			}
 | 
						|
 | 
						|
			input[k] = v
 | 
						|
		}
 | 
						|
 | 
						|
		input[string(l.nodeKey)+"#index"] = int64(i)
 | 
						|
 | 
						|
		items := make(map[string]any)
 | 
						|
		for arrayKey := range arrays {
 | 
						|
			ele := arrays[arrayKey][i]
 | 
						|
			items[arrayKey] = ele
 | 
						|
			currentKey := string(l.nodeKey) + "#" + arrayKey
 | 
						|
 | 
						|
			// Recursively expand map[string]any elements
 | 
						|
			var expand func(prefix string, val interface{})
 | 
						|
			expand = func(prefix string, val interface{}) {
 | 
						|
				input[prefix] = val
 | 
						|
				if nestedMap, ok := val.(map[string]any); ok {
 | 
						|
					for k, v := range nestedMap {
 | 
						|
						expand(prefix+"#"+k, v)
 | 
						|
					}
 | 
						|
				}
 | 
						|
			}
 | 
						|
			expand(currentKey, ele)
 | 
						|
		}
 | 
						|
 | 
						|
		return input, items, nil
 | 
						|
	}
 | 
						|
 | 
						|
	setIthOutput := func(i int, taskOutput map[string]any) {
 | 
						|
		for arrayKey := range l.outputs {
 | 
						|
			source := l.outputs[arrayKey]
 | 
						|
			fromValue, ok := nodes.TakeMapValue(taskOutput, append(compose.FieldPath{string(source.Ref.FromNodeKey)}, source.Ref.FromPath...))
 | 
						|
			if ok {
 | 
						|
				output[arrayKey] = append(output[arrayKey].([]any), fromValue)
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	var (
 | 
						|
		index2Done          = map[int]bool{}
 | 
						|
		index2InterruptInfo = map[int]*compose.InterruptInfo{}
 | 
						|
		resumed             = map[int]bool{}
 | 
						|
	)
 | 
						|
 | 
						|
	for i := 0; i < maxIter; i++ {
 | 
						|
		select {
 | 
						|
		case <-ctx.Done():
 | 
						|
			return nil, ctx.Err() // canceled by Eino workflow engine
 | 
						|
		default:
 | 
						|
		}
 | 
						|
 | 
						|
		if existingCState != nil {
 | 
						|
			if existingCState.Index2Done[i] == true {
 | 
						|
				continue
 | 
						|
			}
 | 
						|
 | 
						|
			if existingCState.Index2InterruptInfo[i] != nil {
 | 
						|
				if len(options.GetResumeIndexes()) > 0 {
 | 
						|
					if _, ok := options.GetResumeIndexes()[i]; !ok {
 | 
						|
						// previously interrupted, but not resumed this time, should not happen
 | 
						|
						panic("impossible")
 | 
						|
					}
 | 
						|
				}
 | 
						|
			}
 | 
						|
 | 
						|
			resumed[i] = true
 | 
						|
		}
 | 
						|
 | 
						|
		input, items, err := getIthInput(i)
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
 | 
						|
		subCtx, checkpointID := execute.InheritExeCtxWithBatchInfo(ctx, i, items)
 | 
						|
 | 
						|
		ithOpts := options.GetOptsForNested()
 | 
						|
		ithOpts = append(ithOpts, options.GetOptsForIndexed(i)...)
 | 
						|
 | 
						|
		if checkpointID != "" {
 | 
						|
			ithOpts = append(ithOpts, compose.WithCheckPointID(checkpointID))
 | 
						|
		}
 | 
						|
 | 
						|
		if len(options.GetResumeIndexes()) > 0 {
 | 
						|
			stateModifier, ok := options.GetResumeIndexes()[i]
 | 
						|
			if ok {
 | 
						|
				fmt.Println("has state modifier for ith run: ", i, ", checkpointID: ", checkpointID)
 | 
						|
				ithOpts = append(ithOpts, compose.WithStateModifier(stateModifier))
 | 
						|
			}
 | 
						|
		}
 | 
						|
 | 
						|
		taskOutput, err := l.inner.Invoke(subCtx, input, ithOpts...)
 | 
						|
		if err != nil {
 | 
						|
			info, ok := compose.ExtractInterruptInfo(err)
 | 
						|
			if !ok {
 | 
						|
				return nil, err
 | 
						|
			}
 | 
						|
 | 
						|
			index2InterruptInfo[i] = info
 | 
						|
			break
 | 
						|
		}
 | 
						|
 | 
						|
		setIthOutput(i, taskOutput)
 | 
						|
 | 
						|
		index2Done[i] = true
 | 
						|
 | 
						|
		if hasBreak.(bool) {
 | 
						|
			break
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	// delete the interruptions that have been resumed
 | 
						|
	for index := range resumed {
 | 
						|
		delete(existingCState.Index2InterruptInfo, index)
 | 
						|
	}
 | 
						|
 | 
						|
	compState := existingCState
 | 
						|
	if compState == nil {
 | 
						|
		compState = &nodes.NestedWorkflowState{
 | 
						|
			Index2Done:          index2Done,
 | 
						|
			Index2InterruptInfo: index2InterruptInfo,
 | 
						|
			FullOutput:          output,
 | 
						|
			IntermediateVars:    convertIntermediateVars(intermediateVars),
 | 
						|
		}
 | 
						|
	} else {
 | 
						|
		for i := range index2Done {
 | 
						|
			compState.Index2Done[i] = index2Done[i]
 | 
						|
		}
 | 
						|
		for i := range index2InterruptInfo {
 | 
						|
			compState.Index2InterruptInfo[i] = index2InterruptInfo[i]
 | 
						|
		}
 | 
						|
		compState.FullOutput = output
 | 
						|
		compState.IntermediateVars = convertIntermediateVars(intermediateVars)
 | 
						|
	}
 | 
						|
 | 
						|
	if len(index2InterruptInfo) > 0 { // this invocation of batch.Execute has new interruptions
 | 
						|
		iEvent := &entity.InterruptEvent{
 | 
						|
			NodeKey:             l.nodeKey,
 | 
						|
			NodeType:            entity.NodeTypeLoop,
 | 
						|
			NestedInterruptInfo: index2InterruptInfo, // only emit the newly generated interruptInfo
 | 
						|
		}
 | 
						|
 | 
						|
		err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error {
 | 
						|
			if e := setter.SaveNestedWorkflowState(l.nodeKey, compState); e != nil {
 | 
						|
				return e
 | 
						|
			}
 | 
						|
 | 
						|
			return setter.SetInterruptEvent(l.nodeKey, iEvent)
 | 
						|
		})
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
 | 
						|
		return nil, compose.InterruptAndRerun
 | 
						|
	} else {
 | 
						|
		err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error {
 | 
						|
			return setter.SaveNestedWorkflowState(l.nodeKey, compState)
 | 
						|
		})
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
 | 
						|
		fmt.Println("save composite info in state within loop: ", compState)
 | 
						|
	}
 | 
						|
 | 
						|
	if existingCState != nil && len(existingCState.Index2InterruptInfo) > 0 {
 | 
						|
		panic(fmt.Sprintf("no interrupt thrown this round, but has historical interrupt events: %v", existingCState.Index2InterruptInfo))
 | 
						|
	}
 | 
						|
 | 
						|
	for outputVarKey, intermediateVarKey := range l.outputVars {
 | 
						|
		output[outputVarKey] = *(intermediateVars[intermediateVarKey])
 | 
						|
	}
 | 
						|
 | 
						|
	return output, nil
 | 
						|
}
 | 
						|
 | 
						|
func (l *Loop) getMaxIter(in map[string]any) (int, error) {
 | 
						|
	maxIter := math.MaxInt
 | 
						|
 | 
						|
	switch l.loopType {
 | 
						|
	case ByArray:
 | 
						|
		for _, arrayKey := range l.inputArrays {
 | 
						|
			a, ok := nodes.TakeMapValue(in, compose.FieldPath{arrayKey})
 | 
						|
			if !ok {
 | 
						|
				return 0, fmt.Errorf("incoming array not present in input: %s", arrayKey)
 | 
						|
			}
 | 
						|
 | 
						|
			if reflect.TypeOf(a).Kind() != reflect.Slice {
 | 
						|
				return 0, fmt.Errorf("incoming array not a slice: %s. Actual type: %v", arrayKey, reflect.TypeOf(a))
 | 
						|
			}
 | 
						|
 | 
						|
			oneLen := reflect.ValueOf(a).Len()
 | 
						|
			if oneLen < maxIter {
 | 
						|
				maxIter = oneLen
 | 
						|
			}
 | 
						|
		}
 | 
						|
	case ByIteration:
 | 
						|
		iter, ok := nodes.TakeMapValue(in, compose.FieldPath{Count})
 | 
						|
		if !ok {
 | 
						|
			return 0, errors.New("incoming LoopCount not present in input when loop type is ByIteration")
 | 
						|
		}
 | 
						|
 | 
						|
		maxIter = int(iter.(int64))
 | 
						|
	case Infinite:
 | 
						|
	default:
 | 
						|
		return 0, fmt.Errorf("loop type not supported: %v", l.loopType)
 | 
						|
	}
 | 
						|
 | 
						|
	return maxIter, nil
 | 
						|
}
 | 
						|
 | 
						|
func convertIntermediateVars(vars map[string]*any) map[string]any {
 | 
						|
	ret := make(map[string]any, len(vars))
 | 
						|
	for k, v := range vars {
 | 
						|
		ret[k] = *v
 | 
						|
	}
 | 
						|
	return ret
 | 
						|
}
 | 
						|
 | 
						|
func (l *Loop) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) {
 | 
						|
	trimmed := make(map[string]any, len(l.inputArrays))
 | 
						|
	for _, arrayKey := range l.inputArrays {
 | 
						|
		if v, ok := in[arrayKey]; ok {
 | 
						|
			trimmed[arrayKey] = v
 | 
						|
		}
 | 
						|
	}
 | 
						|
	return trimmed, nil
 | 
						|
}
 |