refactor: how to add a node type in workflow (#558)

This commit is contained in:
shentongmartin
2025-08-05 14:02:33 +08:00
committed by GitHub
parent 5dafd81a3f
commit bb6ff0026b
96 changed files with 8305 additions and 8717 deletions

View File

@@ -29,6 +29,8 @@ import (
workflow2 "github.com/coze-dev/coze-studio/backend/domain/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
)
@@ -37,7 +39,7 @@ type workflow = compose.Workflow[map[string]any, map[string]any]
type Workflow struct { // TODO: too many fields in this struct, cut them down to the absolutely essentials
*workflow
hierarchy map[vo.NodeKey]vo.NodeKey
connections []*Connection
connections []*schema.Connection
requireCheckpoint bool
entry *compose.WorkflowNode
inner bool
@@ -47,7 +49,7 @@ type Workflow struct { // TODO: too many fields in this struct, cut them down to
input map[string]*vo.TypeInfo
output map[string]*vo.TypeInfo
terminatePlan vo.TerminatePlan
schema *WorkflowSchema
schema *schema.WorkflowSchema
}
type workflowOptions struct {
@@ -78,7 +80,7 @@ func WithMaxNodeCount(c int) WorkflowOption {
}
}
func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption) (*Workflow, error) {
func NewWorkflow(ctx context.Context, sc *schema.WorkflowSchema, opts ...WorkflowOption) (*Workflow, error) {
sc.Init()
wf := &Workflow{
@@ -88,8 +90,8 @@ func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption
schema: sc,
}
wf.streamRun = sc.requireStreaming
wf.requireCheckpoint = sc.requireCheckPoint
wf.streamRun = sc.RequireStreaming()
wf.requireCheckpoint = sc.RequireCheckpoint()
wfOpts := &workflowOptions{}
for _, opt := range opts {
@@ -125,7 +127,6 @@ func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption
processedNodeKey[child.Key] = struct{}{}
}
}
// add all nodes other than composite nodes and their children
for _, ns := range sc.Nodes {
if _, ok := processedNodeKey[ns.Key]; !ok {
@@ -135,7 +136,7 @@ func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption
}
if ns.Type == entity.NodeTypeExit {
wf.terminatePlan = mustGetKey[vo.TerminatePlan]("TerminalPlan", ns.Configs)
wf.terminatePlan = ns.Configs.(*exit.Config).TerminatePlan
}
}
@@ -147,7 +148,7 @@ func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption
compileOpts = append(compileOpts, compose.WithGraphName(strconv.FormatInt(wfOpts.wfID, 10)))
}
fanInConfigs := sc.fanInMergeConfigs()
fanInConfigs := sc.FanInMergeConfigs()
if len(fanInConfigs) > 0 {
compileOpts = append(compileOpts, compose.WithFanInMergeConfig(fanInConfigs))
}
@@ -199,12 +200,12 @@ type innerWorkflowInfo struct {
carryOvers map[vo.NodeKey][]*compose.FieldMapping
}
func (w *Workflow) AddNode(ctx context.Context, ns *NodeSchema) error {
func (w *Workflow) AddNode(ctx context.Context, ns *schema.NodeSchema) error {
_, err := w.addNodeInternal(ctx, ns, nil)
return err
}
func (w *Workflow) AddCompositeNode(ctx context.Context, cNode *CompositeNode) error {
func (w *Workflow) AddCompositeNode(ctx context.Context, cNode *schema.CompositeNode) error {
inner, err := w.getInnerWorkflow(ctx, cNode)
if err != nil {
return err
@@ -213,11 +214,11 @@ func (w *Workflow) AddCompositeNode(ctx context.Context, cNode *CompositeNode) e
return err
}
func (w *Workflow) addInnerNode(ctx context.Context, cNode *NodeSchema) (map[vo.NodeKey][]*compose.FieldMapping, error) {
func (w *Workflow) addInnerNode(ctx context.Context, cNode *schema.NodeSchema) (map[vo.NodeKey][]*compose.FieldMapping, error) {
return w.addNodeInternal(ctx, cNode, nil)
}
func (w *Workflow) addNodeInternal(ctx context.Context, ns *NodeSchema, inner *innerWorkflowInfo) (map[vo.NodeKey][]*compose.FieldMapping, error) {
func (w *Workflow) addNodeInternal(ctx context.Context, ns *schema.NodeSchema, inner *innerWorkflowInfo) (map[vo.NodeKey][]*compose.FieldMapping, error) {
key := ns.Key
var deps *dependencyInfo
@@ -237,7 +238,7 @@ func (w *Workflow) addNodeInternal(ctx context.Context, ns *NodeSchema, inner *i
innerWorkflow = inner.inner
}
ins, err := ns.New(ctx, innerWorkflow, w.schema, deps)
ins, err := New(ctx, ns, innerWorkflow, w.schema, deps)
if err != nil {
return nil, err
}
@@ -245,12 +246,12 @@ func (w *Workflow) addNodeInternal(ctx context.Context, ns *NodeSchema, inner *i
var opts []compose.GraphAddNodeOpt
opts = append(opts, compose.WithNodeName(string(ns.Key)))
preHandler := ns.StatePreHandler(w.streamRun)
preHandler := statePreHandler(ns, w.streamRun)
if preHandler != nil {
opts = append(opts, preHandler)
}
postHandler := ns.StatePostHandler(w.streamRun)
postHandler := statePostHandler(ns, w.streamRun)
if postHandler != nil {
opts = append(opts, postHandler)
}
@@ -297,19 +298,23 @@ func (w *Workflow) addNodeInternal(ctx context.Context, ns *NodeSchema, inner *i
w.entry = wNode
}
outputPortCount, hasExceptionPort := ns.OutputPortCount()
if outputPortCount > 1 || hasExceptionPort {
bMapping, err := w.resolveBranch(key, outputPortCount)
if err != nil {
return nil, err
}
b := w.schema.GetBranch(ns.Key)
if b != nil {
if b.OnlyException() {
_ = w.AddBranch(string(key), b.GetExceptionBranch())
} else {
bb, ok := ns.Configs.(schema.BranchBuilder)
if !ok {
return nil, fmt.Errorf("node schema's Configs should implement BranchBuilder, node type= %v", ns.Type)
}
branch, err := ns.GetBranch(bMapping)
if err != nil {
return nil, err
}
br, err := b.GetFullBranch(ctx, bb)
if err != nil {
return nil, err
}
_ = w.AddBranch(string(key), branch)
_ = w.AddBranch(string(key), br)
}
}
return deps.inputsForParent, nil
@@ -328,15 +333,15 @@ func (w *Workflow) Compile(ctx context.Context, opts ...compose.GraphCompileOpti
return w.workflow.Compile(ctx, opts...)
}
func (w *Workflow) getInnerWorkflow(ctx context.Context, cNode *CompositeNode) (*innerWorkflowInfo, error) {
innerNodes := make(map[vo.NodeKey]*NodeSchema)
func (w *Workflow) getInnerWorkflow(ctx context.Context, cNode *schema.CompositeNode) (*innerWorkflowInfo, error) {
innerNodes := make(map[vo.NodeKey]*schema.NodeSchema)
for _, n := range cNode.Children {
innerNodes[n.Key] = n
}
// trim the connections, only keep the connections that are related to the inner workflow
// ignore the cases when we have nested inner workflows, because we do not support nested composite nodes
innerConnections := make([]*Connection, 0)
innerConnections := make([]*schema.Connection, 0)
for i := range w.schema.Connections {
conn := w.schema.Connections[i]
if _, ok := innerNodes[conn.FromNode]; ok {
@@ -510,7 +515,7 @@ func (d *dependencyInfo) merge(mappings map[vo.NodeKey][]*compose.FieldMapping)
// For example, if the 'from path' is ['a', 'b', 'c'], and 'b' is an array, we will take value using a.b[0].c.
// As a counter example, if the 'from path' is ['a', 'b', 'c'], and 'b' is not an array, but 'c' is an array,
// we will not try to drill, instead, just take value using a.b.c.
func (d *dependencyInfo) arrayDrillDown(allNS map[vo.NodeKey]*NodeSchema) error {
func (d *dependencyInfo) arrayDrillDown(allNS map[vo.NodeKey]*schema.NodeSchema) error {
for nKey, fms := range d.inputs {
if nKey == compose.START { // reference to START node would NEVER need to do array drill down
continue
@@ -638,55 +643,6 @@ type variableInfo struct {
toPath compose.FieldPath
}
func (w *Workflow) resolveBranch(n vo.NodeKey, portCount int) (*BranchMapping, error) {
m := make([]map[string]bool, portCount)
var exception map[string]bool
for _, conn := range w.connections {
if conn.FromNode != n {
continue
}
if conn.FromPort == nil {
continue
}
if *conn.FromPort == "default" { // default condition
if m[portCount-1] == nil {
m[portCount-1] = make(map[string]bool)
}
m[portCount-1][string(conn.ToNode)] = true
} else if *conn.FromPort == "branch_error" {
if exception == nil {
exception = make(map[string]bool)
}
exception[string(conn.ToNode)] = true
} else {
if !strings.HasPrefix(*conn.FromPort, "branch_") {
return nil, fmt.Errorf("outgoing connections has invalid port= %s", *conn.FromPort)
}
index := (*conn.FromPort)[7:]
i, err := strconv.Atoi(index)
if err != nil {
return nil, fmt.Errorf("outgoing connections has invalid port index= %s", *conn.FromPort)
}
if i < 0 || i >= portCount {
return nil, fmt.Errorf("outgoing connections has invalid port index range= %d, condition count= %d", i, portCount)
}
if m[i] == nil {
m[i] = make(map[string]bool)
}
m[i][string(conn.ToNode)] = true
}
}
return &BranchMapping{
Normal: m,
Exception: exception,
}, nil
}
func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.FieldInfo) (*dependencyInfo, error) {
var (
inputs = make(map[vo.NodeKey][]*compose.FieldMapping)
@@ -701,7 +657,7 @@ func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.Field
inputsForParent = make(map[vo.NodeKey][]*compose.FieldMapping)
)
connMap := make(map[vo.NodeKey]Connection)
connMap := make(map[vo.NodeKey]schema.Connection)
for _, conn := range w.connections {
if conn.ToNode != n {
continue
@@ -734,7 +690,7 @@ func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.Field
continue
}
if ok := isInSameWorkflow(w.hierarchy, n, fromNode); ok {
if ok := schema.IsInSameWorkflow(w.hierarchy, n, fromNode); ok {
if _, ok := connMap[fromNode]; ok { // direct dependency
if len(swp.Source.Ref.FromPath) == 0 && len(swp.Path) == 0 {
if inputFull == nil {
@@ -755,10 +711,10 @@ func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.Field
compose.MapFieldPaths(swp.Source.Ref.FromPath, swp.Path))
}
}
} else if ok := isBelowOneLevel(w.hierarchy, n, fromNode); ok {
} else if ok := schema.IsBelowOneLevel(w.hierarchy, n, fromNode); ok {
firstNodesInInnerWorkflow := true
for _, conn := range connMap {
if isInSameWorkflow(w.hierarchy, n, conn.FromNode) {
if schema.IsInSameWorkflow(w.hierarchy, n, conn.FromNode) {
// there is another node 'conn.FromNode' that connects to this node, while also at the same level
firstNodesInInnerWorkflow = false
break
@@ -805,9 +761,9 @@ func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.Field
continue
}
if isBelowOneLevel(w.hierarchy, n, fromNodeKey) {
if schema.IsBelowOneLevel(w.hierarchy, n, fromNodeKey) {
fromNodeKey = compose.START
} else if !isInSameWorkflow(w.hierarchy, n, fromNodeKey) {
} else if !schema.IsInSameWorkflow(w.hierarchy, n, fromNodeKey) {
continue
}
@@ -864,13 +820,13 @@ func (w *Workflow) resolveDependenciesAsParent(n vo.NodeKey, sourceWithPaths []*
variableInfos []*variableInfo
)
connMap := make(map[vo.NodeKey]Connection)
connMap := make(map[vo.NodeKey]schema.Connection)
for _, conn := range w.connections {
if conn.ToNode != n {
continue
}
if isInSameWorkflow(w.hierarchy, conn.FromNode, n) {
if schema.IsInSameWorkflow(w.hierarchy, conn.FromNode, n) {
continue
}
@@ -899,7 +855,7 @@ func (w *Workflow) resolveDependenciesAsParent(n vo.NodeKey, sourceWithPaths []*
swp.Source.Ref.FromPath, swp.Path)
}
if ok := isParentOf(w.hierarchy, n, fromNode); ok {
if ok := schema.IsParentOf(w.hierarchy, n, fromNode); ok {
if _, ok := connMap[fromNode]; ok { // direct dependency
inputs[fromNode] = append(inputs[fromNode], compose.MapFieldPaths(swp.Source.Ref.FromPath, append(compose.FieldPath{string(fromNode)}, swp.Source.Ref.FromPath...)))
} else { // indirect dependency