refactor: how to add a node type in workflow (#558)
This commit is contained in:
@@ -32,8 +32,11 @@ import (
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
@@ -48,24 +51,147 @@ const (
|
||||
type Config struct {
|
||||
MergeStrategy MergeStrategy
|
||||
GroupLen map[string]int
|
||||
FullSources map[string]*nodes.SourceInfo
|
||||
NodeKey vo.NodeKey
|
||||
InputSources []*vo.FieldInfo
|
||||
GroupOrder []string // the order the groups are declared in frontend canvas
|
||||
}
|
||||
|
||||
type VariableAggregator struct {
|
||||
config *Config
|
||||
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema2.NodeSchema, error) {
|
||||
ns := &schema2.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeVariableAggregator,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: c,
|
||||
}
|
||||
|
||||
c.MergeStrategy = FirstNotNullValue
|
||||
inputs := n.Data.Inputs
|
||||
|
||||
groupToLen := make(map[string]int, len(inputs.VariableAggregator.MergeGroups))
|
||||
for i := range inputs.VariableAggregator.MergeGroups {
|
||||
group := inputs.VariableAggregator.MergeGroups[i]
|
||||
tInfo := &vo.TypeInfo{
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: make(map[string]*vo.TypeInfo),
|
||||
}
|
||||
ns.SetInputType(group.Name, tInfo)
|
||||
for ii, v := range group.Variables {
|
||||
name := strconv.Itoa(ii)
|
||||
valueTypeInfo, err := convert.CanvasBlockInputToTypeInfo(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tInfo.Properties[name] = valueTypeInfo
|
||||
sources, err := convert.CanvasBlockInputToFieldInfo(v, compose.FieldPath{group.Name, name}, n.Parent())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ns.AddInputSource(sources...)
|
||||
}
|
||||
|
||||
length := len(group.Variables)
|
||||
groupToLen[group.Name] = length
|
||||
}
|
||||
|
||||
groupOrder := make([]string, 0, len(groupToLen))
|
||||
for i := range inputs.VariableAggregator.MergeGroups {
|
||||
group := inputs.VariableAggregator.MergeGroups[i]
|
||||
groupOrder = append(groupOrder, group.Name)
|
||||
}
|
||||
|
||||
c.GroupLen = groupToLen
|
||||
c.GroupOrder = groupOrder
|
||||
|
||||
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func NewVariableAggregator(_ context.Context, cfg *Config) (*VariableAggregator, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config is required")
|
||||
func (c *Config) Build(_ context.Context, ns *schema2.NodeSchema, _ ...schema2.BuildOption) (any, error) {
|
||||
if c.MergeStrategy != FirstNotNullValue {
|
||||
return nil, fmt.Errorf("merge strategy not supported: %v", c.MergeStrategy)
|
||||
}
|
||||
if cfg.MergeStrategy != FirstNotNullValue {
|
||||
return nil, fmt.Errorf("merge strategy not supported: %v", cfg.MergeStrategy)
|
||||
|
||||
return &VariableAggregator{
|
||||
groupLen: c.GroupLen,
|
||||
fullSources: ns.FullSources,
|
||||
nodeKey: ns.Key,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Config) FieldStreamType(path compose.FieldPath, ns *schema2.NodeSchema,
|
||||
sc *schema2.WorkflowSchema) (schema2.FieldStreamType, error) {
|
||||
if !sc.RequireStreaming() {
|
||||
return schema2.FieldNotStream, nil
|
||||
}
|
||||
return &VariableAggregator{config: cfg}, nil
|
||||
|
||||
if len(path) == 2 { // asking about a specific index within a group
|
||||
for _, fInfo := range ns.InputSources {
|
||||
if len(fInfo.Path) == len(path) {
|
||||
equal := true
|
||||
for i := range fInfo.Path {
|
||||
if fInfo.Path[i] != path[i] {
|
||||
equal = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if equal {
|
||||
if fInfo.Source.Ref == nil || fInfo.Source.Ref.FromNodeKey == "" {
|
||||
return schema2.FieldNotStream, nil // variables or static values
|
||||
}
|
||||
fromNodeKey := fInfo.Source.Ref.FromNodeKey
|
||||
fromNode := sc.GetNode(fromNodeKey)
|
||||
if fromNode == nil {
|
||||
return schema2.FieldNotStream, fmt.Errorf("node %s not found", fromNodeKey)
|
||||
}
|
||||
return nodes.IsStreamingField(fromNode, fInfo.Source.Ref.FromPath, sc)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if len(path) == 1 { // asking about the entire group
|
||||
var streamCount, notStreamCount int
|
||||
for _, fInfo := range ns.InputSources {
|
||||
if fInfo.Path[0] == path[0] { // belong to the group
|
||||
if fInfo.Source.Ref != nil && len(fInfo.Source.Ref.FromNodeKey) > 0 {
|
||||
fromNode := sc.GetNode(fInfo.Source.Ref.FromNodeKey)
|
||||
if fromNode == nil {
|
||||
return schema2.FieldNotStream, fmt.Errorf("node %s not found", fInfo.Source.Ref.FromNodeKey)
|
||||
}
|
||||
subStreamType, err := nodes.IsStreamingField(fromNode, fInfo.Source.Ref.FromPath, sc)
|
||||
if err != nil {
|
||||
return schema2.FieldNotStream, err
|
||||
}
|
||||
|
||||
if subStreamType == schema2.FieldMaybeStream {
|
||||
return schema2.FieldMaybeStream, nil
|
||||
} else if subStreamType == schema2.FieldIsStream {
|
||||
streamCount++
|
||||
} else {
|
||||
notStreamCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if streamCount > 0 && notStreamCount == 0 {
|
||||
return schema2.FieldIsStream, nil
|
||||
}
|
||||
|
||||
if streamCount == 0 && notStreamCount > 0 {
|
||||
return schema2.FieldNotStream, nil
|
||||
}
|
||||
|
||||
return schema2.FieldMaybeStream, nil
|
||||
}
|
||||
|
||||
return schema2.FieldNotStream, fmt.Errorf("variable aggregator output path max len = 2, actual: %v", path)
|
||||
}
|
||||
|
||||
type VariableAggregator struct {
|
||||
groupLen map[string]int
|
||||
fullSources map[string]*schema2.SourceInfo
|
||||
nodeKey vo.NodeKey
|
||||
groupOrder []string // the order the groups are declared in frontend canvas
|
||||
}
|
||||
|
||||
func (v *VariableAggregator) Invoke(ctx context.Context, input map[string]any) (_ map[string]any, err error) {
|
||||
@@ -76,7 +202,7 @@ func (v *VariableAggregator) Invoke(ctx context.Context, input map[string]any) (
|
||||
|
||||
result := make(map[string]any)
|
||||
groupToChoice := make(map[string]int)
|
||||
for group, length := range v.config.GroupLen {
|
||||
for group, length := range v.groupLen {
|
||||
for i := 0; i < length; i++ {
|
||||
if value, ok := in[group][i]; ok {
|
||||
if value != nil {
|
||||
@@ -93,14 +219,14 @@ func (v *VariableAggregator) Invoke(ctx context.Context, input map[string]any) (
|
||||
}
|
||||
|
||||
_ = compose.ProcessState(ctx, func(ctx context.Context, state nodes.DynamicStreamContainer) error {
|
||||
state.SaveDynamicChoice(v.config.NodeKey, groupToChoice)
|
||||
state.SaveDynamicChoice(v.nodeKey, groupToChoice)
|
||||
return nil
|
||||
})
|
||||
|
||||
ctxcache.Store(ctx, groupChoiceTypeCacheKey, map[string]nodes.FieldStreamType{}) // none of the choices are stream
|
||||
ctxcache.Store(ctx, groupChoiceTypeCacheKey, map[string]schema2.FieldStreamType{}) // none of the choices are stream
|
||||
|
||||
groupChoices := make([]any, 0, len(v.config.GroupOrder))
|
||||
for _, group := range v.config.GroupOrder {
|
||||
groupChoices := make([]any, 0, len(v.groupOrder))
|
||||
for _, group := range v.groupOrder {
|
||||
choice := groupToChoice[group]
|
||||
if choice == -1 {
|
||||
groupChoices = append(groupChoices, nil)
|
||||
@@ -125,7 +251,7 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
|
||||
_ *schema.StreamReader[map[string]any], err error) {
|
||||
inStream := streamInputConverter(input)
|
||||
|
||||
resolvedSources, ok := ctxcache.Get[map[string]*nodes.SourceInfo](ctx, resolvedSourcesCacheKey)
|
||||
resolvedSources, ok := ctxcache.Get[map[string]*schema2.SourceInfo](ctx, resolvedSourcesCacheKey)
|
||||
if !ok {
|
||||
panic("unable to get resolvesSources from ctx cache.")
|
||||
}
|
||||
@@ -138,18 +264,18 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
|
||||
|
||||
defer func() {
|
||||
if err == nil {
|
||||
groupChoiceToStreamType := map[string]nodes.FieldStreamType{}
|
||||
groupChoiceToStreamType := map[string]schema2.FieldStreamType{}
|
||||
for group, choice := range groupToChoice {
|
||||
if choice != -1 {
|
||||
item := groupToItems[group][choice]
|
||||
if _, ok := item.(stream); ok {
|
||||
groupChoiceToStreamType[group] = nodes.FieldIsStream
|
||||
groupChoiceToStreamType[group] = schema2.FieldIsStream
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
groupChoices := make([]any, 0, len(v.config.GroupOrder))
|
||||
for _, group := range v.config.GroupOrder {
|
||||
groupChoices := make([]any, 0, len(v.groupOrder))
|
||||
for _, group := range v.groupOrder {
|
||||
choice := groupToChoice[group]
|
||||
if choice == -1 {
|
||||
groupChoices = append(groupChoices, nil)
|
||||
@@ -174,16 +300,16 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
|
||||
// - if an element is not stream, actually receive from the stream to check if it's non-nil
|
||||
|
||||
groupToCurrentIndex := make(map[string]int) // the currently known smallest index that is non-nil for each group
|
||||
for group, length := range v.config.GroupLen {
|
||||
for group, length := range v.groupLen {
|
||||
groupToItems[group] = make([]any, length)
|
||||
groupToCurrentIndex[group] = math.MaxInt
|
||||
for i := 0; i < length; i++ {
|
||||
fType := resolvedSources[group].SubSources[strconv.Itoa(i)].FieldType
|
||||
if fType == nodes.FieldSkipped {
|
||||
if fType == schema2.FieldSkipped {
|
||||
groupToItems[group][i] = skipped{}
|
||||
continue
|
||||
}
|
||||
if fType == nodes.FieldIsStream {
|
||||
if fType == schema2.FieldIsStream {
|
||||
groupToItems[group][i] = stream{}
|
||||
if ci, _ := groupToCurrentIndex[group]; i < ci {
|
||||
groupToCurrentIndex[group] = i
|
||||
@@ -211,7 +337,7 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
|
||||
}
|
||||
|
||||
allDone := func() bool {
|
||||
for group := range v.config.GroupLen {
|
||||
for group := range v.groupLen {
|
||||
_, ok := groupToChoice[group]
|
||||
if !ok {
|
||||
return false
|
||||
@@ -223,7 +349,7 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
|
||||
|
||||
alreadyDone := allDone()
|
||||
if alreadyDone { // all groups have made their choices, no need to actually read input streams
|
||||
result := make(map[string]any, len(v.config.GroupLen))
|
||||
result := make(map[string]any, len(v.groupLen))
|
||||
allSkip := true
|
||||
for group := range groupToChoice {
|
||||
choice := groupToChoice[group]
|
||||
@@ -237,7 +363,7 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
|
||||
|
||||
if allSkip { // no need to convert input streams for the output, because all groups are skipped
|
||||
_ = compose.ProcessState(ctx, func(ctx context.Context, state nodes.DynamicStreamContainer) error {
|
||||
state.SaveDynamicChoice(v.config.NodeKey, groupToChoice)
|
||||
state.SaveDynamicChoice(v.nodeKey, groupToChoice)
|
||||
return nil
|
||||
})
|
||||
return schema.StreamReaderFromArray([]map[string]any{result}), nil
|
||||
@@ -336,7 +462,7 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
|
||||
}
|
||||
|
||||
_ = compose.ProcessState(ctx, func(ctx context.Context, state nodes.DynamicStreamContainer) error {
|
||||
state.SaveDynamicChoice(v.config.NodeKey, groupToChoice)
|
||||
state.SaveDynamicChoice(v.nodeKey, groupToChoice)
|
||||
return nil
|
||||
})
|
||||
|
||||
@@ -416,26 +542,12 @@ type vaCallbackInput struct {
|
||||
Variables []any `json:"variables"`
|
||||
}
|
||||
|
||||
func (v *VariableAggregator) Init(ctx context.Context) (context.Context, error) {
|
||||
ctx = ctxcache.Init(ctx)
|
||||
|
||||
resolvedSources, err := nodes.ResolveStreamSources(ctx, v.config.FullSources)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// need this info for callbacks.OnStart, so we put it in cache within Init()
|
||||
ctxcache.Store(ctx, resolvedSourcesCacheKey, resolvedSources)
|
||||
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
type streamMarkerType string
|
||||
|
||||
const streamMarker streamMarkerType = "<Stream Data...>"
|
||||
|
||||
func (v *VariableAggregator) ToCallbackInput(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
resolvedSources, ok := ctxcache.Get[map[string]*nodes.SourceInfo](ctx, resolvedSourcesCacheKey)
|
||||
resolvedSources, ok := ctxcache.Get[map[string]*schema2.SourceInfo](ctx, resolvedSourcesCacheKey)
|
||||
if !ok {
|
||||
panic("unable to get resolved_sources from ctx cache")
|
||||
}
|
||||
@@ -447,14 +559,14 @@ func (v *VariableAggregator) ToCallbackInput(ctx context.Context, input map[stri
|
||||
|
||||
merged := make([]vaCallbackInput, 0, len(in))
|
||||
|
||||
groupLen := v.config.GroupLen
|
||||
groupLen := v.groupLen
|
||||
|
||||
for groupName, vars := range in {
|
||||
orderedVars := make([]any, groupLen[groupName])
|
||||
for index := range vars {
|
||||
orderedVars[index] = vars[index]
|
||||
if len(resolvedSources) > 0 {
|
||||
if resolvedSources[groupName].SubSources[strconv.Itoa(index)].FieldType == nodes.FieldIsStream {
|
||||
if resolvedSources[groupName].SubSources[strconv.Itoa(index)].FieldType == schema2.FieldIsStream {
|
||||
// replace the streams with streamMarker,
|
||||
// because we won't read, save to execution history, or display these streams to user
|
||||
orderedVars[index] = streamMarker
|
||||
@@ -479,7 +591,7 @@ func (v *VariableAggregator) ToCallbackInput(ctx context.Context, input map[stri
|
||||
}
|
||||
|
||||
func (v *VariableAggregator) ToCallbackOutput(ctx context.Context, output map[string]any) (*nodes.StructuredCallbackOutput, error) {
|
||||
dynamicStreamType, ok := ctxcache.Get[map[string]nodes.FieldStreamType](ctx, groupChoiceTypeCacheKey)
|
||||
dynamicStreamType, ok := ctxcache.Get[map[string]schema2.FieldStreamType](ctx, groupChoiceTypeCacheKey)
|
||||
if !ok {
|
||||
panic("unable to get dynamic stream types from ctx cache")
|
||||
}
|
||||
@@ -501,7 +613,7 @@ func (v *VariableAggregator) ToCallbackOutput(ctx context.Context, output map[st
|
||||
|
||||
newOut := maps.Clone(output)
|
||||
for k := range output {
|
||||
if t, ok := dynamicStreamType[k]; ok && t == nodes.FieldIsStream {
|
||||
if t, ok := dynamicStreamType[k]; ok && t == schema2.FieldIsStream {
|
||||
newOut[k] = streamMarker
|
||||
}
|
||||
}
|
||||
@@ -594,3 +706,15 @@ func init() {
|
||||
nodes.RegisterStreamChunkConcatFunc(concatVACallbackInputs)
|
||||
nodes.RegisterStreamChunkConcatFunc(concatStreamMarkers)
|
||||
}
|
||||
|
||||
func (v *VariableAggregator) Init(ctx context.Context) (context.Context, error) {
|
||||
resolvedSources, err := nodes.ResolveStreamSources(ctx, v.fullSources)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// need this info for callbacks.OnStart, so we put it in cache within Init()
|
||||
ctxcache.Store(ctx, resolvedSourcesCacheKey, resolvedSources)
|
||||
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user