feat: manually mirror opencoze's code from bytedance
Change-Id: I09a73aadda978ad9511264a756b2ce51f5761adf
This commit is contained in:
342
backend/domain/workflow/internal/nodes/selector/clause.go
Normal file
342
backend/domain/workflow/internal/nodes/selector/clause.go
Normal file
@@ -0,0 +1,342 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package selector
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Predicate interface {
|
||||
Resolve() (bool, error)
|
||||
}
|
||||
|
||||
type Clause struct {
|
||||
LeftOperant any
|
||||
Op Operator
|
||||
RightOperant any
|
||||
}
|
||||
|
||||
type MultiClause struct {
|
||||
Clauses []*Clause
|
||||
Relation ClauseRelation
|
||||
}
|
||||
|
||||
func (c *Clause) Resolve() (bool, error) {
|
||||
leftV := c.LeftOperant
|
||||
rightV := c.RightOperant
|
||||
|
||||
leftT := reflect.TypeOf(leftV)
|
||||
rightT := reflect.TypeOf(rightV)
|
||||
|
||||
if err := c.Op.WillAccept(leftT, rightT); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
switch c.Op {
|
||||
case OperatorEqual:
|
||||
if leftV == nil && rightV == nil {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
if leftV == nil || rightV == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
leftV, rightV = alignNumberTypes(leftV, rightV, leftT, rightT)
|
||||
return leftV == rightV, nil
|
||||
case OperatorNotEqual:
|
||||
if leftV == nil && rightV == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if leftV == nil || rightV == nil {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
leftV, rightV = alignNumberTypes(leftV, rightV, leftT, rightT)
|
||||
return leftV != rightV, nil
|
||||
case OperatorEmpty:
|
||||
if leftV == nil {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
if leftArray, ok := leftV.([]any); ok {
|
||||
return len(leftArray) == 0, nil
|
||||
}
|
||||
|
||||
if leftObj, ok := leftV.(map[string]any); ok {
|
||||
return len(leftObj) == 0, nil
|
||||
}
|
||||
|
||||
if leftStr, ok := leftV.(string); ok {
|
||||
return len(leftStr) == 0 || leftStr == "None", nil
|
||||
}
|
||||
|
||||
if leftInt, ok := leftV.(int64); ok {
|
||||
return leftInt == 0, nil
|
||||
}
|
||||
|
||||
if leftFloat, ok := leftV.(float64); ok {
|
||||
return leftFloat == 0, nil
|
||||
}
|
||||
|
||||
if leftBool, ok := leftV.(bool); ok {
|
||||
return !leftBool, nil
|
||||
}
|
||||
|
||||
return false, nil
|
||||
case OperatorNotEmpty:
|
||||
empty, err := (&Clause{LeftOperant: leftV, Op: OperatorEmpty}).Resolve()
|
||||
return !empty, err
|
||||
case OperatorGreater:
|
||||
if leftV == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if rightV == nil {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
leftV, rightV = alignNumberTypes(leftV, rightV, leftT, rightT)
|
||||
if reflect.TypeOf(leftV).Kind() == reflect.Float64 {
|
||||
return leftV.(float64) > rightV.(float64), nil
|
||||
}
|
||||
return leftV.(int64) > rightV.(int64), nil
|
||||
case OperatorGreaterOrEqual:
|
||||
if leftV == nil {
|
||||
if rightV == nil {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if rightV == nil {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
leftV, rightV = alignNumberTypes(leftV, rightV, leftT, rightT)
|
||||
if reflect.TypeOf(leftV).Kind() == reflect.Float64 {
|
||||
return leftV.(float64) >= rightV.(float64), nil
|
||||
}
|
||||
return leftV.(int64) >= rightV.(int64), nil
|
||||
case OperatorLesser:
|
||||
if leftV == nil {
|
||||
if rightV == nil {
|
||||
return false, nil
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
if rightV == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
leftV, rightV = alignNumberTypes(leftV, rightV, leftT, rightT)
|
||||
if reflect.TypeOf(leftV).Kind() == reflect.Float64 {
|
||||
return leftV.(float64) < rightV.(float64), nil
|
||||
}
|
||||
return leftV.(int64) < rightV.(int64), nil
|
||||
case OperatorLesserOrEqual:
|
||||
if leftV == nil {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
if rightV == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
leftV, rightV = alignNumberTypes(leftV, rightV, leftT, rightT)
|
||||
if reflect.TypeOf(leftV).Kind() == reflect.Float64 {
|
||||
return leftV.(float64) <= rightV.(float64), nil
|
||||
}
|
||||
return leftV.(int64) <= rightV.(int64), nil
|
||||
case OperatorIsTrue:
|
||||
if leftV == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return leftV.(bool), nil
|
||||
case OperatorIsFalse:
|
||||
if leftV == nil {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return !leftV.(bool), nil
|
||||
case OperatorLengthGreater:
|
||||
if leftV == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return int64(reflect.ValueOf(leftV).Len()) > rightV.(int64), nil
|
||||
case OperatorLengthGreaterOrEqual:
|
||||
if leftV == nil {
|
||||
if rightV.(int64) == 0 {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return int64(reflect.ValueOf(leftV).Len()) >= rightV.(int64), nil
|
||||
case OperatorLengthLesser:
|
||||
if leftV == nil {
|
||||
if rightV.(int64) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return int64(reflect.ValueOf(leftV).Len()) < rightV.(int64), nil
|
||||
case OperatorLengthLesserOrEqual:
|
||||
if leftV == nil {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return int64(reflect.ValueOf(leftV).Len()) <= rightV.(int64), nil
|
||||
case OperatorContain:
|
||||
if leftV == nil { // treat it as empty slice
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if leftT.Kind() == reflect.String {
|
||||
return strings.Contains(fmt.Sprintf("%v", leftV), rightV.(string)), nil
|
||||
}
|
||||
|
||||
leftValue := reflect.ValueOf(leftV)
|
||||
for i := 0; i < leftValue.Len(); i++ {
|
||||
elem := leftValue.Index(i).Interface()
|
||||
if elem == rightV {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
case OperatorNotContain:
|
||||
if leftV == nil { // treat it as empty slice
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if leftT.Kind() == reflect.String {
|
||||
return !strings.Contains(fmt.Sprintf("%v", leftV), rightV.(string)), nil
|
||||
}
|
||||
|
||||
leftValue := reflect.ValueOf(leftV)
|
||||
for i := 0; i < leftValue.Len(); i++ {
|
||||
elem := leftValue.Index(i).Interface()
|
||||
if elem == rightV {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
case OperatorContainKey:
|
||||
if leftV == nil { // treat it as empty map
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if leftT.Kind() == reflect.Map {
|
||||
leftValue := reflect.ValueOf(leftV)
|
||||
for _, key := range leftValue.MapKeys() {
|
||||
if key.Interface() == rightV {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
} else { // struct, unreachable now
|
||||
for i := 0; i < leftT.NumField(); i++ {
|
||||
field := leftT.Field(i)
|
||||
if field.IsExported() {
|
||||
tag := field.Tag.Get("json")
|
||||
if tag == rightV {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
case OperatorNotContainKey:
|
||||
if leftV == nil { // treat it as empty map
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if leftT.Kind() == reflect.Map {
|
||||
leftValue := reflect.ValueOf(leftV)
|
||||
for _, key := range leftValue.MapKeys() {
|
||||
if key.Interface() == rightV {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
} else { // struct, unreachable now
|
||||
for i := 0; i < leftT.NumField(); i++ {
|
||||
field := leftT.Field(i)
|
||||
if field.IsExported() {
|
||||
tag := field.Tag.Get("json")
|
||||
if tag == rightV {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
default:
|
||||
return false, fmt.Errorf("unknown operator: %v", c.Op)
|
||||
}
|
||||
}
|
||||
|
||||
func (mc *MultiClause) Resolve() (bool, error) {
|
||||
if mc.Relation == ClauseRelationAND {
|
||||
for _, clause := range mc.Clauses {
|
||||
isTrue, err := clause.Resolve()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if !isTrue {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
} else if mc.Relation == ClauseRelationOR {
|
||||
for _, clause := range mc.Clauses {
|
||||
isTrue, err := clause.Resolve()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if isTrue {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
} else {
|
||||
return false, fmt.Errorf("unknown relation: %v", mc.Relation)
|
||||
}
|
||||
}
|
||||
|
||||
func alignNumberTypes(leftV, rightV any, leftT, rightT reflect.Type) (any, any) {
|
||||
if leftT == reflect.TypeOf(int64(0)) {
|
||||
if rightT == reflect.TypeOf(float64(0)) {
|
||||
leftV = float64(leftV.(int64))
|
||||
}
|
||||
} else if leftT == reflect.TypeOf(float64(0)) {
|
||||
if rightT == reflect.TypeOf(int64(0)) {
|
||||
rightV = float64(rightV.(int64))
|
||||
}
|
||||
}
|
||||
|
||||
return leftV, rightV
|
||||
}
|
||||
310
backend/domain/workflow/internal/nodes/selector/clause_test.go
Normal file
310
backend/domain/workflow/internal/nodes/selector/clause_test.go
Normal file
@@ -0,0 +1,310 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package selector
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestClauseResolve tests the Resolve method of the Clause struct.
|
||||
func TestClauseResolve(t *testing.T) {
|
||||
// Test cases for different operators, considering acceptable operand types
|
||||
testCases := []struct {
|
||||
name string
|
||||
clause Clause
|
||||
want bool
|
||||
wantErr bool
|
||||
}{
|
||||
// OperatorEqual
|
||||
{
|
||||
name: "OperatorEqual_IntMatch",
|
||||
clause: Clause{
|
||||
LeftOperant: int64(10),
|
||||
Op: OperatorEqual,
|
||||
RightOperant: int64(10),
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorEqual_IntMismatch",
|
||||
clause: Clause{
|
||||
LeftOperant: int64(10),
|
||||
Op: OperatorEqual,
|
||||
RightOperant: int64(20),
|
||||
},
|
||||
want: false,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorEqual_FloatMatch",
|
||||
clause: Clause{
|
||||
LeftOperant: 10.5,
|
||||
Op: OperatorEqual,
|
||||
RightOperant: 10.5,
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorEqual_StringMatch",
|
||||
clause: Clause{
|
||||
LeftOperant: "test",
|
||||
Op: OperatorEqual,
|
||||
RightOperant: "test",
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
// OperatorNotEqual
|
||||
{
|
||||
name: "OperatorNotEqual_IntMatch",
|
||||
clause: Clause{
|
||||
LeftOperant: int64(10),
|
||||
Op: OperatorNotEqual,
|
||||
RightOperant: int64(20),
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorNotEqual_StringMatch",
|
||||
clause: Clause{
|
||||
LeftOperant: "test",
|
||||
Op: OperatorNotEqual,
|
||||
RightOperant: "xyz",
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
// OperatorEmpty
|
||||
{
|
||||
name: "OperatorEmpty_NilValue",
|
||||
clause: Clause{
|
||||
LeftOperant: nil,
|
||||
Op: OperatorEmpty,
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
// OperatorNotEmpty
|
||||
{
|
||||
name: "OperatorNotEmpty_MapValue",
|
||||
clause: Clause{
|
||||
LeftOperant: map[string]any{"key1": "value1"},
|
||||
Op: OperatorNotEmpty,
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
// OperatorGreater
|
||||
{
|
||||
name: "OperatorGreater_IntMatch",
|
||||
clause: Clause{
|
||||
LeftOperant: int64(10),
|
||||
Op: OperatorGreater,
|
||||
RightOperant: int64(5),
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorGreater_FloatMatch",
|
||||
clause: Clause{
|
||||
LeftOperant: 10.5,
|
||||
Op: OperatorGreater,
|
||||
RightOperant: 5.0,
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
// OperatorGreaterOrEqual
|
||||
{
|
||||
name: "OperatorGreaterOrEqual_IntMatch",
|
||||
clause: Clause{
|
||||
LeftOperant: int64(10),
|
||||
Op: OperatorGreaterOrEqual,
|
||||
RightOperant: int64(10),
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
// OperatorLesser
|
||||
{
|
||||
name: "OperatorLesser_IntMatch",
|
||||
clause: Clause{
|
||||
LeftOperant: int64(10),
|
||||
Op: OperatorLesser,
|
||||
RightOperant: int64(15),
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
// OperatorLesserOrEqual
|
||||
{
|
||||
name: "OperatorLesserOrEqual_IntMatch",
|
||||
clause: Clause{
|
||||
LeftOperant: int64(10),
|
||||
Op: OperatorLesserOrEqual,
|
||||
RightOperant: int64(10),
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
// OperatorIsTrue
|
||||
{
|
||||
name: "OperatorIsTrue_BoolTrue",
|
||||
clause: Clause{
|
||||
LeftOperant: true,
|
||||
Op: OperatorIsTrue,
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
// OperatorIsFalse
|
||||
{
|
||||
name: "OperatorIsFalse_BoolFalse",
|
||||
clause: Clause{
|
||||
LeftOperant: true,
|
||||
Op: OperatorIsFalse,
|
||||
},
|
||||
want: false,
|
||||
wantErr: false,
|
||||
},
|
||||
// OperatorLengthGreater
|
||||
{
|
||||
name: "OperatorLengthGreater_Slice",
|
||||
clause: Clause{
|
||||
LeftOperant: []int{1, 2, 3},
|
||||
Op: OperatorLengthGreater,
|
||||
RightOperant: int64(2),
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorLengthGreater_String",
|
||||
clause: Clause{
|
||||
LeftOperant: "test",
|
||||
Op: OperatorLengthGreater,
|
||||
RightOperant: int64(2),
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
// OperatorLengthGreaterOrEqual
|
||||
{
|
||||
name: "OperatorLengthGreaterOrEqual_Slice",
|
||||
clause: Clause{
|
||||
LeftOperant: []int{1, 2, 3},
|
||||
Op: OperatorLengthGreaterOrEqual,
|
||||
RightOperant: int64(3),
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
// OperatorLengthLesser
|
||||
{
|
||||
name: "OperatorLengthLesser_Slice",
|
||||
clause: Clause{
|
||||
LeftOperant: []int{1, 2, 3},
|
||||
Op: OperatorLengthLesser,
|
||||
RightOperant: int64(4),
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
// OperatorLengthLesserOrEqual
|
||||
{
|
||||
name: "OperatorLengthLesserOrEqual_Slice",
|
||||
clause: Clause{
|
||||
LeftOperant: []int{1, 2, 3},
|
||||
Op: OperatorLengthLesserOrEqual,
|
||||
RightOperant: int64(3),
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
// OperatorContain
|
||||
{
|
||||
name: "OperatorContain_String",
|
||||
clause: Clause{
|
||||
LeftOperant: "test",
|
||||
Op: OperatorContain,
|
||||
RightOperant: "es",
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorContain_Slice",
|
||||
clause: Clause{
|
||||
LeftOperant: []int{1, 2, 3},
|
||||
Op: OperatorContain,
|
||||
RightOperant: 2,
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
// OperatorNotContain
|
||||
{
|
||||
name: "OperatorNotContain_String",
|
||||
clause: Clause{
|
||||
LeftOperant: "test2",
|
||||
Op: OperatorNotContain,
|
||||
RightOperant: "xyz",
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
// OperatorContainKey
|
||||
{
|
||||
name: "OperatorContainKey_Map",
|
||||
clause: Clause{
|
||||
LeftOperant: map[string]any{"key1": "value1"},
|
||||
Op: OperatorContainKey,
|
||||
RightOperant: "key1",
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
// OperatorNotContainKey
|
||||
{
|
||||
name: "OperatorNotContainKey_Map",
|
||||
clause: Clause{
|
||||
LeftOperant: map[string]any{"key1": "value1"},
|
||||
Op: OperatorNotContainKey,
|
||||
RightOperant: "key2",
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got, err := tc.clause.Resolve()
|
||||
if (err != nil) != tc.wantErr {
|
||||
t.Errorf("Clause.Resolve() error = %v, wantErr %v", err, tc.wantErr)
|
||||
return
|
||||
}
|
||||
assert.Equal(t, tc.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
182
backend/domain/workflow/internal/nodes/selector/operator.go
Normal file
182
backend/domain/workflow/internal/nodes/selector/operator.go
Normal file
@@ -0,0 +1,182 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package selector
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
)
|
||||
|
||||
type Operator string
|
||||
|
||||
const (
|
||||
OperatorEqual Operator = "="
|
||||
OperatorNotEqual Operator = "!="
|
||||
OperatorEmpty Operator = "empty"
|
||||
OperatorNotEmpty Operator = "not_empty"
|
||||
OperatorGreater Operator = ">"
|
||||
OperatorGreaterOrEqual Operator = ">="
|
||||
OperatorLesser Operator = "<"
|
||||
OperatorLesserOrEqual Operator = "<="
|
||||
OperatorIsTrue Operator = "true"
|
||||
OperatorIsFalse Operator = "false"
|
||||
OperatorLengthGreater Operator = "len >"
|
||||
OperatorLengthGreaterOrEqual Operator = "len >="
|
||||
OperatorLengthLesser Operator = "len <"
|
||||
OperatorLengthLesserOrEqual Operator = "len <="
|
||||
OperatorContain Operator = "contain"
|
||||
OperatorNotContain Operator = "not_contain"
|
||||
OperatorContainKey Operator = "contain_key"
|
||||
OperatorNotContainKey Operator = "not_contain_key"
|
||||
)
|
||||
|
||||
func (o *Operator) WillAccept(leftT, rightT reflect.Type) error {
|
||||
switch *o {
|
||||
case OperatorEqual, OperatorNotEqual:
|
||||
if leftT == nil || rightT == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if leftT != reflect.TypeOf(int64(0)) && leftT != reflect.TypeOf(float64(0)) && leftT.Kind() != reflect.Bool && leftT.Kind() != reflect.String {
|
||||
return fmt.Errorf("operator %v only accepts int64, float64, bool or string, not %v", *o, leftT)
|
||||
}
|
||||
|
||||
if leftT.Kind() == reflect.Bool || leftT.Kind() != reflect.String {
|
||||
if leftT != rightT {
|
||||
return fmt.Errorf("operator %v left operant and right operant must be same type: %v, %v", *o, leftT, rightT)
|
||||
}
|
||||
}
|
||||
|
||||
if leftT == reflect.TypeOf(int64(0)) || leftT == reflect.TypeOf(float64(0)) {
|
||||
if rightT != reflect.TypeOf(int64(0)) && rightT != reflect.TypeOf(float64(0)) {
|
||||
return fmt.Errorf("operator %v right operant must be int64 or float64, not %v", *o, rightT)
|
||||
}
|
||||
}
|
||||
case OperatorEmpty, OperatorNotEmpty:
|
||||
if rightT != nil {
|
||||
return fmt.Errorf("operator %v does not accept non-nil right operant: %v", *o, rightT)
|
||||
}
|
||||
case OperatorGreater, OperatorGreaterOrEqual, OperatorLesser, OperatorLesserOrEqual:
|
||||
if leftT == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if leftT != reflect.TypeOf(int64(0)) && leftT != reflect.TypeOf(float64(0)) {
|
||||
return fmt.Errorf("operator %v only accepts float64 or int64, not %v", *o, leftT)
|
||||
}
|
||||
case OperatorIsTrue, OperatorIsFalse:
|
||||
if leftT == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if rightT != nil {
|
||||
return fmt.Errorf("operator %v does not accept non-nil right operant: %v", *o, rightT)
|
||||
}
|
||||
|
||||
if leftT.Kind() != reflect.Bool {
|
||||
return fmt.Errorf("operator %v only accepts boolean, not %v", *o, leftT)
|
||||
}
|
||||
case OperatorLengthGreater, OperatorLengthGreaterOrEqual, OperatorLengthLesser, OperatorLengthLesserOrEqual:
|
||||
if leftT == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if leftT.Kind() != reflect.String && leftT.Kind() != reflect.Slice {
|
||||
return fmt.Errorf("operator %v left operant only accepts string or slice, not %v", *o, leftT)
|
||||
}
|
||||
if rightT != reflect.TypeOf(int64(0)) {
|
||||
return fmt.Errorf("operator %v right operant only accepts int64, not %v", *o, rightT)
|
||||
}
|
||||
case OperatorContain, OperatorNotContain:
|
||||
if leftT == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch leftT.Kind() {
|
||||
case reflect.String:
|
||||
if rightT.Kind() != reflect.String {
|
||||
return fmt.Errorf("operator %v whose left operant is string only accepts right operant of string, not %v", *o, rightT)
|
||||
}
|
||||
case reflect.Slice:
|
||||
elemType := leftT.Elem()
|
||||
if !rightT.AssignableTo(elemType) {
|
||||
return fmt.Errorf("operator %v whose left operant is slice only accepts right operant of corresponding element type %v, not %v", *o, elemType, rightT)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("operator %v only accepts left operant of string or slice, not %v", *o, leftT)
|
||||
}
|
||||
case OperatorContainKey, OperatorNotContainKey:
|
||||
if leftT == nil { // treat it as empty map
|
||||
return nil
|
||||
}
|
||||
if leftT.Kind() != reflect.Map {
|
||||
return fmt.Errorf("operator %v only accepts left operant of map, not %v", *o, leftT)
|
||||
}
|
||||
if rightT.Kind() != reflect.String {
|
||||
return fmt.Errorf("operator %v only accepts right operant of string, not %v", *o, rightT)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unknown operator: %d", o)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *Operator) ToCanvasOperatorType() vo.OperatorType {
|
||||
switch *o {
|
||||
case OperatorEqual:
|
||||
return vo.Equal
|
||||
case OperatorNotEqual:
|
||||
return vo.NotEqual
|
||||
case OperatorEmpty:
|
||||
return vo.Empty
|
||||
case OperatorNotEmpty:
|
||||
return vo.NotEmpty
|
||||
case OperatorGreater:
|
||||
return vo.GreaterThan
|
||||
case OperatorGreaterOrEqual:
|
||||
return vo.GreaterThanEqual
|
||||
case OperatorLesser:
|
||||
return vo.LessThan
|
||||
case OperatorLesserOrEqual:
|
||||
return vo.LessThanEqual
|
||||
case OperatorIsTrue:
|
||||
return vo.True
|
||||
case OperatorIsFalse:
|
||||
return vo.False
|
||||
case OperatorLengthGreater:
|
||||
return vo.LengthGreaterThan
|
||||
case OperatorLengthGreaterOrEqual:
|
||||
return vo.LengthGreaterThanEqual
|
||||
case OperatorLengthLesser:
|
||||
return vo.LengthLessThan
|
||||
case OperatorLengthLesserOrEqual:
|
||||
return vo.LengthLessThanEqual
|
||||
case OperatorContain:
|
||||
return vo.Contain
|
||||
case OperatorNotContain:
|
||||
return vo.NotContain
|
||||
case OperatorContainKey:
|
||||
return vo.Contain
|
||||
case OperatorNotContainKey:
|
||||
return vo.NotContain
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown operator: %+v", o))
|
||||
}
|
||||
}
|
||||
397
backend/domain/workflow/internal/nodes/selector/operator_test.go
Normal file
397
backend/domain/workflow/internal/nodes/selector/operator_test.go
Normal file
@@ -0,0 +1,397 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package selector
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestOperatorWillAccept tests the WillAccept method of the Operator struct.
|
||||
func TestOperatorWillAccept(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
operator Operator
|
||||
leftType reflect.Type
|
||||
rightType reflect.Type
|
||||
wantErr bool
|
||||
}{
|
||||
// OperatorEqual
|
||||
{
|
||||
name: "OperatorEqual_Int64Match",
|
||||
operator: OperatorEqual,
|
||||
leftType: reflect.TypeOf(int64(0)),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorEqual_InvalidType",
|
||||
operator: OperatorEqual,
|
||||
leftType: reflect.TypeOf(struct{}{}),
|
||||
rightType: reflect.TypeOf(struct{}{}),
|
||||
wantErr: true,
|
||||
},
|
||||
// OperatorNotEqual
|
||||
{
|
||||
name: "OperatorNotEqual_Float64Match",
|
||||
operator: OperatorNotEqual,
|
||||
leftType: reflect.TypeOf(float64(0)),
|
||||
rightType: reflect.TypeOf(float64(0)),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorNotEqual_InvalidType",
|
||||
operator: OperatorNotEqual,
|
||||
leftType: reflect.TypeOf(struct{}{}),
|
||||
rightType: reflect.TypeOf(struct{}{}),
|
||||
wantErr: true,
|
||||
},
|
||||
// OperatorEmpty
|
||||
{
|
||||
name: "OperatorEmpty_Struct",
|
||||
operator: OperatorEmpty,
|
||||
leftType: reflect.TypeOf(map[string]int{}),
|
||||
rightType: nil,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorEmpty_NonNilRight",
|
||||
operator: OperatorEmpty,
|
||||
leftType: reflect.TypeOf(map[string]int{}),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: true,
|
||||
},
|
||||
// OperatorNotEmpty
|
||||
{
|
||||
name: "OperatorNotEmpty_Slice",
|
||||
operator: OperatorNotEmpty,
|
||||
leftType: reflect.TypeOf([]int{}),
|
||||
rightType: nil,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorNotEmpty_NonNilRight",
|
||||
operator: OperatorNotEmpty,
|
||||
leftType: reflect.TypeOf([]int{}),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: true,
|
||||
},
|
||||
// OperatorGreater
|
||||
{
|
||||
name: "OperatorGreater_Int64Match",
|
||||
operator: OperatorGreater,
|
||||
leftType: reflect.TypeOf(int64(0)),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorGreater_InvalidType",
|
||||
operator: OperatorGreater,
|
||||
leftType: reflect.TypeOf(struct{}{}),
|
||||
rightType: reflect.TypeOf(struct{}{}),
|
||||
wantErr: true,
|
||||
},
|
||||
// OperatorGreaterOrEqual
|
||||
{
|
||||
name: "OperatorGreaterOrEqual_Float64Match",
|
||||
operator: OperatorGreaterOrEqual,
|
||||
leftType: reflect.TypeOf(float64(0)),
|
||||
rightType: reflect.TypeOf(float64(0)),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorGreaterOrEqual_InvalidType",
|
||||
operator: OperatorGreaterOrEqual,
|
||||
leftType: reflect.TypeOf(struct{}{}),
|
||||
rightType: reflect.TypeOf(struct{}{}),
|
||||
wantErr: true,
|
||||
},
|
||||
// OperatorLesser
|
||||
{
|
||||
name: "OperatorLesser_Int64Match",
|
||||
operator: OperatorLesser,
|
||||
leftType: reflect.TypeOf(int64(0)),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorLesser_InvalidType",
|
||||
operator: OperatorLesser,
|
||||
leftType: reflect.TypeOf(struct{}{}),
|
||||
rightType: reflect.TypeOf(struct{}{}),
|
||||
wantErr: true,
|
||||
},
|
||||
// OperatorLesserOrEqual
|
||||
{
|
||||
name: "OperatorLesserOrEqual_Float64Match",
|
||||
operator: OperatorLesserOrEqual,
|
||||
leftType: reflect.TypeOf(float64(0)),
|
||||
rightType: reflect.TypeOf(float64(0)),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorLesserOrEqual_InvalidType",
|
||||
operator: OperatorLesserOrEqual,
|
||||
leftType: reflect.TypeOf(struct{}{}),
|
||||
rightType: reflect.TypeOf(struct{}{}),
|
||||
wantErr: true,
|
||||
},
|
||||
// OperatorIsTrue
|
||||
{
|
||||
name: "OperatorIsTrue_Bool",
|
||||
operator: OperatorIsTrue,
|
||||
leftType: reflect.TypeOf(true),
|
||||
rightType: nil,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorIsTrue_NonNilRight",
|
||||
operator: OperatorIsTrue,
|
||||
leftType: reflect.TypeOf(true),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "OperatorIsTrue_InvalidType",
|
||||
operator: OperatorIsTrue,
|
||||
leftType: reflect.TypeOf(int64(0)),
|
||||
rightType: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
// OperatorIsFalse
|
||||
{
|
||||
name: "OperatorIsFalse_Bool",
|
||||
operator: OperatorIsFalse,
|
||||
leftType: reflect.TypeOf(false),
|
||||
rightType: nil,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorIsFalse_NonNilRight",
|
||||
operator: OperatorIsFalse,
|
||||
leftType: reflect.TypeOf(false),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "OperatorIsFalse_InvalidType",
|
||||
operator: OperatorIsFalse,
|
||||
leftType: reflect.TypeOf(int64(0)),
|
||||
rightType: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
// OperatorLengthGreater
|
||||
{
|
||||
name: "OperatorLengthGreater_String",
|
||||
operator: OperatorLengthGreater,
|
||||
leftType: reflect.TypeOf(""),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorLengthGreater_InvalidLeft",
|
||||
operator: OperatorLengthGreater,
|
||||
leftType: reflect.TypeOf(int64(0)),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "OperatorLengthGreater_InvalidRight",
|
||||
operator: OperatorLengthGreater,
|
||||
leftType: reflect.TypeOf(""),
|
||||
rightType: reflect.TypeOf(float64(0)),
|
||||
wantErr: true,
|
||||
},
|
||||
// OperatorLengthGreaterOrEqual
|
||||
{
|
||||
name: "OperatorLengthGreaterOrEqual_Slice",
|
||||
operator: OperatorLengthGreaterOrEqual,
|
||||
leftType: reflect.TypeOf([]any{}),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorLengthGreaterOrEqual_InvalidLeft",
|
||||
operator: OperatorLengthGreaterOrEqual,
|
||||
leftType: reflect.TypeOf(int64(0)),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "OperatorLengthGreaterOrEqual_InvalidRight",
|
||||
operator: OperatorLengthGreaterOrEqual,
|
||||
leftType: reflect.TypeOf([]any{}),
|
||||
rightType: reflect.TypeOf(float64(0)),
|
||||
wantErr: true,
|
||||
},
|
||||
// OperatorLengthLesser
|
||||
{
|
||||
name: "OperatorLengthLesser_String",
|
||||
operator: OperatorLengthLesser,
|
||||
leftType: reflect.TypeOf(""),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorLengthLesser_InvalidLeft",
|
||||
operator: OperatorLengthLesser,
|
||||
leftType: reflect.TypeOf(int64(0)),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "OperatorLengthLesser_InvalidRight",
|
||||
operator: OperatorLengthLesser,
|
||||
leftType: reflect.TypeOf(""),
|
||||
rightType: reflect.TypeOf(float64(0)),
|
||||
wantErr: true,
|
||||
},
|
||||
// OperatorLengthLesserOrEqual
|
||||
{
|
||||
name: "OperatorLengthLesserOrEqual_Slice",
|
||||
operator: OperatorLengthLesserOrEqual,
|
||||
leftType: reflect.TypeOf([]any{}),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorLengthLesserOrEqual_InvalidLeft",
|
||||
operator: OperatorLengthLesserOrEqual,
|
||||
leftType: reflect.TypeOf(int64(0)),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "OperatorLengthLesserOrEqual_InvalidRight",
|
||||
operator: OperatorLengthLesserOrEqual,
|
||||
leftType: reflect.TypeOf([]any{}),
|
||||
rightType: reflect.TypeOf(float64(0)),
|
||||
wantErr: true,
|
||||
},
|
||||
// OperatorContain
|
||||
{
|
||||
name: "OperatorContain_String",
|
||||
operator: OperatorContain,
|
||||
leftType: reflect.TypeOf(""),
|
||||
rightType: reflect.TypeOf(""),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorContain_Slice",
|
||||
operator: OperatorContain,
|
||||
leftType: reflect.TypeOf([]any{}),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorContain_InvalidLeft",
|
||||
operator: OperatorContain,
|
||||
leftType: reflect.TypeOf(int64(0)),
|
||||
rightType: reflect.TypeOf(0),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "OperatorContain_InvalidRight",
|
||||
operator: OperatorContain,
|
||||
leftType: reflect.TypeOf(""),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: true,
|
||||
},
|
||||
// OperatorNotContain
|
||||
{
|
||||
name: "OperatorNotContain_String",
|
||||
operator: OperatorNotContain,
|
||||
leftType: reflect.TypeOf(""),
|
||||
rightType: reflect.TypeOf(""),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorNotContain_Slice",
|
||||
operator: OperatorNotContain,
|
||||
leftType: reflect.TypeOf([]any{}),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorNotContain_InvalidLeft",
|
||||
operator: OperatorNotContain,
|
||||
leftType: reflect.TypeOf(int64(0)),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "OperatorNotContain_InvalidRight",
|
||||
operator: OperatorNotContain,
|
||||
leftType: reflect.TypeOf(""),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: true,
|
||||
},
|
||||
// OperatorContainKey
|
||||
{
|
||||
name: "OperatorContainKey_Map",
|
||||
operator: OperatorContainKey,
|
||||
leftType: reflect.TypeOf(map[string]any{}),
|
||||
rightType: reflect.TypeOf(""),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorContainKey_InvalidLeft",
|
||||
operator: OperatorContainKey,
|
||||
leftType: reflect.TypeOf(int64(0)),
|
||||
rightType: reflect.TypeOf(""),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "OperatorContainKey_InvalidRight",
|
||||
operator: OperatorContainKey,
|
||||
leftType: reflect.TypeOf(map[string]any{}),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: true,
|
||||
},
|
||||
// OperatorNotContainKey
|
||||
{
|
||||
name: "OperatorNotContainKey_Map",
|
||||
operator: OperatorNotContainKey,
|
||||
leftType: reflect.TypeOf(map[string]any{}),
|
||||
rightType: reflect.TypeOf(""),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorNotContainKey_InvalidLeft",
|
||||
operator: OperatorNotContainKey,
|
||||
leftType: reflect.TypeOf(int64(0)),
|
||||
rightType: reflect.TypeOf(""),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "OperatorNotContainKey_InvalidRight",
|
||||
operator: OperatorNotContainKey,
|
||||
leftType: reflect.TypeOf(map[string]any{}),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := tc.operator.WillAccept(tc.leftType, tc.rightType)
|
||||
if (err != nil) != tc.wantErr {
|
||||
t.Errorf("Operator.WillAccept() error = %v, wantErr %v", err, tc.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
54
backend/domain/workflow/internal/nodes/selector/schema.go
Normal file
54
backend/domain/workflow/internal/nodes/selector/schema.go
Normal file
@@ -0,0 +1,54 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package selector
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
)
|
||||
|
||||
type ClauseRelation string
|
||||
|
||||
const (
|
||||
ClauseRelationAND ClauseRelation = "and"
|
||||
ClauseRelationOR ClauseRelation = "or"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Clauses []*OneClauseSchema `json:"clauses"`
|
||||
}
|
||||
|
||||
type OneClauseSchema struct {
|
||||
Single *Operator `json:"single,omitempty"`
|
||||
Multi *MultiClauseSchema `json:"multi,omitempty"`
|
||||
}
|
||||
|
||||
type MultiClauseSchema struct {
|
||||
Clauses []*Operator `json:"clauses"`
|
||||
Relation ClauseRelation `json:"relation"`
|
||||
}
|
||||
|
||||
func (c ClauseRelation) ToVOLogicType() vo.LogicType {
|
||||
if c == ClauseRelationAND {
|
||||
return vo.AND
|
||||
} else if c == ClauseRelationOR {
|
||||
return vo.OR
|
||||
}
|
||||
|
||||
panic(fmt.Sprintf("unknown clause relation: %s", c))
|
||||
}
|
||||
209
backend/domain/workflow/internal/nodes/selector/selector.go
Normal file
209
backend/domain/workflow/internal/nodes/selector/selector.go
Normal file
@@ -0,0 +1,209 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package selector
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
)
|
||||
|
||||
type Selector struct {
|
||||
config *Config
|
||||
}
|
||||
|
||||
func NewSelector(_ context.Context, config *Config) (*Selector, error) {
|
||||
if config == nil {
|
||||
return nil, fmt.Errorf("config is nil")
|
||||
}
|
||||
|
||||
if len(config.Clauses) == 0 {
|
||||
return nil, fmt.Errorf("config clauses are empty")
|
||||
}
|
||||
|
||||
for _, clause := range config.Clauses {
|
||||
if clause.Single == nil && clause.Multi == nil {
|
||||
return nil, fmt.Errorf("single clause and multi clause are both nil")
|
||||
}
|
||||
|
||||
if clause.Single != nil && clause.Multi != nil {
|
||||
return nil, fmt.Errorf("multi clause and single clause are both non-nil")
|
||||
}
|
||||
|
||||
if clause.Multi != nil {
|
||||
if len(clause.Multi.Clauses) == 0 {
|
||||
return nil, fmt.Errorf("multi clause's single clauses are empty")
|
||||
}
|
||||
|
||||
if clause.Multi.Relation != ClauseRelationAND && clause.Multi.Relation != ClauseRelationOR {
|
||||
return nil, fmt.Errorf("multi clause and clauses are both non-AND-OR: %v", clause.Multi.Relation)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &Selector{
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type Operants struct {
|
||||
Left any
|
||||
Right any
|
||||
Multi []*Operants
|
||||
}
|
||||
|
||||
const (
|
||||
LeftKey = "left"
|
||||
RightKey = "right"
|
||||
SelectKey = "selected"
|
||||
)
|
||||
|
||||
func (s *Selector) Select(_ context.Context, input map[string]any) (out map[string]any, err error) {
|
||||
in, err := s.SelectorInputConverter(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
predicates := make([]Predicate, 0, len(s.config.Clauses))
|
||||
for i, oneConf := range s.config.Clauses {
|
||||
if oneConf.Single != nil {
|
||||
left := in[i].Left
|
||||
right := in[i].Right
|
||||
if right != nil {
|
||||
predicates = append(predicates, &Clause{
|
||||
LeftOperant: left,
|
||||
Op: *oneConf.Single,
|
||||
RightOperant: right,
|
||||
})
|
||||
} else {
|
||||
predicates = append(predicates, &Clause{
|
||||
LeftOperant: left,
|
||||
Op: *oneConf.Single,
|
||||
})
|
||||
}
|
||||
} else if oneConf.Multi != nil {
|
||||
multiClause := &MultiClause{
|
||||
Relation: oneConf.Multi.Relation,
|
||||
}
|
||||
for j, singleConf := range oneConf.Multi.Clauses {
|
||||
left := in[i].Multi[j].Left
|
||||
right := in[i].Multi[j].Right
|
||||
if right != nil {
|
||||
multiClause.Clauses = append(multiClause.Clauses, &Clause{
|
||||
LeftOperant: left,
|
||||
Op: *singleConf,
|
||||
RightOperant: right,
|
||||
})
|
||||
} else {
|
||||
multiClause.Clauses = append(multiClause.Clauses, &Clause{
|
||||
LeftOperant: left,
|
||||
Op: *singleConf,
|
||||
})
|
||||
}
|
||||
}
|
||||
predicates = append(predicates, multiClause)
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid clause config, both single and multi are nil: %v", oneConf)
|
||||
}
|
||||
}
|
||||
|
||||
for i, p := range predicates {
|
||||
isTrue, err := p.Resolve()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if isTrue {
|
||||
return map[string]any{SelectKey: i}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]any{SelectKey: len(in)}, nil // default choice
|
||||
}
|
||||
|
||||
func (s *Selector) GetType() string {
|
||||
return "Selector"
|
||||
}
|
||||
|
||||
func (s *Selector) ConditionCount() int {
|
||||
return len(s.config.Clauses)
|
||||
}
|
||||
|
||||
func (s *Selector) SelectorInputConverter(in map[string]any) (out []Operants, err error) {
|
||||
conf := s.config.Clauses
|
||||
|
||||
for i, oneConf := range conf {
|
||||
if oneConf.Single != nil {
|
||||
left, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), LeftKey})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to take left operant from input map: %v, clause index= %d", in, i)
|
||||
}
|
||||
|
||||
right, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), RightKey})
|
||||
if ok {
|
||||
out = append(out, Operants{Left: left, Right: right})
|
||||
} else {
|
||||
out = append(out, Operants{Left: left})
|
||||
}
|
||||
} else if oneConf.Multi != nil {
|
||||
multiClause := make([]*Operants, 0)
|
||||
for j := range oneConf.Multi.Clauses {
|
||||
left, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), strconv.Itoa(j), LeftKey})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to take left operant from input map: %v, clause index= %d, single clause index= %d", in, i, j)
|
||||
}
|
||||
right, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), strconv.Itoa(j), RightKey})
|
||||
if ok {
|
||||
multiClause = append(multiClause, &Operants{Left: left, Right: right})
|
||||
} else {
|
||||
multiClause = append(multiClause, &Operants{Left: left})
|
||||
}
|
||||
}
|
||||
out = append(out, Operants{Multi: multiClause})
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid clause config, both single and multi are nil: %v", oneConf)
|
||||
}
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *Selector) ToCallbackOutput(_ context.Context, output map[string]any) (*nodes.StructuredCallbackOutput, error) {
|
||||
count := len(s.config.Clauses)
|
||||
out := output[SelectKey].(int)
|
||||
if out == count {
|
||||
cOutput := map[string]any{"result": "pass to else branch"}
|
||||
return &nodes.StructuredCallbackOutput{
|
||||
Output: cOutput,
|
||||
RawOutput: cOutput,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if out >= 0 && out < count {
|
||||
cOutput := map[string]any{"result": fmt.Sprintf("pass to condition %d branch", out+1)}
|
||||
return &nodes.StructuredCallbackOutput{
|
||||
Output: cOutput,
|
||||
RawOutput: cOutput,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("out of range: %d", out)
|
||||
}
|
||||
Reference in New Issue
Block a user