refactor: how to add a node type in workflow (#558)
This commit is contained in:
@@ -30,50 +30,108 @@ import (
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
)
|
||||
|
||||
type Batch struct {
|
||||
config *Config
|
||||
outputs map[string]*vo.FieldSource
|
||||
outputs map[string]*vo.FieldSource
|
||||
innerWorkflow compose.Runnable[map[string]any, map[string]any]
|
||||
key vo.NodeKey
|
||||
inputArrays []string
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
BatchNodeKey vo.NodeKey `json:"batch_node_key"`
|
||||
InnerWorkflow compose.Runnable[map[string]any, map[string]any]
|
||||
type Config struct{}
|
||||
|
||||
InputArrays []string `json:"input_arrays"`
|
||||
Outputs []*vo.FieldInfo `json:"outputs"`
|
||||
}
|
||||
|
||||
func NewBatch(_ context.Context, config *Config) (*Batch, error) {
|
||||
if config == nil {
|
||||
return nil, errors.New("config is required")
|
||||
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
if n.Parent() != nil {
|
||||
return nil, fmt.Errorf("batch node cannot have parent: %s", n.Parent().ID)
|
||||
}
|
||||
|
||||
if len(config.InputArrays) == 0 {
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeBatch,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: c,
|
||||
}
|
||||
|
||||
batchSizeField, err := convert.CanvasBlockInputToFieldInfo(n.Data.Inputs.BatchSize,
|
||||
compose.FieldPath{MaxBatchSizeKey}, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ns.AddInputSource(batchSizeField...)
|
||||
concurrentSizeField, err := convert.CanvasBlockInputToFieldInfo(n.Data.Inputs.ConcurrentSize,
|
||||
compose.FieldPath{ConcurrentSizeKey}, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ns.AddInputSource(concurrentSizeField...)
|
||||
|
||||
batchSizeType, err := convert.CanvasBlockInputToTypeInfo(n.Data.Inputs.BatchSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ns.SetInputType(MaxBatchSizeKey, batchSizeType)
|
||||
concurrentSizeType, err := convert.CanvasBlockInputToTypeInfo(n.Data.Inputs.ConcurrentSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ns.SetInputType(ConcurrentSizeKey, concurrentSizeType)
|
||||
|
||||
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := convert.SetOutputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, opts ...schema.BuildOption) (any, error) {
|
||||
var inputArrays []string
|
||||
for key, tInfo := range ns.InputTypes {
|
||||
if tInfo.Type != vo.DataTypeArray {
|
||||
continue
|
||||
}
|
||||
|
||||
inputArrays = append(inputArrays, key)
|
||||
}
|
||||
|
||||
if len(inputArrays) == 0 {
|
||||
return nil, errors.New("need to have at least one incoming array for batch")
|
||||
}
|
||||
|
||||
if len(config.Outputs) == 0 {
|
||||
if len(ns.OutputSources) == 0 {
|
||||
return nil, errors.New("need to have at least one output variable for batch")
|
||||
}
|
||||
|
||||
b := &Batch{
|
||||
config: config,
|
||||
outputs: make(map[string]*vo.FieldSource),
|
||||
bo := schema.GetBuildOptions(opts...)
|
||||
if bo.Inner == nil {
|
||||
return nil, errors.New("need to have inner workflow for batch")
|
||||
}
|
||||
|
||||
for i := range config.Outputs {
|
||||
source := config.Outputs[i]
|
||||
b := &Batch{
|
||||
outputs: make(map[string]*vo.FieldSource),
|
||||
innerWorkflow: bo.Inner,
|
||||
key: ns.Key,
|
||||
inputArrays: inputArrays,
|
||||
}
|
||||
|
||||
for i := range ns.OutputSources {
|
||||
source := ns.OutputSources[i]
|
||||
path := source.Path
|
||||
if len(path) != 1 {
|
||||
return nil, fmt.Errorf("invalid path %q", path)
|
||||
}
|
||||
|
||||
// from which inner node's which field does the batch's output fields come from
|
||||
b.outputs[path[0]] = &source.Source
|
||||
}
|
||||
|
||||
@@ -97,11 +155,11 @@ func (b *Batch) initOutput(length int) map[string]any {
|
||||
return out
|
||||
}
|
||||
|
||||
func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.NestedWorkflowOption) (
|
||||
func (b *Batch) Invoke(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (
|
||||
out map[string]any, err error) {
|
||||
arrays := make(map[string]any, len(b.config.InputArrays))
|
||||
arrays := make(map[string]any, len(b.inputArrays))
|
||||
minLen := math.MaxInt64
|
||||
for _, arrayKey := range b.config.InputArrays {
|
||||
for _, arrayKey := range b.inputArrays {
|
||||
a, ok := nodes.TakeMapValue(in, compose.FieldPath{arrayKey})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("incoming array not present in input: %s", arrayKey)
|
||||
@@ -160,13 +218,13 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
|
||||
}
|
||||
}
|
||||
|
||||
input[string(b.config.BatchNodeKey)+"#index"] = int64(i)
|
||||
input[string(b.key)+"#index"] = int64(i)
|
||||
|
||||
items := make(map[string]any)
|
||||
for arrayKey, array := range arrays {
|
||||
ele := reflect.ValueOf(array).Index(i).Interface()
|
||||
items[arrayKey] = []any{ele}
|
||||
currentKey := string(b.config.BatchNodeKey) + "#" + arrayKey
|
||||
currentKey := string(b.key) + "#" + arrayKey
|
||||
|
||||
// Recursively expand map[string]any elements
|
||||
var expand func(prefix string, val interface{})
|
||||
@@ -200,15 +258,11 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
|
||||
return nil
|
||||
}
|
||||
|
||||
options := &nodes.NestedWorkflowOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
options := nodes.GetCommonOptions(&nodes.NodeOptions{}, opts...)
|
||||
var existingCState *nodes.NestedWorkflowState
|
||||
err = compose.ProcessState(ctx, func(ctx context.Context, getter nodes.NestedWorkflowAware) error {
|
||||
var e error
|
||||
existingCState, _, e = getter.GetNestedWorkflowState(b.config.BatchNodeKey)
|
||||
existingCState, _, e = getter.GetNestedWorkflowState(b.key)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
@@ -280,7 +334,7 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
|
||||
mu.Unlock()
|
||||
if subCheckpointID != "" {
|
||||
logs.CtxInfof(ctx, "[testInterrupt] prepare %d th run for batch node %s, subCheckPointID %s",
|
||||
i, b.config.BatchNodeKey, subCheckpointID)
|
||||
i, b.key, subCheckpointID)
|
||||
ithOpts = append(ithOpts, compose.WithCheckPointID(subCheckpointID))
|
||||
}
|
||||
|
||||
@@ -298,7 +352,7 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
|
||||
|
||||
// if the innerWorkflow has output emitter that requires stream output, then we need to stream the inner workflow
|
||||
// the output then needs to be concatenated.
|
||||
taskOutput, err := b.config.InnerWorkflow.Invoke(subCtx, input, ithOpts...)
|
||||
taskOutput, err := b.innerWorkflow.Invoke(subCtx, input, ithOpts...)
|
||||
if err != nil {
|
||||
info, ok := compose.ExtractInterruptInfo(err)
|
||||
if !ok {
|
||||
@@ -376,17 +430,17 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
|
||||
|
||||
if len(index2InterruptInfo) > 0 { // this invocation of batch.Execute has new interruptions
|
||||
iEvent := &entity.InterruptEvent{
|
||||
NodeKey: b.config.BatchNodeKey,
|
||||
NodeKey: b.key,
|
||||
NodeType: entity.NodeTypeBatch,
|
||||
NestedInterruptInfo: index2InterruptInfo, // only emit the newly generated interruptInfo
|
||||
}
|
||||
|
||||
err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error {
|
||||
if e := setter.SaveNestedWorkflowState(b.config.BatchNodeKey, compState); e != nil {
|
||||
if e := setter.SaveNestedWorkflowState(b.key, compState); e != nil {
|
||||
return e
|
||||
}
|
||||
|
||||
return setter.SetInterruptEvent(b.config.BatchNodeKey, iEvent)
|
||||
return setter.SetInterruptEvent(b.key, iEvent)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -398,7 +452,7 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
|
||||
return nil, compose.InterruptAndRerun
|
||||
} else {
|
||||
err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error {
|
||||
if e := setter.SaveNestedWorkflowState(b.config.BatchNodeKey, compState); e != nil {
|
||||
if e := setter.SaveNestedWorkflowState(b.key, compState); e != nil {
|
||||
return e
|
||||
}
|
||||
|
||||
@@ -409,8 +463,8 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
|
||||
// although this invocation does not have new interruptions,
|
||||
// this batch node previously have interrupts yet to be resumed.
|
||||
// we overwrite the interrupt events, keeping only the interrupts yet to be resumed.
|
||||
return setter.SetInterruptEvent(b.config.BatchNodeKey, &entity.InterruptEvent{
|
||||
NodeKey: b.config.BatchNodeKey,
|
||||
return setter.SetInterruptEvent(b.key, &entity.InterruptEvent{
|
||||
NodeKey: b.key,
|
||||
NodeType: entity.NodeTypeBatch,
|
||||
NestedInterruptInfo: existingCState.Index2InterruptInfo,
|
||||
})
|
||||
@@ -424,7 +478,7 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
|
||||
|
||||
if existingCState != nil && len(existingCState.Index2InterruptInfo) > 0 {
|
||||
logs.CtxInfof(ctx, "no interrupt thrown this round, but has historical interrupt events yet to be resumed, "+
|
||||
"nodeKey: %v. indexes: %v", b.config.BatchNodeKey, maps.Keys(existingCState.Index2InterruptInfo))
|
||||
"nodeKey: %v. indexes: %v", b.key, maps.Keys(existingCState.Index2InterruptInfo))
|
||||
return nil, compose.InterruptAndRerun // interrupt again to wait for resuming of previously interrupted index runs
|
||||
}
|
||||
|
||||
@@ -432,8 +486,8 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
|
||||
}
|
||||
|
||||
func (b *Batch) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) {
|
||||
trimmed := make(map[string]any, len(b.config.InputArrays))
|
||||
for _, arrayKey := range b.config.InputArrays {
|
||||
trimmed := make(map[string]any, len(b.inputArrays))
|
||||
for _, arrayKey := range b.inputArrays {
|
||||
if v, ok := in[arrayKey]; ok {
|
||||
trimmed[arrayKey] = v
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user