coze-studio/backend/domain/workflow/internal/nodes/batch/batch.go

443 lines
12 KiB
Go

/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package batch
import (
"context"
"errors"
"fmt"
"math"
"reflect"
"slices"
"sync"
"github.com/cloudwego/eino/compose"
"golang.org/x/exp/maps"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
)
type Batch struct {
config *Config
outputs map[string]*vo.FieldSource
}
type Config struct {
BatchNodeKey vo.NodeKey `json:"batch_node_key"`
InnerWorkflow compose.Runnable[map[string]any, map[string]any]
InputArrays []string `json:"input_arrays"`
Outputs []*vo.FieldInfo `json:"outputs"`
}
func NewBatch(_ context.Context, config *Config) (*Batch, error) {
if config == nil {
return nil, errors.New("config is required")
}
if len(config.InputArrays) == 0 {
return nil, errors.New("need to have at least one incoming array for batch")
}
if len(config.Outputs) == 0 {
return nil, errors.New("need to have at least one output variable for batch")
}
b := &Batch{
config: config,
outputs: make(map[string]*vo.FieldSource),
}
for i := range config.Outputs {
source := config.Outputs[i]
path := source.Path
if len(path) != 1 {
return nil, fmt.Errorf("invalid path %q", path)
}
b.outputs[path[0]] = &source.Source
}
return b, nil
}
const (
MaxBatchSizeKey = "batchSize"
ConcurrentSizeKey = "concurrentSize"
)
func (b *Batch) initOutput(length int) map[string]any {
out := make(map[string]any, len(b.outputs))
for key := range b.outputs {
sliceType := reflect.TypeOf([]any{})
slice := reflect.New(sliceType).Elem()
slice.Set(reflect.MakeSlice(sliceType, length, length))
out[key] = slice.Interface()
}
return out
}
func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.NestedWorkflowOption) (
out map[string]any, err error) {
arrays := make(map[string]any, len(b.config.InputArrays))
minLen := math.MaxInt64
for _, arrayKey := range b.config.InputArrays {
a, ok := nodes.TakeMapValue(in, compose.FieldPath{arrayKey})
if !ok {
return nil, fmt.Errorf("incoming array not present in input: %s", arrayKey)
}
if reflect.TypeOf(a).Kind() != reflect.Slice {
return nil, fmt.Errorf("incoming array not a slice: %s. Actual type: %v",
arrayKey, reflect.TypeOf(a))
}
arrays[arrayKey] = a
oneLen := reflect.ValueOf(a).Len()
if oneLen < minLen {
minLen = oneLen
}
}
var maxIter, concurrency int64
maxIterAny, ok := nodes.TakeMapValue(in, compose.FieldPath{MaxBatchSizeKey})
if !ok {
return nil, fmt.Errorf("incoming max iteration not present in input: %s", in)
}
maxIter = maxIterAny.(int64)
if maxIter == 0 {
maxIter = 100
}
concurrencyAny, ok := nodes.TakeMapValue(in, compose.FieldPath{ConcurrentSizeKey})
if !ok {
return nil, fmt.Errorf("incoming concurrency not present in input: %s", in)
}
concurrency = concurrencyAny.(int64)
if concurrency == 0 {
concurrency = 10
}
if minLen > int(maxIter) {
minLen = int(maxIter)
}
output := b.initOutput(minLen)
if minLen == 0 {
return output, nil
}
getIthInput := func(i int) (map[string]any, map[string]any, error) {
input := make(map[string]any)
for k, v := range in { // carry over other values
if k != MaxBatchSizeKey && k != ConcurrentSizeKey {
input[k] = v
}
}
input[string(b.config.BatchNodeKey)+"#index"] = int64(i)
items := make(map[string]any)
for arrayKey, array := range arrays {
ele := reflect.ValueOf(array).Index(i).Interface()
items[arrayKey] = []any{ele}
currentKey := string(b.config.BatchNodeKey) + "#" + arrayKey
// Recursively expand map[string]any elements
var expand func(prefix string, val interface{})
expand = func(prefix string, val interface{}) {
input[prefix] = val
if nestedMap, ok := val.(map[string]any); ok {
for k, v := range nestedMap {
expand(prefix+"#"+k, v)
}
}
}
expand(currentKey, ele)
}
return input, items, nil
}
setIthOutput := func(i int, taskOutput map[string]any) error {
for k, source := range b.outputs {
fromValue, _ := nodes.TakeMapValue(taskOutput, append(compose.FieldPath{string(source.Ref.FromNodeKey)},
source.Ref.FromPath...))
toArray, ok := nodes.TakeMapValue(output, compose.FieldPath{k})
if !ok {
return fmt.Errorf("key not present in outer workflow's output: %s", k)
}
toArray.([]any)[i] = fromValue
}
return nil
}
options := &nodes.NestedWorkflowOptions{}
for _, opt := range opts {
opt(options)
}
var existingCState *nodes.NestedWorkflowState
err = compose.ProcessState(ctx, func(ctx context.Context, getter nodes.NestedWorkflowAware) error {
var e error
existingCState, _, e = getter.GetNestedWorkflowState(b.config.BatchNodeKey)
if e != nil {
return e
}
return nil
})
if err != nil {
return nil, err
}
if existingCState != nil {
output = existingCState.FullOutput
}
ctx, cancelFn := context.WithCancelCause(ctx)
var (
wg sync.WaitGroup
mu sync.Mutex
index2Done = map[int]bool{}
index2InterruptInfo = map[int]*compose.InterruptInfo{}
resumed = map[int]bool{}
)
ithTask := func(i int) {
defer wg.Done()
if existingCState != nil {
if existingCState.Index2Done[i] == true {
return
}
if existingCState.Index2InterruptInfo[i] != nil {
if len(options.GetResumeIndexes()) > 0 {
if _, ok := options.GetResumeIndexes()[i]; !ok {
// previously interrupted, but not resumed this time, skip
return
}
}
}
mu.Lock()
resumed[i] = true
mu.Unlock()
}
select {
case <-ctx.Done():
return // canceled by normal error, abort
default:
}
mu.Lock()
if len(index2InterruptInfo) > 0 { // already has interrupted index, abort
mu.Unlock()
return
}
mu.Unlock()
input, items, err := getIthInput(i)
if err != nil {
cancelFn(err)
return
}
subCtx, subCheckpointID := execute.InheritExeCtxWithBatchInfo(ctx, i, items)
ithOpts := slices.Clone(options.GetOptsForNested())
mu.Lock()
ithOpts = append(ithOpts, options.GetOptsForIndexed(i)...)
mu.Unlock()
if subCheckpointID != "" {
logs.CtxInfof(ctx, "[testInterrupt] prepare %d th run for batch node %s, subCheckPointID %s",
i, b.config.BatchNodeKey, subCheckpointID)
ithOpts = append(ithOpts, compose.WithCheckPointID(subCheckpointID))
}
mu.Lock()
if len(options.GetResumeIndexes()) > 0 {
stateModifier, ok := options.GetResumeIndexes()[i]
mu.Unlock()
if ok {
fmt.Println("has state modifier for ith run: ", i, ", checkpointID: ", subCheckpointID)
ithOpts = append(ithOpts, compose.WithStateModifier(stateModifier))
}
} else {
mu.Unlock()
}
// if the innerWorkflow has output emitter that requires stream output, then we need to stream the inner workflow
// the output then needs to be concatenated.
taskOutput, err := b.config.InnerWorkflow.Invoke(subCtx, input, ithOpts...)
if err != nil {
info, ok := compose.ExtractInterruptInfo(err)
if !ok {
cancelFn(err)
return
}
mu.Lock()
index2InterruptInfo[i] = info
mu.Unlock()
return
}
if err = setIthOutput(i, taskOutput); err != nil {
cancelFn(err)
return
}
mu.Lock()
index2Done[i] = true
mu.Unlock()
}
wg.Add(minLen)
if minLen < int(concurrency) {
for i := 1; i < minLen; i++ {
go ithTask(i)
}
ithTask(0)
} else {
taskChan := make(chan int, concurrency)
for i := 0; i < int(concurrency); i++ {
safego.Go(ctx, func() {
for i := range taskChan {
ithTask(i)
}
})
}
for i := 0; i < minLen; i++ {
taskChan <- i
}
close(taskChan)
}
wg.Wait()
if context.Cause(ctx) != nil {
if errors.Is(context.Cause(ctx), context.Canceled) {
return nil, context.Canceled // canceled by Eino workflow engine
}
return nil, context.Cause(ctx) // normal error, just throw it out
}
// delete the interruptions that have been resumed
for index := range resumed {
delete(existingCState.Index2InterruptInfo, index)
}
compState := existingCState
if compState == nil {
compState = &nodes.NestedWorkflowState{
Index2Done: index2Done,
Index2InterruptInfo: index2InterruptInfo,
FullOutput: output,
}
} else {
for i := range index2Done {
compState.Index2Done[i] = index2Done[i]
}
for i := range index2InterruptInfo {
compState.Index2InterruptInfo[i] = index2InterruptInfo[i]
}
compState.FullOutput = output
}
if len(index2InterruptInfo) > 0 { // this invocation of batch.Execute has new interruptions
iEvent := &entity.InterruptEvent{
NodeKey: b.config.BatchNodeKey,
NodeType: entity.NodeTypeBatch,
NestedInterruptInfo: index2InterruptInfo, // only emit the newly generated interruptInfo
}
err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error {
if e := setter.SaveNestedWorkflowState(b.config.BatchNodeKey, compState); e != nil {
return e
}
return setter.SetInterruptEvent(b.config.BatchNodeKey, iEvent)
})
if err != nil {
return nil, err
}
fmt.Println("save interruptEvent in state within batch: ", iEvent)
fmt.Println("save composite info in state within batch: ", compState)
return nil, compose.InterruptAndRerun
} else {
err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error {
if e := setter.SaveNestedWorkflowState(b.config.BatchNodeKey, compState); e != nil {
return e
}
if existingCState == nil {
return nil
}
// although this invocation does not have new interruptions,
// this batch node previously have interrupts yet to be resumed.
// we overwrite the interrupt events, keeping only the interrupts yet to be resumed.
return setter.SetInterruptEvent(b.config.BatchNodeKey, &entity.InterruptEvent{
NodeKey: b.config.BatchNodeKey,
NodeType: entity.NodeTypeBatch,
NestedInterruptInfo: existingCState.Index2InterruptInfo,
})
})
if err != nil {
return nil, err
}
fmt.Println("save composite info in state within batch: ", compState)
}
if existingCState != nil && len(existingCState.Index2InterruptInfo) > 0 {
logs.CtxInfof(ctx, "no interrupt thrown this round, but has historical interrupt events yet to be resumed, "+
"nodeKey: %v. indexes: %v", b.config.BatchNodeKey, maps.Keys(existingCState.Index2InterruptInfo))
return nil, compose.InterruptAndRerun // interrupt again to wait for resuming of previously interrupted index runs
}
return output, nil
}
func (b *Batch) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) {
trimmed := make(map[string]any, len(b.config.InputArrays))
for _, arrayKey := range b.config.InputArrays {
if v, ok := in[arrayKey]; ok {
trimmed[arrayKey] = v
}
}
return trimmed, nil
}