197 lines
		
	
	
		
			5.3 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			197 lines
		
	
	
		
			5.3 KiB
		
	
	
	
		
			Go
		
	
	
	
| /*
 | |
|  * Copyright 2025 coze-dev Authors
 | |
|  *
 | |
|  * Licensed under the Apache License, Version 2.0 (the "License");
 | |
|  * you may not use this file except in compliance with the License.
 | |
|  * You may obtain a copy of the License at
 | |
|  *
 | |
|  *     http://www.apache.org/licenses/LICENSE-2.0
 | |
|  *
 | |
|  * Unless required by applicable law or agreed to in writing, software
 | |
|  * distributed under the License is distributed on an "AS IS" BASIS,
 | |
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
|  * See the License for the specific language governing permissions and
 | |
|  * limitations under the License.
 | |
|  */
 | |
| 
 | |
| package schema
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"fmt"
 | |
| 
 | |
| 	"github.com/cloudwego/eino/compose"
 | |
| 
 | |
| 	"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
 | |
| )
 | |
| 
 | |
| // Port type constants
 | |
| const (
 | |
| 	PortDefault      = "default"
 | |
| 	PortBranchError  = "branch_error"
 | |
| 	PortBranchFormat = "branch_%d"
 | |
| )
 | |
| 
 | |
| // BranchSchema defines the schema for workflow branches.
 | |
| type BranchSchema struct {
 | |
| 	From             vo.NodeKey                `json:"from_node"`
 | |
| 	DefaultMapping   map[string]bool           `json:"default_mapping,omitempty"`
 | |
| 	ExceptionMapping map[string]bool           `json:"exception_mapping,omitempty"`
 | |
| 	Mappings         map[int64]map[string]bool `json:"mappings,omitempty"`
 | |
| }
 | |
| 
 | |
| // BuildBranches builds branch schemas from connections.
 | |
| func BuildBranches(connections []*Connection) (map[vo.NodeKey]*BranchSchema, error) {
 | |
| 	var branchMap map[vo.NodeKey]*BranchSchema
 | |
| 
 | |
| 	for _, conn := range connections {
 | |
| 		if conn.FromPort == nil || len(*conn.FromPort) == 0 {
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		port := *conn.FromPort
 | |
| 		sourceNodeKey := conn.FromNode
 | |
| 
 | |
| 		if branchMap == nil {
 | |
| 			branchMap = map[vo.NodeKey]*BranchSchema{}
 | |
| 		}
 | |
| 
 | |
| 		// Get or create branch schema for source node
 | |
| 		branch, exists := branchMap[sourceNodeKey]
 | |
| 		if !exists {
 | |
| 			branch = &BranchSchema{
 | |
| 				From: sourceNodeKey,
 | |
| 			}
 | |
| 			branchMap[sourceNodeKey] = branch
 | |
| 		}
 | |
| 
 | |
| 		// Classify port type and add to appropriate mapping
 | |
| 		switch {
 | |
| 		case port == PortDefault:
 | |
| 			if branch.DefaultMapping == nil {
 | |
| 				branch.DefaultMapping = map[string]bool{}
 | |
| 			}
 | |
| 			branch.DefaultMapping[string(conn.ToNode)] = true
 | |
| 		case port == PortBranchError:
 | |
| 			if branch.ExceptionMapping == nil {
 | |
| 				branch.ExceptionMapping = map[string]bool{}
 | |
| 			}
 | |
| 			branch.ExceptionMapping[string(conn.ToNode)] = true
 | |
| 		default:
 | |
| 			var branchNum int64
 | |
| 			_, err := fmt.Sscanf(port, PortBranchFormat, &branchNum)
 | |
| 			if err != nil || branchNum < 0 {
 | |
| 				return nil, fmt.Errorf("invalid port format '%s' for connection %+v", port, conn)
 | |
| 			}
 | |
| 			if branch.Mappings == nil {
 | |
| 				branch.Mappings = map[int64]map[string]bool{}
 | |
| 			}
 | |
| 			if _, exists := branch.Mappings[branchNum]; !exists {
 | |
| 				branch.Mappings[branchNum] = make(map[string]bool)
 | |
| 			}
 | |
| 			branch.Mappings[branchNum][string(conn.ToNode)] = true
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return branchMap, nil
 | |
| }
 | |
| 
 | |
| func (bs *BranchSchema) OnlyException() bool {
 | |
| 	return len(bs.Mappings) == 0 && len(bs.ExceptionMapping) > 0 && len(bs.DefaultMapping) > 0
 | |
| }
 | |
| 
 | |
| func (bs *BranchSchema) GetExceptionBranch() *compose.GraphBranch {
 | |
| 	condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
 | |
| 		isSuccess, ok := in["isSuccess"]
 | |
| 		if ok && isSuccess != nil && !isSuccess.(bool) {
 | |
| 			return bs.ExceptionMapping, nil
 | |
| 		}
 | |
| 
 | |
| 		return bs.DefaultMapping, nil
 | |
| 	}
 | |
| 
 | |
| 	// Combine ExceptionMapping and DefaultMapping into a new map
 | |
| 	endNodes := make(map[string]bool)
 | |
| 	for node := range bs.ExceptionMapping {
 | |
| 		endNodes[node] = true
 | |
| 	}
 | |
| 	for node := range bs.DefaultMapping {
 | |
| 		endNodes[node] = true
 | |
| 	}
 | |
| 
 | |
| 	return compose.NewGraphMultiBranch(condition, endNodes)
 | |
| }
 | |
| 
 | |
| func (bs *BranchSchema) GetFullBranch(ctx context.Context, bb BranchBuilder) (*compose.GraphBranch, error) {
 | |
| 	extractor, hasBranch := bb.BuildBranch(ctx)
 | |
| 	if !hasBranch {
 | |
| 		return nil, fmt.Errorf("branch expected but BranchBuilder thinks not. BranchSchema: %v", bs)
 | |
| 	}
 | |
| 
 | |
| 	if len(bs.ExceptionMapping) == 0 { // no exception, it's a normal branch
 | |
| 		condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
 | |
| 			index, isDefault, err := extractor(ctx, in)
 | |
| 			if err != nil {
 | |
| 				return nil, err
 | |
| 			}
 | |
| 
 | |
| 			if isDefault {
 | |
| 				return bs.DefaultMapping, nil
 | |
| 			}
 | |
| 
 | |
| 			if _, ok := bs.Mappings[index]; !ok {
 | |
| 				return nil, fmt.Errorf("chosen index= %d, out of range", index)
 | |
| 			}
 | |
| 
 | |
| 			return bs.Mappings[index], nil
 | |
| 		}
 | |
| 
 | |
| 		// Combine DefaultMapping and normal mappings into a new map
 | |
| 		endNodes := make(map[string]bool)
 | |
| 		for node := range bs.DefaultMapping {
 | |
| 			endNodes[node] = true
 | |
| 		}
 | |
| 		for _, ms := range bs.Mappings {
 | |
| 			for node := range ms {
 | |
| 				endNodes[node] = true
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		return compose.NewGraphMultiBranch(condition, endNodes), nil
 | |
| 	}
 | |
| 
 | |
| 	condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
 | |
| 		isSuccess, ok := in["isSuccess"]
 | |
| 		if ok && isSuccess != nil && !isSuccess.(bool) {
 | |
| 			return bs.ExceptionMapping, nil
 | |
| 		}
 | |
| 
 | |
| 		index, isDefault, err := extractor(ctx, in)
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 
 | |
| 		if isDefault {
 | |
| 			return bs.DefaultMapping, nil
 | |
| 		}
 | |
| 
 | |
| 		return bs.Mappings[index], nil
 | |
| 	}
 | |
| 
 | |
| 	// Combine ALL mappings into a new map
 | |
| 	endNodes := make(map[string]bool)
 | |
| 	for node := range bs.ExceptionMapping {
 | |
| 		endNodes[node] = true
 | |
| 	}
 | |
| 	for node := range bs.DefaultMapping {
 | |
| 		endNodes[node] = true
 | |
| 	}
 | |
| 	for _, ms := range bs.Mappings {
 | |
| 		for node := range ms {
 | |
| 			endNodes[node] = true
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return compose.NewGraphMultiBranch(condition, endNodes), nil
 | |
| }
 |