refactor: how to add a node type in workflow (#558)

This commit is contained in:
shentongmartin
2025-08-05 14:02:33 +08:00
committed by GitHub
parent 5dafd81a3f
commit bb6ff0026b
96 changed files with 8305 additions and 8717 deletions

View File

@@ -32,8 +32,11 @@ import (
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
@@ -48,24 +51,147 @@ const (
type Config struct {
MergeStrategy MergeStrategy
GroupLen map[string]int
FullSources map[string]*nodes.SourceInfo
NodeKey vo.NodeKey
InputSources []*vo.FieldInfo
GroupOrder []string // the order the groups are declared in frontend canvas
}
type VariableAggregator struct {
config *Config
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema2.NodeSchema, error) {
ns := &schema2.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeVariableAggregator,
Name: n.Data.Meta.Title,
Configs: c,
}
c.MergeStrategy = FirstNotNullValue
inputs := n.Data.Inputs
groupToLen := make(map[string]int, len(inputs.VariableAggregator.MergeGroups))
for i := range inputs.VariableAggregator.MergeGroups {
group := inputs.VariableAggregator.MergeGroups[i]
tInfo := &vo.TypeInfo{
Type: vo.DataTypeObject,
Properties: make(map[string]*vo.TypeInfo),
}
ns.SetInputType(group.Name, tInfo)
for ii, v := range group.Variables {
name := strconv.Itoa(ii)
valueTypeInfo, err := convert.CanvasBlockInputToTypeInfo(v)
if err != nil {
return nil, err
}
tInfo.Properties[name] = valueTypeInfo
sources, err := convert.CanvasBlockInputToFieldInfo(v, compose.FieldPath{group.Name, name}, n.Parent())
if err != nil {
return nil, err
}
ns.AddInputSource(sources...)
}
length := len(group.Variables)
groupToLen[group.Name] = length
}
groupOrder := make([]string, 0, len(groupToLen))
for i := range inputs.VariableAggregator.MergeGroups {
group := inputs.VariableAggregator.MergeGroups[i]
groupOrder = append(groupOrder, group.Name)
}
c.GroupLen = groupToLen
c.GroupOrder = groupOrder
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func NewVariableAggregator(_ context.Context, cfg *Config) (*VariableAggregator, error) {
if cfg == nil {
return nil, errors.New("config is required")
func (c *Config) Build(_ context.Context, ns *schema2.NodeSchema, _ ...schema2.BuildOption) (any, error) {
if c.MergeStrategy != FirstNotNullValue {
return nil, fmt.Errorf("merge strategy not supported: %v", c.MergeStrategy)
}
if cfg.MergeStrategy != FirstNotNullValue {
return nil, fmt.Errorf("merge strategy not supported: %v", cfg.MergeStrategy)
return &VariableAggregator{
groupLen: c.GroupLen,
fullSources: ns.FullSources,
nodeKey: ns.Key,
}, nil
}
func (c *Config) FieldStreamType(path compose.FieldPath, ns *schema2.NodeSchema,
sc *schema2.WorkflowSchema) (schema2.FieldStreamType, error) {
if !sc.RequireStreaming() {
return schema2.FieldNotStream, nil
}
return &VariableAggregator{config: cfg}, nil
if len(path) == 2 { // asking about a specific index within a group
for _, fInfo := range ns.InputSources {
if len(fInfo.Path) == len(path) {
equal := true
for i := range fInfo.Path {
if fInfo.Path[i] != path[i] {
equal = false
break
}
}
if equal {
if fInfo.Source.Ref == nil || fInfo.Source.Ref.FromNodeKey == "" {
return schema2.FieldNotStream, nil // variables or static values
}
fromNodeKey := fInfo.Source.Ref.FromNodeKey
fromNode := sc.GetNode(fromNodeKey)
if fromNode == nil {
return schema2.FieldNotStream, fmt.Errorf("node %s not found", fromNodeKey)
}
return nodes.IsStreamingField(fromNode, fInfo.Source.Ref.FromPath, sc)
}
}
}
} else if len(path) == 1 { // asking about the entire group
var streamCount, notStreamCount int
for _, fInfo := range ns.InputSources {
if fInfo.Path[0] == path[0] { // belong to the group
if fInfo.Source.Ref != nil && len(fInfo.Source.Ref.FromNodeKey) > 0 {
fromNode := sc.GetNode(fInfo.Source.Ref.FromNodeKey)
if fromNode == nil {
return schema2.FieldNotStream, fmt.Errorf("node %s not found", fInfo.Source.Ref.FromNodeKey)
}
subStreamType, err := nodes.IsStreamingField(fromNode, fInfo.Source.Ref.FromPath, sc)
if err != nil {
return schema2.FieldNotStream, err
}
if subStreamType == schema2.FieldMaybeStream {
return schema2.FieldMaybeStream, nil
} else if subStreamType == schema2.FieldIsStream {
streamCount++
} else {
notStreamCount++
}
}
}
}
if streamCount > 0 && notStreamCount == 0 {
return schema2.FieldIsStream, nil
}
if streamCount == 0 && notStreamCount > 0 {
return schema2.FieldNotStream, nil
}
return schema2.FieldMaybeStream, nil
}
return schema2.FieldNotStream, fmt.Errorf("variable aggregator output path max len = 2, actual: %v", path)
}
type VariableAggregator struct {
groupLen map[string]int
fullSources map[string]*schema2.SourceInfo
nodeKey vo.NodeKey
groupOrder []string // the order the groups are declared in frontend canvas
}
func (v *VariableAggregator) Invoke(ctx context.Context, input map[string]any) (_ map[string]any, err error) {
@@ -76,7 +202,7 @@ func (v *VariableAggregator) Invoke(ctx context.Context, input map[string]any) (
result := make(map[string]any)
groupToChoice := make(map[string]int)
for group, length := range v.config.GroupLen {
for group, length := range v.groupLen {
for i := 0; i < length; i++ {
if value, ok := in[group][i]; ok {
if value != nil {
@@ -93,14 +219,14 @@ func (v *VariableAggregator) Invoke(ctx context.Context, input map[string]any) (
}
_ = compose.ProcessState(ctx, func(ctx context.Context, state nodes.DynamicStreamContainer) error {
state.SaveDynamicChoice(v.config.NodeKey, groupToChoice)
state.SaveDynamicChoice(v.nodeKey, groupToChoice)
return nil
})
ctxcache.Store(ctx, groupChoiceTypeCacheKey, map[string]nodes.FieldStreamType{}) // none of the choices are stream
ctxcache.Store(ctx, groupChoiceTypeCacheKey, map[string]schema2.FieldStreamType{}) // none of the choices are stream
groupChoices := make([]any, 0, len(v.config.GroupOrder))
for _, group := range v.config.GroupOrder {
groupChoices := make([]any, 0, len(v.groupOrder))
for _, group := range v.groupOrder {
choice := groupToChoice[group]
if choice == -1 {
groupChoices = append(groupChoices, nil)
@@ -125,7 +251,7 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
_ *schema.StreamReader[map[string]any], err error) {
inStream := streamInputConverter(input)
resolvedSources, ok := ctxcache.Get[map[string]*nodes.SourceInfo](ctx, resolvedSourcesCacheKey)
resolvedSources, ok := ctxcache.Get[map[string]*schema2.SourceInfo](ctx, resolvedSourcesCacheKey)
if !ok {
panic("unable to get resolvesSources from ctx cache.")
}
@@ -138,18 +264,18 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
defer func() {
if err == nil {
groupChoiceToStreamType := map[string]nodes.FieldStreamType{}
groupChoiceToStreamType := map[string]schema2.FieldStreamType{}
for group, choice := range groupToChoice {
if choice != -1 {
item := groupToItems[group][choice]
if _, ok := item.(stream); ok {
groupChoiceToStreamType[group] = nodes.FieldIsStream
groupChoiceToStreamType[group] = schema2.FieldIsStream
}
}
}
groupChoices := make([]any, 0, len(v.config.GroupOrder))
for _, group := range v.config.GroupOrder {
groupChoices := make([]any, 0, len(v.groupOrder))
for _, group := range v.groupOrder {
choice := groupToChoice[group]
if choice == -1 {
groupChoices = append(groupChoices, nil)
@@ -174,16 +300,16 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
// - if an element is not stream, actually receive from the stream to check if it's non-nil
groupToCurrentIndex := make(map[string]int) // the currently known smallest index that is non-nil for each group
for group, length := range v.config.GroupLen {
for group, length := range v.groupLen {
groupToItems[group] = make([]any, length)
groupToCurrentIndex[group] = math.MaxInt
for i := 0; i < length; i++ {
fType := resolvedSources[group].SubSources[strconv.Itoa(i)].FieldType
if fType == nodes.FieldSkipped {
if fType == schema2.FieldSkipped {
groupToItems[group][i] = skipped{}
continue
}
if fType == nodes.FieldIsStream {
if fType == schema2.FieldIsStream {
groupToItems[group][i] = stream{}
if ci, _ := groupToCurrentIndex[group]; i < ci {
groupToCurrentIndex[group] = i
@@ -211,7 +337,7 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
}
allDone := func() bool {
for group := range v.config.GroupLen {
for group := range v.groupLen {
_, ok := groupToChoice[group]
if !ok {
return false
@@ -223,7 +349,7 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
alreadyDone := allDone()
if alreadyDone { // all groups have made their choices, no need to actually read input streams
result := make(map[string]any, len(v.config.GroupLen))
result := make(map[string]any, len(v.groupLen))
allSkip := true
for group := range groupToChoice {
choice := groupToChoice[group]
@@ -237,7 +363,7 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
if allSkip { // no need to convert input streams for the output, because all groups are skipped
_ = compose.ProcessState(ctx, func(ctx context.Context, state nodes.DynamicStreamContainer) error {
state.SaveDynamicChoice(v.config.NodeKey, groupToChoice)
state.SaveDynamicChoice(v.nodeKey, groupToChoice)
return nil
})
return schema.StreamReaderFromArray([]map[string]any{result}), nil
@@ -336,7 +462,7 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
}
_ = compose.ProcessState(ctx, func(ctx context.Context, state nodes.DynamicStreamContainer) error {
state.SaveDynamicChoice(v.config.NodeKey, groupToChoice)
state.SaveDynamicChoice(v.nodeKey, groupToChoice)
return nil
})
@@ -416,26 +542,12 @@ type vaCallbackInput struct {
Variables []any `json:"variables"`
}
func (v *VariableAggregator) Init(ctx context.Context) (context.Context, error) {
ctx = ctxcache.Init(ctx)
resolvedSources, err := nodes.ResolveStreamSources(ctx, v.config.FullSources)
if err != nil {
return nil, err
}
// need this info for callbacks.OnStart, so we put it in cache within Init()
ctxcache.Store(ctx, resolvedSourcesCacheKey, resolvedSources)
return ctx, nil
}
type streamMarkerType string
const streamMarker streamMarkerType = "<Stream Data...>"
func (v *VariableAggregator) ToCallbackInput(ctx context.Context, input map[string]any) (map[string]any, error) {
resolvedSources, ok := ctxcache.Get[map[string]*nodes.SourceInfo](ctx, resolvedSourcesCacheKey)
resolvedSources, ok := ctxcache.Get[map[string]*schema2.SourceInfo](ctx, resolvedSourcesCacheKey)
if !ok {
panic("unable to get resolved_sources from ctx cache")
}
@@ -447,14 +559,14 @@ func (v *VariableAggregator) ToCallbackInput(ctx context.Context, input map[stri
merged := make([]vaCallbackInput, 0, len(in))
groupLen := v.config.GroupLen
groupLen := v.groupLen
for groupName, vars := range in {
orderedVars := make([]any, groupLen[groupName])
for index := range vars {
orderedVars[index] = vars[index]
if len(resolvedSources) > 0 {
if resolvedSources[groupName].SubSources[strconv.Itoa(index)].FieldType == nodes.FieldIsStream {
if resolvedSources[groupName].SubSources[strconv.Itoa(index)].FieldType == schema2.FieldIsStream {
// replace the streams with streamMarker,
// because we won't read, save to execution history, or display these streams to user
orderedVars[index] = streamMarker
@@ -479,7 +591,7 @@ func (v *VariableAggregator) ToCallbackInput(ctx context.Context, input map[stri
}
func (v *VariableAggregator) ToCallbackOutput(ctx context.Context, output map[string]any) (*nodes.StructuredCallbackOutput, error) {
dynamicStreamType, ok := ctxcache.Get[map[string]nodes.FieldStreamType](ctx, groupChoiceTypeCacheKey)
dynamicStreamType, ok := ctxcache.Get[map[string]schema2.FieldStreamType](ctx, groupChoiceTypeCacheKey)
if !ok {
panic("unable to get dynamic stream types from ctx cache")
}
@@ -501,7 +613,7 @@ func (v *VariableAggregator) ToCallbackOutput(ctx context.Context, output map[st
newOut := maps.Clone(output)
for k := range output {
if t, ok := dynamicStreamType[k]; ok && t == nodes.FieldIsStream {
if t, ok := dynamicStreamType[k]; ok && t == schema2.FieldIsStream {
newOut[k] = streamMarker
}
}
@@ -594,3 +706,15 @@ func init() {
nodes.RegisterStreamChunkConcatFunc(concatVACallbackInputs)
nodes.RegisterStreamChunkConcatFunc(concatStreamMarkers)
}
func (v *VariableAggregator) Init(ctx context.Context) (context.Context, error) {
resolvedSources, err := nodes.ResolveStreamSources(ctx, v.fullSources)
if err != nil {
return nil, err
}
// need this info for callbacks.OnStart, so we put it in cache within Init()
ctxcache.Store(ctx, resolvedSourcesCacheKey, resolvedSources)
return ctx, nil
}