197 lines
5.3 KiB
Go
197 lines
5.3 KiB
Go
/*
|
|
* Copyright 2025 coze-dev Authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package schema
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
"github.com/cloudwego/eino/compose"
|
|
|
|
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
|
)
|
|
|
|
// Port type constants
|
|
const (
|
|
PortDefault = "default"
|
|
PortBranchError = "branch_error"
|
|
PortBranchFormat = "branch_%d"
|
|
)
|
|
|
|
// BranchSchema defines the schema for workflow branches.
|
|
type BranchSchema struct {
|
|
From vo.NodeKey `json:"from_node"`
|
|
DefaultMapping map[string]bool `json:"default_mapping,omitempty"`
|
|
ExceptionMapping map[string]bool `json:"exception_mapping,omitempty"`
|
|
Mappings map[int64]map[string]bool `json:"mappings,omitempty"`
|
|
}
|
|
|
|
// BuildBranches builds branch schemas from connections.
|
|
func BuildBranches(connections []*Connection) (map[vo.NodeKey]*BranchSchema, error) {
|
|
var branchMap map[vo.NodeKey]*BranchSchema
|
|
|
|
for _, conn := range connections {
|
|
if conn.FromPort == nil || len(*conn.FromPort) == 0 {
|
|
continue
|
|
}
|
|
|
|
port := *conn.FromPort
|
|
sourceNodeKey := conn.FromNode
|
|
|
|
if branchMap == nil {
|
|
branchMap = map[vo.NodeKey]*BranchSchema{}
|
|
}
|
|
|
|
// Get or create branch schema for source node
|
|
branch, exists := branchMap[sourceNodeKey]
|
|
if !exists {
|
|
branch = &BranchSchema{
|
|
From: sourceNodeKey,
|
|
}
|
|
branchMap[sourceNodeKey] = branch
|
|
}
|
|
|
|
// Classify port type and add to appropriate mapping
|
|
switch {
|
|
case port == PortDefault:
|
|
if branch.DefaultMapping == nil {
|
|
branch.DefaultMapping = map[string]bool{}
|
|
}
|
|
branch.DefaultMapping[string(conn.ToNode)] = true
|
|
case port == PortBranchError:
|
|
if branch.ExceptionMapping == nil {
|
|
branch.ExceptionMapping = map[string]bool{}
|
|
}
|
|
branch.ExceptionMapping[string(conn.ToNode)] = true
|
|
default:
|
|
var branchNum int64
|
|
_, err := fmt.Sscanf(port, PortBranchFormat, &branchNum)
|
|
if err != nil || branchNum < 0 {
|
|
return nil, fmt.Errorf("invalid port format '%s' for connection %+v", port, conn)
|
|
}
|
|
if branch.Mappings == nil {
|
|
branch.Mappings = map[int64]map[string]bool{}
|
|
}
|
|
if _, exists := branch.Mappings[branchNum]; !exists {
|
|
branch.Mappings[branchNum] = make(map[string]bool)
|
|
}
|
|
branch.Mappings[branchNum][string(conn.ToNode)] = true
|
|
}
|
|
}
|
|
|
|
return branchMap, nil
|
|
}
|
|
|
|
func (bs *BranchSchema) OnlyException() bool {
|
|
return len(bs.Mappings) == 0 && len(bs.ExceptionMapping) > 0 && len(bs.DefaultMapping) > 0
|
|
}
|
|
|
|
func (bs *BranchSchema) GetExceptionBranch() *compose.GraphBranch {
|
|
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
|
|
isSuccess, ok := in["isSuccess"]
|
|
if ok && isSuccess != nil && !isSuccess.(bool) {
|
|
return bs.ExceptionMapping, nil
|
|
}
|
|
|
|
return bs.DefaultMapping, nil
|
|
}
|
|
|
|
// Combine ExceptionMapping and DefaultMapping into a new map
|
|
endNodes := make(map[string]bool)
|
|
for node := range bs.ExceptionMapping {
|
|
endNodes[node] = true
|
|
}
|
|
for node := range bs.DefaultMapping {
|
|
endNodes[node] = true
|
|
}
|
|
|
|
return compose.NewGraphMultiBranch(condition, endNodes)
|
|
}
|
|
|
|
func (bs *BranchSchema) GetFullBranch(ctx context.Context, bb BranchBuilder) (*compose.GraphBranch, error) {
|
|
extractor, hasBranch := bb.BuildBranch(ctx)
|
|
if !hasBranch {
|
|
return nil, fmt.Errorf("branch expected but BranchBuilder thinks not. BranchSchema: %v", bs)
|
|
}
|
|
|
|
if len(bs.ExceptionMapping) == 0 { // no exception, it's a normal branch
|
|
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
|
|
index, isDefault, err := extractor(ctx, in)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if isDefault {
|
|
return bs.DefaultMapping, nil
|
|
}
|
|
|
|
if _, ok := bs.Mappings[index]; !ok {
|
|
return nil, fmt.Errorf("chosen index= %d, out of range", index)
|
|
}
|
|
|
|
return bs.Mappings[index], nil
|
|
}
|
|
|
|
// Combine DefaultMapping and normal mappings into a new map
|
|
endNodes := make(map[string]bool)
|
|
for node := range bs.DefaultMapping {
|
|
endNodes[node] = true
|
|
}
|
|
for _, ms := range bs.Mappings {
|
|
for node := range ms {
|
|
endNodes[node] = true
|
|
}
|
|
}
|
|
|
|
return compose.NewGraphMultiBranch(condition, endNodes), nil
|
|
}
|
|
|
|
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
|
|
isSuccess, ok := in["isSuccess"]
|
|
if ok && isSuccess != nil && !isSuccess.(bool) {
|
|
return bs.ExceptionMapping, nil
|
|
}
|
|
|
|
index, isDefault, err := extractor(ctx, in)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if isDefault {
|
|
return bs.DefaultMapping, nil
|
|
}
|
|
|
|
return bs.Mappings[index], nil
|
|
}
|
|
|
|
// Combine ALL mappings into a new map
|
|
endNodes := make(map[string]bool)
|
|
for node := range bs.ExceptionMapping {
|
|
endNodes[node] = true
|
|
}
|
|
for node := range bs.DefaultMapping {
|
|
endNodes[node] = true
|
|
}
|
|
for _, ms := range bs.Mappings {
|
|
for node := range ms {
|
|
endNodes[node] = true
|
|
}
|
|
}
|
|
|
|
return compose.NewGraphMultiBranch(condition, endNodes), nil
|
|
}
|