refactor: how to add a node type in workflow (#558)
This commit is contained in:
@@ -29,6 +29,8 @@ import (
|
||||
workflow2 "github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
)
|
||||
|
||||
@@ -37,7 +39,7 @@ type workflow = compose.Workflow[map[string]any, map[string]any]
|
||||
type Workflow struct { // TODO: too many fields in this struct, cut them down to the absolutely essentials
|
||||
*workflow
|
||||
hierarchy map[vo.NodeKey]vo.NodeKey
|
||||
connections []*Connection
|
||||
connections []*schema.Connection
|
||||
requireCheckpoint bool
|
||||
entry *compose.WorkflowNode
|
||||
inner bool
|
||||
@@ -47,7 +49,7 @@ type Workflow struct { // TODO: too many fields in this struct, cut them down to
|
||||
input map[string]*vo.TypeInfo
|
||||
output map[string]*vo.TypeInfo
|
||||
terminatePlan vo.TerminatePlan
|
||||
schema *WorkflowSchema
|
||||
schema *schema.WorkflowSchema
|
||||
}
|
||||
|
||||
type workflowOptions struct {
|
||||
@@ -78,7 +80,7 @@ func WithMaxNodeCount(c int) WorkflowOption {
|
||||
}
|
||||
}
|
||||
|
||||
func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption) (*Workflow, error) {
|
||||
func NewWorkflow(ctx context.Context, sc *schema.WorkflowSchema, opts ...WorkflowOption) (*Workflow, error) {
|
||||
sc.Init()
|
||||
|
||||
wf := &Workflow{
|
||||
@@ -88,8 +90,8 @@ func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption
|
||||
schema: sc,
|
||||
}
|
||||
|
||||
wf.streamRun = sc.requireStreaming
|
||||
wf.requireCheckpoint = sc.requireCheckPoint
|
||||
wf.streamRun = sc.RequireStreaming()
|
||||
wf.requireCheckpoint = sc.RequireCheckpoint()
|
||||
|
||||
wfOpts := &workflowOptions{}
|
||||
for _, opt := range opts {
|
||||
@@ -125,7 +127,6 @@ func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption
|
||||
processedNodeKey[child.Key] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// add all nodes other than composite nodes and their children
|
||||
for _, ns := range sc.Nodes {
|
||||
if _, ok := processedNodeKey[ns.Key]; !ok {
|
||||
@@ -135,7 +136,7 @@ func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption
|
||||
}
|
||||
|
||||
if ns.Type == entity.NodeTypeExit {
|
||||
wf.terminatePlan = mustGetKey[vo.TerminatePlan]("TerminalPlan", ns.Configs)
|
||||
wf.terminatePlan = ns.Configs.(*exit.Config).TerminatePlan
|
||||
}
|
||||
}
|
||||
|
||||
@@ -147,7 +148,7 @@ func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption
|
||||
compileOpts = append(compileOpts, compose.WithGraphName(strconv.FormatInt(wfOpts.wfID, 10)))
|
||||
}
|
||||
|
||||
fanInConfigs := sc.fanInMergeConfigs()
|
||||
fanInConfigs := sc.FanInMergeConfigs()
|
||||
if len(fanInConfigs) > 0 {
|
||||
compileOpts = append(compileOpts, compose.WithFanInMergeConfig(fanInConfigs))
|
||||
}
|
||||
@@ -199,12 +200,12 @@ type innerWorkflowInfo struct {
|
||||
carryOvers map[vo.NodeKey][]*compose.FieldMapping
|
||||
}
|
||||
|
||||
func (w *Workflow) AddNode(ctx context.Context, ns *NodeSchema) error {
|
||||
func (w *Workflow) AddNode(ctx context.Context, ns *schema.NodeSchema) error {
|
||||
_, err := w.addNodeInternal(ctx, ns, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *Workflow) AddCompositeNode(ctx context.Context, cNode *CompositeNode) error {
|
||||
func (w *Workflow) AddCompositeNode(ctx context.Context, cNode *schema.CompositeNode) error {
|
||||
inner, err := w.getInnerWorkflow(ctx, cNode)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -213,11 +214,11 @@ func (w *Workflow) AddCompositeNode(ctx context.Context, cNode *CompositeNode) e
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *Workflow) addInnerNode(ctx context.Context, cNode *NodeSchema) (map[vo.NodeKey][]*compose.FieldMapping, error) {
|
||||
func (w *Workflow) addInnerNode(ctx context.Context, cNode *schema.NodeSchema) (map[vo.NodeKey][]*compose.FieldMapping, error) {
|
||||
return w.addNodeInternal(ctx, cNode, nil)
|
||||
}
|
||||
|
||||
func (w *Workflow) addNodeInternal(ctx context.Context, ns *NodeSchema, inner *innerWorkflowInfo) (map[vo.NodeKey][]*compose.FieldMapping, error) {
|
||||
func (w *Workflow) addNodeInternal(ctx context.Context, ns *schema.NodeSchema, inner *innerWorkflowInfo) (map[vo.NodeKey][]*compose.FieldMapping, error) {
|
||||
key := ns.Key
|
||||
var deps *dependencyInfo
|
||||
|
||||
@@ -237,7 +238,7 @@ func (w *Workflow) addNodeInternal(ctx context.Context, ns *NodeSchema, inner *i
|
||||
innerWorkflow = inner.inner
|
||||
}
|
||||
|
||||
ins, err := ns.New(ctx, innerWorkflow, w.schema, deps)
|
||||
ins, err := New(ctx, ns, innerWorkflow, w.schema, deps)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -245,12 +246,12 @@ func (w *Workflow) addNodeInternal(ctx context.Context, ns *NodeSchema, inner *i
|
||||
var opts []compose.GraphAddNodeOpt
|
||||
opts = append(opts, compose.WithNodeName(string(ns.Key)))
|
||||
|
||||
preHandler := ns.StatePreHandler(w.streamRun)
|
||||
preHandler := statePreHandler(ns, w.streamRun)
|
||||
if preHandler != nil {
|
||||
opts = append(opts, preHandler)
|
||||
}
|
||||
|
||||
postHandler := ns.StatePostHandler(w.streamRun)
|
||||
postHandler := statePostHandler(ns, w.streamRun)
|
||||
if postHandler != nil {
|
||||
opts = append(opts, postHandler)
|
||||
}
|
||||
@@ -297,19 +298,23 @@ func (w *Workflow) addNodeInternal(ctx context.Context, ns *NodeSchema, inner *i
|
||||
w.entry = wNode
|
||||
}
|
||||
|
||||
outputPortCount, hasExceptionPort := ns.OutputPortCount()
|
||||
if outputPortCount > 1 || hasExceptionPort {
|
||||
bMapping, err := w.resolveBranch(key, outputPortCount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b := w.schema.GetBranch(ns.Key)
|
||||
if b != nil {
|
||||
if b.OnlyException() {
|
||||
_ = w.AddBranch(string(key), b.GetExceptionBranch())
|
||||
} else {
|
||||
bb, ok := ns.Configs.(schema.BranchBuilder)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("node schema's Configs should implement BranchBuilder, node type= %v", ns.Type)
|
||||
}
|
||||
|
||||
branch, err := ns.GetBranch(bMapping)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
br, err := b.GetFullBranch(ctx, bb)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_ = w.AddBranch(string(key), branch)
|
||||
_ = w.AddBranch(string(key), br)
|
||||
}
|
||||
}
|
||||
|
||||
return deps.inputsForParent, nil
|
||||
@@ -328,15 +333,15 @@ func (w *Workflow) Compile(ctx context.Context, opts ...compose.GraphCompileOpti
|
||||
return w.workflow.Compile(ctx, opts...)
|
||||
}
|
||||
|
||||
func (w *Workflow) getInnerWorkflow(ctx context.Context, cNode *CompositeNode) (*innerWorkflowInfo, error) {
|
||||
innerNodes := make(map[vo.NodeKey]*NodeSchema)
|
||||
func (w *Workflow) getInnerWorkflow(ctx context.Context, cNode *schema.CompositeNode) (*innerWorkflowInfo, error) {
|
||||
innerNodes := make(map[vo.NodeKey]*schema.NodeSchema)
|
||||
for _, n := range cNode.Children {
|
||||
innerNodes[n.Key] = n
|
||||
}
|
||||
|
||||
// trim the connections, only keep the connections that are related to the inner workflow
|
||||
// ignore the cases when we have nested inner workflows, because we do not support nested composite nodes
|
||||
innerConnections := make([]*Connection, 0)
|
||||
innerConnections := make([]*schema.Connection, 0)
|
||||
for i := range w.schema.Connections {
|
||||
conn := w.schema.Connections[i]
|
||||
if _, ok := innerNodes[conn.FromNode]; ok {
|
||||
@@ -510,7 +515,7 @@ func (d *dependencyInfo) merge(mappings map[vo.NodeKey][]*compose.FieldMapping)
|
||||
// For example, if the 'from path' is ['a', 'b', 'c'], and 'b' is an array, we will take value using a.b[0].c.
|
||||
// As a counter example, if the 'from path' is ['a', 'b', 'c'], and 'b' is not an array, but 'c' is an array,
|
||||
// we will not try to drill, instead, just take value using a.b.c.
|
||||
func (d *dependencyInfo) arrayDrillDown(allNS map[vo.NodeKey]*NodeSchema) error {
|
||||
func (d *dependencyInfo) arrayDrillDown(allNS map[vo.NodeKey]*schema.NodeSchema) error {
|
||||
for nKey, fms := range d.inputs {
|
||||
if nKey == compose.START { // reference to START node would NEVER need to do array drill down
|
||||
continue
|
||||
@@ -638,55 +643,6 @@ type variableInfo struct {
|
||||
toPath compose.FieldPath
|
||||
}
|
||||
|
||||
func (w *Workflow) resolveBranch(n vo.NodeKey, portCount int) (*BranchMapping, error) {
|
||||
m := make([]map[string]bool, portCount)
|
||||
var exception map[string]bool
|
||||
|
||||
for _, conn := range w.connections {
|
||||
if conn.FromNode != n {
|
||||
continue
|
||||
}
|
||||
|
||||
if conn.FromPort == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if *conn.FromPort == "default" { // default condition
|
||||
if m[portCount-1] == nil {
|
||||
m[portCount-1] = make(map[string]bool)
|
||||
}
|
||||
m[portCount-1][string(conn.ToNode)] = true
|
||||
} else if *conn.FromPort == "branch_error" {
|
||||
if exception == nil {
|
||||
exception = make(map[string]bool)
|
||||
}
|
||||
exception[string(conn.ToNode)] = true
|
||||
} else {
|
||||
if !strings.HasPrefix(*conn.FromPort, "branch_") {
|
||||
return nil, fmt.Errorf("outgoing connections has invalid port= %s", *conn.FromPort)
|
||||
}
|
||||
|
||||
index := (*conn.FromPort)[7:]
|
||||
i, err := strconv.Atoi(index)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("outgoing connections has invalid port index= %s", *conn.FromPort)
|
||||
}
|
||||
if i < 0 || i >= portCount {
|
||||
return nil, fmt.Errorf("outgoing connections has invalid port index range= %d, condition count= %d", i, portCount)
|
||||
}
|
||||
if m[i] == nil {
|
||||
m[i] = make(map[string]bool)
|
||||
}
|
||||
m[i][string(conn.ToNode)] = true
|
||||
}
|
||||
}
|
||||
|
||||
return &BranchMapping{
|
||||
Normal: m,
|
||||
Exception: exception,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.FieldInfo) (*dependencyInfo, error) {
|
||||
var (
|
||||
inputs = make(map[vo.NodeKey][]*compose.FieldMapping)
|
||||
@@ -701,7 +657,7 @@ func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.Field
|
||||
inputsForParent = make(map[vo.NodeKey][]*compose.FieldMapping)
|
||||
)
|
||||
|
||||
connMap := make(map[vo.NodeKey]Connection)
|
||||
connMap := make(map[vo.NodeKey]schema.Connection)
|
||||
for _, conn := range w.connections {
|
||||
if conn.ToNode != n {
|
||||
continue
|
||||
@@ -734,7 +690,7 @@ func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.Field
|
||||
continue
|
||||
}
|
||||
|
||||
if ok := isInSameWorkflow(w.hierarchy, n, fromNode); ok {
|
||||
if ok := schema.IsInSameWorkflow(w.hierarchy, n, fromNode); ok {
|
||||
if _, ok := connMap[fromNode]; ok { // direct dependency
|
||||
if len(swp.Source.Ref.FromPath) == 0 && len(swp.Path) == 0 {
|
||||
if inputFull == nil {
|
||||
@@ -755,10 +711,10 @@ func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.Field
|
||||
compose.MapFieldPaths(swp.Source.Ref.FromPath, swp.Path))
|
||||
}
|
||||
}
|
||||
} else if ok := isBelowOneLevel(w.hierarchy, n, fromNode); ok {
|
||||
} else if ok := schema.IsBelowOneLevel(w.hierarchy, n, fromNode); ok {
|
||||
firstNodesInInnerWorkflow := true
|
||||
for _, conn := range connMap {
|
||||
if isInSameWorkflow(w.hierarchy, n, conn.FromNode) {
|
||||
if schema.IsInSameWorkflow(w.hierarchy, n, conn.FromNode) {
|
||||
// there is another node 'conn.FromNode' that connects to this node, while also at the same level
|
||||
firstNodesInInnerWorkflow = false
|
||||
break
|
||||
@@ -805,9 +761,9 @@ func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.Field
|
||||
continue
|
||||
}
|
||||
|
||||
if isBelowOneLevel(w.hierarchy, n, fromNodeKey) {
|
||||
if schema.IsBelowOneLevel(w.hierarchy, n, fromNodeKey) {
|
||||
fromNodeKey = compose.START
|
||||
} else if !isInSameWorkflow(w.hierarchy, n, fromNodeKey) {
|
||||
} else if !schema.IsInSameWorkflow(w.hierarchy, n, fromNodeKey) {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -864,13 +820,13 @@ func (w *Workflow) resolveDependenciesAsParent(n vo.NodeKey, sourceWithPaths []*
|
||||
variableInfos []*variableInfo
|
||||
)
|
||||
|
||||
connMap := make(map[vo.NodeKey]Connection)
|
||||
connMap := make(map[vo.NodeKey]schema.Connection)
|
||||
for _, conn := range w.connections {
|
||||
if conn.ToNode != n {
|
||||
continue
|
||||
}
|
||||
|
||||
if isInSameWorkflow(w.hierarchy, conn.FromNode, n) {
|
||||
if schema.IsInSameWorkflow(w.hierarchy, conn.FromNode, n) {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -899,7 +855,7 @@ func (w *Workflow) resolveDependenciesAsParent(n vo.NodeKey, sourceWithPaths []*
|
||||
swp.Source.Ref.FromPath, swp.Path)
|
||||
}
|
||||
|
||||
if ok := isParentOf(w.hierarchy, n, fromNode); ok {
|
||||
if ok := schema.IsParentOf(w.hierarchy, n, fromNode); ok {
|
||||
if _, ok := connMap[fromNode]; ok { // direct dependency
|
||||
inputs[fromNode] = append(inputs[fromNode], compose.MapFieldPaths(swp.Source.Ref.FromPath, append(compose.FieldPath{string(fromNode)}, swp.Source.Ref.FromPath...)))
|
||||
} else { // indirect dependency
|
||||
|
||||
Reference in New Issue
Block a user