refactor: how to add a node type in workflow (#558)
This commit is contained in:
@@ -24,8 +24,11 @@ import (
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/adaptor"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
@@ -123,7 +126,7 @@ func (cv *CanvasValidator) ValidateConnections(ctx context.Context) (issues []*I
|
||||
return issues, nil
|
||||
}
|
||||
|
||||
func (cv *CanvasValidator) CheckRefVariable(ctx context.Context) (issues []*Issue, err error) {
|
||||
func (cv *CanvasValidator) CheckRefVariable(_ context.Context) (issues []*Issue, err error) {
|
||||
issues = make([]*Issue, 0)
|
||||
var checkRefVariable func(reachability *reachability, reachableNodes map[string]bool) error
|
||||
checkRefVariable = func(reachability *reachability, parentReachableNodes map[string]bool) error {
|
||||
@@ -237,7 +240,7 @@ func (cv *CanvasValidator) CheckRefVariable(ctx context.Context) (issues []*Issu
|
||||
return issues, nil
|
||||
}
|
||||
|
||||
func (cv *CanvasValidator) ValidateNestedFlows(ctx context.Context) (issues []*Issue, err error) {
|
||||
func (cv *CanvasValidator) ValidateNestedFlows(_ context.Context) (issues []*Issue, err error) {
|
||||
issues = make([]*Issue, 0)
|
||||
for nodeID, node := range cv.reachability.reachableNodes {
|
||||
if nestedReachableNodes, ok := cv.reachability.nestedReachability[nodeID]; ok && len(nestedReachableNodes.nestedReachability) > 0 {
|
||||
@@ -265,13 +268,13 @@ func (cv *CanvasValidator) CheckGlobalVariables(ctx context.Context) (issues []*
|
||||
|
||||
nVars := make([]*nodeVars, 0)
|
||||
for _, node := range cv.cfg.Canvas.Nodes {
|
||||
if node.Type == vo.BlockTypeBotComment {
|
||||
if node.Type == entity.NodeTypeComment.IDStr() {
|
||||
continue
|
||||
}
|
||||
if node.Type == vo.BlockTypeBotAssignVariable {
|
||||
if node.Type == entity.NodeTypeVariableAssigner.IDStr() {
|
||||
v := &nodeVars{node: node, vars: make(map[string]*vo.TypeInfo)}
|
||||
for _, p := range node.Data.Inputs.InputParameters {
|
||||
v.vars[p.Name], err = adaptor.CanvasBlockInputToTypeInfo(p.Left)
|
||||
v.vars[p.Name], err = convert.CanvasBlockInputToTypeInfo(p.Left)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -338,7 +341,7 @@ func (cv *CanvasValidator) CheckSubWorkFlowTerminatePlanType(ctx context.Context
|
||||
var collectSubWorkFlowNodes func(nodes []*vo.Node)
|
||||
collectSubWorkFlowNodes = func(nodes []*vo.Node) {
|
||||
for _, n := range nodes {
|
||||
if n.Type == vo.BlockTypeBotSubWorkflow {
|
||||
if n.Type == entity.NodeTypeSubWorkflow.IDStr() {
|
||||
subWfMap = append(subWfMap, n)
|
||||
wID, err := strconv.ParseInt(n.Data.Inputs.WorkflowID, 10, 64)
|
||||
if err != nil {
|
||||
@@ -465,62 +468,28 @@ func validateConnections(ctx context.Context, c *vo.Canvas) (issues []*Issue, er
|
||||
selectorPorts := make(map[string]map[string]bool)
|
||||
|
||||
for nodeID, node := range nodeMap {
|
||||
switch node.Type {
|
||||
case vo.BlockTypeCondition:
|
||||
branches := node.Data.Inputs.Branches
|
||||
if node.Data.Inputs != nil && node.Data.Inputs.SettingOnError != nil &&
|
||||
node.Data.Inputs.SettingOnError.ProcessType != nil &&
|
||||
*node.Data.Inputs.SettingOnError.ProcessType == vo.ErrorProcessTypeExceptionBranch {
|
||||
if _, exists := selectorPorts[nodeID]; !exists {
|
||||
selectorPorts[nodeID] = make(map[string]bool)
|
||||
}
|
||||
selectorPorts[nodeID]["false"] = true
|
||||
for index := range branches {
|
||||
if index == 0 {
|
||||
selectorPorts[nodeID]["true"] = true
|
||||
} else {
|
||||
selectorPorts[nodeID][fmt.Sprintf("true_%v", index)] = true
|
||||
}
|
||||
}
|
||||
case vo.BlockTypeBotIntent:
|
||||
intents := node.Data.Inputs.Intents
|
||||
if _, exists := selectorPorts[nodeID]; !exists {
|
||||
selectorPorts[nodeID] = make(map[string]bool)
|
||||
}
|
||||
for index := range intents {
|
||||
selectorPorts[nodeID][fmt.Sprintf("branch_%v", index)] = true
|
||||
}
|
||||
selectorPorts[nodeID]["default"] = true
|
||||
if node.Data.Inputs.SettingOnError != nil && node.Data.Inputs.SettingOnError.ProcessType != nil &&
|
||||
*node.Data.Inputs.SettingOnError.ProcessType == vo.ErrorProcessTypeExceptionBranch {
|
||||
selectorPorts[nodeID]["branch_error"] = true
|
||||
}
|
||||
case vo.BlockTypeQuestion:
|
||||
if node.Data.Inputs.QA.AnswerType == vo.QAAnswerTypeOption {
|
||||
if _, exists := selectorPorts[nodeID]; !exists {
|
||||
selectorPorts[nodeID] = make(map[string]bool)
|
||||
}
|
||||
if node.Data.Inputs.QA.OptionType == vo.QAOptionTypeStatic {
|
||||
for index := range node.Data.Inputs.QA.Options {
|
||||
selectorPorts[nodeID][fmt.Sprintf("branch_%v", index)] = true
|
||||
}
|
||||
}
|
||||
|
||||
if node.Data.Inputs.QA.OptionType == vo.QAOptionTypeDynamic {
|
||||
selectorPorts[nodeID][fmt.Sprintf("branch_%v", 0)] = true
|
||||
}
|
||||
}
|
||||
default:
|
||||
if node.Data.Inputs != nil && node.Data.Inputs.SettingOnError != nil &&
|
||||
node.Data.Inputs.SettingOnError.ProcessType != nil &&
|
||||
*node.Data.Inputs.SettingOnError.ProcessType == vo.ErrorProcessTypeExceptionBranch {
|
||||
if _, exists := selectorPorts[nodeID]; !exists {
|
||||
selectorPorts[nodeID] = make(map[string]bool)
|
||||
}
|
||||
selectorPorts[nodeID]["branch_error"] = true
|
||||
selectorPorts[nodeID]["default"] = true
|
||||
} else {
|
||||
outDegree[node.ID] = 0
|
||||
}
|
||||
selectorPorts[nodeID][schema.PortBranchError] = true
|
||||
selectorPorts[nodeID][schema.PortDefault] = true
|
||||
}
|
||||
|
||||
ba, ok := nodes.GetBranchAdaptor(entity.IDStrToNodeType(node.Type))
|
||||
if ok {
|
||||
expects := ba.ExpectPorts(ctx, node)
|
||||
if len(expects) > 0 {
|
||||
if _, exists := selectorPorts[nodeID]; !exists {
|
||||
selectorPorts[nodeID] = make(map[string]bool)
|
||||
}
|
||||
for _, e := range expects {
|
||||
selectorPorts[nodeID][e] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, edge := range c.Edges {
|
||||
@@ -544,8 +513,8 @@ func validateConnections(ctx context.Context, c *vo.Canvas) (issues []*Issue, er
|
||||
for nodeID, node := range nodeMap {
|
||||
nodeName := node.Data.Meta.Title
|
||||
|
||||
switch node.Type {
|
||||
case vo.BlockTypeBotStart:
|
||||
switch et := entity.IDStrToNodeType(node.Type); et {
|
||||
case entity.NodeTypeEntry:
|
||||
if outDegree[nodeID] == 0 {
|
||||
issues = append(issues, &Issue{
|
||||
NodeErr: &NodeErr{
|
||||
@@ -555,13 +524,9 @@ func validateConnections(ctx context.Context, c *vo.Canvas) (issues []*Issue, er
|
||||
Message: `node "start" not connected`,
|
||||
})
|
||||
}
|
||||
case vo.BlockTypeBotEnd:
|
||||
case entity.NodeTypeExit:
|
||||
default:
|
||||
if ports, isSelector := selectorPorts[nodeID]; isSelector {
|
||||
selectorIssues := &Issue{NodeErr: &NodeErr{
|
||||
NodeID: node.ID,
|
||||
NodeName: nodeName,
|
||||
}}
|
||||
message := ""
|
||||
for port := range ports {
|
||||
if portOutDegree[nodeID][port] == 0 {
|
||||
@@ -569,12 +534,15 @@ func validateConnections(ctx context.Context, c *vo.Canvas) (issues []*Issue, er
|
||||
}
|
||||
}
|
||||
if len(message) > 0 {
|
||||
selectorIssues.Message = message
|
||||
selectorIssues := &Issue{NodeErr: &NodeErr{
|
||||
NodeID: node.ID,
|
||||
NodeName: nodeName,
|
||||
}, Message: message}
|
||||
issues = append(issues, selectorIssues)
|
||||
}
|
||||
} else {
|
||||
// Break, continue without checking out degrees
|
||||
if node.Type == vo.BlockTypeBotBreak || node.Type == vo.BlockTypeBotContinue {
|
||||
if et == entity.NodeTypeBreak || et == entity.NodeTypeContinue {
|
||||
continue
|
||||
}
|
||||
if outDegree[nodeID] == 0 {
|
||||
@@ -585,7 +553,6 @@ func validateConnections(ctx context.Context, c *vo.Canvas) (issues []*Issue, er
|
||||
},
|
||||
Message: fmt.Sprintf(`node "%v" not connected`, nodeName),
|
||||
})
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -602,7 +569,7 @@ func analyzeCanvasReachability(c *vo.Canvas) (*reachability, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
startNode, endNode, err := findStartAndEndNodes(c.Nodes)
|
||||
startNode, _, err := findStartAndEndNodes(c.Nodes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -612,7 +579,7 @@ func analyzeCanvasReachability(c *vo.Canvas) (*reachability, error) {
|
||||
edgeMap[edge.SourceNodeID] = append(edgeMap[edge.SourceNodeID], edge.TargetNodeID)
|
||||
}
|
||||
|
||||
reachable.reachableNodes, err = performReachabilityAnalysis(nodeMap, edgeMap, startNode, endNode)
|
||||
reachable.reachableNodes, err = performReachabilityAnalysis(nodeMap, edgeMap, startNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -635,12 +602,12 @@ func processNestedReachability(c *vo.Canvas, r *reachability) error {
|
||||
Nodes: append([]*vo.Node{
|
||||
{
|
||||
ID: node.ID,
|
||||
Type: vo.BlockTypeBotStart,
|
||||
Type: entity.NodeTypeEntry.IDStr(),
|
||||
Data: node.Data,
|
||||
},
|
||||
{
|
||||
ID: node.ID,
|
||||
Type: vo.BlockTypeBotEnd,
|
||||
Type: entity.NodeTypeExit.IDStr(),
|
||||
},
|
||||
}, node.Blocks...),
|
||||
Edges: node.Edges,
|
||||
@@ -663,9 +630,9 @@ func findStartAndEndNodes(nodes []*vo.Node) (*vo.Node, *vo.Node, error) {
|
||||
|
||||
for _, node := range nodes {
|
||||
switch node.Type {
|
||||
case vo.BlockTypeBotStart:
|
||||
case entity.NodeTypeEntry.IDStr():
|
||||
startNode = node
|
||||
case vo.BlockTypeBotEnd:
|
||||
case entity.NodeTypeExit.IDStr():
|
||||
endNode = node
|
||||
}
|
||||
}
|
||||
@@ -680,7 +647,7 @@ func findStartAndEndNodes(nodes []*vo.Node) (*vo.Node, *vo.Node, error) {
|
||||
return startNode, endNode, nil
|
||||
}
|
||||
|
||||
func performReachabilityAnalysis(nodeMap map[string]*vo.Node, edgeMap map[string][]string, startNode *vo.Node, endNode *vo.Node) (map[string]*vo.Node, error) {
|
||||
func performReachabilityAnalysis(nodeMap map[string]*vo.Node, edgeMap map[string][]string, startNode *vo.Node) (map[string]*vo.Node, error) {
|
||||
result := make(map[string]*vo.Node)
|
||||
result[startNode.ID] = startNode
|
||||
|
||||
|
||||
Reference in New Issue
Block a user