497 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			497 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Go
		
	
	
	
/*
 | 
						|
 * Copyright 2025 coze-dev Authors
 | 
						|
 *
 | 
						|
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
 * you may not use this file except in compliance with the License.
 | 
						|
 * You may obtain a copy of the License at
 | 
						|
 *
 | 
						|
 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
 *
 | 
						|
 * Unless required by applicable law or agreed to in writing, software
 | 
						|
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
 * See the License for the specific language governing permissions and
 | 
						|
 * limitations under the License.
 | 
						|
 */
 | 
						|
 | 
						|
package batch
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"errors"
 | 
						|
	"fmt"
 | 
						|
	"math"
 | 
						|
	"reflect"
 | 
						|
	"slices"
 | 
						|
	"sync"
 | 
						|
 | 
						|
	"github.com/cloudwego/eino/compose"
 | 
						|
	"golang.org/x/exp/maps"
 | 
						|
 | 
						|
	"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/pkg/logs"
 | 
						|
	"github.com/coze-dev/coze-studio/backend/pkg/safego"
 | 
						|
)
 | 
						|
 | 
						|
type Batch struct {
 | 
						|
	outputs       map[string]*vo.FieldSource
 | 
						|
	innerWorkflow compose.Runnable[map[string]any, map[string]any]
 | 
						|
	key           vo.NodeKey
 | 
						|
	inputArrays   []string
 | 
						|
}
 | 
						|
 | 
						|
type Config struct{}
 | 
						|
 | 
						|
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
 | 
						|
	if n.Parent() != nil {
 | 
						|
		return nil, fmt.Errorf("batch node cannot have parent: %s", n.Parent().ID)
 | 
						|
	}
 | 
						|
 | 
						|
	ns := &schema.NodeSchema{
 | 
						|
		Key:     vo.NodeKey(n.ID),
 | 
						|
		Type:    entity.NodeTypeBatch,
 | 
						|
		Name:    n.Data.Meta.Title,
 | 
						|
		Configs: c,
 | 
						|
	}
 | 
						|
 | 
						|
	batchSizeField, err := convert.CanvasBlockInputToFieldInfo(n.Data.Inputs.BatchSize,
 | 
						|
		compose.FieldPath{MaxBatchSizeKey}, nil)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	ns.AddInputSource(batchSizeField...)
 | 
						|
	concurrentSizeField, err := convert.CanvasBlockInputToFieldInfo(n.Data.Inputs.ConcurrentSize,
 | 
						|
		compose.FieldPath{ConcurrentSizeKey}, nil)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	ns.AddInputSource(concurrentSizeField...)
 | 
						|
 | 
						|
	batchSizeType, err := convert.CanvasBlockInputToTypeInfo(n.Data.Inputs.BatchSize)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	ns.SetInputType(MaxBatchSizeKey, batchSizeType)
 | 
						|
	concurrentSizeType, err := convert.CanvasBlockInputToTypeInfo(n.Data.Inputs.ConcurrentSize)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	ns.SetInputType(ConcurrentSizeKey, concurrentSizeType)
 | 
						|
 | 
						|
	if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	if err := convert.SetOutputsForNodeSchema(n, ns); err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	return ns, nil
 | 
						|
}
 | 
						|
 | 
						|
func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, opts ...schema.BuildOption) (any, error) {
 | 
						|
	var inputArrays []string
 | 
						|
	for key, tInfo := range ns.InputTypes {
 | 
						|
		if tInfo.Type != vo.DataTypeArray {
 | 
						|
			continue
 | 
						|
		}
 | 
						|
 | 
						|
		inputArrays = append(inputArrays, key)
 | 
						|
	}
 | 
						|
 | 
						|
	if len(inputArrays) == 0 {
 | 
						|
		return nil, errors.New("need to have at least one incoming array for batch")
 | 
						|
	}
 | 
						|
 | 
						|
	if len(ns.OutputSources) == 0 {
 | 
						|
		return nil, errors.New("need to have at least one output variable for batch")
 | 
						|
	}
 | 
						|
 | 
						|
	bo := schema.GetBuildOptions(opts...)
 | 
						|
	if bo.Inner == nil {
 | 
						|
		return nil, errors.New("need to have inner workflow for batch")
 | 
						|
	}
 | 
						|
 | 
						|
	b := &Batch{
 | 
						|
		outputs:       make(map[string]*vo.FieldSource),
 | 
						|
		innerWorkflow: bo.Inner,
 | 
						|
		key:           ns.Key,
 | 
						|
		inputArrays:   inputArrays,
 | 
						|
	}
 | 
						|
 | 
						|
	for i := range ns.OutputSources {
 | 
						|
		source := ns.OutputSources[i]
 | 
						|
		path := source.Path
 | 
						|
		if len(path) != 1 {
 | 
						|
			return nil, fmt.Errorf("invalid path %q", path)
 | 
						|
		}
 | 
						|
 | 
						|
		// from which inner node's which field does the batch's output fields come from
 | 
						|
		b.outputs[path[0]] = &source.Source
 | 
						|
	}
 | 
						|
 | 
						|
	return b, nil
 | 
						|
}
 | 
						|
 | 
						|
const (
 | 
						|
	MaxBatchSizeKey   = "batchSize"
 | 
						|
	ConcurrentSizeKey = "concurrentSize"
 | 
						|
)
 | 
						|
 | 
						|
func (b *Batch) initOutput(length int) map[string]any {
 | 
						|
	out := make(map[string]any, len(b.outputs))
 | 
						|
	for key := range b.outputs {
 | 
						|
		sliceType := reflect.TypeOf([]any{})
 | 
						|
		slice := reflect.New(sliceType).Elem()
 | 
						|
		slice.Set(reflect.MakeSlice(sliceType, length, length))
 | 
						|
		out[key] = slice.Interface()
 | 
						|
	}
 | 
						|
 | 
						|
	return out
 | 
						|
}
 | 
						|
 | 
						|
func (b *Batch) Invoke(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (
 | 
						|
	out map[string]any, err error) {
 | 
						|
	arrays := make(map[string]any, len(b.inputArrays))
 | 
						|
	minLen := math.MaxInt64
 | 
						|
	for _, arrayKey := range b.inputArrays {
 | 
						|
		a, ok := nodes.TakeMapValue(in, compose.FieldPath{arrayKey})
 | 
						|
		if !ok {
 | 
						|
			return nil, fmt.Errorf("incoming array not present in input: %s", arrayKey)
 | 
						|
		}
 | 
						|
 | 
						|
		if reflect.TypeOf(a).Kind() != reflect.Slice {
 | 
						|
			return nil, fmt.Errorf("incoming array not a slice: %s. Actual type: %v",
 | 
						|
				arrayKey, reflect.TypeOf(a))
 | 
						|
		}
 | 
						|
 | 
						|
		arrays[arrayKey] = a
 | 
						|
 | 
						|
		oneLen := reflect.ValueOf(a).Len()
 | 
						|
		if oneLen < minLen {
 | 
						|
			minLen = oneLen
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	var maxIter, concurrency int64
 | 
						|
 | 
						|
	maxIterAny, ok := nodes.TakeMapValue(in, compose.FieldPath{MaxBatchSizeKey})
 | 
						|
	if !ok {
 | 
						|
		return nil, fmt.Errorf("incoming max iteration not present in input: %s", in)
 | 
						|
	}
 | 
						|
 | 
						|
	maxIter = maxIterAny.(int64)
 | 
						|
	if maxIter == 0 {
 | 
						|
		maxIter = 100
 | 
						|
	}
 | 
						|
 | 
						|
	concurrencyAny, ok := nodes.TakeMapValue(in, compose.FieldPath{ConcurrentSizeKey})
 | 
						|
	if !ok {
 | 
						|
		return nil, fmt.Errorf("incoming concurrency not present in input: %s", in)
 | 
						|
	}
 | 
						|
 | 
						|
	concurrency = concurrencyAny.(int64)
 | 
						|
	if concurrency == 0 {
 | 
						|
		concurrency = 10
 | 
						|
	}
 | 
						|
 | 
						|
	if minLen > int(maxIter) {
 | 
						|
		minLen = int(maxIter)
 | 
						|
	}
 | 
						|
 | 
						|
	output := b.initOutput(minLen)
 | 
						|
	if minLen == 0 {
 | 
						|
		return output, nil
 | 
						|
	}
 | 
						|
 | 
						|
	getIthInput := func(i int) (map[string]any, map[string]any, error) {
 | 
						|
		input := make(map[string]any)
 | 
						|
 | 
						|
		for k, v := range in { // carry over other values
 | 
						|
			if k != MaxBatchSizeKey && k != ConcurrentSizeKey {
 | 
						|
				input[k] = v
 | 
						|
			}
 | 
						|
		}
 | 
						|
 | 
						|
		input[string(b.key)+"#index"] = int64(i)
 | 
						|
 | 
						|
		items := make(map[string]any)
 | 
						|
		for arrayKey, array := range arrays {
 | 
						|
			ele := reflect.ValueOf(array).Index(i).Interface()
 | 
						|
			items[arrayKey] = []any{ele}
 | 
						|
			currentKey := string(b.key) + "#" + arrayKey
 | 
						|
 | 
						|
			// Recursively expand map[string]any elements
 | 
						|
			var expand func(prefix string, val interface{})
 | 
						|
			expand = func(prefix string, val interface{}) {
 | 
						|
				input[prefix] = val
 | 
						|
				if nestedMap, ok := val.(map[string]any); ok {
 | 
						|
					for k, v := range nestedMap {
 | 
						|
						expand(prefix+"#"+k, v)
 | 
						|
					}
 | 
						|
				}
 | 
						|
			}
 | 
						|
			expand(currentKey, ele)
 | 
						|
		}
 | 
						|
 | 
						|
		return input, items, nil
 | 
						|
	}
 | 
						|
 | 
						|
	setIthOutput := func(i int, taskOutput map[string]any) error {
 | 
						|
		for k, source := range b.outputs {
 | 
						|
			fromValue, _ := nodes.TakeMapValue(taskOutput, append(compose.FieldPath{string(source.Ref.FromNodeKey)},
 | 
						|
				source.Ref.FromPath...))
 | 
						|
 | 
						|
			toArray, ok := nodes.TakeMapValue(output, compose.FieldPath{k})
 | 
						|
			if !ok {
 | 
						|
				return fmt.Errorf("key not present in outer workflow's output: %s", k)
 | 
						|
			}
 | 
						|
 | 
						|
			toArray.([]any)[i] = fromValue
 | 
						|
		}
 | 
						|
 | 
						|
		return nil
 | 
						|
	}
 | 
						|
 | 
						|
	options := nodes.GetCommonOptions(&nodes.NodeOptions{}, opts...)
 | 
						|
	var existingCState *nodes.NestedWorkflowState
 | 
						|
	err = compose.ProcessState(ctx, func(ctx context.Context, getter nodes.NestedWorkflowAware) error {
 | 
						|
		var e error
 | 
						|
		existingCState, _, e = getter.GetNestedWorkflowState(b.key)
 | 
						|
		if e != nil {
 | 
						|
			return e
 | 
						|
		}
 | 
						|
		return nil
 | 
						|
	})
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	if existingCState != nil {
 | 
						|
		output = existingCState.FullOutput
 | 
						|
	}
 | 
						|
 | 
						|
	ctx, cancelFn := context.WithCancelCause(ctx)
 | 
						|
	var (
 | 
						|
		wg                  sync.WaitGroup
 | 
						|
		mu                  sync.Mutex
 | 
						|
		index2Done          = map[int]bool{}
 | 
						|
		index2InterruptInfo = map[int]*compose.InterruptInfo{}
 | 
						|
		resumed             = map[int]bool{}
 | 
						|
	)
 | 
						|
 | 
						|
	ithTask := func(i int) {
 | 
						|
		defer wg.Done()
 | 
						|
 | 
						|
		if existingCState != nil {
 | 
						|
			if existingCState.Index2Done[i] == true {
 | 
						|
				return
 | 
						|
			}
 | 
						|
 | 
						|
			if existingCState.Index2InterruptInfo[i] != nil {
 | 
						|
				if len(options.GetResumeIndexes()) > 0 {
 | 
						|
					if _, ok := options.GetResumeIndexes()[i]; !ok {
 | 
						|
						// previously interrupted, but not resumed this time, skip
 | 
						|
						return
 | 
						|
					}
 | 
						|
				}
 | 
						|
			}
 | 
						|
 | 
						|
			mu.Lock()
 | 
						|
			resumed[i] = true
 | 
						|
			mu.Unlock()
 | 
						|
		}
 | 
						|
 | 
						|
		select {
 | 
						|
		case <-ctx.Done():
 | 
						|
			return // canceled by normal error, abort
 | 
						|
		default:
 | 
						|
		}
 | 
						|
 | 
						|
		mu.Lock()
 | 
						|
		if len(index2InterruptInfo) > 0 { // already has interrupted index, abort
 | 
						|
			mu.Unlock()
 | 
						|
			return
 | 
						|
		}
 | 
						|
		mu.Unlock()
 | 
						|
 | 
						|
		input, items, err := getIthInput(i)
 | 
						|
		if err != nil {
 | 
						|
			cancelFn(err)
 | 
						|
			return
 | 
						|
		}
 | 
						|
 | 
						|
		subCtx, subCheckpointID := execute.InheritExeCtxWithBatchInfo(ctx, i, items)
 | 
						|
 | 
						|
		ithOpts := slices.Clone(options.GetOptsForNested())
 | 
						|
		mu.Lock()
 | 
						|
		ithOpts = append(ithOpts, options.GetOptsForIndexed(i)...)
 | 
						|
		mu.Unlock()
 | 
						|
		if subCheckpointID != "" {
 | 
						|
			logs.CtxInfof(ctx, "[testInterrupt] prepare %d th run for batch node %s, subCheckPointID %s",
 | 
						|
				i, b.key, subCheckpointID)
 | 
						|
			ithOpts = append(ithOpts, compose.WithCheckPointID(subCheckpointID))
 | 
						|
		}
 | 
						|
 | 
						|
		mu.Lock()
 | 
						|
		if len(options.GetResumeIndexes()) > 0 {
 | 
						|
			stateModifier, ok := options.GetResumeIndexes()[i]
 | 
						|
			mu.Unlock()
 | 
						|
			if ok {
 | 
						|
				fmt.Println("has state modifier for ith run: ", i, ", checkpointID: ", subCheckpointID)
 | 
						|
				ithOpts = append(ithOpts, compose.WithStateModifier(stateModifier))
 | 
						|
			}
 | 
						|
		} else {
 | 
						|
			mu.Unlock()
 | 
						|
		}
 | 
						|
 | 
						|
		// if the innerWorkflow has output emitter that requires stream output, then we need to stream the inner workflow
 | 
						|
		// the output then needs to be concatenated.
 | 
						|
		taskOutput, err := b.innerWorkflow.Invoke(subCtx, input, ithOpts...)
 | 
						|
		if err != nil {
 | 
						|
			info, ok := compose.ExtractInterruptInfo(err)
 | 
						|
			if !ok {
 | 
						|
				cancelFn(err)
 | 
						|
				return
 | 
						|
			}
 | 
						|
 | 
						|
			mu.Lock()
 | 
						|
			index2InterruptInfo[i] = info
 | 
						|
			mu.Unlock()
 | 
						|
			return
 | 
						|
		}
 | 
						|
 | 
						|
		if err = setIthOutput(i, taskOutput); err != nil {
 | 
						|
			cancelFn(err)
 | 
						|
			return
 | 
						|
		}
 | 
						|
 | 
						|
		mu.Lock()
 | 
						|
		index2Done[i] = true
 | 
						|
		mu.Unlock()
 | 
						|
	}
 | 
						|
 | 
						|
	wg.Add(minLen)
 | 
						|
	if minLen < int(concurrency) {
 | 
						|
		for i := 1; i < minLen; i++ {
 | 
						|
			go ithTask(i)
 | 
						|
		}
 | 
						|
		ithTask(0)
 | 
						|
	} else {
 | 
						|
		taskChan := make(chan int, concurrency)
 | 
						|
		for i := 0; i < int(concurrency); i++ {
 | 
						|
			safego.Go(ctx, func() {
 | 
						|
				for i := range taskChan {
 | 
						|
					ithTask(i)
 | 
						|
				}
 | 
						|
			})
 | 
						|
		}
 | 
						|
		for i := 0; i < minLen; i++ {
 | 
						|
			taskChan <- i
 | 
						|
		}
 | 
						|
		close(taskChan)
 | 
						|
	}
 | 
						|
 | 
						|
	wg.Wait()
 | 
						|
 | 
						|
	if context.Cause(ctx) != nil {
 | 
						|
		if errors.Is(context.Cause(ctx), context.Canceled) {
 | 
						|
			return nil, context.Canceled // canceled by Eino workflow engine
 | 
						|
		}
 | 
						|
		return nil, context.Cause(ctx) // normal error, just throw it out
 | 
						|
	}
 | 
						|
 | 
						|
	// delete the interruptions that have been resumed
 | 
						|
	for index := range resumed {
 | 
						|
		delete(existingCState.Index2InterruptInfo, index)
 | 
						|
	}
 | 
						|
 | 
						|
	compState := existingCState
 | 
						|
	if compState == nil {
 | 
						|
		compState = &nodes.NestedWorkflowState{
 | 
						|
			Index2Done:          index2Done,
 | 
						|
			Index2InterruptInfo: index2InterruptInfo,
 | 
						|
			FullOutput:          output,
 | 
						|
		}
 | 
						|
	} else {
 | 
						|
		for i := range index2Done {
 | 
						|
			compState.Index2Done[i] = index2Done[i]
 | 
						|
		}
 | 
						|
		for i := range index2InterruptInfo {
 | 
						|
			compState.Index2InterruptInfo[i] = index2InterruptInfo[i]
 | 
						|
		}
 | 
						|
		compState.FullOutput = output
 | 
						|
	}
 | 
						|
 | 
						|
	if len(index2InterruptInfo) > 0 { // this invocation of batch.Execute has new interruptions
 | 
						|
		iEvent := &entity.InterruptEvent{
 | 
						|
			NodeKey:             b.key,
 | 
						|
			NodeType:            entity.NodeTypeBatch,
 | 
						|
			NestedInterruptInfo: index2InterruptInfo, // only emit the newly generated interruptInfo
 | 
						|
		}
 | 
						|
 | 
						|
		err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error {
 | 
						|
			if e := setter.SaveNestedWorkflowState(b.key, compState); e != nil {
 | 
						|
				return e
 | 
						|
			}
 | 
						|
 | 
						|
			return setter.SetInterruptEvent(b.key, iEvent)
 | 
						|
		})
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
 | 
						|
		fmt.Println("save interruptEvent in state within batch: ", iEvent)
 | 
						|
		fmt.Println("save composite info in state within batch: ", compState)
 | 
						|
 | 
						|
		return nil, compose.InterruptAndRerun
 | 
						|
	} else {
 | 
						|
		err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error {
 | 
						|
			if e := setter.SaveNestedWorkflowState(b.key, compState); e != nil {
 | 
						|
				return e
 | 
						|
			}
 | 
						|
 | 
						|
			if existingCState == nil {
 | 
						|
				return nil
 | 
						|
			}
 | 
						|
 | 
						|
			// although this invocation does not have new interruptions,
 | 
						|
			// this batch node previously have interrupts yet to be resumed.
 | 
						|
			// we overwrite the interrupt events, keeping only the interrupts yet to be resumed.
 | 
						|
			return setter.SetInterruptEvent(b.key, &entity.InterruptEvent{
 | 
						|
				NodeKey:             b.key,
 | 
						|
				NodeType:            entity.NodeTypeBatch,
 | 
						|
				NestedInterruptInfo: existingCState.Index2InterruptInfo,
 | 
						|
			})
 | 
						|
		})
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
 | 
						|
		fmt.Println("save composite info in state within batch: ", compState)
 | 
						|
	}
 | 
						|
 | 
						|
	if existingCState != nil && len(existingCState.Index2InterruptInfo) > 0 {
 | 
						|
		logs.CtxInfof(ctx, "no interrupt thrown this round, but has historical interrupt events yet to be resumed, "+
 | 
						|
			"nodeKey: %v. indexes: %v", b.key, maps.Keys(existingCState.Index2InterruptInfo))
 | 
						|
		return nil, compose.InterruptAndRerun // interrupt again to wait for resuming of previously interrupted index runs
 | 
						|
	}
 | 
						|
 | 
						|
	return output, nil
 | 
						|
}
 | 
						|
 | 
						|
func (b *Batch) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) {
 | 
						|
	trimmed := make(map[string]any, len(b.inputArrays))
 | 
						|
	for _, arrayKey := range b.inputArrays {
 | 
						|
		if v, ok := in[arrayKey]; ok {
 | 
						|
			trimmed[arrayKey] = v
 | 
						|
		}
 | 
						|
	}
 | 
						|
	return trimmed, nil
 | 
						|
}
 |