coze-studio/backend/domain/workflow/internal/schema/branch_schema.go

197 lines
5.3 KiB
Go

/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package schema
import (
"context"
"fmt"
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
)
// Port type constants
const (
PortDefault = "default"
PortBranchError = "branch_error"
PortBranchFormat = "branch_%d"
)
// BranchSchema defines the schema for workflow branches.
type BranchSchema struct {
From vo.NodeKey `json:"from_node"`
DefaultMapping map[string]bool `json:"default_mapping,omitempty"`
ExceptionMapping map[string]bool `json:"exception_mapping,omitempty"`
Mappings map[int64]map[string]bool `json:"mappings,omitempty"`
}
// BuildBranches builds branch schemas from connections.
func BuildBranches(connections []*Connection) (map[vo.NodeKey]*BranchSchema, error) {
var branchMap map[vo.NodeKey]*BranchSchema
for _, conn := range connections {
if conn.FromPort == nil || len(*conn.FromPort) == 0 {
continue
}
port := *conn.FromPort
sourceNodeKey := conn.FromNode
if branchMap == nil {
branchMap = map[vo.NodeKey]*BranchSchema{}
}
// Get or create branch schema for source node
branch, exists := branchMap[sourceNodeKey]
if !exists {
branch = &BranchSchema{
From: sourceNodeKey,
}
branchMap[sourceNodeKey] = branch
}
// Classify port type and add to appropriate mapping
switch {
case port == PortDefault:
if branch.DefaultMapping == nil {
branch.DefaultMapping = map[string]bool{}
}
branch.DefaultMapping[string(conn.ToNode)] = true
case port == PortBranchError:
if branch.ExceptionMapping == nil {
branch.ExceptionMapping = map[string]bool{}
}
branch.ExceptionMapping[string(conn.ToNode)] = true
default:
var branchNum int64
_, err := fmt.Sscanf(port, PortBranchFormat, &branchNum)
if err != nil || branchNum < 0 {
return nil, fmt.Errorf("invalid port format '%s' for connection %+v", port, conn)
}
if branch.Mappings == nil {
branch.Mappings = map[int64]map[string]bool{}
}
if _, exists := branch.Mappings[branchNum]; !exists {
branch.Mappings[branchNum] = make(map[string]bool)
}
branch.Mappings[branchNum][string(conn.ToNode)] = true
}
}
return branchMap, nil
}
func (bs *BranchSchema) OnlyException() bool {
return len(bs.Mappings) == 0 && len(bs.ExceptionMapping) > 0 && len(bs.DefaultMapping) > 0
}
func (bs *BranchSchema) GetExceptionBranch() *compose.GraphBranch {
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
isSuccess, ok := in["isSuccess"]
if ok && isSuccess != nil && !isSuccess.(bool) {
return bs.ExceptionMapping, nil
}
return bs.DefaultMapping, nil
}
// Combine ExceptionMapping and DefaultMapping into a new map
endNodes := make(map[string]bool)
for node := range bs.ExceptionMapping {
endNodes[node] = true
}
for node := range bs.DefaultMapping {
endNodes[node] = true
}
return compose.NewGraphMultiBranch(condition, endNodes)
}
func (bs *BranchSchema) GetFullBranch(ctx context.Context, bb BranchBuilder) (*compose.GraphBranch, error) {
extractor, hasBranch := bb.BuildBranch(ctx)
if !hasBranch {
return nil, fmt.Errorf("branch expected but BranchBuilder thinks not. BranchSchema: %v", bs)
}
if len(bs.ExceptionMapping) == 0 { // no exception, it's a normal branch
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
index, isDefault, err := extractor(ctx, in)
if err != nil {
return nil, err
}
if isDefault {
return bs.DefaultMapping, nil
}
if _, ok := bs.Mappings[index]; !ok {
return nil, fmt.Errorf("chosen index= %d, out of range", index)
}
return bs.Mappings[index], nil
}
// Combine DefaultMapping and normal mappings into a new map
endNodes := make(map[string]bool)
for node := range bs.DefaultMapping {
endNodes[node] = true
}
for _, ms := range bs.Mappings {
for node := range ms {
endNodes[node] = true
}
}
return compose.NewGraphMultiBranch(condition, endNodes), nil
}
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
isSuccess, ok := in["isSuccess"]
if ok && isSuccess != nil && !isSuccess.(bool) {
return bs.ExceptionMapping, nil
}
index, isDefault, err := extractor(ctx, in)
if err != nil {
return nil, err
}
if isDefault {
return bs.DefaultMapping, nil
}
return bs.Mappings[index], nil
}
// Combine ALL mappings into a new map
endNodes := make(map[string]bool)
for node := range bs.ExceptionMapping {
endNodes[node] = true
}
for node := range bs.DefaultMapping {
endNodes[node] = true
}
for _, ms := range bs.Mappings {
for node := range ms {
endNodes[node] = true
}
}
return compose.NewGraphMultiBranch(condition, endNodes), nil
}