feat: manually mirror opencoze's code from bytedance
Change-Id: I09a73aadda978ad9511264a756b2ce51f5761adf
This commit is contained in:
447
backend/domain/workflow/internal/nodes/batch/batch.go
Normal file
447
backend/domain/workflow/internal/nodes/batch/batch.go
Normal file
@@ -0,0 +1,447 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package batch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
)
|
||||
|
||||
type Batch struct {
|
||||
config *Config
|
||||
outputs map[string]*vo.FieldSource
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
BatchNodeKey vo.NodeKey `json:"batch_node_key"`
|
||||
InnerWorkflow compose.Runnable[map[string]any, map[string]any]
|
||||
|
||||
InputArrays []string `json:"input_arrays"`
|
||||
Outputs []*vo.FieldInfo `json:"outputs"`
|
||||
}
|
||||
|
||||
func NewBatch(_ context.Context, config *Config) (*Batch, error) {
|
||||
if config == nil {
|
||||
return nil, errors.New("config is required")
|
||||
}
|
||||
|
||||
if len(config.InputArrays) == 0 {
|
||||
return nil, errors.New("need to have at least one incoming array for batch")
|
||||
}
|
||||
|
||||
if len(config.Outputs) == 0 {
|
||||
return nil, errors.New("need to have at least one output variable for batch")
|
||||
}
|
||||
|
||||
b := &Batch{
|
||||
config: config,
|
||||
outputs: make(map[string]*vo.FieldSource),
|
||||
}
|
||||
|
||||
for i := range config.Outputs {
|
||||
source := config.Outputs[i]
|
||||
path := source.Path
|
||||
if len(path) != 1 {
|
||||
return nil, fmt.Errorf("invalid path %q", path)
|
||||
}
|
||||
|
||||
b.outputs[path[0]] = &source.Source
|
||||
}
|
||||
|
||||
return b, nil
|
||||
}
|
||||
|
||||
const (
|
||||
MaxBatchSizeKey = "batchSize"
|
||||
ConcurrentSizeKey = "concurrentSize"
|
||||
)
|
||||
|
||||
func (b *Batch) initOutput(length int) map[string]any {
|
||||
out := make(map[string]any, len(b.outputs))
|
||||
for key := range b.outputs {
|
||||
sliceType := reflect.TypeOf([]any{})
|
||||
slice := reflect.New(sliceType).Elem()
|
||||
slice.Set(reflect.MakeSlice(sliceType, length, length))
|
||||
out[key] = slice.Interface()
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.NestedWorkflowOption) (
|
||||
out map[string]any, err error) {
|
||||
arrays := make(map[string]any, len(b.config.InputArrays))
|
||||
minLen := math.MaxInt64
|
||||
for _, arrayKey := range b.config.InputArrays {
|
||||
a, ok := nodes.TakeMapValue(in, compose.FieldPath{arrayKey})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("incoming array not present in input: %s", arrayKey)
|
||||
}
|
||||
|
||||
if reflect.TypeOf(a).Kind() != reflect.Slice {
|
||||
return nil, fmt.Errorf("incoming array not a slice: %s. Actual type: %v",
|
||||
arrayKey, reflect.TypeOf(a))
|
||||
}
|
||||
|
||||
arrays[arrayKey] = a
|
||||
|
||||
oneLen := reflect.ValueOf(a).Len()
|
||||
if oneLen < minLen {
|
||||
minLen = oneLen
|
||||
}
|
||||
}
|
||||
|
||||
var maxIter, concurrency int64
|
||||
|
||||
maxIterAny, ok := nodes.TakeMapValue(in, compose.FieldPath{MaxBatchSizeKey})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("incoming max iteration not present in input: %s", in)
|
||||
}
|
||||
|
||||
maxIter = maxIterAny.(int64)
|
||||
if maxIter == 0 {
|
||||
maxIter = 100
|
||||
}
|
||||
|
||||
concurrencyAny, ok := nodes.TakeMapValue(in, compose.FieldPath{ConcurrentSizeKey})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("incoming concurrency not present in input: %s", in)
|
||||
}
|
||||
|
||||
concurrency = concurrencyAny.(int64)
|
||||
if concurrency == 0 {
|
||||
concurrency = 10
|
||||
}
|
||||
|
||||
if minLen > int(maxIter) {
|
||||
minLen = int(maxIter)
|
||||
}
|
||||
|
||||
output := b.initOutput(minLen)
|
||||
if minLen == 0 {
|
||||
return output, nil
|
||||
}
|
||||
|
||||
getIthInput := func(i int) (map[string]any, map[string]any, error) {
|
||||
input := make(map[string]any)
|
||||
|
||||
for k, v := range in { // carry over other values
|
||||
if k != MaxBatchSizeKey && k != ConcurrentSizeKey {
|
||||
input[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
input[string(b.config.BatchNodeKey)+"#index"] = int64(i)
|
||||
|
||||
items := make(map[string]any)
|
||||
for arrayKey, array := range arrays {
|
||||
ele := reflect.ValueOf(array).Index(i).Interface()
|
||||
items[arrayKey] = []any{ele}
|
||||
currentKey := string(b.config.BatchNodeKey) + "#" + arrayKey
|
||||
|
||||
// Recursively expand map[string]any elements
|
||||
if m, ok := ele.(map[string]any); ok {
|
||||
var expand func(prefix string, val interface{})
|
||||
expand = func(prefix string, val interface{}) {
|
||||
if nestedMap, ok := val.(map[string]any); ok {
|
||||
for k, v := range nestedMap {
|
||||
expand(prefix+"#"+k, v)
|
||||
}
|
||||
} else {
|
||||
input[prefix] = val
|
||||
}
|
||||
}
|
||||
expand(currentKey, m)
|
||||
} else {
|
||||
input[currentKey] = ele
|
||||
}
|
||||
}
|
||||
|
||||
return input, items, nil
|
||||
}
|
||||
|
||||
setIthOutput := func(i int, taskOutput map[string]any) error {
|
||||
for k, source := range b.outputs {
|
||||
fromValue, _ := nodes.TakeMapValue(taskOutput, append(compose.FieldPath{string(source.Ref.FromNodeKey)},
|
||||
source.Ref.FromPath...))
|
||||
|
||||
toArray, ok := nodes.TakeMapValue(output, compose.FieldPath{k})
|
||||
if !ok {
|
||||
return fmt.Errorf("key not present in outer workflow's output: %s", k)
|
||||
}
|
||||
|
||||
toArray.([]any)[i] = fromValue
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
options := &nodes.NestedWorkflowOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
var existingCState *nodes.NestedWorkflowState
|
||||
err = compose.ProcessState(ctx, func(ctx context.Context, getter nodes.NestedWorkflowAware) error {
|
||||
var e error
|
||||
existingCState, _, e = getter.GetNestedWorkflowState(b.config.BatchNodeKey)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if existingCState != nil {
|
||||
output = existingCState.FullOutput
|
||||
}
|
||||
|
||||
ctx, cancelFn := context.WithCancelCause(ctx)
|
||||
var (
|
||||
wg sync.WaitGroup
|
||||
mu sync.Mutex
|
||||
index2Done = map[int]bool{}
|
||||
index2InterruptInfo = map[int]*compose.InterruptInfo{}
|
||||
resumed = map[int]bool{}
|
||||
)
|
||||
|
||||
ithTask := func(i int) {
|
||||
defer wg.Done()
|
||||
|
||||
if existingCState != nil {
|
||||
if existingCState.Index2Done[i] == true {
|
||||
return
|
||||
}
|
||||
|
||||
if existingCState.Index2InterruptInfo[i] != nil {
|
||||
if len(options.GetResumeIndexes()) > 0 {
|
||||
if _, ok := options.GetResumeIndexes()[i]; !ok {
|
||||
// previously interrupted, but not resumed this time, skip
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
resumed[i] = true
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return // canceled by normal error, abort
|
||||
default:
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
if len(index2InterruptInfo) > 0 { // already has interrupted index, abort
|
||||
mu.Unlock()
|
||||
return
|
||||
}
|
||||
mu.Unlock()
|
||||
|
||||
input, items, err := getIthInput(i)
|
||||
if err != nil {
|
||||
cancelFn(err)
|
||||
return
|
||||
}
|
||||
|
||||
subCtx, subCheckpointID := execute.InheritExeCtxWithBatchInfo(ctx, i, items)
|
||||
|
||||
ithOpts := slices.Clone(options.GetOptsForNested())
|
||||
mu.Lock()
|
||||
ithOpts = append(ithOpts, options.GetOptsForIndexed(i)...)
|
||||
mu.Unlock()
|
||||
if subCheckpointID != "" {
|
||||
logs.CtxInfof(ctx, "[testInterrupt] prepare %d th run for batch node %s, subCheckPointID %s",
|
||||
i, b.config.BatchNodeKey, subCheckpointID)
|
||||
ithOpts = append(ithOpts, compose.WithCheckPointID(subCheckpointID))
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
if len(options.GetResumeIndexes()) > 0 {
|
||||
stateModifier, ok := options.GetResumeIndexes()[i]
|
||||
mu.Unlock()
|
||||
if ok {
|
||||
fmt.Println("has state modifier for ith run: ", i, ", checkpointID: ", subCheckpointID)
|
||||
ithOpts = append(ithOpts, compose.WithStateModifier(stateModifier))
|
||||
}
|
||||
} else {
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
// if the innerWorkflow has output emitter that requires stream output, then we need to stream the inner workflow
|
||||
// the output then needs to be concatenated.
|
||||
taskOutput, err := b.config.InnerWorkflow.Invoke(subCtx, input, ithOpts...)
|
||||
if err != nil {
|
||||
info, ok := compose.ExtractInterruptInfo(err)
|
||||
if !ok {
|
||||
cancelFn(err)
|
||||
return
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
index2InterruptInfo[i] = info
|
||||
mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
if err = setIthOutput(i, taskOutput); err != nil {
|
||||
cancelFn(err)
|
||||
return
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
index2Done[i] = true
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
wg.Add(minLen)
|
||||
if minLen < int(concurrency) {
|
||||
for i := 1; i < minLen; i++ {
|
||||
go ithTask(i)
|
||||
}
|
||||
ithTask(0)
|
||||
} else {
|
||||
taskChan := make(chan int, concurrency)
|
||||
for i := 0; i < int(concurrency); i++ {
|
||||
safego.Go(ctx, func() {
|
||||
for i := range taskChan {
|
||||
ithTask(i)
|
||||
}
|
||||
})
|
||||
}
|
||||
for i := 0; i < minLen; i++ {
|
||||
taskChan <- i
|
||||
}
|
||||
close(taskChan)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if context.Cause(ctx) != nil {
|
||||
if errors.Is(context.Cause(ctx), context.Canceled) {
|
||||
return nil, context.Canceled // canceled by Eino workflow engine
|
||||
}
|
||||
return nil, context.Cause(ctx) // normal error, just throw it out
|
||||
}
|
||||
|
||||
// delete the interruptions that have been resumed
|
||||
for index := range resumed {
|
||||
delete(existingCState.Index2InterruptInfo, index)
|
||||
}
|
||||
|
||||
compState := existingCState
|
||||
if compState == nil {
|
||||
compState = &nodes.NestedWorkflowState{
|
||||
Index2Done: index2Done,
|
||||
Index2InterruptInfo: index2InterruptInfo,
|
||||
FullOutput: output,
|
||||
}
|
||||
} else {
|
||||
for i := range index2Done {
|
||||
compState.Index2Done[i] = index2Done[i]
|
||||
}
|
||||
for i := range index2InterruptInfo {
|
||||
compState.Index2InterruptInfo[i] = index2InterruptInfo[i]
|
||||
}
|
||||
compState.FullOutput = output
|
||||
}
|
||||
|
||||
if len(index2InterruptInfo) > 0 { // this invocation of batch.Execute has new interruptions
|
||||
iEvent := &entity.InterruptEvent{
|
||||
NodeKey: b.config.BatchNodeKey,
|
||||
NodeType: entity.NodeTypeBatch,
|
||||
NestedInterruptInfo: index2InterruptInfo, // only emit the newly generated interruptInfo
|
||||
}
|
||||
|
||||
err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error {
|
||||
if e := setter.SaveNestedWorkflowState(b.config.BatchNodeKey, compState); e != nil {
|
||||
return e
|
||||
}
|
||||
|
||||
return setter.SetInterruptEvent(b.config.BatchNodeKey, iEvent)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fmt.Println("save interruptEvent in state within batch: ", iEvent)
|
||||
fmt.Println("save composite info in state within batch: ", compState)
|
||||
|
||||
return nil, compose.InterruptAndRerun
|
||||
} else {
|
||||
err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error {
|
||||
if e := setter.SaveNestedWorkflowState(b.config.BatchNodeKey, compState); e != nil {
|
||||
return e
|
||||
}
|
||||
|
||||
if existingCState == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// although this invocation does not have new interruptions,
|
||||
// this batch node previously have interrupts yet to be resumed.
|
||||
// we overwrite the interrupt events, keeping only the interrupts yet to be resumed.
|
||||
return setter.SetInterruptEvent(b.config.BatchNodeKey, &entity.InterruptEvent{
|
||||
NodeKey: b.config.BatchNodeKey,
|
||||
NodeType: entity.NodeTypeBatch,
|
||||
NestedInterruptInfo: existingCState.Index2InterruptInfo,
|
||||
})
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fmt.Println("save composite info in state within batch: ", compState)
|
||||
}
|
||||
|
||||
if existingCState != nil && len(existingCState.Index2InterruptInfo) > 0 {
|
||||
logs.CtxInfof(ctx, "no interrupt thrown this round, but has historical interrupt events yet to be resumed, "+
|
||||
"nodeKey: %v. indexes: %v", b.config.BatchNodeKey, maps.Keys(existingCState.Index2InterruptInfo))
|
||||
return nil, compose.InterruptAndRerun // interrupt again to wait for resuming of previously interrupted index runs
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}
|
||||
|
||||
func (b *Batch) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) {
|
||||
trimmed := make(map[string]any, len(b.config.InputArrays))
|
||||
for _, arrayKey := range b.config.InputArrays {
|
||||
if v, ok := in[arrayKey]; ok {
|
||||
trimmed[arrayKey] = v
|
||||
}
|
||||
}
|
||||
return trimmed, nil
|
||||
}
|
||||
26
backend/domain/workflow/internal/nodes/callbacks.go
Normal file
26
backend/domain/workflow/internal/nodes/callbacks.go
Normal file
@@ -0,0 +1,26 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package nodes
|
||||
|
||||
import "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
|
||||
type StructuredCallbackOutput struct {
|
||||
Output map[string]any
|
||||
RawOutput map[string]any
|
||||
Extra map[string]any // node specific extra info, will go into node execution's extra.ResponseExtra
|
||||
Error vo.WorkflowError
|
||||
}
|
||||
272
backend/domain/workflow/internal/nodes/code/code.go
Normal file
272
backend/domain/workflow/internal/nodes/code/code.go
Normal file
@@ -0,0 +1,272 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package code
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/code"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
const (
|
||||
coderRunnerRawOutputCtxKey = "ctx_raw_output"
|
||||
coderRunnerWarnErrorLevelCtxKey = "ctx_warn_error_level"
|
||||
)
|
||||
|
||||
var (
|
||||
importRegex = regexp.MustCompile(`^\s*import\s+([a-zA-Z0-9_.,\s]+)`)
|
||||
fromImportRegex = regexp.MustCompile(`^\s*from\s+([a-zA-Z0-9_.]+)\s+import`)
|
||||
)
|
||||
|
||||
// pythonBuiltinModules is the list of python built-in modules,
|
||||
// see: https://docs.python.org/3.9/library/
|
||||
var pythonBuiltinModules = map[string]struct{}{
|
||||
"abc": {}, "aifc": {}, "antigravity": {}, "argparse": {}, "ast": {}, "asynchat": {}, "asyncio": {}, "asyncore": {}, "array": {},
|
||||
"atexit": {}, "base64": {}, "bdb": {}, "binhex": {}, "bisect": {}, "builtins": {}, "bz2": {}, "cProfile": {}, "binascii": {},
|
||||
"calendar": {}, "cgi": {}, "cgitb": {}, "chunk": {}, "cmd": {}, "code": {}, "codecs": {}, "codeop": {}, "cmath": {}, "audioop": {},
|
||||
"collections": {}, "colorsys": {}, "compileall": {}, "concurrent": {}, "configparser": {}, "contextlib": {}, "contextvars": {}, "copy": {},
|
||||
"copyreg": {}, "crypt": {}, "csv": {}, "ctypes": {}, "curses": {}, "dataclasses": {}, "datetime": {}, "dbm": {}, "fcntl": {},
|
||||
"decimal": {}, "difflib": {}, "dis": {}, "distutils": {}, "doctest": {}, "email": {}, "encodings": {}, "ensurepip": {}, "ossaudiodev": {},
|
||||
"enum": {}, "errno": {}, "faulthandler": {}, "filecmp": {}, "fileinput": {}, "fnmatch": {}, "formatter": {}, "fractions": {},
|
||||
"ftplib": {}, "functools": {}, "gc": {}, "genericpath": {}, "getopt": {}, "getpass": {}, "gettext": {}, "glob": {}, "grp": {},
|
||||
"graphlib": {}, "gzip": {}, "hashlib": {}, "heapq": {}, "hmac": {}, "html": {}, "http": {}, "imaplib": {}, "msvcrt": {},
|
||||
"imghdr": {}, "imp": {}, "importlib": {}, "inspect": {}, "io": {}, "ipaddress": {}, "itertools": {}, "json": {}, "mmap": {},
|
||||
"keyword": {}, "lib2to3": {}, "linecache": {}, "locale": {}, "logging": {}, "lzma": {}, "mailbox": {}, "mailcap": {}, "msilib": {},
|
||||
"marshal": {}, "math": {}, "mimetypes": {}, "modulefinder": {}, "multiprocessing": {}, "netrc": {}, "nntplib": {}, "ntpath": {},
|
||||
"nturl2path": {}, "numbers": {}, "opcode": {}, "operator": {}, "optparse": {}, "os": {}, "pathlib": {}, "pdb": {}, "readline": {},
|
||||
"pickle": {}, "pickletools": {}, "pipes": {}, "pkgutil": {}, "platform": {}, "plistlib": {}, "poplib": {}, "posix": {}, "parser": {},
|
||||
"posixpath": {}, "pprint": {}, "profile": {}, "pstats": {}, "pty": {}, "pwd": {}, "py_compile": {}, "pyclbr": {}, "spwd": {},
|
||||
"pydoc": {}, "pydoc_data": {}, "queue": {}, "quopri": {}, "random": {}, "re": {}, "reprlib": {}, "rlcompleter": {}, "resource": {},
|
||||
"runpy": {}, "sched": {}, "secrets": {}, "selectors": {}, "shelve": {}, "shlex": {}, "shutil": {}, "signal": {}, "select": {},
|
||||
"site": {}, "smtpd": {}, "smtplib": {}, "sndhdr": {}, "socket": {}, "socketserver": {}, "sqlite3": {}, "sre_compile": {},
|
||||
"sre_constants": {}, "sre_parse": {}, "ssl": {}, "stat": {}, "statistics": {}, "string": {}, "stringprep": {}, "struct": {},
|
||||
"subprocess": {}, "sunau": {}, "symbol": {}, "symtable": {}, "sys": {}, "sysconfig": {}, "tabnanny": {}, "tarfile": {}, "nis": {},
|
||||
"telnetlib": {}, "tempfile": {}, "textwrap": {}, "this": {}, "threading": {}, "time": {}, "timeit": {}, "tkinter": {}, "test": {},
|
||||
"token": {}, "tokenize": {}, "trace": {}, "traceback": {}, "tracemalloc": {}, "tty": {}, "turtle": {}, "turtledemo": {},
|
||||
"types": {}, "typing": {}, "unittest": {}, "urllib": {}, "uu": {}, "uuid": {}, "venv": {}, "warnings": {}, "termios": {},
|
||||
"wave": {}, "weakref": {}, "webbrowser": {}, "wsgiref": {}, "xdrlib": {}, "xml": {}, "xmlrpc": {}, "xxsubtype": {}, "zlib": {},
|
||||
"zipapp": {}, "zipfile": {}, "zipimport": {}, "zoneinfo": {}, "winreg": {}, "syslog": {}, "winsound": {}, "unicodedata": {},
|
||||
}
|
||||
|
||||
// pythonBuiltinBlacklist is the blacklist of python built-in modules,
|
||||
// see: https://www.coze.cn/open/docs/guides/code_node#7f41f073
|
||||
var pythonBuiltinBlacklist = map[string]struct{}{
|
||||
"curses": {},
|
||||
"dbm": {},
|
||||
"ensurepip": {},
|
||||
"fcntl": {},
|
||||
"grp": {},
|
||||
"idlelib": {},
|
||||
"lib2to3": {},
|
||||
"msvcrt": {},
|
||||
"pwd": {},
|
||||
"resource": {},
|
||||
"syslog": {},
|
||||
"termios": {},
|
||||
"tkinter": {},
|
||||
"turtle": {},
|
||||
"turtledemo": {},
|
||||
"venv": {},
|
||||
"winreg": {},
|
||||
"winsound": {},
|
||||
"multiprocessing": {},
|
||||
"threading": {},
|
||||
"socket": {},
|
||||
"pty": {},
|
||||
"tty": {},
|
||||
}
|
||||
|
||||
// pythonThirdPartyWhitelist is the whitelist of python third-party modules,
|
||||
// see: https://www.coze.cn/open/docs/guides/code_node#7f41f073
|
||||
// If you want to use other third-party libraries, you can add them to this whitelist.
|
||||
// And you also need to install them in `/scripts/setup/python.sh` and `/backend/Dockerfile` via `pip install`.
|
||||
var pythonThirdPartyWhitelist = map[string]struct{}{
|
||||
"requests_async": {},
|
||||
"numpy": {},
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Code string
|
||||
Language code.Language
|
||||
OutputConfig map[string]*vo.TypeInfo
|
||||
Runner code.Runner
|
||||
}
|
||||
|
||||
type CodeRunner struct {
|
||||
config *Config
|
||||
importError error
|
||||
}
|
||||
|
||||
func NewCodeRunner(ctx context.Context, cfg *Config) (*CodeRunner, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("cfg is required")
|
||||
}
|
||||
|
||||
if cfg.Language == "" {
|
||||
return nil, errors.New("language is required")
|
||||
}
|
||||
|
||||
if cfg.Code == "" {
|
||||
return nil, errors.New("code is required")
|
||||
}
|
||||
|
||||
if cfg.Language != code.Python {
|
||||
return nil, errors.New("only support python language")
|
||||
}
|
||||
|
||||
if len(cfg.OutputConfig) == 0 {
|
||||
return nil, errors.New("output config is required")
|
||||
}
|
||||
|
||||
if cfg.Runner == nil {
|
||||
return nil, errors.New("run coder is required")
|
||||
}
|
||||
|
||||
importErr := validatePythonImports(cfg.Code)
|
||||
|
||||
return &CodeRunner{
|
||||
config: cfg,
|
||||
importError: importErr,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func validatePythonImports(code string) error {
|
||||
imports := parsePythonImports(code)
|
||||
importErrors := make([]string, 0)
|
||||
|
||||
var blacklistedModules []string
|
||||
var nonWhitelistedModules []string
|
||||
for _, imp := range imports {
|
||||
if _, ok := pythonBuiltinModules[imp]; ok {
|
||||
if _, blacklisted := pythonBuiltinBlacklist[imp]; blacklisted {
|
||||
blacklistedModules = append(blacklistedModules, imp)
|
||||
}
|
||||
} else {
|
||||
if _, whitelisted := pythonThirdPartyWhitelist[imp]; !whitelisted {
|
||||
nonWhitelistedModules = append(nonWhitelistedModules, imp)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(blacklistedModules) > 0 {
|
||||
moduleNames := fmt.Sprintf("'%s'", strings.Join(blacklistedModules, "', '"))
|
||||
importErrors = append(importErrors, fmt.Sprintf("ModuleNotFoundError: The module(s) %s are removed from the Python standard library for security reasons\n", moduleNames))
|
||||
}
|
||||
if len(nonWhitelistedModules) > 0 {
|
||||
moduleNames := fmt.Sprintf("'%s'", strings.Join(nonWhitelistedModules, "', '"))
|
||||
importErrors = append(importErrors, fmt.Sprintf("ModuleNotFoundError: No module named %s\n", moduleNames))
|
||||
}
|
||||
|
||||
if len(importErrors) > 0 {
|
||||
return errors.New(strings.Join(importErrors, ","))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CodeRunner) RunCode(ctx context.Context, input map[string]any) (ret map[string]any, err error) {
|
||||
if c.importError != nil {
|
||||
return nil, vo.WrapError(errno.ErrCodeExecuteFail, c.importError, errorx.KV("detail", c.importError.Error()))
|
||||
}
|
||||
response, err := c.config.Runner.Run(ctx, &code.RunRequest{Code: c.config.Code, Language: c.config.Language, Params: input})
|
||||
if err != nil {
|
||||
return nil, vo.WrapError(errno.ErrCodeExecuteFail, err, errorx.KV("detail", err.Error()))
|
||||
}
|
||||
|
||||
result := response.Result
|
||||
ctxcache.Store(ctx, coderRunnerRawOutputCtxKey, result)
|
||||
|
||||
output, ws, err := nodes.ConvertInputs(ctx, result, c.config.OutputConfig)
|
||||
if err != nil {
|
||||
return nil, vo.WrapIfNeeded(errno.ErrCodeExecuteFail, err, errorx.KV("detail", err.Error()))
|
||||
}
|
||||
|
||||
if ws != nil && len(*ws) > 0 {
|
||||
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
|
||||
ctxcache.Store(ctx, coderRunnerWarnErrorLevelCtxKey, *ws)
|
||||
}
|
||||
|
||||
return output, nil
|
||||
|
||||
}
|
||||
|
||||
func (c *CodeRunner) ToCallbackOutput(ctx context.Context, output map[string]any) (*nodes.StructuredCallbackOutput, error) {
|
||||
rawOutput, ok := ctxcache.Get[map[string]any](ctx, coderRunnerRawOutputCtxKey)
|
||||
if !ok {
|
||||
return nil, errors.New("raw output config is required")
|
||||
}
|
||||
|
||||
var wfe vo.WorkflowError
|
||||
if warnings, ok := ctxcache.Get[nodes.ConversionWarnings](ctx, coderRunnerWarnErrorLevelCtxKey); ok {
|
||||
wfe = vo.WrapWarn(errno.ErrNodeOutputParseFail, warnings, errorx.KV("warnings", warnings.Error()))
|
||||
}
|
||||
return &nodes.StructuredCallbackOutput{
|
||||
Output: output,
|
||||
RawOutput: rawOutput,
|
||||
Error: wfe,
|
||||
},
|
||||
nil
|
||||
|
||||
}
|
||||
|
||||
func parsePythonImports(code string) []string {
|
||||
modules := make(map[string]struct{})
|
||||
lines := strings.Split(code, "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
if matches := importRegex.FindStringSubmatch(line); len(matches) > 1 {
|
||||
importedItemsStr := matches[1]
|
||||
importedItems := strings.Split(importedItemsStr, ",")
|
||||
for _, item := range importedItems {
|
||||
item = strings.TrimSpace(item)
|
||||
parts := strings.Split(item, " ")
|
||||
if len(parts) > 0 {
|
||||
moduleName := parts[0]
|
||||
topLevelModule := strings.Split(moduleName, ".")[0]
|
||||
modules[topLevelModule] = struct{}{}
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if matches := fromImportRegex.FindStringSubmatch(line); len(matches) > 1 {
|
||||
fullModuleName := matches[1]
|
||||
parts := strings.Split(fullModuleName, ".")
|
||||
if len(parts) > 0 {
|
||||
modules[parts[0]] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return maps.Keys(modules)
|
||||
}
|
||||
262
backend/domain/workflow/internal/nodes/code/code_test.go
Normal file
262
backend/domain/workflow/internal/nodes/code/code_test.go
Normal file
@@ -0,0 +1,262 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package code
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/code"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
mockcode "github.com/coze-dev/coze-studio/backend/internal/mock/domain/workflow/crossdomain/code"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
)
|
||||
|
||||
var codeTpl string
|
||||
|
||||
func TestCode_RunCode(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mockRunner := mockcode.NewMockRunner(ctrl)
|
||||
|
||||
t.Run("normal", func(t *testing.T) {
|
||||
var codeTpl = `
|
||||
async def main(args:Args)->Output:
|
||||
params = args.params
|
||||
ret: Output = {
|
||||
"key0": params['input'] + params['input'],
|
||||
"key1": ["hello", "world"],
|
||||
"key2": [123, "345"],
|
||||
"key3": {
|
||||
"key31": "hi",
|
||||
"key32": "hello",
|
||||
"key33": ["123","456"],
|
||||
"key34": {
|
||||
"key341":"123",
|
||||
"key342":456,
|
||||
}
|
||||
},
|
||||
}
|
||||
return ret
|
||||
`
|
||||
ret := map[string]any{
|
||||
"key0": int64(11231123),
|
||||
"key1": []any{"hello", "world"},
|
||||
"key2": []interface{}{int64(123), "345"},
|
||||
"key3": map[string]interface{}{"key31": "hi", "key32": "hello", "key33": []any{"123", "456"}, "key34": map[string]interface{}{"key341": "123", "key342": int64(456)}},
|
||||
"key4": []any{
|
||||
map[string]any{"key41": "41"},
|
||||
map[string]any{"key42": "42"},
|
||||
},
|
||||
}
|
||||
|
||||
response := &code.RunResponse{
|
||||
Result: ret,
|
||||
}
|
||||
|
||||
mockRunner.EXPECT().Run(gomock.Any(), gomock.Any()).Return(response, nil)
|
||||
ctx := t.Context()
|
||||
c := &CodeRunner{
|
||||
config: &Config{
|
||||
Language: code.Python,
|
||||
Code: codeTpl,
|
||||
OutputConfig: map[string]*vo.TypeInfo{
|
||||
"key0": {Type: vo.DataTypeInteger},
|
||||
"key1": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeString}},
|
||||
"key2": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
"key3": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"key31": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key32": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key33": &vo.TypeInfo{Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
"key34": &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"key341": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key342": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
}},
|
||||
},
|
||||
},
|
||||
"key4": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject}},
|
||||
},
|
||||
Runner: mockRunner,
|
||||
},
|
||||
}
|
||||
ret, err := c.RunCode(ctx, map[string]any{
|
||||
"input": "1123",
|
||||
})
|
||||
|
||||
bs, _ := json.Marshal(ret)
|
||||
fmt.Println(string(bs))
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(11231123), ret["key0"])
|
||||
assert.Equal(t, []any{"hello", "world"}, ret["key1"])
|
||||
assert.Equal(t, []any{float64(123), float64(345)}, ret["key2"])
|
||||
assert.Equal(t, []any{float64(123), float64(456)}, ret["key3"].(map[string]any)["key33"])
|
||||
assert.Equal(t, map[string]any{"key41": "41"}, ret["key4"].([]any)[0].(map[string]any))
|
||||
|
||||
})
|
||||
t.Run("field not in return", func(t *testing.T) {
|
||||
codeTpl = `
|
||||
async def main(args:Args)->Output:
|
||||
params = args.params
|
||||
ret: Output = {
|
||||
"key0": params['input'] + params['input'],
|
||||
"key1": ["hello", "world"],
|
||||
"key2": [123, "345"],
|
||||
"key3": {
|
||||
"key31": "hi",
|
||||
"key32": "hello",
|
||||
"key34": {
|
||||
"key341":"123"
|
||||
}
|
||||
},
|
||||
}
|
||||
return ret
|
||||
`
|
||||
|
||||
ret := map[string]any{
|
||||
"key0": int64(11231123),
|
||||
"key1": []any{"hello", "world"},
|
||||
"key2": []interface{}{int64(123), "345"},
|
||||
"key3": map[string]interface{}{"key31": "hi", "key32": "hello", "key34": map[string]interface{}{"key341": "123"}},
|
||||
}
|
||||
|
||||
response := &code.RunResponse{
|
||||
Result: ret,
|
||||
}
|
||||
mockRunner.EXPECT().Run(gomock.Any(), gomock.Any()).Return(response, nil)
|
||||
|
||||
ctx := t.Context()
|
||||
c := &CodeRunner{
|
||||
config: &Config{
|
||||
Code: codeTpl,
|
||||
Language: code.Python,
|
||||
OutputConfig: map[string]*vo.TypeInfo{
|
||||
"key0": {Type: vo.DataTypeInteger},
|
||||
"key1": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeString}},
|
||||
"key2": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
"key3": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"key31": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key32": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key33": &vo.TypeInfo{Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
"key34": &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"key341": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key342": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
}},
|
||||
}},
|
||||
"key4": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"key31": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key32": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key33": &vo.TypeInfo{Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
"key34": &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"key341": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key342": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
},
|
||||
}},
|
||||
},
|
||||
},
|
||||
Runner: mockRunner,
|
||||
},
|
||||
}
|
||||
ret, err := c.RunCode(ctx, map[string]any{
|
||||
"input": "1123",
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(11231123), ret["key0"])
|
||||
assert.Equal(t, []any{"hello", "world"}, ret["key1"])
|
||||
assert.Equal(t, []any{float64(123), float64(345)}, ret["key2"])
|
||||
assert.Equal(t, nil, ret["key4"])
|
||||
assert.Equal(t, nil, ret["key3"].(map[string]any)["key33"])
|
||||
})
|
||||
t.Run("field convert failed", func(t *testing.T) {
|
||||
codeTpl = `
|
||||
async def main(args:Args)->Output:
|
||||
params = args.params
|
||||
ret: Output = {
|
||||
"key0": params['input'] + params['input'],
|
||||
"key1": ["hello", "world"],
|
||||
"key2": [123, "345"],
|
||||
"key3": {
|
||||
"key31": "hi",
|
||||
"key32": "hello",
|
||||
"key34": {
|
||||
"key341":"123",
|
||||
"key343": ["hello", "world"],
|
||||
}
|
||||
},
|
||||
}
|
||||
return ret
|
||||
`
|
||||
ctx := t.Context()
|
||||
ctx = ctxcache.Init(ctx)
|
||||
ret := map[string]any{
|
||||
"key0": int64(11231123),
|
||||
"key1": []any{"hello", "world"},
|
||||
"key2": []interface{}{int64(123), "345"},
|
||||
"key3": map[string]interface{}{"key31": "hi", "key32": "hello", "key34": map[string]interface{}{"key341": "123", "key343": []any{"hello", "world"}}},
|
||||
}
|
||||
response := &code.RunResponse{
|
||||
Result: ret,
|
||||
}
|
||||
mockRunner.EXPECT().Run(gomock.Any(), gomock.Any()).Return(response, nil)
|
||||
|
||||
c := &CodeRunner{
|
||||
config: &Config{
|
||||
Code: codeTpl,
|
||||
Language: code.Python,
|
||||
OutputConfig: map[string]*vo.TypeInfo{
|
||||
"key0": {Type: vo.DataTypeInteger},
|
||||
"key1": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
"key2": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
"key3": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"key31": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key32": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key33": &vo.TypeInfo{Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
"key34": &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"key341": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key342": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key343": &vo.TypeInfo{Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
}},
|
||||
},
|
||||
},
|
||||
},
|
||||
Runner: mockRunner,
|
||||
},
|
||||
}
|
||||
ret, err := c.RunCode(ctx, map[string]any{
|
||||
"input": "1123",
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(11231123), ret["key0"])
|
||||
assert.Equal(t, []any{float64(123), float64(345)}, ret["key2"])
|
||||
|
||||
warnings, ok := ctxcache.Get[nodes.ConversionWarnings](ctx, coderRunnerWarnErrorLevelCtxKey)
|
||||
assert.True(t, ok)
|
||||
s := warnings.Error()
|
||||
assert.Contains(t, s, "field key3.key34.key343.0 is not number")
|
||||
assert.Contains(t, s, "field key3.key34.key343.1 is not number")
|
||||
assert.Contains(t, s, "field key1.0 is not number")
|
||||
assert.Contains(t, s, "field key1.1 is not number")
|
||||
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/conversation"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
)
|
||||
|
||||
type ClearMessageConfig struct {
|
||||
Clearer conversation.ConversationManager
|
||||
}
|
||||
|
||||
type MessageClear struct {
|
||||
config *ClearMessageConfig
|
||||
}
|
||||
|
||||
func NewClearMessage(ctx context.Context, cfg *ClearMessageConfig) (*MessageClear, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config is required")
|
||||
}
|
||||
if cfg.Clearer == nil {
|
||||
return nil, errors.New("clearer is required")
|
||||
}
|
||||
|
||||
return &MessageClear{
|
||||
config: cfg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *MessageClear) Clear(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
name, ok := nodes.TakeMapValue(input, compose.FieldPath{"ConversationName"})
|
||||
if !ok {
|
||||
return nil, errors.New("input map should contains 'ConversationName' key ")
|
||||
}
|
||||
response, err := c.config.Clearer.ClearMessage(ctx, &conversation.ClearMessageRequest{
|
||||
Name: name.(string),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return map[string]any{
|
||||
"isSuccess": response.IsSuccess,
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/conversation"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
)
|
||||
|
||||
type CreateConversationConfig struct {
|
||||
Creator conversation.ConversationManager
|
||||
}
|
||||
|
||||
type CreateConversation struct {
|
||||
config *CreateConversationConfig
|
||||
}
|
||||
|
||||
func NewCreateConversation(ctx context.Context, cfg *CreateConversationConfig) (*CreateConversation, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config is required")
|
||||
}
|
||||
if cfg.Creator == nil {
|
||||
return nil, errors.New("creator is required")
|
||||
}
|
||||
return &CreateConversation{
|
||||
config: cfg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *CreateConversation) Create(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
name, ok := nodes.TakeMapValue(input, compose.FieldPath{"ConversationName"})
|
||||
if !ok {
|
||||
return nil, errors.New("input map should contains 'ConversationName' key ")
|
||||
}
|
||||
response, err := c.config.Creator.CreateConversation(ctx, &conversation.CreateConversationRequest{
|
||||
Name: name.(string),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return response.Result, nil
|
||||
|
||||
}
|
||||
@@ -0,0 +1,108 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/conversation"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
)
|
||||
|
||||
type MessageListConfig struct {
|
||||
Lister conversation.ConversationManager
|
||||
}
|
||||
type MessageList struct {
|
||||
config *MessageListConfig
|
||||
}
|
||||
|
||||
type Param struct {
|
||||
ConversationName string
|
||||
Limit *int
|
||||
BeforeID *string
|
||||
AfterID *string
|
||||
}
|
||||
|
||||
func NewMessageList(ctx context.Context, cfg *MessageListConfig) (*MessageList, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config is required")
|
||||
}
|
||||
|
||||
if cfg.Lister == nil {
|
||||
return nil, errors.New("lister is required")
|
||||
}
|
||||
|
||||
return &MessageList{
|
||||
config: cfg,
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
func (m *MessageList) List(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
param := &Param{}
|
||||
name, ok := nodes.TakeMapValue(input, compose.FieldPath{"ConversationName"})
|
||||
if !ok {
|
||||
return nil, errors.New("ConversationName is required")
|
||||
}
|
||||
param.ConversationName = name.(string)
|
||||
limit, ok := nodes.TakeMapValue(input, compose.FieldPath{"Limit"})
|
||||
if ok {
|
||||
limit := limit.(int)
|
||||
param.Limit = &limit
|
||||
}
|
||||
beforeID, ok := nodes.TakeMapValue(input, compose.FieldPath{"BeforeID"})
|
||||
if ok {
|
||||
beforeID := beforeID.(string)
|
||||
param.BeforeID = &beforeID
|
||||
}
|
||||
afterID, ok := nodes.TakeMapValue(input, compose.FieldPath{"AfterID"})
|
||||
if ok {
|
||||
afterID := afterID.(string)
|
||||
param.BeforeID = &afterID
|
||||
}
|
||||
r, err := m.config.Lister.MessageList(ctx, &conversation.ListMessageRequest{
|
||||
ConversationName: param.ConversationName,
|
||||
Limit: param.Limit,
|
||||
BeforeID: param.BeforeID,
|
||||
AfterID: param.AfterID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make(map[string]any)
|
||||
objects := make([]any, 0, len(r.Messages))
|
||||
for _, msg := range r.Messages {
|
||||
object := make(map[string]any)
|
||||
bs, _ := json.Marshal(msg)
|
||||
err := json.Unmarshal(bs, &object)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
objects = append(objects, object)
|
||||
}
|
||||
|
||||
result["messageList"] = objects
|
||||
result["firstId"] = r.FirstID
|
||||
result["hasMore"] = r.HasMore
|
||||
return result, nil
|
||||
|
||||
}
|
||||
405
backend/domain/workflow/internal/nodes/convert.go
Normal file
405
backend/domain/workflow/internal/nodes/convert.go
Normal file
@@ -0,0 +1,405 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
type ConversionWarning struct {
|
||||
Path string
|
||||
Type vo.DataType
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *ConversionWarning) Error() string {
|
||||
return fmt.Sprintf("field %s is not %s", e.Path, e.Type)
|
||||
}
|
||||
|
||||
type ConversionWarnings []*ConversionWarning
|
||||
|
||||
func (e ConversionWarnings) Merge(e1 ConversionWarnings) ConversionWarnings {
|
||||
return append(e, e1...)
|
||||
}
|
||||
|
||||
func (e ConversionWarnings) Error() string {
|
||||
if len(e) == 0 {
|
||||
return ""
|
||||
}
|
||||
var errs []string
|
||||
for _, err := range e {
|
||||
errs = append(errs, err.Error())
|
||||
}
|
||||
return strings.Join(errs, ", ")
|
||||
}
|
||||
|
||||
func newWarnings(path string, t vo.DataType, err error) *ConversionWarnings {
|
||||
return ptr.Of(ConversionWarnings{
|
||||
{
|
||||
Path: path,
|
||||
Type: t,
|
||||
Err: err,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func ConvertInputs(ctx context.Context, in map[string]any, tInfo map[string]*vo.TypeInfo, opts ...ConvertOption) (
|
||||
map[string]any, *ConversionWarnings, error) {
|
||||
options := &convertOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
if len(in) == 0 {
|
||||
if !options.skipRequireCheck {
|
||||
for n, t := range tInfo {
|
||||
if t.Required {
|
||||
return nil, nil, vo.NewError(errno.ErrMissingRequiredParam, errorx.KV("param", n))
|
||||
}
|
||||
}
|
||||
}
|
||||
return in, nil, nil
|
||||
}
|
||||
|
||||
out := make(map[string]any)
|
||||
var warnings ConversionWarnings
|
||||
for k, v := range in {
|
||||
t, ok := tInfo[k]
|
||||
if !ok {
|
||||
// for input fields not explicitly defined, just pass the string value through
|
||||
logs.CtxWarnf(ctx, "input %s not found in type info", k)
|
||||
if !options.skipUnknownFields {
|
||||
out[k] = in[k]
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
converted, ws, err := Convert(ctx, v, k, t, opts...)
|
||||
if err != nil {
|
||||
return nil, nil, vo.WrapError(errno.ErrInvalidParameter, err)
|
||||
}
|
||||
|
||||
if ws != nil {
|
||||
warnings = append(warnings, *ws...)
|
||||
}
|
||||
out[k] = converted
|
||||
}
|
||||
|
||||
if !options.skipRequireCheck {
|
||||
for k, t := range tInfo {
|
||||
if _, ok := out[k]; !ok {
|
||||
if t.Required {
|
||||
return nil, nil, vo.NewError(errno.ErrMissingRequiredParam, errorx.KV("param", k))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(warnings) > 0 {
|
||||
return out, &warnings, nil
|
||||
}
|
||||
|
||||
return out, nil, nil
|
||||
}
|
||||
|
||||
type convertOptions struct {
|
||||
skipUnknownFields bool
|
||||
failFast bool
|
||||
skipRequireCheck bool
|
||||
}
|
||||
|
||||
type ConvertOption func(*convertOptions)
|
||||
|
||||
func SkipUnknownFields() ConvertOption {
|
||||
return func(o *convertOptions) {
|
||||
o.skipUnknownFields = true
|
||||
}
|
||||
}
|
||||
|
||||
func FailFast() ConvertOption {
|
||||
return func(o *convertOptions) {
|
||||
o.failFast = true
|
||||
}
|
||||
}
|
||||
|
||||
func SkipRequireCheck() ConvertOption {
|
||||
return func(o *convertOptions) {
|
||||
o.skipRequireCheck = true
|
||||
}
|
||||
}
|
||||
|
||||
func Convert(ctx context.Context, in any, path string, t *vo.TypeInfo, opts ...ConvertOption) (
|
||||
any, *ConversionWarnings, error) {
|
||||
options := &convertOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
return convert(ctx, in, path, t, options)
|
||||
}
|
||||
|
||||
func convert(ctx context.Context, in any, path string, t *vo.TypeInfo, options *convertOptions) (
|
||||
any, *ConversionWarnings, error) {
|
||||
if in == nil { // nil is valid for ALL types
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
switch t.Type {
|
||||
case vo.DataTypeString, vo.DataTypeFile, vo.DataTypeTime:
|
||||
return convertToString(ctx, in, path, options)
|
||||
case vo.DataTypeInteger:
|
||||
return convertToInt64(ctx, in, path, options)
|
||||
case vo.DataTypeNumber:
|
||||
return convertToFloat64(ctx, in, path, options)
|
||||
case vo.DataTypeBoolean:
|
||||
return convertToBool(ctx, in, path, options)
|
||||
case vo.DataTypeObject:
|
||||
return convertToObject(ctx, in, path, t, options)
|
||||
case vo.DataTypeArray:
|
||||
return convertToArray(ctx, in, path, t, options)
|
||||
default:
|
||||
if options.failFast {
|
||||
return nil, nil, vo.WrapError(errno.ErrInvalidParameter, fmt.Errorf("unknown input type %s for path %s", t.Type, path))
|
||||
}
|
||||
logs.CtxErrorf(ctx, "unknown input type %s for path %s", t.Type, path)
|
||||
return in, newWarnings(path, t.Type, errors.New("unknown input type")), nil
|
||||
}
|
||||
}
|
||||
|
||||
func convertToString(_ context.Context, in any, path string, options *convertOptions) (any, *ConversionWarnings, error) {
|
||||
switch in.(type) {
|
||||
case string:
|
||||
return in.(string), nil, nil
|
||||
case int64:
|
||||
return strconv.FormatInt(in.(int64), 10), nil, nil
|
||||
case float64:
|
||||
return strconv.FormatFloat(in.(float64), 'f', -1, 64), nil, nil
|
||||
case bool:
|
||||
return strconv.FormatBool(in.(bool)), nil, nil
|
||||
case []any, map[string]any:
|
||||
s, err := sonic.MarshalString(in)
|
||||
if err != nil {
|
||||
if options.failFast {
|
||||
return nil, nil, vo.WrapError(errno.ErrSerializationDeserializationFail, err)
|
||||
}
|
||||
return nil, newWarnings(path, vo.DataTypeString, err), nil
|
||||
}
|
||||
return s, nil, nil
|
||||
default:
|
||||
if options.failFast {
|
||||
return nil, nil, vo.WrapError(errno.ErrInvalidParameter, fmt.Errorf("unsupported type to convert to string: %T", in))
|
||||
}
|
||||
return nil, newWarnings(path, vo.DataTypeString, fmt.Errorf("unsupported type to convert to string: %T", in)), nil
|
||||
}
|
||||
}
|
||||
|
||||
func convertToInt64(_ context.Context, in any, path string, options *convertOptions) (any, *ConversionWarnings, error) {
|
||||
switch in.(type) {
|
||||
case int64:
|
||||
return in.(int64), nil, nil
|
||||
case float64:
|
||||
return int64(in.(float64)), nil, nil
|
||||
case string:
|
||||
i, err := strconv.ParseInt(in.(string), 10, 64)
|
||||
if err != nil {
|
||||
if options.failFast {
|
||||
return nil, nil, vo.WrapError(errno.ErrInvalidParameter, err)
|
||||
}
|
||||
return nil, newWarnings(path, vo.DataTypeInteger, err), nil
|
||||
}
|
||||
return i, nil, nil
|
||||
default:
|
||||
if options.failFast {
|
||||
return nil, nil, vo.WrapError(errno.ErrInvalidParameter, fmt.Errorf("unsupported type to convert to int64: %T", in))
|
||||
}
|
||||
return nil, newWarnings(path, vo.DataTypeInteger, fmt.Errorf("unsupported type to convert to int64: %T", in)), nil
|
||||
}
|
||||
}
|
||||
|
||||
func convertToFloat64(_ context.Context, in any, path string, options *convertOptions) (any, *ConversionWarnings, error) {
|
||||
switch in.(type) {
|
||||
case int64:
|
||||
return float64(in.(int64)), nil, nil
|
||||
case float64:
|
||||
return in.(float64), nil, nil
|
||||
case string:
|
||||
f, err := strconv.ParseFloat(in.(string), 64)
|
||||
if err != nil {
|
||||
if options.failFast {
|
||||
return nil, nil, vo.WrapError(errno.ErrInvalidParameter, err)
|
||||
}
|
||||
return nil, newWarnings(path, vo.DataTypeNumber, err), nil
|
||||
}
|
||||
return f, nil, nil
|
||||
default:
|
||||
if options.failFast {
|
||||
return nil, nil, vo.WrapError(errno.ErrInvalidParameter, fmt.Errorf("unsupported type to convert to float64: %T", in))
|
||||
}
|
||||
return nil, newWarnings(path, vo.DataTypeNumber, fmt.Errorf("unsupported type to convert to float64: %T", in)), nil
|
||||
}
|
||||
}
|
||||
|
||||
func convertToBool(_ context.Context, in any, path string, options *convertOptions) (any, *ConversionWarnings, error) {
|
||||
switch in.(type) {
|
||||
case bool:
|
||||
return in.(bool), nil, nil
|
||||
case string:
|
||||
b, err := strconv.ParseBool(in.(string))
|
||||
if err != nil {
|
||||
if options.failFast {
|
||||
return nil, nil, vo.WrapError(errno.ErrInvalidParameter, err)
|
||||
}
|
||||
return nil, newWarnings(path, vo.DataTypeBoolean, err), nil
|
||||
}
|
||||
return b, nil, nil
|
||||
default:
|
||||
if options.failFast {
|
||||
return nil, nil, vo.WrapError(errno.ErrInvalidParameter, fmt.Errorf("unsupported type to convert to bool: %T", in))
|
||||
}
|
||||
return nil, newWarnings(path, vo.DataTypeBoolean, fmt.Errorf("unsupported type to convert to bool: %T", in)), nil
|
||||
}
|
||||
}
|
||||
|
||||
func convertToObject(ctx context.Context, in any, path string, t *vo.TypeInfo, options *convertOptions) (
|
||||
map[string]any, *ConversionWarnings, error) {
|
||||
var m map[string]any
|
||||
switch in.(type) {
|
||||
case map[string]any:
|
||||
m = in.(map[string]any)
|
||||
case string:
|
||||
err := sonic.UnmarshalString(in.(string), &m)
|
||||
if err != nil {
|
||||
if options.failFast {
|
||||
return nil, nil, vo.WrapError(errno.ErrSerializationDeserializationFail, err)
|
||||
}
|
||||
return nil, newWarnings(path, vo.DataTypeObject, err), nil
|
||||
}
|
||||
default:
|
||||
if options.failFast {
|
||||
return nil, nil, vo.WrapError(errno.ErrInvalidParameter, fmt.Errorf("unsupported type to convert to object: %T", in))
|
||||
}
|
||||
return nil, newWarnings(path, vo.DataTypeObject, fmt.Errorf("unsupported type to convert to object: %T", in)), nil
|
||||
}
|
||||
|
||||
if len(m) == 0 {
|
||||
if !options.skipRequireCheck {
|
||||
for pn, pro := range t.Properties {
|
||||
if pro.Required {
|
||||
return nil, nil, vo.NewError(errno.ErrMissingRequiredParam,
|
||||
errorx.KV("param", fmt.Sprintf("%s.%s", path, pn)))
|
||||
}
|
||||
}
|
||||
}
|
||||
return m, nil, nil
|
||||
}
|
||||
|
||||
out := make(map[string]any, len(m))
|
||||
var warnings ConversionWarnings
|
||||
for k, v := range m {
|
||||
propType, ok := t.Properties[k]
|
||||
if !ok {
|
||||
// for input fields not explicitly defined, just pass the value through
|
||||
logs.CtxWarnf(ctx, "input %s.%s not found in type info", path, k)
|
||||
if !options.skipUnknownFields {
|
||||
out[k] = v
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
propPath := fmt.Sprintf("%s.%s", path, k)
|
||||
newV, ws, err := convert(ctx, v, propPath, propType, options)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
} else if ws != nil {
|
||||
warnings = append(warnings, *ws...)
|
||||
}
|
||||
out[k] = newV
|
||||
}
|
||||
|
||||
if !options.skipRequireCheck {
|
||||
for k, t := range t.Properties {
|
||||
if _, ok := out[k]; !ok {
|
||||
if t.Required {
|
||||
return nil, nil, vo.NewError(errno.ErrMissingRequiredParam,
|
||||
errorx.KV("param", fmt.Sprintf("%s.%s", path, k)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(warnings) > 0 {
|
||||
return out, ptr.Of(warnings), nil
|
||||
}
|
||||
return out, nil, nil
|
||||
}
|
||||
|
||||
func convertToArray(ctx context.Context, in any, path string, t *vo.TypeInfo, options *convertOptions) (
|
||||
[]any, *ConversionWarnings, error) {
|
||||
var a []any
|
||||
switch v := in.(type) {
|
||||
case []any:
|
||||
a = v
|
||||
case string:
|
||||
err := sonic.UnmarshalString(v, &a)
|
||||
if err != nil {
|
||||
if options.failFast {
|
||||
return nil, nil, vo.WrapError(errno.ErrSerializationDeserializationFail, err)
|
||||
}
|
||||
return []any{}, newWarnings(path, vo.DataTypeArray, err), nil
|
||||
}
|
||||
default:
|
||||
if options.failFast {
|
||||
return nil, nil, vo.WrapError(errno.ErrInvalidParameter, fmt.Errorf("unsupported type to convert to array: %T", in))
|
||||
}
|
||||
return []any{}, newWarnings(path, vo.DataTypeArray, fmt.Errorf("unsupported type to convert to array: %T", in)), nil
|
||||
}
|
||||
|
||||
if len(a) == 0 {
|
||||
return a, nil, nil
|
||||
}
|
||||
|
||||
out := make([]any, 0, len(a))
|
||||
var warnings ConversionWarnings
|
||||
elemType := t.ElemTypeInfo
|
||||
for i, v := range a {
|
||||
elemPath := fmt.Sprintf("%s.%d", path, i)
|
||||
newV, ws, err := convert(ctx, v, elemPath, elemType, options)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
} else if ws != nil {
|
||||
warnings = append(warnings, *ws...)
|
||||
} else { // only correctly converted elements go into the final array
|
||||
out = append(out, newV)
|
||||
}
|
||||
}
|
||||
|
||||
if len(warnings) > 0 {
|
||||
return out, ptr.Of(warnings), nil
|
||||
}
|
||||
|
||||
return out, nil, nil
|
||||
}
|
||||
435
backend/domain/workflow/internal/nodes/database/common.go
Normal file
435
backend/domain/workflow/internal/nodes/database/common.go
Normal file
@@ -0,0 +1,435 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
)
|
||||
|
||||
const rowNum = "rowNum"
|
||||
const outputList = "outputList"
|
||||
const TimeFormat = "2006-01-02 15:04:05 -0700 MST"
|
||||
|
||||
func toString(in any) (any, error) {
|
||||
switch in := in.(type) {
|
||||
case []byte:
|
||||
return string(in), nil
|
||||
case string:
|
||||
return in, nil
|
||||
case int64:
|
||||
return strconv.FormatInt(in, 10), nil
|
||||
case float64:
|
||||
return strconv.FormatFloat(in, 'f', -1, 64), nil
|
||||
case time.Time:
|
||||
return in.Format(TimeFormat), nil
|
||||
case bool:
|
||||
return strconv.FormatBool(in), nil
|
||||
case map[string]any, []any:
|
||||
return sonic.MarshalString(in)
|
||||
default:
|
||||
return "", fmt.Errorf("unknown type: %T", in)
|
||||
}
|
||||
}
|
||||
|
||||
func toInteger(in any) (any, error) {
|
||||
switch in := in.(type) {
|
||||
case []byte:
|
||||
return strconv.ParseInt(string(in), 10, 64)
|
||||
case string:
|
||||
return strconv.ParseInt(in, 10, 64)
|
||||
case int64:
|
||||
return in, nil
|
||||
case float64:
|
||||
return int64(in), nil
|
||||
case time.Time, bool:
|
||||
return nil, fmt.Errorf(`type '%T' can't convert to int64'`, in)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown type: %T", in)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func toNumber(in any) (any, error) {
|
||||
switch in := in.(type) {
|
||||
case []byte:
|
||||
i, err := strconv.ParseFloat(string(in), 64)
|
||||
return i, err
|
||||
case string:
|
||||
return strconv.ParseFloat(in, 64)
|
||||
case int64:
|
||||
return float64(in), nil
|
||||
case float64:
|
||||
return in, nil
|
||||
case time.Time, bool:
|
||||
return nil, fmt.Errorf(`type '%T' can't convert to float64'`, in)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown type: %T", in)
|
||||
}
|
||||
}
|
||||
|
||||
func toTime(in any) (any, error) {
|
||||
switch in := in.(type) {
|
||||
case []byte:
|
||||
return string(in), nil
|
||||
case string:
|
||||
return in, nil
|
||||
case int64:
|
||||
return strconv.FormatInt(in, 10), nil
|
||||
case float64:
|
||||
return strconv.FormatFloat(in, 'f', -1, 64), nil
|
||||
case time.Time:
|
||||
return in.Format(TimeFormat), nil
|
||||
case bool:
|
||||
if in {
|
||||
return "1", nil
|
||||
}
|
||||
return "0", nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown type: %T", in)
|
||||
}
|
||||
}
|
||||
|
||||
func toBool(in any) (any, error) {
|
||||
switch in := in.(type) {
|
||||
case []byte:
|
||||
return strconv.ParseBool(string(in))
|
||||
case string:
|
||||
return strconv.ParseBool(in)
|
||||
case int64:
|
||||
return strconv.ParseBool(strconv.FormatInt(in, 10))
|
||||
case float64:
|
||||
return strconv.ParseBool(strconv.FormatFloat(in, 'f', -1, 64))
|
||||
case time.Time:
|
||||
return strconv.ParseBool(in.Format(TimeFormat))
|
||||
case bool:
|
||||
return in, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown type: %T", in)
|
||||
}
|
||||
}
|
||||
|
||||
// formatted convert the interface type according to the datatype type.
|
||||
// notice: object is currently not supported by database, and ignore it.
|
||||
func formatted(in any, ty *vo.TypeInfo) any {
|
||||
switch ty.Type {
|
||||
case vo.DataTypeString:
|
||||
r, err := toString(in)
|
||||
if err != nil {
|
||||
logs.Warnf("formatted string error: %v", err)
|
||||
return nil
|
||||
}
|
||||
return r
|
||||
case vo.DataTypeNumber:
|
||||
r, err := toNumber(in)
|
||||
if err != nil {
|
||||
logs.Warnf("formatted number error: %v", err)
|
||||
return nil
|
||||
}
|
||||
return r
|
||||
case vo.DataTypeInteger:
|
||||
r, err := toInteger(in)
|
||||
if err != nil {
|
||||
logs.Warnf("formatted integer error: %v", err)
|
||||
return nil
|
||||
}
|
||||
return r
|
||||
case vo.DataTypeBoolean:
|
||||
r, err := toBool(in)
|
||||
if err != nil {
|
||||
logs.Warnf("formatted boolean error: %v", err)
|
||||
}
|
||||
return r
|
||||
case vo.DataTypeTime:
|
||||
r, err := toTime(in)
|
||||
if err != nil {
|
||||
logs.Warnf("formatted time error: %v", err)
|
||||
return nil
|
||||
}
|
||||
return r
|
||||
case vo.DataTypeArray:
|
||||
arrayIn := make([]any, 0)
|
||||
inStr, err := toString(in)
|
||||
if err != nil {
|
||||
logs.Warnf("formatted array error: %v", err)
|
||||
return []any{}
|
||||
}
|
||||
|
||||
err = sonic.UnmarshalString(inStr.(string), &arrayIn)
|
||||
if err != nil {
|
||||
logs.Warnf("formatted array unmarshal error: %v", err)
|
||||
return []any{}
|
||||
}
|
||||
result := make([]any, 0)
|
||||
switch ty.ElemTypeInfo.Type {
|
||||
case vo.DataTypeTime:
|
||||
for _, in := range arrayIn {
|
||||
r, err := toTime(in)
|
||||
if err != nil {
|
||||
logs.Warnf("formatted time: %v", err)
|
||||
continue
|
||||
}
|
||||
result = append(result, r)
|
||||
}
|
||||
return result
|
||||
case vo.DataTypeString:
|
||||
for _, in := range arrayIn {
|
||||
r, err := toString(in)
|
||||
if err != nil {
|
||||
logs.Warnf("formatted string failed: %v", err)
|
||||
continue
|
||||
}
|
||||
result = append(result, r)
|
||||
}
|
||||
return result
|
||||
case vo.DataTypeInteger:
|
||||
for _, in := range arrayIn {
|
||||
r, err := toInteger(in)
|
||||
if err != nil {
|
||||
logs.Warnf("formatted interger failed: %v", err)
|
||||
continue
|
||||
}
|
||||
result = append(result, r)
|
||||
}
|
||||
return result
|
||||
case vo.DataTypeBoolean:
|
||||
for _, in := range arrayIn {
|
||||
r, err := toBool(in)
|
||||
if err != nil {
|
||||
logs.Warnf("formatted bool failed: %v", err)
|
||||
continue
|
||||
}
|
||||
result = append(result, r)
|
||||
}
|
||||
return result
|
||||
case vo.DataTypeNumber:
|
||||
for _, in := range arrayIn {
|
||||
r, err := toNumber(in)
|
||||
if err != nil {
|
||||
logs.Warnf("formatted number failed: %v", err)
|
||||
continue
|
||||
}
|
||||
result = append(result, r)
|
||||
}
|
||||
return result
|
||||
case vo.DataTypeObject:
|
||||
properties := ty.ElemTypeInfo.Properties
|
||||
if len(properties) == 0 {
|
||||
for idx := range arrayIn {
|
||||
in := arrayIn[idx]
|
||||
if _, ok := in.(database.Object); ok {
|
||||
result = append(result, in)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
for idx := range arrayIn {
|
||||
in := arrayIn[idx]
|
||||
object, ok := in.(database.Object)
|
||||
if !ok {
|
||||
object = make(database.Object)
|
||||
for key := range properties {
|
||||
object[key] = nil
|
||||
}
|
||||
result = append(result, object)
|
||||
} else {
|
||||
result = append(result, objectFormatted(ty.ElemTypeInfo.Properties, object))
|
||||
|
||||
}
|
||||
}
|
||||
return result
|
||||
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func objectFormatted(props map[string]*vo.TypeInfo, object database.Object) map[string]any {
|
||||
ret := make(map[string]any)
|
||||
|
||||
// if config is nil, it agrees to convert to string type as the default value
|
||||
if len(props) == 0 {
|
||||
for k, v := range object {
|
||||
val, err := toString(v)
|
||||
if err != nil {
|
||||
logs.Warnf("formatted string error: %v", err)
|
||||
continue
|
||||
}
|
||||
ret[k] = val
|
||||
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
for k, v := range props {
|
||||
if r, ok := object[k]; ok && r != nil {
|
||||
formattedValue := formatted(r, v)
|
||||
ret[k] = formattedValue
|
||||
} else {
|
||||
// if key not existed, assign nil
|
||||
ret[k] = nil
|
||||
}
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// responseFormatted convert the object list returned by "response" into the field mapping of the "config output" configuration,
|
||||
// If the conversion fail, set the output list to null. If there are missing fields, set the missing fields to null.
|
||||
func responseFormatted(configOutput map[string]*vo.TypeInfo, response *database.Response) (map[string]any, error) {
|
||||
ret := make(map[string]any)
|
||||
list := make([]any, 0, len(configOutput))
|
||||
|
||||
outputListTypeInfo, ok := configOutput["outputList"]
|
||||
if !ok {
|
||||
return ret, fmt.Errorf("outputList key is required")
|
||||
}
|
||||
if outputListTypeInfo.Type != vo.DataTypeArray {
|
||||
return nil, fmt.Errorf("output list type info must array,but got %v", outputListTypeInfo.Type)
|
||||
}
|
||||
if outputListTypeInfo.ElemTypeInfo == nil {
|
||||
return nil, fmt.Errorf("output list must be an array and the array must contain element type info")
|
||||
}
|
||||
if outputListTypeInfo.ElemTypeInfo.Type != vo.DataTypeObject {
|
||||
return nil, fmt.Errorf("output list must be an array and element must object, but got %v", outputListTypeInfo.ElemTypeInfo.Type)
|
||||
}
|
||||
|
||||
props := outputListTypeInfo.ElemTypeInfo.Properties
|
||||
|
||||
for _, object := range response.Objects {
|
||||
list = append(list, objectFormatted(props, object))
|
||||
}
|
||||
|
||||
ret[outputList] = list
|
||||
if response.RowNumber != nil {
|
||||
ret[rowNum] = *response.RowNumber
|
||||
} else {
|
||||
ret[rowNum] = nil
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func convertClauseGroupToConditionGroup(ctx context.Context, clauseGroup *database.ClauseGroup, input map[string]any) (*database.ConditionGroup, error) {
|
||||
var (
|
||||
rightValue any
|
||||
ok bool
|
||||
)
|
||||
|
||||
conditionGroup := &database.ConditionGroup{
|
||||
Conditions: make([]*database.Condition, 0),
|
||||
Relation: database.ClauseRelationAND,
|
||||
}
|
||||
|
||||
if clauseGroup.Single != nil {
|
||||
clause := clauseGroup.Single
|
||||
if !notNeedTakeMapValue(clause.Operator) {
|
||||
rightValue, ok = nodes.TakeMapValue(input, compose.FieldPath{"__condition_right_0"})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("cannot take single clause from input")
|
||||
}
|
||||
}
|
||||
|
||||
conditionGroup.Conditions = append(conditionGroup.Conditions, &database.Condition{
|
||||
Left: clause.Left,
|
||||
Operator: clause.Operator,
|
||||
Right: rightValue,
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
if clauseGroup.Multi != nil {
|
||||
conditionGroup.Relation = clauseGroup.Multi.Relation
|
||||
|
||||
conditionGroup.Conditions = make([]*database.Condition, len(clauseGroup.Multi.Clauses))
|
||||
multiSelect := clauseGroup.Multi
|
||||
for idx, clause := range multiSelect.Clauses {
|
||||
if !notNeedTakeMapValue(clause.Operator) {
|
||||
rightValue, ok = nodes.TakeMapValue(input, compose.FieldPath{fmt.Sprintf("__condition_right_%d", idx)})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("cannot take multi clause from input")
|
||||
}
|
||||
}
|
||||
conditionGroup.Conditions[idx] = &database.Condition{
|
||||
Left: clause.Left,
|
||||
Operator: clause.Operator,
|
||||
Right: rightValue,
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
return conditionGroup, nil
|
||||
}
|
||||
|
||||
func convertClauseGroupToUpdateInventory(ctx context.Context, clauseGroup *database.ClauseGroup, input map[string]any) (*UpdateInventory, error) {
|
||||
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, clauseGroup, input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fields := parseToInput(input)
|
||||
inventory := &UpdateInventory{
|
||||
ConditionGroup: conditionGroup,
|
||||
Fields: fields,
|
||||
}
|
||||
return inventory, nil
|
||||
}
|
||||
|
||||
func isDebugExecute(ctx context.Context) bool {
|
||||
execCtx := execute.GetExeCtx(ctx)
|
||||
if execCtx == nil {
|
||||
panic(fmt.Errorf("unable to get exe context"))
|
||||
}
|
||||
return execCtx.RootCtx.ExeCfg.Mode == vo.ExecuteModeDebug || execCtx.RootCtx.ExeCfg.Mode == vo.ExecuteModeNodeDebug
|
||||
}
|
||||
|
||||
func getExecUserID(ctx context.Context) int64 {
|
||||
execCtx := execute.GetExeCtx(ctx)
|
||||
if execCtx == nil {
|
||||
panic(fmt.Errorf("unable to get exe context"))
|
||||
}
|
||||
return execCtx.RootCtx.ExeCfg.Operator
|
||||
}
|
||||
|
||||
func parseToInput(input map[string]any) map[string]any {
|
||||
result := make(map[string]any, len(input))
|
||||
for key, value := range input {
|
||||
if strings.HasPrefix(key, "__setting_field_") {
|
||||
key = strings.TrimPrefix(key, "__setting_field_")
|
||||
result[key] = value
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
127
backend/domain/workflow/internal/nodes/database/customsql.go
Normal file
127
backend/domain/workflow/internal/nodes/database/customsql.go
Normal file
@@ -0,0 +1,127 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
)
|
||||
|
||||
type CustomSQLConfig struct {
|
||||
DatabaseInfoID int64
|
||||
SQLTemplate string
|
||||
OutputConfig map[string]*vo.TypeInfo
|
||||
CustomSQLExecutor database.DatabaseOperator
|
||||
}
|
||||
|
||||
func NewCustomSQL(_ context.Context, cfg *CustomSQLConfig) (*CustomSQL, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config is required")
|
||||
}
|
||||
if cfg.DatabaseInfoID == 0 {
|
||||
return nil, errors.New("database info id is required and greater than 0")
|
||||
}
|
||||
if cfg.SQLTemplate == "" {
|
||||
return nil, errors.New("sql template is required")
|
||||
}
|
||||
if cfg.CustomSQLExecutor == nil {
|
||||
return nil, errors.New("custom sqler is required")
|
||||
}
|
||||
return &CustomSQL{
|
||||
config: cfg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type CustomSQL struct {
|
||||
config *CustomSQLConfig
|
||||
}
|
||||
|
||||
func (c *CustomSQL) Execute(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
|
||||
req := &database.CustomSQLRequest{
|
||||
DatabaseInfoID: c.config.DatabaseInfoID,
|
||||
IsDebugRun: isDebugExecute(ctx),
|
||||
UserID: getExecUserID(ctx),
|
||||
}
|
||||
|
||||
inputBytes, err := sonic.Marshal(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
templateSQL := ""
|
||||
templateParts := nodes.ParseTemplate(c.config.SQLTemplate)
|
||||
sqlParams := make([]database.SQLParam, 0, len(templateParts))
|
||||
var nilError = errors.New("field is nil")
|
||||
for _, templatePart := range templateParts {
|
||||
if !templatePart.IsVariable {
|
||||
templateSQL += templatePart.Value
|
||||
continue
|
||||
}
|
||||
|
||||
templateSQL += "?"
|
||||
val, err := templatePart.Render(inputBytes, nodes.WithNilRender(func() (string, error) {
|
||||
return "", nilError
|
||||
}),
|
||||
nodes.WithCustomRender(reflect.TypeOf(false), func(val any) (string, error) {
|
||||
b := val.(bool)
|
||||
if b {
|
||||
return "1", nil
|
||||
}
|
||||
return "0", nil
|
||||
}))
|
||||
|
||||
if err != nil {
|
||||
if !errors.Is(err, nilError) {
|
||||
return nil, err
|
||||
}
|
||||
sqlParams = append(sqlParams, database.SQLParam{
|
||||
IsNull: true,
|
||||
})
|
||||
} else {
|
||||
sqlParams = append(sqlParams, database.SQLParam{
|
||||
Value: val,
|
||||
IsNull: false,
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// replace sql template '?' to ?
|
||||
templateSQL = strings.Replace(templateSQL, "'?'", "?", -1)
|
||||
templateSQL = strings.Replace(templateSQL, "`?`", "?", -1)
|
||||
req.SQL = templateSQL
|
||||
req.Params = sqlParams
|
||||
response, err := c.config.CustomSQLExecutor.Execute(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret, err := responseFormatted(c.config.OutputConfig, response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
@@ -0,0 +1,110 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/bytedance/mockey"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database/databasemock"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
)
|
||||
|
||||
type mockCustomSQLer struct {
|
||||
validate func(req *database.CustomSQLRequest)
|
||||
}
|
||||
|
||||
func (m mockCustomSQLer) Execute() func(ctx context.Context, request *database.CustomSQLRequest) (*database.Response, error) {
|
||||
return func(ctx context.Context, request *database.CustomSQLRequest) (*database.Response, error) {
|
||||
m.validate(request)
|
||||
r := &database.Response{
|
||||
Objects: []database.Object{
|
||||
database.Object{
|
||||
"v1": "v1_ret",
|
||||
"v2": "v2_ret",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
}
|
||||
|
||||
func TestCustomSQL_Execute(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSQLer := mockCustomSQLer{
|
||||
validate: func(req *database.CustomSQLRequest) {
|
||||
assert.Equal(t, int64(111), req.DatabaseInfoID)
|
||||
ps := []database.SQLParam{
|
||||
database.SQLParam{Value: "v1_value"},
|
||||
database.SQLParam{Value: "v2_value"},
|
||||
database.SQLParam{Value: "v3_value"},
|
||||
}
|
||||
assert.Equal(t, ps, req.Params)
|
||||
assert.Equal(t, "select * from v1 where v1 = ? and v2 = ? and v3 = ?", req.SQL)
|
||||
},
|
||||
}
|
||||
|
||||
defer mockey.Mock(execute.GetExeCtx).Return(&execute.Context{
|
||||
RootCtx: execute.RootCtx{
|
||||
ExeCfg: vo.ExecuteConfig{
|
||||
Mode: vo.ExecuteModeDebug,
|
||||
Operator: 123,
|
||||
BizType: vo.BizTypeWorkflow,
|
||||
},
|
||||
},
|
||||
}).Build().UnPatch()
|
||||
|
||||
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
|
||||
mockDatabaseOperator.EXPECT().Execute(gomock.Any(), gomock.Any()).DoAndReturn(mockSQLer.Execute()).AnyTimes()
|
||||
|
||||
cfg := &CustomSQLConfig{
|
||||
DatabaseInfoID: 111,
|
||||
SQLTemplate: "select * from v1 where v1 = {{v1}} and v2 = '{{v2}}' and v3 = `{{v3}}`",
|
||||
CustomSQLExecutor: mockDatabaseOperator,
|
||||
OutputConfig: map[string]*vo.TypeInfo{
|
||||
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"v1": {Type: vo.DataTypeString},
|
||||
"v2": {Type: vo.DataTypeString},
|
||||
}}},
|
||||
"rowNum": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
}
|
||||
cl := &CustomSQL{
|
||||
config: cfg,
|
||||
}
|
||||
|
||||
ret, err := cl.Execute(t.Context(), map[string]any{
|
||||
"v1": "v1_value",
|
||||
"v2": "v2_value",
|
||||
"v3": "v3_value",
|
||||
})
|
||||
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, "v1_ret", ret[outputList].([]any)[0].(database.Object)["v1"])
|
||||
assert.Equal(t, "v2_ret", ret[outputList].([]any)[0].(database.Object)["v2"])
|
||||
|
||||
}
|
||||
111
backend/domain/workflow/internal/nodes/database/delete.go
Normal file
111
backend/domain/workflow/internal/nodes/database/delete.go
Normal file
@@ -0,0 +1,111 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
)
|
||||
|
||||
type DeleteConfig struct {
|
||||
DatabaseInfoID int64
|
||||
ClauseGroup *database.ClauseGroup
|
||||
OutputConfig map[string]*vo.TypeInfo
|
||||
|
||||
Deleter database.DatabaseOperator
|
||||
}
|
||||
type Delete struct {
|
||||
config *DeleteConfig
|
||||
}
|
||||
|
||||
func NewDelete(_ context.Context, cfg *DeleteConfig) (*Delete, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config is required")
|
||||
}
|
||||
if cfg.DatabaseInfoID == 0 {
|
||||
return nil, errors.New("database info id is required and greater than 0")
|
||||
}
|
||||
|
||||
if cfg.ClauseGroup == nil {
|
||||
return nil, errors.New("clauseGroup is required")
|
||||
}
|
||||
if cfg.Deleter == nil {
|
||||
return nil, errors.New("deleter is required")
|
||||
}
|
||||
|
||||
return &Delete{
|
||||
config: cfg,
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
func (d *Delete) Delete(ctx context.Context, in map[string]any) (map[string]any, error) {
|
||||
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, d.config.ClauseGroup, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request := &database.DeleteRequest{
|
||||
DatabaseInfoID: d.config.DatabaseInfoID,
|
||||
ConditionGroup: conditionGroup,
|
||||
IsDebugRun: isDebugExecute(ctx),
|
||||
UserID: getExecUserID(ctx),
|
||||
}
|
||||
|
||||
response, err := d.config.Deleter.Delete(ctx, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret, err := responseFormatted(d.config.OutputConfig, response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (d *Delete) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) {
|
||||
conditionGroup, err := convertClauseGroupToConditionGroup(context.Background(), d.config.ClauseGroup, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d.toDatabaseDeleteCallbackInput(conditionGroup)
|
||||
}
|
||||
|
||||
func (d *Delete) toDatabaseDeleteCallbackInput(conditionGroup *database.ConditionGroup) (map[string]any, error) {
|
||||
databaseID := d.config.DatabaseInfoID
|
||||
result := make(map[string]any)
|
||||
|
||||
result["databaseInfoList"] = []string{fmt.Sprintf("%d", databaseID)}
|
||||
result["deleteParam"] = map[string]any{}
|
||||
|
||||
condition, err := convertToCondition(conditionGroup)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
type Field struct {
|
||||
FieldID string `json:"fieldId"`
|
||||
IsDistinct bool `json:"isDistinct"`
|
||||
}
|
||||
result["deleteParam"] = map[string]any{
|
||||
"condition": condition}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
102
backend/domain/workflow/internal/nodes/database/insert.go
Normal file
102
backend/domain/workflow/internal/nodes/database/insert.go
Normal file
@@ -0,0 +1,102 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
)
|
||||
|
||||
type InsertConfig struct {
|
||||
DatabaseInfoID int64
|
||||
OutputConfig map[string]*vo.TypeInfo
|
||||
Inserter database.DatabaseOperator
|
||||
}
|
||||
|
||||
type Insert struct {
|
||||
config *InsertConfig
|
||||
}
|
||||
|
||||
func NewInsert(_ context.Context, cfg *InsertConfig) (*Insert, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config is required")
|
||||
}
|
||||
if cfg.DatabaseInfoID == 0 {
|
||||
return nil, errors.New("database info id is required and greater than 0")
|
||||
}
|
||||
|
||||
if cfg.Inserter == nil {
|
||||
return nil, errors.New("inserter is required")
|
||||
}
|
||||
return &Insert{
|
||||
config: cfg,
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
func (is *Insert) Insert(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
|
||||
fields := parseToInput(input)
|
||||
req := &database.InsertRequest{
|
||||
DatabaseInfoID: is.config.DatabaseInfoID,
|
||||
Fields: fields,
|
||||
IsDebugRun: isDebugExecute(ctx),
|
||||
UserID: getExecUserID(ctx),
|
||||
}
|
||||
|
||||
response, err := is.config.Inserter.Insert(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret, err := responseFormatted(is.config.OutputConfig, response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (is *Insert) ToCallbackInput(_ context.Context, input map[string]any) (map[string]any, error) {
|
||||
databaseID := is.config.DatabaseInfoID
|
||||
fs := parseToInput(input)
|
||||
result := make(map[string]any)
|
||||
result["databaseInfoList"] = []string{fmt.Sprintf("%d", databaseID)}
|
||||
|
||||
type FieldInfo struct {
|
||||
FieldID string `json:"fieldId"`
|
||||
FieldValue any `json:"fieldValue"`
|
||||
}
|
||||
|
||||
fieldInfo := make([]*FieldInfo, 0)
|
||||
for k, v := range fs {
|
||||
fieldInfo = append(fieldInfo, &FieldInfo{
|
||||
FieldID: k,
|
||||
FieldValue: v,
|
||||
})
|
||||
}
|
||||
result["insertParam"] = map[string]any{
|
||||
"fieldInfo": fieldInfo,
|
||||
}
|
||||
|
||||
return result, nil
|
||||
|
||||
}
|
||||
221
backend/domain/workflow/internal/nodes/database/query.go
Normal file
221
backend/domain/workflow/internal/nodes/database/query.go
Normal file
@@ -0,0 +1,221 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
)
|
||||
|
||||
type QueryConfig struct {
|
||||
DatabaseInfoID int64
|
||||
QueryFields []string
|
||||
OrderClauses []*database.OrderClause
|
||||
OutputConfig map[string]*vo.TypeInfo
|
||||
ClauseGroup *database.ClauseGroup
|
||||
Limit int64
|
||||
Op database.DatabaseOperator
|
||||
}
|
||||
|
||||
type Query struct {
|
||||
config *QueryConfig
|
||||
}
|
||||
|
||||
func NewQuery(_ context.Context, cfg *QueryConfig) (*Query, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config is required")
|
||||
}
|
||||
if cfg.DatabaseInfoID == 0 {
|
||||
return nil, errors.New("database info id is required and greater than 0")
|
||||
}
|
||||
|
||||
if cfg.Limit == 0 {
|
||||
return nil, errors.New("limit is required and greater than 0")
|
||||
}
|
||||
|
||||
if cfg.Op == nil {
|
||||
return nil, errors.New("op is required")
|
||||
}
|
||||
|
||||
return &Query{config: cfg}, nil
|
||||
|
||||
}
|
||||
|
||||
func (ds *Query) Query(ctx context.Context, in map[string]any) (map[string]any, error) {
|
||||
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, ds.config.ClauseGroup, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req := &database.QueryRequest{
|
||||
DatabaseInfoID: ds.config.DatabaseInfoID,
|
||||
OrderClauses: ds.config.OrderClauses,
|
||||
SelectFields: ds.config.QueryFields,
|
||||
Limit: ds.config.Limit,
|
||||
IsDebugRun: isDebugExecute(ctx),
|
||||
UserID: getExecUserID(ctx),
|
||||
}
|
||||
|
||||
req.ConditionGroup = conditionGroup
|
||||
|
||||
response, err := ds.config.Op.Query(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret, err := responseFormatted(ds.config.OutputConfig, response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func notNeedTakeMapValue(op database.Operator) bool {
|
||||
return op == database.OperatorIsNull || op == database.OperatorIsNotNull
|
||||
}
|
||||
|
||||
func (ds *Query) ToCallbackInput(ctx context.Context, in map[string]any) (map[string]any, error) {
|
||||
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, ds.config.ClauseGroup, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toDatabaseQueryCallbackInput(ds.config, conditionGroup)
|
||||
}
|
||||
|
||||
func toDatabaseQueryCallbackInput(config *QueryConfig, conditionGroup *database.ConditionGroup) (map[string]any, error) {
|
||||
result := make(map[string]any)
|
||||
|
||||
databaseID := config.DatabaseInfoID
|
||||
result["databaseInfoList"] = []string{fmt.Sprintf("%d", databaseID)}
|
||||
result["selectParam"] = map[string]any{}
|
||||
|
||||
condition, err := convertToCondition(conditionGroup)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
type Field struct {
|
||||
FieldID string `json:"fieldId"`
|
||||
IsDistinct bool `json:"isDistinct"`
|
||||
}
|
||||
fieldList := make([]Field, 0, len(config.QueryFields))
|
||||
for _, f := range config.QueryFields {
|
||||
fieldList = append(fieldList, Field{FieldID: f})
|
||||
}
|
||||
type Order struct {
|
||||
FieldID string `json:"fieldId"`
|
||||
IsAsc bool `json:"isAsc"`
|
||||
}
|
||||
|
||||
OrderList := make([]Order, 0)
|
||||
for _, c := range config.OrderClauses {
|
||||
OrderList = append(OrderList, Order{
|
||||
FieldID: c.FieldID,
|
||||
IsAsc: c.IsAsc,
|
||||
})
|
||||
}
|
||||
result["selectParam"] = map[string]any{
|
||||
"condition": condition,
|
||||
"fieldList": fieldList,
|
||||
"limit": config.Limit,
|
||||
"orderByList": OrderList,
|
||||
}
|
||||
|
||||
return result, nil
|
||||
|
||||
}
|
||||
|
||||
type ConditionItem struct {
|
||||
Left string `json:"left"`
|
||||
Operation string `json:"operation"`
|
||||
Right any `json:"right"`
|
||||
}
|
||||
type Condition struct {
|
||||
ConditionList []ConditionItem `json:"conditionList"`
|
||||
Logic string `json:"logic"`
|
||||
}
|
||||
|
||||
func convertToCondition(conditionGroup *database.ConditionGroup) (*Condition, error) {
|
||||
logic, err := convertToLogic(conditionGroup.Relation)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
condition := &Condition{
|
||||
ConditionList: make([]ConditionItem, 0),
|
||||
Logic: logic,
|
||||
}
|
||||
for _, c := range conditionGroup.Conditions {
|
||||
op, err := convertToOperation(c.Operator)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid operator: %s", c.Operator)
|
||||
}
|
||||
condition.ConditionList = append(condition.ConditionList, ConditionItem{
|
||||
Left: c.Left,
|
||||
Operation: op,
|
||||
Right: c.Right,
|
||||
})
|
||||
|
||||
}
|
||||
return condition, nil
|
||||
}
|
||||
|
||||
func convertToOperation(Op database.Operator) (string, error) {
|
||||
switch Op {
|
||||
case database.OperatorEqual:
|
||||
return "EQUAL", nil
|
||||
case database.OperatorNotEqual:
|
||||
return "NOT_EQUAL", nil
|
||||
case database.OperatorGreater:
|
||||
return "GREATER_THAN", nil
|
||||
case database.OperatorLesser:
|
||||
return "LESS_THAN", nil
|
||||
case database.OperatorGreaterOrEqual:
|
||||
return "GREATER_EQUAL", nil
|
||||
case database.OperatorLesserOrEqual:
|
||||
return "LESS_EQUAL", nil
|
||||
case database.OperatorIn:
|
||||
return "IN", nil
|
||||
case database.OperatorNotIn:
|
||||
return "NOT_IN", nil
|
||||
case database.OperatorIsNull:
|
||||
return "IS_NULL", nil
|
||||
case database.OperatorIsNotNull:
|
||||
return "IS_NOT_NULL", nil
|
||||
case database.OperatorLike:
|
||||
return "LIKE", nil
|
||||
case database.OperatorNotLike:
|
||||
return "NOT LIKE", nil
|
||||
}
|
||||
return "", fmt.Errorf("not a valid database Operator")
|
||||
|
||||
}
|
||||
|
||||
func convertToLogic(rel database.ClauseRelation) (string, error) {
|
||||
switch rel {
|
||||
case database.ClauseRelationOR:
|
||||
return "OR", nil
|
||||
case database.ClauseRelationAND:
|
||||
return "AND", nil
|
||||
default:
|
||||
return "", fmt.Errorf("unknown clause relation %v", rel)
|
||||
|
||||
}
|
||||
}
|
||||
456
backend/domain/workflow/internal/nodes/database/query_test.go
Normal file
456
backend/domain/workflow/internal/nodes/database/query_test.go
Normal file
@@ -0,0 +1,456 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/bytedance/mockey"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database/databasemock"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
)
|
||||
|
||||
type mockDsSelect struct {
|
||||
t *testing.T
|
||||
objects []database.Object
|
||||
validate func(request *database.QueryRequest)
|
||||
}
|
||||
|
||||
func (m *mockDsSelect) Query() func(ctx context.Context, request *database.QueryRequest) (*database.Response, error) {
|
||||
return func(ctx context.Context, request *database.QueryRequest) (*database.Response, error) {
|
||||
n := int64(1)
|
||||
|
||||
m.validate(request)
|
||||
|
||||
return &database.Response{
|
||||
RowNumber: &n,
|
||||
Objects: m.objects,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func TestDataset_Query(t *testing.T) {
|
||||
defer mockey.Mock(execute.GetExeCtx).Return(&execute.Context{
|
||||
RootCtx: execute.RootCtx{
|
||||
ExeCfg: vo.ExecuteConfig{
|
||||
Mode: vo.ExecuteModeDebug,
|
||||
Operator: 123,
|
||||
BizType: vo.BizTypeWorkflow,
|
||||
},
|
||||
},
|
||||
}).Build().UnPatch()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
t.Run("string case", func(t *testing.T) {
|
||||
t.Run("single", func(t *testing.T) {
|
||||
objects := make([]database.Object, 0)
|
||||
objects = append(objects, database.Object{
|
||||
"v1": "1",
|
||||
"v2": int64(2),
|
||||
})
|
||||
|
||||
cfg := &QueryConfig{
|
||||
DatabaseInfoID: 111,
|
||||
ClauseGroup: &database.ClauseGroup{
|
||||
Single: &database.Clause{
|
||||
Left: "v1",
|
||||
Operator: database.OperatorLike,
|
||||
},
|
||||
},
|
||||
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
|
||||
QueryFields: []string{"v1", "v2"},
|
||||
OutputConfig: map[string]*vo.TypeInfo{
|
||||
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
"v1": {Type: vo.DataTypeString},
|
||||
"v2": {Type: vo.DataTypeString},
|
||||
},
|
||||
}},
|
||||
"rowNum": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
}
|
||||
|
||||
mockQuery := &mockDsSelect{objects: objects, t: t, validate: func(request *database.QueryRequest) {
|
||||
if request.DatabaseInfoID != cfg.DatabaseInfoID {
|
||||
t.Fatal("database id should be equal")
|
||||
}
|
||||
cGroup := request.ConditionGroup
|
||||
assert.Equal(t, cGroup.Conditions[0].Left, cfg.ClauseGroup.Single.Left)
|
||||
assert.Equal(t, cGroup.Conditions[0].Operator, cfg.ClauseGroup.Single.Operator)
|
||||
|
||||
}}
|
||||
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
|
||||
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query())
|
||||
|
||||
cfg.Op = mockDatabaseOperator
|
||||
|
||||
ds := Query{
|
||||
config: cfg,
|
||||
}
|
||||
|
||||
in := map[string]interface{}{
|
||||
"__condition_right_0": 1,
|
||||
}
|
||||
|
||||
result, err := ds.Query(t.Context(), in)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "1", result["outputList"].([]any)[0].(database.Object)["v1"])
|
||||
assert.Equal(t, "2", result["outputList"].([]any)[0].(database.Object)["v2"])
|
||||
})
|
||||
|
||||
t.Run("multi", func(t *testing.T) {
|
||||
cfg := &QueryConfig{
|
||||
DatabaseInfoID: 111,
|
||||
ClauseGroup: &database.ClauseGroup{
|
||||
Multi: &database.MultiClause{
|
||||
Relation: database.ClauseRelationOR,
|
||||
Clauses: []*database.Clause{
|
||||
{Left: "v1", Operator: database.OperatorLike},
|
||||
{Left: "v2", Operator: database.OperatorLike},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
|
||||
QueryFields: []string{"v1", "v2"},
|
||||
|
||||
OutputConfig: map[string]*vo.TypeInfo{
|
||||
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
"v1": {Type: vo.DataTypeString},
|
||||
"v2": {Type: vo.DataTypeString},
|
||||
},
|
||||
}},
|
||||
"rowNum": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
}
|
||||
|
||||
objects := make([]database.Object, 0)
|
||||
objects = append(objects, database.Object{
|
||||
"v1": "1",
|
||||
"v2": int64(2),
|
||||
})
|
||||
|
||||
mockQuery := &mockDsSelect{objects: objects, t: t, validate: func(request *database.QueryRequest) {
|
||||
if request.DatabaseInfoID != cfg.DatabaseInfoID {
|
||||
t.Fatal("database id should be equal")
|
||||
|
||||
}
|
||||
cGroup := request.ConditionGroup
|
||||
assert.Equal(t, cGroup.Conditions[0].Right, 1)
|
||||
assert.Equal(t, cGroup.Conditions[1].Right, 2)
|
||||
assert.Equal(t, cGroup.Relation, cfg.ClauseGroup.Multi.Relation)
|
||||
|
||||
}}
|
||||
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
|
||||
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
|
||||
|
||||
cfg.Op = mockDatabaseOperator
|
||||
|
||||
ds := Query{
|
||||
config: cfg,
|
||||
}
|
||||
|
||||
in := map[string]any{
|
||||
"__condition_right_0": 1,
|
||||
"__condition_right_1": 2,
|
||||
}
|
||||
|
||||
result, err := ds.Query(t.Context(), in)
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "1", result["outputList"].([]any)[0].(database.Object)["v1"])
|
||||
assert.Equal(t, "2", result["outputList"].([]any)[0].(database.Object)["v2"])
|
||||
})
|
||||
|
||||
t.Run("formated error", func(t *testing.T) {
|
||||
cfg := &QueryConfig{
|
||||
DatabaseInfoID: 111,
|
||||
ClauseGroup: &database.ClauseGroup{
|
||||
Single: &database.Clause{
|
||||
Left: "v1",
|
||||
Operator: database.OperatorLike,
|
||||
},
|
||||
},
|
||||
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
|
||||
QueryFields: []string{"v1", "v2"},
|
||||
|
||||
OutputConfig: map[string]*vo.TypeInfo{
|
||||
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
"v1": {Type: vo.DataTypeInteger},
|
||||
"v2": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
}},
|
||||
"rowNum": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
}
|
||||
objects := make([]database.Object, 0)
|
||||
objects = append(objects, database.Object{
|
||||
"v1": "abc",
|
||||
"v2": int64(2),
|
||||
})
|
||||
|
||||
mockQuery := &mockDsSelect{objects: objects, t: t, validate: func(request *database.QueryRequest) {
|
||||
if request.DatabaseInfoID != cfg.DatabaseInfoID {
|
||||
t.Fatal("database id should be equal")
|
||||
|
||||
}
|
||||
cGroup := request.ConditionGroup
|
||||
assert.Equal(t, cGroup.Conditions[0].Left, cfg.ClauseGroup.Single.Left)
|
||||
assert.Equal(t, cGroup.Conditions[0].Operator, cfg.ClauseGroup.Single.Operator)
|
||||
|
||||
}}
|
||||
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
|
||||
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
|
||||
|
||||
cfg.Op = mockDatabaseOperator
|
||||
|
||||
ds := Query{
|
||||
config: cfg,
|
||||
}
|
||||
|
||||
in := map[string]any{
|
||||
"__condition_right_0": 1,
|
||||
}
|
||||
|
||||
result, err := ds.Query(t.Context(), in)
|
||||
assert.NoError(t, err)
|
||||
fmt.Println(result)
|
||||
assert.Equal(t, map[string]any{
|
||||
"v1": nil,
|
||||
"v2": int64(2),
|
||||
}, result["outputList"].([]any)[0])
|
||||
|
||||
})
|
||||
|
||||
t.Run("redundancy return field", func(t *testing.T) {
|
||||
cfg := &QueryConfig{
|
||||
DatabaseInfoID: 111,
|
||||
ClauseGroup: &database.ClauseGroup{
|
||||
Single: &database.Clause{
|
||||
Left: "v1",
|
||||
Operator: database.OperatorLike,
|
||||
},
|
||||
},
|
||||
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
|
||||
QueryFields: []string{"v1", "v2"},
|
||||
|
||||
OutputConfig: map[string]*vo.TypeInfo{
|
||||
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
"v1": {Type: vo.DataTypeInteger},
|
||||
"v2": {Type: vo.DataTypeInteger},
|
||||
"v3": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
}},
|
||||
"rowNum": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
}
|
||||
objects := make([]database.Object, 0)
|
||||
objects = append(objects, database.Object{
|
||||
"v1": "1",
|
||||
"v2": int64(2),
|
||||
})
|
||||
mockQuery := &mockDsSelect{objects: objects, t: t, validate: func(request *database.QueryRequest) {
|
||||
if request.DatabaseInfoID != cfg.DatabaseInfoID {
|
||||
t.Fatal("database id should be equal")
|
||||
}
|
||||
cGroup := request.ConditionGroup
|
||||
assert.Equal(t, cGroup.Conditions[0].Left, cfg.ClauseGroup.Single.Left)
|
||||
assert.Equal(t, cGroup.Conditions[0].Operator, cfg.ClauseGroup.Single.Operator)
|
||||
}}
|
||||
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
|
||||
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
|
||||
|
||||
cfg.Op = mockDatabaseOperator
|
||||
|
||||
ds := Query{
|
||||
config: cfg,
|
||||
}
|
||||
|
||||
in := map[string]any{"__condition_right_0": 1}
|
||||
|
||||
result, err := ds.Query(t.Context(), in)
|
||||
assert.NoError(t, err)
|
||||
fmt.Println(result)
|
||||
assert.Equal(t, int64(1), result["outputList"].([]any)[0].(database.Object)["v1"])
|
||||
assert.Equal(t, int64(2), result["outputList"].([]any)[0].(database.Object)["v2"])
|
||||
assert.Equal(t, nil, result["outputList"].([]any)[0].(database.Object)["v3"])
|
||||
|
||||
})
|
||||
|
||||
})
|
||||
|
||||
t.Run("other case", func(t *testing.T) {
|
||||
|
||||
cfg := &QueryConfig{
|
||||
DatabaseInfoID: 111,
|
||||
ClauseGroup: &database.ClauseGroup{
|
||||
Single: &database.Clause{
|
||||
Left: "v1",
|
||||
Operator: database.OperatorLike,
|
||||
},
|
||||
},
|
||||
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
|
||||
QueryFields: []string{"v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8"},
|
||||
|
||||
OutputConfig: map[string]*vo.TypeInfo{
|
||||
"outputList": {Type: vo.DataTypeArray,
|
||||
ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"v1": {Type: vo.DataTypeInteger},
|
||||
"v2": {Type: vo.DataTypeNumber},
|
||||
"v3": {Type: vo.DataTypeBoolean},
|
||||
"v4": {Type: vo.DataTypeBoolean},
|
||||
"v5": {Type: vo.DataTypeTime},
|
||||
"v6": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeInteger}},
|
||||
"v7": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeBoolean}},
|
||||
"v8": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
},
|
||||
}},
|
||||
"rowNum": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
}
|
||||
|
||||
objects := make([]database.Object, 0)
|
||||
objects = append(objects, database.Object{
|
||||
"v1": "1",
|
||||
"v2": "2.1",
|
||||
"v3": int64(0),
|
||||
"v4": "true",
|
||||
"v5": "2020-02-20T10:10:10",
|
||||
"v6": `["1","2","3"]`,
|
||||
"v7": `[false,true,"true"]`,
|
||||
"v8": `["1.2",2.1, 3.9]`,
|
||||
})
|
||||
|
||||
mockQuery := &mockDsSelect{objects: objects, t: t, validate: func(request *database.QueryRequest) {
|
||||
if request.DatabaseInfoID != cfg.DatabaseInfoID {
|
||||
t.Fatal("database id should be equal")
|
||||
}
|
||||
cGroup := request.ConditionGroup
|
||||
assert.Equal(t, cGroup.Conditions[0].Left, cfg.ClauseGroup.Single.Left)
|
||||
assert.Equal(t, cGroup.Conditions[0].Operator, cfg.ClauseGroup.Single.Operator)
|
||||
|
||||
}}
|
||||
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
|
||||
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
|
||||
|
||||
cfg.Op = mockDatabaseOperator
|
||||
|
||||
ds := Query{
|
||||
config: cfg,
|
||||
}
|
||||
|
||||
in := map[string]any{
|
||||
"__condition_right_0": 1,
|
||||
}
|
||||
|
||||
result, err := ds.Query(t.Context(), in)
|
||||
assert.NoError(t, err)
|
||||
object := result["outputList"].([]any)[0].(database.Object)
|
||||
|
||||
assert.Equal(t, int64(1), object["v1"])
|
||||
assert.Equal(t, 2.1, object["v2"])
|
||||
assert.Equal(t, false, object["v3"])
|
||||
assert.Equal(t, true, object["v4"])
|
||||
assert.Equal(t, "2020-02-20T10:10:10", object["v5"])
|
||||
assert.Equal(t, []any{int64(1), int64(2), int64(3)}, object["v6"])
|
||||
assert.Equal(t, []any{false, true, true}, object["v7"])
|
||||
assert.Equal(t, []any{1.2, 2.1, 3.9}, object["v8"])
|
||||
|
||||
})
|
||||
|
||||
t.Run("config output list is nil", func(t *testing.T) {
|
||||
|
||||
cfg := &QueryConfig{
|
||||
DatabaseInfoID: 111,
|
||||
ClauseGroup: &database.ClauseGroup{
|
||||
Single: &database.Clause{
|
||||
Left: "v1",
|
||||
Operator: database.OperatorLike,
|
||||
},
|
||||
},
|
||||
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
|
||||
QueryFields: []string{"v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8"},
|
||||
OutputConfig: map[string]*vo.TypeInfo{
|
||||
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{}}},
|
||||
"rowNum": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
}
|
||||
|
||||
objects := make([]database.Object, 0)
|
||||
objects = append(objects, database.Object{
|
||||
"v1": int64(1),
|
||||
"v2": "2.1",
|
||||
"v3": int64(0),
|
||||
"v4": "true",
|
||||
"v5": "2020-02-20T10:10:10",
|
||||
"v6": `["1","2","3"]`,
|
||||
"v7": `[false,true,"true"]`,
|
||||
"v8": `["1.2",2.1, 3.9]`,
|
||||
})
|
||||
mockQuery := &mockDsSelect{objects: objects, t: t, validate: func(request *database.QueryRequest) {
|
||||
if request.DatabaseInfoID != cfg.DatabaseInfoID {
|
||||
t.Fatal("database id should be equal")
|
||||
}
|
||||
cGroup := request.ConditionGroup
|
||||
assert.Equal(t, cGroup.Conditions[0].Left, cfg.ClauseGroup.Single.Left)
|
||||
assert.Equal(t, cGroup.Conditions[0].Operator, cfg.ClauseGroup.Single.Operator)
|
||||
|
||||
}}
|
||||
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
|
||||
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
|
||||
|
||||
cfg.Op = mockDatabaseOperator
|
||||
ds := Query{
|
||||
config: cfg,
|
||||
}
|
||||
|
||||
in := map[string]any{
|
||||
"__condition_right_0": 1,
|
||||
}
|
||||
|
||||
result, err := ds.Query(t.Context(), in)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, result["outputList"].([]any)[0].(database.Object), database.Object{
|
||||
"v1": "1",
|
||||
"v2": "2.1",
|
||||
"v3": "0",
|
||||
"v4": "true",
|
||||
"v5": "2020-02-20T10:10:10",
|
||||
"v6": `["1","2","3"]`,
|
||||
"v7": `[false,true,"true"]`,
|
||||
"v8": `["1.2",2.1, 3.9]`,
|
||||
})
|
||||
|
||||
})
|
||||
|
||||
}
|
||||
133
backend/domain/workflow/internal/nodes/database/update.go
Normal file
133
backend/domain/workflow/internal/nodes/database/update.go
Normal file
@@ -0,0 +1,133 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
)
|
||||
|
||||
type UpdateConfig struct {
|
||||
DatabaseInfoID int64
|
||||
ClauseGroup *database.ClauseGroup
|
||||
OutputConfig map[string]*vo.TypeInfo
|
||||
Updater database.DatabaseOperator
|
||||
}
|
||||
|
||||
type Update struct {
|
||||
config *UpdateConfig
|
||||
}
|
||||
type UpdateInventory struct {
|
||||
ConditionGroup *database.ConditionGroup
|
||||
Fields map[string]any
|
||||
}
|
||||
|
||||
func NewUpdate(_ context.Context, cfg *UpdateConfig) (*Update, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config is required")
|
||||
}
|
||||
if cfg.DatabaseInfoID == 0 {
|
||||
return nil, errors.New("database info id is required and greater than 0")
|
||||
}
|
||||
|
||||
if cfg.ClauseGroup == nil {
|
||||
return nil, errors.New("clause group is required and greater than 0")
|
||||
}
|
||||
|
||||
if cfg.Updater == nil {
|
||||
return nil, errors.New("updater is required")
|
||||
}
|
||||
|
||||
return &Update{config: cfg}, nil
|
||||
}
|
||||
|
||||
func (u *Update) Update(ctx context.Context, in map[string]any) (map[string]any, error) {
|
||||
inventory, err := convertClauseGroupToUpdateInventory(ctx, u.config.ClauseGroup, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fields := make(map[string]any)
|
||||
|
||||
for key, value := range inventory.Fields {
|
||||
fields[key] = value
|
||||
}
|
||||
|
||||
req := &database.UpdateRequest{
|
||||
DatabaseInfoID: u.config.DatabaseInfoID,
|
||||
ConditionGroup: inventory.ConditionGroup,
|
||||
Fields: fields,
|
||||
IsDebugRun: isDebugExecute(ctx),
|
||||
UserID: getExecUserID(ctx),
|
||||
}
|
||||
|
||||
response, err := u.config.Updater.Update(ctx, req)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret, err := responseFormatted(u.config.OutputConfig, response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (u *Update) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) {
|
||||
inventory, err := convertClauseGroupToUpdateInventory(context.Background(), u.config.ClauseGroup, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return u.toDatabaseUpdateCallbackInput(inventory)
|
||||
}
|
||||
|
||||
func (u *Update) toDatabaseUpdateCallbackInput(inventory *UpdateInventory) (map[string]any, error) {
|
||||
databaseID := u.config.DatabaseInfoID
|
||||
result := make(map[string]any)
|
||||
result["databaseInfoList"] = []string{fmt.Sprintf("%d", databaseID)}
|
||||
result["updateParam"] = map[string]any{}
|
||||
|
||||
condition, err := convertToCondition(inventory.ConditionGroup)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
type FieldInfo struct {
|
||||
fieldID string
|
||||
fieldValue any
|
||||
}
|
||||
|
||||
fieldInfo := make([]FieldInfo, 0)
|
||||
for k, v := range inventory.Fields {
|
||||
fieldInfo = append(fieldInfo, FieldInfo{
|
||||
fieldID: k,
|
||||
fieldValue: v,
|
||||
})
|
||||
}
|
||||
|
||||
result["updateParam"] = map[string]any{
|
||||
"condition": condition,
|
||||
"fieldInfo": fieldInfo,
|
||||
}
|
||||
return result, nil
|
||||
|
||||
}
|
||||
555
backend/domain/workflow/internal/nodes/emitter/emitter.go
Normal file
555
backend/domain/workflow/internal/nodes/emitter/emitter.go
Normal file
@@ -0,0 +1,555 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package emitter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
)
|
||||
|
||||
type OutputEmitter struct {
|
||||
cfg *Config
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Template string
|
||||
FullSources map[string]*nodes.SourceInfo
|
||||
}
|
||||
|
||||
func New(_ context.Context, cfg *Config) (*OutputEmitter, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config is required")
|
||||
}
|
||||
|
||||
return &OutputEmitter{
|
||||
cfg: cfg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type cachedVal struct {
|
||||
val any
|
||||
finished bool
|
||||
subCaches *cacheStore
|
||||
}
|
||||
|
||||
type cacheStore struct {
|
||||
store map[string]*cachedVal
|
||||
infos map[string]*nodes.SourceInfo
|
||||
}
|
||||
|
||||
func newCacheStore(infos map[string]*nodes.SourceInfo) *cacheStore {
|
||||
return &cacheStore{
|
||||
store: make(map[string]*cachedVal),
|
||||
infos: infos,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cacheStore) put(k string, v any) (any, error) {
|
||||
sInfo, ok := c.infos[k]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no such key found from SourceInfos: %s", k)
|
||||
}
|
||||
|
||||
if !sInfo.IsIntermediate { // this is not an intermediate object container
|
||||
isStream := sInfo.FieldType == nodes.FieldIsStream
|
||||
if !isStream {
|
||||
_, ok := c.store[k]
|
||||
if !ok {
|
||||
out := &cachedVal{
|
||||
val: v,
|
||||
finished: true,
|
||||
}
|
||||
c.store[k] = out
|
||||
return v, nil
|
||||
} else {
|
||||
return nil, fmt.Errorf("source %s not intermediate, not stream, appears multiple times", k)
|
||||
}
|
||||
} else { // this is an actual stream
|
||||
vStr, ok := v.(string) // stream value should be string
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("source %s is not intermediate, is stream, but value type not str: %T", k, v)
|
||||
}
|
||||
|
||||
isFinished := strings.HasSuffix(vStr, nodes.KeyIsFinished)
|
||||
if isFinished {
|
||||
vStr = strings.TrimSuffix(vStr, nodes.KeyIsFinished)
|
||||
}
|
||||
|
||||
var existingStr string
|
||||
existing, ok := c.store[k]
|
||||
if ok {
|
||||
existingStr = existing.val.(string)
|
||||
}
|
||||
|
||||
existingStr = existingStr + vStr
|
||||
out := &cachedVal{
|
||||
val: existingStr,
|
||||
finished: isFinished,
|
||||
}
|
||||
c.store[k] = out
|
||||
|
||||
return vStr, nil
|
||||
}
|
||||
}
|
||||
|
||||
if len(sInfo.SubSources) == 0 {
|
||||
// k is intermediate container, needs to check its sub sources
|
||||
return nil, fmt.Errorf("source %s is intermediate, but does not have sub sources", k)
|
||||
}
|
||||
|
||||
vMap, ok := v.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("source %s is intermediate, but value type not map: %T", k, v)
|
||||
}
|
||||
|
||||
currentCache, existed := c.store[k]
|
||||
if !existed {
|
||||
currentCache = &cachedVal{
|
||||
val: v,
|
||||
subCaches: newCacheStore(sInfo.SubSources),
|
||||
}
|
||||
c.store[k] = currentCache
|
||||
} else {
|
||||
// already cached k before, merge cached value with new value
|
||||
currentCache.val = merge(currentCache.val, v)
|
||||
}
|
||||
|
||||
subCacheStore := currentCache.subCaches
|
||||
for subK := range subCacheStore.infos {
|
||||
subV, ok := vMap[subK]
|
||||
if !ok { // subK not present in this chunk
|
||||
continue
|
||||
}
|
||||
_, err := subCacheStore.put(subK, subV)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
_ = c.finished(k)
|
||||
|
||||
return vMap, nil
|
||||
}
|
||||
|
||||
func (c *cacheStore) finished(k string) bool {
|
||||
cached, ok := c.store[k]
|
||||
if !ok {
|
||||
return c.infos[k].FieldType == nodes.FieldSkipped
|
||||
}
|
||||
|
||||
if cached.finished {
|
||||
return true
|
||||
}
|
||||
|
||||
sInfo := c.infos[k]
|
||||
if !sInfo.IsIntermediate {
|
||||
return cached.finished
|
||||
}
|
||||
|
||||
for subK := range sInfo.SubSources {
|
||||
subFinished := cached.subCaches.finished(subK)
|
||||
if !subFinished {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
cached.finished = true
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *cacheStore) find(part nodes.TemplatePart) (root any, subCache *cachedVal, sourceInfo *nodes.SourceInfo,
|
||||
actualPath []string,
|
||||
) {
|
||||
rootCached, ok := c.store[part.Root]
|
||||
if !ok {
|
||||
return nil, nil, nil, nil
|
||||
}
|
||||
|
||||
// now try to find the nearest match within the cached tree
|
||||
subPaths := part.SubPathsBeforeSlice
|
||||
currentCache := rootCached
|
||||
currentSource := c.infos[part.Root]
|
||||
for i := range subPaths {
|
||||
if !currentSource.IsIntermediate {
|
||||
// currentSource is already the leaf, no need to look further
|
||||
break
|
||||
}
|
||||
subPath := subPaths[i]
|
||||
subInfo, ok := currentSource.SubSources[subPath]
|
||||
if !ok {
|
||||
// this sub path is not in the source info tree
|
||||
// it's just a user defined variable field in the template
|
||||
break
|
||||
}
|
||||
|
||||
actualPath = append(actualPath, subPath)
|
||||
|
||||
subCache, ok = currentCache.subCaches.store[subPath]
|
||||
if !ok {
|
||||
// subPath corresponds to a real Field Source,
|
||||
// if it's not cached, then it hasn't appeared in the stream yet
|
||||
return rootCached.val, nil, subInfo, actualPath
|
||||
}
|
||||
if !subCache.finished {
|
||||
return rootCached.val, subCache, subInfo, actualPath
|
||||
}
|
||||
|
||||
currentCache = subCache
|
||||
currentSource = subInfo
|
||||
}
|
||||
|
||||
return rootCached.val, currentCache, currentSource, actualPath
|
||||
}
|
||||
|
||||
func (c *cacheStore) readyForPart(part nodes.TemplatePart, sw *schema.StreamWriter[map[string]any]) (
|
||||
hasErr bool, partFinished bool) {
|
||||
cachedRoot, subCache, sourceInfo, _ := c.find(part)
|
||||
if cachedRoot != nil && subCache != nil {
|
||||
if subCache.finished || sourceInfo.FieldType == nodes.FieldIsStream {
|
||||
hasErr = renderAndSend(part, part.Root, cachedRoot, sw)
|
||||
if hasErr {
|
||||
return true, false
|
||||
}
|
||||
if subCache.finished { // move on to next part in template
|
||||
return false, true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false, false
|
||||
}
|
||||
|
||||
func (c *cacheStore) fillZero(nodeKey vo.NodeKey) map[string]any {
|
||||
filled := make(map[string]any)
|
||||
for field, sInfo := range c.infos {
|
||||
if !sInfo.FromNode(nodeKey) {
|
||||
continue
|
||||
}
|
||||
|
||||
cacheV, ok := c.store[field]
|
||||
if !sInfo.IsIntermediate {
|
||||
if !ok {
|
||||
c.store[field] = &cachedVal{
|
||||
val: sInfo.TypeInfo.Zero(),
|
||||
finished: true,
|
||||
subCaches: nil,
|
||||
}
|
||||
|
||||
filled[field] = true
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if !ok {
|
||||
cacheV = &cachedVal{
|
||||
val: make(map[string]any),
|
||||
subCaches: newCacheStore(sInfo.SubSources),
|
||||
}
|
||||
c.store[field] = cacheV
|
||||
}
|
||||
|
||||
subFilled := cacheV.subCaches.fillZero(nodeKey)
|
||||
if len(subFilled) > 0 {
|
||||
filled[field] = subFilled
|
||||
}
|
||||
}
|
||||
|
||||
return filled
|
||||
}
|
||||
|
||||
func merge(a, b any) any {
|
||||
aStr, ok1 := a.(string)
|
||||
bStr, ok2 := b.(string)
|
||||
if ok1 && ok2 {
|
||||
if strings.HasSuffix(bStr, nodes.KeyIsFinished) {
|
||||
bStr = strings.TrimSuffix(bStr, nodes.KeyIsFinished)
|
||||
}
|
||||
return aStr + bStr
|
||||
}
|
||||
|
||||
aMap, ok1 := a.(map[string]interface{})
|
||||
bMap, ok2 := b.(map[string]interface{})
|
||||
if ok1 && ok2 {
|
||||
merged := make(map[string]any)
|
||||
for k, v := range aMap {
|
||||
merged[k] = v
|
||||
}
|
||||
for k, v := range bMap {
|
||||
if _, ok := merged[k]; !ok { // only bMap has this field, just set it
|
||||
merged[k] = v
|
||||
continue
|
||||
}
|
||||
merged[k] = merge(merged[k], v)
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
panic(fmt.Errorf("can only merge two maps or two strings, a type: %T, b type: %T", a, b))
|
||||
}
|
||||
|
||||
const outputKey = "output"
|
||||
|
||||
func (e *OutputEmitter) EmitStream(ctx context.Context, in *schema.StreamReader[map[string]any]) (out *schema.StreamReader[map[string]any], err error) {
|
||||
resolvedSources, err := nodes.ResolveStreamSources(ctx, e.cfg.FullSources)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sr, sw := schema.Pipe[map[string]any](0)
|
||||
parts := nodes.ParseTemplate(e.cfg.Template)
|
||||
safego.Go(ctx, func() {
|
||||
hasErr := false
|
||||
defer func() {
|
||||
if !hasErr {
|
||||
sw.Send(map[string]any{outputKey: nodes.KeyIsFinished}, nil)
|
||||
}
|
||||
sw.Close()
|
||||
in.Close()
|
||||
}()
|
||||
|
||||
caches := newCacheStore(resolvedSources)
|
||||
|
||||
partsLoop:
|
||||
for _, part := range parts {
|
||||
select {
|
||||
case <-ctx.Done(): // canceled by Eino workflow engine
|
||||
sw.Send(nil, ctx.Err())
|
||||
hasErr = true
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
if !part.IsVariable { // literal string within template, just emit it
|
||||
sw.Send(map[string]any{outputKey: part.Value}, nil)
|
||||
continue
|
||||
}
|
||||
|
||||
// now this 'part' is a variable, first check if the source(s) for it are skipped (the nodes are not selected)
|
||||
// if skipped, just move on to the next 'part'
|
||||
skipped, invalid := part.Skipped(resolvedSources)
|
||||
if skipped {
|
||||
continue
|
||||
}
|
||||
if invalid {
|
||||
sw.Send(map[string]any{outputKey: "{{" + part.Value + "}}"}, nil)
|
||||
continue
|
||||
}
|
||||
|
||||
// now this 'part' definitely should have a match, look for a hit within cache store
|
||||
// if found in cache store, emit the root only if the match is finished or the match is stream
|
||||
// the rule for a hit: the nearest match within the source tree
|
||||
// if hit, and the cachedVal is also finished, continue to next template part
|
||||
var partFinished bool
|
||||
hasErr, partFinished = caches.readyForPart(part, sw)
|
||||
if hasErr {
|
||||
return
|
||||
}
|
||||
if partFinished {
|
||||
continue
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done(): // canceled by Eino workflow engine or timeout
|
||||
sw.Send(nil, ctx.Err())
|
||||
hasErr = true
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
shouldChangePart := false
|
||||
|
||||
chunk, err := in.Recv()
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
// current part is not fulfilled, emit the literal part content and move on to next part
|
||||
sw.Send(map[string]any{outputKey: part.Value}, nil)
|
||||
break
|
||||
}
|
||||
|
||||
if sn, ok := schema.GetSourceName(err); ok {
|
||||
// received end signal for a particular predecessor nodeID, do the following:
|
||||
// - obtain the field sources mapped from this predecessor node
|
||||
// - check which fields are still missing in the cache store
|
||||
// - fill zero value for these missing fields
|
||||
// - check if the current template part should be rendered and sent immediately
|
||||
// - check if we should move on to next part in template
|
||||
filled := caches.fillZero(vo.NodeKey(sn))
|
||||
if _, okk := filled[part.Root]; okk {
|
||||
// current part is influenced by the 'fill zero' operation
|
||||
hasErr, shouldChangePart = caches.readyForPart(part, sw)
|
||||
if hasErr {
|
||||
return
|
||||
}
|
||||
if shouldChangePart {
|
||||
continue partsLoop
|
||||
}
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
hasErr = true
|
||||
sw.Send(nil, err) // real error
|
||||
return
|
||||
}
|
||||
|
||||
chunkLoop:
|
||||
for k := range chunk {
|
||||
v := chunk[k]
|
||||
v, err = caches.put(k, v) // always update the cache
|
||||
if err != nil {
|
||||
hasErr = true
|
||||
sw.Send(nil, err)
|
||||
return
|
||||
}
|
||||
|
||||
// needs to check if this 'k' is the current part's root
|
||||
// if it is, do the case analysis:
|
||||
// - the source is a leaf (not intermediate):
|
||||
// - the source is stream, emit the formatted stream content immediately
|
||||
// - the source is not stream, emit the formatted one-time content immediately
|
||||
// - the source is intermediate:
|
||||
// - the source is not finished, do not emit it
|
||||
// - the source is finished, emit the full content (cached + new) immediately
|
||||
if k == part.Root {
|
||||
cachedRoot, subCache, sourceInfo, actualPath := caches.find(part)
|
||||
if sourceInfo == nil {
|
||||
panic("impossible, k is part.root, but sourceInfo is nil")
|
||||
}
|
||||
|
||||
if subCache != nil {
|
||||
if sourceInfo.IsIntermediate {
|
||||
if subCache.finished {
|
||||
hasErr = renderAndSend(part, part.Root, cachedRoot, sw)
|
||||
if hasErr {
|
||||
return
|
||||
}
|
||||
shouldChangePart = true
|
||||
}
|
||||
} else {
|
||||
if sourceInfo.FieldType == nodes.FieldIsStream {
|
||||
currentV := v
|
||||
for i := 0; i < len(actualPath)-1; i++ {
|
||||
currentM, ok := currentV.(map[string]any)
|
||||
if !ok {
|
||||
panic("emit item not map[string]any")
|
||||
}
|
||||
currentV, ok = currentM[actualPath[i]]
|
||||
if !ok {
|
||||
continue chunkLoop
|
||||
}
|
||||
}
|
||||
|
||||
if len(actualPath) > 0 {
|
||||
finalV, ok := currentV.(map[string]any)[actualPath[len(actualPath)-1]]
|
||||
if !ok {
|
||||
continue chunkLoop
|
||||
}
|
||||
currentV = finalV
|
||||
}
|
||||
vStr, ok := currentV.(string)
|
||||
if !ok {
|
||||
panic(fmt.Errorf("source %s is not intermediate, is stream, but value type not str: %T", k, v))
|
||||
}
|
||||
|
||||
if strings.HasSuffix(vStr, nodes.KeyIsFinished) {
|
||||
vStr = strings.TrimSuffix(vStr, nodes.KeyIsFinished)
|
||||
}
|
||||
|
||||
var delta any
|
||||
delta = vStr
|
||||
for j := len(actualPath) - 1; j >= 0; j-- {
|
||||
delta = map[string]any{
|
||||
actualPath[j]: delta,
|
||||
}
|
||||
}
|
||||
|
||||
hasErr = renderAndSend(part, part.Root, delta, sw)
|
||||
} else {
|
||||
hasErr = renderAndSend(part, part.Root, v, sw)
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
return
|
||||
}
|
||||
if subCache.finished {
|
||||
shouldChangePart = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
if shouldChangePart {
|
||||
continue partsLoop
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return sr, nil
|
||||
}
|
||||
|
||||
func (e *OutputEmitter) Emit(ctx context.Context, in map[string]any) (output map[string]any, err error) {
|
||||
s, err := nodes.Render(ctx, e.cfg.Template, in, e.cfg.FullSources)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
output = map[string]any{
|
||||
outputKey: s,
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}
|
||||
|
||||
func renderAndSend(tp nodes.TemplatePart, k string, v any, sw *schema.StreamWriter[map[string]any]) bool /*hasError*/ {
|
||||
m, err := sonic.Marshal(map[string]any{k: v})
|
||||
if err != nil {
|
||||
sw.Send(nil, err)
|
||||
return true
|
||||
}
|
||||
|
||||
r, err := tp.Render(m)
|
||||
if err != nil {
|
||||
sw.Send(nil, err)
|
||||
return true
|
||||
}
|
||||
|
||||
if len(r) == 0 { // won't send if formatted result is empty string
|
||||
return false
|
||||
}
|
||||
|
||||
logs.Infof("send: %v", r)
|
||||
|
||||
sw.Send(map[string]any{outputKey: r}, nil)
|
||||
return false
|
||||
}
|
||||
62
backend/domain/workflow/internal/nodes/entry/entry.go
Normal file
62
backend/domain/workflow/internal/nodes/entry/entry.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package entry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
DefaultValues map[string]any
|
||||
OutputTypes map[string]*vo.TypeInfo
|
||||
}
|
||||
|
||||
type Entry struct {
|
||||
cfg *Config
|
||||
defaultValues map[string]any
|
||||
}
|
||||
|
||||
func NewEntry(ctx context.Context, cfg *Config) (*Entry, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("config is requried")
|
||||
}
|
||||
defaultValues, _, err := nodes.ConvertInputs(ctx, cfg.DefaultValues, cfg.OutputTypes, nodes.FailFast(), nodes.SkipRequireCheck())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Entry{
|
||||
cfg: cfg,
|
||||
defaultValues: defaultValues,
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
func (e *Entry) Invoke(_ context.Context, in map[string]any) (out map[string]any, err error) {
|
||||
|
||||
for k, v := range e.defaultValues {
|
||||
if val, ok := in[k]; ok {
|
||||
tInfo := e.cfg.OutputTypes[k]
|
||||
switch tInfo.Type {
|
||||
case vo.DataTypeString:
|
||||
if len(val.(string)) == 0 {
|
||||
in[k] = v
|
||||
}
|
||||
case vo.DataTypeArray:
|
||||
if len(val.([]any)) == 0 {
|
||||
in[k] = v
|
||||
}
|
||||
case vo.DataTypeObject:
|
||||
if len(val.(map[string]any)) == 0 {
|
||||
in[k] = v
|
||||
}
|
||||
}
|
||||
} else {
|
||||
in[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return in, err
|
||||
}
|
||||
@@ -0,0 +1,684 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package httprequester
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"maps"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/crypto"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
)
|
||||
|
||||
const defaultGetFileTimeout = 20 // second
|
||||
const maxSize int64 = 20 * 1024 * 1024 // 20MB
|
||||
|
||||
const (
|
||||
HeaderAuthorization = "Authorization"
|
||||
HeaderBearerPrefix = "Bearer "
|
||||
HeaderContentType = "Content-Type"
|
||||
)
|
||||
|
||||
type AuthType uint
|
||||
|
||||
const (
|
||||
BearToken AuthType = 1
|
||||
Custom AuthType = 2
|
||||
)
|
||||
|
||||
const (
|
||||
ContentTypeJSON = "application/json"
|
||||
ContentTypePlainText = "text/plain"
|
||||
ContentTypeFormURLEncoded = "application/x-www-form-urlencoded"
|
||||
ContentTypeBinary = "application/octet-stream"
|
||||
)
|
||||
|
||||
type Location uint8
|
||||
|
||||
const (
|
||||
Header Location = 1
|
||||
QueryParam Location = 2
|
||||
)
|
||||
|
||||
type BodyType string
|
||||
|
||||
const (
|
||||
BodyTypeNone BodyType = "EMPTY"
|
||||
BodyTypeJSON BodyType = "JSON"
|
||||
BodyTypeRawText BodyType = "RAW_TEXT"
|
||||
BodyTypeFormData BodyType = "FORM_DATA"
|
||||
BodyTypeFormURLEncoded BodyType = "FORM_URLENCODED"
|
||||
BodyTypeBinary BodyType = "BINARY"
|
||||
)
|
||||
|
||||
type URLConfig struct {
|
||||
Tpl string `json:"tpl"`
|
||||
}
|
||||
|
||||
type IgnoreExceptionSetting struct {
|
||||
IgnoreException bool `json:"ignore_exception"`
|
||||
DefaultOutput map[string]any `json:"default_output,omitempty"`
|
||||
}
|
||||
|
||||
type BodyConfig struct {
|
||||
BodyType BodyType `json:"body_type"`
|
||||
FormDataConfig *FormDataConfig `json:"form_data_config,omitempty"`
|
||||
TextPlainConfig *TextPlainConfig `json:"text_plain_config,omitempty"`
|
||||
TextJsonConfig *TextJsonConfig `json:"text_json_config,omitempty"`
|
||||
}
|
||||
|
||||
type FormDataConfig struct {
|
||||
FileTypeMapping map[string]bool `json:"file_type_mapping"`
|
||||
}
|
||||
|
||||
type TextPlainConfig struct {
|
||||
Tpl string `json:"tpl"`
|
||||
}
|
||||
|
||||
type TextJsonConfig struct {
|
||||
Tpl string
|
||||
}
|
||||
|
||||
type AuthenticationConfig struct {
|
||||
Type AuthType `json:"type"`
|
||||
Location Location `json:"location"`
|
||||
}
|
||||
|
||||
type Authentication struct {
|
||||
Key string
|
||||
Value string
|
||||
Token string
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
URLVars map[string]any
|
||||
Headers map[string]string
|
||||
Params map[string]string
|
||||
Authentication *Authentication
|
||||
FormDataVars map[string]string
|
||||
FormURLEncodedVars map[string]string
|
||||
JsonVars map[string]any
|
||||
TextPlainVars map[string]any
|
||||
FileURL *string
|
||||
}
|
||||
|
||||
var globalVariableReplaceRegexp = regexp.MustCompile(`global_variable_(\w+)\["(\w+)"\]`)
|
||||
|
||||
type MD5FieldMapping struct {
|
||||
HeaderMD5Mapping map[string]string `json:"header_md_5_mapping,omitempty"` // md5 vs key
|
||||
ParamMD5Mapping map[string]string `json:"param_md_5_mapping,omitempty"`
|
||||
URLMD5Mapping map[string]string `json:"url_md_5_mapping,omitempty"`
|
||||
BodyMD5Mapping map[string]string `json:"body_md_5_mapping,omitempty"`
|
||||
}
|
||||
|
||||
func (fm *MD5FieldMapping) SetHeaderFields(fields ...string) {
|
||||
if fm.HeaderMD5Mapping == nil && len(fields) > 0 {
|
||||
fm.HeaderMD5Mapping = make(map[string]string)
|
||||
}
|
||||
for _, field := range fields {
|
||||
fm.HeaderMD5Mapping[crypto.MD5HexValue(field)] = field
|
||||
}
|
||||
|
||||
}
|
||||
func (fm *MD5FieldMapping) SetParamFields(fields ...string) {
|
||||
if fm.ParamMD5Mapping == nil && len(fields) > 0 {
|
||||
fm.ParamMD5Mapping = make(map[string]string)
|
||||
}
|
||||
|
||||
for _, field := range fields {
|
||||
fm.ParamMD5Mapping[crypto.MD5HexValue(field)] = field
|
||||
}
|
||||
|
||||
}
|
||||
func (fm *MD5FieldMapping) SetURLFields(fields ...string) {
|
||||
if fm.URLMD5Mapping == nil && len(fields) > 0 {
|
||||
fm.URLMD5Mapping = make(map[string]string)
|
||||
}
|
||||
for _, field := range fields {
|
||||
fm.URLMD5Mapping[crypto.MD5HexValue(field)] = field
|
||||
}
|
||||
|
||||
}
|
||||
func (fm *MD5FieldMapping) SetBodyFields(fields ...string) {
|
||||
if fm.BodyMD5Mapping == nil && len(fields) > 0 {
|
||||
fm.BodyMD5Mapping = make(map[string]string)
|
||||
}
|
||||
for _, field := range fields {
|
||||
fm.BodyMD5Mapping[crypto.MD5HexValue(field)] = field
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
URLConfig URLConfig
|
||||
AuthConfig *AuthenticationConfig
|
||||
BodyConfig BodyConfig
|
||||
Method string
|
||||
Timeout time.Duration
|
||||
RetryTimes uint64
|
||||
|
||||
IgnoreException bool
|
||||
DefaultOutput map[string]any
|
||||
|
||||
MD5FieldMapping
|
||||
}
|
||||
|
||||
type HTTPRequester struct {
|
||||
client *http.Client
|
||||
config *Config
|
||||
}
|
||||
|
||||
func NewHTTPRequester(_ context.Context, cfg *Config) (*HTTPRequester, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("config is requried")
|
||||
}
|
||||
|
||||
if len(cfg.Method) == 0 {
|
||||
return nil, fmt.Errorf("method is requried")
|
||||
}
|
||||
|
||||
hg := &HTTPRequester{}
|
||||
client := http.DefaultClient
|
||||
if cfg.Timeout > 0 {
|
||||
client.Timeout = cfg.Timeout
|
||||
}
|
||||
|
||||
hg.client = client
|
||||
hg.config = cfg
|
||||
|
||||
return hg, nil
|
||||
}
|
||||
|
||||
func (hg *HTTPRequester) Invoke(ctx context.Context, input map[string]any) (output map[string]any, err error) {
|
||||
var (
|
||||
req = &Request{}
|
||||
method = hg.config.Method
|
||||
retryTimes = hg.config.RetryTimes
|
||||
body io.ReadCloser
|
||||
contentType string
|
||||
response *http.Response
|
||||
)
|
||||
|
||||
req, err = hg.config.parserToRequest(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpRequest := &http.Request{
|
||||
Method: method,
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
httpURL, err := nodes.TemplateRender(hg.config.URLConfig.Tpl, req.URLVars)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for key, value := range req.Headers {
|
||||
httpRequest.Header.Set(key, value)
|
||||
}
|
||||
|
||||
u, err := url.Parse(httpURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params := u.Query()
|
||||
for key, value := range req.Params {
|
||||
params.Set(key, value)
|
||||
}
|
||||
|
||||
if hg.config.AuthConfig != nil {
|
||||
httpRequest.Header, params, err = hg.config.AuthConfig.addAuthentication(ctx, req.Authentication, httpRequest.Header, params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
u.RawQuery = params.Encode()
|
||||
httpRequest.URL = u
|
||||
|
||||
body, contentType, err = hg.config.BodyConfig.getBodyAndContentType(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if body != nil {
|
||||
httpRequest.Body = body
|
||||
}
|
||||
|
||||
if contentType != "" {
|
||||
httpRequest.Header.Add(HeaderContentType, contentType)
|
||||
}
|
||||
|
||||
for i := uint64(0); i <= retryTimes; i++ {
|
||||
response, err = hg.client.Do(httpRequest)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := make(map[string]any)
|
||||
|
||||
headers := func() string {
|
||||
// The structure of httpResp.Header is map[string][]string
|
||||
// If there are multiple header values, the last one will be selected by default
|
||||
hds := make(map[string]string, len(response.Header))
|
||||
for key, values := range response.Header {
|
||||
if len(values) == 0 {
|
||||
hds[key] = ""
|
||||
} else {
|
||||
hds[key] = values[len(values)-1]
|
||||
}
|
||||
}
|
||||
bs, _ := json.Marshal(hds)
|
||||
return string(bs)
|
||||
}()
|
||||
result["headers"] = headers
|
||||
var bodyBytes []byte
|
||||
|
||||
if response.Body != nil {
|
||||
defer func() {
|
||||
_ = response.Body.Close()
|
||||
}()
|
||||
|
||||
bodyBytes, err = io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if response.StatusCode >= http.StatusBadRequest {
|
||||
return nil, fmt.Errorf("request %v failed, response status code=%d, status=%v, headers=%v, body=%v",
|
||||
httpURL, response.StatusCode, response.Status, headers, string(bodyBytes))
|
||||
}
|
||||
|
||||
result["body"] = string(bodyBytes)
|
||||
result["statusCode"] = int64(response.StatusCode)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// decodeUnicode parses the Unicode escape sequence in the string
|
||||
func decodeUnicode(s string) string {
|
||||
var result strings.Builder
|
||||
for i := 0; i < len(s); {
|
||||
if i+1 < len(s) && s[i] == '\\' && s[i+1] == 'u' {
|
||||
if i+6 <= len(s) {
|
||||
hexStr := s[i+2 : i+6]
|
||||
if code, err := strconv.ParseInt(hexStr, 16, 32); err == nil {
|
||||
result.WriteRune(rune(code))
|
||||
i += 6
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
result.WriteByte(s[i])
|
||||
i++
|
||||
}
|
||||
return result.String()
|
||||
}
|
||||
|
||||
func (authCfg *AuthenticationConfig) addAuthentication(_ context.Context, auth *Authentication, header http.Header, params url.Values) (
|
||||
http.Header, url.Values, error) {
|
||||
|
||||
if authCfg.Type == BearToken {
|
||||
header.Set(HeaderAuthorization, HeaderBearerPrefix+auth.Token)
|
||||
return header, params, nil
|
||||
}
|
||||
if authCfg.Type == Custom && authCfg.Location == Header {
|
||||
header.Set(auth.Key, auth.Value)
|
||||
return header, params, nil
|
||||
}
|
||||
|
||||
if authCfg.Type == Custom && authCfg.Location == QueryParam {
|
||||
params.Set(auth.Key, auth.Value)
|
||||
return header, params, nil
|
||||
}
|
||||
|
||||
return header, params, nil
|
||||
}
|
||||
|
||||
func (b *BodyConfig) getBodyAndContentType(ctx context.Context, req *Request) (io.ReadCloser, string, error) {
|
||||
var (
|
||||
body io.Reader
|
||||
contentType string
|
||||
)
|
||||
|
||||
// body none return body nil
|
||||
if b.BodyType == BodyTypeNone {
|
||||
return nil, "", nil
|
||||
}
|
||||
|
||||
switch b.BodyType {
|
||||
case BodyTypeJSON:
|
||||
jsonString, err := nodes.TemplateRender(b.TextJsonConfig.Tpl, req.JsonVars)
|
||||
if err != nil {
|
||||
return nil, contentType, err
|
||||
}
|
||||
body = strings.NewReader(jsonString)
|
||||
contentType = ContentTypeJSON
|
||||
case BodyTypeFormURLEncoded:
|
||||
form := url.Values{}
|
||||
for key, value := range req.FormURLEncodedVars {
|
||||
form.Add(key, value)
|
||||
}
|
||||
|
||||
body = strings.NewReader(form.Encode())
|
||||
contentType = ContentTypeFormURLEncoded
|
||||
case BodyTypeRawText:
|
||||
textString, err := nodes.TemplateRender(b.TextPlainConfig.Tpl, req.TextPlainVars)
|
||||
if err != nil {
|
||||
return nil, contentType, err
|
||||
}
|
||||
|
||||
body = strings.NewReader(textString)
|
||||
contentType = ContentTypePlainText
|
||||
case BodyTypeBinary:
|
||||
if req.FileURL == nil {
|
||||
return nil, contentType, fmt.Errorf("file url is required")
|
||||
}
|
||||
|
||||
fileURL := *req.FileURL
|
||||
response, err := httpGet(ctx, fileURL)
|
||||
if err != nil {
|
||||
return nil, contentType, err
|
||||
}
|
||||
|
||||
body = response.Body
|
||||
contentType = ContentTypeBinary
|
||||
case BodyTypeFormData:
|
||||
var buffer = &bytes.Buffer{}
|
||||
formDataConfig := b.FormDataConfig
|
||||
writer := multipart.NewWriter(buffer)
|
||||
|
||||
total := int64(0)
|
||||
for key, value := range req.FormDataVars {
|
||||
if ok := formDataConfig.FileTypeMapping[key]; ok {
|
||||
fileWrite, err := writer.CreateFormFile(key, key)
|
||||
if err != nil {
|
||||
return nil, contentType, err
|
||||
}
|
||||
|
||||
response, err := httpGet(ctx, value)
|
||||
if err != nil {
|
||||
return nil, contentType, err
|
||||
}
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
return nil, contentType, fmt.Errorf("failed to download file: %s, status code %v", value, response.StatusCode)
|
||||
}
|
||||
|
||||
size, err := io.Copy(fileWrite, response.Body)
|
||||
if err != nil {
|
||||
return nil, contentType, err
|
||||
}
|
||||
|
||||
total += size
|
||||
if total > maxSize {
|
||||
return nil, contentType, fmt.Errorf("too large body, total size: %d", total)
|
||||
}
|
||||
} else {
|
||||
err := writer.WriteField(key, value)
|
||||
if err != nil {
|
||||
return nil, contentType, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_ = writer.Close()
|
||||
contentType = writer.FormDataContentType()
|
||||
body = buffer
|
||||
default:
|
||||
return nil, contentType, fmt.Errorf("unknown content type %s", b.BodyType)
|
||||
}
|
||||
|
||||
if _, ok := body.(io.ReadCloser); ok {
|
||||
return body.(io.ReadCloser), contentType, nil
|
||||
}
|
||||
|
||||
return io.NopCloser(body), contentType, nil
|
||||
}
|
||||
|
||||
func httpGet(ctx context.Context, url string) (*http.Response, error) {
|
||||
request, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
http.DefaultClient.Timeout = time.Second * defaultGetFileTimeout
|
||||
return http.DefaultClient.Do(request)
|
||||
}
|
||||
|
||||
func (hg *HTTPRequester) ToCallbackInput(_ context.Context, input map[string]any) (map[string]any, error) {
|
||||
var (
|
||||
request = &Request{}
|
||||
config = hg.config
|
||||
)
|
||||
request, err := hg.config.parserToRequest(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := make(map[string]any)
|
||||
result["method"] = config.Method
|
||||
|
||||
u, err := nodes.TemplateRender(config.URLConfig.Tpl, request.URLVars)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result["url"] = u
|
||||
|
||||
params := make(map[string]any, len(request.Params))
|
||||
for k, v := range request.Params {
|
||||
params[k] = v
|
||||
}
|
||||
result["param"] = params
|
||||
|
||||
headers := make(map[string]any, len(request.Headers))
|
||||
for k, v := range request.Headers {
|
||||
headers[k] = v
|
||||
}
|
||||
result["header"] = headers
|
||||
result["auth"] = nil
|
||||
if config.AuthConfig != nil {
|
||||
if config.AuthConfig.Type == Custom {
|
||||
result["auth"] = map[string]interface{}{
|
||||
"Key": request.Authentication.Key,
|
||||
"Value": request.Authentication.Value,
|
||||
}
|
||||
} else if config.AuthConfig.Type == BearToken {
|
||||
result["auth"] = map[string]interface{}{
|
||||
"token": request.Authentication.Token,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result["body"] = nil
|
||||
switch config.BodyConfig.BodyType {
|
||||
case BodyTypeJSON:
|
||||
js, err := nodes.TemplateRender(config.BodyConfig.TextJsonConfig.Tpl, request.JsonVars)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ret := make(map[string]any)
|
||||
err = sonic.Unmarshal([]byte(js), &ret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result["body"] = ret
|
||||
case BodyTypeRawText:
|
||||
tx, err := nodes.TemplateRender(config.BodyConfig.TextPlainConfig.Tpl, request.TextPlainVars)
|
||||
if err != nil {
|
||||
|
||||
return nil, err
|
||||
}
|
||||
result["body"] = tx
|
||||
case BodyTypeFormData:
|
||||
result["body"] = request.FormDataVars
|
||||
case BodyTypeFormURLEncoded:
|
||||
result["body"] = request.FormURLEncodedVars
|
||||
case BodyTypeBinary:
|
||||
result["body"] = request.FileURL
|
||||
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
const (
|
||||
apiInfoURLPrefix = "__apiInfo_url_"
|
||||
headersPrefix = "__headers_"
|
||||
paramsPrefix = "__params_"
|
||||
|
||||
authDataPrefix = "__auth_authData_"
|
||||
authBearerTokenDataPrefix = "bearerTokenData_token"
|
||||
authCustomDataPrefix = "customData_data"
|
||||
|
||||
bodyDataPrefix = "__body_bodyData_"
|
||||
bodyJsonPrefix = "json_"
|
||||
bodyFormDataPrefix = "formData_"
|
||||
bodyFormURLEncodedPrefix = "formURLEncoded_"
|
||||
bodyRawTextPrefix = "rawText_"
|
||||
bodyBinaryFileURLPrefix = "binary_fileURL"
|
||||
)
|
||||
|
||||
func (cfg *Config) parserToRequest(input map[string]any) (*Request, error) {
|
||||
request := &Request{
|
||||
URLVars: make(map[string]any),
|
||||
Headers: make(map[string]string),
|
||||
Params: make(map[string]string),
|
||||
Authentication: &Authentication{},
|
||||
FormURLEncodedVars: make(map[string]string),
|
||||
JsonVars: make(map[string]any),
|
||||
TextPlainVars: make(map[string]any),
|
||||
FormDataVars: map[string]string{},
|
||||
}
|
||||
for key, value := range input {
|
||||
if strings.HasPrefix(key, apiInfoURLPrefix) {
|
||||
urlMD5 := strings.TrimPrefix(key, apiInfoURLPrefix)
|
||||
if urlKey, ok := cfg.URLMD5Mapping[urlMD5]; ok {
|
||||
if strings.HasPrefix(urlKey, "global_variable_") {
|
||||
urlKey = globalVariableReplaceRegexp.ReplaceAllString(urlKey, "global_variable_$1.$2")
|
||||
}
|
||||
nodes.SetMapValue(request.URLVars, strings.Split(urlKey, "."), value.(string))
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(key, headersPrefix) {
|
||||
headerKeyMD5 := strings.TrimPrefix(key, headersPrefix)
|
||||
if headerKey, ok := cfg.HeaderMD5Mapping[headerKeyMD5]; ok {
|
||||
request.Headers[headerKey] = value.(string)
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(key, paramsPrefix) {
|
||||
paramKeyMD5 := strings.TrimPrefix(key, paramsPrefix)
|
||||
if paramKey, ok := cfg.ParamMD5Mapping[paramKeyMD5]; ok {
|
||||
request.Params[paramKey] = value.(string)
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(key, authDataPrefix) {
|
||||
authKey := strings.TrimPrefix(key, authDataPrefix)
|
||||
if strings.HasPrefix(authKey, authBearerTokenDataPrefix) {
|
||||
request.Authentication.Token = value.(string) // bear
|
||||
}
|
||||
if strings.HasPrefix(authKey, authCustomDataPrefix) {
|
||||
if key == "__auth_authData_customData_data_Key" {
|
||||
request.Authentication.Key = value.(string)
|
||||
}
|
||||
if key == "__auth_authData_customData_data_Value" {
|
||||
request.Authentication.Value = value.(string)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(key, bodyDataPrefix) {
|
||||
bodyKey := strings.TrimPrefix(key, bodyDataPrefix)
|
||||
if strings.HasPrefix(bodyKey, bodyJsonPrefix) {
|
||||
jsonMd5Key := strings.TrimPrefix(bodyKey, bodyJsonPrefix)
|
||||
if jsonKey, ok := cfg.BodyMD5Mapping[jsonMd5Key]; ok {
|
||||
if strings.HasPrefix(jsonKey, "global_variable_") {
|
||||
jsonKey = globalVariableReplaceRegexp.ReplaceAllString(jsonKey, "global_variable_$1.$2")
|
||||
}
|
||||
nodes.SetMapValue(request.JsonVars, strings.Split(jsonKey, "."), value)
|
||||
}
|
||||
|
||||
}
|
||||
if strings.HasPrefix(bodyKey, bodyFormDataPrefix) {
|
||||
formDataMd5Key := strings.TrimPrefix(bodyKey, bodyFormDataPrefix)
|
||||
if formDataKey, ok := cfg.BodyMD5Mapping[formDataMd5Key]; ok {
|
||||
request.FormDataVars[formDataKey] = value.(string)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if strings.HasPrefix(bodyKey, bodyFormURLEncodedPrefix) {
|
||||
formURLEncodeMd5Key := strings.TrimPrefix(bodyKey, bodyFormURLEncodedPrefix)
|
||||
if formURLEncodeKey, ok := cfg.BodyMD5Mapping[formURLEncodeMd5Key]; ok {
|
||||
request.FormURLEncodedVars[formURLEncodeKey] = value.(string)
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(bodyKey, bodyRawTextPrefix) {
|
||||
rawTextMd5Key := strings.TrimPrefix(bodyKey, bodyRawTextPrefix)
|
||||
if rawTextKey, ok := cfg.BodyMD5Mapping[rawTextMd5Key]; ok {
|
||||
if strings.HasPrefix(rawTextKey, "global_variable_") {
|
||||
rawTextKey = globalVariableReplaceRegexp.ReplaceAllString(rawTextKey, "global_variable_$1.$2")
|
||||
}
|
||||
nodes.SetMapValue(request.TextPlainVars, strings.Split(rawTextKey, "."), value)
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(bodyKey, bodyBinaryFileURLPrefix) {
|
||||
request.FileURL = ptr.Of(value.(string))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (hg *HTTPRequester) ToCallbackOutput(_ context.Context, out map[string]any) (*nodes.StructuredCallbackOutput, error) {
|
||||
if body, ok := out["body"]; !ok {
|
||||
return &nodes.StructuredCallbackOutput{
|
||||
RawOutput: out,
|
||||
Output: out,
|
||||
}, nil
|
||||
} else {
|
||||
output := maps.Clone(out)
|
||||
output["body"] = decodeUnicode(body.(string))
|
||||
return &nodes.StructuredCallbackOutput{
|
||||
RawOutput: out,
|
||||
Output: output,
|
||||
}, nil
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,396 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package httprequester
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/crypto"
|
||||
)
|
||||
|
||||
func TestInvoke(t *testing.T) {
|
||||
t.Run("get method", func(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
response := map[string]string{
|
||||
"message": "success",
|
||||
}
|
||||
bs, _ := json.Marshal(response)
|
||||
_, _ = w.Write(bs)
|
||||
}))
|
||||
defer ts.Close()
|
||||
urlTpl := ts.URL + "/{{url_v1}}"
|
||||
cfg := &Config{
|
||||
URLConfig: URLConfig{
|
||||
Tpl: urlTpl,
|
||||
},
|
||||
BodyConfig: BodyConfig{
|
||||
BodyType: BodyTypeNone,
|
||||
},
|
||||
Method: http.MethodGet,
|
||||
RetryTimes: 1,
|
||||
Timeout: 2 * time.Second,
|
||||
MD5FieldMapping: MD5FieldMapping{
|
||||
URLMD5Mapping: map[string]string{
|
||||
crypto.MD5HexValue("url_v1"): "url_v1",
|
||||
},
|
||||
HeaderMD5Mapping: map[string]string{
|
||||
crypto.MD5HexValue("h1"): "h1",
|
||||
crypto.MD5HexValue("h2"): "h2",
|
||||
},
|
||||
ParamMD5Mapping: map[string]string{
|
||||
crypto.MD5HexValue("p1"): "p1",
|
||||
crypto.MD5HexValue("p2"): "p2",
|
||||
},
|
||||
},
|
||||
}
|
||||
hg, err := NewHTTPRequester(context.Background(), cfg)
|
||||
assert.NoError(t, err)
|
||||
m := map[string]any{
|
||||
"__apiInfo_url_" + crypto.MD5HexValue("url_v1"): "v1",
|
||||
"__headers_" + crypto.MD5HexValue("h1"): "1",
|
||||
"__headers_" + crypto.MD5HexValue("h2"): "2",
|
||||
"__params_" + crypto.MD5HexValue("p1"): "v1",
|
||||
"__params_" + crypto.MD5HexValue("p2"): "v2",
|
||||
}
|
||||
|
||||
result, err := hg.Invoke(context.Background(), m)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, `{"message":"success"}`, result["body"])
|
||||
assert.Equal(t, int64(200), result["statusCode"])
|
||||
})
|
||||
|
||||
t.Run("post method multipart/form-data", func(t *testing.T) {
|
||||
fileServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fileContent := "fileV1"
|
||||
_, _ = w.Write([]byte(fileContent))
|
||||
}))
|
||||
defer fileServer.Close()
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
err := r.ParseMultipartForm(10 << 20) // 10 MB
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
f1 := r.MultipartForm.Value["f1"][0]
|
||||
assert.Equal(t, "fv1", f1)
|
||||
f2 := r.MultipartForm.Value["f2"][0]
|
||||
assert.Equal(t, "fv2", f2)
|
||||
file, _, err := r.FormFile("fileURL")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
fileBs, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.Equal(t, "fileV1", string(fileBs))
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
response := map[string]string{
|
||||
"message": "success",
|
||||
}
|
||||
bs, _ := json.Marshal(response)
|
||||
_, _ = w.Write(bs)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
urlTpl := ts.URL + "/{{post_v1}}"
|
||||
|
||||
cfg := &Config{
|
||||
URLConfig: URLConfig{
|
||||
Tpl: urlTpl,
|
||||
},
|
||||
BodyConfig: BodyConfig{
|
||||
BodyType: BodyTypeFormData,
|
||||
FormDataConfig: &FormDataConfig{
|
||||
map[string]bool{
|
||||
"fileURL": true,
|
||||
},
|
||||
},
|
||||
},
|
||||
Method: http.MethodPost,
|
||||
RetryTimes: 1,
|
||||
Timeout: 2 * time.Second,
|
||||
MD5FieldMapping: MD5FieldMapping{
|
||||
URLMD5Mapping: map[string]string{
|
||||
crypto.MD5HexValue("post_v1"): "post_v1",
|
||||
},
|
||||
HeaderMD5Mapping: map[string]string{
|
||||
crypto.MD5HexValue("h1"): "h1",
|
||||
crypto.MD5HexValue("h2"): "h2",
|
||||
},
|
||||
ParamMD5Mapping: map[string]string{
|
||||
crypto.MD5HexValue("p1"): "p1",
|
||||
crypto.MD5HexValue("p2"): "p2",
|
||||
},
|
||||
BodyMD5Mapping: map[string]string{
|
||||
crypto.MD5HexValue("f1"): "f1",
|
||||
crypto.MD5HexValue("f2"): "f2",
|
||||
crypto.MD5HexValue("fileURL"): "fileURL",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// 创建 HTTPRequest 实例
|
||||
hg, err := NewHTTPRequester(context.Background(), cfg)
|
||||
assert.NoError(t, err)
|
||||
|
||||
m := map[string]any{
|
||||
"__apiInfo_url_" + crypto.MD5HexValue("post_v1"): "post_v1",
|
||||
"__headers_" + crypto.MD5HexValue("h1"): "1",
|
||||
"__headers_" + crypto.MD5HexValue("h2"): "2",
|
||||
"__params_" + crypto.MD5HexValue("p1"): "v1",
|
||||
"__params_" + crypto.MD5HexValue("p2"): "v2",
|
||||
"__body_bodyData_formData_" + crypto.MD5HexValue("f1"): "fv1",
|
||||
"__body_bodyData_formData_" + crypto.MD5HexValue("f2"): "fv2",
|
||||
"__body_bodyData_formData_" + crypto.MD5HexValue("fileURL"): fileServer.URL,
|
||||
}
|
||||
|
||||
result, err := hg.Invoke(context.Background(), m)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, `{"message":"success"}`, result["body"])
|
||||
assert.Equal(t, int64(200), result["statusCode"])
|
||||
})
|
||||
|
||||
t.Run("post method text/plain", func(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = r.Body.Close()
|
||||
}()
|
||||
assert.Equal(t, "text v1 v2", string(body))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
response := map[string]string{
|
||||
"message": "success",
|
||||
}
|
||||
bs, _ := json.Marshal(response)
|
||||
_, _ = w.Write(bs)
|
||||
}))
|
||||
defer ts.Close()
|
||||
urlTpl := ts.URL + "/{{post_text_plain}}"
|
||||
cfg := &Config{
|
||||
URLConfig: URLConfig{
|
||||
Tpl: urlTpl,
|
||||
},
|
||||
BodyConfig: BodyConfig{
|
||||
BodyType: BodyTypeRawText,
|
||||
TextPlainConfig: &TextPlainConfig{
|
||||
Tpl: "text {{v1}} {{v2}}",
|
||||
},
|
||||
},
|
||||
Method: http.MethodPost,
|
||||
RetryTimes: 1,
|
||||
Timeout: 2 * time.Second,
|
||||
MD5FieldMapping: MD5FieldMapping{
|
||||
URLMD5Mapping: map[string]string{
|
||||
crypto.MD5HexValue("post_text_plain"): "post_text_plain",
|
||||
},
|
||||
HeaderMD5Mapping: map[string]string{
|
||||
crypto.MD5HexValue("h1"): "h1",
|
||||
crypto.MD5HexValue("h2"): "h2",
|
||||
},
|
||||
ParamMD5Mapping: map[string]string{
|
||||
crypto.MD5HexValue("p1"): "p1",
|
||||
crypto.MD5HexValue("p2"): "p2",
|
||||
},
|
||||
BodyMD5Mapping: map[string]string{
|
||||
crypto.MD5HexValue("v1"): "v1",
|
||||
crypto.MD5HexValue("v2"): "v2",
|
||||
},
|
||||
},
|
||||
}
|
||||
hg, err := NewHTTPRequester(context.Background(), cfg)
|
||||
assert.NoError(t, err)
|
||||
|
||||
m := map[string]any{
|
||||
"__apiInfo_url_" + crypto.MD5HexValue("post_text_plain"): "post_text_plain",
|
||||
"__headers_" + crypto.MD5HexValue("h1"): "1",
|
||||
"__headers_" + crypto.MD5HexValue("h2"): "2",
|
||||
"__params_" + crypto.MD5HexValue("p1"): "v1",
|
||||
"__params_" + crypto.MD5HexValue("p2"): "v2",
|
||||
"__body_bodyData_rawText_" + crypto.MD5HexValue("v1"): "v1",
|
||||
"__body_bodyData_rawText_" + crypto.MD5HexValue("v2"): "v2",
|
||||
}
|
||||
|
||||
result, err := hg.Invoke(context.Background(), m)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, `{"message":"success"}`, result["body"])
|
||||
assert.Equal(t, int64(200), result["statusCode"])
|
||||
})
|
||||
|
||||
t.Run("post method application/json", func(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = r.Body.Close()
|
||||
}()
|
||||
assert.Equal(t, `{"v1":v1,"v2":v2}`, string(body))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
response := map[string]string{
|
||||
"message": "success",
|
||||
}
|
||||
bs, _ := json.Marshal(response)
|
||||
_, _ = w.Write(bs)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
urlTpl := ts.URL + "/{{application_json}}"
|
||||
|
||||
cfg := &Config{
|
||||
URLConfig: URLConfig{
|
||||
Tpl: urlTpl,
|
||||
},
|
||||
BodyConfig: BodyConfig{
|
||||
BodyType: BodyTypeJSON,
|
||||
TextJsonConfig: &TextJsonConfig{
|
||||
Tpl: `{"v1":{{v1}},"v2":{{v2}}}`,
|
||||
},
|
||||
},
|
||||
Method: http.MethodPost,
|
||||
RetryTimes: 1,
|
||||
Timeout: 2 * time.Second,
|
||||
|
||||
MD5FieldMapping: MD5FieldMapping{
|
||||
URLMD5Mapping: map[string]string{
|
||||
crypto.MD5HexValue("application_json"): "application_json",
|
||||
},
|
||||
HeaderMD5Mapping: map[string]string{
|
||||
crypto.MD5HexValue("h1"): "h1",
|
||||
crypto.MD5HexValue("h2"): "h2",
|
||||
},
|
||||
ParamMD5Mapping: map[string]string{
|
||||
crypto.MD5HexValue("p1"): "p1",
|
||||
crypto.MD5HexValue("p2"): "p2",
|
||||
},
|
||||
BodyMD5Mapping: map[string]string{
|
||||
crypto.MD5HexValue("v1"): "v1",
|
||||
crypto.MD5HexValue("v2"): "v2",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// 创建 HTTPRequest 实例
|
||||
hg, err := NewHTTPRequester(context.Background(), cfg)
|
||||
assert.NoError(t, err)
|
||||
|
||||
m := map[string]any{
|
||||
"__apiInfo_url_" + crypto.MD5HexValue("application_json"): "application_json",
|
||||
"__headers_" + crypto.MD5HexValue("h1"): "1",
|
||||
"__headers_" + crypto.MD5HexValue("h2"): "2",
|
||||
"__params_" + crypto.MD5HexValue("p1"): "v1",
|
||||
"__params_" + crypto.MD5HexValue("p2"): "v2",
|
||||
"__body_bodyData_json_" + crypto.MD5HexValue("v1"): "v1",
|
||||
"__body_bodyData_json_" + crypto.MD5HexValue("v2"): "v2",
|
||||
}
|
||||
|
||||
result, err := hg.Invoke(context.Background(), m)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, `{"message":"success"}`, result["body"])
|
||||
assert.Equal(t, int64(200), result["statusCode"])
|
||||
})
|
||||
|
||||
t.Run("post method application/octet-stream", func(t *testing.T) {
|
||||
fileServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fileContent := strings.Repeat("fileV1", 100)
|
||||
_, _ = w.Write([]byte(fileContent))
|
||||
}))
|
||||
defer fileServer.Close()
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = r.Body.Close()
|
||||
}()
|
||||
fileContent := strings.Repeat("fileV1", 100)
|
||||
assert.Equal(t, fileContent, string(body))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
response := map[string]string{
|
||||
"message": "success",
|
||||
}
|
||||
bs, _ := json.Marshal(response)
|
||||
_, _ = w.Write(bs)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
urlTpl := ts.URL + "/{{binary}}"
|
||||
|
||||
cfg := &Config{
|
||||
URLConfig: URLConfig{
|
||||
Tpl: urlTpl,
|
||||
},
|
||||
BodyConfig: BodyConfig{
|
||||
BodyType: BodyTypeBinary,
|
||||
},
|
||||
Method: http.MethodPost,
|
||||
RetryTimes: 1,
|
||||
Timeout: 2 * time.Second,
|
||||
MD5FieldMapping: MD5FieldMapping{
|
||||
URLMD5Mapping: map[string]string{
|
||||
crypto.MD5HexValue("binary"): "binary",
|
||||
},
|
||||
HeaderMD5Mapping: map[string]string{
|
||||
crypto.MD5HexValue("h1"): "h1",
|
||||
crypto.MD5HexValue("h2"): "h2",
|
||||
},
|
||||
ParamMD5Mapping: map[string]string{
|
||||
crypto.MD5HexValue("p1"): "p1",
|
||||
crypto.MD5HexValue("p2"): "p2",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// 创建 HTTPRequest 实例
|
||||
hg, err := NewHTTPRequester(context.Background(), cfg)
|
||||
assert.NoError(t, err)
|
||||
|
||||
m := map[string]any{
|
||||
"__apiInfo_url_" + crypto.MD5HexValue("application_json"): "application_json",
|
||||
"__headers_" + crypto.MD5HexValue("h1"): "1",
|
||||
"__headers_" + crypto.MD5HexValue("h2"): "2",
|
||||
"__params_" + crypto.MD5HexValue("p1"): "v1",
|
||||
"__params_" + crypto.MD5HexValue("p2"): "v2",
|
||||
"__body_bodyData_binary_fileURL" + crypto.MD5HexValue("v1"): fileServer.URL,
|
||||
}
|
||||
|
||||
result, err := hg.Invoke(context.Background(), m)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, `{"message":"success"}`, result["body"])
|
||||
assert.Equal(t, int64(200), result["statusCode"])
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,212 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package intentdetector
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/components/prompt"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/spf13/cast"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Intents []string
|
||||
SystemPrompt string
|
||||
IsFastMode bool
|
||||
ChatModel model.BaseChatModel
|
||||
}
|
||||
|
||||
const SystemIntentPrompt = `
|
||||
# Role
|
||||
You are an intention classification expert, good at being able to judge which classification the user's input belongs to.
|
||||
|
||||
## Skills
|
||||
Skill 1: Clearly determine which of the following intention classifications the user's input belongs to.
|
||||
Intention classification list:
|
||||
[
|
||||
{"classificationId": 0, "content": "Other intentions"},
|
||||
{{intents}}
|
||||
]
|
||||
|
||||
Note:
|
||||
- Please determine the match only between the user's input content and the Intention classification list content, without judging or categorizing the match with the classification ID.
|
||||
|
||||
{{advance}}
|
||||
|
||||
## Reply requirements
|
||||
- The answer must be returned in JSON format.
|
||||
- Strictly ensure that the output is in a valid JSON format.
|
||||
- Do not add prefix "json or suffix""
|
||||
- The answer needs to include the following fields such as:
|
||||
{
|
||||
"classificationId": 0,
|
||||
"reason": "Unclear intentions"
|
||||
}
|
||||
|
||||
##Limit
|
||||
- Please do not reply in text.
|
||||
`
|
||||
|
||||
const FastModeSystemIntentPrompt = `
|
||||
# Role
|
||||
You are an intention classification expert, good at being able to judge which classification the user's input belongs to.
|
||||
|
||||
## Skills
|
||||
Skill 1: Clearly determine which of the following intention classifications the user's input belongs to.
|
||||
Intention classification list:
|
||||
[
|
||||
{"classificationId": 0, "content": "Other intentions"},
|
||||
{{intents}}
|
||||
]
|
||||
|
||||
Note:
|
||||
- Please determine the match only between the user's input content and the Intention classification list content, without judging or categorizing the match with the classification ID.
|
||||
|
||||
|
||||
## Reply requirements
|
||||
- The answer must be a number indicated classificationId.
|
||||
- if not match, please just output an number 0.
|
||||
- do not output json format data, just output an number.
|
||||
|
||||
##Limit
|
||||
- Please do not reply in text.`
|
||||
|
||||
type IntentDetector struct {
|
||||
config *Config
|
||||
runner compose.Runnable[map[string]any, *schema.Message]
|
||||
}
|
||||
|
||||
func NewIntentDetector(ctx context.Context, cfg *Config) (*IntentDetector, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("cfg is required")
|
||||
}
|
||||
if !cfg.IsFastMode && cfg.ChatModel == nil {
|
||||
return nil, errors.New("config chat model is required")
|
||||
}
|
||||
|
||||
if len(cfg.Intents) == 0 {
|
||||
return nil, errors.New("config intents is required")
|
||||
}
|
||||
chain := compose.NewChain[map[string]any, *schema.Message]()
|
||||
|
||||
spt := ternary.IFElse[string](cfg.IsFastMode, FastModeSystemIntentPrompt, SystemIntentPrompt)
|
||||
|
||||
sptTemplate, err := nodes.TemplateRender(spt, map[string]interface{}{
|
||||
"intents": toIntentString(cfg.Intents),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
prompts := prompt.FromMessages(schema.Jinja2,
|
||||
&schema.Message{Content: sptTemplate, Role: schema.System},
|
||||
&schema.Message{Content: "{{query}}", Role: schema.User})
|
||||
|
||||
r, err := chain.AppendChatTemplate(prompts).AppendChatModel(cfg.ChatModel).Compile(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &IntentDetector{
|
||||
config: cfg,
|
||||
runner: r,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (id *IntentDetector) parseToNodeOut(content string) (map[string]any, error) {
|
||||
nodeOutput := make(map[string]any)
|
||||
nodeOutput["classificationId"] = 0
|
||||
if content == "" {
|
||||
return nodeOutput, errors.New("content is empty")
|
||||
}
|
||||
|
||||
if id.config.IsFastMode {
|
||||
cid, err := strconv.ParseInt(content, 10, 64)
|
||||
if err != nil {
|
||||
return nodeOutput, err
|
||||
}
|
||||
nodeOutput["classificationId"] = cid
|
||||
return nodeOutput, nil
|
||||
}
|
||||
|
||||
leftIndex := strings.Index(content, "{")
|
||||
rightIndex := strings.Index(content, "}")
|
||||
if leftIndex == -1 || rightIndex == -1 {
|
||||
return nodeOutput, errors.New("content is invalid")
|
||||
}
|
||||
|
||||
err := json.Unmarshal([]byte(content[leftIndex:rightIndex+1]), &nodeOutput)
|
||||
if err != nil {
|
||||
return nodeOutput, err
|
||||
}
|
||||
|
||||
return nodeOutput, nil
|
||||
}
|
||||
|
||||
func (id *IntentDetector) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
query, ok := input["query"]
|
||||
if !ok {
|
||||
return nil, errors.New("input query field required")
|
||||
}
|
||||
|
||||
queryStr, ok := query.(string)
|
||||
if !ok {
|
||||
queryStr = cast.ToString(query)
|
||||
}
|
||||
|
||||
vars := make(map[string]any)
|
||||
vars["query"] = queryStr
|
||||
if !id.config.IsFastMode {
|
||||
ad, err := nodes.TemplateRender(id.config.SystemPrompt, map[string]any{"query": query})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
vars["advance"] = ad
|
||||
}
|
||||
o, err := id.runner.Invoke(ctx, vars)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return id.parseToNodeOut(o.Content)
|
||||
}
|
||||
|
||||
func toIntentString(its []string) string {
|
||||
type IntentVariableItem struct {
|
||||
ClassificationID int64 `json:"classificationId"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
vs := make([]*IntentVariableItem, 0, len(its))
|
||||
|
||||
for idx, it := range its {
|
||||
vs = append(vs, &IntentVariableItem{
|
||||
ClassificationID: int64(idx + 1),
|
||||
Content: it,
|
||||
})
|
||||
}
|
||||
itsBytes, _ := json.Marshal(vs)
|
||||
return string(itsBytes)
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package intentdetector
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type mockChatModel struct {
|
||||
topSeed bool
|
||||
}
|
||||
|
||||
func (m mockChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) {
|
||||
if m.topSeed {
|
||||
return &schema.Message{
|
||||
Content: "1",
|
||||
}, nil
|
||||
}
|
||||
return &schema.Message{
|
||||
Content: `{"classificationId":1,"reason":"高兴"}`,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m mockChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m mockChatModel) BindTools(tools []*schema.ToolInfo) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestNewIntentDetector(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
t.Run("fast mode", func(t *testing.T) {
|
||||
dt, err := NewIntentDetector(ctx, &Config{
|
||||
Intents: []string{"高兴", "悲伤"},
|
||||
IsFastMode: true,
|
||||
ChatModel: &mockChatModel{topSeed: true},
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
|
||||
ret, err := dt.Invoke(ctx, map[string]any{
|
||||
"query": "我考了100分",
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, ret["classificationId"], int64(1))
|
||||
})
|
||||
|
||||
t.Run("full mode", func(t *testing.T) {
|
||||
|
||||
dt, err := NewIntentDetector(ctx, &Config{
|
||||
Intents: []string{"高兴", "悲伤"},
|
||||
IsFastMode: false,
|
||||
ChatModel: &mockChatModel{},
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
|
||||
ret, err := dt.Invoke(ctx, map[string]any{
|
||||
"query": "我考了100分",
|
||||
})
|
||||
fmt.Println(err)
|
||||
assert.Nil(t, err)
|
||||
fmt.Println(ret)
|
||||
assert.Equal(t, ret["classificationId"], float64(1))
|
||||
assert.Equal(t, ret["reason"], "高兴")
|
||||
})
|
||||
|
||||
}
|
||||
28
backend/domain/workflow/internal/nodes/interrupt.go
Normal file
28
backend/domain/workflow/internal/nodes/interrupt.go
Normal file
@@ -0,0 +1,28 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
)
|
||||
|
||||
type InterruptEventStore interface {
|
||||
GetInterruptEvent(nodeKey vo.NodeKey) (*entity.InterruptEvent, bool, error)
|
||||
SetInterruptEvent(nodeKey vo.NodeKey, value *entity.InterruptEvent) error
|
||||
DeleteInterruptEvent(nodeKey vo.NodeKey) error
|
||||
}
|
||||
@@ -0,0 +1,121 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package json
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
const (
|
||||
InputKeyDeserialization = "input"
|
||||
OutputKeyDeserialization = "output"
|
||||
warningsKey = "deserialization_warnings"
|
||||
)
|
||||
|
||||
type DeserializationConfig struct {
|
||||
OutputFields map[string]*vo.TypeInfo `json:"outputFields,omitempty"`
|
||||
}
|
||||
|
||||
type Deserializer struct {
|
||||
config *DeserializationConfig
|
||||
typeInfo *vo.TypeInfo
|
||||
}
|
||||
|
||||
func NewJsonDeserializer(_ context.Context, cfg *DeserializationConfig) (*Deserializer, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("config required")
|
||||
}
|
||||
if cfg.OutputFields == nil {
|
||||
return nil, fmt.Errorf("OutputFields is required for deserialization")
|
||||
}
|
||||
typeInfo := cfg.OutputFields[OutputKeyDeserialization]
|
||||
if typeInfo == nil {
|
||||
return nil, fmt.Errorf("no output field specified in deserialization config")
|
||||
}
|
||||
return &Deserializer{
|
||||
config: cfg,
|
||||
typeInfo: typeInfo,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (jd *Deserializer) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
jsonStrValue := input[InputKeyDeserialization]
|
||||
|
||||
jsonStr, ok := jsonStrValue.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("input is not a string, got %T", jsonStrValue)
|
||||
}
|
||||
|
||||
typeInfo := jd.typeInfo
|
||||
|
||||
var rawValue any
|
||||
var err error
|
||||
|
||||
// Unmarshal based on the root type
|
||||
switch typeInfo.Type {
|
||||
case vo.DataTypeString, vo.DataTypeInteger, vo.DataTypeNumber, vo.DataTypeBoolean, vo.DataTypeTime, vo.DataTypeFile:
|
||||
// Scalar types - unmarshal to generic any
|
||||
err = sonic.Unmarshal([]byte(jsonStr), &rawValue)
|
||||
case vo.DataTypeArray:
|
||||
// Array type - unmarshal to []any
|
||||
var arr []any
|
||||
err = sonic.Unmarshal([]byte(jsonStr), &arr)
|
||||
rawValue = arr
|
||||
case vo.DataTypeObject:
|
||||
// Object type - unmarshal to map[string]any
|
||||
var obj map[string]any
|
||||
err = sonic.Unmarshal([]byte(jsonStr), &obj)
|
||||
rawValue = obj
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported root data type: %s", typeInfo.Type)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("JSON unmarshaling failed: %w", err)
|
||||
}
|
||||
|
||||
convertedValue, ws, err := nodes.Convert(ctx, rawValue, OutputKeyDeserialization, typeInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ws != nil && len(*ws) > 0 {
|
||||
ctxcache.Store(ctx, warningsKey, *ws)
|
||||
}
|
||||
|
||||
return map[string]any{OutputKeyDeserialization: convertedValue}, nil
|
||||
}
|
||||
|
||||
func (jd *Deserializer) ToCallbackOutput(ctx context.Context, out map[string]any) (*nodes.StructuredCallbackOutput, error) {
|
||||
var wfe vo.WorkflowError
|
||||
if warnings, ok := ctxcache.Get[nodes.ConversionWarnings](ctx, warningsKey); ok {
|
||||
wfe = vo.WrapWarn(errno.ErrNodeOutputParseFail, warnings, errorx.KV("warnings", warnings.Error()))
|
||||
}
|
||||
return &nodes.StructuredCallbackOutput{
|
||||
Output: out,
|
||||
RawOutput: out,
|
||||
Error: wfe,
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,360 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package json
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
)
|
||||
|
||||
func TestNewJsonDeserializer(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Test with nil config
|
||||
_, err := NewJsonDeserializer(ctx, nil)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "config required")
|
||||
|
||||
// Test with missing OutputFields config
|
||||
_, err = NewJsonDeserializer(ctx, &DeserializationConfig{})
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "OutputFields is required")
|
||||
|
||||
// Test with missing output key in OutputFields
|
||||
_, err = NewJsonDeserializer(ctx, &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"testKey": {Type: vo.DataTypeString},
|
||||
},
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no output field specified in deserialization config")
|
||||
|
||||
// Test with valid config
|
||||
validConfig := &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
OutputKeyDeserialization: {Type: vo.DataTypeString},
|
||||
},
|
||||
}
|
||||
processor, err := NewJsonDeserializer(ctx, validConfig)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, processor)
|
||||
}
|
||||
|
||||
func TestJsonDeserializer_Invoke(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Base type test config
|
||||
baseTypeConfig := &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {Type: vo.DataTypeString},
|
||||
},
|
||||
}
|
||||
|
||||
// Object type test config
|
||||
objectTypeConfig := &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
"name": {Type: vo.DataTypeString, Required: true},
|
||||
"age": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Array type test config
|
||||
arrayTypeConfig := &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {
|
||||
Type: vo.DataTypeArray,
|
||||
ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeInteger},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Nested array object test config
|
||||
nestedArrayConfig := &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {
|
||||
Type: vo.DataTypeArray,
|
||||
ElemTypeInfo: &vo.TypeInfo{
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
"id": {Type: vo.DataTypeInteger},
|
||||
"name": {Type: vo.DataTypeString},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Test cases
|
||||
tests := []struct {
|
||||
name string
|
||||
config *DeserializationConfig
|
||||
inputJSON string
|
||||
expectedOutput any
|
||||
expectErr bool
|
||||
expectWarnings int
|
||||
}{{
|
||||
name: "Test string deserialization",
|
||||
config: baseTypeConfig,
|
||||
inputJSON: `"test string"`,
|
||||
expectedOutput: "test string",
|
||||
expectErr: false,
|
||||
expectWarnings: 0,
|
||||
}, {
|
||||
name: "Test integer deserialization",
|
||||
config: &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
},
|
||||
inputJSON: `123`,
|
||||
expectedOutput: 123,
|
||||
expectErr: false,
|
||||
expectWarnings: 0,
|
||||
}, {
|
||||
name: "Test boolean deserialization",
|
||||
config: &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {Type: vo.DataTypeBoolean},
|
||||
},
|
||||
},
|
||||
inputJSON: `true`,
|
||||
expectedOutput: true,
|
||||
expectErr: false,
|
||||
expectWarnings: 0,
|
||||
}, {
|
||||
name: "Test object deserialization",
|
||||
config: objectTypeConfig,
|
||||
inputJSON: `{"name":"test","age":20}`,
|
||||
expectedOutput: map[string]any{"name": "test", "age": 20.0},
|
||||
expectErr: false,
|
||||
expectWarnings: 0,
|
||||
}, {
|
||||
name: "Test array deserialization",
|
||||
config: arrayTypeConfig,
|
||||
inputJSON: `[1,2,3]`,
|
||||
expectedOutput: []any{1.0, 2.0, 3.0},
|
||||
expectErr: false,
|
||||
expectWarnings: 0,
|
||||
}, {
|
||||
name: "Test nested array object deserialization",
|
||||
config: nestedArrayConfig,
|
||||
inputJSON: `[{"id":1,"name":"a"},{"id":2,"name":"b"}]`,
|
||||
expectedOutput: []any{
|
||||
map[string]any{"id": 1.0, "name": "a"},
|
||||
map[string]any{"id": 2.0, "name": "b"},
|
||||
},
|
||||
expectErr: false,
|
||||
expectWarnings: 0,
|
||||
}, {
|
||||
name: "Test invalid JSON format",
|
||||
config: baseTypeConfig,
|
||||
inputJSON: `{invalid json}`,
|
||||
expectedOutput: nil,
|
||||
expectErr: true,
|
||||
expectWarnings: 0,
|
||||
}, {
|
||||
name: "Test type mismatch warning",
|
||||
config: &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
},
|
||||
inputJSON: `"not a number"`,
|
||||
expectedOutput: nil,
|
||||
expectErr: false,
|
||||
expectWarnings: 1,
|
||||
}, {
|
||||
name: "Test null JSON input",
|
||||
config: baseTypeConfig,
|
||||
inputJSON: `null`,
|
||||
expectedOutput: nil,
|
||||
expectErr: false,
|
||||
expectWarnings: 0,
|
||||
}, {
|
||||
name: "Test string to integer conversion",
|
||||
config: &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
},
|
||||
inputJSON: `"123"`,
|
||||
expectedOutput: 123,
|
||||
expectErr: false,
|
||||
expectWarnings: 0,
|
||||
}, {
|
||||
name: "Test float to integer conversion (integer part)",
|
||||
config: &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
},
|
||||
inputJSON: `123.0`,
|
||||
expectedOutput: 123,
|
||||
expectErr: false,
|
||||
expectWarnings: 0,
|
||||
}, {
|
||||
name: "Test float to integer conversion (non-integer part)",
|
||||
config: &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
},
|
||||
inputJSON: `123.5`,
|
||||
expectedOutput: 123,
|
||||
expectErr: false,
|
||||
expectWarnings: 0,
|
||||
}, {
|
||||
name: "Test boolean to integer conversion",
|
||||
config: &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
},
|
||||
inputJSON: `true`,
|
||||
expectedOutput: nil,
|
||||
expectErr: false,
|
||||
expectWarnings: 1,
|
||||
}, {
|
||||
name: "Test string to boolean conversion",
|
||||
config: &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {Type: vo.DataTypeBoolean},
|
||||
},
|
||||
},
|
||||
inputJSON: `"true"`,
|
||||
expectedOutput: true,
|
||||
expectErr: false,
|
||||
expectWarnings: 0,
|
||||
}, {
|
||||
name: "Test string to integer conversion in nested object",
|
||||
config: &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
"age": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
inputJSON: `{"age":"456"}`,
|
||||
expectedOutput: map[string]any{"age": 456},
|
||||
expectErr: false,
|
||||
expectWarnings: 0,
|
||||
}, {
|
||||
name: "Test string to integer conversion for array elements",
|
||||
config: &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {
|
||||
Type: vo.DataTypeArray,
|
||||
ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeInteger},
|
||||
},
|
||||
},
|
||||
},
|
||||
inputJSON: `["1", "2", "3"]`,
|
||||
expectedOutput: []any{1, 2, 3},
|
||||
expectErr: false,
|
||||
expectWarnings: 0,
|
||||
}, {
|
||||
name: "Test string with non-numeric characters to integer conversion",
|
||||
config: &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
},
|
||||
inputJSON: `"123abc"`,
|
||||
expectedOutput: nil,
|
||||
expectErr: false,
|
||||
expectWarnings: 1,
|
||||
}, {
|
||||
name: "Test type mismatch in nested object field",
|
||||
config: &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
"score": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
inputJSON: `{"score":"invalid"}`,
|
||||
expectedOutput: map[string]any{"score": nil},
|
||||
expectErr: false,
|
||||
expectWarnings: 1,
|
||||
}, {
|
||||
name: "Test partial conversion failure in array elements",
|
||||
config: &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {
|
||||
Type: vo.DataTypeArray,
|
||||
ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeInteger},
|
||||
},
|
||||
},
|
||||
},
|
||||
inputJSON: `["1", "two", 3]`,
|
||||
expectedOutput: []any{1, 3},
|
||||
expectErr: false,
|
||||
expectWarnings: 1,
|
||||
}}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
processor, err := NewJsonDeserializer(ctx, tt.config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
ctxWithCache := ctxcache.Init(ctx)
|
||||
input := map[string]any{"input": tt.inputJSON}
|
||||
result, err := processor.Invoke(ctxWithCache, input)
|
||||
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, result, OutputKeyDeserialization)
|
||||
|
||||
// Verify the output
|
||||
output := result[OutputKeyDeserialization]
|
||||
if tt.expectedOutput == nil {
|
||||
assert.Nil(t, output)
|
||||
} else {
|
||||
// Serialize expected and actual output to JSON for comparison, ignoring type differences (e.g., float64 vs. int)
|
||||
actualJSON, _ := sonic.Marshal(output)
|
||||
expectedJSON, _ := sonic.Marshal(tt.expectedOutput)
|
||||
assert.JSONEq(t, string(expectedJSON), string(actualJSON))
|
||||
}
|
||||
|
||||
// Verify the number of warnings
|
||||
warnings, _ := ctxcache.Get[nodes.ConversionWarnings](ctxWithCache, warningsKey)
|
||||
assert.Equal(t, tt.expectWarnings, len(warnings))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package json
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
)
|
||||
|
||||
const (
|
||||
InputKeySerialization = "input"
|
||||
OutputKeySerialization = "output"
|
||||
)
|
||||
|
||||
type SerializationConfig struct {
|
||||
InputTypes map[string]*vo.TypeInfo
|
||||
}
|
||||
|
||||
type JsonSerializer struct {
|
||||
config *SerializationConfig
|
||||
}
|
||||
|
||||
func NewJsonSerializer(_ context.Context, cfg *SerializationConfig) (*JsonSerializer, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("config required")
|
||||
}
|
||||
if cfg.InputTypes == nil {
|
||||
return nil, fmt.Errorf("InputTypes is required for serialization")
|
||||
}
|
||||
|
||||
return &JsonSerializer{
|
||||
config: cfg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (js *JsonSerializer) Invoke(_ context.Context, input map[string]any) (map[string]any, error) {
|
||||
// Directly use the input map for serialization
|
||||
if input == nil {
|
||||
return nil, fmt.Errorf("input data for serialization cannot be nil")
|
||||
}
|
||||
|
||||
originData := input[InputKeySerialization]
|
||||
serializedData, err := sonic.Marshal(originData) // Serialize the entire input map
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("serialization error: %w", err)
|
||||
}
|
||||
return map[string]any{OutputKeySerialization: string(serializedData)}, nil
|
||||
}
|
||||
@@ -0,0 +1,134 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package json
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
)
|
||||
|
||||
func TestNewJsonSerialize(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Test with nil config
|
||||
_, err := NewJsonSerializer(ctx, nil)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "config required")
|
||||
|
||||
// Test with missing InputTypes config
|
||||
_, err = NewJsonSerializer(ctx, &SerializationConfig{})
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "InputTypes is required")
|
||||
|
||||
// Test with valid config
|
||||
validConfig := &SerializationConfig{
|
||||
InputTypes: map[string]*vo.TypeInfo{
|
||||
"testKey": {Type: "string"},
|
||||
},
|
||||
}
|
||||
processor, err := NewJsonSerializer(ctx, validConfig)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, processor)
|
||||
}
|
||||
|
||||
func TestJsonSerialize_Invoke(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
config := &SerializationConfig{
|
||||
InputTypes: map[string]*vo.TypeInfo{
|
||||
"stringKey": {Type: "string"},
|
||||
"intKey": {Type: "integer"},
|
||||
"boolKey": {Type: "boolean"},
|
||||
"objKey": {Type: "object"},
|
||||
},
|
||||
}
|
||||
|
||||
processor, err := NewJsonSerializer(ctx, config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test cases
|
||||
tests := []struct {
|
||||
name string
|
||||
input map[string]any
|
||||
expected string
|
||||
expectErr bool
|
||||
}{{
|
||||
name: "Test string serialization",
|
||||
input: map[string]any{
|
||||
"input": "test",
|
||||
},
|
||||
expected: `"test"`,
|
||||
expectErr: false,
|
||||
}, {
|
||||
name: "Test integer serialization",
|
||||
input: map[string]any{
|
||||
"input": 123,
|
||||
},
|
||||
expected: `123`,
|
||||
expectErr: false,
|
||||
}, {
|
||||
name: "Test boolean serialization",
|
||||
input: map[string]any{
|
||||
"input": true,
|
||||
},
|
||||
expected: `true`,
|
||||
expectErr: false,
|
||||
}, {
|
||||
name: "Test object serialization",
|
||||
input: map[string]any{
|
||||
"input": map[string]any{
|
||||
"nestedKey": "nestedValue",
|
||||
},
|
||||
},
|
||||
expected: `{"nestedKey":"nestedValue"}`,
|
||||
expectErr: false,
|
||||
}, {
|
||||
name: "Test nil input",
|
||||
input: nil,
|
||||
expected: "",
|
||||
expectErr: true,
|
||||
}, {
|
||||
name: "Test special character handling",
|
||||
input: map[string]any{
|
||||
"input": "\"test\"\nwith\twhitespace",
|
||||
},
|
||||
expected: `"\"test\"\nwith\twhitespace"`,
|
||||
expectErr: false,
|
||||
}}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := processor.Invoke(ctx, tt.input)
|
||||
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, result, OutputKeySerialization)
|
||||
|
||||
jsonStr, ok := result[OutputKeySerialization].(string)
|
||||
assert.True(t, ok, "The output should be of type string")
|
||||
|
||||
assert.JSONEq(t, tt.expected, jsonStr)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
|
||||
)
|
||||
|
||||
type DeleterConfig struct {
|
||||
KnowledgeID int64
|
||||
KnowledgeDeleter knowledge.KnowledgeOperator
|
||||
}
|
||||
|
||||
type KnowledgeDeleter struct {
|
||||
config *DeleterConfig
|
||||
}
|
||||
|
||||
func NewKnowledgeDeleter(_ context.Context, cfg *DeleterConfig) (*KnowledgeDeleter, error) {
|
||||
if cfg.KnowledgeDeleter == nil {
|
||||
return nil, errors.New("knowledge deleter is required")
|
||||
}
|
||||
return &KnowledgeDeleter{
|
||||
config: cfg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (k *KnowledgeDeleter) Delete(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
documentID, ok := input["documentID"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("documentID is required and must be a string")
|
||||
}
|
||||
|
||||
req := &knowledge.DeleteDocumentRequest{
|
||||
DocumentID: documentID,
|
||||
}
|
||||
|
||||
response, err := k.config.KnowledgeDeleter.Delete(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make(map[string]any)
|
||||
result["isSuccess"] = response.IsSuccess
|
||||
|
||||
return result, nil
|
||||
}
|
||||
@@ -0,0 +1,111 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
|
||||
)
|
||||
|
||||
type IndexerConfig struct {
|
||||
KnowledgeID int64
|
||||
ParsingStrategy *knowledge.ParsingStrategy
|
||||
ChunkingStrategy *knowledge.ChunkingStrategy
|
||||
KnowledgeIndexer knowledge.KnowledgeOperator
|
||||
}
|
||||
|
||||
type KnowledgeIndexer struct {
|
||||
config *IndexerConfig
|
||||
}
|
||||
|
||||
func NewKnowledgeIndexer(_ context.Context, cfg *IndexerConfig) (*KnowledgeIndexer, error) {
|
||||
if cfg.ParsingStrategy == nil {
|
||||
return nil, errors.New("parsing strategy is required")
|
||||
}
|
||||
if cfg.ChunkingStrategy == nil {
|
||||
return nil, errors.New("chunking strategy is required")
|
||||
}
|
||||
if cfg.KnowledgeIndexer == nil {
|
||||
return nil, errors.New("knowledge indexer is required")
|
||||
}
|
||||
return &KnowledgeIndexer{
|
||||
config: cfg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (k *KnowledgeIndexer) Store(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
|
||||
fileURL, ok := input["knowledge"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("knowledge is required")
|
||||
}
|
||||
|
||||
fileName, ext, err := parseToFileNameAndFileExtension(fileURL)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req := &knowledge.CreateDocumentRequest{
|
||||
KnowledgeID: k.config.KnowledgeID,
|
||||
ParsingStrategy: k.config.ParsingStrategy,
|
||||
ChunkingStrategy: k.config.ChunkingStrategy,
|
||||
FileURL: fileURL,
|
||||
FileName: fileName,
|
||||
FileExtension: ext,
|
||||
}
|
||||
|
||||
response, err := k.config.KnowledgeIndexer.Store(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make(map[string]any)
|
||||
result["documentId"] = response.DocumentID
|
||||
result["fileName"] = response.FileName
|
||||
result["fileUrl"] = response.FileURL
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func parseToFileNameAndFileExtension(fileURL string) (string, parser.FileExtension, error) {
|
||||
|
||||
u, err := url.Parse(fileURL)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
fileName := u.Query().Get("x-wf-file_name")
|
||||
if len(fileName) == 0 {
|
||||
return "", "", errors.New("file name is required")
|
||||
}
|
||||
|
||||
fileExt := strings.ToLower(strings.TrimPrefix(filepath.Ext(fileName), "."))
|
||||
|
||||
ext, support := parser.ValidateFileExtension(fileExt)
|
||||
if !support {
|
||||
return "", "", fmt.Errorf("unsupported file type: %s", fileExt)
|
||||
}
|
||||
return fileName, ext, nil
|
||||
}
|
||||
@@ -0,0 +1,87 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
)
|
||||
|
||||
const outputList = "outputList"
|
||||
|
||||
type RetrieveConfig struct {
|
||||
KnowledgeIDs []int64
|
||||
RetrievalStrategy *knowledge.RetrievalStrategy
|
||||
Retriever knowledge.KnowledgeOperator
|
||||
}
|
||||
|
||||
type KnowledgeRetrieve struct {
|
||||
config *RetrieveConfig
|
||||
}
|
||||
|
||||
func NewKnowledgeRetrieve(_ context.Context, cfg *RetrieveConfig) (*KnowledgeRetrieve, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("cfg is required")
|
||||
}
|
||||
|
||||
if cfg.Retriever == nil {
|
||||
return nil, errors.New("retriever is required")
|
||||
}
|
||||
|
||||
if len(cfg.KnowledgeIDs) == 0 {
|
||||
return nil, errors.New("knowledgeI ids is required")
|
||||
}
|
||||
|
||||
if cfg.RetrievalStrategy == nil {
|
||||
return nil, errors.New("retrieval strategy is required")
|
||||
}
|
||||
|
||||
return &KnowledgeRetrieve{
|
||||
config: cfg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (kr *KnowledgeRetrieve) Retrieve(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
|
||||
query, ok := input["Query"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("capital query key is required")
|
||||
}
|
||||
|
||||
req := &knowledge.RetrieveRequest{
|
||||
Query: query,
|
||||
KnowledgeIDs: kr.config.KnowledgeIDs,
|
||||
RetrievalStrategy: kr.config.RetrievalStrategy,
|
||||
}
|
||||
|
||||
response, err := kr.config.Retriever.Retrieve(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := make(map[string]any)
|
||||
result[outputList] = slices.Transform(response.Slices, func(m *knowledge.Slice) any {
|
||||
return map[string]any{
|
||||
"documentId": m.DocumentID,
|
||||
"output": m.Output,
|
||||
}
|
||||
})
|
||||
|
||||
return result, nil
|
||||
}
|
||||
847
backend/domain/workflow/internal/nodes/llm/llm.go
Normal file
847
backend/domain/workflow/internal/nodes/llm/llm.go
Normal file
@@ -0,0 +1,847 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/model/ark"
|
||||
"github.com/cloudwego/eino-ext/components/model/deepseek"
|
||||
"github.com/cloudwego/eino/callbacks"
|
||||
"github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/components/prompt"
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/flow/agent/react"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
callbacks2 "github.com/cloudwego/eino/utils/callbacks"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
crossknowledge "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
type Format int
|
||||
|
||||
const (
|
||||
FormatText Format = iota
|
||||
FormatMarkdown
|
||||
FormatJSON
|
||||
)
|
||||
|
||||
const (
|
||||
jsonPromptFormat = `
|
||||
Strictly reply in valid JSON format.
|
||||
- Ensure the output strictly conforms to the JSON schema below
|
||||
- Do not include explanations, comments, or any text outside the JSON.
|
||||
|
||||
Here is the output JSON schema:
|
||||
'''
|
||||
%s
|
||||
'''
|
||||
`
|
||||
markdownPrompt = `
|
||||
Strictly reply in valid Markdown format.
|
||||
- For headings, use number signs (#).
|
||||
- For list items, start with dashes (-).
|
||||
- To emphasize text, wrap it with asterisks (*).
|
||||
- For code or commands, surround them with backticks (` + "`" + `).
|
||||
- For quoted text, use greater than signs (>).
|
||||
- For links, wrap the text in square brackets [], followed by the URL in parentheses ().
|
||||
- For images, use square brackets [] for the alt text, followed by the image URL in parentheses ().
|
||||
|
||||
`
|
||||
)
|
||||
|
||||
const (
|
||||
ReasoningOutputKey = "reasoning_content"
|
||||
)
|
||||
|
||||
const knowledgeUserPromptTemplate = `根据引用的内容回答问题:
|
||||
1.如果引用的内容里面包含 <img src=""> 的标签, 标签里的 src 字段表示图片地址, 需要在回答问题的时候展示出去, 输出格式为"" 。
|
||||
2.如果引用的内容不包含 <img src=""> 的标签, 你回答问题时不需要展示图片 。
|
||||
例如:
|
||||
如果内容为<img src="https://example.com/image.jpg">一只小猫,你的输出应为:。
|
||||
如果内容为<img src="https://example.com/image1.jpg">一只小猫 和 <img src="https://example.com/image2.jpg">一只小狗 和 <img src="https://example.com/image3.jpg">一只小牛,你的输出应为: 和  和 
|
||||
you can refer to the following content and do relevant searches to improve:
|
||||
---
|
||||
%s
|
||||
|
||||
question is:
|
||||
|
||||
`
|
||||
|
||||
const knowledgeIntentPrompt = `
|
||||
# 角色:
|
||||
你是一个知识库意图识别AI Agent。
|
||||
## 目标:
|
||||
- 按照「系统提示词」、用户需求、最新的聊天记录选择应该使用的知识库。
|
||||
## 工作流程:
|
||||
1. 分析「系统提示词」以确定用户的具体需求。
|
||||
2. 如果「系统提示词」明确指明了要使用的知识库,则直接返回这些知识库,只输出它们的knowledge_id,不需要再判断用户的输入
|
||||
3. 检查每个知识库的knowledge_name和knowledge_description,以了解它们各自的功能。
|
||||
4. 根据用户需求,选择最符合的知识库。
|
||||
5. 如果找到一个或多个合适的知识库,输出它们的knowledge_id。如果没有合适的知识库,输出0。
|
||||
## 约束:
|
||||
- 严格按照「系统提示词」和用户的需求选择知识库。「系统提示词」的优先级大于用户的需求
|
||||
- 如果有多个合适的知识库,将它们的knowledge_id用英文逗号连接后输出。
|
||||
- 输出必须仅为knowledge_id或0,不得包括任何其他内容或解释,不要在id后面输出知识库名称。
|
||||
|
||||
## 输出示例
|
||||
123,456
|
||||
|
||||
## 输出格式:
|
||||
输出应该是一个纯数字或者由英文逗号连接的数字序列,具体取决于选择的知识库数量。不应包含任何其他文本或格式。
|
||||
## 知识库列表如下
|
||||
%s
|
||||
## 「系统提示词」如下
|
||||
%s
|
||||
`
|
||||
|
||||
const (
|
||||
knowledgeTemplateKey = "knowledge_template"
|
||||
knowledgeChatModelKey = "knowledge_chat_model"
|
||||
knowledgeLambdaKey = "knowledge_lambda"
|
||||
knowledgeUserPromptTemplateKey = "knowledge_user_prompt_prefix"
|
||||
templateNodeKey = "template"
|
||||
llmNodeKey = "llm"
|
||||
outputConvertNodeKey = "output_convert"
|
||||
)
|
||||
|
||||
type NoReCallReplyMode int64
|
||||
|
||||
const (
|
||||
NoReCallReplyModeOfDefault NoReCallReplyMode = 0
|
||||
NoReCallReplyModeOfCustomize NoReCallReplyMode = 1
|
||||
)
|
||||
|
||||
type RetrievalStrategy struct {
|
||||
RetrievalStrategy *crossknowledge.RetrievalStrategy
|
||||
NoReCallReplyMode NoReCallReplyMode
|
||||
NoReCallReplyCustomizePrompt string
|
||||
}
|
||||
|
||||
type KnowledgeRecallConfig struct {
|
||||
ChatModel model.BaseChatModel
|
||||
Retriever crossknowledge.KnowledgeOperator
|
||||
RetrievalStrategy *RetrievalStrategy
|
||||
SelectedKnowledgeDetails []*crossknowledge.KnowledgeDetail
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
ChatModel ModelWithInfo
|
||||
Tools []tool.BaseTool
|
||||
SystemPrompt string
|
||||
UserPrompt string
|
||||
OutputFormat Format
|
||||
InputFields map[string]*vo.TypeInfo
|
||||
OutputFields map[string]*vo.TypeInfo
|
||||
ToolsReturnDirectly map[string]bool
|
||||
KnowledgeRecallConfig *KnowledgeRecallConfig
|
||||
FullSources map[string]*nodes.SourceInfo
|
||||
}
|
||||
|
||||
type LLM struct {
|
||||
r compose.Runnable[map[string]any, map[string]any]
|
||||
outputFormat Format
|
||||
outputFields map[string]*vo.TypeInfo
|
||||
canStream bool
|
||||
requireCheckpoint bool
|
||||
fullSources map[string]*nodes.SourceInfo
|
||||
}
|
||||
|
||||
const (
|
||||
rawOutputKey = "llm_raw_output_%s"
|
||||
warningKey = "llm_warning_%s"
|
||||
)
|
||||
|
||||
func jsonParse(ctx context.Context, data string, schema_ map[string]*vo.TypeInfo) (map[string]any, error) {
|
||||
data = nodes.ExtractJSONString(data)
|
||||
|
||||
var result map[string]any
|
||||
|
||||
err := sonic.UnmarshalString(data, &result)
|
||||
if err != nil {
|
||||
c := execute.GetExeCtx(ctx)
|
||||
if c != nil {
|
||||
logs.CtxErrorf(ctx, "failed to parse json: %v, data: %s", err, data)
|
||||
rawOutputK := fmt.Sprintf(rawOutputKey, c.NodeCtx.NodeKey)
|
||||
warningK := fmt.Sprintf(warningKey, c.NodeCtx.NodeKey)
|
||||
ctxcache.Store(ctx, rawOutputK, data)
|
||||
ctxcache.Store(ctx, warningK, vo.WrapWarn(errno.ErrLLMStructuredOutputParseFail, err))
|
||||
return map[string]any{}, nil
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r, ws, err := nodes.ConvertInputs(ctx, result, schema_)
|
||||
if err != nil {
|
||||
return nil, vo.WrapError(errno.ErrLLMStructuredOutputParseFail, err)
|
||||
}
|
||||
|
||||
if ws != nil {
|
||||
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func getReasoningContent(message *schema.Message) string {
|
||||
c, ok := deepseek.GetReasoningContent(message)
|
||||
if ok {
|
||||
return c
|
||||
}
|
||||
|
||||
c, ok = ark.GetReasoningContent(message)
|
||||
if ok {
|
||||
return c
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
nested []nodes.NestedWorkflowOption
|
||||
toolWorkflowSW *schema.StreamWriter[*entity.Message]
|
||||
}
|
||||
|
||||
type Option func(o *Options)
|
||||
|
||||
func WithNestedWorkflowOptions(nested ...nodes.NestedWorkflowOption) Option {
|
||||
return func(o *Options) {
|
||||
o.nested = append(o.nested, nested...)
|
||||
}
|
||||
}
|
||||
|
||||
func WithToolWorkflowMessageWriter(sw *schema.StreamWriter[*entity.Message]) Option {
|
||||
return func(o *Options) {
|
||||
o.toolWorkflowSW = sw
|
||||
}
|
||||
}
|
||||
|
||||
type llmState = map[string]any
|
||||
|
||||
const agentModelName = "agent_model"
|
||||
|
||||
func New(ctx context.Context, cfg *Config) (*LLM, error) {
|
||||
g := compose.NewGraph[map[string]any, map[string]any](compose.WithGenLocalState(func(ctx context.Context) (state llmState) {
|
||||
return llmState{}
|
||||
}))
|
||||
|
||||
var (
|
||||
hasReasoning bool
|
||||
canStream = true
|
||||
)
|
||||
|
||||
format := cfg.OutputFormat
|
||||
if format == FormatJSON {
|
||||
if len(cfg.OutputFields) == 1 {
|
||||
for _, v := range cfg.OutputFields {
|
||||
if v.Type == vo.DataTypeString {
|
||||
format = FormatText
|
||||
break
|
||||
}
|
||||
}
|
||||
} else if len(cfg.OutputFields) == 2 {
|
||||
if _, ok := cfg.OutputFields[ReasoningOutputKey]; ok {
|
||||
for k, v := range cfg.OutputFields {
|
||||
if k != ReasoningOutputKey && v.Type == vo.DataTypeString {
|
||||
format = FormatText
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
userPrompt := cfg.UserPrompt
|
||||
switch format {
|
||||
case FormatJSON:
|
||||
jsonSchema, err := vo.TypeInfoToJSONSchema(cfg.OutputFields, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
jsonPrompt := fmt.Sprintf(jsonPromptFormat, jsonSchema)
|
||||
userPrompt = userPrompt + jsonPrompt
|
||||
case FormatMarkdown:
|
||||
userPrompt = userPrompt + markdownPrompt
|
||||
case FormatText:
|
||||
}
|
||||
|
||||
if cfg.KnowledgeRecallConfig != nil {
|
||||
err := injectKnowledgeTool(ctx, g, cfg.UserPrompt, cfg.KnowledgeRecallConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userPrompt = fmt.Sprintf("{{%s}}%s", knowledgeUserPromptTemplateKey, userPrompt)
|
||||
|
||||
inputs := maps.Clone(cfg.InputFields)
|
||||
inputs[knowledgeUserPromptTemplateKey] = &vo.TypeInfo{
|
||||
Type: vo.DataTypeString,
|
||||
}
|
||||
sp := newPromptTpl(schema.System, cfg.SystemPrompt, inputs, nil)
|
||||
up := newPromptTpl(schema.User, userPrompt, inputs, []string{knowledgeUserPromptTemplateKey})
|
||||
template := newPrompts(sp, up, cfg.ChatModel)
|
||||
|
||||
_ = g.AddChatTemplateNode(templateNodeKey, template,
|
||||
compose.WithStatePreHandler(func(ctx context.Context, in map[string]any, state llmState) (map[string]any, error) {
|
||||
for k, v := range state {
|
||||
in[k] = v
|
||||
}
|
||||
return in, nil
|
||||
}))
|
||||
_ = g.AddEdge(knowledgeLambdaKey, templateNodeKey)
|
||||
|
||||
} else {
|
||||
sp := newPromptTpl(schema.System, cfg.SystemPrompt, cfg.InputFields, nil)
|
||||
up := newPromptTpl(schema.User, userPrompt, cfg.InputFields, nil)
|
||||
template := newPrompts(sp, up, cfg.ChatModel)
|
||||
_ = g.AddChatTemplateNode(templateNodeKey, template)
|
||||
|
||||
_ = g.AddEdge(compose.START, templateNodeKey)
|
||||
}
|
||||
|
||||
if len(cfg.Tools) > 0 {
|
||||
m, ok := cfg.ChatModel.(model.ToolCallingChatModel)
|
||||
if !ok {
|
||||
return nil, errors.New("requires a ToolCallingChatModel to use with tools")
|
||||
}
|
||||
reactConfig := react.AgentConfig{
|
||||
ToolCallingModel: m,
|
||||
ToolsConfig: compose.ToolsNodeConfig{Tools: cfg.Tools},
|
||||
ModelNodeName: agentModelName,
|
||||
}
|
||||
|
||||
if len(cfg.ToolsReturnDirectly) > 0 {
|
||||
reactConfig.ToolReturnDirectly = make(map[string]struct{}, len(cfg.ToolsReturnDirectly))
|
||||
for k := range cfg.ToolsReturnDirectly {
|
||||
reactConfig.ToolReturnDirectly[k] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
reactAgent, err := react.NewAgent(ctx, &reactConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
agentNode, opts := reactAgent.ExportGraph()
|
||||
opts = append(opts, compose.WithNodeName("workflow_llm_react_agent"))
|
||||
_ = g.AddGraphNode(llmNodeKey, agentNode, opts...)
|
||||
} else {
|
||||
_ = g.AddChatModelNode(llmNodeKey, cfg.ChatModel)
|
||||
}
|
||||
|
||||
_ = g.AddEdge(templateNodeKey, llmNodeKey)
|
||||
|
||||
if format == FormatJSON {
|
||||
iConvert := func(ctx context.Context, msg *schema.Message) (map[string]any, error) {
|
||||
return jsonParse(ctx, msg.Content, cfg.OutputFields)
|
||||
}
|
||||
|
||||
convertNode := compose.InvokableLambda(iConvert)
|
||||
|
||||
_ = g.AddLambdaNode(outputConvertNodeKey, convertNode)
|
||||
|
||||
canStream = false
|
||||
} else {
|
||||
var outputKey string
|
||||
if len(cfg.OutputFields) != 1 && len(cfg.OutputFields) != 2 {
|
||||
panic("impossible")
|
||||
}
|
||||
|
||||
for k, v := range cfg.OutputFields {
|
||||
if v.Type != vo.DataTypeString {
|
||||
panic("impossible")
|
||||
}
|
||||
|
||||
if k == ReasoningOutputKey {
|
||||
hasReasoning = true
|
||||
} else {
|
||||
outputKey = k
|
||||
}
|
||||
}
|
||||
|
||||
iConvert := func(_ context.Context, msg *schema.Message, _ ...struct{}) (map[string]any, error) {
|
||||
out := map[string]any{outputKey: msg.Content}
|
||||
if hasReasoning {
|
||||
out[ReasoningOutputKey] = getReasoningContent(msg)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
tConvert := func(_ context.Context, s *schema.StreamReader[*schema.Message], _ ...struct{}) (*schema.StreamReader[map[string]any], error) {
|
||||
sr, sw := schema.Pipe[map[string]any](0)
|
||||
|
||||
safego.Go(ctx, func() {
|
||||
reasoningDone := false
|
||||
for {
|
||||
msg, err := s.Recv()
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
sw.Send(map[string]any{
|
||||
outputKey: nodes.KeyIsFinished,
|
||||
}, nil)
|
||||
sw.Close()
|
||||
return
|
||||
}
|
||||
|
||||
sw.Send(nil, err)
|
||||
sw.Close()
|
||||
return
|
||||
}
|
||||
|
||||
if hasReasoning {
|
||||
reasoning := getReasoningContent(msg)
|
||||
if len(reasoning) > 0 {
|
||||
sw.Send(map[string]any{ReasoningOutputKey: reasoning}, nil)
|
||||
}
|
||||
}
|
||||
|
||||
if len(msg.Content) > 0 {
|
||||
if !reasoningDone && hasReasoning {
|
||||
reasoningDone = true
|
||||
sw.Send(map[string]any{
|
||||
ReasoningOutputKey: nodes.KeyIsFinished,
|
||||
}, nil)
|
||||
}
|
||||
sw.Send(map[string]any{outputKey: msg.Content}, nil)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return sr, nil
|
||||
}
|
||||
|
||||
convertNode, err := compose.AnyLambda(iConvert, nil, nil, tConvert)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_ = g.AddLambdaNode(outputConvertNodeKey, convertNode)
|
||||
}
|
||||
|
||||
_ = g.AddEdge(llmNodeKey, outputConvertNodeKey)
|
||||
_ = g.AddEdge(outputConvertNodeKey, compose.END)
|
||||
|
||||
requireCheckpoint := false
|
||||
if len(cfg.Tools) > 0 {
|
||||
requireCheckpoint = true
|
||||
}
|
||||
|
||||
var opts []compose.GraphCompileOption
|
||||
if requireCheckpoint {
|
||||
opts = append(opts, compose.WithCheckPointStore(workflow.GetRepository()))
|
||||
}
|
||||
opts = append(opts, compose.WithGraphName("workflow_llm_node_graph"))
|
||||
|
||||
r, err := g.Compile(ctx, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
llm := &LLM{
|
||||
r: r,
|
||||
outputFormat: format,
|
||||
canStream: canStream,
|
||||
requireCheckpoint: requireCheckpoint,
|
||||
fullSources: cfg.FullSources,
|
||||
}
|
||||
|
||||
return llm, nil
|
||||
}
|
||||
|
||||
func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...Option) (composeOpts []compose.Option, resumingEvent *entity.InterruptEvent, err error) {
|
||||
c := execute.GetExeCtx(ctx)
|
||||
if c != nil {
|
||||
resumingEvent = c.NodeCtx.ResumingEvent
|
||||
}
|
||||
var previousToolES map[string]*entity.ToolInterruptEvent
|
||||
|
||||
if c != nil && c.RootCtx.ResumeEvent != nil {
|
||||
// check if we are not resuming, but previously interrupted. Interrupt immediately.
|
||||
if resumingEvent == nil {
|
||||
err := compose.ProcessState(ctx, func(ctx context.Context, state ToolInterruptEventStore) error {
|
||||
var e error
|
||||
previousToolES, e = state.GetToolInterruptEvents(c.NodeKey)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if len(previousToolES) > 0 {
|
||||
return nil, nil, compose.InterruptAndRerun
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if l.requireCheckpoint && c != nil {
|
||||
checkpointID := fmt.Sprintf("%d_%s", c.RootCtx.RootExecuteID, c.NodeCtx.NodeKey)
|
||||
composeOpts = append(composeOpts, compose.WithCheckPointID(checkpointID))
|
||||
}
|
||||
|
||||
llmOpts := &Options{}
|
||||
for _, opt := range opts {
|
||||
opt(llmOpts)
|
||||
}
|
||||
|
||||
nestedOpts := &nodes.NestedWorkflowOptions{}
|
||||
for _, opt := range llmOpts.nested {
|
||||
opt(nestedOpts)
|
||||
}
|
||||
|
||||
composeOpts = append(composeOpts, nestedOpts.GetOptsForNested()...)
|
||||
|
||||
if resumingEvent != nil {
|
||||
var (
|
||||
resumeData string
|
||||
e error
|
||||
allIEs = make(map[string]*entity.ToolInterruptEvent)
|
||||
)
|
||||
err = compose.ProcessState(ctx, func(ctx context.Context, state ToolInterruptEventStore) error {
|
||||
allIEs, e = state.GetToolInterruptEvents(c.NodeKey)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
|
||||
allIEs = maps.Clone(allIEs)
|
||||
|
||||
resumeData, e = state.ResumeToolInterruptEvent(c.NodeKey, resumingEvent.ToolInterruptEvent.ToolCallID)
|
||||
|
||||
return e
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
composeOpts = append(composeOpts, compose.WithToolsNodeOption(
|
||||
compose.WithToolOption(
|
||||
execute.WithResume(&entity.ResumeRequest{
|
||||
ExecuteID: resumingEvent.ToolInterruptEvent.ExecuteID,
|
||||
EventID: resumingEvent.ToolInterruptEvent.ID,
|
||||
ResumeData: resumeData,
|
||||
}, allIEs))))
|
||||
|
||||
chatModelHandler := callbacks2.NewHandlerHelper().ChatModel(&callbacks2.ModelCallbackHandler{
|
||||
OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.CallbackInput) context.Context {
|
||||
if runInfo.Name != agentModelName {
|
||||
return ctx
|
||||
}
|
||||
|
||||
// react agent loops back to chat model after resuming,
|
||||
// pop the previous interrupt event immediately
|
||||
ie, deleted, e := workflow.GetRepository().PopFirstInterruptEvent(ctx, c.RootExecuteID)
|
||||
if e != nil {
|
||||
logs.CtxErrorf(ctx, "failed to pop first interrupt event on react agent chatmodel start: %v", err)
|
||||
return ctx
|
||||
}
|
||||
|
||||
if !deleted {
|
||||
logs.CtxErrorf(ctx, "failed to pop first interrupt event on react agent chatmodel start: not deleted")
|
||||
return ctx
|
||||
}
|
||||
|
||||
if ie.ID != resumingEvent.ID {
|
||||
logs.CtxErrorf(ctx, "failed to pop first interrupt event on react agent chatmodel start, "+
|
||||
"deleted ID: %d, resumingEvent ID: %d", ie.ID, resumingEvent.ID)
|
||||
return ctx
|
||||
}
|
||||
|
||||
return ctx
|
||||
},
|
||||
}).Handler()
|
||||
|
||||
composeOpts = append(composeOpts, compose.WithCallbacks(chatModelHandler))
|
||||
}
|
||||
|
||||
if c != nil {
|
||||
exeCfg := c.ExeCfg
|
||||
composeOpts = append(composeOpts, compose.WithToolsNodeOption(compose.WithToolOption(execute.WithExecuteConfig(exeCfg))))
|
||||
}
|
||||
|
||||
if llmOpts.toolWorkflowSW != nil {
|
||||
toolMsgOpt, toolMsgSR := execute.WithMessagePipe()
|
||||
composeOpts = append(composeOpts, toolMsgOpt)
|
||||
|
||||
safego.Go(ctx, func() {
|
||||
defer toolMsgSR.Close()
|
||||
for {
|
||||
msg, err := toolMsgSR.Recv()
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return
|
||||
}
|
||||
logs.CtxErrorf(ctx, "failed to receive message from tool workflow: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
logs.Infof("received message from tool workflow: %+v", msg)
|
||||
|
||||
llmOpts.toolWorkflowSW.Send(msg, nil)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
resolvedSources, err := nodes.ResolveStreamSources(ctx, l.fullSources)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var nodeKey vo.NodeKey
|
||||
if c != nil && c.NodeCtx != nil {
|
||||
nodeKey = c.NodeCtx.NodeKey
|
||||
}
|
||||
ctxcache.Store(ctx, fmt.Sprintf(sourceKey, nodeKey), resolvedSources)
|
||||
|
||||
return composeOpts, resumingEvent, nil
|
||||
}
|
||||
|
||||
func handleInterrupt(ctx context.Context, err error, resumingEvent *entity.InterruptEvent) error {
|
||||
info, ok := compose.ExtractInterruptInfo(err)
|
||||
if !ok {
|
||||
return err
|
||||
}
|
||||
|
||||
info = info.SubGraphs["llm"] // 'llm' is the node key of the react agent
|
||||
var extra any
|
||||
for i := range info.RerunNodesExtra {
|
||||
extra = info.RerunNodesExtra[i]
|
||||
break
|
||||
}
|
||||
|
||||
toolsNodeExtra, ok := extra.(*compose.ToolsInterruptAndRerunExtra)
|
||||
if !ok {
|
||||
return fmt.Errorf("llm rerun node extra type expected to be ToolsInterruptAndRerunExtra, actual: %T", extra)
|
||||
}
|
||||
id, err := workflow.GetRepository().GenID(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var (
|
||||
previousInterruptedCallID string
|
||||
highPriorityEvent *entity.ToolInterruptEvent
|
||||
)
|
||||
if resumingEvent != nil {
|
||||
previousInterruptedCallID = resumingEvent.ToolInterruptEvent.ToolCallID
|
||||
}
|
||||
|
||||
c := execute.GetExeCtx(ctx)
|
||||
|
||||
toolIEs := make([]*entity.ToolInterruptEvent, 0, len(toolsNodeExtra.RerunExtraMap))
|
||||
for callID := range toolsNodeExtra.RerunExtraMap {
|
||||
subIE, ok := toolsNodeExtra.RerunExtraMap[callID].(*entity.ToolInterruptEvent)
|
||||
if !ok {
|
||||
return fmt.Errorf("llm rerun node extra type expected to be ToolInterruptEvent, actual: %T", extra)
|
||||
}
|
||||
|
||||
if subIE.ExecuteID == 0 {
|
||||
subIE.ExecuteID = c.RootExecuteID
|
||||
}
|
||||
|
||||
toolIEs = append(toolIEs, subIE)
|
||||
if subIE.ToolCallID == previousInterruptedCallID {
|
||||
highPriorityEvent = subIE
|
||||
}
|
||||
}
|
||||
|
||||
ie := &entity.InterruptEvent{
|
||||
ID: id,
|
||||
NodeKey: c.NodeKey,
|
||||
NodeType: entity.NodeTypeLLM,
|
||||
NodeTitle: c.NodeName,
|
||||
NodeIcon: entity.NodeMetaByNodeType(entity.NodeTypeLLM).IconURL,
|
||||
EventType: entity.InterruptEventLLM,
|
||||
}
|
||||
|
||||
if highPriorityEvent != nil {
|
||||
ie.ToolInterruptEvent = highPriorityEvent
|
||||
} else {
|
||||
ie.ToolInterruptEvent = toolIEs[0]
|
||||
}
|
||||
|
||||
err = compose.ProcessState(ctx, func(ctx context.Context, ieStore ToolInterruptEventStore) error {
|
||||
for i := range toolIEs {
|
||||
e := ieStore.SetToolInterruptEvent(c.NodeKey, toolIEs[i].ToolCallID, toolIEs[i])
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return compose.NewInterruptAndRerunErr(ie)
|
||||
}
|
||||
|
||||
func (l *LLM) Chat(ctx context.Context, in map[string]any, opts ...Option) (out map[string]any, err error) {
|
||||
composeOpts, resumingEvent, err := l.prepare(ctx, in, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out, err = l.r.Invoke(ctx, in, composeOpts...)
|
||||
if err != nil {
|
||||
err = handleInterrupt(ctx, err, resumingEvent)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (l *LLM) ChatStream(ctx context.Context, in map[string]any, opts ...Option) (out *schema.StreamReader[map[string]any], err error) {
|
||||
composeOpts, resumingEvent, err := l.prepare(ctx, in, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out, err = l.r.Stream(ctx, in, composeOpts...)
|
||||
if err != nil {
|
||||
err = handleInterrupt(ctx, err, resumingEvent)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func injectKnowledgeTool(_ context.Context, g *compose.Graph[map[string]any, map[string]any], userPrompt string, cfg *KnowledgeRecallConfig) error {
|
||||
selectedKwDetails, err := sonic.MarshalString(cfg.SelectedKnowledgeDetails)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_ = g.AddChatTemplateNode(knowledgeTemplateKey,
|
||||
prompt.FromMessages(schema.Jinja2,
|
||||
schema.SystemMessage(fmt.Sprintf(knowledgeIntentPrompt, selectedKwDetails, userPrompt)),
|
||||
), compose.WithStatePreHandler(func(ctx context.Context, in map[string]any, state llmState) (map[string]any, error) {
|
||||
for k, v := range in {
|
||||
state[k] = v
|
||||
}
|
||||
return in, nil
|
||||
}))
|
||||
_ = g.AddChatModelNode(knowledgeChatModelKey, cfg.ChatModel)
|
||||
|
||||
_ = g.AddLambdaNode(knowledgeLambdaKey, compose.InvokableLambda(func(ctx context.Context, input *schema.Message) (output map[string]any, err error) {
|
||||
modelPredictionIDs := strings.Split(input.Content, ",")
|
||||
selectKwIDs := slices.ToMap(cfg.SelectedKnowledgeDetails, func(e *crossknowledge.KnowledgeDetail) (string, int64) {
|
||||
return strconv.Itoa(int(e.ID)), e.ID
|
||||
})
|
||||
recallKnowledgeIDs := make([]int64, 0)
|
||||
for _, id := range modelPredictionIDs {
|
||||
if kid, ok := selectKwIDs[id]; ok {
|
||||
recallKnowledgeIDs = append(recallKnowledgeIDs, kid)
|
||||
}
|
||||
}
|
||||
|
||||
if len(recallKnowledgeIDs) == 0 {
|
||||
return make(map[string]any), nil
|
||||
}
|
||||
|
||||
docs, err := cfg.Retriever.Retrieve(ctx, &crossknowledge.RetrieveRequest{
|
||||
Query: userPrompt,
|
||||
KnowledgeIDs: recallKnowledgeIDs,
|
||||
RetrievalStrategy: cfg.RetrievalStrategy.RetrievalStrategy,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(docs.Slices) == 0 && cfg.RetrievalStrategy.NoReCallReplyMode == NoReCallReplyModeOfDefault {
|
||||
return make(map[string]any), nil
|
||||
}
|
||||
|
||||
sb := strings.Builder{}
|
||||
if len(docs.Slices) == 0 && cfg.RetrievalStrategy.NoReCallReplyMode == NoReCallReplyModeOfCustomize {
|
||||
sb.WriteString("recall slice 1: \n")
|
||||
sb.WriteString(cfg.RetrievalStrategy.NoReCallReplyCustomizePrompt + "\n")
|
||||
}
|
||||
|
||||
for idx, msg := range docs.Slices {
|
||||
sb.WriteString(fmt.Sprintf("recall slice %d:\n", idx+1))
|
||||
sb.WriteString(fmt.Sprintf("%s\n", msg.Output))
|
||||
}
|
||||
|
||||
output = map[string]any{
|
||||
knowledgeUserPromptTemplateKey: fmt.Sprintf(knowledgeUserPromptTemplate, sb.String()),
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}))
|
||||
_ = g.AddEdge(compose.START, knowledgeTemplateKey)
|
||||
_ = g.AddEdge(knowledgeTemplateKey, knowledgeChatModelKey)
|
||||
_ = g.AddEdge(knowledgeChatModelKey, knowledgeLambdaKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
type ToolInterruptEventStore interface {
|
||||
SetToolInterruptEvent(llmNodeKey vo.NodeKey, toolCallID string, ie *entity.ToolInterruptEvent) error
|
||||
GetToolInterruptEvents(llmNodeKey vo.NodeKey) (map[string]*entity.ToolInterruptEvent, error)
|
||||
ResumeToolInterruptEvent(llmNodeKey vo.NodeKey, toolCallID string) (string, error)
|
||||
}
|
||||
|
||||
func (l *LLM) ToCallbackOutput(ctx context.Context, output map[string]any) (*nodes.StructuredCallbackOutput, error) {
|
||||
c := execute.GetExeCtx(ctx)
|
||||
if c == nil {
|
||||
return &nodes.StructuredCallbackOutput{
|
||||
Output: output,
|
||||
RawOutput: output,
|
||||
}, nil
|
||||
}
|
||||
rawOutputK := fmt.Sprintf(rawOutputKey, c.NodeKey)
|
||||
warningK := fmt.Sprintf(warningKey, c.NodeKey)
|
||||
rawOutput, found := ctxcache.Get[string](ctx, rawOutputK)
|
||||
if !found {
|
||||
return &nodes.StructuredCallbackOutput{
|
||||
Output: output,
|
||||
RawOutput: output,
|
||||
}, nil
|
||||
}
|
||||
|
||||
warning, found := ctxcache.Get[vo.WorkflowError](ctx, warningK)
|
||||
if !found {
|
||||
return &nodes.StructuredCallbackOutput{
|
||||
Output: output,
|
||||
RawOutput: map[string]any{"output": rawOutput},
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &nodes.StructuredCallbackOutput{
|
||||
Output: output,
|
||||
RawOutput: map[string]any{"output": rawOutput},
|
||||
Error: warning,
|
||||
}, nil
|
||||
}
|
||||
184
backend/domain/workflow/internal/nodes/llm/model_with_info.go
Normal file
184
backend/domain/workflow/internal/nodes/llm/model_with_info.go
Normal file
@@ -0,0 +1,184 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/cloudwego/eino/callbacks"
|
||||
"github.com/cloudwego/eino/components"
|
||||
"github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
crossmodelmgr "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
)
|
||||
|
||||
type ModelWithInfo interface {
|
||||
model.BaseChatModel
|
||||
Info(ctx context.Context) *crossmodelmgr.Model
|
||||
}
|
||||
|
||||
type ModelForLLM struct {
|
||||
Model model.BaseChatModel
|
||||
MInfo *crossmodelmgr.Model
|
||||
FallbackModel model.BaseChatModel
|
||||
FallbackInfo *crossmodelmgr.Model
|
||||
UseFallback func(ctx context.Context) bool
|
||||
|
||||
modelEnableCallback bool
|
||||
fallbackEnableCallback bool
|
||||
}
|
||||
|
||||
func NewModel(m model.BaseChatModel, info *crossmodelmgr.Model) *ModelForLLM {
|
||||
return &ModelForLLM{
|
||||
Model: m,
|
||||
MInfo: info,
|
||||
UseFallback: func(ctx context.Context) bool {
|
||||
return false
|
||||
},
|
||||
|
||||
modelEnableCallback: components.IsCallbacksEnabled(m),
|
||||
}
|
||||
}
|
||||
|
||||
func NewModelWithFallback(m, f model.BaseChatModel, info, fInfo *crossmodelmgr.Model) *ModelForLLM {
|
||||
return &ModelForLLM{
|
||||
Model: m,
|
||||
MInfo: info,
|
||||
FallbackModel: f,
|
||||
FallbackInfo: fInfo,
|
||||
UseFallback: func(ctx context.Context) bool {
|
||||
exeCtx := execute.GetExeCtx(ctx)
|
||||
if exeCtx == nil || exeCtx.NodeCtx == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return exeCtx.CurrentRetryCount > 0
|
||||
},
|
||||
|
||||
modelEnableCallback: components.IsCallbacksEnabled(m),
|
||||
fallbackEnableCallback: components.IsCallbacksEnabled(f),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *ModelForLLM) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (
|
||||
output *schema.Message, err error) {
|
||||
if m.UseFallback(ctx) {
|
||||
if !m.fallbackEnableCallback {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = callbacks.OnError(ctx, err)
|
||||
} else {
|
||||
_ = callbacks.OnEnd(ctx, output)
|
||||
}
|
||||
}()
|
||||
ctx = callbacks.OnStart(ctx, input)
|
||||
}
|
||||
return m.FallbackModel.Generate(ctx, input, opts...)
|
||||
}
|
||||
|
||||
if !m.modelEnableCallback {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = callbacks.OnError(ctx, err)
|
||||
} else {
|
||||
_ = callbacks.OnEnd(ctx, output)
|
||||
}
|
||||
}()
|
||||
ctx = callbacks.OnStart(ctx, input)
|
||||
}
|
||||
return m.Model.Generate(ctx, input, opts...)
|
||||
}
|
||||
|
||||
func (m *ModelForLLM) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (
|
||||
output *schema.StreamReader[*schema.Message], err error) {
|
||||
if m.UseFallback(ctx) {
|
||||
if !m.fallbackEnableCallback {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = callbacks.OnError(ctx, err)
|
||||
} else {
|
||||
_, output = callbacks.OnEndWithStreamOutput(ctx, output)
|
||||
}
|
||||
}()
|
||||
ctx = callbacks.OnStart(ctx, input)
|
||||
}
|
||||
return m.FallbackModel.Stream(ctx, input, opts...)
|
||||
}
|
||||
|
||||
if !m.modelEnableCallback {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = callbacks.OnError(ctx, err)
|
||||
} else {
|
||||
_, output = callbacks.OnEndWithStreamOutput(ctx, output)
|
||||
}
|
||||
}()
|
||||
ctx = callbacks.OnStart(ctx, input)
|
||||
}
|
||||
return m.Model.Stream(ctx, input, opts...)
|
||||
}
|
||||
|
||||
func (m *ModelForLLM) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatModel, error) {
|
||||
toolModel, ok := m.Model.(model.ToolCallingChatModel)
|
||||
if !ok {
|
||||
return nil, errors.New("requires a ToolCallingChatModel to use with tools")
|
||||
}
|
||||
|
||||
var err error
|
||||
toolModel, err = toolModel.WithTools(tools)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var fallbackToolModel model.ToolCallingChatModel
|
||||
if m.FallbackModel != nil {
|
||||
fallbackToolModel, ok = m.FallbackModel.(model.ToolCallingChatModel)
|
||||
if !ok {
|
||||
return nil, errors.New("requires a ToolCallingChatModel to use with tools")
|
||||
}
|
||||
|
||||
fallbackToolModel, err = fallbackToolModel.WithTools(tools)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &ModelForLLM{
|
||||
Model: toolModel,
|
||||
MInfo: m.MInfo,
|
||||
FallbackModel: fallbackToolModel,
|
||||
FallbackInfo: m.FallbackInfo,
|
||||
UseFallback: m.UseFallback,
|
||||
modelEnableCallback: m.modelEnableCallback,
|
||||
fallbackEnableCallback: m.fallbackEnableCallback,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *ModelForLLM) IsCallbacksEnabled() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *ModelForLLM) Info(ctx context.Context) *crossmodelmgr.Model {
|
||||
if m.UseFallback(ctx) {
|
||||
return m.FallbackInfo
|
||||
}
|
||||
|
||||
return m.MInfo
|
||||
}
|
||||
287
backend/domain/workflow/internal/nodes/llm/prompt.go
Normal file
287
backend/domain/workflow/internal/nodes/llm/prompt.go
Normal file
@@ -0,0 +1,287 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/cloudwego/eino/components/prompt"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
)
|
||||
|
||||
type prompts struct {
|
||||
sp *promptTpl
|
||||
up *promptTpl
|
||||
mwi ModelWithInfo
|
||||
}
|
||||
|
||||
type promptTpl struct {
|
||||
role schema.RoleType
|
||||
tpl string
|
||||
parts []promptPart
|
||||
hasMultiModal bool
|
||||
reservedKeys []string
|
||||
}
|
||||
|
||||
type promptPart struct {
|
||||
part nodes.TemplatePart
|
||||
fileType *vo.FileSubType
|
||||
}
|
||||
|
||||
func newPromptTpl(role schema.RoleType,
|
||||
tpl string,
|
||||
inputTypes map[string]*vo.TypeInfo,
|
||||
reservedKeys []string,
|
||||
) *promptTpl {
|
||||
if len(tpl) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
parts := nodes.ParseTemplate(tpl)
|
||||
promptParts := make([]promptPart, 0, len(parts))
|
||||
hasMultiModal := false
|
||||
for _, part := range parts {
|
||||
if !part.IsVariable {
|
||||
promptParts = append(promptParts, promptPart{
|
||||
part: part,
|
||||
})
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
tInfo := part.TypeInfo(inputTypes)
|
||||
if tInfo == nil || tInfo.Type != vo.DataTypeFile {
|
||||
promptParts = append(promptParts, promptPart{
|
||||
part: part,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
promptParts = append(promptParts, promptPart{
|
||||
part: part,
|
||||
fileType: tInfo.FileType,
|
||||
})
|
||||
|
||||
hasMultiModal = true
|
||||
}
|
||||
|
||||
return &promptTpl{
|
||||
role: role,
|
||||
tpl: tpl,
|
||||
parts: promptParts,
|
||||
hasMultiModal: hasMultiModal,
|
||||
reservedKeys: reservedKeys,
|
||||
}
|
||||
}
|
||||
|
||||
const sourceKey = "sources_%s"
|
||||
|
||||
func newPrompts(sp, up *promptTpl, model ModelWithInfo) *prompts {
|
||||
return &prompts{
|
||||
sp: sp,
|
||||
up: up,
|
||||
mwi: model,
|
||||
}
|
||||
}
|
||||
|
||||
func (pl *promptTpl) render(ctx context.Context, vs map[string]any,
|
||||
sources map[string]*nodes.SourceInfo,
|
||||
supportedModals map[modelmgr.Modal]bool) (*schema.Message, error) {
|
||||
if !pl.hasMultiModal || len(supportedModals) == 0 {
|
||||
var opts []nodes.RenderOption
|
||||
if len(pl.reservedKeys) > 0 {
|
||||
opts = append(opts, nodes.WithReservedKey(pl.reservedKeys...))
|
||||
}
|
||||
r, err := nodes.Render(ctx, pl.tpl, vs, sources, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &schema.Message{
|
||||
Role: pl.role,
|
||||
Content: r,
|
||||
}, nil
|
||||
}
|
||||
|
||||
multiParts := make([]schema.ChatMessagePart, 0, len(pl.parts))
|
||||
m, err := sonic.Marshal(vs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, part := range pl.parts {
|
||||
if !part.part.IsVariable {
|
||||
multiParts = append(multiParts, schema.ChatMessagePart{
|
||||
Type: schema.ChatMessagePartTypeText,
|
||||
Text: part.part.Value,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
skipped, invalid := part.part.Skipped(sources)
|
||||
if invalid {
|
||||
var reserved bool
|
||||
for _, k := range pl.reservedKeys {
|
||||
if k == part.part.Root {
|
||||
reserved = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !reserved {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if skipped {
|
||||
continue
|
||||
}
|
||||
|
||||
r, err := part.part.Render(m)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if part.fileType == nil {
|
||||
multiParts = append(multiParts, schema.ChatMessagePart{
|
||||
Type: schema.ChatMessagePartTypeText,
|
||||
Text: r,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
switch *part.fileType {
|
||||
case vo.FileTypeImage, vo.FileTypeSVG:
|
||||
if _, ok := supportedModals[modelmgr.ModalImage]; !ok {
|
||||
multiParts = append(multiParts, schema.ChatMessagePart{
|
||||
Type: schema.ChatMessagePartTypeText,
|
||||
Text: r,
|
||||
})
|
||||
} else {
|
||||
multiParts = append(multiParts, schema.ChatMessagePart{
|
||||
Type: schema.ChatMessagePartTypeImageURL,
|
||||
ImageURL: &schema.ChatMessageImageURL{
|
||||
URL: r,
|
||||
},
|
||||
})
|
||||
}
|
||||
case vo.FileTypeAudio, vo.FileTypeVoice:
|
||||
if _, ok := supportedModals[modelmgr.ModalAudio]; !ok {
|
||||
multiParts = append(multiParts, schema.ChatMessagePart{
|
||||
Type: schema.ChatMessagePartTypeText,
|
||||
Text: r,
|
||||
})
|
||||
} else {
|
||||
multiParts = append(multiParts, schema.ChatMessagePart{
|
||||
Type: schema.ChatMessagePartTypeAudioURL,
|
||||
AudioURL: &schema.ChatMessageAudioURL{
|
||||
URL: r,
|
||||
},
|
||||
})
|
||||
}
|
||||
case vo.FileTypeVideo:
|
||||
if _, ok := supportedModals[modelmgr.ModalVideo]; !ok {
|
||||
multiParts = append(multiParts, schema.ChatMessagePart{
|
||||
Type: schema.ChatMessagePartTypeText,
|
||||
Text: r,
|
||||
})
|
||||
} else {
|
||||
multiParts = append(multiParts, schema.ChatMessagePart{
|
||||
Type: schema.ChatMessagePartTypeVideoURL,
|
||||
VideoURL: &schema.ChatMessageVideoURL{
|
||||
URL: r,
|
||||
},
|
||||
})
|
||||
}
|
||||
default:
|
||||
if _, ok := supportedModals[modelmgr.ModalFile]; !ok {
|
||||
multiParts = append(multiParts, schema.ChatMessagePart{
|
||||
Type: schema.ChatMessagePartTypeText,
|
||||
Text: r,
|
||||
})
|
||||
} else {
|
||||
multiParts = append(multiParts, schema.ChatMessagePart{
|
||||
Type: schema.ChatMessagePartTypeFileURL,
|
||||
FileURL: &schema.ChatMessageFileURL{
|
||||
URL: r,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &schema.Message{
|
||||
Role: pl.role,
|
||||
MultiContent: multiParts,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *prompts) Format(ctx context.Context, vs map[string]any, _ ...prompt.Option) (
|
||||
_ []*schema.Message, err error) {
|
||||
exeCtx := execute.GetExeCtx(ctx)
|
||||
var nodeKey vo.NodeKey
|
||||
if exeCtx != nil && exeCtx.NodeCtx != nil {
|
||||
nodeKey = exeCtx.NodeCtx.NodeKey
|
||||
}
|
||||
sk := fmt.Sprintf(sourceKey, nodeKey)
|
||||
|
||||
sources, ok := ctxcache.Get[map[string]*nodes.SourceInfo](ctx, sk)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("resolved sources not found llm node, key: %s", sk)
|
||||
}
|
||||
|
||||
supportedModal := map[modelmgr.Modal]bool{}
|
||||
mInfo := p.mwi.Info(ctx)
|
||||
if mInfo != nil {
|
||||
for i := range mInfo.Meta.Capability.InputModal {
|
||||
supportedModal[mInfo.Meta.Capability.InputModal[i]] = true
|
||||
}
|
||||
}
|
||||
|
||||
var systemMsg, userMsg *schema.Message
|
||||
if p.sp != nil {
|
||||
systemMsg, err = p.sp.render(ctx, vs, sources, supportedModal)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if p.up != nil {
|
||||
userMsg, err = p.up.render(ctx, vs, sources, supportedModal)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if userMsg == nil {
|
||||
// give it a default empty message.
|
||||
// Some model may fail on empty message such as this one.
|
||||
userMsg = schema.UserMessage("")
|
||||
}
|
||||
|
||||
if systemMsg == nil {
|
||||
return []*schema.Message{userMsg}, nil
|
||||
}
|
||||
|
||||
return []*schema.Message{systemMsg, userMsg}, nil
|
||||
}
|
||||
45
backend/domain/workflow/internal/nodes/loop/break.go
Normal file
45
backend/domain/workflow/internal/nodes/loop/break.go
Normal file
@@ -0,0 +1,45 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package loop
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
|
||||
)
|
||||
|
||||
type Break struct {
|
||||
parentIntermediateStore variable.Store
|
||||
}
|
||||
|
||||
func NewBreak(_ context.Context, store variable.Store) (*Break, error) {
|
||||
return &Break{
|
||||
parentIntermediateStore: store,
|
||||
}, nil
|
||||
}
|
||||
|
||||
const BreakKey = "$break"
|
||||
|
||||
func (b *Break) DoBreak(ctx context.Context, _ map[string]any) (map[string]any, error) {
|
||||
err := b.parentIntermediateStore.Set(ctx, compose.FieldPath{BreakKey}, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return map[string]any{}, nil
|
||||
}
|
||||
424
backend/domain/workflow/internal/nodes/loop/loop.go
Normal file
424
backend/domain/workflow/internal/nodes/loop/loop.go
Normal file
@@ -0,0 +1,424 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package loop
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
type Loop struct {
|
||||
config *Config
|
||||
outputs map[string]*vo.FieldSource
|
||||
outputVars map[string]string
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
LoopNodeKey vo.NodeKey
|
||||
LoopType Type
|
||||
InputArrays []string
|
||||
IntermediateVars map[string]*vo.TypeInfo
|
||||
Outputs []*vo.FieldInfo
|
||||
|
||||
Inner compose.Runnable[map[string]any, map[string]any]
|
||||
}
|
||||
|
||||
type Type string
|
||||
|
||||
const (
|
||||
ByArray Type = "by_array"
|
||||
ByIteration Type = "by_iteration"
|
||||
Infinite Type = "infinite"
|
||||
)
|
||||
|
||||
func NewLoop(_ context.Context, conf *Config) (*Loop, error) {
|
||||
if conf == nil {
|
||||
return nil, errors.New("config is nil")
|
||||
}
|
||||
|
||||
if conf.LoopType == ByArray {
|
||||
if len(conf.InputArrays) == 0 {
|
||||
return nil, errors.New("input arrays is empty when loop type is ByArray")
|
||||
}
|
||||
}
|
||||
|
||||
loop := &Loop{
|
||||
config: conf,
|
||||
outputs: make(map[string]*vo.FieldSource),
|
||||
outputVars: make(map[string]string),
|
||||
}
|
||||
|
||||
for _, info := range conf.Outputs {
|
||||
if len(info.Path) != 1 {
|
||||
return nil, fmt.Errorf("invalid output path: %s", info.Path)
|
||||
}
|
||||
|
||||
k := info.Path[0]
|
||||
fromPath := info.Source.Ref.FromPath
|
||||
|
||||
if info.Source.Ref != nil && info.Source.Ref.VariableType != nil &&
|
||||
*info.Source.Ref.VariableType == vo.ParentIntermediate {
|
||||
if len(fromPath) > 1 {
|
||||
return nil, fmt.Errorf("loop output refers to intermediate variable, but path length > 1: %v", fromPath)
|
||||
}
|
||||
|
||||
if _, ok := conf.IntermediateVars[fromPath[0]]; !ok {
|
||||
return nil, fmt.Errorf("loop output refers to intermediate variable, but not found in intermediate vars: %v", fromPath)
|
||||
}
|
||||
|
||||
loop.outputVars[k] = fromPath[0]
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
loop.outputs[k] = &info.Source
|
||||
}
|
||||
|
||||
return loop, nil
|
||||
}
|
||||
|
||||
const (
|
||||
Count = "loopCount"
|
||||
)
|
||||
|
||||
func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.NestedWorkflowOption) (out map[string]any, err error) {
|
||||
maxIter, err := l.getMaxIter(in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
arrays := make(map[string][]any, len(l.config.InputArrays))
|
||||
for _, arrayKey := range l.config.InputArrays {
|
||||
a, ok := nodes.TakeMapValue(in, compose.FieldPath{arrayKey})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("incoming array not present in input: %s", arrayKey)
|
||||
}
|
||||
arrays[arrayKey] = a.([]any)
|
||||
}
|
||||
|
||||
options := &nodes.NestedWorkflowOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
var (
|
||||
existingCState *nodes.NestedWorkflowState
|
||||
intermediateVars map[string]*any
|
||||
output map[string]any
|
||||
hasBreak = any(false)
|
||||
)
|
||||
err = compose.ProcessState(ctx, func(ctx context.Context, getter nodes.NestedWorkflowAware) error {
|
||||
var e error
|
||||
existingCState, _, e = getter.GetNestedWorkflowState(l.config.LoopNodeKey)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if existingCState != nil {
|
||||
output = existingCState.FullOutput
|
||||
intermediateVars = make(map[string]*any, len(existingCState.IntermediateVars))
|
||||
for k := range existingCState.IntermediateVars {
|
||||
intermediateVars[k] = ptr.Of(existingCState.IntermediateVars[k])
|
||||
}
|
||||
intermediateVars[BreakKey] = &hasBreak
|
||||
} else {
|
||||
output = make(map[string]any, len(l.outputs))
|
||||
for k := range l.outputs {
|
||||
output[k] = make([]any, 0)
|
||||
}
|
||||
|
||||
intermediateVars = make(map[string]*any, len(l.config.IntermediateVars))
|
||||
for varKey := range l.config.IntermediateVars {
|
||||
v, ok := nodes.TakeMapValue(in, compose.FieldPath{varKey})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("incoming intermediate variable not present in input: %s", varKey)
|
||||
}
|
||||
|
||||
intermediateVars[varKey] = &v
|
||||
}
|
||||
intermediateVars[BreakKey] = &hasBreak
|
||||
}
|
||||
|
||||
ctx = nodes.InitIntermediateVars(ctx, intermediateVars, l.config.IntermediateVars)
|
||||
|
||||
getIthInput := func(i int) (map[string]any, map[string]any, error) {
|
||||
input := make(map[string]any)
|
||||
|
||||
for k, v := range in { // carry over other values
|
||||
if k == Count {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := arrays[k]; ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := intermediateVars[k]; ok {
|
||||
continue
|
||||
}
|
||||
|
||||
input[k] = v
|
||||
}
|
||||
|
||||
input[string(l.config.LoopNodeKey)+"#index"] = int64(i)
|
||||
|
||||
items := make(map[string]any)
|
||||
for arrayKey := range arrays {
|
||||
ele := arrays[arrayKey][i]
|
||||
items[arrayKey] = ele
|
||||
currentKey := string(l.config.LoopNodeKey) + "#" + arrayKey
|
||||
|
||||
// Recursively expand map[string]any elements
|
||||
if m, ok := ele.(map[string]any); ok {
|
||||
var expand func(prefix string, val interface{})
|
||||
expand = func(prefix string, val interface{}) {
|
||||
if nestedMap, ok := val.(map[string]any); ok {
|
||||
for k, v := range nestedMap {
|
||||
expand(prefix+"#"+k, v)
|
||||
}
|
||||
} else {
|
||||
input[prefix] = val
|
||||
}
|
||||
}
|
||||
expand(currentKey, m)
|
||||
} else {
|
||||
input[currentKey] = ele
|
||||
}
|
||||
}
|
||||
|
||||
return input, items, nil
|
||||
}
|
||||
|
||||
setIthOutput := func(i int, taskOutput map[string]any) {
|
||||
for arrayKey := range l.outputs {
|
||||
source := l.outputs[arrayKey]
|
||||
fromValue, ok := nodes.TakeMapValue(taskOutput, append(compose.FieldPath{string(source.Ref.FromNodeKey)}, source.Ref.FromPath...))
|
||||
if ok {
|
||||
output[arrayKey] = append(output[arrayKey].([]any), fromValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
index2Done = map[int]bool{}
|
||||
index2InterruptInfo = map[int]*compose.InterruptInfo{}
|
||||
resumed = map[int]bool{}
|
||||
)
|
||||
|
||||
for i := 0; i < maxIter; i++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err() // canceled by Eino workflow engine
|
||||
default:
|
||||
}
|
||||
|
||||
if existingCState != nil {
|
||||
if existingCState.Index2Done[i] == true {
|
||||
continue
|
||||
}
|
||||
|
||||
if existingCState.Index2InterruptInfo[i] != nil {
|
||||
if len(options.GetResumeIndexes()) > 0 {
|
||||
if _, ok := options.GetResumeIndexes()[i]; !ok {
|
||||
// previously interrupted, but not resumed this time, should not happen
|
||||
panic("impossible")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
resumed[i] = true
|
||||
}
|
||||
|
||||
input, items, err := getIthInput(i)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
subCtx, checkpointID := execute.InheritExeCtxWithBatchInfo(ctx, i, items)
|
||||
|
||||
ithOpts := options.GetOptsForNested()
|
||||
ithOpts = append(ithOpts, options.GetOptsForIndexed(i)...)
|
||||
|
||||
if checkpointID != "" {
|
||||
ithOpts = append(ithOpts, compose.WithCheckPointID(checkpointID))
|
||||
}
|
||||
|
||||
if len(options.GetResumeIndexes()) > 0 {
|
||||
stateModifier, ok := options.GetResumeIndexes()[i]
|
||||
if ok {
|
||||
fmt.Println("has state modifier for ith run: ", i, ", checkpointID: ", checkpointID)
|
||||
ithOpts = append(ithOpts, compose.WithStateModifier(stateModifier))
|
||||
}
|
||||
}
|
||||
|
||||
taskOutput, err := l.config.Inner.Invoke(subCtx, input, ithOpts...)
|
||||
if err != nil {
|
||||
info, ok := compose.ExtractInterruptInfo(err)
|
||||
if !ok {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
index2InterruptInfo[i] = info
|
||||
break
|
||||
}
|
||||
|
||||
setIthOutput(i, taskOutput)
|
||||
|
||||
index2Done[i] = true
|
||||
|
||||
if hasBreak.(bool) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// delete the interruptions that have been resumed
|
||||
for index := range resumed {
|
||||
delete(existingCState.Index2InterruptInfo, index)
|
||||
}
|
||||
|
||||
compState := existingCState
|
||||
if compState == nil {
|
||||
compState = &nodes.NestedWorkflowState{
|
||||
Index2Done: index2Done,
|
||||
Index2InterruptInfo: index2InterruptInfo,
|
||||
FullOutput: output,
|
||||
IntermediateVars: convertIntermediateVars(intermediateVars),
|
||||
}
|
||||
} else {
|
||||
for i := range index2Done {
|
||||
compState.Index2Done[i] = index2Done[i]
|
||||
}
|
||||
for i := range index2InterruptInfo {
|
||||
compState.Index2InterruptInfo[i] = index2InterruptInfo[i]
|
||||
}
|
||||
compState.FullOutput = output
|
||||
compState.IntermediateVars = convertIntermediateVars(intermediateVars)
|
||||
}
|
||||
|
||||
if len(index2InterruptInfo) > 0 { // this invocation of batch.Execute has new interruptions
|
||||
iEvent := &entity.InterruptEvent{
|
||||
NodeKey: l.config.LoopNodeKey,
|
||||
NodeType: entity.NodeTypeLoop,
|
||||
NestedInterruptInfo: index2InterruptInfo, // only emit the newly generated interruptInfo
|
||||
}
|
||||
|
||||
err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error {
|
||||
if e := setter.SaveNestedWorkflowState(l.config.LoopNodeKey, compState); e != nil {
|
||||
return e
|
||||
}
|
||||
|
||||
return setter.SetInterruptEvent(l.config.LoopNodeKey, iEvent)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fmt.Println("save interruptEvent in state within loop: ", iEvent)
|
||||
fmt.Println("save composite info in state within loop: ", compState)
|
||||
|
||||
return nil, compose.InterruptAndRerun
|
||||
} else {
|
||||
err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error {
|
||||
return setter.SaveNestedWorkflowState(l.config.LoopNodeKey, compState)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fmt.Println("save composite info in state within loop: ", compState)
|
||||
}
|
||||
|
||||
if existingCState != nil && len(existingCState.Index2InterruptInfo) > 0 {
|
||||
fmt.Println("no interrupt thrown this round, but has historical interrupt events: ", existingCState.Index2InterruptInfo)
|
||||
panic("impossible")
|
||||
}
|
||||
|
||||
for outputVarKey, intermediateVarKey := range l.outputVars {
|
||||
output[outputVarKey] = *(intermediateVars[intermediateVarKey])
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}
|
||||
|
||||
func (l *Loop) getMaxIter(in map[string]any) (int, error) {
|
||||
maxIter := math.MaxInt
|
||||
|
||||
switch l.config.LoopType {
|
||||
case ByArray:
|
||||
for _, arrayKey := range l.config.InputArrays {
|
||||
a, ok := nodes.TakeMapValue(in, compose.FieldPath{arrayKey})
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("incoming array not present in input: %s", arrayKey)
|
||||
}
|
||||
|
||||
if reflect.TypeOf(a).Kind() != reflect.Slice {
|
||||
return 0, fmt.Errorf("incoming array not a slice: %s. Actual type: %v", arrayKey, reflect.TypeOf(a))
|
||||
}
|
||||
|
||||
oneLen := reflect.ValueOf(a).Len()
|
||||
if oneLen < maxIter {
|
||||
maxIter = oneLen
|
||||
}
|
||||
}
|
||||
case ByIteration:
|
||||
iter, ok := nodes.TakeMapValue(in, compose.FieldPath{Count})
|
||||
if !ok {
|
||||
return 0, errors.New("incoming LoopCount not present in input when loop type is ByIteration")
|
||||
}
|
||||
|
||||
maxIter = int(iter.(int64))
|
||||
case Infinite:
|
||||
default:
|
||||
return 0, fmt.Errorf("loop type not supported: %v", l.config.LoopType)
|
||||
}
|
||||
|
||||
return maxIter, nil
|
||||
}
|
||||
|
||||
func convertIntermediateVars(vars map[string]*any) map[string]any {
|
||||
ret := make(map[string]any, len(vars))
|
||||
for k, v := range vars {
|
||||
ret[k] = *v
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (l *Loop) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) {
|
||||
trimmed := make(map[string]any, len(l.config.InputArrays))
|
||||
for _, arrayKey := range l.config.InputArrays {
|
||||
if v, ok := in[arrayKey]; ok {
|
||||
trimmed[arrayKey] = v
|
||||
}
|
||||
}
|
||||
return trimmed, nil
|
||||
}
|
||||
90
backend/domain/workflow/internal/nodes/nested.go
Normal file
90
backend/domain/workflow/internal/nodes/nested.go
Normal file
@@ -0,0 +1,90 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
)
|
||||
|
||||
type NestedWorkflowOptions struct {
|
||||
optsForNested []compose.Option
|
||||
toResumeIndexes map[int]compose.StateModifier
|
||||
optsForIndexed map[int][]compose.Option
|
||||
}
|
||||
|
||||
type NestedWorkflowOption func(*NestedWorkflowOptions)
|
||||
|
||||
func WithOptsForNested(opts ...compose.Option) NestedWorkflowOption {
|
||||
return func(o *NestedWorkflowOptions) {
|
||||
o.optsForNested = append(o.optsForNested, opts...)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *NestedWorkflowOptions) GetOptsForNested() []compose.Option {
|
||||
return c.optsForNested
|
||||
}
|
||||
|
||||
func WithResumeIndex(i int, m compose.StateModifier) NestedWorkflowOption {
|
||||
return func(o *NestedWorkflowOptions) {
|
||||
if o.toResumeIndexes == nil {
|
||||
o.toResumeIndexes = map[int]compose.StateModifier{}
|
||||
}
|
||||
|
||||
o.toResumeIndexes[i] = m
|
||||
}
|
||||
}
|
||||
|
||||
func (c *NestedWorkflowOptions) GetResumeIndexes() map[int]compose.StateModifier {
|
||||
return c.toResumeIndexes
|
||||
}
|
||||
|
||||
func WithOptsForIndexed(index int, opts ...compose.Option) NestedWorkflowOption {
|
||||
return func(o *NestedWorkflowOptions) {
|
||||
if o.optsForIndexed == nil {
|
||||
o.optsForIndexed = map[int][]compose.Option{}
|
||||
}
|
||||
o.optsForIndexed[index] = opts
|
||||
}
|
||||
}
|
||||
|
||||
func (c *NestedWorkflowOptions) GetOptsForIndexed(index int) []compose.Option {
|
||||
if c.optsForIndexed == nil {
|
||||
return nil
|
||||
}
|
||||
return c.optsForIndexed[index]
|
||||
}
|
||||
|
||||
type NestedWorkflowState struct {
|
||||
Index2Done map[int]bool `json:"index_2_done,omitempty"`
|
||||
Index2InterruptInfo map[int]*compose.InterruptInfo `json:"index_2_interrupt_info,omitempty"`
|
||||
FullOutput map[string]any `json:"full_output,omitempty"`
|
||||
IntermediateVars map[string]any `json:"intermediate_vars,omitempty"`
|
||||
}
|
||||
|
||||
func (c *NestedWorkflowState) String() string {
|
||||
s, _ := sonic.MarshalIndent(c, "", " ")
|
||||
return string(s)
|
||||
}
|
||||
|
||||
type NestedWorkflowAware interface {
|
||||
SaveNestedWorkflowState(key vo.NodeKey, state *NestedWorkflowState) error
|
||||
GetNestedWorkflowState(key vo.NodeKey) (*NestedWorkflowState, bool, error)
|
||||
InterruptEventStore
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
)
|
||||
|
||||
type ParentIntermediateStore struct {
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
type IntermediateVarKey struct{}
|
||||
|
||||
func (p *ParentIntermediateStore) Init(_ context.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
func (p *ParentIntermediateStore) Get(ctx context.Context, path compose.FieldPath, opts ...variable.OptionFn) (any, error) {
|
||||
defer p.mu.RUnlock()
|
||||
p.mu.RLock()
|
||||
|
||||
if len(path) != 1 {
|
||||
return nil, fmt.Errorf("invalid path: %v", path)
|
||||
}
|
||||
|
||||
ivs := getIntermediateVars(ctx)
|
||||
v, ok := ivs.vars[path[0]]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("variable not found: %s", path[0])
|
||||
}
|
||||
|
||||
if *v == nil {
|
||||
return ivs.types[path[0]].Zero(), nil
|
||||
}
|
||||
|
||||
return *v, nil
|
||||
}
|
||||
|
||||
func (p *ParentIntermediateStore) Set(ctx context.Context, path compose.FieldPath, value any, opts ...variable.OptionFn) error {
|
||||
defer p.mu.Unlock()
|
||||
p.mu.Lock()
|
||||
|
||||
if len(path) != 1 {
|
||||
return fmt.Errorf("invalid path: %v", path)
|
||||
}
|
||||
|
||||
ivs := getIntermediateVars(ctx)
|
||||
v, ok := ivs.vars[path[0]]
|
||||
if !ok {
|
||||
return fmt.Errorf("variable not found: %s", path[0])
|
||||
}
|
||||
|
||||
if value == nil {
|
||||
*v = ivs.types[path[0]].Zero()
|
||||
} else {
|
||||
*v = value
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type intermediateVar struct {
|
||||
vars map[string]*any
|
||||
types map[string]*vo.TypeInfo
|
||||
}
|
||||
|
||||
func InitIntermediateVars(ctx context.Context, vars map[string]*any, typeInfos map[string]*vo.TypeInfo) context.Context {
|
||||
return context.WithValue(ctx, IntermediateVarKey{}, &intermediateVar{
|
||||
vars: vars,
|
||||
types: typeInfos,
|
||||
})
|
||||
}
|
||||
|
||||
func getIntermediateVars(ctx context.Context) *intermediateVar {
|
||||
return ctx.Value(IntermediateVarKey{}).(*intermediateVar)
|
||||
}
|
||||
83
backend/domain/workflow/internal/nodes/plugin/plugin.go
Normal file
83
backend/domain/workflow/internal/nodes/plugin/plugin.go
Normal file
@@ -0,0 +1,83 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package plugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
PluginID int64
|
||||
ToolID int64
|
||||
PluginVersion string
|
||||
|
||||
PluginService plugin.Service
|
||||
}
|
||||
|
||||
type Plugin struct {
|
||||
config *Config
|
||||
}
|
||||
|
||||
func NewPlugin(_ context.Context, cfg *Config) (*Plugin, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config is nil")
|
||||
}
|
||||
if cfg.PluginID == 0 {
|
||||
return nil, errors.New("plugin id is required")
|
||||
}
|
||||
if cfg.ToolID == 0 {
|
||||
return nil, errors.New("tool id is required")
|
||||
}
|
||||
if cfg.PluginService == nil {
|
||||
return nil, errors.New("tool service is required")
|
||||
}
|
||||
|
||||
return &Plugin{config: cfg}, nil
|
||||
}
|
||||
|
||||
func (p *Plugin) Invoke(ctx context.Context, parameters map[string]any) (ret map[string]any, err error) {
|
||||
var exeCfg vo.ExecuteConfig
|
||||
if ctxExeCfg := execute.GetExeCtx(ctx); ctxExeCfg != nil {
|
||||
exeCfg = ctxExeCfg.ExeCfg
|
||||
}
|
||||
result, err := p.config.PluginService.ExecutePlugin(ctx, parameters, &vo.PluginEntity{
|
||||
PluginID: p.config.PluginID,
|
||||
PluginVersion: ptr.Of(p.config.PluginVersion),
|
||||
}, p.config.ToolID, exeCfg)
|
||||
if err != nil {
|
||||
if extra, ok := compose.IsInterruptRerunError(err); ok {
|
||||
// TODO: temporarily replace interrupt with real error, because frontend cannot handle interrupt for now
|
||||
interruptData := extra.(*entity.InterruptEvent).InterruptData
|
||||
return nil, vo.NewError(errno.ErrAuthorizationRequired, errorx.KV("extra", interruptData))
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
|
||||
}
|
||||
633
backend/domain/workflow/internal/nodes/qa/question_answer.go
Normal file
633
backend/domain/workflow/internal/nodes/qa/question_answer.go
Normal file
@@ -0,0 +1,633 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package qa
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
type QuestionAnswer struct {
|
||||
config *Config
|
||||
nodeMeta entity.NodeTypeMeta
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
QuestionTpl string
|
||||
AnswerType AnswerType
|
||||
|
||||
ChoiceType ChoiceType
|
||||
FixedChoices []string
|
||||
|
||||
// used for intent recognize if answer by choices and given a custom answer, as well as for extracting structured output from user response
|
||||
Model model.BaseChatModel
|
||||
|
||||
// the following are required if AnswerType is AnswerDirectly and needs to extract from answer
|
||||
ExtractFromAnswer bool
|
||||
AdditionalSystemPromptTpl string
|
||||
MaxAnswerCount int
|
||||
OutputFields map[string]*vo.TypeInfo
|
||||
|
||||
NodeKey vo.NodeKey
|
||||
}
|
||||
|
||||
type AnswerType string
|
||||
|
||||
const (
|
||||
AnswerDirectly AnswerType = "directly"
|
||||
AnswerByChoices AnswerType = "by_choices"
|
||||
)
|
||||
|
||||
type ChoiceType string
|
||||
|
||||
const (
|
||||
FixedChoices ChoiceType = "fixed"
|
||||
DynamicChoices ChoiceType = "dynamic"
|
||||
)
|
||||
|
||||
const (
|
||||
DynamicChoicesKey = "dynamic_option"
|
||||
QuestionsKey = "$questions"
|
||||
AnswersKey = "$answers"
|
||||
UserResponseKey = "USER_RESPONSE"
|
||||
OptionIDKey = "optionId"
|
||||
OptionContentKey = "optionContent"
|
||||
)
|
||||
|
||||
const (
|
||||
extractSystemPrompt = `# 角色
|
||||
你是一个参数提取 agent,你的工作是从用户的回答中提取出多个字段的值,每个字段遵循以下规则
|
||||
# 字段说明
|
||||
%s
|
||||
## 输出要求
|
||||
- 严格以 json 格式返回答案。
|
||||
- 严格确保答案采用有效的 JSON 格式。
|
||||
- 按照字段说明提取出字段的值,将已经提取到的字段放在 fields 字段
|
||||
- 对于未提取到的<必填字段>生成一个新的追问问题question
|
||||
- 确保在追问问题中只包含所有未提取的<必填字段>
|
||||
- 不要重复问之前问过的问题
|
||||
- 问题的语种请和用户的输入保持一致,如英文、中文等
|
||||
- 输出按照下面结构体格式返回,包含提取到的字段或者追问的问题
|
||||
- 不要回复和提取无关的问题
|
||||
type Output struct {
|
||||
fields FieldInfo // 根据字段说明已经提取到的字段
|
||||
question string // 新一轮追问的问题
|
||||
}`
|
||||
extractUserPromptSuffix = `
|
||||
- 严格以 json 格式返回答案。
|
||||
- 严格确保答案采用有效的 JSON 格式。
|
||||
- - 必填字段没有获取全则继续追问
|
||||
- 必填字段: %s
|
||||
%s
|
||||
`
|
||||
additionalPersona = `
|
||||
追问人设设定: %s
|
||||
`
|
||||
choiceIntentDetectPrompt = `# Role
|
||||
You are a semantic matching expert, good at analyzing the option that the user wants to choose based on the current context.
|
||||
##Skill
|
||||
Skill 1: Clearly identify which of the following options the user's reply is semantically closest to:
|
||||
%s
|
||||
|
||||
##Restrictions
|
||||
Strictly identify the intention and select the most suitable option. You can only reply with the option_id and no other content. If you think there is no suitable option, output -1
|
||||
##Output format
|
||||
Note: You can only output the id or -1. Your output can only be a pure number and no other content (including the reason)!`
|
||||
)
|
||||
|
||||
func NewQuestionAnswer(_ context.Context, conf *Config) (*QuestionAnswer, error) {
|
||||
if conf == nil {
|
||||
return nil, errors.New("config is nil")
|
||||
}
|
||||
|
||||
if conf.AnswerType == AnswerDirectly {
|
||||
if conf.ExtractFromAnswer {
|
||||
if conf.Model == nil {
|
||||
return nil, errors.New("model is required when extract from answer")
|
||||
}
|
||||
if len(conf.OutputFields) == 0 {
|
||||
return nil, errors.New("output fields is required when extract from answer")
|
||||
}
|
||||
}
|
||||
} else if conf.AnswerType == AnswerByChoices {
|
||||
if conf.ChoiceType == FixedChoices {
|
||||
if len(conf.FixedChoices) == 0 {
|
||||
return nil, errors.New("fixed choices is required when extract from answer")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("unknown answer type: %s", conf.AnswerType)
|
||||
}
|
||||
|
||||
nodeMeta := entity.NodeMetaByNodeType(entity.NodeTypeQuestionAnswer)
|
||||
if nodeMeta == nil {
|
||||
return nil, errors.New("node meta not found for question answer")
|
||||
}
|
||||
|
||||
return &QuestionAnswer{
|
||||
config: conf,
|
||||
nodeMeta: *nodeMeta,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type Question struct {
|
||||
Question string
|
||||
Choices []string
|
||||
}
|
||||
|
||||
type namedOpt struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type optionContent struct {
|
||||
Options []namedOpt `json:"options"`
|
||||
Question string `json:"question"`
|
||||
}
|
||||
|
||||
type message struct {
|
||||
Type string `json:"type"`
|
||||
ContentType string `json:"content_type"`
|
||||
Content any `json:"content"` // either optionContent or string
|
||||
ID string `json:"id,omitempty"`
|
||||
}
|
||||
|
||||
// Execute formats the question (optionally with choices), interrupts, then extracts the answer.
|
||||
// input: the references by input fields, as well as the dynamic choices array if needed.
|
||||
// output: USER_RESPONSE for direct answer, structured output if needs to extract from answer, and option ID / content for answer by choices.
|
||||
func (q *QuestionAnswer) Execute(ctx context.Context, in map[string]any) (out map[string]any, err error) {
|
||||
var (
|
||||
questions []*Question
|
||||
answers []string
|
||||
isFirst bool
|
||||
notResumed bool
|
||||
)
|
||||
|
||||
questions, answers, isFirst, notResumed, err = q.extractCurrentState(in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if notResumed { // previously interrupted but not resumed this time, interrupt immediately
|
||||
return nil, compose.InterruptAndRerun
|
||||
}
|
||||
|
||||
out = make(map[string]any)
|
||||
out[QuestionsKey] = questions
|
||||
out[AnswersKey] = answers
|
||||
|
||||
switch q.config.AnswerType {
|
||||
case AnswerDirectly:
|
||||
if isFirst { // first execution, ask the question
|
||||
// format the question. Which is common to all use cases
|
||||
firstQuestion, err := nodes.TemplateRender(q.config.QuestionTpl, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, q.interrupt(ctx, firstQuestion, nil, nil, nil)
|
||||
}
|
||||
|
||||
if q.config.ExtractFromAnswer {
|
||||
return q.extractFromAnswer(ctx, in, questions, answers)
|
||||
}
|
||||
|
||||
out[UserResponseKey] = answers[0]
|
||||
return out, nil
|
||||
case AnswerByChoices:
|
||||
if !isFirst {
|
||||
lastAnswer := answers[len(answers)-1]
|
||||
lastQuestion := questions[len(questions)-1]
|
||||
for i, choice := range lastQuestion.Choices {
|
||||
if lastAnswer == choice {
|
||||
out[OptionIDKey] = intToAlphabet(i)
|
||||
out[OptionContentKey] = choice
|
||||
return out, nil
|
||||
}
|
||||
}
|
||||
|
||||
index, err := q.intentDetect(ctx, lastAnswer, lastQuestion.Choices)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if index >= 0 {
|
||||
out[OptionIDKey] = intToAlphabet(index)
|
||||
out[OptionContentKey] = lastQuestion.Choices[index]
|
||||
return out, nil
|
||||
}
|
||||
|
||||
out[OptionIDKey] = "other"
|
||||
out[OptionContentKey] = lastAnswer
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// format the question. Which is common to all use cases
|
||||
firstQuestion, err := nodes.TemplateRender(q.config.QuestionTpl, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var formattedChoices []string
|
||||
switch q.config.ChoiceType {
|
||||
case FixedChoices:
|
||||
for _, choice := range q.config.FixedChoices {
|
||||
formattedChoice, err := nodes.TemplateRender(choice, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
formattedChoices = append(formattedChoices, formattedChoice)
|
||||
}
|
||||
case DynamicChoices:
|
||||
dynamicChoices, ok := nodes.TakeMapValue(in, compose.FieldPath{DynamicChoicesKey})
|
||||
if !ok || len(dynamicChoices.([]any)) == 0 {
|
||||
return nil, vo.NewError(errno.ErrQuestionOptionsEmpty)
|
||||
}
|
||||
|
||||
const maxDynamicChoices = 26
|
||||
for i, choice := range dynamicChoices.([]any) {
|
||||
if i >= maxDynamicChoices {
|
||||
break // take first 26 choices, discard the others
|
||||
}
|
||||
c := choice.(string)
|
||||
formattedChoices = append(formattedChoices, c)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown choice type: %s", q.config.ChoiceType)
|
||||
}
|
||||
|
||||
return nil, q.interrupt(ctx, firstQuestion, formattedChoices, nil, nil)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown answer type: %s", q.config.AnswerType)
|
||||
}
|
||||
}
|
||||
|
||||
func (q *QuestionAnswer) extractFromAnswer(ctx context.Context, in map[string]any, questions []*Question, answers []string) (map[string]any, error) {
|
||||
fieldInfo := "FieldInfo"
|
||||
s, err := vo.TypeInfoToJSONSchema(q.config.OutputFields, &fieldInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sysPrompt := fmt.Sprintf(extractSystemPrompt, s)
|
||||
|
||||
var requiredFields []string
|
||||
for fName, tInfo := range q.config.OutputFields {
|
||||
if tInfo.Required {
|
||||
requiredFields = append(requiredFields, fName)
|
||||
}
|
||||
}
|
||||
|
||||
var formattedAdditionalPrompt string
|
||||
if len(q.config.AdditionalSystemPromptTpl) > 0 {
|
||||
additionalPrompt, err := nodes.TemplateRender(q.config.AdditionalSystemPromptTpl, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
formattedAdditionalPrompt = fmt.Sprintf(additionalPersona, additionalPrompt)
|
||||
}
|
||||
|
||||
userPromptSuffix := fmt.Sprintf(extractUserPromptSuffix, requiredFields, formattedAdditionalPrompt)
|
||||
|
||||
var (
|
||||
messages = make([]*schema.Message, 0, len(questions)*2+1)
|
||||
userResponse string
|
||||
)
|
||||
messages = append(messages, schema.SystemMessage(sysPrompt))
|
||||
for i := range questions {
|
||||
messages = append(messages, schema.AssistantMessage(questions[i].Question, nil))
|
||||
|
||||
answer := answers[i]
|
||||
if i == len(questions)-1 {
|
||||
userResponse = answer
|
||||
answer = answer + userPromptSuffix
|
||||
}
|
||||
messages = append(messages, schema.UserMessage(answer))
|
||||
}
|
||||
|
||||
out, err := q.config.Model.Generate(ctx, messages)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
content := nodes.ExtractJSONString(out.Content)
|
||||
|
||||
var outMap = make(map[string]any)
|
||||
err = sonic.UnmarshalString(content, &outMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nextQuestion, ok := outMap["question"]
|
||||
if ok {
|
||||
nextQuestionStr, ok := nextQuestion.(string)
|
||||
if ok && len(nextQuestionStr) > 0 {
|
||||
if len(answers) >= q.config.MaxAnswerCount {
|
||||
return nil, fmt.Errorf("max answer count= %d exceeded", q.config.MaxAnswerCount)
|
||||
}
|
||||
|
||||
return nil, q.interrupt(ctx, nextQuestionStr, nil, questions, answers)
|
||||
}
|
||||
}
|
||||
|
||||
fields, ok := outMap["fields"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("field %s not found", fieldInfo)
|
||||
}
|
||||
|
||||
realOutput, ws, err := nodes.ConvertInputs(ctx, fields.(map[string]any), q.config.OutputFields, nodes.SkipRequireCheck())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ws != nil {
|
||||
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
|
||||
}
|
||||
|
||||
realOutput[UserResponseKey] = userResponse
|
||||
realOutput[QuestionsKey] = questions
|
||||
realOutput[AnswersKey] = answers
|
||||
return realOutput, nil
|
||||
}
|
||||
|
||||
func (q *QuestionAnswer) extractCurrentState(in map[string]any) (
|
||||
qResult []*Question,
|
||||
aResult []string,
|
||||
isFirst bool, // whether this execution if the first ever execution for this node
|
||||
notResumed bool, // whether this node is previously interrupted, but not resumed this time, because another node is resumed
|
||||
err error) {
|
||||
questions, ok := in[QuestionsKey]
|
||||
if ok {
|
||||
qResult = questions.([]*Question)
|
||||
}
|
||||
|
||||
answers, ok := in[AnswersKey]
|
||||
if ok {
|
||||
aResult = answers.([]string)
|
||||
}
|
||||
|
||||
if len(qResult) == 0 && len(aResult) == 0 {
|
||||
return nil, nil, true, false, nil
|
||||
}
|
||||
|
||||
if len(qResult) != len(aResult) && len(qResult) != len(aResult)+1 {
|
||||
return nil, nil, false, false,
|
||||
fmt.Errorf("invalid state, question count is expected to be equal to answer count or 1 more than answer count: %v", in)
|
||||
}
|
||||
|
||||
return qResult, aResult, false, len(qResult) == len(aResult)+1, nil
|
||||
}
|
||||
|
||||
func (q *QuestionAnswer) intentDetect(ctx context.Context, answer string, choices []string) (int, error) {
|
||||
type option struct {
|
||||
Option string `json:"option"`
|
||||
OptionID int `json:"option_id"`
|
||||
}
|
||||
|
||||
options := make([]option, 0, len(choices))
|
||||
for i := range choices {
|
||||
options = append(options, option{Option: choices[i], OptionID: i})
|
||||
}
|
||||
|
||||
optionsStr, err := sonic.MarshalString(options)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
|
||||
sysPrompt := fmt.Sprintf(choiceIntentDetectPrompt, optionsStr)
|
||||
messages := []*schema.Message{
|
||||
schema.SystemMessage(sysPrompt),
|
||||
schema.UserMessage(answer),
|
||||
}
|
||||
|
||||
out, err := q.config.Model.Generate(ctx, messages)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
|
||||
index, err := strconv.Atoi(out.Content)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
|
||||
return index, nil
|
||||
}
|
||||
|
||||
type QuestionAnswerAware interface {
|
||||
AddQuestion(nodeKey vo.NodeKey, question *Question)
|
||||
AddAnswer(nodeKey vo.NodeKey, answer string)
|
||||
GetQuestionsAndAnswers(nodeKey vo.NodeKey) ([]*Question, []string)
|
||||
}
|
||||
|
||||
func (q *QuestionAnswer) interrupt(ctx context.Context, newQuestion string, choices []string, oldQuestions []*Question, oldAnswers []string) error {
|
||||
history := q.generateHistory(oldQuestions, oldAnswers, &newQuestion, choices)
|
||||
|
||||
historyList := map[string][]*message{
|
||||
"messages": history,
|
||||
}
|
||||
interruptData, err := sonic.MarshalString(historyList)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
eventID, err := workflow.GetRepository().GenID(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
event := &entity.InterruptEvent{
|
||||
ID: eventID,
|
||||
NodeKey: q.config.NodeKey,
|
||||
NodeType: entity.NodeTypeQuestionAnswer,
|
||||
NodeTitle: q.nodeMeta.Name,
|
||||
NodeIcon: q.nodeMeta.IconURL,
|
||||
InterruptData: interruptData,
|
||||
EventType: entity.InterruptEventQuestion,
|
||||
}
|
||||
|
||||
_ = compose.ProcessState(ctx, func(ctx context.Context, setter QuestionAnswerAware) error {
|
||||
setter.AddQuestion(q.config.NodeKey, &Question{
|
||||
Question: newQuestion,
|
||||
Choices: choices,
|
||||
})
|
||||
return nil
|
||||
})
|
||||
|
||||
return compose.NewInterruptAndRerunErr(event)
|
||||
}
|
||||
|
||||
func intToAlphabet(num int) string {
|
||||
if num >= 0 && num <= 25 {
|
||||
char := rune('A' + num)
|
||||
return string(char)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func AlphabetToInt(str string) (int, bool) {
|
||||
if len(str) != 1 {
|
||||
return 0, false
|
||||
}
|
||||
char := rune(str[0])
|
||||
char = unicode.ToUpper(char)
|
||||
if char >= 'A' && char <= 'Z' {
|
||||
return int(char - 'A'), true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func (q *QuestionAnswer) generateHistory(oldQuestions []*Question, oldAnswers []string, newQuestion *string, choices []string) []*message {
|
||||
conv := func(opts []string) (namedOpts []namedOpt) {
|
||||
for _, opt := range opts {
|
||||
namedOpts = append(namedOpts, namedOpt{
|
||||
Name: opt,
|
||||
})
|
||||
}
|
||||
return namedOpts
|
||||
}
|
||||
|
||||
history := make([]*message, 0, len(oldQuestions)+len(oldAnswers)+1)
|
||||
for i := 0; i < len(oldQuestions); i++ {
|
||||
oldQuestion := oldQuestions[i]
|
||||
oldAnswer := oldAnswers[i]
|
||||
contentType := ternary.IFElse(q.config.AnswerType == AnswerByChoices, "option", "text")
|
||||
questionMsg := &message{
|
||||
Type: "question",
|
||||
ContentType: contentType,
|
||||
ID: fmt.Sprintf("%s_%d", q.config.NodeKey, i*2),
|
||||
}
|
||||
|
||||
if q.config.AnswerType == AnswerByChoices {
|
||||
questionMsg.Content = optionContent{
|
||||
Options: conv(oldQuestion.Choices),
|
||||
Question: oldQuestion.Question,
|
||||
}
|
||||
} else {
|
||||
questionMsg.Content = oldQuestion.Question
|
||||
}
|
||||
|
||||
answerMsg := &message{
|
||||
Type: "answer",
|
||||
ContentType: contentType,
|
||||
Content: oldAnswer,
|
||||
ID: fmt.Sprintf("%s_%d", q.config.NodeKey, i+1),
|
||||
}
|
||||
|
||||
history = append(history, questionMsg, answerMsg)
|
||||
}
|
||||
|
||||
if newQuestion != nil {
|
||||
if q.config.AnswerType == AnswerByChoices {
|
||||
history = append(history, &message{
|
||||
Type: "question",
|
||||
ContentType: "option",
|
||||
Content: optionContent{
|
||||
Options: conv(choices),
|
||||
Question: *newQuestion,
|
||||
},
|
||||
ID: fmt.Sprintf("%s_%d", q.config.NodeKey, len(oldQuestions)*2),
|
||||
})
|
||||
} else {
|
||||
history = append(history, &message{
|
||||
Type: "question",
|
||||
ContentType: "text",
|
||||
Content: *newQuestion,
|
||||
ID: fmt.Sprintf("%s_%d", q.config.NodeKey, len(oldQuestions)*2),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return history
|
||||
}
|
||||
|
||||
func (q *QuestionAnswer) ToCallbackOutput(_ context.Context, out map[string]any) (*nodes.StructuredCallbackOutput, error) {
|
||||
questions := out[QuestionsKey].([]*Question)
|
||||
answers := out[AnswersKey].([]string)
|
||||
selected, hasSelected := out[OptionContentKey]
|
||||
history := q.generateHistory(questions, answers, nil, nil)
|
||||
for _, msg := range history {
|
||||
optionC, ok := msg.Content.(optionContent)
|
||||
if ok {
|
||||
msg.Content = optionC.Question
|
||||
}
|
||||
msg.ID = ""
|
||||
}
|
||||
|
||||
delete(out, QuestionsKey)
|
||||
delete(out, AnswersKey)
|
||||
|
||||
sOut := &nodes.StructuredCallbackOutput{
|
||||
Output: out,
|
||||
RawOutput: map[string]any{
|
||||
"messages": history,
|
||||
},
|
||||
}
|
||||
|
||||
if hasSelected {
|
||||
sOut.RawOutput["selected"] = selected
|
||||
}
|
||||
|
||||
return sOut, nil
|
||||
}
|
||||
|
||||
func AppendInterruptData(interruptData string, resumeData string) (string, error) {
|
||||
var historyList = make(map[string][]*message)
|
||||
err := sonic.UnmarshalString(interruptData, &historyList)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
lastQuestion := historyList["messages"][len(historyList["messages"])-1]
|
||||
segments := strings.Split(lastQuestion.ID, "_")
|
||||
nodeKey := segments[0]
|
||||
i, err := strconv.Atoi(segments[1])
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
answerMsg := &message{
|
||||
Type: "answer",
|
||||
ContentType: lastQuestion.ContentType,
|
||||
Content: resumeData,
|
||||
ID: fmt.Sprintf("%s_%d", nodeKey, i+1),
|
||||
}
|
||||
|
||||
historyList["messages"] = append(historyList["messages"], answerMsg)
|
||||
m, err := sonic.MarshalString(historyList)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
@@ -0,0 +1,181 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package receiver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
sonic2 "github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
OutputTypes map[string]*vo.TypeInfo
|
||||
NodeKey vo.NodeKey
|
||||
OutputSchema string
|
||||
}
|
||||
|
||||
type InputReceiver struct {
|
||||
outputTypes map[string]*vo.TypeInfo
|
||||
interruptData string
|
||||
nodeKey vo.NodeKey
|
||||
nodeMeta entity.NodeTypeMeta
|
||||
}
|
||||
|
||||
func New(_ context.Context, cfg *Config) (*InputReceiver, error) {
|
||||
nodeMeta := entity.NodeMetaByNodeType(entity.NodeTypeInputReceiver)
|
||||
if nodeMeta == nil {
|
||||
return nil, errors.New("node meta not found for input receiver")
|
||||
}
|
||||
|
||||
interruptData := map[string]string{
|
||||
"content_type": "form_schema",
|
||||
"content": cfg.OutputSchema,
|
||||
}
|
||||
|
||||
interruptDataStr, err := sonic.ConfigStd.MarshalToString(interruptData) // keep the order of the keys
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &InputReceiver{
|
||||
outputTypes: cfg.OutputTypes,
|
||||
nodeMeta: *nodeMeta,
|
||||
nodeKey: cfg.NodeKey,
|
||||
interruptData: interruptDataStr,
|
||||
}, nil
|
||||
}
|
||||
|
||||
const (
|
||||
ReceivedDataKey = "$received_data"
|
||||
receiverWarningKey = "receiver_warning_%d_%s"
|
||||
)
|
||||
|
||||
func (i *InputReceiver) Invoke(ctx context.Context, in map[string]any) (map[string]any, error) {
|
||||
var input string
|
||||
if in != nil {
|
||||
receivedData, ok := in[ReceivedDataKey]
|
||||
if ok {
|
||||
input = receivedData.(string)
|
||||
}
|
||||
}
|
||||
|
||||
if len(input) == 0 {
|
||||
err := compose.ProcessState(ctx, func(ctx context.Context, ieStore nodes.InterruptEventStore) error {
|
||||
_, found, e := ieStore.GetInterruptEvent(i.nodeKey) // TODO: try not use InterruptEventStore or state in general
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
|
||||
if !found { // only generate a new event if it doesn't exist
|
||||
eventID, err := workflow.GetRepository().GenID(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ieStore.SetInterruptEvent(i.nodeKey, &entity.InterruptEvent{
|
||||
ID: eventID,
|
||||
NodeKey: i.nodeKey,
|
||||
NodeType: entity.NodeTypeInputReceiver,
|
||||
NodeTitle: i.nodeMeta.Name,
|
||||
NodeIcon: i.nodeMeta.IconURL,
|
||||
InterruptData: i.interruptData,
|
||||
EventType: entity.InterruptEventInput,
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, compose.InterruptAndRerun
|
||||
}
|
||||
|
||||
out, err := jsonParseRelaxed(ctx, input, i.outputTypes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func jsonParseRelaxed(ctx context.Context, data string, schema_ map[string]*vo.TypeInfo) (map[string]any, error) {
|
||||
var result map[string]any
|
||||
|
||||
err := sonic2.UnmarshalString(data, &result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r, ws, err := nodes.ConvertInputs(ctx, result, schema_, nodes.SkipUnknownFields())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ws != nil && len(*ws) > 0 {
|
||||
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
|
||||
var (
|
||||
executeID int64
|
||||
nodeKey vo.NodeKey
|
||||
)
|
||||
if c := execute.GetExeCtx(ctx); c != nil {
|
||||
executeID = c.RootExecuteID
|
||||
nodeKey = c.NodeKey
|
||||
}
|
||||
|
||||
warningKey := fmt.Sprintf(receiverWarningKey, executeID, nodeKey)
|
||||
ctxcache.Store(ctx, warningKey, *ws)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (i *InputReceiver) ToCallbackOutput(ctx context.Context, output map[string]any) (
|
||||
*nodes.StructuredCallbackOutput, error) {
|
||||
var (
|
||||
executeID int64
|
||||
nodeKey vo.NodeKey
|
||||
)
|
||||
if c := execute.GetExeCtx(ctx); c != nil {
|
||||
executeID = c.RootExecuteID
|
||||
nodeKey = c.NodeKey
|
||||
}
|
||||
|
||||
warningKey := fmt.Sprintf(receiverWarningKey, executeID, nodeKey)
|
||||
|
||||
var wfe vo.WorkflowError
|
||||
if warnings, ok := ctxcache.Get[nodes.ConversionWarnings](ctx, warningKey); ok {
|
||||
wfe = vo.WrapWarn(errno.ErrNodeOutputParseFail, warnings, errorx.KV("warnings", warnings.Error()))
|
||||
}
|
||||
return &nodes.StructuredCallbackOutput{
|
||||
Output: output,
|
||||
RawOutput: output,
|
||||
Error: wfe,
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package receiver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
)
|
||||
|
||||
func Test_jsonParseRelaxed(t *testing.T) {
|
||||
tInfos := map[string]*vo.TypeInfo{
|
||||
"str_key": {
|
||||
Type: vo.DataTypeString,
|
||||
},
|
||||
"obj_key": {
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
"field1": {
|
||||
Type: vo.DataTypeString,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
data := `{"str_key": "val"}`
|
||||
|
||||
result, err := jsonParseRelaxed(context.Background(), data, tInfos)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, map[string]any{"str_key": "val"}, result)
|
||||
}
|
||||
342
backend/domain/workflow/internal/nodes/selector/clause.go
Normal file
342
backend/domain/workflow/internal/nodes/selector/clause.go
Normal file
@@ -0,0 +1,342 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package selector
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Predicate interface {
|
||||
Resolve() (bool, error)
|
||||
}
|
||||
|
||||
type Clause struct {
|
||||
LeftOperant any
|
||||
Op Operator
|
||||
RightOperant any
|
||||
}
|
||||
|
||||
type MultiClause struct {
|
||||
Clauses []*Clause
|
||||
Relation ClauseRelation
|
||||
}
|
||||
|
||||
func (c *Clause) Resolve() (bool, error) {
|
||||
leftV := c.LeftOperant
|
||||
rightV := c.RightOperant
|
||||
|
||||
leftT := reflect.TypeOf(leftV)
|
||||
rightT := reflect.TypeOf(rightV)
|
||||
|
||||
if err := c.Op.WillAccept(leftT, rightT); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
switch c.Op {
|
||||
case OperatorEqual:
|
||||
if leftV == nil && rightV == nil {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
if leftV == nil || rightV == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
leftV, rightV = alignNumberTypes(leftV, rightV, leftT, rightT)
|
||||
return leftV == rightV, nil
|
||||
case OperatorNotEqual:
|
||||
if leftV == nil && rightV == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if leftV == nil || rightV == nil {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
leftV, rightV = alignNumberTypes(leftV, rightV, leftT, rightT)
|
||||
return leftV != rightV, nil
|
||||
case OperatorEmpty:
|
||||
if leftV == nil {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
if leftArray, ok := leftV.([]any); ok {
|
||||
return len(leftArray) == 0, nil
|
||||
}
|
||||
|
||||
if leftObj, ok := leftV.(map[string]any); ok {
|
||||
return len(leftObj) == 0, nil
|
||||
}
|
||||
|
||||
if leftStr, ok := leftV.(string); ok {
|
||||
return len(leftStr) == 0 || leftStr == "None", nil
|
||||
}
|
||||
|
||||
if leftInt, ok := leftV.(int64); ok {
|
||||
return leftInt == 0, nil
|
||||
}
|
||||
|
||||
if leftFloat, ok := leftV.(float64); ok {
|
||||
return leftFloat == 0, nil
|
||||
}
|
||||
|
||||
if leftBool, ok := leftV.(bool); ok {
|
||||
return !leftBool, nil
|
||||
}
|
||||
|
||||
return false, nil
|
||||
case OperatorNotEmpty:
|
||||
empty, err := (&Clause{LeftOperant: leftV, Op: OperatorEmpty}).Resolve()
|
||||
return !empty, err
|
||||
case OperatorGreater:
|
||||
if leftV == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if rightV == nil {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
leftV, rightV = alignNumberTypes(leftV, rightV, leftT, rightT)
|
||||
if reflect.TypeOf(leftV).Kind() == reflect.Float64 {
|
||||
return leftV.(float64) > rightV.(float64), nil
|
||||
}
|
||||
return leftV.(int64) > rightV.(int64), nil
|
||||
case OperatorGreaterOrEqual:
|
||||
if leftV == nil {
|
||||
if rightV == nil {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if rightV == nil {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
leftV, rightV = alignNumberTypes(leftV, rightV, leftT, rightT)
|
||||
if reflect.TypeOf(leftV).Kind() == reflect.Float64 {
|
||||
return leftV.(float64) >= rightV.(float64), nil
|
||||
}
|
||||
return leftV.(int64) >= rightV.(int64), nil
|
||||
case OperatorLesser:
|
||||
if leftV == nil {
|
||||
if rightV == nil {
|
||||
return false, nil
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
if rightV == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
leftV, rightV = alignNumberTypes(leftV, rightV, leftT, rightT)
|
||||
if reflect.TypeOf(leftV).Kind() == reflect.Float64 {
|
||||
return leftV.(float64) < rightV.(float64), nil
|
||||
}
|
||||
return leftV.(int64) < rightV.(int64), nil
|
||||
case OperatorLesserOrEqual:
|
||||
if leftV == nil {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
if rightV == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
leftV, rightV = alignNumberTypes(leftV, rightV, leftT, rightT)
|
||||
if reflect.TypeOf(leftV).Kind() == reflect.Float64 {
|
||||
return leftV.(float64) <= rightV.(float64), nil
|
||||
}
|
||||
return leftV.(int64) <= rightV.(int64), nil
|
||||
case OperatorIsTrue:
|
||||
if leftV == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return leftV.(bool), nil
|
||||
case OperatorIsFalse:
|
||||
if leftV == nil {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return !leftV.(bool), nil
|
||||
case OperatorLengthGreater:
|
||||
if leftV == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return int64(reflect.ValueOf(leftV).Len()) > rightV.(int64), nil
|
||||
case OperatorLengthGreaterOrEqual:
|
||||
if leftV == nil {
|
||||
if rightV.(int64) == 0 {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return int64(reflect.ValueOf(leftV).Len()) >= rightV.(int64), nil
|
||||
case OperatorLengthLesser:
|
||||
if leftV == nil {
|
||||
if rightV.(int64) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return int64(reflect.ValueOf(leftV).Len()) < rightV.(int64), nil
|
||||
case OperatorLengthLesserOrEqual:
|
||||
if leftV == nil {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return int64(reflect.ValueOf(leftV).Len()) <= rightV.(int64), nil
|
||||
case OperatorContain:
|
||||
if leftV == nil { // treat it as empty slice
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if leftT.Kind() == reflect.String {
|
||||
return strings.Contains(fmt.Sprintf("%v", leftV), rightV.(string)), nil
|
||||
}
|
||||
|
||||
leftValue := reflect.ValueOf(leftV)
|
||||
for i := 0; i < leftValue.Len(); i++ {
|
||||
elem := leftValue.Index(i).Interface()
|
||||
if elem == rightV {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
case OperatorNotContain:
|
||||
if leftV == nil { // treat it as empty slice
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if leftT.Kind() == reflect.String {
|
||||
return !strings.Contains(fmt.Sprintf("%v", leftV), rightV.(string)), nil
|
||||
}
|
||||
|
||||
leftValue := reflect.ValueOf(leftV)
|
||||
for i := 0; i < leftValue.Len(); i++ {
|
||||
elem := leftValue.Index(i).Interface()
|
||||
if elem == rightV {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
case OperatorContainKey:
|
||||
if leftV == nil { // treat it as empty map
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if leftT.Kind() == reflect.Map {
|
||||
leftValue := reflect.ValueOf(leftV)
|
||||
for _, key := range leftValue.MapKeys() {
|
||||
if key.Interface() == rightV {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
} else { // struct, unreachable now
|
||||
for i := 0; i < leftT.NumField(); i++ {
|
||||
field := leftT.Field(i)
|
||||
if field.IsExported() {
|
||||
tag := field.Tag.Get("json")
|
||||
if tag == rightV {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
case OperatorNotContainKey:
|
||||
if leftV == nil { // treat it as empty map
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if leftT.Kind() == reflect.Map {
|
||||
leftValue := reflect.ValueOf(leftV)
|
||||
for _, key := range leftValue.MapKeys() {
|
||||
if key.Interface() == rightV {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
} else { // struct, unreachable now
|
||||
for i := 0; i < leftT.NumField(); i++ {
|
||||
field := leftT.Field(i)
|
||||
if field.IsExported() {
|
||||
tag := field.Tag.Get("json")
|
||||
if tag == rightV {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
default:
|
||||
return false, fmt.Errorf("unknown operator: %v", c.Op)
|
||||
}
|
||||
}
|
||||
|
||||
func (mc *MultiClause) Resolve() (bool, error) {
|
||||
if mc.Relation == ClauseRelationAND {
|
||||
for _, clause := range mc.Clauses {
|
||||
isTrue, err := clause.Resolve()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if !isTrue {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
} else if mc.Relation == ClauseRelationOR {
|
||||
for _, clause := range mc.Clauses {
|
||||
isTrue, err := clause.Resolve()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if isTrue {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
} else {
|
||||
return false, fmt.Errorf("unknown relation: %v", mc.Relation)
|
||||
}
|
||||
}
|
||||
|
||||
func alignNumberTypes(leftV, rightV any, leftT, rightT reflect.Type) (any, any) {
|
||||
if leftT == reflect.TypeOf(int64(0)) {
|
||||
if rightT == reflect.TypeOf(float64(0)) {
|
||||
leftV = float64(leftV.(int64))
|
||||
}
|
||||
} else if leftT == reflect.TypeOf(float64(0)) {
|
||||
if rightT == reflect.TypeOf(int64(0)) {
|
||||
rightV = float64(rightV.(int64))
|
||||
}
|
||||
}
|
||||
|
||||
return leftV, rightV
|
||||
}
|
||||
310
backend/domain/workflow/internal/nodes/selector/clause_test.go
Normal file
310
backend/domain/workflow/internal/nodes/selector/clause_test.go
Normal file
@@ -0,0 +1,310 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package selector
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestClauseResolve tests the Resolve method of the Clause struct.
|
||||
func TestClauseResolve(t *testing.T) {
|
||||
// Test cases for different operators, considering acceptable operand types
|
||||
testCases := []struct {
|
||||
name string
|
||||
clause Clause
|
||||
want bool
|
||||
wantErr bool
|
||||
}{
|
||||
// OperatorEqual
|
||||
{
|
||||
name: "OperatorEqual_IntMatch",
|
||||
clause: Clause{
|
||||
LeftOperant: int64(10),
|
||||
Op: OperatorEqual,
|
||||
RightOperant: int64(10),
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorEqual_IntMismatch",
|
||||
clause: Clause{
|
||||
LeftOperant: int64(10),
|
||||
Op: OperatorEqual,
|
||||
RightOperant: int64(20),
|
||||
},
|
||||
want: false,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorEqual_FloatMatch",
|
||||
clause: Clause{
|
||||
LeftOperant: 10.5,
|
||||
Op: OperatorEqual,
|
||||
RightOperant: 10.5,
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorEqual_StringMatch",
|
||||
clause: Clause{
|
||||
LeftOperant: "test",
|
||||
Op: OperatorEqual,
|
||||
RightOperant: "test",
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
// OperatorNotEqual
|
||||
{
|
||||
name: "OperatorNotEqual_IntMatch",
|
||||
clause: Clause{
|
||||
LeftOperant: int64(10),
|
||||
Op: OperatorNotEqual,
|
||||
RightOperant: int64(20),
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorNotEqual_StringMatch",
|
||||
clause: Clause{
|
||||
LeftOperant: "test",
|
||||
Op: OperatorNotEqual,
|
||||
RightOperant: "xyz",
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
// OperatorEmpty
|
||||
{
|
||||
name: "OperatorEmpty_NilValue",
|
||||
clause: Clause{
|
||||
LeftOperant: nil,
|
||||
Op: OperatorEmpty,
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
// OperatorNotEmpty
|
||||
{
|
||||
name: "OperatorNotEmpty_MapValue",
|
||||
clause: Clause{
|
||||
LeftOperant: map[string]any{"key1": "value1"},
|
||||
Op: OperatorNotEmpty,
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
// OperatorGreater
|
||||
{
|
||||
name: "OperatorGreater_IntMatch",
|
||||
clause: Clause{
|
||||
LeftOperant: int64(10),
|
||||
Op: OperatorGreater,
|
||||
RightOperant: int64(5),
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorGreater_FloatMatch",
|
||||
clause: Clause{
|
||||
LeftOperant: 10.5,
|
||||
Op: OperatorGreater,
|
||||
RightOperant: 5.0,
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
// OperatorGreaterOrEqual
|
||||
{
|
||||
name: "OperatorGreaterOrEqual_IntMatch",
|
||||
clause: Clause{
|
||||
LeftOperant: int64(10),
|
||||
Op: OperatorGreaterOrEqual,
|
||||
RightOperant: int64(10),
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
// OperatorLesser
|
||||
{
|
||||
name: "OperatorLesser_IntMatch",
|
||||
clause: Clause{
|
||||
LeftOperant: int64(10),
|
||||
Op: OperatorLesser,
|
||||
RightOperant: int64(15),
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
// OperatorLesserOrEqual
|
||||
{
|
||||
name: "OperatorLesserOrEqual_IntMatch",
|
||||
clause: Clause{
|
||||
LeftOperant: int64(10),
|
||||
Op: OperatorLesserOrEqual,
|
||||
RightOperant: int64(10),
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
// OperatorIsTrue
|
||||
{
|
||||
name: "OperatorIsTrue_BoolTrue",
|
||||
clause: Clause{
|
||||
LeftOperant: true,
|
||||
Op: OperatorIsTrue,
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
// OperatorIsFalse
|
||||
{
|
||||
name: "OperatorIsFalse_BoolFalse",
|
||||
clause: Clause{
|
||||
LeftOperant: true,
|
||||
Op: OperatorIsFalse,
|
||||
},
|
||||
want: false,
|
||||
wantErr: false,
|
||||
},
|
||||
// OperatorLengthGreater
|
||||
{
|
||||
name: "OperatorLengthGreater_Slice",
|
||||
clause: Clause{
|
||||
LeftOperant: []int{1, 2, 3},
|
||||
Op: OperatorLengthGreater,
|
||||
RightOperant: int64(2),
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorLengthGreater_String",
|
||||
clause: Clause{
|
||||
LeftOperant: "test",
|
||||
Op: OperatorLengthGreater,
|
||||
RightOperant: int64(2),
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
// OperatorLengthGreaterOrEqual
|
||||
{
|
||||
name: "OperatorLengthGreaterOrEqual_Slice",
|
||||
clause: Clause{
|
||||
LeftOperant: []int{1, 2, 3},
|
||||
Op: OperatorLengthGreaterOrEqual,
|
||||
RightOperant: int64(3),
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
// OperatorLengthLesser
|
||||
{
|
||||
name: "OperatorLengthLesser_Slice",
|
||||
clause: Clause{
|
||||
LeftOperant: []int{1, 2, 3},
|
||||
Op: OperatorLengthLesser,
|
||||
RightOperant: int64(4),
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
// OperatorLengthLesserOrEqual
|
||||
{
|
||||
name: "OperatorLengthLesserOrEqual_Slice",
|
||||
clause: Clause{
|
||||
LeftOperant: []int{1, 2, 3},
|
||||
Op: OperatorLengthLesserOrEqual,
|
||||
RightOperant: int64(3),
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
// OperatorContain
|
||||
{
|
||||
name: "OperatorContain_String",
|
||||
clause: Clause{
|
||||
LeftOperant: "test",
|
||||
Op: OperatorContain,
|
||||
RightOperant: "es",
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorContain_Slice",
|
||||
clause: Clause{
|
||||
LeftOperant: []int{1, 2, 3},
|
||||
Op: OperatorContain,
|
||||
RightOperant: 2,
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
// OperatorNotContain
|
||||
{
|
||||
name: "OperatorNotContain_String",
|
||||
clause: Clause{
|
||||
LeftOperant: "test2",
|
||||
Op: OperatorNotContain,
|
||||
RightOperant: "xyz",
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
// OperatorContainKey
|
||||
{
|
||||
name: "OperatorContainKey_Map",
|
||||
clause: Clause{
|
||||
LeftOperant: map[string]any{"key1": "value1"},
|
||||
Op: OperatorContainKey,
|
||||
RightOperant: "key1",
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
// OperatorNotContainKey
|
||||
{
|
||||
name: "OperatorNotContainKey_Map",
|
||||
clause: Clause{
|
||||
LeftOperant: map[string]any{"key1": "value1"},
|
||||
Op: OperatorNotContainKey,
|
||||
RightOperant: "key2",
|
||||
},
|
||||
want: true,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got, err := tc.clause.Resolve()
|
||||
if (err != nil) != tc.wantErr {
|
||||
t.Errorf("Clause.Resolve() error = %v, wantErr %v", err, tc.wantErr)
|
||||
return
|
||||
}
|
||||
assert.Equal(t, tc.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
182
backend/domain/workflow/internal/nodes/selector/operator.go
Normal file
182
backend/domain/workflow/internal/nodes/selector/operator.go
Normal file
@@ -0,0 +1,182 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package selector
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
)
|
||||
|
||||
type Operator string
|
||||
|
||||
const (
|
||||
OperatorEqual Operator = "="
|
||||
OperatorNotEqual Operator = "!="
|
||||
OperatorEmpty Operator = "empty"
|
||||
OperatorNotEmpty Operator = "not_empty"
|
||||
OperatorGreater Operator = ">"
|
||||
OperatorGreaterOrEqual Operator = ">="
|
||||
OperatorLesser Operator = "<"
|
||||
OperatorLesserOrEqual Operator = "<="
|
||||
OperatorIsTrue Operator = "true"
|
||||
OperatorIsFalse Operator = "false"
|
||||
OperatorLengthGreater Operator = "len >"
|
||||
OperatorLengthGreaterOrEqual Operator = "len >="
|
||||
OperatorLengthLesser Operator = "len <"
|
||||
OperatorLengthLesserOrEqual Operator = "len <="
|
||||
OperatorContain Operator = "contain"
|
||||
OperatorNotContain Operator = "not_contain"
|
||||
OperatorContainKey Operator = "contain_key"
|
||||
OperatorNotContainKey Operator = "not_contain_key"
|
||||
)
|
||||
|
||||
func (o *Operator) WillAccept(leftT, rightT reflect.Type) error {
|
||||
switch *o {
|
||||
case OperatorEqual, OperatorNotEqual:
|
||||
if leftT == nil || rightT == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if leftT != reflect.TypeOf(int64(0)) && leftT != reflect.TypeOf(float64(0)) && leftT.Kind() != reflect.Bool && leftT.Kind() != reflect.String {
|
||||
return fmt.Errorf("operator %v only accepts int64, float64, bool or string, not %v", *o, leftT)
|
||||
}
|
||||
|
||||
if leftT.Kind() == reflect.Bool || leftT.Kind() != reflect.String {
|
||||
if leftT != rightT {
|
||||
return fmt.Errorf("operator %v left operant and right operant must be same type: %v, %v", *o, leftT, rightT)
|
||||
}
|
||||
}
|
||||
|
||||
if leftT == reflect.TypeOf(int64(0)) || leftT == reflect.TypeOf(float64(0)) {
|
||||
if rightT != reflect.TypeOf(int64(0)) && rightT != reflect.TypeOf(float64(0)) {
|
||||
return fmt.Errorf("operator %v right operant must be int64 or float64, not %v", *o, rightT)
|
||||
}
|
||||
}
|
||||
case OperatorEmpty, OperatorNotEmpty:
|
||||
if rightT != nil {
|
||||
return fmt.Errorf("operator %v does not accept non-nil right operant: %v", *o, rightT)
|
||||
}
|
||||
case OperatorGreater, OperatorGreaterOrEqual, OperatorLesser, OperatorLesserOrEqual:
|
||||
if leftT == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if leftT != reflect.TypeOf(int64(0)) && leftT != reflect.TypeOf(float64(0)) {
|
||||
return fmt.Errorf("operator %v only accepts float64 or int64, not %v", *o, leftT)
|
||||
}
|
||||
case OperatorIsTrue, OperatorIsFalse:
|
||||
if leftT == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if rightT != nil {
|
||||
return fmt.Errorf("operator %v does not accept non-nil right operant: %v", *o, rightT)
|
||||
}
|
||||
|
||||
if leftT.Kind() != reflect.Bool {
|
||||
return fmt.Errorf("operator %v only accepts boolean, not %v", *o, leftT)
|
||||
}
|
||||
case OperatorLengthGreater, OperatorLengthGreaterOrEqual, OperatorLengthLesser, OperatorLengthLesserOrEqual:
|
||||
if leftT == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if leftT.Kind() != reflect.String && leftT.Kind() != reflect.Slice {
|
||||
return fmt.Errorf("operator %v left operant only accepts string or slice, not %v", *o, leftT)
|
||||
}
|
||||
if rightT != reflect.TypeOf(int64(0)) {
|
||||
return fmt.Errorf("operator %v right operant only accepts int64, not %v", *o, rightT)
|
||||
}
|
||||
case OperatorContain, OperatorNotContain:
|
||||
if leftT == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch leftT.Kind() {
|
||||
case reflect.String:
|
||||
if rightT.Kind() != reflect.String {
|
||||
return fmt.Errorf("operator %v whose left operant is string only accepts right operant of string, not %v", *o, rightT)
|
||||
}
|
||||
case reflect.Slice:
|
||||
elemType := leftT.Elem()
|
||||
if !rightT.AssignableTo(elemType) {
|
||||
return fmt.Errorf("operator %v whose left operant is slice only accepts right operant of corresponding element type %v, not %v", *o, elemType, rightT)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("operator %v only accepts left operant of string or slice, not %v", *o, leftT)
|
||||
}
|
||||
case OperatorContainKey, OperatorNotContainKey:
|
||||
if leftT == nil { // treat it as empty map
|
||||
return nil
|
||||
}
|
||||
if leftT.Kind() != reflect.Map {
|
||||
return fmt.Errorf("operator %v only accepts left operant of map, not %v", *o, leftT)
|
||||
}
|
||||
if rightT.Kind() != reflect.String {
|
||||
return fmt.Errorf("operator %v only accepts right operant of string, not %v", *o, rightT)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unknown operator: %d", o)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *Operator) ToCanvasOperatorType() vo.OperatorType {
|
||||
switch *o {
|
||||
case OperatorEqual:
|
||||
return vo.Equal
|
||||
case OperatorNotEqual:
|
||||
return vo.NotEqual
|
||||
case OperatorEmpty:
|
||||
return vo.Empty
|
||||
case OperatorNotEmpty:
|
||||
return vo.NotEmpty
|
||||
case OperatorGreater:
|
||||
return vo.GreaterThan
|
||||
case OperatorGreaterOrEqual:
|
||||
return vo.GreaterThanEqual
|
||||
case OperatorLesser:
|
||||
return vo.LessThan
|
||||
case OperatorLesserOrEqual:
|
||||
return vo.LessThanEqual
|
||||
case OperatorIsTrue:
|
||||
return vo.True
|
||||
case OperatorIsFalse:
|
||||
return vo.False
|
||||
case OperatorLengthGreater:
|
||||
return vo.LengthGreaterThan
|
||||
case OperatorLengthGreaterOrEqual:
|
||||
return vo.LengthGreaterThanEqual
|
||||
case OperatorLengthLesser:
|
||||
return vo.LengthLessThan
|
||||
case OperatorLengthLesserOrEqual:
|
||||
return vo.LengthLessThanEqual
|
||||
case OperatorContain:
|
||||
return vo.Contain
|
||||
case OperatorNotContain:
|
||||
return vo.NotContain
|
||||
case OperatorContainKey:
|
||||
return vo.Contain
|
||||
case OperatorNotContainKey:
|
||||
return vo.NotContain
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown operator: %+v", o))
|
||||
}
|
||||
}
|
||||
397
backend/domain/workflow/internal/nodes/selector/operator_test.go
Normal file
397
backend/domain/workflow/internal/nodes/selector/operator_test.go
Normal file
@@ -0,0 +1,397 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package selector
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestOperatorWillAccept tests the WillAccept method of the Operator struct.
|
||||
func TestOperatorWillAccept(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
operator Operator
|
||||
leftType reflect.Type
|
||||
rightType reflect.Type
|
||||
wantErr bool
|
||||
}{
|
||||
// OperatorEqual
|
||||
{
|
||||
name: "OperatorEqual_Int64Match",
|
||||
operator: OperatorEqual,
|
||||
leftType: reflect.TypeOf(int64(0)),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorEqual_InvalidType",
|
||||
operator: OperatorEqual,
|
||||
leftType: reflect.TypeOf(struct{}{}),
|
||||
rightType: reflect.TypeOf(struct{}{}),
|
||||
wantErr: true,
|
||||
},
|
||||
// OperatorNotEqual
|
||||
{
|
||||
name: "OperatorNotEqual_Float64Match",
|
||||
operator: OperatorNotEqual,
|
||||
leftType: reflect.TypeOf(float64(0)),
|
||||
rightType: reflect.TypeOf(float64(0)),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorNotEqual_InvalidType",
|
||||
operator: OperatorNotEqual,
|
||||
leftType: reflect.TypeOf(struct{}{}),
|
||||
rightType: reflect.TypeOf(struct{}{}),
|
||||
wantErr: true,
|
||||
},
|
||||
// OperatorEmpty
|
||||
{
|
||||
name: "OperatorEmpty_Struct",
|
||||
operator: OperatorEmpty,
|
||||
leftType: reflect.TypeOf(map[string]int{}),
|
||||
rightType: nil,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorEmpty_NonNilRight",
|
||||
operator: OperatorEmpty,
|
||||
leftType: reflect.TypeOf(map[string]int{}),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: true,
|
||||
},
|
||||
// OperatorNotEmpty
|
||||
{
|
||||
name: "OperatorNotEmpty_Slice",
|
||||
operator: OperatorNotEmpty,
|
||||
leftType: reflect.TypeOf([]int{}),
|
||||
rightType: nil,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorNotEmpty_NonNilRight",
|
||||
operator: OperatorNotEmpty,
|
||||
leftType: reflect.TypeOf([]int{}),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: true,
|
||||
},
|
||||
// OperatorGreater
|
||||
{
|
||||
name: "OperatorGreater_Int64Match",
|
||||
operator: OperatorGreater,
|
||||
leftType: reflect.TypeOf(int64(0)),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorGreater_InvalidType",
|
||||
operator: OperatorGreater,
|
||||
leftType: reflect.TypeOf(struct{}{}),
|
||||
rightType: reflect.TypeOf(struct{}{}),
|
||||
wantErr: true,
|
||||
},
|
||||
// OperatorGreaterOrEqual
|
||||
{
|
||||
name: "OperatorGreaterOrEqual_Float64Match",
|
||||
operator: OperatorGreaterOrEqual,
|
||||
leftType: reflect.TypeOf(float64(0)),
|
||||
rightType: reflect.TypeOf(float64(0)),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorGreaterOrEqual_InvalidType",
|
||||
operator: OperatorGreaterOrEqual,
|
||||
leftType: reflect.TypeOf(struct{}{}),
|
||||
rightType: reflect.TypeOf(struct{}{}),
|
||||
wantErr: true,
|
||||
},
|
||||
// OperatorLesser
|
||||
{
|
||||
name: "OperatorLesser_Int64Match",
|
||||
operator: OperatorLesser,
|
||||
leftType: reflect.TypeOf(int64(0)),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorLesser_InvalidType",
|
||||
operator: OperatorLesser,
|
||||
leftType: reflect.TypeOf(struct{}{}),
|
||||
rightType: reflect.TypeOf(struct{}{}),
|
||||
wantErr: true,
|
||||
},
|
||||
// OperatorLesserOrEqual
|
||||
{
|
||||
name: "OperatorLesserOrEqual_Float64Match",
|
||||
operator: OperatorLesserOrEqual,
|
||||
leftType: reflect.TypeOf(float64(0)),
|
||||
rightType: reflect.TypeOf(float64(0)),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorLesserOrEqual_InvalidType",
|
||||
operator: OperatorLesserOrEqual,
|
||||
leftType: reflect.TypeOf(struct{}{}),
|
||||
rightType: reflect.TypeOf(struct{}{}),
|
||||
wantErr: true,
|
||||
},
|
||||
// OperatorIsTrue
|
||||
{
|
||||
name: "OperatorIsTrue_Bool",
|
||||
operator: OperatorIsTrue,
|
||||
leftType: reflect.TypeOf(true),
|
||||
rightType: nil,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorIsTrue_NonNilRight",
|
||||
operator: OperatorIsTrue,
|
||||
leftType: reflect.TypeOf(true),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "OperatorIsTrue_InvalidType",
|
||||
operator: OperatorIsTrue,
|
||||
leftType: reflect.TypeOf(int64(0)),
|
||||
rightType: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
// OperatorIsFalse
|
||||
{
|
||||
name: "OperatorIsFalse_Bool",
|
||||
operator: OperatorIsFalse,
|
||||
leftType: reflect.TypeOf(false),
|
||||
rightType: nil,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorIsFalse_NonNilRight",
|
||||
operator: OperatorIsFalse,
|
||||
leftType: reflect.TypeOf(false),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "OperatorIsFalse_InvalidType",
|
||||
operator: OperatorIsFalse,
|
||||
leftType: reflect.TypeOf(int64(0)),
|
||||
rightType: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
// OperatorLengthGreater
|
||||
{
|
||||
name: "OperatorLengthGreater_String",
|
||||
operator: OperatorLengthGreater,
|
||||
leftType: reflect.TypeOf(""),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorLengthGreater_InvalidLeft",
|
||||
operator: OperatorLengthGreater,
|
||||
leftType: reflect.TypeOf(int64(0)),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "OperatorLengthGreater_InvalidRight",
|
||||
operator: OperatorLengthGreater,
|
||||
leftType: reflect.TypeOf(""),
|
||||
rightType: reflect.TypeOf(float64(0)),
|
||||
wantErr: true,
|
||||
},
|
||||
// OperatorLengthGreaterOrEqual
|
||||
{
|
||||
name: "OperatorLengthGreaterOrEqual_Slice",
|
||||
operator: OperatorLengthGreaterOrEqual,
|
||||
leftType: reflect.TypeOf([]any{}),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorLengthGreaterOrEqual_InvalidLeft",
|
||||
operator: OperatorLengthGreaterOrEqual,
|
||||
leftType: reflect.TypeOf(int64(0)),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "OperatorLengthGreaterOrEqual_InvalidRight",
|
||||
operator: OperatorLengthGreaterOrEqual,
|
||||
leftType: reflect.TypeOf([]any{}),
|
||||
rightType: reflect.TypeOf(float64(0)),
|
||||
wantErr: true,
|
||||
},
|
||||
// OperatorLengthLesser
|
||||
{
|
||||
name: "OperatorLengthLesser_String",
|
||||
operator: OperatorLengthLesser,
|
||||
leftType: reflect.TypeOf(""),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorLengthLesser_InvalidLeft",
|
||||
operator: OperatorLengthLesser,
|
||||
leftType: reflect.TypeOf(int64(0)),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "OperatorLengthLesser_InvalidRight",
|
||||
operator: OperatorLengthLesser,
|
||||
leftType: reflect.TypeOf(""),
|
||||
rightType: reflect.TypeOf(float64(0)),
|
||||
wantErr: true,
|
||||
},
|
||||
// OperatorLengthLesserOrEqual
|
||||
{
|
||||
name: "OperatorLengthLesserOrEqual_Slice",
|
||||
operator: OperatorLengthLesserOrEqual,
|
||||
leftType: reflect.TypeOf([]any{}),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorLengthLesserOrEqual_InvalidLeft",
|
||||
operator: OperatorLengthLesserOrEqual,
|
||||
leftType: reflect.TypeOf(int64(0)),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "OperatorLengthLesserOrEqual_InvalidRight",
|
||||
operator: OperatorLengthLesserOrEqual,
|
||||
leftType: reflect.TypeOf([]any{}),
|
||||
rightType: reflect.TypeOf(float64(0)),
|
||||
wantErr: true,
|
||||
},
|
||||
// OperatorContain
|
||||
{
|
||||
name: "OperatorContain_String",
|
||||
operator: OperatorContain,
|
||||
leftType: reflect.TypeOf(""),
|
||||
rightType: reflect.TypeOf(""),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorContain_Slice",
|
||||
operator: OperatorContain,
|
||||
leftType: reflect.TypeOf([]any{}),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorContain_InvalidLeft",
|
||||
operator: OperatorContain,
|
||||
leftType: reflect.TypeOf(int64(0)),
|
||||
rightType: reflect.TypeOf(0),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "OperatorContain_InvalidRight",
|
||||
operator: OperatorContain,
|
||||
leftType: reflect.TypeOf(""),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: true,
|
||||
},
|
||||
// OperatorNotContain
|
||||
{
|
||||
name: "OperatorNotContain_String",
|
||||
operator: OperatorNotContain,
|
||||
leftType: reflect.TypeOf(""),
|
||||
rightType: reflect.TypeOf(""),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorNotContain_Slice",
|
||||
operator: OperatorNotContain,
|
||||
leftType: reflect.TypeOf([]any{}),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorNotContain_InvalidLeft",
|
||||
operator: OperatorNotContain,
|
||||
leftType: reflect.TypeOf(int64(0)),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "OperatorNotContain_InvalidRight",
|
||||
operator: OperatorNotContain,
|
||||
leftType: reflect.TypeOf(""),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: true,
|
||||
},
|
||||
// OperatorContainKey
|
||||
{
|
||||
name: "OperatorContainKey_Map",
|
||||
operator: OperatorContainKey,
|
||||
leftType: reflect.TypeOf(map[string]any{}),
|
||||
rightType: reflect.TypeOf(""),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorContainKey_InvalidLeft",
|
||||
operator: OperatorContainKey,
|
||||
leftType: reflect.TypeOf(int64(0)),
|
||||
rightType: reflect.TypeOf(""),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "OperatorContainKey_InvalidRight",
|
||||
operator: OperatorContainKey,
|
||||
leftType: reflect.TypeOf(map[string]any{}),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: true,
|
||||
},
|
||||
// OperatorNotContainKey
|
||||
{
|
||||
name: "OperatorNotContainKey_Map",
|
||||
operator: OperatorNotContainKey,
|
||||
leftType: reflect.TypeOf(map[string]any{}),
|
||||
rightType: reflect.TypeOf(""),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OperatorNotContainKey_InvalidLeft",
|
||||
operator: OperatorNotContainKey,
|
||||
leftType: reflect.TypeOf(int64(0)),
|
||||
rightType: reflect.TypeOf(""),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "OperatorNotContainKey_InvalidRight",
|
||||
operator: OperatorNotContainKey,
|
||||
leftType: reflect.TypeOf(map[string]any{}),
|
||||
rightType: reflect.TypeOf(int64(0)),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := tc.operator.WillAccept(tc.leftType, tc.rightType)
|
||||
if (err != nil) != tc.wantErr {
|
||||
t.Errorf("Operator.WillAccept() error = %v, wantErr %v", err, tc.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
54
backend/domain/workflow/internal/nodes/selector/schema.go
Normal file
54
backend/domain/workflow/internal/nodes/selector/schema.go
Normal file
@@ -0,0 +1,54 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package selector
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
)
|
||||
|
||||
type ClauseRelation string
|
||||
|
||||
const (
|
||||
ClauseRelationAND ClauseRelation = "and"
|
||||
ClauseRelationOR ClauseRelation = "or"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Clauses []*OneClauseSchema `json:"clauses"`
|
||||
}
|
||||
|
||||
type OneClauseSchema struct {
|
||||
Single *Operator `json:"single,omitempty"`
|
||||
Multi *MultiClauseSchema `json:"multi,omitempty"`
|
||||
}
|
||||
|
||||
type MultiClauseSchema struct {
|
||||
Clauses []*Operator `json:"clauses"`
|
||||
Relation ClauseRelation `json:"relation"`
|
||||
}
|
||||
|
||||
func (c ClauseRelation) ToVOLogicType() vo.LogicType {
|
||||
if c == ClauseRelationAND {
|
||||
return vo.AND
|
||||
} else if c == ClauseRelationOR {
|
||||
return vo.OR
|
||||
}
|
||||
|
||||
panic(fmt.Sprintf("unknown clause relation: %s", c))
|
||||
}
|
||||
209
backend/domain/workflow/internal/nodes/selector/selector.go
Normal file
209
backend/domain/workflow/internal/nodes/selector/selector.go
Normal file
@@ -0,0 +1,209 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package selector
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
)
|
||||
|
||||
type Selector struct {
|
||||
config *Config
|
||||
}
|
||||
|
||||
func NewSelector(_ context.Context, config *Config) (*Selector, error) {
|
||||
if config == nil {
|
||||
return nil, fmt.Errorf("config is nil")
|
||||
}
|
||||
|
||||
if len(config.Clauses) == 0 {
|
||||
return nil, fmt.Errorf("config clauses are empty")
|
||||
}
|
||||
|
||||
for _, clause := range config.Clauses {
|
||||
if clause.Single == nil && clause.Multi == nil {
|
||||
return nil, fmt.Errorf("single clause and multi clause are both nil")
|
||||
}
|
||||
|
||||
if clause.Single != nil && clause.Multi != nil {
|
||||
return nil, fmt.Errorf("multi clause and single clause are both non-nil")
|
||||
}
|
||||
|
||||
if clause.Multi != nil {
|
||||
if len(clause.Multi.Clauses) == 0 {
|
||||
return nil, fmt.Errorf("multi clause's single clauses are empty")
|
||||
}
|
||||
|
||||
if clause.Multi.Relation != ClauseRelationAND && clause.Multi.Relation != ClauseRelationOR {
|
||||
return nil, fmt.Errorf("multi clause and clauses are both non-AND-OR: %v", clause.Multi.Relation)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &Selector{
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type Operants struct {
|
||||
Left any
|
||||
Right any
|
||||
Multi []*Operants
|
||||
}
|
||||
|
||||
const (
|
||||
LeftKey = "left"
|
||||
RightKey = "right"
|
||||
SelectKey = "selected"
|
||||
)
|
||||
|
||||
func (s *Selector) Select(_ context.Context, input map[string]any) (out map[string]any, err error) {
|
||||
in, err := s.SelectorInputConverter(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
predicates := make([]Predicate, 0, len(s.config.Clauses))
|
||||
for i, oneConf := range s.config.Clauses {
|
||||
if oneConf.Single != nil {
|
||||
left := in[i].Left
|
||||
right := in[i].Right
|
||||
if right != nil {
|
||||
predicates = append(predicates, &Clause{
|
||||
LeftOperant: left,
|
||||
Op: *oneConf.Single,
|
||||
RightOperant: right,
|
||||
})
|
||||
} else {
|
||||
predicates = append(predicates, &Clause{
|
||||
LeftOperant: left,
|
||||
Op: *oneConf.Single,
|
||||
})
|
||||
}
|
||||
} else if oneConf.Multi != nil {
|
||||
multiClause := &MultiClause{
|
||||
Relation: oneConf.Multi.Relation,
|
||||
}
|
||||
for j, singleConf := range oneConf.Multi.Clauses {
|
||||
left := in[i].Multi[j].Left
|
||||
right := in[i].Multi[j].Right
|
||||
if right != nil {
|
||||
multiClause.Clauses = append(multiClause.Clauses, &Clause{
|
||||
LeftOperant: left,
|
||||
Op: *singleConf,
|
||||
RightOperant: right,
|
||||
})
|
||||
} else {
|
||||
multiClause.Clauses = append(multiClause.Clauses, &Clause{
|
||||
LeftOperant: left,
|
||||
Op: *singleConf,
|
||||
})
|
||||
}
|
||||
}
|
||||
predicates = append(predicates, multiClause)
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid clause config, both single and multi are nil: %v", oneConf)
|
||||
}
|
||||
}
|
||||
|
||||
for i, p := range predicates {
|
||||
isTrue, err := p.Resolve()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if isTrue {
|
||||
return map[string]any{SelectKey: i}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]any{SelectKey: len(in)}, nil // default choice
|
||||
}
|
||||
|
||||
func (s *Selector) GetType() string {
|
||||
return "Selector"
|
||||
}
|
||||
|
||||
func (s *Selector) ConditionCount() int {
|
||||
return len(s.config.Clauses)
|
||||
}
|
||||
|
||||
func (s *Selector) SelectorInputConverter(in map[string]any) (out []Operants, err error) {
|
||||
conf := s.config.Clauses
|
||||
|
||||
for i, oneConf := range conf {
|
||||
if oneConf.Single != nil {
|
||||
left, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), LeftKey})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to take left operant from input map: %v, clause index= %d", in, i)
|
||||
}
|
||||
|
||||
right, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), RightKey})
|
||||
if ok {
|
||||
out = append(out, Operants{Left: left, Right: right})
|
||||
} else {
|
||||
out = append(out, Operants{Left: left})
|
||||
}
|
||||
} else if oneConf.Multi != nil {
|
||||
multiClause := make([]*Operants, 0)
|
||||
for j := range oneConf.Multi.Clauses {
|
||||
left, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), strconv.Itoa(j), LeftKey})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to take left operant from input map: %v, clause index= %d, single clause index= %d", in, i, j)
|
||||
}
|
||||
right, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), strconv.Itoa(j), RightKey})
|
||||
if ok {
|
||||
multiClause = append(multiClause, &Operants{Left: left, Right: right})
|
||||
} else {
|
||||
multiClause = append(multiClause, &Operants{Left: left})
|
||||
}
|
||||
}
|
||||
out = append(out, Operants{Multi: multiClause})
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid clause config, both single and multi are nil: %v", oneConf)
|
||||
}
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *Selector) ToCallbackOutput(_ context.Context, output map[string]any) (*nodes.StructuredCallbackOutput, error) {
|
||||
count := len(s.config.Clauses)
|
||||
out := output[SelectKey].(int)
|
||||
if out == count {
|
||||
cOutput := map[string]any{"result": "pass to else branch"}
|
||||
return &nodes.StructuredCallbackOutput{
|
||||
Output: cOutput,
|
||||
RawOutput: cOutput,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if out >= 0 && out < count {
|
||||
cOutput := map[string]any{"result": fmt.Sprintf("pass to condition %d branch", out+1)}
|
||||
return &nodes.StructuredCallbackOutput{
|
||||
Output: cOutput,
|
||||
RawOutput: cOutput,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("out of range: %d", out)
|
||||
}
|
||||
185
backend/domain/workflow/internal/nodes/stream.go
Normal file
185
backend/domain/workflow/internal/nodes/stream.go
Normal file
@@ -0,0 +1,185 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
)
|
||||
|
||||
var KeyIsFinished = "\x1FKey is finished\x1F"
|
||||
|
||||
type Mode string
|
||||
|
||||
const (
|
||||
Streaming Mode = "streaming"
|
||||
NonStreaming Mode = "non-streaming"
|
||||
)
|
||||
|
||||
type FieldStreamType string
|
||||
|
||||
const (
|
||||
FieldIsStream FieldStreamType = "yes" // absolutely a stream
|
||||
FieldNotStream FieldStreamType = "no" // absolutely not a stream
|
||||
FieldMaybeStream FieldStreamType = "maybe" // maybe a stream, requires request-time resolution
|
||||
FieldSkipped FieldStreamType = "skipped" // the field source's node is skipped
|
||||
)
|
||||
|
||||
// SourceInfo contains stream type for a input field source of a node.
|
||||
type SourceInfo struct {
|
||||
// IsIntermediate means this field is itself not a field source, but a map containing one or more field sources.
|
||||
IsIntermediate bool
|
||||
// FieldType the stream type of the field. May require request-time resolution in addition to compile-time.
|
||||
FieldType FieldStreamType
|
||||
// FromNodeKey is the node key that produces this field source. empty if the field is a static value or variable.
|
||||
FromNodeKey vo.NodeKey
|
||||
// FromPath is the path of this field source within the source node. empty if the field is a static value or variable.
|
||||
FromPath compose.FieldPath
|
||||
TypeInfo *vo.TypeInfo
|
||||
// SubSources are SourceInfo for keys within this intermediate Map(Object) field.
|
||||
SubSources map[string]*SourceInfo
|
||||
}
|
||||
|
||||
type DynamicStreamContainer interface {
|
||||
SaveDynamicChoice(nodeKey vo.NodeKey, groupToChoice map[string]int)
|
||||
GetDynamicChoice(nodeKey vo.NodeKey) map[string]int
|
||||
GetDynamicStreamType(nodeKey vo.NodeKey, group string) (FieldStreamType, error)
|
||||
GetAllDynamicStreamTypes(nodeKey vo.NodeKey) (map[string]FieldStreamType, error)
|
||||
}
|
||||
|
||||
// ResolveStreamSources resolves incoming field sources for a node, deciding their stream type.
|
||||
func ResolveStreamSources(ctx context.Context, sources map[string]*SourceInfo) (map[string]*SourceInfo, error) {
|
||||
resolved := make(map[string]*SourceInfo, len(sources))
|
||||
|
||||
nodeKey2Skipped := make(map[vo.NodeKey]bool)
|
||||
|
||||
var resolver func(path string, sInfo *SourceInfo) (*SourceInfo, error)
|
||||
resolver = func(path string, sInfo *SourceInfo) (*SourceInfo, error) {
|
||||
resolvedNode := &SourceInfo{
|
||||
IsIntermediate: sInfo.IsIntermediate,
|
||||
FieldType: sInfo.FieldType,
|
||||
FromNodeKey: sInfo.FromNodeKey,
|
||||
FromPath: sInfo.FromPath,
|
||||
TypeInfo: sInfo.TypeInfo,
|
||||
}
|
||||
|
||||
if len(sInfo.SubSources) > 0 {
|
||||
resolvedNode.SubSources = make(map[string]*SourceInfo, len(sInfo.SubSources))
|
||||
|
||||
for k, subInfo := range sInfo.SubSources {
|
||||
resolvedSub, err := resolver(k, subInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resolvedNode.SubSources[k] = resolvedSub
|
||||
}
|
||||
|
||||
return resolvedNode, nil
|
||||
}
|
||||
|
||||
if sInfo.FromNodeKey == "" { // static values and variables, always non-streaming and available
|
||||
return resolvedNode, nil
|
||||
}
|
||||
|
||||
var skipped, ok bool
|
||||
if skipped, ok = nodeKey2Skipped[sInfo.FromNodeKey]; !ok {
|
||||
_ = compose.ProcessState(ctx, func(ctx context.Context, state NodeExecuteStatusAware) error {
|
||||
skipped = !state.NodeExecuted(sInfo.FromNodeKey)
|
||||
return nil
|
||||
})
|
||||
nodeKey2Skipped[sInfo.FromNodeKey] = skipped
|
||||
}
|
||||
|
||||
if skipped {
|
||||
resolvedNode.FieldType = FieldSkipped
|
||||
return resolvedNode, nil
|
||||
}
|
||||
|
||||
if sInfo.FieldType == FieldMaybeStream {
|
||||
if len(sInfo.SubSources) > 0 {
|
||||
panic("a maybe stream field should not have sub sources")
|
||||
}
|
||||
|
||||
var streamType FieldStreamType
|
||||
err := compose.ProcessState(ctx, func(ctx context.Context, state DynamicStreamContainer) error {
|
||||
var e error
|
||||
streamType, e = state.GetDynamicStreamType(sInfo.FromNodeKey, sInfo.FromPath[0])
|
||||
return e
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &SourceInfo{
|
||||
IsIntermediate: sInfo.IsIntermediate,
|
||||
FieldType: streamType,
|
||||
FromNodeKey: sInfo.FromNodeKey,
|
||||
FromPath: sInfo.FromPath,
|
||||
SubSources: sInfo.SubSources,
|
||||
TypeInfo: sInfo.TypeInfo,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return resolvedNode, nil
|
||||
}
|
||||
|
||||
for k, sInfo := range sources {
|
||||
resolvedInfo, err := resolver(k, sInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resolved[k] = resolvedInfo
|
||||
}
|
||||
|
||||
return resolved, nil
|
||||
}
|
||||
|
||||
type NodeExecuteStatusAware interface {
|
||||
NodeExecuted(key vo.NodeKey) bool
|
||||
}
|
||||
|
||||
func (s *SourceInfo) Skipped() bool {
|
||||
if !s.IsIntermediate {
|
||||
return s.FieldType == FieldSkipped
|
||||
}
|
||||
|
||||
for _, sub := range s.SubSources {
|
||||
if !sub.Skipped() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *SourceInfo) FromNode(nodeKey vo.NodeKey) bool {
|
||||
if !s.IsIntermediate {
|
||||
return s.FromNodeKey == nodeKey
|
||||
}
|
||||
|
||||
for _, sub := range s.SubSources {
|
||||
if sub.FromNode(nodeKey) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,154 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package subworkflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Runner compose.Runnable[map[string]any, map[string]any]
|
||||
}
|
||||
|
||||
type SubWorkflow struct {
|
||||
cfg *Config
|
||||
}
|
||||
|
||||
func NewSubWorkflow(_ context.Context, cfg *Config) (*SubWorkflow, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config is nil")
|
||||
}
|
||||
|
||||
if cfg.Runner == nil {
|
||||
return nil, errors.New("runnable is nil")
|
||||
}
|
||||
|
||||
return &SubWorkflow{cfg: cfg}, nil
|
||||
}
|
||||
|
||||
func (s *SubWorkflow) Invoke(ctx context.Context, in map[string]any, opts ...nodes.NestedWorkflowOption) (map[string]any, error) {
|
||||
nestedOpts, nodeKey, err := prepareOptions(ctx, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out, err := s.cfg.Runner.Invoke(ctx, in, nestedOpts...)
|
||||
if err != nil {
|
||||
interruptInfo, ok := compose.ExtractInterruptInfo(err)
|
||||
if !ok {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
iEvent := &entity.InterruptEvent{
|
||||
NodeKey: nodeKey,
|
||||
NodeType: entity.NodeTypeSubWorkflow,
|
||||
SubWorkflowInterruptInfo: interruptInfo,
|
||||
}
|
||||
|
||||
err = compose.ProcessState(ctx, func(ctx context.Context, setter nodes.InterruptEventStore) error {
|
||||
return setter.SetInterruptEvent(nodeKey, iEvent)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, compose.InterruptAndRerun
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *SubWorkflow) Stream(ctx context.Context, in map[string]any, opts ...nodes.NestedWorkflowOption) (*schema.StreamReader[map[string]any], error) {
|
||||
nestedOpts, nodeKey, err := prepareOptions(ctx, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out, err := s.cfg.Runner.Stream(ctx, in, nestedOpts...)
|
||||
if err != nil {
|
||||
interruptInfo, ok := compose.ExtractInterruptInfo(err)
|
||||
if !ok {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
iEvent := &entity.InterruptEvent{
|
||||
NodeKey: nodeKey,
|
||||
NodeType: entity.NodeTypeSubWorkflow,
|
||||
SubWorkflowInterruptInfo: interruptInfo,
|
||||
}
|
||||
|
||||
err = compose.ProcessState(ctx, func(ctx context.Context, setter nodes.InterruptEventStore) error {
|
||||
return setter.SetInterruptEvent(nodeKey, iEvent)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, compose.InterruptAndRerun
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func prepareOptions(ctx context.Context, opts ...nodes.NestedWorkflowOption) ([]compose.Option, vo.NodeKey, error) {
|
||||
options := &nodes.NestedWorkflowOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
nestedOpts := options.GetOptsForNested()
|
||||
|
||||
exeCtx := execute.GetExeCtx(ctx)
|
||||
if exeCtx == nil {
|
||||
panic("impossible. exeCtx in sub workflow is nil")
|
||||
}
|
||||
|
||||
checkPointID := exeCtx.CheckPointID
|
||||
if len(checkPointID) > 0 {
|
||||
newCheckpointID := checkPointID
|
||||
if exeCtx.SubWorkflowCtx != nil {
|
||||
newCheckpointID += "_" + strconv.Itoa(int(exeCtx.SubWorkflowCtx.SubExecuteID))
|
||||
}
|
||||
newCheckpointID += "_" + strconv.Itoa(int(exeCtx.NodeCtx.NodeExecuteID))
|
||||
nestedOpts = append(nestedOpts, compose.WithCheckPointID(newCheckpointID))
|
||||
}
|
||||
|
||||
if len(options.GetResumeIndexes()) > 0 {
|
||||
if len(options.GetResumeIndexes()) != 1 {
|
||||
return nil, "", fmt.Errorf("resume indexes for sub workflow length must be 1")
|
||||
}
|
||||
if _, ok := options.GetResumeIndexes()[0]; !ok {
|
||||
return nil, "", fmt.Errorf("resume indexes for sub workflow must resume index 0")
|
||||
}
|
||||
stateModifier, ok := options.GetResumeIndexes()[0]
|
||||
if ok {
|
||||
nestedOpts = append(nestedOpts, compose.WithStateModifier(stateModifier))
|
||||
}
|
||||
}
|
||||
|
||||
return nestedOpts, exeCtx.NodeKey, nil
|
||||
}
|
||||
431
backend/domain/workflow/internal/nodes/template.go
Normal file
431
backend/domain/workflow/internal/nodes/template.go
Normal file
@@ -0,0 +1,431 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/bytedance/sonic/ast"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
type TemplatePart struct {
|
||||
IsVariable bool
|
||||
Value string
|
||||
Root string
|
||||
SubPathsBeforeSlice []string
|
||||
JsonPath []any
|
||||
|
||||
literal string
|
||||
}
|
||||
|
||||
var re = regexp.MustCompile(`{{\s*([^}]+)\s*}}`)
|
||||
|
||||
func ParseTemplate(template string) []TemplatePart {
|
||||
matches := re.FindAllStringSubmatchIndex(template, -1)
|
||||
parts := make([]TemplatePart, 0)
|
||||
lastEnd := 0
|
||||
|
||||
loop:
|
||||
for _, match := range matches {
|
||||
start, end := match[0], match[1]
|
||||
placeholderStart, placeholderEnd := match[2], match[3]
|
||||
|
||||
// Add the literal part before the current variable placeholder
|
||||
if start > lastEnd {
|
||||
parts = append(parts, TemplatePart{
|
||||
IsVariable: false,
|
||||
Value: template[lastEnd:start],
|
||||
})
|
||||
}
|
||||
|
||||
// Add the variable placeholder
|
||||
val := template[placeholderStart:placeholderEnd]
|
||||
segments := strings.Split(val, ".")
|
||||
var subPaths []string
|
||||
if !strings.Contains(segments[0], "[") {
|
||||
for i := 1; i < len(segments); i++ {
|
||||
if strings.Contains(segments[i], "[") {
|
||||
break
|
||||
}
|
||||
subPaths = append(subPaths, segments[i])
|
||||
}
|
||||
}
|
||||
|
||||
var jsonPath []any
|
||||
for _, segment := range segments {
|
||||
// find the first '[' to separate the initial key from array accessors
|
||||
firstBracket := strings.Index(segment, "[")
|
||||
if firstBracket == -1 {
|
||||
// No brackets, the whole segment is a key
|
||||
jsonPath = append(jsonPath, segment)
|
||||
continue
|
||||
}
|
||||
|
||||
// Add the initial key part
|
||||
key := segment[:firstBracket]
|
||||
if key != "" {
|
||||
jsonPath = append(jsonPath, key)
|
||||
}
|
||||
|
||||
// Now, parse the array accessors like [1][2]
|
||||
rest := segment[firstBracket:]
|
||||
for strings.HasPrefix(rest, "[") {
|
||||
closeBracket := strings.Index(rest, "]")
|
||||
if closeBracket == -1 {
|
||||
// Malformed, treat as literal
|
||||
parts = append(parts, TemplatePart{IsVariable: false, Value: val})
|
||||
continue loop
|
||||
}
|
||||
|
||||
idxStr := rest[1:closeBracket]
|
||||
idx, err := strconv.Atoi(idxStr)
|
||||
if err != nil {
|
||||
// Malformed, treat as literal
|
||||
parts = append(parts, TemplatePart{IsVariable: false, Value: val})
|
||||
continue loop
|
||||
}
|
||||
|
||||
jsonPath = append(jsonPath, idx)
|
||||
rest = rest[closeBracket+1:]
|
||||
}
|
||||
|
||||
if rest != "" {
|
||||
// Malformed, treat as literal
|
||||
parts = append(parts, TemplatePart{IsVariable: false, Value: val})
|
||||
continue loop
|
||||
}
|
||||
}
|
||||
|
||||
parts = append(parts, TemplatePart{
|
||||
IsVariable: true,
|
||||
Value: val,
|
||||
Root: removeSlice(segments[0]),
|
||||
SubPathsBeforeSlice: subPaths,
|
||||
JsonPath: jsonPath,
|
||||
|
||||
literal: "{{" + val + "}}",
|
||||
})
|
||||
|
||||
lastEnd = end
|
||||
}
|
||||
|
||||
// Add the remaining literal part if there is any
|
||||
if lastEnd < len(template) {
|
||||
parts = append(parts, TemplatePart{
|
||||
IsVariable: false,
|
||||
Value: template[lastEnd:],
|
||||
})
|
||||
}
|
||||
|
||||
return parts
|
||||
}
|
||||
|
||||
func removeSlice(s string) string {
|
||||
i := strings.Index(s, "[")
|
||||
if i != -1 {
|
||||
return s[:i]
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
type renderOptions struct {
|
||||
type2CustomRenderer map[reflect.Type]func(any) (string, error)
|
||||
reservedKey map[string]struct{}
|
||||
nilRenderer func() (string, error)
|
||||
}
|
||||
|
||||
func WithNilRender(fn func() (string, error)) RenderOption {
|
||||
return func(opts *renderOptions) {
|
||||
opts.nilRenderer = fn
|
||||
}
|
||||
}
|
||||
|
||||
type RenderOption func(options *renderOptions)
|
||||
|
||||
func WithCustomRender(rType reflect.Type, fn func(any) (string, error)) RenderOption {
|
||||
return func(opts *renderOptions) {
|
||||
if opts.type2CustomRenderer == nil {
|
||||
opts.type2CustomRenderer = make(map[reflect.Type]func(any) (string, error))
|
||||
}
|
||||
opts.type2CustomRenderer[rType] = fn
|
||||
}
|
||||
}
|
||||
|
||||
func WithReservedKey(keys ...string) RenderOption {
|
||||
return func(opts *renderOptions) {
|
||||
if opts.reservedKey == nil {
|
||||
opts.reservedKey = make(map[string]struct{})
|
||||
}
|
||||
for _, key := range keys {
|
||||
opts.reservedKey[key] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var renderConfig = sonic.Config{
|
||||
SortMapKeys: true,
|
||||
}.Froze()
|
||||
|
||||
func joinJsonPath(p []any) string {
|
||||
var sb strings.Builder
|
||||
for i := range p {
|
||||
field, ok := p[i].(string)
|
||||
if ok {
|
||||
if i > 0 {
|
||||
_, ok := p[i-1].(string)
|
||||
if ok {
|
||||
sb.WriteString(".")
|
||||
}
|
||||
}
|
||||
sb.WriteString(field)
|
||||
} else {
|
||||
sb.WriteString(fmt.Sprintf("[%d]", p[i]))
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (tp TemplatePart) Render(m []byte, opts ...RenderOption) (string, error) {
|
||||
options := &renderOptions{
|
||||
type2CustomRenderer: make(map[reflect.Type]func(any) (string, error)),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
n, err := sonic.Get(m, tp.JsonPath...)
|
||||
if err != nil {
|
||||
notExist := errors.Is(err, ast.ErrNotExist)
|
||||
var syntaxErr ast.SyntaxError
|
||||
if notExist || errors.As(err, &syntaxErr) {
|
||||
// get each path segments one by one until the first not found error
|
||||
var segParent, current ast.Node
|
||||
for i := range tp.JsonPath {
|
||||
current, err = sonic.Get(m, tp.JsonPath[:i+1]...)
|
||||
if err != nil {
|
||||
if errors.Is(err, ast.ErrNotExist) { // first not found segment
|
||||
segmentI, ok := tp.JsonPath[i].(int)
|
||||
if ok {
|
||||
if !segParent.Exists() {
|
||||
panic("impossible")
|
||||
} else {
|
||||
segArr, err := segParent.Array()
|
||||
if err != nil { // not taking elements from array
|
||||
return tp.literal, nil
|
||||
}
|
||||
|
||||
return "", vo.NewError(errno.ErrArrIndexOutOfRange,
|
||||
errorx.KV("arr_name", joinJsonPath(tp.JsonPath[:i])),
|
||||
errorx.KV("req_index", strconv.Itoa(segmentI)),
|
||||
errorx.KV("arr_len", strconv.Itoa(len(segArr))))
|
||||
}
|
||||
}
|
||||
return tp.literal, nil // not array element not found, but object field, just print
|
||||
} else if errors.As(err, &syntaxErr) {
|
||||
segmentI, ok := tp.JsonPath[i].(int)
|
||||
if ok {
|
||||
return "", vo.NewError(errno.ErrIndexingNilArray,
|
||||
errorx.KV("arr_name", joinJsonPath(tp.JsonPath[:i])),
|
||||
errorx.KV("req_index", strconv.Itoa(segmentI)))
|
||||
}
|
||||
return tp.literal, nil // not array element not found, but object field, just print
|
||||
}
|
||||
return tp.literal, nil // not ErrNotExist, just print
|
||||
} else {
|
||||
segParent = current
|
||||
}
|
||||
}
|
||||
}
|
||||
return tp.literal, nil
|
||||
}
|
||||
|
||||
i, err := n.InterfaceUseNumber()
|
||||
if err != nil {
|
||||
return tp.literal, nil
|
||||
}
|
||||
|
||||
if i == nil {
|
||||
if options.nilRenderer != nil {
|
||||
return options.nilRenderer()
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
if len(options.type2CustomRenderer) > 0 {
|
||||
rType := reflect.TypeOf(i)
|
||||
if fn, ok := options.type2CustomRenderer[rType]; ok {
|
||||
return fn(i)
|
||||
}
|
||||
}
|
||||
|
||||
switch i.(type) {
|
||||
case string:
|
||||
return i.(string), nil
|
||||
case json.Number:
|
||||
return i.(json.Number).String(), nil
|
||||
case bool:
|
||||
return strconv.FormatBool(i.(bool)), nil
|
||||
default:
|
||||
ms, err := renderConfig.MarshalToString(i) // keep order of the map keys
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return ms, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (tp TemplatePart) Skipped(resolvedSources map[string]*SourceInfo) (skipped bool, invalid bool) {
|
||||
if len(resolvedSources) == 0 { // no information available, maybe outside the scope of a workflow
|
||||
return false, false
|
||||
}
|
||||
|
||||
// examine along the TemplatePart's root and sub paths,
|
||||
// trying to find a matching SourceInfo as far as possible.
|
||||
// the result would be one of two cases:
|
||||
// - a REAL field source is matched, just check if that field source is skipped
|
||||
// - otherwise an INTERMEDIATE field source is matched, it can only be skipped if ALL its sub sources are skipped
|
||||
matchingSource, ok := resolvedSources[tp.Root]
|
||||
if !ok { // the user specified a non-existing source, it can never have any value, just skip it
|
||||
return false, true
|
||||
}
|
||||
|
||||
if !matchingSource.IsIntermediate {
|
||||
return matchingSource.FieldType == FieldSkipped, false
|
||||
}
|
||||
|
||||
for _, subPath := range tp.SubPathsBeforeSlice {
|
||||
subSource, ok := matchingSource.SubSources[subPath]
|
||||
if !ok { // has gone deeper than the field source
|
||||
if matchingSource.IsIntermediate { // the user specified a non-existing source, just skip it
|
||||
return false, true
|
||||
}
|
||||
return matchingSource.FieldType == FieldSkipped, false
|
||||
}
|
||||
|
||||
matchingSource = subSource
|
||||
}
|
||||
|
||||
if !matchingSource.IsIntermediate {
|
||||
return matchingSource.FieldType == FieldSkipped, false
|
||||
}
|
||||
|
||||
var checkSourceSkipped func(sInfo *SourceInfo) bool
|
||||
checkSourceSkipped = func(sInfo *SourceInfo) bool {
|
||||
if !sInfo.IsIntermediate {
|
||||
return sInfo.FieldType == FieldSkipped
|
||||
}
|
||||
for _, subSource := range sInfo.SubSources {
|
||||
if !checkSourceSkipped(subSource) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
return checkSourceSkipped(matchingSource), false
|
||||
}
|
||||
|
||||
func (tp TemplatePart) TypeInfo(types map[string]*vo.TypeInfo) *vo.TypeInfo {
|
||||
if len(tp.SubPathsBeforeSlice) == 0 {
|
||||
return types[tp.Root]
|
||||
}
|
||||
rootType, ok := types[tp.Root]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
currentType := rootType
|
||||
for _, subPath := range tp.SubPathsBeforeSlice {
|
||||
if len(currentType.Properties) == 0 {
|
||||
return nil
|
||||
}
|
||||
subType, ok := currentType.Properties[subPath]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
currentType = subType
|
||||
}
|
||||
return currentType
|
||||
}
|
||||
|
||||
func Render(ctx context.Context, tpl string, input map[string]any, sources map[string]*SourceInfo, opts ...RenderOption) (string, error) {
|
||||
mi, err := sonic.Marshal(input)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
resolvedSources, err := ResolveStreamSources(ctx, sources)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
options := &renderOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
parts := ParseTemplate(tpl)
|
||||
for _, part := range parts {
|
||||
if !part.IsVariable {
|
||||
sb.WriteString(part.Value)
|
||||
continue
|
||||
}
|
||||
|
||||
if options.reservedKey != nil {
|
||||
if _, ok := options.reservedKey[part.Root]; ok {
|
||||
i, err := part.Render(mi, opts...)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
sb.WriteString(i)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
skipped, invalid := part.Skipped(resolvedSources)
|
||||
if skipped {
|
||||
continue
|
||||
}
|
||||
|
||||
if invalid {
|
||||
sb.WriteString(part.literal)
|
||||
continue
|
||||
}
|
||||
|
||||
i, err := part.Render(mi, opts...)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
sb.WriteString(i)
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
@@ -0,0 +1,128 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package textprocessor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
)
|
||||
|
||||
type Type string
|
||||
|
||||
const (
|
||||
ConcatText Type = "concat"
|
||||
SplitText Type = "split"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Type Type `json:"type"`
|
||||
Tpl string `json:"tpl"`
|
||||
ConcatChar string `json:"concatChar"`
|
||||
Separators []string `json:"separator"`
|
||||
FullSources map[string]*nodes.SourceInfo `json:"fullSources"`
|
||||
}
|
||||
|
||||
type TextProcessor struct {
|
||||
config *Config
|
||||
}
|
||||
|
||||
func NewTextProcessor(_ context.Context, cfg *Config) (*TextProcessor, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("config requried")
|
||||
}
|
||||
if cfg.Type == ConcatText && len(cfg.Tpl) == 0 {
|
||||
return nil, fmt.Errorf("config tpl requried")
|
||||
}
|
||||
|
||||
return &TextProcessor{
|
||||
config: cfg,
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
const OutputKey = "output"
|
||||
|
||||
func (t *TextProcessor) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
switch t.config.Type {
|
||||
case ConcatText:
|
||||
arrayRenderer := func(i any) (string, error) {
|
||||
vs := i.([]any)
|
||||
return join(vs, t.config.ConcatChar)
|
||||
}
|
||||
|
||||
result, err := nodes.Render(ctx, t.config.Tpl, input, t.config.FullSources,
|
||||
nodes.WithCustomRender(reflect.TypeOf([]any{}), arrayRenderer))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return map[string]any{OutputKey: result}, nil
|
||||
case SplitText:
|
||||
value, ok := input["String"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("input string requried")
|
||||
}
|
||||
|
||||
valueString, ok := value.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("input string field must string type but got %T", valueString)
|
||||
}
|
||||
values := strings.Split(valueString, t.config.Separators[0])
|
||||
// 对每个分隔符进行迭代处理
|
||||
for _, sep := range t.config.Separators[1:] {
|
||||
var tempParts []string
|
||||
for _, part := range values {
|
||||
tempParts = append(tempParts, strings.Split(part, sep)...)
|
||||
}
|
||||
values = tempParts
|
||||
}
|
||||
anyValues := make([]any, 0, len(values))
|
||||
for _, v := range values {
|
||||
anyValues = append(anyValues, v)
|
||||
}
|
||||
|
||||
return map[string]any{OutputKey: anyValues}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("not support type %s", t.config.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func join(vs []any, concatChar string) (string, error) {
|
||||
as := make([]string, 0, len(vs))
|
||||
for _, v := range vs {
|
||||
if v == nil {
|
||||
as = append(as, "")
|
||||
continue
|
||||
}
|
||||
if _, ok := v.(map[string]any); ok {
|
||||
bs, err := sonic.Marshal(v)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
as = append(as, string(bs))
|
||||
continue
|
||||
}
|
||||
|
||||
as = append(as, fmt.Sprintf("%v", v))
|
||||
}
|
||||
return strings.Join(as, concatChar), nil
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package textprocessor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewTextProcessorNodeGenerator(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
t.Run("split", func(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Type: SplitText,
|
||||
Separators: []string{",", "|", "."},
|
||||
}
|
||||
p, err := NewTextProcessor(ctx, cfg)
|
||||
assert.NoError(t, err)
|
||||
|
||||
result, err := p.Invoke(ctx, map[string]any{
|
||||
"String": "a,b|c.d,e|f|g",
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, result["output"], []any{"a", "b", "c", "d", "e", "f", "g"})
|
||||
})
|
||||
|
||||
t.Run("concat", func(t *testing.T) {
|
||||
in := map[string]any{
|
||||
"a": []any{"1", map[string]any{
|
||||
"1": 1,
|
||||
}, 3},
|
||||
"b": map[string]any{
|
||||
"b1": []string{"1", "2", "3"},
|
||||
"b2": []any{"1", 2, "3"},
|
||||
},
|
||||
"c": map[string]any{
|
||||
"c1": "1",
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &Config{
|
||||
Type: ConcatText,
|
||||
ConcatChar: `\t`,
|
||||
Tpl: "fx{{a}}=={{b.b1}}=={{b.b2[1]}}=={{c}}",
|
||||
}
|
||||
p, err := NewTextProcessor(context.Background(), cfg)
|
||||
|
||||
result, err := p.Invoke(ctx, in)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, result["output"], `fx1\t{"1":1}\t3==1\t2\t3==2=={"c1":"1"}`)
|
||||
})
|
||||
}
|
||||
283
backend/domain/workflow/internal/nodes/utils.go
Normal file
283
backend/domain/workflow/internal/nodes/utils.go
Normal file
@@ -0,0 +1,283 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"maps"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
// TakeMapValue extracts the value for specified path from input map.
|
||||
// Returns false if map key not exist for specified path.
|
||||
func TakeMapValue(m map[string]any, path compose.FieldPath) (any, bool) {
|
||||
if m == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
container := m
|
||||
for _, p := range path[:len(path)-1] {
|
||||
if _, ok := container[p]; !ok {
|
||||
return nil, false
|
||||
}
|
||||
container = container[p].(map[string]any)
|
||||
}
|
||||
|
||||
if v, ok := container[path[len(path)-1]]; ok {
|
||||
return v, true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func SetMapValue(m map[string]any, path compose.FieldPath, v any) {
|
||||
container := m
|
||||
for _, p := range path[:len(path)-1] {
|
||||
if _, ok := container[p]; !ok {
|
||||
container[p] = make(map[string]any)
|
||||
}
|
||||
container = container[p].(map[string]any)
|
||||
}
|
||||
|
||||
container[path[len(path)-1]] = v
|
||||
}
|
||||
|
||||
func TemplateRender(template string, vals map[string]interface{}) (string, error) {
|
||||
sb := strings.Builder{}
|
||||
valsBytes, err := sonic.Marshal(vals)
|
||||
if err != nil {
|
||||
return "", vo.WrapError(errno.ErrSerializationDeserializationFail, err)
|
||||
}
|
||||
parts := ParseTemplate(template)
|
||||
for idx := range parts {
|
||||
part := parts[idx]
|
||||
if !part.IsVariable {
|
||||
sb.WriteString(part.Value)
|
||||
} else {
|
||||
renderString, err := part.Render(valsBytes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sb.WriteString(renderString)
|
||||
}
|
||||
}
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
func ExtractJSONString(content string) string {
|
||||
if strings.HasPrefix(content, "```") && strings.HasSuffix(content, "```") {
|
||||
content = content[3 : len(content)-3]
|
||||
}
|
||||
|
||||
if strings.HasPrefix(content, "json") {
|
||||
content = content[4:]
|
||||
}
|
||||
|
||||
return content
|
||||
}
|
||||
|
||||
func ConcatTwoMaps(m1, m2 map[string]any) (map[string]any, error) {
|
||||
merged := maps.Clone(m1)
|
||||
for k, v := range m2 {
|
||||
current, ok := merged[k]
|
||||
if !ok || current == nil {
|
||||
if vStr, ok := v.(string); ok {
|
||||
if vStr == KeyIsFinished {
|
||||
continue
|
||||
}
|
||||
}
|
||||
merged[k] = v
|
||||
continue
|
||||
}
|
||||
|
||||
vStr, ok1 := v.(string)
|
||||
currentStr, ok2 := current.(string)
|
||||
if ok1 && ok2 {
|
||||
if strings.HasSuffix(vStr, KeyIsFinished) {
|
||||
vStr = strings.TrimSuffix(vStr, KeyIsFinished)
|
||||
}
|
||||
merged[k] = currentStr + vStr
|
||||
continue
|
||||
}
|
||||
|
||||
vMap, ok1 := v.(map[string]any)
|
||||
currentMap, ok2 := current.(map[string]any)
|
||||
if ok1 && ok2 {
|
||||
concatenated, err := ConcatTwoMaps(currentMap, vMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
merged[k] = concatenated
|
||||
continue
|
||||
}
|
||||
|
||||
items, err := toSliceValue([]any{current, v})
|
||||
if err != nil {
|
||||
logs.Errorf("failed to convert to slice value: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var cv reflect.Value
|
||||
if reflect.TypeOf(v).Kind() == reflect.Map {
|
||||
cv, err = concatMaps(items)
|
||||
} else {
|
||||
cv, err = concatSliceValue(items)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
merged[k] = cv.Interface()
|
||||
}
|
||||
return merged, nil
|
||||
}
|
||||
|
||||
// the following codes are copied from github.com/cloudwego/eino
|
||||
|
||||
func concatMaps(ms reflect.Value) (reflect.Value, error) {
|
||||
typ := ms.Type().Elem()
|
||||
|
||||
rms := reflect.MakeMap(reflect.MapOf(typ.Key(), reflect.TypeOf((*[]any)(nil)).Elem()))
|
||||
ret := reflect.MakeMap(typ)
|
||||
|
||||
n := ms.Len()
|
||||
for i := 0; i < n; i++ {
|
||||
m := ms.Index(i)
|
||||
|
||||
for _, key := range m.MapKeys() {
|
||||
vals := rms.MapIndex(key)
|
||||
if !vals.IsValid() {
|
||||
var s []any
|
||||
vals = reflect.ValueOf(s)
|
||||
}
|
||||
|
||||
val := m.MapIndex(key)
|
||||
vals = reflect.Append(vals, val)
|
||||
rms.SetMapIndex(key, vals)
|
||||
}
|
||||
}
|
||||
|
||||
for _, key := range rms.MapKeys() {
|
||||
vals := rms.MapIndex(key)
|
||||
|
||||
anyVals := vals.Interface().([]any)
|
||||
v, err := toSliceValue(anyVals)
|
||||
if err != nil {
|
||||
return reflect.Value{}, err
|
||||
}
|
||||
|
||||
var cv reflect.Value
|
||||
|
||||
if v.Type().Elem().Kind() == reflect.Map {
|
||||
cv, err = concatMaps(v)
|
||||
} else {
|
||||
cv, err = concatSliceValue(v)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return reflect.Value{}, err
|
||||
}
|
||||
|
||||
ret.SetMapIndex(key, cv)
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func concatSliceValue(val reflect.Value) (reflect.Value, error) {
|
||||
elmType := val.Type().Elem()
|
||||
|
||||
if val.Len() == 1 {
|
||||
return val.Index(0), nil
|
||||
}
|
||||
|
||||
f := GetConcatFunc(elmType)
|
||||
if f != nil {
|
||||
return f(val)
|
||||
}
|
||||
|
||||
// if all elements in the slice are empty, return an empty value
|
||||
// if there is exactly one non-empty element in the slice, return that non-empty element
|
||||
// otherwise, throw an error.
|
||||
var filtered reflect.Value
|
||||
for i := 0; i < val.Len(); i++ {
|
||||
oneVal := val.Index(i)
|
||||
if !oneVal.IsZero() {
|
||||
if filtered.IsValid() {
|
||||
return reflect.Value{}, fmt.Errorf("cannot concat multiple non-zero value of type %s", elmType)
|
||||
}
|
||||
|
||||
filtered = oneVal
|
||||
}
|
||||
}
|
||||
if !filtered.IsValid() {
|
||||
filtered = reflect.New(elmType).Elem()
|
||||
}
|
||||
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
func toSliceValue(vs []any) (reflect.Value, error) {
|
||||
typ := reflect.TypeOf(vs[0])
|
||||
|
||||
ret := reflect.MakeSlice(reflect.SliceOf(typ), len(vs), len(vs))
|
||||
ret.Index(0).Set(reflect.ValueOf(vs[0]))
|
||||
|
||||
for i := 1; i < len(vs); i++ {
|
||||
v := vs[i]
|
||||
vt := reflect.TypeOf(v)
|
||||
if typ != vt {
|
||||
return reflect.Value{}, fmt.Errorf("unexpected slice element type. Got %v, expected %v", typ, vt)
|
||||
}
|
||||
|
||||
ret.Index(i).Set(reflect.ValueOf(v))
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
var (
|
||||
concatFunctions = map[reflect.Type]any{}
|
||||
)
|
||||
|
||||
func RegisterStreamChunkConcatFunc[T any](fn func([]T) (T, error)) {
|
||||
concatFunctions[reflect.TypeOf((*T)(nil)).Elem()] = fn
|
||||
}
|
||||
|
||||
func GetConcatFunc(typ reflect.Type) func(reflect.Value) (reflect.Value, error) {
|
||||
if fn, ok := concatFunctions[typ]; ok {
|
||||
return func(a reflect.Value) (reflect.Value, error) {
|
||||
rvs := reflect.ValueOf(fn).Call([]reflect.Value{a})
|
||||
var err error
|
||||
if !rvs[1].IsNil() {
|
||||
err = rvs[1].Interface().(error)
|
||||
}
|
||||
return rvs[0], err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,596 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package variableaggregator
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"maps"
|
||||
"math"
|
||||
"runtime/debug"
|
||||
"slices"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
)
|
||||
|
||||
type MergeStrategy uint
|
||||
|
||||
const (
|
||||
FirstNotNullValue MergeStrategy = 1
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
MergeStrategy MergeStrategy
|
||||
GroupLen map[string]int
|
||||
FullSources map[string]*nodes.SourceInfo
|
||||
NodeKey vo.NodeKey
|
||||
InputSources []*vo.FieldInfo
|
||||
GroupOrder []string // the order the groups are declared in frontend canvas
|
||||
}
|
||||
|
||||
type VariableAggregator struct {
|
||||
config *Config
|
||||
}
|
||||
|
||||
func NewVariableAggregator(_ context.Context, cfg *Config) (*VariableAggregator, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config is required")
|
||||
}
|
||||
if cfg.MergeStrategy != FirstNotNullValue {
|
||||
return nil, fmt.Errorf("merge strategy not supported: %v", cfg.MergeStrategy)
|
||||
}
|
||||
return &VariableAggregator{config: cfg}, nil
|
||||
}
|
||||
|
||||
func (v *VariableAggregator) Invoke(ctx context.Context, input map[string]any) (_ map[string]any, err error) {
|
||||
in, err := inputConverter(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make(map[string]any)
|
||||
groupToChoice := make(map[string]int)
|
||||
for group, length := range v.config.GroupLen {
|
||||
for i := 0; i < length; i++ {
|
||||
if value, ok := in[group][i]; ok {
|
||||
if value != nil {
|
||||
result[group] = value
|
||||
groupToChoice[group] = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := result[group]; !ok {
|
||||
groupToChoice[group] = -1
|
||||
}
|
||||
}
|
||||
|
||||
_ = compose.ProcessState(ctx, func(ctx context.Context, state nodes.DynamicStreamContainer) error {
|
||||
state.SaveDynamicChoice(v.config.NodeKey, groupToChoice)
|
||||
return nil
|
||||
})
|
||||
|
||||
ctxcache.Store(ctx, groupChoiceTypeCacheKey, map[string]nodes.FieldStreamType{}) // none of the choices are stream
|
||||
|
||||
groupChoices := make([]any, 0, len(v.config.GroupOrder))
|
||||
for _, group := range v.config.GroupOrder {
|
||||
choice := groupToChoice[group]
|
||||
if choice == -1 {
|
||||
groupChoices = append(groupChoices, nil)
|
||||
} else {
|
||||
groupChoices = append(groupChoices, choice)
|
||||
}
|
||||
}
|
||||
|
||||
ctxcache.Store(ctx, groupChoiceCacheKey, groupChoices)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
const (
|
||||
resolvedSourcesCacheKey = "resolved_sources"
|
||||
groupChoiceTypeCacheKey = "group_choice_type"
|
||||
groupChoiceCacheKey = "group_choice"
|
||||
)
|
||||
|
||||
// Transform picks the first non-nil value from each group from a stream of map[group]items.
|
||||
func (v *VariableAggregator) Transform(ctx context.Context, input *schema.StreamReader[map[string]any]) (
|
||||
_ *schema.StreamReader[map[string]any], err error) {
|
||||
inStream := streamInputConverter(input)
|
||||
|
||||
resolvedSources, ok := ctxcache.Get[map[string]*nodes.SourceInfo](ctx, resolvedSourcesCacheKey)
|
||||
if !ok {
|
||||
panic("unable to get resolvesSources from ctx cache.")
|
||||
}
|
||||
|
||||
groupToItems := make(map[string][]any)
|
||||
groupToChoice := make(map[string]int)
|
||||
type skipped struct{}
|
||||
type null struct{}
|
||||
type stream struct{}
|
||||
|
||||
defer func() {
|
||||
if err == nil {
|
||||
groupChoiceToStreamType := map[string]nodes.FieldStreamType{}
|
||||
for group, choice := range groupToChoice {
|
||||
if choice != -1 {
|
||||
item := groupToItems[group][choice]
|
||||
if _, ok := item.(stream); ok {
|
||||
groupChoiceToStreamType[group] = nodes.FieldIsStream
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
groupChoices := make([]any, 0, len(v.config.GroupOrder))
|
||||
for _, group := range v.config.GroupOrder {
|
||||
choice := groupToChoice[group]
|
||||
if choice == -1 {
|
||||
groupChoices = append(groupChoices, nil)
|
||||
} else {
|
||||
groupChoices = append(groupChoices, choice)
|
||||
}
|
||||
}
|
||||
|
||||
// store group -> field type for use in callbacks.OnEnd
|
||||
ctxcache.Store(ctx, groupChoiceTypeCacheKey, groupChoiceToStreamType)
|
||||
ctxcache.Store(ctx, groupChoiceCacheKey, groupChoices)
|
||||
}
|
||||
}()
|
||||
|
||||
// goal: find the first non-nil element in each group. 'First' means the smallest index in each group's slice.
|
||||
// For a stream element, if the stream source is not skipped, then this stream element is non-nil,
|
||||
// even if there's no content in the stream.
|
||||
// steps:
|
||||
// - for each group, iterate over each element in order, check the element's stream type
|
||||
// - if an element is skipped, move on
|
||||
// - if an element is stream, pick it
|
||||
// - if an element is not stream, actually receive from the stream to check if it's non-nil
|
||||
|
||||
groupToCurrentIndex := make(map[string]int) // the currently known smallest index that is non-nil for each group
|
||||
for group, length := range v.config.GroupLen {
|
||||
groupToItems[group] = make([]any, length)
|
||||
groupToCurrentIndex[group] = math.MaxInt
|
||||
for i := 0; i < length; i++ {
|
||||
fType := resolvedSources[group].SubSources[strconv.Itoa(i)].FieldType
|
||||
if fType == nodes.FieldSkipped {
|
||||
groupToItems[group][i] = skipped{}
|
||||
continue
|
||||
}
|
||||
if fType == nodes.FieldIsStream {
|
||||
groupToItems[group][i] = stream{}
|
||||
if ci, _ := groupToCurrentIndex[group]; i < ci {
|
||||
groupToCurrentIndex[group] = i
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
hasUndecided := false
|
||||
for i := 0; i < length; i++ {
|
||||
if groupToItems[group][i] == nil {
|
||||
hasUndecided = true
|
||||
break
|
||||
}
|
||||
|
||||
_, ok := groupToItems[group][i].(stream)
|
||||
if ok { // if none of the elements before this one is none-stream, pick this first stream
|
||||
groupToChoice[group] = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := groupToChoice[group]; !ok && !hasUndecided {
|
||||
groupToChoice[group] = -1 // all of this group's elements are skipped, won't have any non-nil ones
|
||||
}
|
||||
}
|
||||
|
||||
allDone := func() bool {
|
||||
for group := range v.config.GroupLen {
|
||||
_, ok := groupToChoice[group]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
alreadyDone := allDone()
|
||||
if alreadyDone { // all groups have made their choices, no need to actually read input streams
|
||||
result := make(map[string]any, len(v.config.GroupLen))
|
||||
allSkip := true
|
||||
for group := range groupToChoice {
|
||||
choice := groupToChoice[group]
|
||||
if choice == -1 {
|
||||
result[group] = nil // all elements of this group are skipped
|
||||
} else {
|
||||
result[group] = choice
|
||||
allSkip = false
|
||||
}
|
||||
}
|
||||
|
||||
if allSkip { // no need to convert input streams for the output, because all groups are skipped
|
||||
_ = compose.ProcessState(ctx, func(ctx context.Context, state nodes.DynamicStreamContainer) error {
|
||||
state.SaveDynamicChoice(v.config.NodeKey, groupToChoice)
|
||||
return nil
|
||||
})
|
||||
return schema.StreamReaderFromArray([]map[string]any{result}), nil
|
||||
}
|
||||
}
|
||||
|
||||
outS := inStream
|
||||
if !alreadyDone {
|
||||
inCopy := inStream.Copy(2)
|
||||
defer inCopy[0].Close()
|
||||
outS = inCopy[1]
|
||||
|
||||
recvLoop:
|
||||
for {
|
||||
chunk, err := inCopy[0].Recv()
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
panic("EOF reached before making choices for all groups")
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for group, items := range chunk {
|
||||
if _, ok := groupToChoice[group]; ok {
|
||||
continue // already made the decision for the group.
|
||||
}
|
||||
|
||||
for i := range items {
|
||||
if i >= groupToCurrentIndex[group] {
|
||||
continue
|
||||
}
|
||||
|
||||
existing := groupToItems[group][i]
|
||||
if existing != nil { // belongs to a stream element
|
||||
continue
|
||||
}
|
||||
|
||||
// now the item is always a non-stream element
|
||||
item := items[i]
|
||||
if item == nil {
|
||||
groupToItems[group][i] = null{}
|
||||
} else {
|
||||
groupToItems[group][i] = item
|
||||
}
|
||||
|
||||
groupToCurrentIndex[group] = i
|
||||
|
||||
finalized := true
|
||||
for j := 0; j < i; j++ {
|
||||
indexedItem := groupToItems[group][j]
|
||||
if indexedItem == nil { // there exists non-finalized elements in front of the current item
|
||||
finalized = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if finalized {
|
||||
if item == nil { // current item is nil, we need to find the first non-nil element in the group
|
||||
foundNonNil := false
|
||||
hasUndecided := false
|
||||
for j := 0; j < len(groupToItems[group]); j++ {
|
||||
indexedItem := groupToItems[group][j]
|
||||
if indexedItem != nil {
|
||||
_, ok := indexedItem.(skipped)
|
||||
if ok {
|
||||
continue
|
||||
}
|
||||
|
||||
_, ok = indexedItem.(null)
|
||||
if ok {
|
||||
continue
|
||||
}
|
||||
|
||||
groupToChoice[group] = j
|
||||
foundNonNil = true
|
||||
break
|
||||
} else {
|
||||
hasUndecided = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundNonNil && !hasUndecided {
|
||||
groupToChoice[group] = -1 // this group does not have any non-nil value
|
||||
}
|
||||
} else {
|
||||
groupToChoice[group] = i
|
||||
}
|
||||
if allDone() {
|
||||
break recvLoop
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_ = compose.ProcessState(ctx, func(ctx context.Context, state nodes.DynamicStreamContainer) error {
|
||||
state.SaveDynamicChoice(v.config.NodeKey, groupToChoice)
|
||||
return nil
|
||||
})
|
||||
|
||||
actualStream := schema.StreamReaderWithConvert(outS, func(in map[string]map[int]any) (map[string]any, error) {
|
||||
out := make(map[string]any)
|
||||
for group, items := range in {
|
||||
choice, ok := groupToChoice[group]
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("group %s does not have choice", group))
|
||||
}
|
||||
|
||||
if choice < 0 {
|
||||
panic(fmt.Sprintf("group %s choice = %d, less than zero, but found actual item in stream", group, choice))
|
||||
}
|
||||
|
||||
if _, ok := items[choice]; ok {
|
||||
out[group] = items[choice]
|
||||
}
|
||||
}
|
||||
|
||||
if len(out) == 0 {
|
||||
return nil, schema.ErrNoValue
|
||||
}
|
||||
|
||||
return out, nil
|
||||
})
|
||||
|
||||
nullGroups := make(map[string]any)
|
||||
for group, choice := range groupToChoice {
|
||||
if choice < 0 {
|
||||
nullGroups[group] = nil
|
||||
}
|
||||
}
|
||||
if len(nullGroups) > 0 {
|
||||
nullStream := schema.StreamReaderFromArray([]map[string]any{nullGroups})
|
||||
return schema.MergeStreamReaders([]*schema.StreamReader[map[string]any]{actualStream, nullStream}), nil
|
||||
}
|
||||
|
||||
return actualStream, nil
|
||||
}
|
||||
|
||||
func inputConverter(in map[string]any) (converted map[string]map[int]any, err error) {
|
||||
converted = make(map[string]map[int]any)
|
||||
|
||||
for k, value := range in {
|
||||
m, ok := value.(map[string]any)
|
||||
if !ok {
|
||||
return nil, errors.New("value is not a map[string]any")
|
||||
}
|
||||
converted[k] = make(map[int]any, len(m))
|
||||
for i, sv := range m {
|
||||
index, err := strconv.Atoi(i)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(" converting %s to int failed, err=%v", i, err)
|
||||
}
|
||||
converted[k][index] = sv
|
||||
}
|
||||
}
|
||||
|
||||
return converted, nil
|
||||
}
|
||||
|
||||
func streamInputConverter(in *schema.StreamReader[map[string]any]) *schema.StreamReader[map[string]map[int]any] {
|
||||
converter := func(input map[string]any) (output map[string]map[int]any, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = safego.NewPanicErr(r, debug.Stack())
|
||||
}
|
||||
}()
|
||||
return inputConverter(input)
|
||||
}
|
||||
return schema.StreamReaderWithConvert(in, converter)
|
||||
}
|
||||
|
||||
type vaCallbackInput struct {
|
||||
Name string `json:"name"`
|
||||
Variables []any `json:"variables"`
|
||||
}
|
||||
|
||||
func (v *VariableAggregator) Init(ctx context.Context) (context.Context, error) {
|
||||
ctx = ctxcache.Init(ctx)
|
||||
|
||||
resolvedSources, err := nodes.ResolveStreamSources(ctx, v.config.FullSources)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// need this info for callbacks.OnStart, so we put it in cache within Init()
|
||||
ctxcache.Store(ctx, resolvedSourcesCacheKey, resolvedSources)
|
||||
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
type streamMarkerType string
|
||||
|
||||
const streamMarker streamMarkerType = "<Stream Data...>"
|
||||
|
||||
func (v *VariableAggregator) ToCallbackInput(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
resolvedSources, ok := ctxcache.Get[map[string]*nodes.SourceInfo](ctx, resolvedSourcesCacheKey)
|
||||
if !ok {
|
||||
panic("unable to get resolved_sources from ctx cache")
|
||||
}
|
||||
|
||||
in, err := inputConverter(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
merged := make([]vaCallbackInput, 0, len(in))
|
||||
|
||||
groupLen := v.config.GroupLen
|
||||
|
||||
for groupName, vars := range in {
|
||||
orderedVars := make([]any, groupLen[groupName])
|
||||
for index := range vars {
|
||||
orderedVars[index] = vars[index]
|
||||
if len(resolvedSources) > 0 {
|
||||
if resolvedSources[groupName].SubSources[strconv.Itoa(index)].FieldType == nodes.FieldIsStream {
|
||||
// replace the streams with streamMarker,
|
||||
// because we won't read, save to execution history, or display these streams to user
|
||||
orderedVars[index] = streamMarker
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
merged = append(merged, vaCallbackInput{
|
||||
Name: groupName,
|
||||
Variables: orderedVars,
|
||||
})
|
||||
}
|
||||
|
||||
// Sort merged slice by Name
|
||||
sort.Slice(merged, func(i, j int) bool {
|
||||
return merged[i].Name < merged[j].Name
|
||||
})
|
||||
|
||||
return map[string]any{
|
||||
"mergeGroups": merged,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (v *VariableAggregator) ToCallbackOutput(ctx context.Context, output map[string]any) (*nodes.StructuredCallbackOutput, error) {
|
||||
dynamicStreamType, ok := ctxcache.Get[map[string]nodes.FieldStreamType](ctx, groupChoiceTypeCacheKey)
|
||||
if !ok {
|
||||
panic("unable to get dynamic stream types from ctx cache")
|
||||
}
|
||||
|
||||
groupChoices, ok := ctxcache.Get[[]any](ctx, groupChoiceCacheKey)
|
||||
if !ok {
|
||||
panic("unable to get group choices from ctx cache")
|
||||
}
|
||||
|
||||
if len(dynamicStreamType) == 0 {
|
||||
return &nodes.StructuredCallbackOutput{
|
||||
Output: output,
|
||||
RawOutput: output,
|
||||
Extra: map[string]any{
|
||||
"variable_select": groupChoices,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
newOut := maps.Clone(output)
|
||||
for k := range output {
|
||||
if t, ok := dynamicStreamType[k]; ok && t == nodes.FieldIsStream {
|
||||
newOut[k] = streamMarker
|
||||
}
|
||||
}
|
||||
|
||||
return &nodes.StructuredCallbackOutput{
|
||||
Output: newOut,
|
||||
RawOutput: newOut,
|
||||
Extra: map[string]any{
|
||||
"variable_select": groupChoices,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func concatVACallbackInputs(vs [][]vaCallbackInput) ([]vaCallbackInput, error) {
|
||||
if len(vs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
init := slices.Clone(vs[0])
|
||||
for i := 1; i < len(vs); i++ {
|
||||
next := vs[i]
|
||||
for j := 0; j < len(next); j++ {
|
||||
oneGroup := next[j]
|
||||
groupName := oneGroup.Name
|
||||
var (
|
||||
existingGroup *vaCallbackInput
|
||||
nextIndex = len(init)
|
||||
currentIndex int
|
||||
)
|
||||
for k := 0; k < len(init); k++ {
|
||||
if init[k].Name == groupName {
|
||||
existingGroup = ptr.Of(init[k])
|
||||
currentIndex = k
|
||||
} else if init[k].Name > groupName && k < nextIndex {
|
||||
nextIndex = k
|
||||
}
|
||||
}
|
||||
|
||||
if existingGroup == nil {
|
||||
after := slices.Clone(init[nextIndex:])
|
||||
init = append(init[:nextIndex], oneGroup)
|
||||
init = append(init, after...)
|
||||
} else {
|
||||
for vi := 0; vi < len(oneGroup.Variables); vi++ {
|
||||
newV := oneGroup.Variables[vi]
|
||||
if newV == nil {
|
||||
if vi >= len(existingGroup.Variables) {
|
||||
for i := len(existingGroup.Variables); i <= vi; i++ {
|
||||
existingGroup.Variables = append(existingGroup.Variables, nil)
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
if newStr, ok := newV.(string); ok {
|
||||
if strings.HasSuffix(newStr, nodes.KeyIsFinished) {
|
||||
newStr = strings.TrimSuffix(newStr, nodes.KeyIsFinished)
|
||||
}
|
||||
newV = newStr
|
||||
}
|
||||
for ei := len(existingGroup.Variables); ei <= vi; ei++ {
|
||||
existingGroup.Variables = append(existingGroup.Variables, nil)
|
||||
}
|
||||
ev := existingGroup.Variables[vi]
|
||||
if ev == nil {
|
||||
existingGroup.Variables[vi] = oneGroup.Variables[vi]
|
||||
} else {
|
||||
if evStr, ok := ev.(streamMarkerType); !ok {
|
||||
return nil, fmt.Errorf("multiple stream chunk when concating VACallbackInputs, variable %s is not string", ev)
|
||||
} else {
|
||||
if evStr != streamMarker || newV.(streamMarkerType) != streamMarker {
|
||||
return nil, fmt.Errorf("multiple stream chunk when concating VACallbackInputs, variable %s is not streamMarker", ev)
|
||||
}
|
||||
existingGroup.Variables[vi] = evStr
|
||||
}
|
||||
}
|
||||
}
|
||||
init[currentIndex] = *existingGroup
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return init, nil
|
||||
}
|
||||
|
||||
func concatStreamMarkers(_ []streamMarkerType) (streamMarkerType, error) {
|
||||
return streamMarker, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
nodes.RegisterStreamChunkConcatFunc(concatVACallbackInputs)
|
||||
nodes.RegisterStreamChunkConcatFunc(concatStreamMarkers)
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package variableassigner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
type AppVariables struct {
|
||||
vars map[string]any
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewAppVariables() *AppVariables {
|
||||
return &AppVariables{
|
||||
vars: make(map[string]any),
|
||||
}
|
||||
}
|
||||
|
||||
func (av *AppVariables) Set(key string, value any) {
|
||||
av.mu.Lock()
|
||||
av.vars[key] = value
|
||||
av.mu.Unlock()
|
||||
}
|
||||
|
||||
func (av *AppVariables) Get(key string) (any, bool) {
|
||||
av.mu.RLock()
|
||||
defer av.mu.RUnlock()
|
||||
|
||||
if value, ok := av.vars[key]; ok {
|
||||
return value, ok
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
type AppVariableStore interface {
|
||||
GetAppVariableValue(key string) (any, bool)
|
||||
SetAppVariableValue(key string, value any)
|
||||
}
|
||||
|
||||
type VariableAssigner struct {
|
||||
config *Config
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Pairs []*Pair
|
||||
Handler *variable.Handler
|
||||
}
|
||||
|
||||
type Pair struct {
|
||||
Left vo.Reference
|
||||
Right compose.FieldPath
|
||||
}
|
||||
|
||||
func NewVariableAssigner(_ context.Context, conf *Config) (*VariableAssigner, error) {
|
||||
for _, pair := range conf.Pairs {
|
||||
if pair.Left.VariableType == nil {
|
||||
return nil, fmt.Errorf("cannot assign to output of nodes in VariableAssigner, ref: %v", pair.Left)
|
||||
}
|
||||
|
||||
if *pair.Left.VariableType == vo.GlobalSystem {
|
||||
return nil, fmt.Errorf("cannot assign to global system variables in VariableAssigner because they are read-only, ref: %v", pair.Left)
|
||||
}
|
||||
|
||||
vType := *pair.Left.VariableType
|
||||
if vType != vo.GlobalAPP && vType != vo.GlobalUser {
|
||||
return nil, fmt.Errorf("cannot assign to variable type %s in VariableAssigner", vType)
|
||||
}
|
||||
}
|
||||
|
||||
return &VariableAssigner{
|
||||
config: conf,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (v *VariableAssigner) Assign(ctx context.Context, in map[string]any) (map[string]any, error) {
|
||||
for _, pair := range v.config.Pairs {
|
||||
right, ok := nodes.TakeMapValue(in, pair.Right)
|
||||
if !ok {
|
||||
return nil, vo.NewError(errno.ErrInputFieldMissing, errorx.KV("name", strings.Join(pair.Right, ".")))
|
||||
}
|
||||
|
||||
vType := *pair.Left.VariableType
|
||||
switch vType {
|
||||
case vo.GlobalAPP:
|
||||
err := compose.ProcessState(ctx, func(ctx context.Context, appVarsStore AppVariableStore) error {
|
||||
if len(pair.Left.FromPath) != 1 {
|
||||
return fmt.Errorf("can only assign to top level variable: %v", pair.Left.FromPath)
|
||||
}
|
||||
appVarsStore.SetAppVariableValue(pair.Left.FromPath[0], right)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case vo.GlobalUser:
|
||||
opts := make([]variable.OptionFn, 0, 1)
|
||||
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil {
|
||||
exeCfg := exeCtx.RootCtx.ExeCfg
|
||||
opts = append(opts, variable.WithStoreInfo(variable.StoreInfo{
|
||||
AgentID: exeCfg.AgentID,
|
||||
AppID: exeCfg.AppID,
|
||||
ConnectorID: exeCfg.ConnectorID,
|
||||
ConnectorUID: exeCfg.ConnectorUID,
|
||||
}))
|
||||
}
|
||||
err := v.config.Handler.Set(ctx, *pair.Left.VariableType, pair.Left.FromPath, right, opts...)
|
||||
if err != nil {
|
||||
return nil, vo.WrapIfNeeded(errno.ErrVariablesAPIFail, err)
|
||||
}
|
||||
default:
|
||||
panic("impossible")
|
||||
}
|
||||
}
|
||||
|
||||
// TODO if not error considered successful
|
||||
return map[string]any{
|
||||
"isSuccess": true,
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package variableassigner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
)
|
||||
|
||||
type InLoop struct {
|
||||
config *Config
|
||||
intermediateVarStore variable.Store
|
||||
}
|
||||
|
||||
func NewVariableAssignerInLoop(_ context.Context, conf *Config) (*InLoop, error) {
|
||||
return &InLoop{
|
||||
config: conf,
|
||||
intermediateVarStore: &nodes.ParentIntermediateStore{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (v *InLoop) Assign(ctx context.Context, in map[string]any) (out map[string]any, err error) {
|
||||
for _, pair := range v.config.Pairs {
|
||||
if pair.Left.VariableType == nil || *pair.Left.VariableType != vo.ParentIntermediate {
|
||||
panic(fmt.Errorf("dest is %+v in VariableAssignerInloop, invalid", pair.Left))
|
||||
}
|
||||
|
||||
right, ok := nodes.TakeMapValue(in, pair.Right)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to extract right value for path %s", pair.Right)
|
||||
}
|
||||
|
||||
err := v.intermediateVarStore.Set(ctx, pair.Left.FromPath, right)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]any{}, nil
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package variableassigner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
func TestVariableAssigner(t *testing.T) {
|
||||
intVar := any(1)
|
||||
strVar := any("str")
|
||||
objVar := any(map[string]any{
|
||||
"key": "value",
|
||||
})
|
||||
arrVar := any([]any{1, "2"})
|
||||
|
||||
va := &InLoop{
|
||||
config: &Config{
|
||||
Pairs: []*Pair{
|
||||
{
|
||||
Left: vo.Reference{
|
||||
FromPath: compose.FieldPath{"int_var_s"},
|
||||
VariableType: ptr.Of(vo.ParentIntermediate),
|
||||
},
|
||||
Right: compose.FieldPath{"int_var_t"},
|
||||
},
|
||||
{
|
||||
Left: vo.Reference{
|
||||
FromPath: compose.FieldPath{"str_var_s"},
|
||||
VariableType: ptr.Of(vo.ParentIntermediate),
|
||||
},
|
||||
Right: compose.FieldPath{"str_var_t"},
|
||||
},
|
||||
{
|
||||
Left: vo.Reference{
|
||||
FromPath: compose.FieldPath{"obj_var_s"},
|
||||
VariableType: ptr.Of(vo.ParentIntermediate),
|
||||
},
|
||||
Right: compose.FieldPath{"obj_var_t"},
|
||||
},
|
||||
{
|
||||
Left: vo.Reference{
|
||||
FromPath: compose.FieldPath{"arr_var_s"},
|
||||
VariableType: ptr.Of(vo.ParentIntermediate),
|
||||
},
|
||||
Right: compose.FieldPath{"arr_var_t"},
|
||||
},
|
||||
},
|
||||
},
|
||||
intermediateVarStore: &nodes.ParentIntermediateStore{},
|
||||
}
|
||||
|
||||
ctx := nodes.InitIntermediateVars(context.Background(), map[string]*any{
|
||||
"int_var_s": &intVar,
|
||||
"str_var_s": &strVar,
|
||||
"obj_var_s": &objVar,
|
||||
"arr_var_s": &arrVar,
|
||||
}, nil)
|
||||
|
||||
_, err := va.Assign(ctx, map[string]any{
|
||||
"int_var_t": 2,
|
||||
"str_var_t": "str2",
|
||||
"obj_var_t": map[string]any{
|
||||
"key2": "value2",
|
||||
},
|
||||
"arr_var_t": []any{3, "4"},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 2, intVar)
|
||||
assert.Equal(t, "str2", strVar)
|
||||
assert.Equal(t, map[string]any{
|
||||
"key2": "value2",
|
||||
}, objVar)
|
||||
assert.Equal(t, []any{3, "4"}, arrVar)
|
||||
}
|
||||
Reference in New Issue
Block a user