feat: manually mirror opencoze's code from bytedance

Change-Id: I09a73aadda978ad9511264a756b2ce51f5761adf
This commit is contained in:
fanlv
2025-07-20 17:36:12 +08:00
commit 890153324f
14811 changed files with 1923430 additions and 0 deletions

View File

@@ -0,0 +1,447 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package batch
import (
"context"
"errors"
"fmt"
"math"
"reflect"
"slices"
"sync"
"github.com/cloudwego/eino/compose"
"golang.org/x/exp/maps"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
)
type Batch struct {
config *Config
outputs map[string]*vo.FieldSource
}
type Config struct {
BatchNodeKey vo.NodeKey `json:"batch_node_key"`
InnerWorkflow compose.Runnable[map[string]any, map[string]any]
InputArrays []string `json:"input_arrays"`
Outputs []*vo.FieldInfo `json:"outputs"`
}
func NewBatch(_ context.Context, config *Config) (*Batch, error) {
if config == nil {
return nil, errors.New("config is required")
}
if len(config.InputArrays) == 0 {
return nil, errors.New("need to have at least one incoming array for batch")
}
if len(config.Outputs) == 0 {
return nil, errors.New("need to have at least one output variable for batch")
}
b := &Batch{
config: config,
outputs: make(map[string]*vo.FieldSource),
}
for i := range config.Outputs {
source := config.Outputs[i]
path := source.Path
if len(path) != 1 {
return nil, fmt.Errorf("invalid path %q", path)
}
b.outputs[path[0]] = &source.Source
}
return b, nil
}
const (
MaxBatchSizeKey = "batchSize"
ConcurrentSizeKey = "concurrentSize"
)
func (b *Batch) initOutput(length int) map[string]any {
out := make(map[string]any, len(b.outputs))
for key := range b.outputs {
sliceType := reflect.TypeOf([]any{})
slice := reflect.New(sliceType).Elem()
slice.Set(reflect.MakeSlice(sliceType, length, length))
out[key] = slice.Interface()
}
return out
}
func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.NestedWorkflowOption) (
out map[string]any, err error) {
arrays := make(map[string]any, len(b.config.InputArrays))
minLen := math.MaxInt64
for _, arrayKey := range b.config.InputArrays {
a, ok := nodes.TakeMapValue(in, compose.FieldPath{arrayKey})
if !ok {
return nil, fmt.Errorf("incoming array not present in input: %s", arrayKey)
}
if reflect.TypeOf(a).Kind() != reflect.Slice {
return nil, fmt.Errorf("incoming array not a slice: %s. Actual type: %v",
arrayKey, reflect.TypeOf(a))
}
arrays[arrayKey] = a
oneLen := reflect.ValueOf(a).Len()
if oneLen < minLen {
minLen = oneLen
}
}
var maxIter, concurrency int64
maxIterAny, ok := nodes.TakeMapValue(in, compose.FieldPath{MaxBatchSizeKey})
if !ok {
return nil, fmt.Errorf("incoming max iteration not present in input: %s", in)
}
maxIter = maxIterAny.(int64)
if maxIter == 0 {
maxIter = 100
}
concurrencyAny, ok := nodes.TakeMapValue(in, compose.FieldPath{ConcurrentSizeKey})
if !ok {
return nil, fmt.Errorf("incoming concurrency not present in input: %s", in)
}
concurrency = concurrencyAny.(int64)
if concurrency == 0 {
concurrency = 10
}
if minLen > int(maxIter) {
minLen = int(maxIter)
}
output := b.initOutput(minLen)
if minLen == 0 {
return output, nil
}
getIthInput := func(i int) (map[string]any, map[string]any, error) {
input := make(map[string]any)
for k, v := range in { // carry over other values
if k != MaxBatchSizeKey && k != ConcurrentSizeKey {
input[k] = v
}
}
input[string(b.config.BatchNodeKey)+"#index"] = int64(i)
items := make(map[string]any)
for arrayKey, array := range arrays {
ele := reflect.ValueOf(array).Index(i).Interface()
items[arrayKey] = []any{ele}
currentKey := string(b.config.BatchNodeKey) + "#" + arrayKey
// Recursively expand map[string]any elements
if m, ok := ele.(map[string]any); ok {
var expand func(prefix string, val interface{})
expand = func(prefix string, val interface{}) {
if nestedMap, ok := val.(map[string]any); ok {
for k, v := range nestedMap {
expand(prefix+"#"+k, v)
}
} else {
input[prefix] = val
}
}
expand(currentKey, m)
} else {
input[currentKey] = ele
}
}
return input, items, nil
}
setIthOutput := func(i int, taskOutput map[string]any) error {
for k, source := range b.outputs {
fromValue, _ := nodes.TakeMapValue(taskOutput, append(compose.FieldPath{string(source.Ref.FromNodeKey)},
source.Ref.FromPath...))
toArray, ok := nodes.TakeMapValue(output, compose.FieldPath{k})
if !ok {
return fmt.Errorf("key not present in outer workflow's output: %s", k)
}
toArray.([]any)[i] = fromValue
}
return nil
}
options := &nodes.NestedWorkflowOptions{}
for _, opt := range opts {
opt(options)
}
var existingCState *nodes.NestedWorkflowState
err = compose.ProcessState(ctx, func(ctx context.Context, getter nodes.NestedWorkflowAware) error {
var e error
existingCState, _, e = getter.GetNestedWorkflowState(b.config.BatchNodeKey)
if e != nil {
return e
}
return nil
})
if err != nil {
return nil, err
}
if existingCState != nil {
output = existingCState.FullOutput
}
ctx, cancelFn := context.WithCancelCause(ctx)
var (
wg sync.WaitGroup
mu sync.Mutex
index2Done = map[int]bool{}
index2InterruptInfo = map[int]*compose.InterruptInfo{}
resumed = map[int]bool{}
)
ithTask := func(i int) {
defer wg.Done()
if existingCState != nil {
if existingCState.Index2Done[i] == true {
return
}
if existingCState.Index2InterruptInfo[i] != nil {
if len(options.GetResumeIndexes()) > 0 {
if _, ok := options.GetResumeIndexes()[i]; !ok {
// previously interrupted, but not resumed this time, skip
return
}
}
}
mu.Lock()
resumed[i] = true
mu.Unlock()
}
select {
case <-ctx.Done():
return // canceled by normal error, abort
default:
}
mu.Lock()
if len(index2InterruptInfo) > 0 { // already has interrupted index, abort
mu.Unlock()
return
}
mu.Unlock()
input, items, err := getIthInput(i)
if err != nil {
cancelFn(err)
return
}
subCtx, subCheckpointID := execute.InheritExeCtxWithBatchInfo(ctx, i, items)
ithOpts := slices.Clone(options.GetOptsForNested())
mu.Lock()
ithOpts = append(ithOpts, options.GetOptsForIndexed(i)...)
mu.Unlock()
if subCheckpointID != "" {
logs.CtxInfof(ctx, "[testInterrupt] prepare %d th run for batch node %s, subCheckPointID %s",
i, b.config.BatchNodeKey, subCheckpointID)
ithOpts = append(ithOpts, compose.WithCheckPointID(subCheckpointID))
}
mu.Lock()
if len(options.GetResumeIndexes()) > 0 {
stateModifier, ok := options.GetResumeIndexes()[i]
mu.Unlock()
if ok {
fmt.Println("has state modifier for ith run: ", i, ", checkpointID: ", subCheckpointID)
ithOpts = append(ithOpts, compose.WithStateModifier(stateModifier))
}
} else {
mu.Unlock()
}
// if the innerWorkflow has output emitter that requires stream output, then we need to stream the inner workflow
// the output then needs to be concatenated.
taskOutput, err := b.config.InnerWorkflow.Invoke(subCtx, input, ithOpts...)
if err != nil {
info, ok := compose.ExtractInterruptInfo(err)
if !ok {
cancelFn(err)
return
}
mu.Lock()
index2InterruptInfo[i] = info
mu.Unlock()
return
}
if err = setIthOutput(i, taskOutput); err != nil {
cancelFn(err)
return
}
mu.Lock()
index2Done[i] = true
mu.Unlock()
}
wg.Add(minLen)
if minLen < int(concurrency) {
for i := 1; i < minLen; i++ {
go ithTask(i)
}
ithTask(0)
} else {
taskChan := make(chan int, concurrency)
for i := 0; i < int(concurrency); i++ {
safego.Go(ctx, func() {
for i := range taskChan {
ithTask(i)
}
})
}
for i := 0; i < minLen; i++ {
taskChan <- i
}
close(taskChan)
}
wg.Wait()
if context.Cause(ctx) != nil {
if errors.Is(context.Cause(ctx), context.Canceled) {
return nil, context.Canceled // canceled by Eino workflow engine
}
return nil, context.Cause(ctx) // normal error, just throw it out
}
// delete the interruptions that have been resumed
for index := range resumed {
delete(existingCState.Index2InterruptInfo, index)
}
compState := existingCState
if compState == nil {
compState = &nodes.NestedWorkflowState{
Index2Done: index2Done,
Index2InterruptInfo: index2InterruptInfo,
FullOutput: output,
}
} else {
for i := range index2Done {
compState.Index2Done[i] = index2Done[i]
}
for i := range index2InterruptInfo {
compState.Index2InterruptInfo[i] = index2InterruptInfo[i]
}
compState.FullOutput = output
}
if len(index2InterruptInfo) > 0 { // this invocation of batch.Execute has new interruptions
iEvent := &entity.InterruptEvent{
NodeKey: b.config.BatchNodeKey,
NodeType: entity.NodeTypeBatch,
NestedInterruptInfo: index2InterruptInfo, // only emit the newly generated interruptInfo
}
err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error {
if e := setter.SaveNestedWorkflowState(b.config.BatchNodeKey, compState); e != nil {
return e
}
return setter.SetInterruptEvent(b.config.BatchNodeKey, iEvent)
})
if err != nil {
return nil, err
}
fmt.Println("save interruptEvent in state within batch: ", iEvent)
fmt.Println("save composite info in state within batch: ", compState)
return nil, compose.InterruptAndRerun
} else {
err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error {
if e := setter.SaveNestedWorkflowState(b.config.BatchNodeKey, compState); e != nil {
return e
}
if existingCState == nil {
return nil
}
// although this invocation does not have new interruptions,
// this batch node previously have interrupts yet to be resumed.
// we overwrite the interrupt events, keeping only the interrupts yet to be resumed.
return setter.SetInterruptEvent(b.config.BatchNodeKey, &entity.InterruptEvent{
NodeKey: b.config.BatchNodeKey,
NodeType: entity.NodeTypeBatch,
NestedInterruptInfo: existingCState.Index2InterruptInfo,
})
})
if err != nil {
return nil, err
}
fmt.Println("save composite info in state within batch: ", compState)
}
if existingCState != nil && len(existingCState.Index2InterruptInfo) > 0 {
logs.CtxInfof(ctx, "no interrupt thrown this round, but has historical interrupt events yet to be resumed, "+
"nodeKey: %v. indexes: %v", b.config.BatchNodeKey, maps.Keys(existingCState.Index2InterruptInfo))
return nil, compose.InterruptAndRerun // interrupt again to wait for resuming of previously interrupted index runs
}
return output, nil
}
func (b *Batch) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) {
trimmed := make(map[string]any, len(b.config.InputArrays))
for _, arrayKey := range b.config.InputArrays {
if v, ok := in[arrayKey]; ok {
trimmed[arrayKey] = v
}
}
return trimmed, nil
}

View File

@@ -0,0 +1,26 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nodes
import "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
type StructuredCallbackOutput struct {
Output map[string]any
RawOutput map[string]any
Extra map[string]any // node specific extra info, will go into node execution's extra.ResponseExtra
Error vo.WorkflowError
}

View File

@@ -0,0 +1,272 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package code
import (
"context"
"errors"
"fmt"
"regexp"
"strings"
"golang.org/x/exp/maps"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/code"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
const (
coderRunnerRawOutputCtxKey = "ctx_raw_output"
coderRunnerWarnErrorLevelCtxKey = "ctx_warn_error_level"
)
var (
importRegex = regexp.MustCompile(`^\s*import\s+([a-zA-Z0-9_.,\s]+)`)
fromImportRegex = regexp.MustCompile(`^\s*from\s+([a-zA-Z0-9_.]+)\s+import`)
)
// pythonBuiltinModules is the list of python built-in modules,
// see: https://docs.python.org/3.9/library/
var pythonBuiltinModules = map[string]struct{}{
"abc": {}, "aifc": {}, "antigravity": {}, "argparse": {}, "ast": {}, "asynchat": {}, "asyncio": {}, "asyncore": {}, "array": {},
"atexit": {}, "base64": {}, "bdb": {}, "binhex": {}, "bisect": {}, "builtins": {}, "bz2": {}, "cProfile": {}, "binascii": {},
"calendar": {}, "cgi": {}, "cgitb": {}, "chunk": {}, "cmd": {}, "code": {}, "codecs": {}, "codeop": {}, "cmath": {}, "audioop": {},
"collections": {}, "colorsys": {}, "compileall": {}, "concurrent": {}, "configparser": {}, "contextlib": {}, "contextvars": {}, "copy": {},
"copyreg": {}, "crypt": {}, "csv": {}, "ctypes": {}, "curses": {}, "dataclasses": {}, "datetime": {}, "dbm": {}, "fcntl": {},
"decimal": {}, "difflib": {}, "dis": {}, "distutils": {}, "doctest": {}, "email": {}, "encodings": {}, "ensurepip": {}, "ossaudiodev": {},
"enum": {}, "errno": {}, "faulthandler": {}, "filecmp": {}, "fileinput": {}, "fnmatch": {}, "formatter": {}, "fractions": {},
"ftplib": {}, "functools": {}, "gc": {}, "genericpath": {}, "getopt": {}, "getpass": {}, "gettext": {}, "glob": {}, "grp": {},
"graphlib": {}, "gzip": {}, "hashlib": {}, "heapq": {}, "hmac": {}, "html": {}, "http": {}, "imaplib": {}, "msvcrt": {},
"imghdr": {}, "imp": {}, "importlib": {}, "inspect": {}, "io": {}, "ipaddress": {}, "itertools": {}, "json": {}, "mmap": {},
"keyword": {}, "lib2to3": {}, "linecache": {}, "locale": {}, "logging": {}, "lzma": {}, "mailbox": {}, "mailcap": {}, "msilib": {},
"marshal": {}, "math": {}, "mimetypes": {}, "modulefinder": {}, "multiprocessing": {}, "netrc": {}, "nntplib": {}, "ntpath": {},
"nturl2path": {}, "numbers": {}, "opcode": {}, "operator": {}, "optparse": {}, "os": {}, "pathlib": {}, "pdb": {}, "readline": {},
"pickle": {}, "pickletools": {}, "pipes": {}, "pkgutil": {}, "platform": {}, "plistlib": {}, "poplib": {}, "posix": {}, "parser": {},
"posixpath": {}, "pprint": {}, "profile": {}, "pstats": {}, "pty": {}, "pwd": {}, "py_compile": {}, "pyclbr": {}, "spwd": {},
"pydoc": {}, "pydoc_data": {}, "queue": {}, "quopri": {}, "random": {}, "re": {}, "reprlib": {}, "rlcompleter": {}, "resource": {},
"runpy": {}, "sched": {}, "secrets": {}, "selectors": {}, "shelve": {}, "shlex": {}, "shutil": {}, "signal": {}, "select": {},
"site": {}, "smtpd": {}, "smtplib": {}, "sndhdr": {}, "socket": {}, "socketserver": {}, "sqlite3": {}, "sre_compile": {},
"sre_constants": {}, "sre_parse": {}, "ssl": {}, "stat": {}, "statistics": {}, "string": {}, "stringprep": {}, "struct": {},
"subprocess": {}, "sunau": {}, "symbol": {}, "symtable": {}, "sys": {}, "sysconfig": {}, "tabnanny": {}, "tarfile": {}, "nis": {},
"telnetlib": {}, "tempfile": {}, "textwrap": {}, "this": {}, "threading": {}, "time": {}, "timeit": {}, "tkinter": {}, "test": {},
"token": {}, "tokenize": {}, "trace": {}, "traceback": {}, "tracemalloc": {}, "tty": {}, "turtle": {}, "turtledemo": {},
"types": {}, "typing": {}, "unittest": {}, "urllib": {}, "uu": {}, "uuid": {}, "venv": {}, "warnings": {}, "termios": {},
"wave": {}, "weakref": {}, "webbrowser": {}, "wsgiref": {}, "xdrlib": {}, "xml": {}, "xmlrpc": {}, "xxsubtype": {}, "zlib": {},
"zipapp": {}, "zipfile": {}, "zipimport": {}, "zoneinfo": {}, "winreg": {}, "syslog": {}, "winsound": {}, "unicodedata": {},
}
// pythonBuiltinBlacklist is the blacklist of python built-in modules,
// see: https://www.coze.cn/open/docs/guides/code_node#7f41f073
var pythonBuiltinBlacklist = map[string]struct{}{
"curses": {},
"dbm": {},
"ensurepip": {},
"fcntl": {},
"grp": {},
"idlelib": {},
"lib2to3": {},
"msvcrt": {},
"pwd": {},
"resource": {},
"syslog": {},
"termios": {},
"tkinter": {},
"turtle": {},
"turtledemo": {},
"venv": {},
"winreg": {},
"winsound": {},
"multiprocessing": {},
"threading": {},
"socket": {},
"pty": {},
"tty": {},
}
// pythonThirdPartyWhitelist is the whitelist of python third-party modules,
// see: https://www.coze.cn/open/docs/guides/code_node#7f41f073
// If you want to use other third-party libraries, you can add them to this whitelist.
// And you also need to install them in `/scripts/setup/python.sh` and `/backend/Dockerfile` via `pip install`.
var pythonThirdPartyWhitelist = map[string]struct{}{
"requests_async": {},
"numpy": {},
}
type Config struct {
Code string
Language code.Language
OutputConfig map[string]*vo.TypeInfo
Runner code.Runner
}
type CodeRunner struct {
config *Config
importError error
}
func NewCodeRunner(ctx context.Context, cfg *Config) (*CodeRunner, error) {
if cfg == nil {
return nil, errors.New("cfg is required")
}
if cfg.Language == "" {
return nil, errors.New("language is required")
}
if cfg.Code == "" {
return nil, errors.New("code is required")
}
if cfg.Language != code.Python {
return nil, errors.New("only support python language")
}
if len(cfg.OutputConfig) == 0 {
return nil, errors.New("output config is required")
}
if cfg.Runner == nil {
return nil, errors.New("run coder is required")
}
importErr := validatePythonImports(cfg.Code)
return &CodeRunner{
config: cfg,
importError: importErr,
}, nil
}
func validatePythonImports(code string) error {
imports := parsePythonImports(code)
importErrors := make([]string, 0)
var blacklistedModules []string
var nonWhitelistedModules []string
for _, imp := range imports {
if _, ok := pythonBuiltinModules[imp]; ok {
if _, blacklisted := pythonBuiltinBlacklist[imp]; blacklisted {
blacklistedModules = append(blacklistedModules, imp)
}
} else {
if _, whitelisted := pythonThirdPartyWhitelist[imp]; !whitelisted {
nonWhitelistedModules = append(nonWhitelistedModules, imp)
}
}
}
if len(blacklistedModules) > 0 {
moduleNames := fmt.Sprintf("'%s'", strings.Join(blacklistedModules, "', '"))
importErrors = append(importErrors, fmt.Sprintf("ModuleNotFoundError: The module(s) %s are removed from the Python standard library for security reasons\n", moduleNames))
}
if len(nonWhitelistedModules) > 0 {
moduleNames := fmt.Sprintf("'%s'", strings.Join(nonWhitelistedModules, "', '"))
importErrors = append(importErrors, fmt.Sprintf("ModuleNotFoundError: No module named %s\n", moduleNames))
}
if len(importErrors) > 0 {
return errors.New(strings.Join(importErrors, ","))
}
return nil
}
func (c *CodeRunner) RunCode(ctx context.Context, input map[string]any) (ret map[string]any, err error) {
if c.importError != nil {
return nil, vo.WrapError(errno.ErrCodeExecuteFail, c.importError, errorx.KV("detail", c.importError.Error()))
}
response, err := c.config.Runner.Run(ctx, &code.RunRequest{Code: c.config.Code, Language: c.config.Language, Params: input})
if err != nil {
return nil, vo.WrapError(errno.ErrCodeExecuteFail, err, errorx.KV("detail", err.Error()))
}
result := response.Result
ctxcache.Store(ctx, coderRunnerRawOutputCtxKey, result)
output, ws, err := nodes.ConvertInputs(ctx, result, c.config.OutputConfig)
if err != nil {
return nil, vo.WrapIfNeeded(errno.ErrCodeExecuteFail, err, errorx.KV("detail", err.Error()))
}
if ws != nil && len(*ws) > 0 {
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
ctxcache.Store(ctx, coderRunnerWarnErrorLevelCtxKey, *ws)
}
return output, nil
}
func (c *CodeRunner) ToCallbackOutput(ctx context.Context, output map[string]any) (*nodes.StructuredCallbackOutput, error) {
rawOutput, ok := ctxcache.Get[map[string]any](ctx, coderRunnerRawOutputCtxKey)
if !ok {
return nil, errors.New("raw output config is required")
}
var wfe vo.WorkflowError
if warnings, ok := ctxcache.Get[nodes.ConversionWarnings](ctx, coderRunnerWarnErrorLevelCtxKey); ok {
wfe = vo.WrapWarn(errno.ErrNodeOutputParseFail, warnings, errorx.KV("warnings", warnings.Error()))
}
return &nodes.StructuredCallbackOutput{
Output: output,
RawOutput: rawOutput,
Error: wfe,
},
nil
}
func parsePythonImports(code string) []string {
modules := make(map[string]struct{})
lines := strings.Split(code, "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "#") {
continue
}
if matches := importRegex.FindStringSubmatch(line); len(matches) > 1 {
importedItemsStr := matches[1]
importedItems := strings.Split(importedItemsStr, ",")
for _, item := range importedItems {
item = strings.TrimSpace(item)
parts := strings.Split(item, " ")
if len(parts) > 0 {
moduleName := parts[0]
topLevelModule := strings.Split(moduleName, ".")[0]
modules[topLevelModule] = struct{}{}
}
}
continue
}
if matches := fromImportRegex.FindStringSubmatch(line); len(matches) > 1 {
fullModuleName := matches[1]
parts := strings.Split(fullModuleName, ".")
if len(parts) > 0 {
modules[parts[0]] = struct{}{}
}
}
}
return maps.Keys(modules)
}

View File

@@ -0,0 +1,262 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package code
import (
"encoding/json"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/code"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
mockcode "github.com/coze-dev/coze-studio/backend/internal/mock/domain/workflow/crossdomain/code"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
)
var codeTpl string
func TestCode_RunCode(t *testing.T) {
ctrl := gomock.NewController(t)
mockRunner := mockcode.NewMockRunner(ctrl)
t.Run("normal", func(t *testing.T) {
var codeTpl = `
async def main(args:Args)->Output:
params = args.params
ret: Output = {
"key0": params['input'] + params['input'],
"key1": ["hello", "world"],
"key2": [123, "345"],
"key3": {
"key31": "hi",
"key32": "hello",
"key33": ["123","456"],
"key34": {
"key341":"123",
"key342":456,
}
},
}
return ret
`
ret := map[string]any{
"key0": int64(11231123),
"key1": []any{"hello", "world"},
"key2": []interface{}{int64(123), "345"},
"key3": map[string]interface{}{"key31": "hi", "key32": "hello", "key33": []any{"123", "456"}, "key34": map[string]interface{}{"key341": "123", "key342": int64(456)}},
"key4": []any{
map[string]any{"key41": "41"},
map[string]any{"key42": "42"},
},
}
response := &code.RunResponse{
Result: ret,
}
mockRunner.EXPECT().Run(gomock.Any(), gomock.Any()).Return(response, nil)
ctx := t.Context()
c := &CodeRunner{
config: &Config{
Language: code.Python,
Code: codeTpl,
OutputConfig: map[string]*vo.TypeInfo{
"key0": {Type: vo.DataTypeInteger},
"key1": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeString}},
"key2": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key3": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key31": &vo.TypeInfo{Type: vo.DataTypeString},
"key32": &vo.TypeInfo{Type: vo.DataTypeString},
"key33": &vo.TypeInfo{Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key34": &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key341": &vo.TypeInfo{Type: vo.DataTypeString},
"key342": &vo.TypeInfo{Type: vo.DataTypeString},
}},
},
},
"key4": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject}},
},
Runner: mockRunner,
},
}
ret, err := c.RunCode(ctx, map[string]any{
"input": "1123",
})
bs, _ := json.Marshal(ret)
fmt.Println(string(bs))
assert.NoError(t, err)
assert.Equal(t, int64(11231123), ret["key0"])
assert.Equal(t, []any{"hello", "world"}, ret["key1"])
assert.Equal(t, []any{float64(123), float64(345)}, ret["key2"])
assert.Equal(t, []any{float64(123), float64(456)}, ret["key3"].(map[string]any)["key33"])
assert.Equal(t, map[string]any{"key41": "41"}, ret["key4"].([]any)[0].(map[string]any))
})
t.Run("field not in return", func(t *testing.T) {
codeTpl = `
async def main(args:Args)->Output:
params = args.params
ret: Output = {
"key0": params['input'] + params['input'],
"key1": ["hello", "world"],
"key2": [123, "345"],
"key3": {
"key31": "hi",
"key32": "hello",
"key34": {
"key341":"123"
}
},
}
return ret
`
ret := map[string]any{
"key0": int64(11231123),
"key1": []any{"hello", "world"},
"key2": []interface{}{int64(123), "345"},
"key3": map[string]interface{}{"key31": "hi", "key32": "hello", "key34": map[string]interface{}{"key341": "123"}},
}
response := &code.RunResponse{
Result: ret,
}
mockRunner.EXPECT().Run(gomock.Any(), gomock.Any()).Return(response, nil)
ctx := t.Context()
c := &CodeRunner{
config: &Config{
Code: codeTpl,
Language: code.Python,
OutputConfig: map[string]*vo.TypeInfo{
"key0": {Type: vo.DataTypeInteger},
"key1": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeString}},
"key2": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key3": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key31": &vo.TypeInfo{Type: vo.DataTypeString},
"key32": &vo.TypeInfo{Type: vo.DataTypeString},
"key33": &vo.TypeInfo{Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key34": &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key341": &vo.TypeInfo{Type: vo.DataTypeString},
"key342": &vo.TypeInfo{Type: vo.DataTypeString},
}},
}},
"key4": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key31": &vo.TypeInfo{Type: vo.DataTypeString},
"key32": &vo.TypeInfo{Type: vo.DataTypeString},
"key33": &vo.TypeInfo{Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key34": &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key341": &vo.TypeInfo{Type: vo.DataTypeString},
"key342": &vo.TypeInfo{Type: vo.DataTypeString},
},
}},
},
},
Runner: mockRunner,
},
}
ret, err := c.RunCode(ctx, map[string]any{
"input": "1123",
})
assert.NoError(t, err)
assert.Equal(t, int64(11231123), ret["key0"])
assert.Equal(t, []any{"hello", "world"}, ret["key1"])
assert.Equal(t, []any{float64(123), float64(345)}, ret["key2"])
assert.Equal(t, nil, ret["key4"])
assert.Equal(t, nil, ret["key3"].(map[string]any)["key33"])
})
t.Run("field convert failed", func(t *testing.T) {
codeTpl = `
async def main(args:Args)->Output:
params = args.params
ret: Output = {
"key0": params['input'] + params['input'],
"key1": ["hello", "world"],
"key2": [123, "345"],
"key3": {
"key31": "hi",
"key32": "hello",
"key34": {
"key341":"123",
"key343": ["hello", "world"],
}
},
}
return ret
`
ctx := t.Context()
ctx = ctxcache.Init(ctx)
ret := map[string]any{
"key0": int64(11231123),
"key1": []any{"hello", "world"},
"key2": []interface{}{int64(123), "345"},
"key3": map[string]interface{}{"key31": "hi", "key32": "hello", "key34": map[string]interface{}{"key341": "123", "key343": []any{"hello", "world"}}},
}
response := &code.RunResponse{
Result: ret,
}
mockRunner.EXPECT().Run(gomock.Any(), gomock.Any()).Return(response, nil)
c := &CodeRunner{
config: &Config{
Code: codeTpl,
Language: code.Python,
OutputConfig: map[string]*vo.TypeInfo{
"key0": {Type: vo.DataTypeInteger},
"key1": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key2": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key3": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key31": &vo.TypeInfo{Type: vo.DataTypeString},
"key32": &vo.TypeInfo{Type: vo.DataTypeString},
"key33": &vo.TypeInfo{Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key34": &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key341": &vo.TypeInfo{Type: vo.DataTypeString},
"key342": &vo.TypeInfo{Type: vo.DataTypeString},
"key343": &vo.TypeInfo{Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
}},
},
},
},
Runner: mockRunner,
},
}
ret, err := c.RunCode(ctx, map[string]any{
"input": "1123",
})
assert.NoError(t, err)
assert.NoError(t, err)
assert.Equal(t, int64(11231123), ret["key0"])
assert.Equal(t, []any{float64(123), float64(345)}, ret["key2"])
warnings, ok := ctxcache.Get[nodes.ConversionWarnings](ctx, coderRunnerWarnErrorLevelCtxKey)
assert.True(t, ok)
s := warnings.Error()
assert.Contains(t, s, "field key3.key34.key343.0 is not number")
assert.Contains(t, s, "field key3.key34.key343.1 is not number")
assert.Contains(t, s, "field key1.0 is not number")
assert.Contains(t, s, "field key1.1 is not number")
})
}

View File

@@ -0,0 +1,64 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package conversation
import (
"context"
"errors"
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/conversation"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
)
type ClearMessageConfig struct {
Clearer conversation.ConversationManager
}
type MessageClear struct {
config *ClearMessageConfig
}
func NewClearMessage(ctx context.Context, cfg *ClearMessageConfig) (*MessageClear, error) {
if cfg == nil {
return nil, errors.New("config is required")
}
if cfg.Clearer == nil {
return nil, errors.New("clearer is required")
}
return &MessageClear{
config: cfg,
}, nil
}
func (c *MessageClear) Clear(ctx context.Context, input map[string]any) (map[string]any, error) {
name, ok := nodes.TakeMapValue(input, compose.FieldPath{"ConversationName"})
if !ok {
return nil, errors.New("input map should contains 'ConversationName' key ")
}
response, err := c.config.Clearer.ClearMessage(ctx, &conversation.ClearMessageRequest{
Name: name.(string),
})
if err != nil {
return nil, err
}
return map[string]any{
"isSuccess": response.IsSuccess,
}, nil
}

View File

@@ -0,0 +1,62 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package conversation
import (
"context"
"errors"
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/conversation"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
)
type CreateConversationConfig struct {
Creator conversation.ConversationManager
}
type CreateConversation struct {
config *CreateConversationConfig
}
func NewCreateConversation(ctx context.Context, cfg *CreateConversationConfig) (*CreateConversation, error) {
if cfg == nil {
return nil, errors.New("config is required")
}
if cfg.Creator == nil {
return nil, errors.New("creator is required")
}
return &CreateConversation{
config: cfg,
}, nil
}
func (c *CreateConversation) Create(ctx context.Context, input map[string]any) (map[string]any, error) {
name, ok := nodes.TakeMapValue(input, compose.FieldPath{"ConversationName"})
if !ok {
return nil, errors.New("input map should contains 'ConversationName' key ")
}
response, err := c.config.Creator.CreateConversation(ctx, &conversation.CreateConversationRequest{
Name: name.(string),
})
if err != nil {
return nil, err
}
return response.Result, nil
}

View File

@@ -0,0 +1,108 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package conversation
import (
"context"
"encoding/json"
"errors"
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/conversation"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
)
type MessageListConfig struct {
Lister conversation.ConversationManager
}
type MessageList struct {
config *MessageListConfig
}
type Param struct {
ConversationName string
Limit *int
BeforeID *string
AfterID *string
}
func NewMessageList(ctx context.Context, cfg *MessageListConfig) (*MessageList, error) {
if cfg == nil {
return nil, errors.New("config is required")
}
if cfg.Lister == nil {
return nil, errors.New("lister is required")
}
return &MessageList{
config: cfg,
}, nil
}
func (m *MessageList) List(ctx context.Context, input map[string]any) (map[string]any, error) {
param := &Param{}
name, ok := nodes.TakeMapValue(input, compose.FieldPath{"ConversationName"})
if !ok {
return nil, errors.New("ConversationName is required")
}
param.ConversationName = name.(string)
limit, ok := nodes.TakeMapValue(input, compose.FieldPath{"Limit"})
if ok {
limit := limit.(int)
param.Limit = &limit
}
beforeID, ok := nodes.TakeMapValue(input, compose.FieldPath{"BeforeID"})
if ok {
beforeID := beforeID.(string)
param.BeforeID = &beforeID
}
afterID, ok := nodes.TakeMapValue(input, compose.FieldPath{"AfterID"})
if ok {
afterID := afterID.(string)
param.BeforeID = &afterID
}
r, err := m.config.Lister.MessageList(ctx, &conversation.ListMessageRequest{
ConversationName: param.ConversationName,
Limit: param.Limit,
BeforeID: param.BeforeID,
AfterID: param.AfterID,
})
if err != nil {
return nil, err
}
result := make(map[string]any)
objects := make([]any, 0, len(r.Messages))
for _, msg := range r.Messages {
object := make(map[string]any)
bs, _ := json.Marshal(msg)
err := json.Unmarshal(bs, &object)
if err != nil {
return nil, err
}
objects = append(objects, object)
}
result["messageList"] = objects
result["firstId"] = r.FirstID
result["hasMore"] = r.HasMore
return result, nil
}

View File

@@ -0,0 +1,405 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nodes
import (
"context"
"errors"
"fmt"
"strconv"
"strings"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
type ConversionWarning struct {
Path string
Type vo.DataType
Err error
}
func (e *ConversionWarning) Error() string {
return fmt.Sprintf("field %s is not %s", e.Path, e.Type)
}
type ConversionWarnings []*ConversionWarning
func (e ConversionWarnings) Merge(e1 ConversionWarnings) ConversionWarnings {
return append(e, e1...)
}
func (e ConversionWarnings) Error() string {
if len(e) == 0 {
return ""
}
var errs []string
for _, err := range e {
errs = append(errs, err.Error())
}
return strings.Join(errs, ", ")
}
func newWarnings(path string, t vo.DataType, err error) *ConversionWarnings {
return ptr.Of(ConversionWarnings{
{
Path: path,
Type: t,
Err: err,
},
})
}
func ConvertInputs(ctx context.Context, in map[string]any, tInfo map[string]*vo.TypeInfo, opts ...ConvertOption) (
map[string]any, *ConversionWarnings, error) {
options := &convertOptions{}
for _, opt := range opts {
opt(options)
}
if len(in) == 0 {
if !options.skipRequireCheck {
for n, t := range tInfo {
if t.Required {
return nil, nil, vo.NewError(errno.ErrMissingRequiredParam, errorx.KV("param", n))
}
}
}
return in, nil, nil
}
out := make(map[string]any)
var warnings ConversionWarnings
for k, v := range in {
t, ok := tInfo[k]
if !ok {
// for input fields not explicitly defined, just pass the string value through
logs.CtxWarnf(ctx, "input %s not found in type info", k)
if !options.skipUnknownFields {
out[k] = in[k]
}
continue
}
converted, ws, err := Convert(ctx, v, k, t, opts...)
if err != nil {
return nil, nil, vo.WrapError(errno.ErrInvalidParameter, err)
}
if ws != nil {
warnings = append(warnings, *ws...)
}
out[k] = converted
}
if !options.skipRequireCheck {
for k, t := range tInfo {
if _, ok := out[k]; !ok {
if t.Required {
return nil, nil, vo.NewError(errno.ErrMissingRequiredParam, errorx.KV("param", k))
}
}
}
}
if len(warnings) > 0 {
return out, &warnings, nil
}
return out, nil, nil
}
type convertOptions struct {
skipUnknownFields bool
failFast bool
skipRequireCheck bool
}
type ConvertOption func(*convertOptions)
func SkipUnknownFields() ConvertOption {
return func(o *convertOptions) {
o.skipUnknownFields = true
}
}
func FailFast() ConvertOption {
return func(o *convertOptions) {
o.failFast = true
}
}
func SkipRequireCheck() ConvertOption {
return func(o *convertOptions) {
o.skipRequireCheck = true
}
}
func Convert(ctx context.Context, in any, path string, t *vo.TypeInfo, opts ...ConvertOption) (
any, *ConversionWarnings, error) {
options := &convertOptions{}
for _, opt := range opts {
opt(options)
}
return convert(ctx, in, path, t, options)
}
func convert(ctx context.Context, in any, path string, t *vo.TypeInfo, options *convertOptions) (
any, *ConversionWarnings, error) {
if in == nil { // nil is valid for ALL types
return nil, nil, nil
}
switch t.Type {
case vo.DataTypeString, vo.DataTypeFile, vo.DataTypeTime:
return convertToString(ctx, in, path, options)
case vo.DataTypeInteger:
return convertToInt64(ctx, in, path, options)
case vo.DataTypeNumber:
return convertToFloat64(ctx, in, path, options)
case vo.DataTypeBoolean:
return convertToBool(ctx, in, path, options)
case vo.DataTypeObject:
return convertToObject(ctx, in, path, t, options)
case vo.DataTypeArray:
return convertToArray(ctx, in, path, t, options)
default:
if options.failFast {
return nil, nil, vo.WrapError(errno.ErrInvalidParameter, fmt.Errorf("unknown input type %s for path %s", t.Type, path))
}
logs.CtxErrorf(ctx, "unknown input type %s for path %s", t.Type, path)
return in, newWarnings(path, t.Type, errors.New("unknown input type")), nil
}
}
func convertToString(_ context.Context, in any, path string, options *convertOptions) (any, *ConversionWarnings, error) {
switch in.(type) {
case string:
return in.(string), nil, nil
case int64:
return strconv.FormatInt(in.(int64), 10), nil, nil
case float64:
return strconv.FormatFloat(in.(float64), 'f', -1, 64), nil, nil
case bool:
return strconv.FormatBool(in.(bool)), nil, nil
case []any, map[string]any:
s, err := sonic.MarshalString(in)
if err != nil {
if options.failFast {
return nil, nil, vo.WrapError(errno.ErrSerializationDeserializationFail, err)
}
return nil, newWarnings(path, vo.DataTypeString, err), nil
}
return s, nil, nil
default:
if options.failFast {
return nil, nil, vo.WrapError(errno.ErrInvalidParameter, fmt.Errorf("unsupported type to convert to string: %T", in))
}
return nil, newWarnings(path, vo.DataTypeString, fmt.Errorf("unsupported type to convert to string: %T", in)), nil
}
}
func convertToInt64(_ context.Context, in any, path string, options *convertOptions) (any, *ConversionWarnings, error) {
switch in.(type) {
case int64:
return in.(int64), nil, nil
case float64:
return int64(in.(float64)), nil, nil
case string:
i, err := strconv.ParseInt(in.(string), 10, 64)
if err != nil {
if options.failFast {
return nil, nil, vo.WrapError(errno.ErrInvalidParameter, err)
}
return nil, newWarnings(path, vo.DataTypeInteger, err), nil
}
return i, nil, nil
default:
if options.failFast {
return nil, nil, vo.WrapError(errno.ErrInvalidParameter, fmt.Errorf("unsupported type to convert to int64: %T", in))
}
return nil, newWarnings(path, vo.DataTypeInteger, fmt.Errorf("unsupported type to convert to int64: %T", in)), nil
}
}
func convertToFloat64(_ context.Context, in any, path string, options *convertOptions) (any, *ConversionWarnings, error) {
switch in.(type) {
case int64:
return float64(in.(int64)), nil, nil
case float64:
return in.(float64), nil, nil
case string:
f, err := strconv.ParseFloat(in.(string), 64)
if err != nil {
if options.failFast {
return nil, nil, vo.WrapError(errno.ErrInvalidParameter, err)
}
return nil, newWarnings(path, vo.DataTypeNumber, err), nil
}
return f, nil, nil
default:
if options.failFast {
return nil, nil, vo.WrapError(errno.ErrInvalidParameter, fmt.Errorf("unsupported type to convert to float64: %T", in))
}
return nil, newWarnings(path, vo.DataTypeNumber, fmt.Errorf("unsupported type to convert to float64: %T", in)), nil
}
}
func convertToBool(_ context.Context, in any, path string, options *convertOptions) (any, *ConversionWarnings, error) {
switch in.(type) {
case bool:
return in.(bool), nil, nil
case string:
b, err := strconv.ParseBool(in.(string))
if err != nil {
if options.failFast {
return nil, nil, vo.WrapError(errno.ErrInvalidParameter, err)
}
return nil, newWarnings(path, vo.DataTypeBoolean, err), nil
}
return b, nil, nil
default:
if options.failFast {
return nil, nil, vo.WrapError(errno.ErrInvalidParameter, fmt.Errorf("unsupported type to convert to bool: %T", in))
}
return nil, newWarnings(path, vo.DataTypeBoolean, fmt.Errorf("unsupported type to convert to bool: %T", in)), nil
}
}
func convertToObject(ctx context.Context, in any, path string, t *vo.TypeInfo, options *convertOptions) (
map[string]any, *ConversionWarnings, error) {
var m map[string]any
switch in.(type) {
case map[string]any:
m = in.(map[string]any)
case string:
err := sonic.UnmarshalString(in.(string), &m)
if err != nil {
if options.failFast {
return nil, nil, vo.WrapError(errno.ErrSerializationDeserializationFail, err)
}
return nil, newWarnings(path, vo.DataTypeObject, err), nil
}
default:
if options.failFast {
return nil, nil, vo.WrapError(errno.ErrInvalidParameter, fmt.Errorf("unsupported type to convert to object: %T", in))
}
return nil, newWarnings(path, vo.DataTypeObject, fmt.Errorf("unsupported type to convert to object: %T", in)), nil
}
if len(m) == 0 {
if !options.skipRequireCheck {
for pn, pro := range t.Properties {
if pro.Required {
return nil, nil, vo.NewError(errno.ErrMissingRequiredParam,
errorx.KV("param", fmt.Sprintf("%s.%s", path, pn)))
}
}
}
return m, nil, nil
}
out := make(map[string]any, len(m))
var warnings ConversionWarnings
for k, v := range m {
propType, ok := t.Properties[k]
if !ok {
// for input fields not explicitly defined, just pass the value through
logs.CtxWarnf(ctx, "input %s.%s not found in type info", path, k)
if !options.skipUnknownFields {
out[k] = v
}
continue
}
propPath := fmt.Sprintf("%s.%s", path, k)
newV, ws, err := convert(ctx, v, propPath, propType, options)
if err != nil {
return nil, nil, err
} else if ws != nil {
warnings = append(warnings, *ws...)
}
out[k] = newV
}
if !options.skipRequireCheck {
for k, t := range t.Properties {
if _, ok := out[k]; !ok {
if t.Required {
return nil, nil, vo.NewError(errno.ErrMissingRequiredParam,
errorx.KV("param", fmt.Sprintf("%s.%s", path, k)))
}
}
}
}
if len(warnings) > 0 {
return out, ptr.Of(warnings), nil
}
return out, nil, nil
}
func convertToArray(ctx context.Context, in any, path string, t *vo.TypeInfo, options *convertOptions) (
[]any, *ConversionWarnings, error) {
var a []any
switch v := in.(type) {
case []any:
a = v
case string:
err := sonic.UnmarshalString(v, &a)
if err != nil {
if options.failFast {
return nil, nil, vo.WrapError(errno.ErrSerializationDeserializationFail, err)
}
return []any{}, newWarnings(path, vo.DataTypeArray, err), nil
}
default:
if options.failFast {
return nil, nil, vo.WrapError(errno.ErrInvalidParameter, fmt.Errorf("unsupported type to convert to array: %T", in))
}
return []any{}, newWarnings(path, vo.DataTypeArray, fmt.Errorf("unsupported type to convert to array: %T", in)), nil
}
if len(a) == 0 {
return a, nil, nil
}
out := make([]any, 0, len(a))
var warnings ConversionWarnings
elemType := t.ElemTypeInfo
for i, v := range a {
elemPath := fmt.Sprintf("%s.%d", path, i)
newV, ws, err := convert(ctx, v, elemPath, elemType, options)
if err != nil {
return nil, nil, err
} else if ws != nil {
warnings = append(warnings, *ws...)
} else { // only correctly converted elements go into the final array
out = append(out, newV)
}
}
if len(warnings) > 0 {
return out, ptr.Of(warnings), nil
}
return out, nil, nil
}

View File

@@ -0,0 +1,435 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package database
import (
"context"
"fmt"
"strconv"
"strings"
"time"
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
)
const rowNum = "rowNum"
const outputList = "outputList"
const TimeFormat = "2006-01-02 15:04:05 -0700 MST"
func toString(in any) (any, error) {
switch in := in.(type) {
case []byte:
return string(in), nil
case string:
return in, nil
case int64:
return strconv.FormatInt(in, 10), nil
case float64:
return strconv.FormatFloat(in, 'f', -1, 64), nil
case time.Time:
return in.Format(TimeFormat), nil
case bool:
return strconv.FormatBool(in), nil
case map[string]any, []any:
return sonic.MarshalString(in)
default:
return "", fmt.Errorf("unknown type: %T", in)
}
}
func toInteger(in any) (any, error) {
switch in := in.(type) {
case []byte:
return strconv.ParseInt(string(in), 10, 64)
case string:
return strconv.ParseInt(in, 10, 64)
case int64:
return in, nil
case float64:
return int64(in), nil
case time.Time, bool:
return nil, fmt.Errorf(`type '%T' can't convert to int64'`, in)
default:
return nil, fmt.Errorf("unknown type: %T", in)
}
}
func toNumber(in any) (any, error) {
switch in := in.(type) {
case []byte:
i, err := strconv.ParseFloat(string(in), 64)
return i, err
case string:
return strconv.ParseFloat(in, 64)
case int64:
return float64(in), nil
case float64:
return in, nil
case time.Time, bool:
return nil, fmt.Errorf(`type '%T' can't convert to float64'`, in)
default:
return nil, fmt.Errorf("unknown type: %T", in)
}
}
func toTime(in any) (any, error) {
switch in := in.(type) {
case []byte:
return string(in), nil
case string:
return in, nil
case int64:
return strconv.FormatInt(in, 10), nil
case float64:
return strconv.FormatFloat(in, 'f', -1, 64), nil
case time.Time:
return in.Format(TimeFormat), nil
case bool:
if in {
return "1", nil
}
return "0", nil
default:
return nil, fmt.Errorf("unknown type: %T", in)
}
}
func toBool(in any) (any, error) {
switch in := in.(type) {
case []byte:
return strconv.ParseBool(string(in))
case string:
return strconv.ParseBool(in)
case int64:
return strconv.ParseBool(strconv.FormatInt(in, 10))
case float64:
return strconv.ParseBool(strconv.FormatFloat(in, 'f', -1, 64))
case time.Time:
return strconv.ParseBool(in.Format(TimeFormat))
case bool:
return in, nil
default:
return nil, fmt.Errorf("unknown type: %T", in)
}
}
// formatted convert the interface type according to the datatype type.
// notice: object is currently not supported by database, and ignore it.
func formatted(in any, ty *vo.TypeInfo) any {
switch ty.Type {
case vo.DataTypeString:
r, err := toString(in)
if err != nil {
logs.Warnf("formatted string error: %v", err)
return nil
}
return r
case vo.DataTypeNumber:
r, err := toNumber(in)
if err != nil {
logs.Warnf("formatted number error: %v", err)
return nil
}
return r
case vo.DataTypeInteger:
r, err := toInteger(in)
if err != nil {
logs.Warnf("formatted integer error: %v", err)
return nil
}
return r
case vo.DataTypeBoolean:
r, err := toBool(in)
if err != nil {
logs.Warnf("formatted boolean error: %v", err)
}
return r
case vo.DataTypeTime:
r, err := toTime(in)
if err != nil {
logs.Warnf("formatted time error: %v", err)
return nil
}
return r
case vo.DataTypeArray:
arrayIn := make([]any, 0)
inStr, err := toString(in)
if err != nil {
logs.Warnf("formatted array error: %v", err)
return []any{}
}
err = sonic.UnmarshalString(inStr.(string), &arrayIn)
if err != nil {
logs.Warnf("formatted array unmarshal error: %v", err)
return []any{}
}
result := make([]any, 0)
switch ty.ElemTypeInfo.Type {
case vo.DataTypeTime:
for _, in := range arrayIn {
r, err := toTime(in)
if err != nil {
logs.Warnf("formatted time: %v", err)
continue
}
result = append(result, r)
}
return result
case vo.DataTypeString:
for _, in := range arrayIn {
r, err := toString(in)
if err != nil {
logs.Warnf("formatted string failed: %v", err)
continue
}
result = append(result, r)
}
return result
case vo.DataTypeInteger:
for _, in := range arrayIn {
r, err := toInteger(in)
if err != nil {
logs.Warnf("formatted interger failed: %v", err)
continue
}
result = append(result, r)
}
return result
case vo.DataTypeBoolean:
for _, in := range arrayIn {
r, err := toBool(in)
if err != nil {
logs.Warnf("formatted bool failed: %v", err)
continue
}
result = append(result, r)
}
return result
case vo.DataTypeNumber:
for _, in := range arrayIn {
r, err := toNumber(in)
if err != nil {
logs.Warnf("formatted number failed: %v", err)
continue
}
result = append(result, r)
}
return result
case vo.DataTypeObject:
properties := ty.ElemTypeInfo.Properties
if len(properties) == 0 {
for idx := range arrayIn {
in := arrayIn[idx]
if _, ok := in.(database.Object); ok {
result = append(result, in)
}
}
return result
}
for idx := range arrayIn {
in := arrayIn[idx]
object, ok := in.(database.Object)
if !ok {
object = make(database.Object)
for key := range properties {
object[key] = nil
}
result = append(result, object)
} else {
result = append(result, objectFormatted(ty.ElemTypeInfo.Properties, object))
}
}
return result
default:
return nil
}
default:
return nil
}
}
func objectFormatted(props map[string]*vo.TypeInfo, object database.Object) map[string]any {
ret := make(map[string]any)
// if config is nil, it agrees to convert to string type as the default value
if len(props) == 0 {
for k, v := range object {
val, err := toString(v)
if err != nil {
logs.Warnf("formatted string error: %v", err)
continue
}
ret[k] = val
}
return ret
}
for k, v := range props {
if r, ok := object[k]; ok && r != nil {
formattedValue := formatted(r, v)
ret[k] = formattedValue
} else {
// if key not existed, assign nil
ret[k] = nil
}
}
return ret
}
// responseFormatted convert the object list returned by "response" into the field mapping of the "config output" configuration,
// If the conversion fail, set the output list to null. If there are missing fields, set the missing fields to null.
func responseFormatted(configOutput map[string]*vo.TypeInfo, response *database.Response) (map[string]any, error) {
ret := make(map[string]any)
list := make([]any, 0, len(configOutput))
outputListTypeInfo, ok := configOutput["outputList"]
if !ok {
return ret, fmt.Errorf("outputList key is required")
}
if outputListTypeInfo.Type != vo.DataTypeArray {
return nil, fmt.Errorf("output list type info must array,but got %v", outputListTypeInfo.Type)
}
if outputListTypeInfo.ElemTypeInfo == nil {
return nil, fmt.Errorf("output list must be an array and the array must contain element type info")
}
if outputListTypeInfo.ElemTypeInfo.Type != vo.DataTypeObject {
return nil, fmt.Errorf("output list must be an array and element must object, but got %v", outputListTypeInfo.ElemTypeInfo.Type)
}
props := outputListTypeInfo.ElemTypeInfo.Properties
for _, object := range response.Objects {
list = append(list, objectFormatted(props, object))
}
ret[outputList] = list
if response.RowNumber != nil {
ret[rowNum] = *response.RowNumber
} else {
ret[rowNum] = nil
}
return ret, nil
}
func convertClauseGroupToConditionGroup(ctx context.Context, clauseGroup *database.ClauseGroup, input map[string]any) (*database.ConditionGroup, error) {
var (
rightValue any
ok bool
)
conditionGroup := &database.ConditionGroup{
Conditions: make([]*database.Condition, 0),
Relation: database.ClauseRelationAND,
}
if clauseGroup.Single != nil {
clause := clauseGroup.Single
if !notNeedTakeMapValue(clause.Operator) {
rightValue, ok = nodes.TakeMapValue(input, compose.FieldPath{"__condition_right_0"})
if !ok {
return nil, fmt.Errorf("cannot take single clause from input")
}
}
conditionGroup.Conditions = append(conditionGroup.Conditions, &database.Condition{
Left: clause.Left,
Operator: clause.Operator,
Right: rightValue,
})
}
if clauseGroup.Multi != nil {
conditionGroup.Relation = clauseGroup.Multi.Relation
conditionGroup.Conditions = make([]*database.Condition, len(clauseGroup.Multi.Clauses))
multiSelect := clauseGroup.Multi
for idx, clause := range multiSelect.Clauses {
if !notNeedTakeMapValue(clause.Operator) {
rightValue, ok = nodes.TakeMapValue(input, compose.FieldPath{fmt.Sprintf("__condition_right_%d", idx)})
if !ok {
return nil, fmt.Errorf("cannot take multi clause from input")
}
}
conditionGroup.Conditions[idx] = &database.Condition{
Left: clause.Left,
Operator: clause.Operator,
Right: rightValue,
}
}
}
return conditionGroup, nil
}
func convertClauseGroupToUpdateInventory(ctx context.Context, clauseGroup *database.ClauseGroup, input map[string]any) (*UpdateInventory, error) {
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, clauseGroup, input)
if err != nil {
return nil, err
}
fields := parseToInput(input)
inventory := &UpdateInventory{
ConditionGroup: conditionGroup,
Fields: fields,
}
return inventory, nil
}
func isDebugExecute(ctx context.Context) bool {
execCtx := execute.GetExeCtx(ctx)
if execCtx == nil {
panic(fmt.Errorf("unable to get exe context"))
}
return execCtx.RootCtx.ExeCfg.Mode == vo.ExecuteModeDebug || execCtx.RootCtx.ExeCfg.Mode == vo.ExecuteModeNodeDebug
}
func getExecUserID(ctx context.Context) int64 {
execCtx := execute.GetExeCtx(ctx)
if execCtx == nil {
panic(fmt.Errorf("unable to get exe context"))
}
return execCtx.RootCtx.ExeCfg.Operator
}
func parseToInput(input map[string]any) map[string]any {
result := make(map[string]any, len(input))
for key, value := range input {
if strings.HasPrefix(key, "__setting_field_") {
key = strings.TrimPrefix(key, "__setting_field_")
result[key] = value
}
}
return result
}

View File

@@ -0,0 +1,127 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package database
import (
"context"
"errors"
"reflect"
"strings"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
)
type CustomSQLConfig struct {
DatabaseInfoID int64
SQLTemplate string
OutputConfig map[string]*vo.TypeInfo
CustomSQLExecutor database.DatabaseOperator
}
func NewCustomSQL(_ context.Context, cfg *CustomSQLConfig) (*CustomSQL, error) {
if cfg == nil {
return nil, errors.New("config is required")
}
if cfg.DatabaseInfoID == 0 {
return nil, errors.New("database info id is required and greater than 0")
}
if cfg.SQLTemplate == "" {
return nil, errors.New("sql template is required")
}
if cfg.CustomSQLExecutor == nil {
return nil, errors.New("custom sqler is required")
}
return &CustomSQL{
config: cfg,
}, nil
}
type CustomSQL struct {
config *CustomSQLConfig
}
func (c *CustomSQL) Execute(ctx context.Context, input map[string]any) (map[string]any, error) {
req := &database.CustomSQLRequest{
DatabaseInfoID: c.config.DatabaseInfoID,
IsDebugRun: isDebugExecute(ctx),
UserID: getExecUserID(ctx),
}
inputBytes, err := sonic.Marshal(input)
if err != nil {
return nil, err
}
templateSQL := ""
templateParts := nodes.ParseTemplate(c.config.SQLTemplate)
sqlParams := make([]database.SQLParam, 0, len(templateParts))
var nilError = errors.New("field is nil")
for _, templatePart := range templateParts {
if !templatePart.IsVariable {
templateSQL += templatePart.Value
continue
}
templateSQL += "?"
val, err := templatePart.Render(inputBytes, nodes.WithNilRender(func() (string, error) {
return "", nilError
}),
nodes.WithCustomRender(reflect.TypeOf(false), func(val any) (string, error) {
b := val.(bool)
if b {
return "1", nil
}
return "0", nil
}))
if err != nil {
if !errors.Is(err, nilError) {
return nil, err
}
sqlParams = append(sqlParams, database.SQLParam{
IsNull: true,
})
} else {
sqlParams = append(sqlParams, database.SQLParam{
Value: val,
IsNull: false,
})
}
}
// replace sql template '?' to ?
templateSQL = strings.Replace(templateSQL, "'?'", "?", -1)
templateSQL = strings.Replace(templateSQL, "`?`", "?", -1)
req.SQL = templateSQL
req.Params = sqlParams
response, err := c.config.CustomSQLExecutor.Execute(ctx, req)
if err != nil {
return nil, err
}
ret, err := responseFormatted(c.config.OutputConfig, response)
if err != nil {
return nil, err
}
return ret, nil
}

View File

@@ -0,0 +1,110 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package database
import (
"context"
"testing"
"github.com/bytedance/mockey"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database/databasemock"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
)
type mockCustomSQLer struct {
validate func(req *database.CustomSQLRequest)
}
func (m mockCustomSQLer) Execute() func(ctx context.Context, request *database.CustomSQLRequest) (*database.Response, error) {
return func(ctx context.Context, request *database.CustomSQLRequest) (*database.Response, error) {
m.validate(request)
r := &database.Response{
Objects: []database.Object{
database.Object{
"v1": "v1_ret",
"v2": "v2_ret",
},
},
}
return r, nil
}
}
func TestCustomSQL_Execute(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSQLer := mockCustomSQLer{
validate: func(req *database.CustomSQLRequest) {
assert.Equal(t, int64(111), req.DatabaseInfoID)
ps := []database.SQLParam{
database.SQLParam{Value: "v1_value"},
database.SQLParam{Value: "v2_value"},
database.SQLParam{Value: "v3_value"},
}
assert.Equal(t, ps, req.Params)
assert.Equal(t, "select * from v1 where v1 = ? and v2 = ? and v3 = ?", req.SQL)
},
}
defer mockey.Mock(execute.GetExeCtx).Return(&execute.Context{
RootCtx: execute.RootCtx{
ExeCfg: vo.ExecuteConfig{
Mode: vo.ExecuteModeDebug,
Operator: 123,
BizType: vo.BizTypeWorkflow,
},
},
}).Build().UnPatch()
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
mockDatabaseOperator.EXPECT().Execute(gomock.Any(), gomock.Any()).DoAndReturn(mockSQLer.Execute()).AnyTimes()
cfg := &CustomSQLConfig{
DatabaseInfoID: 111,
SQLTemplate: "select * from v1 where v1 = {{v1}} and v2 = '{{v2}}' and v3 = `{{v3}}`",
CustomSQLExecutor: mockDatabaseOperator,
OutputConfig: map[string]*vo.TypeInfo{
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"v1": {Type: vo.DataTypeString},
"v2": {Type: vo.DataTypeString},
}}},
"rowNum": {Type: vo.DataTypeInteger},
},
}
cl := &CustomSQL{
config: cfg,
}
ret, err := cl.Execute(t.Context(), map[string]any{
"v1": "v1_value",
"v2": "v2_value",
"v3": "v3_value",
})
assert.Nil(t, err)
assert.Equal(t, "v1_ret", ret[outputList].([]any)[0].(database.Object)["v1"])
assert.Equal(t, "v2_ret", ret[outputList].([]any)[0].(database.Object)["v2"])
}

View File

@@ -0,0 +1,111 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package database
import (
"context"
"errors"
"fmt"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
)
type DeleteConfig struct {
DatabaseInfoID int64
ClauseGroup *database.ClauseGroup
OutputConfig map[string]*vo.TypeInfo
Deleter database.DatabaseOperator
}
type Delete struct {
config *DeleteConfig
}
func NewDelete(_ context.Context, cfg *DeleteConfig) (*Delete, error) {
if cfg == nil {
return nil, errors.New("config is required")
}
if cfg.DatabaseInfoID == 0 {
return nil, errors.New("database info id is required and greater than 0")
}
if cfg.ClauseGroup == nil {
return nil, errors.New("clauseGroup is required")
}
if cfg.Deleter == nil {
return nil, errors.New("deleter is required")
}
return &Delete{
config: cfg,
}, nil
}
func (d *Delete) Delete(ctx context.Context, in map[string]any) (map[string]any, error) {
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, d.config.ClauseGroup, in)
if err != nil {
return nil, err
}
request := &database.DeleteRequest{
DatabaseInfoID: d.config.DatabaseInfoID,
ConditionGroup: conditionGroup,
IsDebugRun: isDebugExecute(ctx),
UserID: getExecUserID(ctx),
}
response, err := d.config.Deleter.Delete(ctx, request)
if err != nil {
return nil, err
}
ret, err := responseFormatted(d.config.OutputConfig, response)
if err != nil {
return nil, err
}
return ret, nil
}
func (d *Delete) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) {
conditionGroup, err := convertClauseGroupToConditionGroup(context.Background(), d.config.ClauseGroup, in)
if err != nil {
return nil, err
}
return d.toDatabaseDeleteCallbackInput(conditionGroup)
}
func (d *Delete) toDatabaseDeleteCallbackInput(conditionGroup *database.ConditionGroup) (map[string]any, error) {
databaseID := d.config.DatabaseInfoID
result := make(map[string]any)
result["databaseInfoList"] = []string{fmt.Sprintf("%d", databaseID)}
result["deleteParam"] = map[string]any{}
condition, err := convertToCondition(conditionGroup)
if err != nil {
return nil, err
}
type Field struct {
FieldID string `json:"fieldId"`
IsDistinct bool `json:"isDistinct"`
}
result["deleteParam"] = map[string]any{
"condition": condition}
return result, nil
}

View File

@@ -0,0 +1,102 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package database
import (
"context"
"errors"
"fmt"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
)
type InsertConfig struct {
DatabaseInfoID int64
OutputConfig map[string]*vo.TypeInfo
Inserter database.DatabaseOperator
}
type Insert struct {
config *InsertConfig
}
func NewInsert(_ context.Context, cfg *InsertConfig) (*Insert, error) {
if cfg == nil {
return nil, errors.New("config is required")
}
if cfg.DatabaseInfoID == 0 {
return nil, errors.New("database info id is required and greater than 0")
}
if cfg.Inserter == nil {
return nil, errors.New("inserter is required")
}
return &Insert{
config: cfg,
}, nil
}
func (is *Insert) Insert(ctx context.Context, input map[string]any) (map[string]any, error) {
fields := parseToInput(input)
req := &database.InsertRequest{
DatabaseInfoID: is.config.DatabaseInfoID,
Fields: fields,
IsDebugRun: isDebugExecute(ctx),
UserID: getExecUserID(ctx),
}
response, err := is.config.Inserter.Insert(ctx, req)
if err != nil {
return nil, err
}
ret, err := responseFormatted(is.config.OutputConfig, response)
if err != nil {
return nil, err
}
return ret, nil
}
func (is *Insert) ToCallbackInput(_ context.Context, input map[string]any) (map[string]any, error) {
databaseID := is.config.DatabaseInfoID
fs := parseToInput(input)
result := make(map[string]any)
result["databaseInfoList"] = []string{fmt.Sprintf("%d", databaseID)}
type FieldInfo struct {
FieldID string `json:"fieldId"`
FieldValue any `json:"fieldValue"`
}
fieldInfo := make([]*FieldInfo, 0)
for k, v := range fs {
fieldInfo = append(fieldInfo, &FieldInfo{
FieldID: k,
FieldValue: v,
})
}
result["insertParam"] = map[string]any{
"fieldInfo": fieldInfo,
}
return result, nil
}

View File

@@ -0,0 +1,221 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package database
import (
"context"
"errors"
"fmt"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
)
type QueryConfig struct {
DatabaseInfoID int64
QueryFields []string
OrderClauses []*database.OrderClause
OutputConfig map[string]*vo.TypeInfo
ClauseGroup *database.ClauseGroup
Limit int64
Op database.DatabaseOperator
}
type Query struct {
config *QueryConfig
}
func NewQuery(_ context.Context, cfg *QueryConfig) (*Query, error) {
if cfg == nil {
return nil, errors.New("config is required")
}
if cfg.DatabaseInfoID == 0 {
return nil, errors.New("database info id is required and greater than 0")
}
if cfg.Limit == 0 {
return nil, errors.New("limit is required and greater than 0")
}
if cfg.Op == nil {
return nil, errors.New("op is required")
}
return &Query{config: cfg}, nil
}
func (ds *Query) Query(ctx context.Context, in map[string]any) (map[string]any, error) {
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, ds.config.ClauseGroup, in)
if err != nil {
return nil, err
}
req := &database.QueryRequest{
DatabaseInfoID: ds.config.DatabaseInfoID,
OrderClauses: ds.config.OrderClauses,
SelectFields: ds.config.QueryFields,
Limit: ds.config.Limit,
IsDebugRun: isDebugExecute(ctx),
UserID: getExecUserID(ctx),
}
req.ConditionGroup = conditionGroup
response, err := ds.config.Op.Query(ctx, req)
if err != nil {
return nil, err
}
ret, err := responseFormatted(ds.config.OutputConfig, response)
if err != nil {
return nil, err
}
return ret, nil
}
func notNeedTakeMapValue(op database.Operator) bool {
return op == database.OperatorIsNull || op == database.OperatorIsNotNull
}
func (ds *Query) ToCallbackInput(ctx context.Context, in map[string]any) (map[string]any, error) {
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, ds.config.ClauseGroup, in)
if err != nil {
return nil, err
}
return toDatabaseQueryCallbackInput(ds.config, conditionGroup)
}
func toDatabaseQueryCallbackInput(config *QueryConfig, conditionGroup *database.ConditionGroup) (map[string]any, error) {
result := make(map[string]any)
databaseID := config.DatabaseInfoID
result["databaseInfoList"] = []string{fmt.Sprintf("%d", databaseID)}
result["selectParam"] = map[string]any{}
condition, err := convertToCondition(conditionGroup)
if err != nil {
return nil, err
}
type Field struct {
FieldID string `json:"fieldId"`
IsDistinct bool `json:"isDistinct"`
}
fieldList := make([]Field, 0, len(config.QueryFields))
for _, f := range config.QueryFields {
fieldList = append(fieldList, Field{FieldID: f})
}
type Order struct {
FieldID string `json:"fieldId"`
IsAsc bool `json:"isAsc"`
}
OrderList := make([]Order, 0)
for _, c := range config.OrderClauses {
OrderList = append(OrderList, Order{
FieldID: c.FieldID,
IsAsc: c.IsAsc,
})
}
result["selectParam"] = map[string]any{
"condition": condition,
"fieldList": fieldList,
"limit": config.Limit,
"orderByList": OrderList,
}
return result, nil
}
type ConditionItem struct {
Left string `json:"left"`
Operation string `json:"operation"`
Right any `json:"right"`
}
type Condition struct {
ConditionList []ConditionItem `json:"conditionList"`
Logic string `json:"logic"`
}
func convertToCondition(conditionGroup *database.ConditionGroup) (*Condition, error) {
logic, err := convertToLogic(conditionGroup.Relation)
if err != nil {
return nil, err
}
condition := &Condition{
ConditionList: make([]ConditionItem, 0),
Logic: logic,
}
for _, c := range conditionGroup.Conditions {
op, err := convertToOperation(c.Operator)
if err != nil {
return nil, fmt.Errorf("invalid operator: %s", c.Operator)
}
condition.ConditionList = append(condition.ConditionList, ConditionItem{
Left: c.Left,
Operation: op,
Right: c.Right,
})
}
return condition, nil
}
func convertToOperation(Op database.Operator) (string, error) {
switch Op {
case database.OperatorEqual:
return "EQUAL", nil
case database.OperatorNotEqual:
return "NOT_EQUAL", nil
case database.OperatorGreater:
return "GREATER_THAN", nil
case database.OperatorLesser:
return "LESS_THAN", nil
case database.OperatorGreaterOrEqual:
return "GREATER_EQUAL", nil
case database.OperatorLesserOrEqual:
return "LESS_EQUAL", nil
case database.OperatorIn:
return "IN", nil
case database.OperatorNotIn:
return "NOT_IN", nil
case database.OperatorIsNull:
return "IS_NULL", nil
case database.OperatorIsNotNull:
return "IS_NOT_NULL", nil
case database.OperatorLike:
return "LIKE", nil
case database.OperatorNotLike:
return "NOT LIKE", nil
}
return "", fmt.Errorf("not a valid database Operator")
}
func convertToLogic(rel database.ClauseRelation) (string, error) {
switch rel {
case database.ClauseRelationOR:
return "OR", nil
case database.ClauseRelationAND:
return "AND", nil
default:
return "", fmt.Errorf("unknown clause relation %v", rel)
}
}

View File

@@ -0,0 +1,456 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package database
import (
"context"
"fmt"
"testing"
"github.com/bytedance/mockey"
"go.uber.org/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database/databasemock"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
)
type mockDsSelect struct {
t *testing.T
objects []database.Object
validate func(request *database.QueryRequest)
}
func (m *mockDsSelect) Query() func(ctx context.Context, request *database.QueryRequest) (*database.Response, error) {
return func(ctx context.Context, request *database.QueryRequest) (*database.Response, error) {
n := int64(1)
m.validate(request)
return &database.Response{
RowNumber: &n,
Objects: m.objects,
}, nil
}
}
func TestDataset_Query(t *testing.T) {
defer mockey.Mock(execute.GetExeCtx).Return(&execute.Context{
RootCtx: execute.RootCtx{
ExeCfg: vo.ExecuteConfig{
Mode: vo.ExecuteModeDebug,
Operator: 123,
BizType: vo.BizTypeWorkflow,
},
},
}).Build().UnPatch()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
t.Run("string case", func(t *testing.T) {
t.Run("single", func(t *testing.T) {
objects := make([]database.Object, 0)
objects = append(objects, database.Object{
"v1": "1",
"v2": int64(2),
})
cfg := &QueryConfig{
DatabaseInfoID: 111,
ClauseGroup: &database.ClauseGroup{
Single: &database.Clause{
Left: "v1",
Operator: database.OperatorLike,
},
},
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
QueryFields: []string{"v1", "v2"},
OutputConfig: map[string]*vo.TypeInfo{
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"v1": {Type: vo.DataTypeString},
"v2": {Type: vo.DataTypeString},
},
}},
"rowNum": {Type: vo.DataTypeInteger},
},
}
mockQuery := &mockDsSelect{objects: objects, t: t, validate: func(request *database.QueryRequest) {
if request.DatabaseInfoID != cfg.DatabaseInfoID {
t.Fatal("database id should be equal")
}
cGroup := request.ConditionGroup
assert.Equal(t, cGroup.Conditions[0].Left, cfg.ClauseGroup.Single.Left)
assert.Equal(t, cGroup.Conditions[0].Operator, cfg.ClauseGroup.Single.Operator)
}}
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query())
cfg.Op = mockDatabaseOperator
ds := Query{
config: cfg,
}
in := map[string]interface{}{
"__condition_right_0": 1,
}
result, err := ds.Query(t.Context(), in)
assert.NoError(t, err)
assert.Equal(t, "1", result["outputList"].([]any)[0].(database.Object)["v1"])
assert.Equal(t, "2", result["outputList"].([]any)[0].(database.Object)["v2"])
})
t.Run("multi", func(t *testing.T) {
cfg := &QueryConfig{
DatabaseInfoID: 111,
ClauseGroup: &database.ClauseGroup{
Multi: &database.MultiClause{
Relation: database.ClauseRelationOR,
Clauses: []*database.Clause{
{Left: "v1", Operator: database.OperatorLike},
{Left: "v2", Operator: database.OperatorLike},
},
},
},
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
QueryFields: []string{"v1", "v2"},
OutputConfig: map[string]*vo.TypeInfo{
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"v1": {Type: vo.DataTypeString},
"v2": {Type: vo.DataTypeString},
},
}},
"rowNum": {Type: vo.DataTypeInteger},
},
}
objects := make([]database.Object, 0)
objects = append(objects, database.Object{
"v1": "1",
"v2": int64(2),
})
mockQuery := &mockDsSelect{objects: objects, t: t, validate: func(request *database.QueryRequest) {
if request.DatabaseInfoID != cfg.DatabaseInfoID {
t.Fatal("database id should be equal")
}
cGroup := request.ConditionGroup
assert.Equal(t, cGroup.Conditions[0].Right, 1)
assert.Equal(t, cGroup.Conditions[1].Right, 2)
assert.Equal(t, cGroup.Relation, cfg.ClauseGroup.Multi.Relation)
}}
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
cfg.Op = mockDatabaseOperator
ds := Query{
config: cfg,
}
in := map[string]any{
"__condition_right_0": 1,
"__condition_right_1": 2,
}
result, err := ds.Query(t.Context(), in)
assert.NoError(t, err)
assert.NoError(t, err)
assert.Equal(t, "1", result["outputList"].([]any)[0].(database.Object)["v1"])
assert.Equal(t, "2", result["outputList"].([]any)[0].(database.Object)["v2"])
})
t.Run("formated error", func(t *testing.T) {
cfg := &QueryConfig{
DatabaseInfoID: 111,
ClauseGroup: &database.ClauseGroup{
Single: &database.Clause{
Left: "v1",
Operator: database.OperatorLike,
},
},
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
QueryFields: []string{"v1", "v2"},
OutputConfig: map[string]*vo.TypeInfo{
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"v1": {Type: vo.DataTypeInteger},
"v2": {Type: vo.DataTypeInteger},
},
}},
"rowNum": {Type: vo.DataTypeInteger},
},
}
objects := make([]database.Object, 0)
objects = append(objects, database.Object{
"v1": "abc",
"v2": int64(2),
})
mockQuery := &mockDsSelect{objects: objects, t: t, validate: func(request *database.QueryRequest) {
if request.DatabaseInfoID != cfg.DatabaseInfoID {
t.Fatal("database id should be equal")
}
cGroup := request.ConditionGroup
assert.Equal(t, cGroup.Conditions[0].Left, cfg.ClauseGroup.Single.Left)
assert.Equal(t, cGroup.Conditions[0].Operator, cfg.ClauseGroup.Single.Operator)
}}
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
cfg.Op = mockDatabaseOperator
ds := Query{
config: cfg,
}
in := map[string]any{
"__condition_right_0": 1,
}
result, err := ds.Query(t.Context(), in)
assert.NoError(t, err)
fmt.Println(result)
assert.Equal(t, map[string]any{
"v1": nil,
"v2": int64(2),
}, result["outputList"].([]any)[0])
})
t.Run("redundancy return field", func(t *testing.T) {
cfg := &QueryConfig{
DatabaseInfoID: 111,
ClauseGroup: &database.ClauseGroup{
Single: &database.Clause{
Left: "v1",
Operator: database.OperatorLike,
},
},
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
QueryFields: []string{"v1", "v2"},
OutputConfig: map[string]*vo.TypeInfo{
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"v1": {Type: vo.DataTypeInteger},
"v2": {Type: vo.DataTypeInteger},
"v3": {Type: vo.DataTypeInteger},
},
}},
"rowNum": {Type: vo.DataTypeInteger},
},
}
objects := make([]database.Object, 0)
objects = append(objects, database.Object{
"v1": "1",
"v2": int64(2),
})
mockQuery := &mockDsSelect{objects: objects, t: t, validate: func(request *database.QueryRequest) {
if request.DatabaseInfoID != cfg.DatabaseInfoID {
t.Fatal("database id should be equal")
}
cGroup := request.ConditionGroup
assert.Equal(t, cGroup.Conditions[0].Left, cfg.ClauseGroup.Single.Left)
assert.Equal(t, cGroup.Conditions[0].Operator, cfg.ClauseGroup.Single.Operator)
}}
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
cfg.Op = mockDatabaseOperator
ds := Query{
config: cfg,
}
in := map[string]any{"__condition_right_0": 1}
result, err := ds.Query(t.Context(), in)
assert.NoError(t, err)
fmt.Println(result)
assert.Equal(t, int64(1), result["outputList"].([]any)[0].(database.Object)["v1"])
assert.Equal(t, int64(2), result["outputList"].([]any)[0].(database.Object)["v2"])
assert.Equal(t, nil, result["outputList"].([]any)[0].(database.Object)["v3"])
})
})
t.Run("other case", func(t *testing.T) {
cfg := &QueryConfig{
DatabaseInfoID: 111,
ClauseGroup: &database.ClauseGroup{
Single: &database.Clause{
Left: "v1",
Operator: database.OperatorLike,
},
},
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
QueryFields: []string{"v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8"},
OutputConfig: map[string]*vo.TypeInfo{
"outputList": {Type: vo.DataTypeArray,
ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"v1": {Type: vo.DataTypeInteger},
"v2": {Type: vo.DataTypeNumber},
"v3": {Type: vo.DataTypeBoolean},
"v4": {Type: vo.DataTypeBoolean},
"v5": {Type: vo.DataTypeTime},
"v6": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeInteger}},
"v7": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeBoolean}},
"v8": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
},
}},
"rowNum": {Type: vo.DataTypeInteger},
},
}
objects := make([]database.Object, 0)
objects = append(objects, database.Object{
"v1": "1",
"v2": "2.1",
"v3": int64(0),
"v4": "true",
"v5": "2020-02-20T10:10:10",
"v6": `["1","2","3"]`,
"v7": `[false,true,"true"]`,
"v8": `["1.2",2.1, 3.9]`,
})
mockQuery := &mockDsSelect{objects: objects, t: t, validate: func(request *database.QueryRequest) {
if request.DatabaseInfoID != cfg.DatabaseInfoID {
t.Fatal("database id should be equal")
}
cGroup := request.ConditionGroup
assert.Equal(t, cGroup.Conditions[0].Left, cfg.ClauseGroup.Single.Left)
assert.Equal(t, cGroup.Conditions[0].Operator, cfg.ClauseGroup.Single.Operator)
}}
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
cfg.Op = mockDatabaseOperator
ds := Query{
config: cfg,
}
in := map[string]any{
"__condition_right_0": 1,
}
result, err := ds.Query(t.Context(), in)
assert.NoError(t, err)
object := result["outputList"].([]any)[0].(database.Object)
assert.Equal(t, int64(1), object["v1"])
assert.Equal(t, 2.1, object["v2"])
assert.Equal(t, false, object["v3"])
assert.Equal(t, true, object["v4"])
assert.Equal(t, "2020-02-20T10:10:10", object["v5"])
assert.Equal(t, []any{int64(1), int64(2), int64(3)}, object["v6"])
assert.Equal(t, []any{false, true, true}, object["v7"])
assert.Equal(t, []any{1.2, 2.1, 3.9}, object["v8"])
})
t.Run("config output list is nil", func(t *testing.T) {
cfg := &QueryConfig{
DatabaseInfoID: 111,
ClauseGroup: &database.ClauseGroup{
Single: &database.Clause{
Left: "v1",
Operator: database.OperatorLike,
},
},
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
QueryFields: []string{"v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8"},
OutputConfig: map[string]*vo.TypeInfo{
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{}}},
"rowNum": {Type: vo.DataTypeInteger},
},
}
objects := make([]database.Object, 0)
objects = append(objects, database.Object{
"v1": int64(1),
"v2": "2.1",
"v3": int64(0),
"v4": "true",
"v5": "2020-02-20T10:10:10",
"v6": `["1","2","3"]`,
"v7": `[false,true,"true"]`,
"v8": `["1.2",2.1, 3.9]`,
})
mockQuery := &mockDsSelect{objects: objects, t: t, validate: func(request *database.QueryRequest) {
if request.DatabaseInfoID != cfg.DatabaseInfoID {
t.Fatal("database id should be equal")
}
cGroup := request.ConditionGroup
assert.Equal(t, cGroup.Conditions[0].Left, cfg.ClauseGroup.Single.Left)
assert.Equal(t, cGroup.Conditions[0].Operator, cfg.ClauseGroup.Single.Operator)
}}
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
cfg.Op = mockDatabaseOperator
ds := Query{
config: cfg,
}
in := map[string]any{
"__condition_right_0": 1,
}
result, err := ds.Query(t.Context(), in)
assert.NoError(t, err)
assert.Equal(t, result["outputList"].([]any)[0].(database.Object), database.Object{
"v1": "1",
"v2": "2.1",
"v3": "0",
"v4": "true",
"v5": "2020-02-20T10:10:10",
"v6": `["1","2","3"]`,
"v7": `[false,true,"true"]`,
"v8": `["1.2",2.1, 3.9]`,
})
})
}

View File

@@ -0,0 +1,133 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package database
import (
"context"
"errors"
"fmt"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
)
type UpdateConfig struct {
DatabaseInfoID int64
ClauseGroup *database.ClauseGroup
OutputConfig map[string]*vo.TypeInfo
Updater database.DatabaseOperator
}
type Update struct {
config *UpdateConfig
}
type UpdateInventory struct {
ConditionGroup *database.ConditionGroup
Fields map[string]any
}
func NewUpdate(_ context.Context, cfg *UpdateConfig) (*Update, error) {
if cfg == nil {
return nil, errors.New("config is required")
}
if cfg.DatabaseInfoID == 0 {
return nil, errors.New("database info id is required and greater than 0")
}
if cfg.ClauseGroup == nil {
return nil, errors.New("clause group is required and greater than 0")
}
if cfg.Updater == nil {
return nil, errors.New("updater is required")
}
return &Update{config: cfg}, nil
}
func (u *Update) Update(ctx context.Context, in map[string]any) (map[string]any, error) {
inventory, err := convertClauseGroupToUpdateInventory(ctx, u.config.ClauseGroup, in)
if err != nil {
return nil, err
}
fields := make(map[string]any)
for key, value := range inventory.Fields {
fields[key] = value
}
req := &database.UpdateRequest{
DatabaseInfoID: u.config.DatabaseInfoID,
ConditionGroup: inventory.ConditionGroup,
Fields: fields,
IsDebugRun: isDebugExecute(ctx),
UserID: getExecUserID(ctx),
}
response, err := u.config.Updater.Update(ctx, req)
if err != nil {
return nil, err
}
ret, err := responseFormatted(u.config.OutputConfig, response)
if err != nil {
return nil, err
}
return ret, nil
}
func (u *Update) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) {
inventory, err := convertClauseGroupToUpdateInventory(context.Background(), u.config.ClauseGroup, in)
if err != nil {
return nil, err
}
return u.toDatabaseUpdateCallbackInput(inventory)
}
func (u *Update) toDatabaseUpdateCallbackInput(inventory *UpdateInventory) (map[string]any, error) {
databaseID := u.config.DatabaseInfoID
result := make(map[string]any)
result["databaseInfoList"] = []string{fmt.Sprintf("%d", databaseID)}
result["updateParam"] = map[string]any{}
condition, err := convertToCondition(inventory.ConditionGroup)
if err != nil {
return nil, err
}
type FieldInfo struct {
fieldID string
fieldValue any
}
fieldInfo := make([]FieldInfo, 0)
for k, v := range inventory.Fields {
fieldInfo = append(fieldInfo, FieldInfo{
fieldID: k,
fieldValue: v,
})
}
result["updateParam"] = map[string]any{
"condition": condition,
"fieldInfo": fieldInfo,
}
return result, nil
}

View File

@@ -0,0 +1,555 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package emitter
import (
"context"
"errors"
"fmt"
"io"
"strings"
"github.com/bytedance/sonic"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
)
type OutputEmitter struct {
cfg *Config
}
type Config struct {
Template string
FullSources map[string]*nodes.SourceInfo
}
func New(_ context.Context, cfg *Config) (*OutputEmitter, error) {
if cfg == nil {
return nil, errors.New("config is required")
}
return &OutputEmitter{
cfg: cfg,
}, nil
}
type cachedVal struct {
val any
finished bool
subCaches *cacheStore
}
type cacheStore struct {
store map[string]*cachedVal
infos map[string]*nodes.SourceInfo
}
func newCacheStore(infos map[string]*nodes.SourceInfo) *cacheStore {
return &cacheStore{
store: make(map[string]*cachedVal),
infos: infos,
}
}
func (c *cacheStore) put(k string, v any) (any, error) {
sInfo, ok := c.infos[k]
if !ok {
return nil, fmt.Errorf("no such key found from SourceInfos: %s", k)
}
if !sInfo.IsIntermediate { // this is not an intermediate object container
isStream := sInfo.FieldType == nodes.FieldIsStream
if !isStream {
_, ok := c.store[k]
if !ok {
out := &cachedVal{
val: v,
finished: true,
}
c.store[k] = out
return v, nil
} else {
return nil, fmt.Errorf("source %s not intermediate, not stream, appears multiple times", k)
}
} else { // this is an actual stream
vStr, ok := v.(string) // stream value should be string
if !ok {
return nil, fmt.Errorf("source %s is not intermediate, is stream, but value type not str: %T", k, v)
}
isFinished := strings.HasSuffix(vStr, nodes.KeyIsFinished)
if isFinished {
vStr = strings.TrimSuffix(vStr, nodes.KeyIsFinished)
}
var existingStr string
existing, ok := c.store[k]
if ok {
existingStr = existing.val.(string)
}
existingStr = existingStr + vStr
out := &cachedVal{
val: existingStr,
finished: isFinished,
}
c.store[k] = out
return vStr, nil
}
}
if len(sInfo.SubSources) == 0 {
// k is intermediate container, needs to check its sub sources
return nil, fmt.Errorf("source %s is intermediate, but does not have sub sources", k)
}
vMap, ok := v.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("source %s is intermediate, but value type not map: %T", k, v)
}
currentCache, existed := c.store[k]
if !existed {
currentCache = &cachedVal{
val: v,
subCaches: newCacheStore(sInfo.SubSources),
}
c.store[k] = currentCache
} else {
// already cached k before, merge cached value with new value
currentCache.val = merge(currentCache.val, v)
}
subCacheStore := currentCache.subCaches
for subK := range subCacheStore.infos {
subV, ok := vMap[subK]
if !ok { // subK not present in this chunk
continue
}
_, err := subCacheStore.put(subK, subV)
if err != nil {
return nil, err
}
}
_ = c.finished(k)
return vMap, nil
}
func (c *cacheStore) finished(k string) bool {
cached, ok := c.store[k]
if !ok {
return c.infos[k].FieldType == nodes.FieldSkipped
}
if cached.finished {
return true
}
sInfo := c.infos[k]
if !sInfo.IsIntermediate {
return cached.finished
}
for subK := range sInfo.SubSources {
subFinished := cached.subCaches.finished(subK)
if !subFinished {
return false
}
}
cached.finished = true
return true
}
func (c *cacheStore) find(part nodes.TemplatePart) (root any, subCache *cachedVal, sourceInfo *nodes.SourceInfo,
actualPath []string,
) {
rootCached, ok := c.store[part.Root]
if !ok {
return nil, nil, nil, nil
}
// now try to find the nearest match within the cached tree
subPaths := part.SubPathsBeforeSlice
currentCache := rootCached
currentSource := c.infos[part.Root]
for i := range subPaths {
if !currentSource.IsIntermediate {
// currentSource is already the leaf, no need to look further
break
}
subPath := subPaths[i]
subInfo, ok := currentSource.SubSources[subPath]
if !ok {
// this sub path is not in the source info tree
// it's just a user defined variable field in the template
break
}
actualPath = append(actualPath, subPath)
subCache, ok = currentCache.subCaches.store[subPath]
if !ok {
// subPath corresponds to a real Field Source,
// if it's not cached, then it hasn't appeared in the stream yet
return rootCached.val, nil, subInfo, actualPath
}
if !subCache.finished {
return rootCached.val, subCache, subInfo, actualPath
}
currentCache = subCache
currentSource = subInfo
}
return rootCached.val, currentCache, currentSource, actualPath
}
func (c *cacheStore) readyForPart(part nodes.TemplatePart, sw *schema.StreamWriter[map[string]any]) (
hasErr bool, partFinished bool) {
cachedRoot, subCache, sourceInfo, _ := c.find(part)
if cachedRoot != nil && subCache != nil {
if subCache.finished || sourceInfo.FieldType == nodes.FieldIsStream {
hasErr = renderAndSend(part, part.Root, cachedRoot, sw)
if hasErr {
return true, false
}
if subCache.finished { // move on to next part in template
return false, true
}
}
}
return false, false
}
func (c *cacheStore) fillZero(nodeKey vo.NodeKey) map[string]any {
filled := make(map[string]any)
for field, sInfo := range c.infos {
if !sInfo.FromNode(nodeKey) {
continue
}
cacheV, ok := c.store[field]
if !sInfo.IsIntermediate {
if !ok {
c.store[field] = &cachedVal{
val: sInfo.TypeInfo.Zero(),
finished: true,
subCaches: nil,
}
filled[field] = true
}
continue
}
if !ok {
cacheV = &cachedVal{
val: make(map[string]any),
subCaches: newCacheStore(sInfo.SubSources),
}
c.store[field] = cacheV
}
subFilled := cacheV.subCaches.fillZero(nodeKey)
if len(subFilled) > 0 {
filled[field] = subFilled
}
}
return filled
}
func merge(a, b any) any {
aStr, ok1 := a.(string)
bStr, ok2 := b.(string)
if ok1 && ok2 {
if strings.HasSuffix(bStr, nodes.KeyIsFinished) {
bStr = strings.TrimSuffix(bStr, nodes.KeyIsFinished)
}
return aStr + bStr
}
aMap, ok1 := a.(map[string]interface{})
bMap, ok2 := b.(map[string]interface{})
if ok1 && ok2 {
merged := make(map[string]any)
for k, v := range aMap {
merged[k] = v
}
for k, v := range bMap {
if _, ok := merged[k]; !ok { // only bMap has this field, just set it
merged[k] = v
continue
}
merged[k] = merge(merged[k], v)
}
return merged
}
panic(fmt.Errorf("can only merge two maps or two strings, a type: %T, b type: %T", a, b))
}
const outputKey = "output"
func (e *OutputEmitter) EmitStream(ctx context.Context, in *schema.StreamReader[map[string]any]) (out *schema.StreamReader[map[string]any], err error) {
resolvedSources, err := nodes.ResolveStreamSources(ctx, e.cfg.FullSources)
if err != nil {
return nil, err
}
sr, sw := schema.Pipe[map[string]any](0)
parts := nodes.ParseTemplate(e.cfg.Template)
safego.Go(ctx, func() {
hasErr := false
defer func() {
if !hasErr {
sw.Send(map[string]any{outputKey: nodes.KeyIsFinished}, nil)
}
sw.Close()
in.Close()
}()
caches := newCacheStore(resolvedSources)
partsLoop:
for _, part := range parts {
select {
case <-ctx.Done(): // canceled by Eino workflow engine
sw.Send(nil, ctx.Err())
hasErr = true
return
default:
}
if !part.IsVariable { // literal string within template, just emit it
sw.Send(map[string]any{outputKey: part.Value}, nil)
continue
}
// now this 'part' is a variable, first check if the source(s) for it are skipped (the nodes are not selected)
// if skipped, just move on to the next 'part'
skipped, invalid := part.Skipped(resolvedSources)
if skipped {
continue
}
if invalid {
sw.Send(map[string]any{outputKey: "{{" + part.Value + "}}"}, nil)
continue
}
// now this 'part' definitely should have a match, look for a hit within cache store
// if found in cache store, emit the root only if the match is finished or the match is stream
// the rule for a hit: the nearest match within the source tree
// if hit, and the cachedVal is also finished, continue to next template part
var partFinished bool
hasErr, partFinished = caches.readyForPart(part, sw)
if hasErr {
return
}
if partFinished {
continue
}
for {
select {
case <-ctx.Done(): // canceled by Eino workflow engine or timeout
sw.Send(nil, ctx.Err())
hasErr = true
return
default:
}
shouldChangePart := false
chunk, err := in.Recv()
if err != nil {
if err == io.EOF {
// current part is not fulfilled, emit the literal part content and move on to next part
sw.Send(map[string]any{outputKey: part.Value}, nil)
break
}
if sn, ok := schema.GetSourceName(err); ok {
// received end signal for a particular predecessor nodeID, do the following:
// - obtain the field sources mapped from this predecessor node
// - check which fields are still missing in the cache store
// - fill zero value for these missing fields
// - check if the current template part should be rendered and sent immediately
// - check if we should move on to next part in template
filled := caches.fillZero(vo.NodeKey(sn))
if _, okk := filled[part.Root]; okk {
// current part is influenced by the 'fill zero' operation
hasErr, shouldChangePart = caches.readyForPart(part, sw)
if hasErr {
return
}
if shouldChangePart {
continue partsLoop
}
}
continue
}
hasErr = true
sw.Send(nil, err) // real error
return
}
chunkLoop:
for k := range chunk {
v := chunk[k]
v, err = caches.put(k, v) // always update the cache
if err != nil {
hasErr = true
sw.Send(nil, err)
return
}
// needs to check if this 'k' is the current part's root
// if it is, do the case analysis:
// - the source is a leaf (not intermediate):
// - the source is stream, emit the formatted stream content immediately
// - the source is not stream, emit the formatted one-time content immediately
// - the source is intermediate:
// - the source is not finished, do not emit it
// - the source is finished, emit the full content (cached + new) immediately
if k == part.Root {
cachedRoot, subCache, sourceInfo, actualPath := caches.find(part)
if sourceInfo == nil {
panic("impossible, k is part.root, but sourceInfo is nil")
}
if subCache != nil {
if sourceInfo.IsIntermediate {
if subCache.finished {
hasErr = renderAndSend(part, part.Root, cachedRoot, sw)
if hasErr {
return
}
shouldChangePart = true
}
} else {
if sourceInfo.FieldType == nodes.FieldIsStream {
currentV := v
for i := 0; i < len(actualPath)-1; i++ {
currentM, ok := currentV.(map[string]any)
if !ok {
panic("emit item not map[string]any")
}
currentV, ok = currentM[actualPath[i]]
if !ok {
continue chunkLoop
}
}
if len(actualPath) > 0 {
finalV, ok := currentV.(map[string]any)[actualPath[len(actualPath)-1]]
if !ok {
continue chunkLoop
}
currentV = finalV
}
vStr, ok := currentV.(string)
if !ok {
panic(fmt.Errorf("source %s is not intermediate, is stream, but value type not str: %T", k, v))
}
if strings.HasSuffix(vStr, nodes.KeyIsFinished) {
vStr = strings.TrimSuffix(vStr, nodes.KeyIsFinished)
}
var delta any
delta = vStr
for j := len(actualPath) - 1; j >= 0; j-- {
delta = map[string]any{
actualPath[j]: delta,
}
}
hasErr = renderAndSend(part, part.Root, delta, sw)
} else {
hasErr = renderAndSend(part, part.Root, v, sw)
}
if hasErr {
return
}
if subCache.finished {
shouldChangePart = true
}
}
}
}
}
if shouldChangePart {
continue partsLoop
}
}
}
})
return sr, nil
}
func (e *OutputEmitter) Emit(ctx context.Context, in map[string]any) (output map[string]any, err error) {
s, err := nodes.Render(ctx, e.cfg.Template, in, e.cfg.FullSources)
if err != nil {
return nil, err
}
output = map[string]any{
outputKey: s,
}
return output, nil
}
func renderAndSend(tp nodes.TemplatePart, k string, v any, sw *schema.StreamWriter[map[string]any]) bool /*hasError*/ {
m, err := sonic.Marshal(map[string]any{k: v})
if err != nil {
sw.Send(nil, err)
return true
}
r, err := tp.Render(m)
if err != nil {
sw.Send(nil, err)
return true
}
if len(r) == 0 { // won't send if formatted result is empty string
return false
}
logs.Infof("send: %v", r)
sw.Send(map[string]any{outputKey: r}, nil)
return false
}

View File

@@ -0,0 +1,62 @@
package entry
import (
"context"
"fmt"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
)
type Config struct {
DefaultValues map[string]any
OutputTypes map[string]*vo.TypeInfo
}
type Entry struct {
cfg *Config
defaultValues map[string]any
}
func NewEntry(ctx context.Context, cfg *Config) (*Entry, error) {
if cfg == nil {
return nil, fmt.Errorf("config is requried")
}
defaultValues, _, err := nodes.ConvertInputs(ctx, cfg.DefaultValues, cfg.OutputTypes, nodes.FailFast(), nodes.SkipRequireCheck())
if err != nil {
return nil, err
}
return &Entry{
cfg: cfg,
defaultValues: defaultValues,
}, nil
}
func (e *Entry) Invoke(_ context.Context, in map[string]any) (out map[string]any, err error) {
for k, v := range e.defaultValues {
if val, ok := in[k]; ok {
tInfo := e.cfg.OutputTypes[k]
switch tInfo.Type {
case vo.DataTypeString:
if len(val.(string)) == 0 {
in[k] = v
}
case vo.DataTypeArray:
if len(val.([]any)) == 0 {
in[k] = v
}
case vo.DataTypeObject:
if len(val.(map[string]any)) == 0 {
in[k] = v
}
}
} else {
in[k] = v
}
}
return in, err
}

View File

@@ -0,0 +1,684 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package httprequester
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"maps"
"mime/multipart"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"time"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/pkg/lang/crypto"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
)
const defaultGetFileTimeout = 20 // second
const maxSize int64 = 20 * 1024 * 1024 // 20MB
const (
HeaderAuthorization = "Authorization"
HeaderBearerPrefix = "Bearer "
HeaderContentType = "Content-Type"
)
type AuthType uint
const (
BearToken AuthType = 1
Custom AuthType = 2
)
const (
ContentTypeJSON = "application/json"
ContentTypePlainText = "text/plain"
ContentTypeFormURLEncoded = "application/x-www-form-urlencoded"
ContentTypeBinary = "application/octet-stream"
)
type Location uint8
const (
Header Location = 1
QueryParam Location = 2
)
type BodyType string
const (
BodyTypeNone BodyType = "EMPTY"
BodyTypeJSON BodyType = "JSON"
BodyTypeRawText BodyType = "RAW_TEXT"
BodyTypeFormData BodyType = "FORM_DATA"
BodyTypeFormURLEncoded BodyType = "FORM_URLENCODED"
BodyTypeBinary BodyType = "BINARY"
)
type URLConfig struct {
Tpl string `json:"tpl"`
}
type IgnoreExceptionSetting struct {
IgnoreException bool `json:"ignore_exception"`
DefaultOutput map[string]any `json:"default_output,omitempty"`
}
type BodyConfig struct {
BodyType BodyType `json:"body_type"`
FormDataConfig *FormDataConfig `json:"form_data_config,omitempty"`
TextPlainConfig *TextPlainConfig `json:"text_plain_config,omitempty"`
TextJsonConfig *TextJsonConfig `json:"text_json_config,omitempty"`
}
type FormDataConfig struct {
FileTypeMapping map[string]bool `json:"file_type_mapping"`
}
type TextPlainConfig struct {
Tpl string `json:"tpl"`
}
type TextJsonConfig struct {
Tpl string
}
type AuthenticationConfig struct {
Type AuthType `json:"type"`
Location Location `json:"location"`
}
type Authentication struct {
Key string
Value string
Token string
}
type Request struct {
URLVars map[string]any
Headers map[string]string
Params map[string]string
Authentication *Authentication
FormDataVars map[string]string
FormURLEncodedVars map[string]string
JsonVars map[string]any
TextPlainVars map[string]any
FileURL *string
}
var globalVariableReplaceRegexp = regexp.MustCompile(`global_variable_(\w+)\["(\w+)"\]`)
type MD5FieldMapping struct {
HeaderMD5Mapping map[string]string `json:"header_md_5_mapping,omitempty"` // md5 vs key
ParamMD5Mapping map[string]string `json:"param_md_5_mapping,omitempty"`
URLMD5Mapping map[string]string `json:"url_md_5_mapping,omitempty"`
BodyMD5Mapping map[string]string `json:"body_md_5_mapping,omitempty"`
}
func (fm *MD5FieldMapping) SetHeaderFields(fields ...string) {
if fm.HeaderMD5Mapping == nil && len(fields) > 0 {
fm.HeaderMD5Mapping = make(map[string]string)
}
for _, field := range fields {
fm.HeaderMD5Mapping[crypto.MD5HexValue(field)] = field
}
}
func (fm *MD5FieldMapping) SetParamFields(fields ...string) {
if fm.ParamMD5Mapping == nil && len(fields) > 0 {
fm.ParamMD5Mapping = make(map[string]string)
}
for _, field := range fields {
fm.ParamMD5Mapping[crypto.MD5HexValue(field)] = field
}
}
func (fm *MD5FieldMapping) SetURLFields(fields ...string) {
if fm.URLMD5Mapping == nil && len(fields) > 0 {
fm.URLMD5Mapping = make(map[string]string)
}
for _, field := range fields {
fm.URLMD5Mapping[crypto.MD5HexValue(field)] = field
}
}
func (fm *MD5FieldMapping) SetBodyFields(fields ...string) {
if fm.BodyMD5Mapping == nil && len(fields) > 0 {
fm.BodyMD5Mapping = make(map[string]string)
}
for _, field := range fields {
fm.BodyMD5Mapping[crypto.MD5HexValue(field)] = field
}
}
type Config struct {
URLConfig URLConfig
AuthConfig *AuthenticationConfig
BodyConfig BodyConfig
Method string
Timeout time.Duration
RetryTimes uint64
IgnoreException bool
DefaultOutput map[string]any
MD5FieldMapping
}
type HTTPRequester struct {
client *http.Client
config *Config
}
func NewHTTPRequester(_ context.Context, cfg *Config) (*HTTPRequester, error) {
if cfg == nil {
return nil, fmt.Errorf("config is requried")
}
if len(cfg.Method) == 0 {
return nil, fmt.Errorf("method is requried")
}
hg := &HTTPRequester{}
client := http.DefaultClient
if cfg.Timeout > 0 {
client.Timeout = cfg.Timeout
}
hg.client = client
hg.config = cfg
return hg, nil
}
func (hg *HTTPRequester) Invoke(ctx context.Context, input map[string]any) (output map[string]any, err error) {
var (
req = &Request{}
method = hg.config.Method
retryTimes = hg.config.RetryTimes
body io.ReadCloser
contentType string
response *http.Response
)
req, err = hg.config.parserToRequest(input)
if err != nil {
return nil, err
}
httpRequest := &http.Request{
Method: method,
Header: http.Header{},
}
httpURL, err := nodes.TemplateRender(hg.config.URLConfig.Tpl, req.URLVars)
if err != nil {
return nil, err
}
for key, value := range req.Headers {
httpRequest.Header.Set(key, value)
}
u, err := url.Parse(httpURL)
if err != nil {
return nil, err
}
params := u.Query()
for key, value := range req.Params {
params.Set(key, value)
}
if hg.config.AuthConfig != nil {
httpRequest.Header, params, err = hg.config.AuthConfig.addAuthentication(ctx, req.Authentication, httpRequest.Header, params)
if err != nil {
return nil, err
}
}
u.RawQuery = params.Encode()
httpRequest.URL = u
body, contentType, err = hg.config.BodyConfig.getBodyAndContentType(ctx, req)
if err != nil {
return nil, err
}
if body != nil {
httpRequest.Body = body
}
if contentType != "" {
httpRequest.Header.Add(HeaderContentType, contentType)
}
for i := uint64(0); i <= retryTimes; i++ {
response, err = hg.client.Do(httpRequest)
if err == nil {
break
}
}
if err != nil {
return nil, err
}
result := make(map[string]any)
headers := func() string {
// The structure of httpResp.Header is map[string][]string
// If there are multiple header values, the last one will be selected by default
hds := make(map[string]string, len(response.Header))
for key, values := range response.Header {
if len(values) == 0 {
hds[key] = ""
} else {
hds[key] = values[len(values)-1]
}
}
bs, _ := json.Marshal(hds)
return string(bs)
}()
result["headers"] = headers
var bodyBytes []byte
if response.Body != nil {
defer func() {
_ = response.Body.Close()
}()
bodyBytes, err = io.ReadAll(response.Body)
if err != nil {
return nil, err
}
}
if response.StatusCode >= http.StatusBadRequest {
return nil, fmt.Errorf("request %v failed, response status code=%d, status=%v, headers=%v, body=%v",
httpURL, response.StatusCode, response.Status, headers, string(bodyBytes))
}
result["body"] = string(bodyBytes)
result["statusCode"] = int64(response.StatusCode)
return result, nil
}
// decodeUnicode parses the Unicode escape sequence in the string
func decodeUnicode(s string) string {
var result strings.Builder
for i := 0; i < len(s); {
if i+1 < len(s) && s[i] == '\\' && s[i+1] == 'u' {
if i+6 <= len(s) {
hexStr := s[i+2 : i+6]
if code, err := strconv.ParseInt(hexStr, 16, 32); err == nil {
result.WriteRune(rune(code))
i += 6
continue
}
}
}
result.WriteByte(s[i])
i++
}
return result.String()
}
func (authCfg *AuthenticationConfig) addAuthentication(_ context.Context, auth *Authentication, header http.Header, params url.Values) (
http.Header, url.Values, error) {
if authCfg.Type == BearToken {
header.Set(HeaderAuthorization, HeaderBearerPrefix+auth.Token)
return header, params, nil
}
if authCfg.Type == Custom && authCfg.Location == Header {
header.Set(auth.Key, auth.Value)
return header, params, nil
}
if authCfg.Type == Custom && authCfg.Location == QueryParam {
params.Set(auth.Key, auth.Value)
return header, params, nil
}
return header, params, nil
}
func (b *BodyConfig) getBodyAndContentType(ctx context.Context, req *Request) (io.ReadCloser, string, error) {
var (
body io.Reader
contentType string
)
// body none return body nil
if b.BodyType == BodyTypeNone {
return nil, "", nil
}
switch b.BodyType {
case BodyTypeJSON:
jsonString, err := nodes.TemplateRender(b.TextJsonConfig.Tpl, req.JsonVars)
if err != nil {
return nil, contentType, err
}
body = strings.NewReader(jsonString)
contentType = ContentTypeJSON
case BodyTypeFormURLEncoded:
form := url.Values{}
for key, value := range req.FormURLEncodedVars {
form.Add(key, value)
}
body = strings.NewReader(form.Encode())
contentType = ContentTypeFormURLEncoded
case BodyTypeRawText:
textString, err := nodes.TemplateRender(b.TextPlainConfig.Tpl, req.TextPlainVars)
if err != nil {
return nil, contentType, err
}
body = strings.NewReader(textString)
contentType = ContentTypePlainText
case BodyTypeBinary:
if req.FileURL == nil {
return nil, contentType, fmt.Errorf("file url is required")
}
fileURL := *req.FileURL
response, err := httpGet(ctx, fileURL)
if err != nil {
return nil, contentType, err
}
body = response.Body
contentType = ContentTypeBinary
case BodyTypeFormData:
var buffer = &bytes.Buffer{}
formDataConfig := b.FormDataConfig
writer := multipart.NewWriter(buffer)
total := int64(0)
for key, value := range req.FormDataVars {
if ok := formDataConfig.FileTypeMapping[key]; ok {
fileWrite, err := writer.CreateFormFile(key, key)
if err != nil {
return nil, contentType, err
}
response, err := httpGet(ctx, value)
if err != nil {
return nil, contentType, err
}
if response.StatusCode != http.StatusOK {
return nil, contentType, fmt.Errorf("failed to download file: %s, status code %v", value, response.StatusCode)
}
size, err := io.Copy(fileWrite, response.Body)
if err != nil {
return nil, contentType, err
}
total += size
if total > maxSize {
return nil, contentType, fmt.Errorf("too large body, total size: %d", total)
}
} else {
err := writer.WriteField(key, value)
if err != nil {
return nil, contentType, err
}
}
}
_ = writer.Close()
contentType = writer.FormDataContentType()
body = buffer
default:
return nil, contentType, fmt.Errorf("unknown content type %s", b.BodyType)
}
if _, ok := body.(io.ReadCloser); ok {
return body.(io.ReadCloser), contentType, nil
}
return io.NopCloser(body), contentType, nil
}
func httpGet(ctx context.Context, url string) (*http.Response, error) {
request, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, err
}
http.DefaultClient.Timeout = time.Second * defaultGetFileTimeout
return http.DefaultClient.Do(request)
}
func (hg *HTTPRequester) ToCallbackInput(_ context.Context, input map[string]any) (map[string]any, error) {
var (
request = &Request{}
config = hg.config
)
request, err := hg.config.parserToRequest(input)
if err != nil {
return nil, err
}
result := make(map[string]any)
result["method"] = config.Method
u, err := nodes.TemplateRender(config.URLConfig.Tpl, request.URLVars)
if err != nil {
return nil, err
}
result["url"] = u
params := make(map[string]any, len(request.Params))
for k, v := range request.Params {
params[k] = v
}
result["param"] = params
headers := make(map[string]any, len(request.Headers))
for k, v := range request.Headers {
headers[k] = v
}
result["header"] = headers
result["auth"] = nil
if config.AuthConfig != nil {
if config.AuthConfig.Type == Custom {
result["auth"] = map[string]interface{}{
"Key": request.Authentication.Key,
"Value": request.Authentication.Value,
}
} else if config.AuthConfig.Type == BearToken {
result["auth"] = map[string]interface{}{
"token": request.Authentication.Token,
}
}
}
result["body"] = nil
switch config.BodyConfig.BodyType {
case BodyTypeJSON:
js, err := nodes.TemplateRender(config.BodyConfig.TextJsonConfig.Tpl, request.JsonVars)
if err != nil {
return nil, err
}
ret := make(map[string]any)
err = sonic.Unmarshal([]byte(js), &ret)
if err != nil {
return nil, err
}
result["body"] = ret
case BodyTypeRawText:
tx, err := nodes.TemplateRender(config.BodyConfig.TextPlainConfig.Tpl, request.TextPlainVars)
if err != nil {
return nil, err
}
result["body"] = tx
case BodyTypeFormData:
result["body"] = request.FormDataVars
case BodyTypeFormURLEncoded:
result["body"] = request.FormURLEncodedVars
case BodyTypeBinary:
result["body"] = request.FileURL
}
return result, nil
}
const (
apiInfoURLPrefix = "__apiInfo_url_"
headersPrefix = "__headers_"
paramsPrefix = "__params_"
authDataPrefix = "__auth_authData_"
authBearerTokenDataPrefix = "bearerTokenData_token"
authCustomDataPrefix = "customData_data"
bodyDataPrefix = "__body_bodyData_"
bodyJsonPrefix = "json_"
bodyFormDataPrefix = "formData_"
bodyFormURLEncodedPrefix = "formURLEncoded_"
bodyRawTextPrefix = "rawText_"
bodyBinaryFileURLPrefix = "binary_fileURL"
)
func (cfg *Config) parserToRequest(input map[string]any) (*Request, error) {
request := &Request{
URLVars: make(map[string]any),
Headers: make(map[string]string),
Params: make(map[string]string),
Authentication: &Authentication{},
FormURLEncodedVars: make(map[string]string),
JsonVars: make(map[string]any),
TextPlainVars: make(map[string]any),
FormDataVars: map[string]string{},
}
for key, value := range input {
if strings.HasPrefix(key, apiInfoURLPrefix) {
urlMD5 := strings.TrimPrefix(key, apiInfoURLPrefix)
if urlKey, ok := cfg.URLMD5Mapping[urlMD5]; ok {
if strings.HasPrefix(urlKey, "global_variable_") {
urlKey = globalVariableReplaceRegexp.ReplaceAllString(urlKey, "global_variable_$1.$2")
}
nodes.SetMapValue(request.URLVars, strings.Split(urlKey, "."), value.(string))
}
}
if strings.HasPrefix(key, headersPrefix) {
headerKeyMD5 := strings.TrimPrefix(key, headersPrefix)
if headerKey, ok := cfg.HeaderMD5Mapping[headerKeyMD5]; ok {
request.Headers[headerKey] = value.(string)
}
}
if strings.HasPrefix(key, paramsPrefix) {
paramKeyMD5 := strings.TrimPrefix(key, paramsPrefix)
if paramKey, ok := cfg.ParamMD5Mapping[paramKeyMD5]; ok {
request.Params[paramKey] = value.(string)
}
}
if strings.HasPrefix(key, authDataPrefix) {
authKey := strings.TrimPrefix(key, authDataPrefix)
if strings.HasPrefix(authKey, authBearerTokenDataPrefix) {
request.Authentication.Token = value.(string) // bear
}
if strings.HasPrefix(authKey, authCustomDataPrefix) {
if key == "__auth_authData_customData_data_Key" {
request.Authentication.Key = value.(string)
}
if key == "__auth_authData_customData_data_Value" {
request.Authentication.Value = value.(string)
}
}
}
if strings.HasPrefix(key, bodyDataPrefix) {
bodyKey := strings.TrimPrefix(key, bodyDataPrefix)
if strings.HasPrefix(bodyKey, bodyJsonPrefix) {
jsonMd5Key := strings.TrimPrefix(bodyKey, bodyJsonPrefix)
if jsonKey, ok := cfg.BodyMD5Mapping[jsonMd5Key]; ok {
if strings.HasPrefix(jsonKey, "global_variable_") {
jsonKey = globalVariableReplaceRegexp.ReplaceAllString(jsonKey, "global_variable_$1.$2")
}
nodes.SetMapValue(request.JsonVars, strings.Split(jsonKey, "."), value)
}
}
if strings.HasPrefix(bodyKey, bodyFormDataPrefix) {
formDataMd5Key := strings.TrimPrefix(bodyKey, bodyFormDataPrefix)
if formDataKey, ok := cfg.BodyMD5Mapping[formDataMd5Key]; ok {
request.FormDataVars[formDataKey] = value.(string)
}
}
if strings.HasPrefix(bodyKey, bodyFormURLEncodedPrefix) {
formURLEncodeMd5Key := strings.TrimPrefix(bodyKey, bodyFormURLEncodedPrefix)
if formURLEncodeKey, ok := cfg.BodyMD5Mapping[formURLEncodeMd5Key]; ok {
request.FormURLEncodedVars[formURLEncodeKey] = value.(string)
}
}
if strings.HasPrefix(bodyKey, bodyRawTextPrefix) {
rawTextMd5Key := strings.TrimPrefix(bodyKey, bodyRawTextPrefix)
if rawTextKey, ok := cfg.BodyMD5Mapping[rawTextMd5Key]; ok {
if strings.HasPrefix(rawTextKey, "global_variable_") {
rawTextKey = globalVariableReplaceRegexp.ReplaceAllString(rawTextKey, "global_variable_$1.$2")
}
nodes.SetMapValue(request.TextPlainVars, strings.Split(rawTextKey, "."), value)
}
}
if strings.HasPrefix(bodyKey, bodyBinaryFileURLPrefix) {
request.FileURL = ptr.Of(value.(string))
}
}
}
return request, nil
}
func (hg *HTTPRequester) ToCallbackOutput(_ context.Context, out map[string]any) (*nodes.StructuredCallbackOutput, error) {
if body, ok := out["body"]; !ok {
return &nodes.StructuredCallbackOutput{
RawOutput: out,
Output: out,
}, nil
} else {
output := maps.Clone(out)
output["body"] = decodeUnicode(body.(string))
return &nodes.StructuredCallbackOutput{
RawOutput: out,
Output: output,
}, nil
}
}

View File

@@ -0,0 +1,396 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package httprequester
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/coze-dev/coze-studio/backend/pkg/lang/crypto"
)
func TestInvoke(t *testing.T) {
t.Run("get method", func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
response := map[string]string{
"message": "success",
}
bs, _ := json.Marshal(response)
_, _ = w.Write(bs)
}))
defer ts.Close()
urlTpl := ts.URL + "/{{url_v1}}"
cfg := &Config{
URLConfig: URLConfig{
Tpl: urlTpl,
},
BodyConfig: BodyConfig{
BodyType: BodyTypeNone,
},
Method: http.MethodGet,
RetryTimes: 1,
Timeout: 2 * time.Second,
MD5FieldMapping: MD5FieldMapping{
URLMD5Mapping: map[string]string{
crypto.MD5HexValue("url_v1"): "url_v1",
},
HeaderMD5Mapping: map[string]string{
crypto.MD5HexValue("h1"): "h1",
crypto.MD5HexValue("h2"): "h2",
},
ParamMD5Mapping: map[string]string{
crypto.MD5HexValue("p1"): "p1",
crypto.MD5HexValue("p2"): "p2",
},
},
}
hg, err := NewHTTPRequester(context.Background(), cfg)
assert.NoError(t, err)
m := map[string]any{
"__apiInfo_url_" + crypto.MD5HexValue("url_v1"): "v1",
"__headers_" + crypto.MD5HexValue("h1"): "1",
"__headers_" + crypto.MD5HexValue("h2"): "2",
"__params_" + crypto.MD5HexValue("p1"): "v1",
"__params_" + crypto.MD5HexValue("p2"): "v2",
}
result, err := hg.Invoke(context.Background(), m)
assert.NoError(t, err)
assert.Equal(t, `{"message":"success"}`, result["body"])
assert.Equal(t, int64(200), result["statusCode"])
})
t.Run("post method multipart/form-data", func(t *testing.T) {
fileServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fileContent := "fileV1"
_, _ = w.Write([]byte(fileContent))
}))
defer fileServer.Close()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err := r.ParseMultipartForm(10 << 20) // 10 MB
if err != nil {
return
}
f1 := r.MultipartForm.Value["f1"][0]
assert.Equal(t, "fv1", f1)
f2 := r.MultipartForm.Value["f2"][0]
assert.Equal(t, "fv2", f2)
file, _, err := r.FormFile("fileURL")
if err != nil {
t.Error(err)
}
fileBs, err := io.ReadAll(file)
if err != nil {
t.Error(err)
}
assert.Equal(t, "fileV1", string(fileBs))
w.WriteHeader(http.StatusOK)
response := map[string]string{
"message": "success",
}
bs, _ := json.Marshal(response)
_, _ = w.Write(bs)
}))
defer ts.Close()
urlTpl := ts.URL + "/{{post_v1}}"
cfg := &Config{
URLConfig: URLConfig{
Tpl: urlTpl,
},
BodyConfig: BodyConfig{
BodyType: BodyTypeFormData,
FormDataConfig: &FormDataConfig{
map[string]bool{
"fileURL": true,
},
},
},
Method: http.MethodPost,
RetryTimes: 1,
Timeout: 2 * time.Second,
MD5FieldMapping: MD5FieldMapping{
URLMD5Mapping: map[string]string{
crypto.MD5HexValue("post_v1"): "post_v1",
},
HeaderMD5Mapping: map[string]string{
crypto.MD5HexValue("h1"): "h1",
crypto.MD5HexValue("h2"): "h2",
},
ParamMD5Mapping: map[string]string{
crypto.MD5HexValue("p1"): "p1",
crypto.MD5HexValue("p2"): "p2",
},
BodyMD5Mapping: map[string]string{
crypto.MD5HexValue("f1"): "f1",
crypto.MD5HexValue("f2"): "f2",
crypto.MD5HexValue("fileURL"): "fileURL",
},
},
}
// 创建 HTTPRequest 实例
hg, err := NewHTTPRequester(context.Background(), cfg)
assert.NoError(t, err)
m := map[string]any{
"__apiInfo_url_" + crypto.MD5HexValue("post_v1"): "post_v1",
"__headers_" + crypto.MD5HexValue("h1"): "1",
"__headers_" + crypto.MD5HexValue("h2"): "2",
"__params_" + crypto.MD5HexValue("p1"): "v1",
"__params_" + crypto.MD5HexValue("p2"): "v2",
"__body_bodyData_formData_" + crypto.MD5HexValue("f1"): "fv1",
"__body_bodyData_formData_" + crypto.MD5HexValue("f2"): "fv2",
"__body_bodyData_formData_" + crypto.MD5HexValue("fileURL"): fileServer.URL,
}
result, err := hg.Invoke(context.Background(), m)
assert.NoError(t, err)
assert.Equal(t, `{"message":"success"}`, result["body"])
assert.Equal(t, int64(200), result["statusCode"])
})
t.Run("post method text/plain", func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
if err != nil {
t.Fatal(err)
return
}
defer func() {
_ = r.Body.Close()
}()
assert.Equal(t, "text v1 v2", string(body))
w.WriteHeader(http.StatusOK)
response := map[string]string{
"message": "success",
}
bs, _ := json.Marshal(response)
_, _ = w.Write(bs)
}))
defer ts.Close()
urlTpl := ts.URL + "/{{post_text_plain}}"
cfg := &Config{
URLConfig: URLConfig{
Tpl: urlTpl,
},
BodyConfig: BodyConfig{
BodyType: BodyTypeRawText,
TextPlainConfig: &TextPlainConfig{
Tpl: "text {{v1}} {{v2}}",
},
},
Method: http.MethodPost,
RetryTimes: 1,
Timeout: 2 * time.Second,
MD5FieldMapping: MD5FieldMapping{
URLMD5Mapping: map[string]string{
crypto.MD5HexValue("post_text_plain"): "post_text_plain",
},
HeaderMD5Mapping: map[string]string{
crypto.MD5HexValue("h1"): "h1",
crypto.MD5HexValue("h2"): "h2",
},
ParamMD5Mapping: map[string]string{
crypto.MD5HexValue("p1"): "p1",
crypto.MD5HexValue("p2"): "p2",
},
BodyMD5Mapping: map[string]string{
crypto.MD5HexValue("v1"): "v1",
crypto.MD5HexValue("v2"): "v2",
},
},
}
hg, err := NewHTTPRequester(context.Background(), cfg)
assert.NoError(t, err)
m := map[string]any{
"__apiInfo_url_" + crypto.MD5HexValue("post_text_plain"): "post_text_plain",
"__headers_" + crypto.MD5HexValue("h1"): "1",
"__headers_" + crypto.MD5HexValue("h2"): "2",
"__params_" + crypto.MD5HexValue("p1"): "v1",
"__params_" + crypto.MD5HexValue("p2"): "v2",
"__body_bodyData_rawText_" + crypto.MD5HexValue("v1"): "v1",
"__body_bodyData_rawText_" + crypto.MD5HexValue("v2"): "v2",
}
result, err := hg.Invoke(context.Background(), m)
assert.NoError(t, err)
assert.Equal(t, `{"message":"success"}`, result["body"])
assert.Equal(t, int64(200), result["statusCode"])
})
t.Run("post method application/json", func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
if err != nil {
t.Fatal(err)
return
}
defer func() {
_ = r.Body.Close()
}()
assert.Equal(t, `{"v1":v1,"v2":v2}`, string(body))
w.WriteHeader(http.StatusOK)
response := map[string]string{
"message": "success",
}
bs, _ := json.Marshal(response)
_, _ = w.Write(bs)
}))
defer ts.Close()
urlTpl := ts.URL + "/{{application_json}}"
cfg := &Config{
URLConfig: URLConfig{
Tpl: urlTpl,
},
BodyConfig: BodyConfig{
BodyType: BodyTypeJSON,
TextJsonConfig: &TextJsonConfig{
Tpl: `{"v1":{{v1}},"v2":{{v2}}}`,
},
},
Method: http.MethodPost,
RetryTimes: 1,
Timeout: 2 * time.Second,
MD5FieldMapping: MD5FieldMapping{
URLMD5Mapping: map[string]string{
crypto.MD5HexValue("application_json"): "application_json",
},
HeaderMD5Mapping: map[string]string{
crypto.MD5HexValue("h1"): "h1",
crypto.MD5HexValue("h2"): "h2",
},
ParamMD5Mapping: map[string]string{
crypto.MD5HexValue("p1"): "p1",
crypto.MD5HexValue("p2"): "p2",
},
BodyMD5Mapping: map[string]string{
crypto.MD5HexValue("v1"): "v1",
crypto.MD5HexValue("v2"): "v2",
},
},
}
// 创建 HTTPRequest 实例
hg, err := NewHTTPRequester(context.Background(), cfg)
assert.NoError(t, err)
m := map[string]any{
"__apiInfo_url_" + crypto.MD5HexValue("application_json"): "application_json",
"__headers_" + crypto.MD5HexValue("h1"): "1",
"__headers_" + crypto.MD5HexValue("h2"): "2",
"__params_" + crypto.MD5HexValue("p1"): "v1",
"__params_" + crypto.MD5HexValue("p2"): "v2",
"__body_bodyData_json_" + crypto.MD5HexValue("v1"): "v1",
"__body_bodyData_json_" + crypto.MD5HexValue("v2"): "v2",
}
result, err := hg.Invoke(context.Background(), m)
assert.NoError(t, err)
assert.Equal(t, `{"message":"success"}`, result["body"])
assert.Equal(t, int64(200), result["statusCode"])
})
t.Run("post method application/octet-stream", func(t *testing.T) {
fileServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fileContent := strings.Repeat("fileV1", 100)
_, _ = w.Write([]byte(fileContent))
}))
defer fileServer.Close()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
if err != nil {
t.Fatal(err)
return
}
defer func() {
_ = r.Body.Close()
}()
fileContent := strings.Repeat("fileV1", 100)
assert.Equal(t, fileContent, string(body))
w.WriteHeader(http.StatusOK)
response := map[string]string{
"message": "success",
}
bs, _ := json.Marshal(response)
_, _ = w.Write(bs)
}))
defer ts.Close()
urlTpl := ts.URL + "/{{binary}}"
cfg := &Config{
URLConfig: URLConfig{
Tpl: urlTpl,
},
BodyConfig: BodyConfig{
BodyType: BodyTypeBinary,
},
Method: http.MethodPost,
RetryTimes: 1,
Timeout: 2 * time.Second,
MD5FieldMapping: MD5FieldMapping{
URLMD5Mapping: map[string]string{
crypto.MD5HexValue("binary"): "binary",
},
HeaderMD5Mapping: map[string]string{
crypto.MD5HexValue("h1"): "h1",
crypto.MD5HexValue("h2"): "h2",
},
ParamMD5Mapping: map[string]string{
crypto.MD5HexValue("p1"): "p1",
crypto.MD5HexValue("p2"): "p2",
},
},
}
// 创建 HTTPRequest 实例
hg, err := NewHTTPRequester(context.Background(), cfg)
assert.NoError(t, err)
m := map[string]any{
"__apiInfo_url_" + crypto.MD5HexValue("application_json"): "application_json",
"__headers_" + crypto.MD5HexValue("h1"): "1",
"__headers_" + crypto.MD5HexValue("h2"): "2",
"__params_" + crypto.MD5HexValue("p1"): "v1",
"__params_" + crypto.MD5HexValue("p2"): "v2",
"__body_bodyData_binary_fileURL" + crypto.MD5HexValue("v1"): fileServer.URL,
}
result, err := hg.Invoke(context.Background(), m)
assert.NoError(t, err)
assert.Equal(t, `{"message":"success"}`, result["body"])
assert.Equal(t, int64(200), result["statusCode"])
})
}

View File

@@ -0,0 +1,212 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package intentdetector
import (
"context"
"encoding/json"
"errors"
"strconv"
"strings"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/components/prompt"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
"github.com/spf13/cast"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
)
type Config struct {
Intents []string
SystemPrompt string
IsFastMode bool
ChatModel model.BaseChatModel
}
const SystemIntentPrompt = `
# Role
You are an intention classification expert, good at being able to judge which classification the user's input belongs to.
## Skills
Skill 1: Clearly determine which of the following intention classifications the user's input belongs to.
Intention classification list:
[
{"classificationId": 0, "content": "Other intentions"},
{{intents}}
]
Note:
- Please determine the match only between the user's input content and the Intention classification list content, without judging or categorizing the match with the classification ID.
{{advance}}
## Reply requirements
- The answer must be returned in JSON format.
- Strictly ensure that the output is in a valid JSON format.
- Do not add prefix "json or suffix""
- The answer needs to include the following fields such as:
{
"classificationId": 0,
"reason": "Unclear intentions"
}
##Limit
- Please do not reply in text.
`
const FastModeSystemIntentPrompt = `
# Role
You are an intention classification expert, good at being able to judge which classification the user's input belongs to.
## Skills
Skill 1: Clearly determine which of the following intention classifications the user's input belongs to.
Intention classification list:
[
{"classificationId": 0, "content": "Other intentions"},
{{intents}}
]
Note:
- Please determine the match only between the user's input content and the Intention classification list content, without judging or categorizing the match with the classification ID.
## Reply requirements
- The answer must be a number indicated classificationId.
- if not match, please just output an number 0.
- do not output json format data, just output an number.
##Limit
- Please do not reply in text.`
type IntentDetector struct {
config *Config
runner compose.Runnable[map[string]any, *schema.Message]
}
func NewIntentDetector(ctx context.Context, cfg *Config) (*IntentDetector, error) {
if cfg == nil {
return nil, errors.New("cfg is required")
}
if !cfg.IsFastMode && cfg.ChatModel == nil {
return nil, errors.New("config chat model is required")
}
if len(cfg.Intents) == 0 {
return nil, errors.New("config intents is required")
}
chain := compose.NewChain[map[string]any, *schema.Message]()
spt := ternary.IFElse[string](cfg.IsFastMode, FastModeSystemIntentPrompt, SystemIntentPrompt)
sptTemplate, err := nodes.TemplateRender(spt, map[string]interface{}{
"intents": toIntentString(cfg.Intents),
})
if err != nil {
return nil, err
}
prompts := prompt.FromMessages(schema.Jinja2,
&schema.Message{Content: sptTemplate, Role: schema.System},
&schema.Message{Content: "{{query}}", Role: schema.User})
r, err := chain.AppendChatTemplate(prompts).AppendChatModel(cfg.ChatModel).Compile(ctx)
if err != nil {
return nil, err
}
return &IntentDetector{
config: cfg,
runner: r,
}, nil
}
func (id *IntentDetector) parseToNodeOut(content string) (map[string]any, error) {
nodeOutput := make(map[string]any)
nodeOutput["classificationId"] = 0
if content == "" {
return nodeOutput, errors.New("content is empty")
}
if id.config.IsFastMode {
cid, err := strconv.ParseInt(content, 10, 64)
if err != nil {
return nodeOutput, err
}
nodeOutput["classificationId"] = cid
return nodeOutput, nil
}
leftIndex := strings.Index(content, "{")
rightIndex := strings.Index(content, "}")
if leftIndex == -1 || rightIndex == -1 {
return nodeOutput, errors.New("content is invalid")
}
err := json.Unmarshal([]byte(content[leftIndex:rightIndex+1]), &nodeOutput)
if err != nil {
return nodeOutput, err
}
return nodeOutput, nil
}
func (id *IntentDetector) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
query, ok := input["query"]
if !ok {
return nil, errors.New("input query field required")
}
queryStr, ok := query.(string)
if !ok {
queryStr = cast.ToString(query)
}
vars := make(map[string]any)
vars["query"] = queryStr
if !id.config.IsFastMode {
ad, err := nodes.TemplateRender(id.config.SystemPrompt, map[string]any{"query": query})
if err != nil {
return nil, err
}
vars["advance"] = ad
}
o, err := id.runner.Invoke(ctx, vars)
if err != nil {
return nil, err
}
return id.parseToNodeOut(o.Content)
}
func toIntentString(its []string) string {
type IntentVariableItem struct {
ClassificationID int64 `json:"classificationId"`
Content string `json:"content"`
}
vs := make([]*IntentVariableItem, 0, len(its))
for idx, it := range its {
vs = append(vs, &IntentVariableItem{
ClassificationID: int64(idx + 1),
Content: it,
})
}
itsBytes, _ := json.Marshal(vs)
return string(itsBytes)
}

View File

@@ -0,0 +1,88 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package intentdetector
import (
"context"
"fmt"
"testing"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/schema"
"github.com/stretchr/testify/assert"
)
type mockChatModel struct {
topSeed bool
}
func (m mockChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) {
if m.topSeed {
return &schema.Message{
Content: "1",
}, nil
}
return &schema.Message{
Content: `{"classificationId":1,"reason":"高兴"}`,
}, nil
}
func (m mockChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) {
return nil, nil
}
func (m mockChatModel) BindTools(tools []*schema.ToolInfo) error {
return nil
}
func TestNewIntentDetector(t *testing.T) {
ctx := context.Background()
t.Run("fast mode", func(t *testing.T) {
dt, err := NewIntentDetector(ctx, &Config{
Intents: []string{"高兴", "悲伤"},
IsFastMode: true,
ChatModel: &mockChatModel{topSeed: true},
})
assert.Nil(t, err)
ret, err := dt.Invoke(ctx, map[string]any{
"query": "我考了100分",
})
assert.Nil(t, err)
assert.Equal(t, ret["classificationId"], int64(1))
})
t.Run("full mode", func(t *testing.T) {
dt, err := NewIntentDetector(ctx, &Config{
Intents: []string{"高兴", "悲伤"},
IsFastMode: false,
ChatModel: &mockChatModel{},
})
assert.Nil(t, err)
ret, err := dt.Invoke(ctx, map[string]any{
"query": "我考了100分",
})
fmt.Println(err)
assert.Nil(t, err)
fmt.Println(ret)
assert.Equal(t, ret["classificationId"], float64(1))
assert.Equal(t, ret["reason"], "高兴")
})
}

View File

@@ -0,0 +1,28 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nodes
import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
)
type InterruptEventStore interface {
GetInterruptEvent(nodeKey vo.NodeKey) (*entity.InterruptEvent, bool, error)
SetInterruptEvent(nodeKey vo.NodeKey, value *entity.InterruptEvent) error
DeleteInterruptEvent(nodeKey vo.NodeKey) error
}

View File

@@ -0,0 +1,121 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package json
import (
"context"
"fmt"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
const (
InputKeyDeserialization = "input"
OutputKeyDeserialization = "output"
warningsKey = "deserialization_warnings"
)
type DeserializationConfig struct {
OutputFields map[string]*vo.TypeInfo `json:"outputFields,omitempty"`
}
type Deserializer struct {
config *DeserializationConfig
typeInfo *vo.TypeInfo
}
func NewJsonDeserializer(_ context.Context, cfg *DeserializationConfig) (*Deserializer, error) {
if cfg == nil {
return nil, fmt.Errorf("config required")
}
if cfg.OutputFields == nil {
return nil, fmt.Errorf("OutputFields is required for deserialization")
}
typeInfo := cfg.OutputFields[OutputKeyDeserialization]
if typeInfo == nil {
return nil, fmt.Errorf("no output field specified in deserialization config")
}
return &Deserializer{
config: cfg,
typeInfo: typeInfo,
}, nil
}
func (jd *Deserializer) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
jsonStrValue := input[InputKeyDeserialization]
jsonStr, ok := jsonStrValue.(string)
if !ok {
return nil, fmt.Errorf("input is not a string, got %T", jsonStrValue)
}
typeInfo := jd.typeInfo
var rawValue any
var err error
// Unmarshal based on the root type
switch typeInfo.Type {
case vo.DataTypeString, vo.DataTypeInteger, vo.DataTypeNumber, vo.DataTypeBoolean, vo.DataTypeTime, vo.DataTypeFile:
// Scalar types - unmarshal to generic any
err = sonic.Unmarshal([]byte(jsonStr), &rawValue)
case vo.DataTypeArray:
// Array type - unmarshal to []any
var arr []any
err = sonic.Unmarshal([]byte(jsonStr), &arr)
rawValue = arr
case vo.DataTypeObject:
// Object type - unmarshal to map[string]any
var obj map[string]any
err = sonic.Unmarshal([]byte(jsonStr), &obj)
rawValue = obj
default:
return nil, fmt.Errorf("unsupported root data type: %s", typeInfo.Type)
}
if err != nil {
return nil, fmt.Errorf("JSON unmarshaling failed: %w", err)
}
convertedValue, ws, err := nodes.Convert(ctx, rawValue, OutputKeyDeserialization, typeInfo)
if err != nil {
return nil, err
}
if ws != nil && len(*ws) > 0 {
ctxcache.Store(ctx, warningsKey, *ws)
}
return map[string]any{OutputKeyDeserialization: convertedValue}, nil
}
func (jd *Deserializer) ToCallbackOutput(ctx context.Context, out map[string]any) (*nodes.StructuredCallbackOutput, error) {
var wfe vo.WorkflowError
if warnings, ok := ctxcache.Get[nodes.ConversionWarnings](ctx, warningsKey); ok {
wfe = vo.WrapWarn(errno.ErrNodeOutputParseFail, warnings, errorx.KV("warnings", warnings.Error()))
}
return &nodes.StructuredCallbackOutput{
Output: out,
RawOutput: out,
Error: wfe,
}, nil
}

View File

@@ -0,0 +1,360 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package json
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
)
func TestNewJsonDeserializer(t *testing.T) {
ctx := context.Background()
// Test with nil config
_, err := NewJsonDeserializer(ctx, nil)
assert.Error(t, err)
assert.Contains(t, err.Error(), "config required")
// Test with missing OutputFields config
_, err = NewJsonDeserializer(ctx, &DeserializationConfig{})
assert.Error(t, err)
assert.Contains(t, err.Error(), "OutputFields is required")
// Test with missing output key in OutputFields
_, err = NewJsonDeserializer(ctx, &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"testKey": {Type: vo.DataTypeString},
},
})
assert.Error(t, err)
assert.Contains(t, err.Error(), "no output field specified in deserialization config")
// Test with valid config
validConfig := &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
OutputKeyDeserialization: {Type: vo.DataTypeString},
},
}
processor, err := NewJsonDeserializer(ctx, validConfig)
assert.NoError(t, err)
assert.NotNil(t, processor)
}
func TestJsonDeserializer_Invoke(t *testing.T) {
ctx := context.Background()
// Base type test config
baseTypeConfig := &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {Type: vo.DataTypeString},
},
}
// Object type test config
objectTypeConfig := &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"name": {Type: vo.DataTypeString, Required: true},
"age": {Type: vo.DataTypeInteger},
},
},
},
}
// Array type test config
arrayTypeConfig := &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {
Type: vo.DataTypeArray,
ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeInteger},
},
},
}
// Nested array object test config
nestedArrayConfig := &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {
Type: vo.DataTypeArray,
ElemTypeInfo: &vo.TypeInfo{
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"id": {Type: vo.DataTypeInteger},
"name": {Type: vo.DataTypeString},
},
},
},
},
}
// Test cases
tests := []struct {
name string
config *DeserializationConfig
inputJSON string
expectedOutput any
expectErr bool
expectWarnings int
}{{
name: "Test string deserialization",
config: baseTypeConfig,
inputJSON: `"test string"`,
expectedOutput: "test string",
expectErr: false,
expectWarnings: 0,
}, {
name: "Test integer deserialization",
config: &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {Type: vo.DataTypeInteger},
},
},
inputJSON: `123`,
expectedOutput: 123,
expectErr: false,
expectWarnings: 0,
}, {
name: "Test boolean deserialization",
config: &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {Type: vo.DataTypeBoolean},
},
},
inputJSON: `true`,
expectedOutput: true,
expectErr: false,
expectWarnings: 0,
}, {
name: "Test object deserialization",
config: objectTypeConfig,
inputJSON: `{"name":"test","age":20}`,
expectedOutput: map[string]any{"name": "test", "age": 20.0},
expectErr: false,
expectWarnings: 0,
}, {
name: "Test array deserialization",
config: arrayTypeConfig,
inputJSON: `[1,2,3]`,
expectedOutput: []any{1.0, 2.0, 3.0},
expectErr: false,
expectWarnings: 0,
}, {
name: "Test nested array object deserialization",
config: nestedArrayConfig,
inputJSON: `[{"id":1,"name":"a"},{"id":2,"name":"b"}]`,
expectedOutput: []any{
map[string]any{"id": 1.0, "name": "a"},
map[string]any{"id": 2.0, "name": "b"},
},
expectErr: false,
expectWarnings: 0,
}, {
name: "Test invalid JSON format",
config: baseTypeConfig,
inputJSON: `{invalid json}`,
expectedOutput: nil,
expectErr: true,
expectWarnings: 0,
}, {
name: "Test type mismatch warning",
config: &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {Type: vo.DataTypeInteger},
},
},
inputJSON: `"not a number"`,
expectedOutput: nil,
expectErr: false,
expectWarnings: 1,
}, {
name: "Test null JSON input",
config: baseTypeConfig,
inputJSON: `null`,
expectedOutput: nil,
expectErr: false,
expectWarnings: 0,
}, {
name: "Test string to integer conversion",
config: &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {Type: vo.DataTypeInteger},
},
},
inputJSON: `"123"`,
expectedOutput: 123,
expectErr: false,
expectWarnings: 0,
}, {
name: "Test float to integer conversion (integer part)",
config: &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {Type: vo.DataTypeInteger},
},
},
inputJSON: `123.0`,
expectedOutput: 123,
expectErr: false,
expectWarnings: 0,
}, {
name: "Test float to integer conversion (non-integer part)",
config: &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {Type: vo.DataTypeInteger},
},
},
inputJSON: `123.5`,
expectedOutput: 123,
expectErr: false,
expectWarnings: 0,
}, {
name: "Test boolean to integer conversion",
config: &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {Type: vo.DataTypeInteger},
},
},
inputJSON: `true`,
expectedOutput: nil,
expectErr: false,
expectWarnings: 1,
}, {
name: "Test string to boolean conversion",
config: &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {Type: vo.DataTypeBoolean},
},
},
inputJSON: `"true"`,
expectedOutput: true,
expectErr: false,
expectWarnings: 0,
}, {
name: "Test string to integer conversion in nested object",
config: &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"age": {Type: vo.DataTypeInteger},
},
},
},
},
inputJSON: `{"age":"456"}`,
expectedOutput: map[string]any{"age": 456},
expectErr: false,
expectWarnings: 0,
}, {
name: "Test string to integer conversion for array elements",
config: &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {
Type: vo.DataTypeArray,
ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeInteger},
},
},
},
inputJSON: `["1", "2", "3"]`,
expectedOutput: []any{1, 2, 3},
expectErr: false,
expectWarnings: 0,
}, {
name: "Test string with non-numeric characters to integer conversion",
config: &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {Type: vo.DataTypeInteger},
},
},
inputJSON: `"123abc"`,
expectedOutput: nil,
expectErr: false,
expectWarnings: 1,
}, {
name: "Test type mismatch in nested object field",
config: &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"score": {Type: vo.DataTypeInteger},
},
},
},
},
inputJSON: `{"score":"invalid"}`,
expectedOutput: map[string]any{"score": nil},
expectErr: false,
expectWarnings: 1,
}, {
name: "Test partial conversion failure in array elements",
config: &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {
Type: vo.DataTypeArray,
ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeInteger},
},
},
},
inputJSON: `["1", "two", 3]`,
expectedOutput: []any{1, 3},
expectErr: false,
expectWarnings: 1,
}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
processor, err := NewJsonDeserializer(ctx, tt.config)
assert.NoError(t, err)
ctxWithCache := ctxcache.Init(ctx)
input := map[string]any{"input": tt.inputJSON}
result, err := processor.Invoke(ctxWithCache, input)
if tt.expectErr {
assert.Error(t, err)
return
}
assert.NoError(t, err)
assert.Contains(t, result, OutputKeyDeserialization)
// Verify the output
output := result[OutputKeyDeserialization]
if tt.expectedOutput == nil {
assert.Nil(t, output)
} else {
// Serialize expected and actual output to JSON for comparison, ignoring type differences (e.g., float64 vs. int)
actualJSON, _ := sonic.Marshal(output)
expectedJSON, _ := sonic.Marshal(tt.expectedOutput)
assert.JSONEq(t, string(expectedJSON), string(actualJSON))
}
// Verify the number of warnings
warnings, _ := ctxcache.Get[nodes.ConversionWarnings](ctxWithCache, warningsKey)
assert.Equal(t, tt.expectWarnings, len(warnings))
})
}
}

View File

@@ -0,0 +1,65 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package json
import (
"context"
"fmt"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
)
const (
InputKeySerialization = "input"
OutputKeySerialization = "output"
)
type SerializationConfig struct {
InputTypes map[string]*vo.TypeInfo
}
type JsonSerializer struct {
config *SerializationConfig
}
func NewJsonSerializer(_ context.Context, cfg *SerializationConfig) (*JsonSerializer, error) {
if cfg == nil {
return nil, fmt.Errorf("config required")
}
if cfg.InputTypes == nil {
return nil, fmt.Errorf("InputTypes is required for serialization")
}
return &JsonSerializer{
config: cfg,
}, nil
}
func (js *JsonSerializer) Invoke(_ context.Context, input map[string]any) (map[string]any, error) {
// Directly use the input map for serialization
if input == nil {
return nil, fmt.Errorf("input data for serialization cannot be nil")
}
originData := input[InputKeySerialization]
serializedData, err := sonic.Marshal(originData) // Serialize the entire input map
if err != nil {
return nil, fmt.Errorf("serialization error: %w", err)
}
return map[string]any{OutputKeySerialization: string(serializedData)}, nil
}

View File

@@ -0,0 +1,134 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package json
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
)
func TestNewJsonSerialize(t *testing.T) {
ctx := context.Background()
// Test with nil config
_, err := NewJsonSerializer(ctx, nil)
assert.Error(t, err)
assert.Contains(t, err.Error(), "config required")
// Test with missing InputTypes config
_, err = NewJsonSerializer(ctx, &SerializationConfig{})
assert.Error(t, err)
assert.Contains(t, err.Error(), "InputTypes is required")
// Test with valid config
validConfig := &SerializationConfig{
InputTypes: map[string]*vo.TypeInfo{
"testKey": {Type: "string"},
},
}
processor, err := NewJsonSerializer(ctx, validConfig)
assert.NoError(t, err)
assert.NotNil(t, processor)
}
func TestJsonSerialize_Invoke(t *testing.T) {
ctx := context.Background()
config := &SerializationConfig{
InputTypes: map[string]*vo.TypeInfo{
"stringKey": {Type: "string"},
"intKey": {Type: "integer"},
"boolKey": {Type: "boolean"},
"objKey": {Type: "object"},
},
}
processor, err := NewJsonSerializer(ctx, config)
assert.NoError(t, err)
// Test cases
tests := []struct {
name string
input map[string]any
expected string
expectErr bool
}{{
name: "Test string serialization",
input: map[string]any{
"input": "test",
},
expected: `"test"`,
expectErr: false,
}, {
name: "Test integer serialization",
input: map[string]any{
"input": 123,
},
expected: `123`,
expectErr: false,
}, {
name: "Test boolean serialization",
input: map[string]any{
"input": true,
},
expected: `true`,
expectErr: false,
}, {
name: "Test object serialization",
input: map[string]any{
"input": map[string]any{
"nestedKey": "nestedValue",
},
},
expected: `{"nestedKey":"nestedValue"}`,
expectErr: false,
}, {
name: "Test nil input",
input: nil,
expected: "",
expectErr: true,
}, {
name: "Test special character handling",
input: map[string]any{
"input": "\"test\"\nwith\twhitespace",
},
expected: `"\"test\"\nwith\twhitespace"`,
expectErr: false,
}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := processor.Invoke(ctx, tt.input)
if tt.expectErr {
assert.Error(t, err)
return
}
assert.NoError(t, err)
assert.Contains(t, result, OutputKeySerialization)
jsonStr, ok := result[OutputKeySerialization].(string)
assert.True(t, ok, "The output should be of type string")
assert.JSONEq(t, tt.expected, jsonStr)
})
}
}

View File

@@ -0,0 +1,63 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package knowledge
import (
"context"
"errors"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
)
type DeleterConfig struct {
KnowledgeID int64
KnowledgeDeleter knowledge.KnowledgeOperator
}
type KnowledgeDeleter struct {
config *DeleterConfig
}
func NewKnowledgeDeleter(_ context.Context, cfg *DeleterConfig) (*KnowledgeDeleter, error) {
if cfg.KnowledgeDeleter == nil {
return nil, errors.New("knowledge deleter is required")
}
return &KnowledgeDeleter{
config: cfg,
}, nil
}
func (k *KnowledgeDeleter) Delete(ctx context.Context, input map[string]any) (map[string]any, error) {
documentID, ok := input["documentID"].(string)
if !ok {
return nil, errors.New("documentID is required and must be a string")
}
req := &knowledge.DeleteDocumentRequest{
DocumentID: documentID,
}
response, err := k.config.KnowledgeDeleter.Delete(ctx, req)
if err != nil {
return nil, err
}
result := make(map[string]any)
result["isSuccess"] = response.IsSuccess
return result, nil
}

View File

@@ -0,0 +1,111 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package knowledge
import (
"context"
"errors"
"fmt"
"net/url"
"path/filepath"
"strings"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
)
type IndexerConfig struct {
KnowledgeID int64
ParsingStrategy *knowledge.ParsingStrategy
ChunkingStrategy *knowledge.ChunkingStrategy
KnowledgeIndexer knowledge.KnowledgeOperator
}
type KnowledgeIndexer struct {
config *IndexerConfig
}
func NewKnowledgeIndexer(_ context.Context, cfg *IndexerConfig) (*KnowledgeIndexer, error) {
if cfg.ParsingStrategy == nil {
return nil, errors.New("parsing strategy is required")
}
if cfg.ChunkingStrategy == nil {
return nil, errors.New("chunking strategy is required")
}
if cfg.KnowledgeIndexer == nil {
return nil, errors.New("knowledge indexer is required")
}
return &KnowledgeIndexer{
config: cfg,
}, nil
}
func (k *KnowledgeIndexer) Store(ctx context.Context, input map[string]any) (map[string]any, error) {
fileURL, ok := input["knowledge"].(string)
if !ok {
return nil, errors.New("knowledge is required")
}
fileName, ext, err := parseToFileNameAndFileExtension(fileURL)
if err != nil {
return nil, err
}
req := &knowledge.CreateDocumentRequest{
KnowledgeID: k.config.KnowledgeID,
ParsingStrategy: k.config.ParsingStrategy,
ChunkingStrategy: k.config.ChunkingStrategy,
FileURL: fileURL,
FileName: fileName,
FileExtension: ext,
}
response, err := k.config.KnowledgeIndexer.Store(ctx, req)
if err != nil {
return nil, err
}
result := make(map[string]any)
result["documentId"] = response.DocumentID
result["fileName"] = response.FileName
result["fileUrl"] = response.FileURL
return result, nil
}
func parseToFileNameAndFileExtension(fileURL string) (string, parser.FileExtension, error) {
u, err := url.Parse(fileURL)
if err != nil {
return "", "", err
}
fileName := u.Query().Get("x-wf-file_name")
if len(fileName) == 0 {
return "", "", errors.New("file name is required")
}
fileExt := strings.ToLower(strings.TrimPrefix(filepath.Ext(fileName), "."))
ext, support := parser.ValidateFileExtension(fileExt)
if !support {
return "", "", fmt.Errorf("unsupported file type: %s", fileExt)
}
return fileName, ext, nil
}

View File

@@ -0,0 +1,87 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package knowledge
import (
"context"
"errors"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
)
const outputList = "outputList"
type RetrieveConfig struct {
KnowledgeIDs []int64
RetrievalStrategy *knowledge.RetrievalStrategy
Retriever knowledge.KnowledgeOperator
}
type KnowledgeRetrieve struct {
config *RetrieveConfig
}
func NewKnowledgeRetrieve(_ context.Context, cfg *RetrieveConfig) (*KnowledgeRetrieve, error) {
if cfg == nil {
return nil, errors.New("cfg is required")
}
if cfg.Retriever == nil {
return nil, errors.New("retriever is required")
}
if len(cfg.KnowledgeIDs) == 0 {
return nil, errors.New("knowledgeI ids is required")
}
if cfg.RetrievalStrategy == nil {
return nil, errors.New("retrieval strategy is required")
}
return &KnowledgeRetrieve{
config: cfg,
}, nil
}
func (kr *KnowledgeRetrieve) Retrieve(ctx context.Context, input map[string]any) (map[string]any, error) {
query, ok := input["Query"].(string)
if !ok {
return nil, errors.New("capital query key is required")
}
req := &knowledge.RetrieveRequest{
Query: query,
KnowledgeIDs: kr.config.KnowledgeIDs,
RetrievalStrategy: kr.config.RetrievalStrategy,
}
response, err := kr.config.Retriever.Retrieve(ctx, req)
if err != nil {
return nil, err
}
result := make(map[string]any)
result[outputList] = slices.Transform(response.Slices, func(m *knowledge.Slice) any {
return map[string]any{
"documentId": m.DocumentID,
"output": m.Output,
}
})
return result, nil
}

View File

@@ -0,0 +1,847 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package llm
import (
"context"
"errors"
"fmt"
"io"
"strconv"
"strings"
"github.com/cloudwego/eino-ext/components/model/ark"
"github.com/cloudwego/eino-ext/components/model/deepseek"
"github.com/cloudwego/eino/callbacks"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/components/prompt"
"github.com/cloudwego/eino/components/tool"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/flow/agent/react"
"github.com/cloudwego/eino/schema"
callbacks2 "github.com/cloudwego/eino/utils/callbacks"
"golang.org/x/exp/maps"
"github.com/coze-dev/coze-studio/backend/domain/workflow"
crossknowledge "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
type Format int
const (
FormatText Format = iota
FormatMarkdown
FormatJSON
)
const (
jsonPromptFormat = `
Strictly reply in valid JSON format.
- Ensure the output strictly conforms to the JSON schema below
- Do not include explanations, comments, or any text outside the JSON.
Here is the output JSON schema:
'''
%s
'''
`
markdownPrompt = `
Strictly reply in valid Markdown format.
- For headings, use number signs (#).
- For list items, start with dashes (-).
- To emphasize text, wrap it with asterisks (*).
- For code or commands, surround them with backticks (` + "`" + `).
- For quoted text, use greater than signs (>).
- For links, wrap the text in square brackets [], followed by the URL in parentheses ().
- For images, use square brackets [] for the alt text, followed by the image URL in parentheses ().
`
)
const (
ReasoningOutputKey = "reasoning_content"
)
const knowledgeUserPromptTemplate = `根据引用的内容回答问题:
1.如果引用的内容里面包含 <img src=""> 的标签, 标签里的 src 字段表示图片地址, 需要在回答问题的时候展示出去, 输出格式为"![图片名称](图片地址)" 。
2.如果引用的内容不包含 <img src=""> 的标签, 你回答问题时不需要展示图片 。
例如:
如果内容为<img src="https://example.com/image.jpg">一只小猫,你的输出应为:![一只小猫](https://example.com/image.jpg)。
如果内容为<img src="https://example.com/image1.jpg">一只小猫 和 <img src="https://example.com/image2.jpg">一只小狗 和 <img src="https://example.com/image3.jpg">一只小牛,你的输出应为:![一只小猫](https://example.com/image1.jpg) 和 ![一只小狗](https://example.com/image2.jpg) 和 ![一只小牛](https://example.com/image3.jpg)
you can refer to the following content and do relevant searches to improve:
---
%s
question is:
`
const knowledgeIntentPrompt = `
# 角色:
你是一个知识库意图识别AI Agent。
## 目标:
- 按照「系统提示词」、用户需求、最新的聊天记录选择应该使用的知识库。
## 工作流程:
1. 分析「系统提示词」以确定用户的具体需求。
2. 如果「系统提示词」明确指明了要使用的知识库则直接返回这些知识库只输出它们的knowledge_id不需要再判断用户的输入
3. 检查每个知识库的knowledge_name和knowledge_description以了解它们各自的功能。
4. 根据用户需求,选择最符合的知识库。
5. 如果找到一个或多个合适的知识库输出它们的knowledge_id。如果没有合适的知识库输出0。
## 约束:
- 严格按照「系统提示词」和用户的需求选择知识库。「系统提示词」的优先级大于用户的需求
- 如果有多个合适的知识库将它们的knowledge_id用英文逗号连接后输出。
- 输出必须仅为knowledge_id或0不得包括任何其他内容或解释不要在id后面输出知识库名称。
## 输出示例
123,456
## 输出格式:
输出应该是一个纯数字或者由英文逗号连接的数字序列,具体取决于选择的知识库数量。不应包含任何其他文本或格式。
## 知识库列表如下
%s
## 「系统提示词」如下
%s
`
const (
knowledgeTemplateKey = "knowledge_template"
knowledgeChatModelKey = "knowledge_chat_model"
knowledgeLambdaKey = "knowledge_lambda"
knowledgeUserPromptTemplateKey = "knowledge_user_prompt_prefix"
templateNodeKey = "template"
llmNodeKey = "llm"
outputConvertNodeKey = "output_convert"
)
type NoReCallReplyMode int64
const (
NoReCallReplyModeOfDefault NoReCallReplyMode = 0
NoReCallReplyModeOfCustomize NoReCallReplyMode = 1
)
type RetrievalStrategy struct {
RetrievalStrategy *crossknowledge.RetrievalStrategy
NoReCallReplyMode NoReCallReplyMode
NoReCallReplyCustomizePrompt string
}
type KnowledgeRecallConfig struct {
ChatModel model.BaseChatModel
Retriever crossknowledge.KnowledgeOperator
RetrievalStrategy *RetrievalStrategy
SelectedKnowledgeDetails []*crossknowledge.KnowledgeDetail
}
type Config struct {
ChatModel ModelWithInfo
Tools []tool.BaseTool
SystemPrompt string
UserPrompt string
OutputFormat Format
InputFields map[string]*vo.TypeInfo
OutputFields map[string]*vo.TypeInfo
ToolsReturnDirectly map[string]bool
KnowledgeRecallConfig *KnowledgeRecallConfig
FullSources map[string]*nodes.SourceInfo
}
type LLM struct {
r compose.Runnable[map[string]any, map[string]any]
outputFormat Format
outputFields map[string]*vo.TypeInfo
canStream bool
requireCheckpoint bool
fullSources map[string]*nodes.SourceInfo
}
const (
rawOutputKey = "llm_raw_output_%s"
warningKey = "llm_warning_%s"
)
func jsonParse(ctx context.Context, data string, schema_ map[string]*vo.TypeInfo) (map[string]any, error) {
data = nodes.ExtractJSONString(data)
var result map[string]any
err := sonic.UnmarshalString(data, &result)
if err != nil {
c := execute.GetExeCtx(ctx)
if c != nil {
logs.CtxErrorf(ctx, "failed to parse json: %v, data: %s", err, data)
rawOutputK := fmt.Sprintf(rawOutputKey, c.NodeCtx.NodeKey)
warningK := fmt.Sprintf(warningKey, c.NodeCtx.NodeKey)
ctxcache.Store(ctx, rawOutputK, data)
ctxcache.Store(ctx, warningK, vo.WrapWarn(errno.ErrLLMStructuredOutputParseFail, err))
return map[string]any{}, nil
}
return nil, err
}
r, ws, err := nodes.ConvertInputs(ctx, result, schema_)
if err != nil {
return nil, vo.WrapError(errno.ErrLLMStructuredOutputParseFail, err)
}
if ws != nil {
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
}
return r, nil
}
func getReasoningContent(message *schema.Message) string {
c, ok := deepseek.GetReasoningContent(message)
if ok {
return c
}
c, ok = ark.GetReasoningContent(message)
if ok {
return c
}
return ""
}
type Options struct {
nested []nodes.NestedWorkflowOption
toolWorkflowSW *schema.StreamWriter[*entity.Message]
}
type Option func(o *Options)
func WithNestedWorkflowOptions(nested ...nodes.NestedWorkflowOption) Option {
return func(o *Options) {
o.nested = append(o.nested, nested...)
}
}
func WithToolWorkflowMessageWriter(sw *schema.StreamWriter[*entity.Message]) Option {
return func(o *Options) {
o.toolWorkflowSW = sw
}
}
type llmState = map[string]any
const agentModelName = "agent_model"
func New(ctx context.Context, cfg *Config) (*LLM, error) {
g := compose.NewGraph[map[string]any, map[string]any](compose.WithGenLocalState(func(ctx context.Context) (state llmState) {
return llmState{}
}))
var (
hasReasoning bool
canStream = true
)
format := cfg.OutputFormat
if format == FormatJSON {
if len(cfg.OutputFields) == 1 {
for _, v := range cfg.OutputFields {
if v.Type == vo.DataTypeString {
format = FormatText
break
}
}
} else if len(cfg.OutputFields) == 2 {
if _, ok := cfg.OutputFields[ReasoningOutputKey]; ok {
for k, v := range cfg.OutputFields {
if k != ReasoningOutputKey && v.Type == vo.DataTypeString {
format = FormatText
break
}
}
}
}
}
userPrompt := cfg.UserPrompt
switch format {
case FormatJSON:
jsonSchema, err := vo.TypeInfoToJSONSchema(cfg.OutputFields, nil)
if err != nil {
return nil, err
}
jsonPrompt := fmt.Sprintf(jsonPromptFormat, jsonSchema)
userPrompt = userPrompt + jsonPrompt
case FormatMarkdown:
userPrompt = userPrompt + markdownPrompt
case FormatText:
}
if cfg.KnowledgeRecallConfig != nil {
err := injectKnowledgeTool(ctx, g, cfg.UserPrompt, cfg.KnowledgeRecallConfig)
if err != nil {
return nil, err
}
userPrompt = fmt.Sprintf("{{%s}}%s", knowledgeUserPromptTemplateKey, userPrompt)
inputs := maps.Clone(cfg.InputFields)
inputs[knowledgeUserPromptTemplateKey] = &vo.TypeInfo{
Type: vo.DataTypeString,
}
sp := newPromptTpl(schema.System, cfg.SystemPrompt, inputs, nil)
up := newPromptTpl(schema.User, userPrompt, inputs, []string{knowledgeUserPromptTemplateKey})
template := newPrompts(sp, up, cfg.ChatModel)
_ = g.AddChatTemplateNode(templateNodeKey, template,
compose.WithStatePreHandler(func(ctx context.Context, in map[string]any, state llmState) (map[string]any, error) {
for k, v := range state {
in[k] = v
}
return in, nil
}))
_ = g.AddEdge(knowledgeLambdaKey, templateNodeKey)
} else {
sp := newPromptTpl(schema.System, cfg.SystemPrompt, cfg.InputFields, nil)
up := newPromptTpl(schema.User, userPrompt, cfg.InputFields, nil)
template := newPrompts(sp, up, cfg.ChatModel)
_ = g.AddChatTemplateNode(templateNodeKey, template)
_ = g.AddEdge(compose.START, templateNodeKey)
}
if len(cfg.Tools) > 0 {
m, ok := cfg.ChatModel.(model.ToolCallingChatModel)
if !ok {
return nil, errors.New("requires a ToolCallingChatModel to use with tools")
}
reactConfig := react.AgentConfig{
ToolCallingModel: m,
ToolsConfig: compose.ToolsNodeConfig{Tools: cfg.Tools},
ModelNodeName: agentModelName,
}
if len(cfg.ToolsReturnDirectly) > 0 {
reactConfig.ToolReturnDirectly = make(map[string]struct{}, len(cfg.ToolsReturnDirectly))
for k := range cfg.ToolsReturnDirectly {
reactConfig.ToolReturnDirectly[k] = struct{}{}
}
}
reactAgent, err := react.NewAgent(ctx, &reactConfig)
if err != nil {
return nil, err
}
agentNode, opts := reactAgent.ExportGraph()
opts = append(opts, compose.WithNodeName("workflow_llm_react_agent"))
_ = g.AddGraphNode(llmNodeKey, agentNode, opts...)
} else {
_ = g.AddChatModelNode(llmNodeKey, cfg.ChatModel)
}
_ = g.AddEdge(templateNodeKey, llmNodeKey)
if format == FormatJSON {
iConvert := func(ctx context.Context, msg *schema.Message) (map[string]any, error) {
return jsonParse(ctx, msg.Content, cfg.OutputFields)
}
convertNode := compose.InvokableLambda(iConvert)
_ = g.AddLambdaNode(outputConvertNodeKey, convertNode)
canStream = false
} else {
var outputKey string
if len(cfg.OutputFields) != 1 && len(cfg.OutputFields) != 2 {
panic("impossible")
}
for k, v := range cfg.OutputFields {
if v.Type != vo.DataTypeString {
panic("impossible")
}
if k == ReasoningOutputKey {
hasReasoning = true
} else {
outputKey = k
}
}
iConvert := func(_ context.Context, msg *schema.Message, _ ...struct{}) (map[string]any, error) {
out := map[string]any{outputKey: msg.Content}
if hasReasoning {
out[ReasoningOutputKey] = getReasoningContent(msg)
}
return out, nil
}
tConvert := func(_ context.Context, s *schema.StreamReader[*schema.Message], _ ...struct{}) (*schema.StreamReader[map[string]any], error) {
sr, sw := schema.Pipe[map[string]any](0)
safego.Go(ctx, func() {
reasoningDone := false
for {
msg, err := s.Recv()
if err != nil {
if err == io.EOF {
sw.Send(map[string]any{
outputKey: nodes.KeyIsFinished,
}, nil)
sw.Close()
return
}
sw.Send(nil, err)
sw.Close()
return
}
if hasReasoning {
reasoning := getReasoningContent(msg)
if len(reasoning) > 0 {
sw.Send(map[string]any{ReasoningOutputKey: reasoning}, nil)
}
}
if len(msg.Content) > 0 {
if !reasoningDone && hasReasoning {
reasoningDone = true
sw.Send(map[string]any{
ReasoningOutputKey: nodes.KeyIsFinished,
}, nil)
}
sw.Send(map[string]any{outputKey: msg.Content}, nil)
}
}
})
return sr, nil
}
convertNode, err := compose.AnyLambda(iConvert, nil, nil, tConvert)
if err != nil {
return nil, err
}
_ = g.AddLambdaNode(outputConvertNodeKey, convertNode)
}
_ = g.AddEdge(llmNodeKey, outputConvertNodeKey)
_ = g.AddEdge(outputConvertNodeKey, compose.END)
requireCheckpoint := false
if len(cfg.Tools) > 0 {
requireCheckpoint = true
}
var opts []compose.GraphCompileOption
if requireCheckpoint {
opts = append(opts, compose.WithCheckPointStore(workflow.GetRepository()))
}
opts = append(opts, compose.WithGraphName("workflow_llm_node_graph"))
r, err := g.Compile(ctx, opts...)
if err != nil {
return nil, err
}
llm := &LLM{
r: r,
outputFormat: format,
canStream: canStream,
requireCheckpoint: requireCheckpoint,
fullSources: cfg.FullSources,
}
return llm, nil
}
func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...Option) (composeOpts []compose.Option, resumingEvent *entity.InterruptEvent, err error) {
c := execute.GetExeCtx(ctx)
if c != nil {
resumingEvent = c.NodeCtx.ResumingEvent
}
var previousToolES map[string]*entity.ToolInterruptEvent
if c != nil && c.RootCtx.ResumeEvent != nil {
// check if we are not resuming, but previously interrupted. Interrupt immediately.
if resumingEvent == nil {
err := compose.ProcessState(ctx, func(ctx context.Context, state ToolInterruptEventStore) error {
var e error
previousToolES, e = state.GetToolInterruptEvents(c.NodeKey)
if e != nil {
return e
}
return nil
})
if err != nil {
return nil, nil, err
}
if len(previousToolES) > 0 {
return nil, nil, compose.InterruptAndRerun
}
}
}
if l.requireCheckpoint && c != nil {
checkpointID := fmt.Sprintf("%d_%s", c.RootCtx.RootExecuteID, c.NodeCtx.NodeKey)
composeOpts = append(composeOpts, compose.WithCheckPointID(checkpointID))
}
llmOpts := &Options{}
for _, opt := range opts {
opt(llmOpts)
}
nestedOpts := &nodes.NestedWorkflowOptions{}
for _, opt := range llmOpts.nested {
opt(nestedOpts)
}
composeOpts = append(composeOpts, nestedOpts.GetOptsForNested()...)
if resumingEvent != nil {
var (
resumeData string
e error
allIEs = make(map[string]*entity.ToolInterruptEvent)
)
err = compose.ProcessState(ctx, func(ctx context.Context, state ToolInterruptEventStore) error {
allIEs, e = state.GetToolInterruptEvents(c.NodeKey)
if e != nil {
return e
}
allIEs = maps.Clone(allIEs)
resumeData, e = state.ResumeToolInterruptEvent(c.NodeKey, resumingEvent.ToolInterruptEvent.ToolCallID)
return e
})
if err != nil {
return nil, nil, err
}
composeOpts = append(composeOpts, compose.WithToolsNodeOption(
compose.WithToolOption(
execute.WithResume(&entity.ResumeRequest{
ExecuteID: resumingEvent.ToolInterruptEvent.ExecuteID,
EventID: resumingEvent.ToolInterruptEvent.ID,
ResumeData: resumeData,
}, allIEs))))
chatModelHandler := callbacks2.NewHandlerHelper().ChatModel(&callbacks2.ModelCallbackHandler{
OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.CallbackInput) context.Context {
if runInfo.Name != agentModelName {
return ctx
}
// react agent loops back to chat model after resuming,
// pop the previous interrupt event immediately
ie, deleted, e := workflow.GetRepository().PopFirstInterruptEvent(ctx, c.RootExecuteID)
if e != nil {
logs.CtxErrorf(ctx, "failed to pop first interrupt event on react agent chatmodel start: %v", err)
return ctx
}
if !deleted {
logs.CtxErrorf(ctx, "failed to pop first interrupt event on react agent chatmodel start: not deleted")
return ctx
}
if ie.ID != resumingEvent.ID {
logs.CtxErrorf(ctx, "failed to pop first interrupt event on react agent chatmodel start, "+
"deleted ID: %d, resumingEvent ID: %d", ie.ID, resumingEvent.ID)
return ctx
}
return ctx
},
}).Handler()
composeOpts = append(composeOpts, compose.WithCallbacks(chatModelHandler))
}
if c != nil {
exeCfg := c.ExeCfg
composeOpts = append(composeOpts, compose.WithToolsNodeOption(compose.WithToolOption(execute.WithExecuteConfig(exeCfg))))
}
if llmOpts.toolWorkflowSW != nil {
toolMsgOpt, toolMsgSR := execute.WithMessagePipe()
composeOpts = append(composeOpts, toolMsgOpt)
safego.Go(ctx, func() {
defer toolMsgSR.Close()
for {
msg, err := toolMsgSR.Recv()
if err != nil {
if err == io.EOF {
return
}
logs.CtxErrorf(ctx, "failed to receive message from tool workflow: %v", err)
return
}
logs.Infof("received message from tool workflow: %+v", msg)
llmOpts.toolWorkflowSW.Send(msg, nil)
}
})
}
resolvedSources, err := nodes.ResolveStreamSources(ctx, l.fullSources)
if err != nil {
return nil, nil, err
}
var nodeKey vo.NodeKey
if c != nil && c.NodeCtx != nil {
nodeKey = c.NodeCtx.NodeKey
}
ctxcache.Store(ctx, fmt.Sprintf(sourceKey, nodeKey), resolvedSources)
return composeOpts, resumingEvent, nil
}
func handleInterrupt(ctx context.Context, err error, resumingEvent *entity.InterruptEvent) error {
info, ok := compose.ExtractInterruptInfo(err)
if !ok {
return err
}
info = info.SubGraphs["llm"] // 'llm' is the node key of the react agent
var extra any
for i := range info.RerunNodesExtra {
extra = info.RerunNodesExtra[i]
break
}
toolsNodeExtra, ok := extra.(*compose.ToolsInterruptAndRerunExtra)
if !ok {
return fmt.Errorf("llm rerun node extra type expected to be ToolsInterruptAndRerunExtra, actual: %T", extra)
}
id, err := workflow.GetRepository().GenID(ctx)
if err != nil {
return err
}
var (
previousInterruptedCallID string
highPriorityEvent *entity.ToolInterruptEvent
)
if resumingEvent != nil {
previousInterruptedCallID = resumingEvent.ToolInterruptEvent.ToolCallID
}
c := execute.GetExeCtx(ctx)
toolIEs := make([]*entity.ToolInterruptEvent, 0, len(toolsNodeExtra.RerunExtraMap))
for callID := range toolsNodeExtra.RerunExtraMap {
subIE, ok := toolsNodeExtra.RerunExtraMap[callID].(*entity.ToolInterruptEvent)
if !ok {
return fmt.Errorf("llm rerun node extra type expected to be ToolInterruptEvent, actual: %T", extra)
}
if subIE.ExecuteID == 0 {
subIE.ExecuteID = c.RootExecuteID
}
toolIEs = append(toolIEs, subIE)
if subIE.ToolCallID == previousInterruptedCallID {
highPriorityEvent = subIE
}
}
ie := &entity.InterruptEvent{
ID: id,
NodeKey: c.NodeKey,
NodeType: entity.NodeTypeLLM,
NodeTitle: c.NodeName,
NodeIcon: entity.NodeMetaByNodeType(entity.NodeTypeLLM).IconURL,
EventType: entity.InterruptEventLLM,
}
if highPriorityEvent != nil {
ie.ToolInterruptEvent = highPriorityEvent
} else {
ie.ToolInterruptEvent = toolIEs[0]
}
err = compose.ProcessState(ctx, func(ctx context.Context, ieStore ToolInterruptEventStore) error {
for i := range toolIEs {
e := ieStore.SetToolInterruptEvent(c.NodeKey, toolIEs[i].ToolCallID, toolIEs[i])
if e != nil {
return e
}
}
return nil
})
if err != nil {
return err
}
return compose.NewInterruptAndRerunErr(ie)
}
func (l *LLM) Chat(ctx context.Context, in map[string]any, opts ...Option) (out map[string]any, err error) {
composeOpts, resumingEvent, err := l.prepare(ctx, in, opts...)
if err != nil {
return nil, err
}
out, err = l.r.Invoke(ctx, in, composeOpts...)
if err != nil {
err = handleInterrupt(ctx, err, resumingEvent)
return nil, err
}
return out, nil
}
func (l *LLM) ChatStream(ctx context.Context, in map[string]any, opts ...Option) (out *schema.StreamReader[map[string]any], err error) {
composeOpts, resumingEvent, err := l.prepare(ctx, in, opts...)
if err != nil {
return nil, err
}
out, err = l.r.Stream(ctx, in, composeOpts...)
if err != nil {
err = handleInterrupt(ctx, err, resumingEvent)
return nil, err
}
return out, nil
}
func injectKnowledgeTool(_ context.Context, g *compose.Graph[map[string]any, map[string]any], userPrompt string, cfg *KnowledgeRecallConfig) error {
selectedKwDetails, err := sonic.MarshalString(cfg.SelectedKnowledgeDetails)
if err != nil {
return err
}
_ = g.AddChatTemplateNode(knowledgeTemplateKey,
prompt.FromMessages(schema.Jinja2,
schema.SystemMessage(fmt.Sprintf(knowledgeIntentPrompt, selectedKwDetails, userPrompt)),
), compose.WithStatePreHandler(func(ctx context.Context, in map[string]any, state llmState) (map[string]any, error) {
for k, v := range in {
state[k] = v
}
return in, nil
}))
_ = g.AddChatModelNode(knowledgeChatModelKey, cfg.ChatModel)
_ = g.AddLambdaNode(knowledgeLambdaKey, compose.InvokableLambda(func(ctx context.Context, input *schema.Message) (output map[string]any, err error) {
modelPredictionIDs := strings.Split(input.Content, ",")
selectKwIDs := slices.ToMap(cfg.SelectedKnowledgeDetails, func(e *crossknowledge.KnowledgeDetail) (string, int64) {
return strconv.Itoa(int(e.ID)), e.ID
})
recallKnowledgeIDs := make([]int64, 0)
for _, id := range modelPredictionIDs {
if kid, ok := selectKwIDs[id]; ok {
recallKnowledgeIDs = append(recallKnowledgeIDs, kid)
}
}
if len(recallKnowledgeIDs) == 0 {
return make(map[string]any), nil
}
docs, err := cfg.Retriever.Retrieve(ctx, &crossknowledge.RetrieveRequest{
Query: userPrompt,
KnowledgeIDs: recallKnowledgeIDs,
RetrievalStrategy: cfg.RetrievalStrategy.RetrievalStrategy,
})
if err != nil {
return nil, err
}
if len(docs.Slices) == 0 && cfg.RetrievalStrategy.NoReCallReplyMode == NoReCallReplyModeOfDefault {
return make(map[string]any), nil
}
sb := strings.Builder{}
if len(docs.Slices) == 0 && cfg.RetrievalStrategy.NoReCallReplyMode == NoReCallReplyModeOfCustomize {
sb.WriteString("recall slice 1: \n")
sb.WriteString(cfg.RetrievalStrategy.NoReCallReplyCustomizePrompt + "\n")
}
for idx, msg := range docs.Slices {
sb.WriteString(fmt.Sprintf("recall slice %d:\n", idx+1))
sb.WriteString(fmt.Sprintf("%s\n", msg.Output))
}
output = map[string]any{
knowledgeUserPromptTemplateKey: fmt.Sprintf(knowledgeUserPromptTemplate, sb.String()),
}
return output, nil
}))
_ = g.AddEdge(compose.START, knowledgeTemplateKey)
_ = g.AddEdge(knowledgeTemplateKey, knowledgeChatModelKey)
_ = g.AddEdge(knowledgeChatModelKey, knowledgeLambdaKey)
return nil
}
type ToolInterruptEventStore interface {
SetToolInterruptEvent(llmNodeKey vo.NodeKey, toolCallID string, ie *entity.ToolInterruptEvent) error
GetToolInterruptEvents(llmNodeKey vo.NodeKey) (map[string]*entity.ToolInterruptEvent, error)
ResumeToolInterruptEvent(llmNodeKey vo.NodeKey, toolCallID string) (string, error)
}
func (l *LLM) ToCallbackOutput(ctx context.Context, output map[string]any) (*nodes.StructuredCallbackOutput, error) {
c := execute.GetExeCtx(ctx)
if c == nil {
return &nodes.StructuredCallbackOutput{
Output: output,
RawOutput: output,
}, nil
}
rawOutputK := fmt.Sprintf(rawOutputKey, c.NodeKey)
warningK := fmt.Sprintf(warningKey, c.NodeKey)
rawOutput, found := ctxcache.Get[string](ctx, rawOutputK)
if !found {
return &nodes.StructuredCallbackOutput{
Output: output,
RawOutput: output,
}, nil
}
warning, found := ctxcache.Get[vo.WorkflowError](ctx, warningK)
if !found {
return &nodes.StructuredCallbackOutput{
Output: output,
RawOutput: map[string]any{"output": rawOutput},
}, nil
}
return &nodes.StructuredCallbackOutput{
Output: output,
RawOutput: map[string]any{"output": rawOutput},
Error: warning,
}, nil
}

View File

@@ -0,0 +1,184 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package llm
import (
"context"
"errors"
"github.com/cloudwego/eino/callbacks"
"github.com/cloudwego/eino/components"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/schema"
crossmodelmgr "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
)
type ModelWithInfo interface {
model.BaseChatModel
Info(ctx context.Context) *crossmodelmgr.Model
}
type ModelForLLM struct {
Model model.BaseChatModel
MInfo *crossmodelmgr.Model
FallbackModel model.BaseChatModel
FallbackInfo *crossmodelmgr.Model
UseFallback func(ctx context.Context) bool
modelEnableCallback bool
fallbackEnableCallback bool
}
func NewModel(m model.BaseChatModel, info *crossmodelmgr.Model) *ModelForLLM {
return &ModelForLLM{
Model: m,
MInfo: info,
UseFallback: func(ctx context.Context) bool {
return false
},
modelEnableCallback: components.IsCallbacksEnabled(m),
}
}
func NewModelWithFallback(m, f model.BaseChatModel, info, fInfo *crossmodelmgr.Model) *ModelForLLM {
return &ModelForLLM{
Model: m,
MInfo: info,
FallbackModel: f,
FallbackInfo: fInfo,
UseFallback: func(ctx context.Context) bool {
exeCtx := execute.GetExeCtx(ctx)
if exeCtx == nil || exeCtx.NodeCtx == nil {
return false
}
return exeCtx.CurrentRetryCount > 0
},
modelEnableCallback: components.IsCallbacksEnabled(m),
fallbackEnableCallback: components.IsCallbacksEnabled(f),
}
}
func (m *ModelForLLM) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (
output *schema.Message, err error) {
if m.UseFallback(ctx) {
if !m.fallbackEnableCallback {
defer func() {
if err != nil {
_ = callbacks.OnError(ctx, err)
} else {
_ = callbacks.OnEnd(ctx, output)
}
}()
ctx = callbacks.OnStart(ctx, input)
}
return m.FallbackModel.Generate(ctx, input, opts...)
}
if !m.modelEnableCallback {
defer func() {
if err != nil {
_ = callbacks.OnError(ctx, err)
} else {
_ = callbacks.OnEnd(ctx, output)
}
}()
ctx = callbacks.OnStart(ctx, input)
}
return m.Model.Generate(ctx, input, opts...)
}
func (m *ModelForLLM) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (
output *schema.StreamReader[*schema.Message], err error) {
if m.UseFallback(ctx) {
if !m.fallbackEnableCallback {
defer func() {
if err != nil {
_ = callbacks.OnError(ctx, err)
} else {
_, output = callbacks.OnEndWithStreamOutput(ctx, output)
}
}()
ctx = callbacks.OnStart(ctx, input)
}
return m.FallbackModel.Stream(ctx, input, opts...)
}
if !m.modelEnableCallback {
defer func() {
if err != nil {
_ = callbacks.OnError(ctx, err)
} else {
_, output = callbacks.OnEndWithStreamOutput(ctx, output)
}
}()
ctx = callbacks.OnStart(ctx, input)
}
return m.Model.Stream(ctx, input, opts...)
}
func (m *ModelForLLM) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatModel, error) {
toolModel, ok := m.Model.(model.ToolCallingChatModel)
if !ok {
return nil, errors.New("requires a ToolCallingChatModel to use with tools")
}
var err error
toolModel, err = toolModel.WithTools(tools)
if err != nil {
return nil, err
}
var fallbackToolModel model.ToolCallingChatModel
if m.FallbackModel != nil {
fallbackToolModel, ok = m.FallbackModel.(model.ToolCallingChatModel)
if !ok {
return nil, errors.New("requires a ToolCallingChatModel to use with tools")
}
fallbackToolModel, err = fallbackToolModel.WithTools(tools)
if err != nil {
return nil, err
}
}
return &ModelForLLM{
Model: toolModel,
MInfo: m.MInfo,
FallbackModel: fallbackToolModel,
FallbackInfo: m.FallbackInfo,
UseFallback: m.UseFallback,
modelEnableCallback: m.modelEnableCallback,
fallbackEnableCallback: m.fallbackEnableCallback,
}, nil
}
func (m *ModelForLLM) IsCallbacksEnabled() bool {
return true
}
func (m *ModelForLLM) Info(ctx context.Context) *crossmodelmgr.Model {
if m.UseFallback(ctx) {
return m.FallbackInfo
}
return m.MInfo
}

View File

@@ -0,0 +1,287 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package llm
import (
"context"
"fmt"
"github.com/cloudwego/eino/components/prompt"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
)
type prompts struct {
sp *promptTpl
up *promptTpl
mwi ModelWithInfo
}
type promptTpl struct {
role schema.RoleType
tpl string
parts []promptPart
hasMultiModal bool
reservedKeys []string
}
type promptPart struct {
part nodes.TemplatePart
fileType *vo.FileSubType
}
func newPromptTpl(role schema.RoleType,
tpl string,
inputTypes map[string]*vo.TypeInfo,
reservedKeys []string,
) *promptTpl {
if len(tpl) == 0 {
return nil
}
parts := nodes.ParseTemplate(tpl)
promptParts := make([]promptPart, 0, len(parts))
hasMultiModal := false
for _, part := range parts {
if !part.IsVariable {
promptParts = append(promptParts, promptPart{
part: part,
})
continue
}
tInfo := part.TypeInfo(inputTypes)
if tInfo == nil || tInfo.Type != vo.DataTypeFile {
promptParts = append(promptParts, promptPart{
part: part,
})
continue
}
promptParts = append(promptParts, promptPart{
part: part,
fileType: tInfo.FileType,
})
hasMultiModal = true
}
return &promptTpl{
role: role,
tpl: tpl,
parts: promptParts,
hasMultiModal: hasMultiModal,
reservedKeys: reservedKeys,
}
}
const sourceKey = "sources_%s"
func newPrompts(sp, up *promptTpl, model ModelWithInfo) *prompts {
return &prompts{
sp: sp,
up: up,
mwi: model,
}
}
func (pl *promptTpl) render(ctx context.Context, vs map[string]any,
sources map[string]*nodes.SourceInfo,
supportedModals map[modelmgr.Modal]bool) (*schema.Message, error) {
if !pl.hasMultiModal || len(supportedModals) == 0 {
var opts []nodes.RenderOption
if len(pl.reservedKeys) > 0 {
opts = append(opts, nodes.WithReservedKey(pl.reservedKeys...))
}
r, err := nodes.Render(ctx, pl.tpl, vs, sources, opts...)
if err != nil {
return nil, err
}
return &schema.Message{
Role: pl.role,
Content: r,
}, nil
}
multiParts := make([]schema.ChatMessagePart, 0, len(pl.parts))
m, err := sonic.Marshal(vs)
if err != nil {
return nil, err
}
for _, part := range pl.parts {
if !part.part.IsVariable {
multiParts = append(multiParts, schema.ChatMessagePart{
Type: schema.ChatMessagePartTypeText,
Text: part.part.Value,
})
continue
}
skipped, invalid := part.part.Skipped(sources)
if invalid {
var reserved bool
for _, k := range pl.reservedKeys {
if k == part.part.Root {
reserved = true
break
}
}
if !reserved {
continue
}
}
if skipped {
continue
}
r, err := part.part.Render(m)
if err != nil {
return nil, err
}
if part.fileType == nil {
multiParts = append(multiParts, schema.ChatMessagePart{
Type: schema.ChatMessagePartTypeText,
Text: r,
})
continue
}
switch *part.fileType {
case vo.FileTypeImage, vo.FileTypeSVG:
if _, ok := supportedModals[modelmgr.ModalImage]; !ok {
multiParts = append(multiParts, schema.ChatMessagePart{
Type: schema.ChatMessagePartTypeText,
Text: r,
})
} else {
multiParts = append(multiParts, schema.ChatMessagePart{
Type: schema.ChatMessagePartTypeImageURL,
ImageURL: &schema.ChatMessageImageURL{
URL: r,
},
})
}
case vo.FileTypeAudio, vo.FileTypeVoice:
if _, ok := supportedModals[modelmgr.ModalAudio]; !ok {
multiParts = append(multiParts, schema.ChatMessagePart{
Type: schema.ChatMessagePartTypeText,
Text: r,
})
} else {
multiParts = append(multiParts, schema.ChatMessagePart{
Type: schema.ChatMessagePartTypeAudioURL,
AudioURL: &schema.ChatMessageAudioURL{
URL: r,
},
})
}
case vo.FileTypeVideo:
if _, ok := supportedModals[modelmgr.ModalVideo]; !ok {
multiParts = append(multiParts, schema.ChatMessagePart{
Type: schema.ChatMessagePartTypeText,
Text: r,
})
} else {
multiParts = append(multiParts, schema.ChatMessagePart{
Type: schema.ChatMessagePartTypeVideoURL,
VideoURL: &schema.ChatMessageVideoURL{
URL: r,
},
})
}
default:
if _, ok := supportedModals[modelmgr.ModalFile]; !ok {
multiParts = append(multiParts, schema.ChatMessagePart{
Type: schema.ChatMessagePartTypeText,
Text: r,
})
} else {
multiParts = append(multiParts, schema.ChatMessagePart{
Type: schema.ChatMessagePartTypeFileURL,
FileURL: &schema.ChatMessageFileURL{
URL: r,
},
})
}
}
}
return &schema.Message{
Role: pl.role,
MultiContent: multiParts,
}, nil
}
func (p *prompts) Format(ctx context.Context, vs map[string]any, _ ...prompt.Option) (
_ []*schema.Message, err error) {
exeCtx := execute.GetExeCtx(ctx)
var nodeKey vo.NodeKey
if exeCtx != nil && exeCtx.NodeCtx != nil {
nodeKey = exeCtx.NodeCtx.NodeKey
}
sk := fmt.Sprintf(sourceKey, nodeKey)
sources, ok := ctxcache.Get[map[string]*nodes.SourceInfo](ctx, sk)
if !ok {
return nil, fmt.Errorf("resolved sources not found llm node, key: %s", sk)
}
supportedModal := map[modelmgr.Modal]bool{}
mInfo := p.mwi.Info(ctx)
if mInfo != nil {
for i := range mInfo.Meta.Capability.InputModal {
supportedModal[mInfo.Meta.Capability.InputModal[i]] = true
}
}
var systemMsg, userMsg *schema.Message
if p.sp != nil {
systemMsg, err = p.sp.render(ctx, vs, sources, supportedModal)
if err != nil {
return nil, err
}
}
if p.up != nil {
userMsg, err = p.up.render(ctx, vs, sources, supportedModal)
if err != nil {
return nil, err
}
}
if userMsg == nil {
// give it a default empty message.
// Some model may fail on empty message such as this one.
userMsg = schema.UserMessage("")
}
if systemMsg == nil {
return []*schema.Message{userMsg}, nil
}
return []*schema.Message{systemMsg, userMsg}, nil
}

View File

@@ -0,0 +1,45 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package loop
import (
"context"
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
)
type Break struct {
parentIntermediateStore variable.Store
}
func NewBreak(_ context.Context, store variable.Store) (*Break, error) {
return &Break{
parentIntermediateStore: store,
}, nil
}
const BreakKey = "$break"
func (b *Break) DoBreak(ctx context.Context, _ map[string]any) (map[string]any, error) {
err := b.parentIntermediateStore.Set(ctx, compose.FieldPath{BreakKey}, true)
if err != nil {
return nil, err
}
return map[string]any{}, nil
}

View File

@@ -0,0 +1,424 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package loop
import (
"context"
"errors"
"fmt"
"math"
"reflect"
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
type Loop struct {
config *Config
outputs map[string]*vo.FieldSource
outputVars map[string]string
}
type Config struct {
LoopNodeKey vo.NodeKey
LoopType Type
InputArrays []string
IntermediateVars map[string]*vo.TypeInfo
Outputs []*vo.FieldInfo
Inner compose.Runnable[map[string]any, map[string]any]
}
type Type string
const (
ByArray Type = "by_array"
ByIteration Type = "by_iteration"
Infinite Type = "infinite"
)
func NewLoop(_ context.Context, conf *Config) (*Loop, error) {
if conf == nil {
return nil, errors.New("config is nil")
}
if conf.LoopType == ByArray {
if len(conf.InputArrays) == 0 {
return nil, errors.New("input arrays is empty when loop type is ByArray")
}
}
loop := &Loop{
config: conf,
outputs: make(map[string]*vo.FieldSource),
outputVars: make(map[string]string),
}
for _, info := range conf.Outputs {
if len(info.Path) != 1 {
return nil, fmt.Errorf("invalid output path: %s", info.Path)
}
k := info.Path[0]
fromPath := info.Source.Ref.FromPath
if info.Source.Ref != nil && info.Source.Ref.VariableType != nil &&
*info.Source.Ref.VariableType == vo.ParentIntermediate {
if len(fromPath) > 1 {
return nil, fmt.Errorf("loop output refers to intermediate variable, but path length > 1: %v", fromPath)
}
if _, ok := conf.IntermediateVars[fromPath[0]]; !ok {
return nil, fmt.Errorf("loop output refers to intermediate variable, but not found in intermediate vars: %v", fromPath)
}
loop.outputVars[k] = fromPath[0]
continue
}
loop.outputs[k] = &info.Source
}
return loop, nil
}
const (
Count = "loopCount"
)
func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.NestedWorkflowOption) (out map[string]any, err error) {
maxIter, err := l.getMaxIter(in)
if err != nil {
return nil, err
}
arrays := make(map[string][]any, len(l.config.InputArrays))
for _, arrayKey := range l.config.InputArrays {
a, ok := nodes.TakeMapValue(in, compose.FieldPath{arrayKey})
if !ok {
return nil, fmt.Errorf("incoming array not present in input: %s", arrayKey)
}
arrays[arrayKey] = a.([]any)
}
options := &nodes.NestedWorkflowOptions{}
for _, opt := range opts {
opt(options)
}
var (
existingCState *nodes.NestedWorkflowState
intermediateVars map[string]*any
output map[string]any
hasBreak = any(false)
)
err = compose.ProcessState(ctx, func(ctx context.Context, getter nodes.NestedWorkflowAware) error {
var e error
existingCState, _, e = getter.GetNestedWorkflowState(l.config.LoopNodeKey)
if e != nil {
return e
}
return nil
})
if err != nil {
return nil, err
}
if existingCState != nil {
output = existingCState.FullOutput
intermediateVars = make(map[string]*any, len(existingCState.IntermediateVars))
for k := range existingCState.IntermediateVars {
intermediateVars[k] = ptr.Of(existingCState.IntermediateVars[k])
}
intermediateVars[BreakKey] = &hasBreak
} else {
output = make(map[string]any, len(l.outputs))
for k := range l.outputs {
output[k] = make([]any, 0)
}
intermediateVars = make(map[string]*any, len(l.config.IntermediateVars))
for varKey := range l.config.IntermediateVars {
v, ok := nodes.TakeMapValue(in, compose.FieldPath{varKey})
if !ok {
return nil, fmt.Errorf("incoming intermediate variable not present in input: %s", varKey)
}
intermediateVars[varKey] = &v
}
intermediateVars[BreakKey] = &hasBreak
}
ctx = nodes.InitIntermediateVars(ctx, intermediateVars, l.config.IntermediateVars)
getIthInput := func(i int) (map[string]any, map[string]any, error) {
input := make(map[string]any)
for k, v := range in { // carry over other values
if k == Count {
continue
}
if _, ok := arrays[k]; ok {
continue
}
if _, ok := intermediateVars[k]; ok {
continue
}
input[k] = v
}
input[string(l.config.LoopNodeKey)+"#index"] = int64(i)
items := make(map[string]any)
for arrayKey := range arrays {
ele := arrays[arrayKey][i]
items[arrayKey] = ele
currentKey := string(l.config.LoopNodeKey) + "#" + arrayKey
// Recursively expand map[string]any elements
if m, ok := ele.(map[string]any); ok {
var expand func(prefix string, val interface{})
expand = func(prefix string, val interface{}) {
if nestedMap, ok := val.(map[string]any); ok {
for k, v := range nestedMap {
expand(prefix+"#"+k, v)
}
} else {
input[prefix] = val
}
}
expand(currentKey, m)
} else {
input[currentKey] = ele
}
}
return input, items, nil
}
setIthOutput := func(i int, taskOutput map[string]any) {
for arrayKey := range l.outputs {
source := l.outputs[arrayKey]
fromValue, ok := nodes.TakeMapValue(taskOutput, append(compose.FieldPath{string(source.Ref.FromNodeKey)}, source.Ref.FromPath...))
if ok {
output[arrayKey] = append(output[arrayKey].([]any), fromValue)
}
}
}
var (
index2Done = map[int]bool{}
index2InterruptInfo = map[int]*compose.InterruptInfo{}
resumed = map[int]bool{}
)
for i := 0; i < maxIter; i++ {
select {
case <-ctx.Done():
return nil, ctx.Err() // canceled by Eino workflow engine
default:
}
if existingCState != nil {
if existingCState.Index2Done[i] == true {
continue
}
if existingCState.Index2InterruptInfo[i] != nil {
if len(options.GetResumeIndexes()) > 0 {
if _, ok := options.GetResumeIndexes()[i]; !ok {
// previously interrupted, but not resumed this time, should not happen
panic("impossible")
}
}
}
resumed[i] = true
}
input, items, err := getIthInput(i)
if err != nil {
return nil, err
}
subCtx, checkpointID := execute.InheritExeCtxWithBatchInfo(ctx, i, items)
ithOpts := options.GetOptsForNested()
ithOpts = append(ithOpts, options.GetOptsForIndexed(i)...)
if checkpointID != "" {
ithOpts = append(ithOpts, compose.WithCheckPointID(checkpointID))
}
if len(options.GetResumeIndexes()) > 0 {
stateModifier, ok := options.GetResumeIndexes()[i]
if ok {
fmt.Println("has state modifier for ith run: ", i, ", checkpointID: ", checkpointID)
ithOpts = append(ithOpts, compose.WithStateModifier(stateModifier))
}
}
taskOutput, err := l.config.Inner.Invoke(subCtx, input, ithOpts...)
if err != nil {
info, ok := compose.ExtractInterruptInfo(err)
if !ok {
return nil, err
}
index2InterruptInfo[i] = info
break
}
setIthOutput(i, taskOutput)
index2Done[i] = true
if hasBreak.(bool) {
break
}
}
// delete the interruptions that have been resumed
for index := range resumed {
delete(existingCState.Index2InterruptInfo, index)
}
compState := existingCState
if compState == nil {
compState = &nodes.NestedWorkflowState{
Index2Done: index2Done,
Index2InterruptInfo: index2InterruptInfo,
FullOutput: output,
IntermediateVars: convertIntermediateVars(intermediateVars),
}
} else {
for i := range index2Done {
compState.Index2Done[i] = index2Done[i]
}
for i := range index2InterruptInfo {
compState.Index2InterruptInfo[i] = index2InterruptInfo[i]
}
compState.FullOutput = output
compState.IntermediateVars = convertIntermediateVars(intermediateVars)
}
if len(index2InterruptInfo) > 0 { // this invocation of batch.Execute has new interruptions
iEvent := &entity.InterruptEvent{
NodeKey: l.config.LoopNodeKey,
NodeType: entity.NodeTypeLoop,
NestedInterruptInfo: index2InterruptInfo, // only emit the newly generated interruptInfo
}
err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error {
if e := setter.SaveNestedWorkflowState(l.config.LoopNodeKey, compState); e != nil {
return e
}
return setter.SetInterruptEvent(l.config.LoopNodeKey, iEvent)
})
if err != nil {
return nil, err
}
fmt.Println("save interruptEvent in state within loop: ", iEvent)
fmt.Println("save composite info in state within loop: ", compState)
return nil, compose.InterruptAndRerun
} else {
err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error {
return setter.SaveNestedWorkflowState(l.config.LoopNodeKey, compState)
})
if err != nil {
return nil, err
}
fmt.Println("save composite info in state within loop: ", compState)
}
if existingCState != nil && len(existingCState.Index2InterruptInfo) > 0 {
fmt.Println("no interrupt thrown this round, but has historical interrupt events: ", existingCState.Index2InterruptInfo)
panic("impossible")
}
for outputVarKey, intermediateVarKey := range l.outputVars {
output[outputVarKey] = *(intermediateVars[intermediateVarKey])
}
return output, nil
}
func (l *Loop) getMaxIter(in map[string]any) (int, error) {
maxIter := math.MaxInt
switch l.config.LoopType {
case ByArray:
for _, arrayKey := range l.config.InputArrays {
a, ok := nodes.TakeMapValue(in, compose.FieldPath{arrayKey})
if !ok {
return 0, fmt.Errorf("incoming array not present in input: %s", arrayKey)
}
if reflect.TypeOf(a).Kind() != reflect.Slice {
return 0, fmt.Errorf("incoming array not a slice: %s. Actual type: %v", arrayKey, reflect.TypeOf(a))
}
oneLen := reflect.ValueOf(a).Len()
if oneLen < maxIter {
maxIter = oneLen
}
}
case ByIteration:
iter, ok := nodes.TakeMapValue(in, compose.FieldPath{Count})
if !ok {
return 0, errors.New("incoming LoopCount not present in input when loop type is ByIteration")
}
maxIter = int(iter.(int64))
case Infinite:
default:
return 0, fmt.Errorf("loop type not supported: %v", l.config.LoopType)
}
return maxIter, nil
}
func convertIntermediateVars(vars map[string]*any) map[string]any {
ret := make(map[string]any, len(vars))
for k, v := range vars {
ret[k] = *v
}
return ret
}
func (l *Loop) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) {
trimmed := make(map[string]any, len(l.config.InputArrays))
for _, arrayKey := range l.config.InputArrays {
if v, ok := in[arrayKey]; ok {
trimmed[arrayKey] = v
}
}
return trimmed, nil
}

View File

@@ -0,0 +1,90 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nodes
import (
"github.com/bytedance/sonic"
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
)
type NestedWorkflowOptions struct {
optsForNested []compose.Option
toResumeIndexes map[int]compose.StateModifier
optsForIndexed map[int][]compose.Option
}
type NestedWorkflowOption func(*NestedWorkflowOptions)
func WithOptsForNested(opts ...compose.Option) NestedWorkflowOption {
return func(o *NestedWorkflowOptions) {
o.optsForNested = append(o.optsForNested, opts...)
}
}
func (c *NestedWorkflowOptions) GetOptsForNested() []compose.Option {
return c.optsForNested
}
func WithResumeIndex(i int, m compose.StateModifier) NestedWorkflowOption {
return func(o *NestedWorkflowOptions) {
if o.toResumeIndexes == nil {
o.toResumeIndexes = map[int]compose.StateModifier{}
}
o.toResumeIndexes[i] = m
}
}
func (c *NestedWorkflowOptions) GetResumeIndexes() map[int]compose.StateModifier {
return c.toResumeIndexes
}
func WithOptsForIndexed(index int, opts ...compose.Option) NestedWorkflowOption {
return func(o *NestedWorkflowOptions) {
if o.optsForIndexed == nil {
o.optsForIndexed = map[int][]compose.Option{}
}
o.optsForIndexed[index] = opts
}
}
func (c *NestedWorkflowOptions) GetOptsForIndexed(index int) []compose.Option {
if c.optsForIndexed == nil {
return nil
}
return c.optsForIndexed[index]
}
type NestedWorkflowState struct {
Index2Done map[int]bool `json:"index_2_done,omitempty"`
Index2InterruptInfo map[int]*compose.InterruptInfo `json:"index_2_interrupt_info,omitempty"`
FullOutput map[string]any `json:"full_output,omitempty"`
IntermediateVars map[string]any `json:"intermediate_vars,omitempty"`
}
func (c *NestedWorkflowState) String() string {
s, _ := sonic.MarshalIndent(c, "", " ")
return string(s)
}
type NestedWorkflowAware interface {
SaveNestedWorkflowState(key vo.NodeKey, state *NestedWorkflowState) error
GetNestedWorkflowState(key vo.NodeKey) (*NestedWorkflowState, bool, error)
InterruptEventStore
}

View File

@@ -0,0 +1,99 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nodes
import (
"context"
"fmt"
"sync"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
)
type ParentIntermediateStore struct {
mu sync.RWMutex
}
type IntermediateVarKey struct{}
func (p *ParentIntermediateStore) Init(_ context.Context) {
return
}
func (p *ParentIntermediateStore) Get(ctx context.Context, path compose.FieldPath, opts ...variable.OptionFn) (any, error) {
defer p.mu.RUnlock()
p.mu.RLock()
if len(path) != 1 {
return nil, fmt.Errorf("invalid path: %v", path)
}
ivs := getIntermediateVars(ctx)
v, ok := ivs.vars[path[0]]
if !ok {
return nil, fmt.Errorf("variable not found: %s", path[0])
}
if *v == nil {
return ivs.types[path[0]].Zero(), nil
}
return *v, nil
}
func (p *ParentIntermediateStore) Set(ctx context.Context, path compose.FieldPath, value any, opts ...variable.OptionFn) error {
defer p.mu.Unlock()
p.mu.Lock()
if len(path) != 1 {
return fmt.Errorf("invalid path: %v", path)
}
ivs := getIntermediateVars(ctx)
v, ok := ivs.vars[path[0]]
if !ok {
return fmt.Errorf("variable not found: %s", path[0])
}
if value == nil {
*v = ivs.types[path[0]].Zero()
} else {
*v = value
}
return nil
}
type intermediateVar struct {
vars map[string]*any
types map[string]*vo.TypeInfo
}
func InitIntermediateVars(ctx context.Context, vars map[string]*any, typeInfos map[string]*vo.TypeInfo) context.Context {
return context.WithValue(ctx, IntermediateVarKey{}, &intermediateVar{
vars: vars,
types: typeInfos,
})
}
func getIntermediateVars(ctx context.Context) *intermediateVar {
return ctx.Value(IntermediateVarKey{}).(*intermediateVar)
}

View File

@@ -0,0 +1,83 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package plugin
import (
"context"
"errors"
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
type Config struct {
PluginID int64
ToolID int64
PluginVersion string
PluginService plugin.Service
}
type Plugin struct {
config *Config
}
func NewPlugin(_ context.Context, cfg *Config) (*Plugin, error) {
if cfg == nil {
return nil, errors.New("config is nil")
}
if cfg.PluginID == 0 {
return nil, errors.New("plugin id is required")
}
if cfg.ToolID == 0 {
return nil, errors.New("tool id is required")
}
if cfg.PluginService == nil {
return nil, errors.New("tool service is required")
}
return &Plugin{config: cfg}, nil
}
func (p *Plugin) Invoke(ctx context.Context, parameters map[string]any) (ret map[string]any, err error) {
var exeCfg vo.ExecuteConfig
if ctxExeCfg := execute.GetExeCtx(ctx); ctxExeCfg != nil {
exeCfg = ctxExeCfg.ExeCfg
}
result, err := p.config.PluginService.ExecutePlugin(ctx, parameters, &vo.PluginEntity{
PluginID: p.config.PluginID,
PluginVersion: ptr.Of(p.config.PluginVersion),
}, p.config.ToolID, exeCfg)
if err != nil {
if extra, ok := compose.IsInterruptRerunError(err); ok {
// TODO: temporarily replace interrupt with real error, because frontend cannot handle interrupt for now
interruptData := extra.(*entity.InterruptEvent).InterruptData
return nil, vo.NewError(errno.ErrAuthorizationRequired, errorx.KV("extra", interruptData))
}
return nil, err
}
return result, nil
}

View File

@@ -0,0 +1,633 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package qa
import (
"context"
"errors"
"fmt"
"strconv"
"strings"
"unicode"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/domain/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
type QuestionAnswer struct {
config *Config
nodeMeta entity.NodeTypeMeta
}
type Config struct {
QuestionTpl string
AnswerType AnswerType
ChoiceType ChoiceType
FixedChoices []string
// used for intent recognize if answer by choices and given a custom answer, as well as for extracting structured output from user response
Model model.BaseChatModel
// the following are required if AnswerType is AnswerDirectly and needs to extract from answer
ExtractFromAnswer bool
AdditionalSystemPromptTpl string
MaxAnswerCount int
OutputFields map[string]*vo.TypeInfo
NodeKey vo.NodeKey
}
type AnswerType string
const (
AnswerDirectly AnswerType = "directly"
AnswerByChoices AnswerType = "by_choices"
)
type ChoiceType string
const (
FixedChoices ChoiceType = "fixed"
DynamicChoices ChoiceType = "dynamic"
)
const (
DynamicChoicesKey = "dynamic_option"
QuestionsKey = "$questions"
AnswersKey = "$answers"
UserResponseKey = "USER_RESPONSE"
OptionIDKey = "optionId"
OptionContentKey = "optionContent"
)
const (
extractSystemPrompt = `# 角色
你是一个参数提取 agent你的工作是从用户的回答中提取出多个字段的值每个字段遵循以下规则
# 字段说明
%s
## 输出要求
- 严格以 json 格式返回答案。
- 严格确保答案采用有效的 JSON 格式。
- 按照字段说明提取出字段的值,将已经提取到的字段放在 fields 字段
- 对于未提取到的<必填字段>生成一个新的追问问题question
- 确保在追问问题中只包含所有未提取的<必填字段>
- 不要重复问之前问过的问题
- 问题的语种请和用户的输入保持一致,如英文、中文等
- 输出按照下面结构体格式返回,包含提取到的字段或者追问的问题
- 不要回复和提取无关的问题
type Output struct {
fields FieldInfo // 根据字段说明已经提取到的字段
question string // 新一轮追问的问题
}`
extractUserPromptSuffix = `
- 严格以 json 格式返回答案。
- 严格确保答案采用有效的 JSON 格式。
- - 必填字段没有获取全则继续追问
- 必填字段: %s
%s
`
additionalPersona = `
追问人设设定: %s
`
choiceIntentDetectPrompt = `# Role
You are a semantic matching expert, good at analyzing the option that the user wants to choose based on the current context.
##Skill
Skill 1: Clearly identify which of the following options the user's reply is semantically closest to:
%s
##Restrictions
Strictly identify the intention and select the most suitable option. You can only reply with the option_id and no other content. If you think there is no suitable option, output -1
##Output format
Note: You can only output the id or -1. Your output can only be a pure number and no other content (including the reason)!`
)
func NewQuestionAnswer(_ context.Context, conf *Config) (*QuestionAnswer, error) {
if conf == nil {
return nil, errors.New("config is nil")
}
if conf.AnswerType == AnswerDirectly {
if conf.ExtractFromAnswer {
if conf.Model == nil {
return nil, errors.New("model is required when extract from answer")
}
if len(conf.OutputFields) == 0 {
return nil, errors.New("output fields is required when extract from answer")
}
}
} else if conf.AnswerType == AnswerByChoices {
if conf.ChoiceType == FixedChoices {
if len(conf.FixedChoices) == 0 {
return nil, errors.New("fixed choices is required when extract from answer")
}
}
} else {
return nil, fmt.Errorf("unknown answer type: %s", conf.AnswerType)
}
nodeMeta := entity.NodeMetaByNodeType(entity.NodeTypeQuestionAnswer)
if nodeMeta == nil {
return nil, errors.New("node meta not found for question answer")
}
return &QuestionAnswer{
config: conf,
nodeMeta: *nodeMeta,
}, nil
}
type Question struct {
Question string
Choices []string
}
type namedOpt struct {
Name string `json:"name"`
}
type optionContent struct {
Options []namedOpt `json:"options"`
Question string `json:"question"`
}
type message struct {
Type string `json:"type"`
ContentType string `json:"content_type"`
Content any `json:"content"` // either optionContent or string
ID string `json:"id,omitempty"`
}
// Execute formats the question (optionally with choices), interrupts, then extracts the answer.
// input: the references by input fields, as well as the dynamic choices array if needed.
// output: USER_RESPONSE for direct answer, structured output if needs to extract from answer, and option ID / content for answer by choices.
func (q *QuestionAnswer) Execute(ctx context.Context, in map[string]any) (out map[string]any, err error) {
var (
questions []*Question
answers []string
isFirst bool
notResumed bool
)
questions, answers, isFirst, notResumed, err = q.extractCurrentState(in)
if err != nil {
return nil, err
}
if notResumed { // previously interrupted but not resumed this time, interrupt immediately
return nil, compose.InterruptAndRerun
}
out = make(map[string]any)
out[QuestionsKey] = questions
out[AnswersKey] = answers
switch q.config.AnswerType {
case AnswerDirectly:
if isFirst { // first execution, ask the question
// format the question. Which is common to all use cases
firstQuestion, err := nodes.TemplateRender(q.config.QuestionTpl, in)
if err != nil {
return nil, err
}
return nil, q.interrupt(ctx, firstQuestion, nil, nil, nil)
}
if q.config.ExtractFromAnswer {
return q.extractFromAnswer(ctx, in, questions, answers)
}
out[UserResponseKey] = answers[0]
return out, nil
case AnswerByChoices:
if !isFirst {
lastAnswer := answers[len(answers)-1]
lastQuestion := questions[len(questions)-1]
for i, choice := range lastQuestion.Choices {
if lastAnswer == choice {
out[OptionIDKey] = intToAlphabet(i)
out[OptionContentKey] = choice
return out, nil
}
}
index, err := q.intentDetect(ctx, lastAnswer, lastQuestion.Choices)
if err != nil {
return nil, err
}
if index >= 0 {
out[OptionIDKey] = intToAlphabet(index)
out[OptionContentKey] = lastQuestion.Choices[index]
return out, nil
}
out[OptionIDKey] = "other"
out[OptionContentKey] = lastAnswer
return out, nil
}
// format the question. Which is common to all use cases
firstQuestion, err := nodes.TemplateRender(q.config.QuestionTpl, in)
if err != nil {
return nil, err
}
var formattedChoices []string
switch q.config.ChoiceType {
case FixedChoices:
for _, choice := range q.config.FixedChoices {
formattedChoice, err := nodes.TemplateRender(choice, in)
if err != nil {
return nil, err
}
formattedChoices = append(formattedChoices, formattedChoice)
}
case DynamicChoices:
dynamicChoices, ok := nodes.TakeMapValue(in, compose.FieldPath{DynamicChoicesKey})
if !ok || len(dynamicChoices.([]any)) == 0 {
return nil, vo.NewError(errno.ErrQuestionOptionsEmpty)
}
const maxDynamicChoices = 26
for i, choice := range dynamicChoices.([]any) {
if i >= maxDynamicChoices {
break // take first 26 choices, discard the others
}
c := choice.(string)
formattedChoices = append(formattedChoices, c)
}
default:
return nil, fmt.Errorf("unknown choice type: %s", q.config.ChoiceType)
}
return nil, q.interrupt(ctx, firstQuestion, formattedChoices, nil, nil)
default:
return nil, fmt.Errorf("unknown answer type: %s", q.config.AnswerType)
}
}
func (q *QuestionAnswer) extractFromAnswer(ctx context.Context, in map[string]any, questions []*Question, answers []string) (map[string]any, error) {
fieldInfo := "FieldInfo"
s, err := vo.TypeInfoToJSONSchema(q.config.OutputFields, &fieldInfo)
if err != nil {
return nil, err
}
sysPrompt := fmt.Sprintf(extractSystemPrompt, s)
var requiredFields []string
for fName, tInfo := range q.config.OutputFields {
if tInfo.Required {
requiredFields = append(requiredFields, fName)
}
}
var formattedAdditionalPrompt string
if len(q.config.AdditionalSystemPromptTpl) > 0 {
additionalPrompt, err := nodes.TemplateRender(q.config.AdditionalSystemPromptTpl, in)
if err != nil {
return nil, err
}
formattedAdditionalPrompt = fmt.Sprintf(additionalPersona, additionalPrompt)
}
userPromptSuffix := fmt.Sprintf(extractUserPromptSuffix, requiredFields, formattedAdditionalPrompt)
var (
messages = make([]*schema.Message, 0, len(questions)*2+1)
userResponse string
)
messages = append(messages, schema.SystemMessage(sysPrompt))
for i := range questions {
messages = append(messages, schema.AssistantMessage(questions[i].Question, nil))
answer := answers[i]
if i == len(questions)-1 {
userResponse = answer
answer = answer + userPromptSuffix
}
messages = append(messages, schema.UserMessage(answer))
}
out, err := q.config.Model.Generate(ctx, messages)
if err != nil {
return nil, err
}
content := nodes.ExtractJSONString(out.Content)
var outMap = make(map[string]any)
err = sonic.UnmarshalString(content, &outMap)
if err != nil {
return nil, err
}
nextQuestion, ok := outMap["question"]
if ok {
nextQuestionStr, ok := nextQuestion.(string)
if ok && len(nextQuestionStr) > 0 {
if len(answers) >= q.config.MaxAnswerCount {
return nil, fmt.Errorf("max answer count= %d exceeded", q.config.MaxAnswerCount)
}
return nil, q.interrupt(ctx, nextQuestionStr, nil, questions, answers)
}
}
fields, ok := outMap["fields"]
if !ok {
return nil, fmt.Errorf("field %s not found", fieldInfo)
}
realOutput, ws, err := nodes.ConvertInputs(ctx, fields.(map[string]any), q.config.OutputFields, nodes.SkipRequireCheck())
if err != nil {
return nil, err
}
if ws != nil {
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
}
realOutput[UserResponseKey] = userResponse
realOutput[QuestionsKey] = questions
realOutput[AnswersKey] = answers
return realOutput, nil
}
func (q *QuestionAnswer) extractCurrentState(in map[string]any) (
qResult []*Question,
aResult []string,
isFirst bool, // whether this execution if the first ever execution for this node
notResumed bool, // whether this node is previously interrupted, but not resumed this time, because another node is resumed
err error) {
questions, ok := in[QuestionsKey]
if ok {
qResult = questions.([]*Question)
}
answers, ok := in[AnswersKey]
if ok {
aResult = answers.([]string)
}
if len(qResult) == 0 && len(aResult) == 0 {
return nil, nil, true, false, nil
}
if len(qResult) != len(aResult) && len(qResult) != len(aResult)+1 {
return nil, nil, false, false,
fmt.Errorf("invalid state, question count is expected to be equal to answer count or 1 more than answer count: %v", in)
}
return qResult, aResult, false, len(qResult) == len(aResult)+1, nil
}
func (q *QuestionAnswer) intentDetect(ctx context.Context, answer string, choices []string) (int, error) {
type option struct {
Option string `json:"option"`
OptionID int `json:"option_id"`
}
options := make([]option, 0, len(choices))
for i := range choices {
options = append(options, option{Option: choices[i], OptionID: i})
}
optionsStr, err := sonic.MarshalString(options)
if err != nil {
return -1, err
}
sysPrompt := fmt.Sprintf(choiceIntentDetectPrompt, optionsStr)
messages := []*schema.Message{
schema.SystemMessage(sysPrompt),
schema.UserMessage(answer),
}
out, err := q.config.Model.Generate(ctx, messages)
if err != nil {
return -1, err
}
index, err := strconv.Atoi(out.Content)
if err != nil {
return -1, err
}
return index, nil
}
type QuestionAnswerAware interface {
AddQuestion(nodeKey vo.NodeKey, question *Question)
AddAnswer(nodeKey vo.NodeKey, answer string)
GetQuestionsAndAnswers(nodeKey vo.NodeKey) ([]*Question, []string)
}
func (q *QuestionAnswer) interrupt(ctx context.Context, newQuestion string, choices []string, oldQuestions []*Question, oldAnswers []string) error {
history := q.generateHistory(oldQuestions, oldAnswers, &newQuestion, choices)
historyList := map[string][]*message{
"messages": history,
}
interruptData, err := sonic.MarshalString(historyList)
if err != nil {
return err
}
eventID, err := workflow.GetRepository().GenID(ctx)
if err != nil {
return err
}
event := &entity.InterruptEvent{
ID: eventID,
NodeKey: q.config.NodeKey,
NodeType: entity.NodeTypeQuestionAnswer,
NodeTitle: q.nodeMeta.Name,
NodeIcon: q.nodeMeta.IconURL,
InterruptData: interruptData,
EventType: entity.InterruptEventQuestion,
}
_ = compose.ProcessState(ctx, func(ctx context.Context, setter QuestionAnswerAware) error {
setter.AddQuestion(q.config.NodeKey, &Question{
Question: newQuestion,
Choices: choices,
})
return nil
})
return compose.NewInterruptAndRerunErr(event)
}
func intToAlphabet(num int) string {
if num >= 0 && num <= 25 {
char := rune('A' + num)
return string(char)
}
return ""
}
func AlphabetToInt(str string) (int, bool) {
if len(str) != 1 {
return 0, false
}
char := rune(str[0])
char = unicode.ToUpper(char)
if char >= 'A' && char <= 'Z' {
return int(char - 'A'), true
}
return 0, false
}
func (q *QuestionAnswer) generateHistory(oldQuestions []*Question, oldAnswers []string, newQuestion *string, choices []string) []*message {
conv := func(opts []string) (namedOpts []namedOpt) {
for _, opt := range opts {
namedOpts = append(namedOpts, namedOpt{
Name: opt,
})
}
return namedOpts
}
history := make([]*message, 0, len(oldQuestions)+len(oldAnswers)+1)
for i := 0; i < len(oldQuestions); i++ {
oldQuestion := oldQuestions[i]
oldAnswer := oldAnswers[i]
contentType := ternary.IFElse(q.config.AnswerType == AnswerByChoices, "option", "text")
questionMsg := &message{
Type: "question",
ContentType: contentType,
ID: fmt.Sprintf("%s_%d", q.config.NodeKey, i*2),
}
if q.config.AnswerType == AnswerByChoices {
questionMsg.Content = optionContent{
Options: conv(oldQuestion.Choices),
Question: oldQuestion.Question,
}
} else {
questionMsg.Content = oldQuestion.Question
}
answerMsg := &message{
Type: "answer",
ContentType: contentType,
Content: oldAnswer,
ID: fmt.Sprintf("%s_%d", q.config.NodeKey, i+1),
}
history = append(history, questionMsg, answerMsg)
}
if newQuestion != nil {
if q.config.AnswerType == AnswerByChoices {
history = append(history, &message{
Type: "question",
ContentType: "option",
Content: optionContent{
Options: conv(choices),
Question: *newQuestion,
},
ID: fmt.Sprintf("%s_%d", q.config.NodeKey, len(oldQuestions)*2),
})
} else {
history = append(history, &message{
Type: "question",
ContentType: "text",
Content: *newQuestion,
ID: fmt.Sprintf("%s_%d", q.config.NodeKey, len(oldQuestions)*2),
})
}
}
return history
}
func (q *QuestionAnswer) ToCallbackOutput(_ context.Context, out map[string]any) (*nodes.StructuredCallbackOutput, error) {
questions := out[QuestionsKey].([]*Question)
answers := out[AnswersKey].([]string)
selected, hasSelected := out[OptionContentKey]
history := q.generateHistory(questions, answers, nil, nil)
for _, msg := range history {
optionC, ok := msg.Content.(optionContent)
if ok {
msg.Content = optionC.Question
}
msg.ID = ""
}
delete(out, QuestionsKey)
delete(out, AnswersKey)
sOut := &nodes.StructuredCallbackOutput{
Output: out,
RawOutput: map[string]any{
"messages": history,
},
}
if hasSelected {
sOut.RawOutput["selected"] = selected
}
return sOut, nil
}
func AppendInterruptData(interruptData string, resumeData string) (string, error) {
var historyList = make(map[string][]*message)
err := sonic.UnmarshalString(interruptData, &historyList)
if err != nil {
return "", err
}
lastQuestion := historyList["messages"][len(historyList["messages"])-1]
segments := strings.Split(lastQuestion.ID, "_")
nodeKey := segments[0]
i, err := strconv.Atoi(segments[1])
if err != nil {
return "", err
}
answerMsg := &message{
Type: "answer",
ContentType: lastQuestion.ContentType,
Content: resumeData,
ID: fmt.Sprintf("%s_%d", nodeKey, i+1),
}
historyList["messages"] = append(historyList["messages"], answerMsg)
m, err := sonic.MarshalString(historyList)
if err != nil {
return "", err
}
return m, nil
}

View File

@@ -0,0 +1,181 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package receiver
import (
"context"
"errors"
"fmt"
"github.com/bytedance/sonic"
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
sonic2 "github.com/coze-dev/coze-studio/backend/pkg/sonic"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
type Config struct {
OutputTypes map[string]*vo.TypeInfo
NodeKey vo.NodeKey
OutputSchema string
}
type InputReceiver struct {
outputTypes map[string]*vo.TypeInfo
interruptData string
nodeKey vo.NodeKey
nodeMeta entity.NodeTypeMeta
}
func New(_ context.Context, cfg *Config) (*InputReceiver, error) {
nodeMeta := entity.NodeMetaByNodeType(entity.NodeTypeInputReceiver)
if nodeMeta == nil {
return nil, errors.New("node meta not found for input receiver")
}
interruptData := map[string]string{
"content_type": "form_schema",
"content": cfg.OutputSchema,
}
interruptDataStr, err := sonic.ConfigStd.MarshalToString(interruptData) // keep the order of the keys
if err != nil {
return nil, err
}
return &InputReceiver{
outputTypes: cfg.OutputTypes,
nodeMeta: *nodeMeta,
nodeKey: cfg.NodeKey,
interruptData: interruptDataStr,
}, nil
}
const (
ReceivedDataKey = "$received_data"
receiverWarningKey = "receiver_warning_%d_%s"
)
func (i *InputReceiver) Invoke(ctx context.Context, in map[string]any) (map[string]any, error) {
var input string
if in != nil {
receivedData, ok := in[ReceivedDataKey]
if ok {
input = receivedData.(string)
}
}
if len(input) == 0 {
err := compose.ProcessState(ctx, func(ctx context.Context, ieStore nodes.InterruptEventStore) error {
_, found, e := ieStore.GetInterruptEvent(i.nodeKey) // TODO: try not use InterruptEventStore or state in general
if e != nil {
return e
}
if !found { // only generate a new event if it doesn't exist
eventID, err := workflow.GetRepository().GenID(ctx)
if err != nil {
return err
}
return ieStore.SetInterruptEvent(i.nodeKey, &entity.InterruptEvent{
ID: eventID,
NodeKey: i.nodeKey,
NodeType: entity.NodeTypeInputReceiver,
NodeTitle: i.nodeMeta.Name,
NodeIcon: i.nodeMeta.IconURL,
InterruptData: i.interruptData,
EventType: entity.InterruptEventInput,
})
}
return nil
})
if err != nil {
return nil, err
}
return nil, compose.InterruptAndRerun
}
out, err := jsonParseRelaxed(ctx, input, i.outputTypes)
if err != nil {
return nil, err
}
return out, nil
}
func jsonParseRelaxed(ctx context.Context, data string, schema_ map[string]*vo.TypeInfo) (map[string]any, error) {
var result map[string]any
err := sonic2.UnmarshalString(data, &result)
if err != nil {
return nil, err
}
r, ws, err := nodes.ConvertInputs(ctx, result, schema_, nodes.SkipUnknownFields())
if err != nil {
return nil, err
}
if ws != nil && len(*ws) > 0 {
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
var (
executeID int64
nodeKey vo.NodeKey
)
if c := execute.GetExeCtx(ctx); c != nil {
executeID = c.RootExecuteID
nodeKey = c.NodeKey
}
warningKey := fmt.Sprintf(receiverWarningKey, executeID, nodeKey)
ctxcache.Store(ctx, warningKey, *ws)
}
return r, nil
}
func (i *InputReceiver) ToCallbackOutput(ctx context.Context, output map[string]any) (
*nodes.StructuredCallbackOutput, error) {
var (
executeID int64
nodeKey vo.NodeKey
)
if c := execute.GetExeCtx(ctx); c != nil {
executeID = c.RootExecuteID
nodeKey = c.NodeKey
}
warningKey := fmt.Sprintf(receiverWarningKey, executeID, nodeKey)
var wfe vo.WorkflowError
if warnings, ok := ctxcache.Get[nodes.ConversionWarnings](ctx, warningKey); ok {
wfe = vo.WrapWarn(errno.ErrNodeOutputParseFail, warnings, errorx.KV("warnings", warnings.Error()))
}
return &nodes.StructuredCallbackOutput{
Output: output,
RawOutput: output,
Error: wfe,
}, nil
}

View File

@@ -0,0 +1,48 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package receiver
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
)
func Test_jsonParseRelaxed(t *testing.T) {
tInfos := map[string]*vo.TypeInfo{
"str_key": {
Type: vo.DataTypeString,
},
"obj_key": {
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"field1": {
Type: vo.DataTypeString,
},
},
},
}
data := `{"str_key": "val"}`
result, err := jsonParseRelaxed(context.Background(), data, tInfos)
assert.NoError(t, err)
assert.Equal(t, map[string]any{"str_key": "val"}, result)
}

View File

@@ -0,0 +1,342 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package selector
import (
"fmt"
"reflect"
"strings"
)
type Predicate interface {
Resolve() (bool, error)
}
type Clause struct {
LeftOperant any
Op Operator
RightOperant any
}
type MultiClause struct {
Clauses []*Clause
Relation ClauseRelation
}
func (c *Clause) Resolve() (bool, error) {
leftV := c.LeftOperant
rightV := c.RightOperant
leftT := reflect.TypeOf(leftV)
rightT := reflect.TypeOf(rightV)
if err := c.Op.WillAccept(leftT, rightT); err != nil {
return false, err
}
switch c.Op {
case OperatorEqual:
if leftV == nil && rightV == nil {
return true, nil
}
if leftV == nil || rightV == nil {
return false, nil
}
leftV, rightV = alignNumberTypes(leftV, rightV, leftT, rightT)
return leftV == rightV, nil
case OperatorNotEqual:
if leftV == nil && rightV == nil {
return false, nil
}
if leftV == nil || rightV == nil {
return true, nil
}
leftV, rightV = alignNumberTypes(leftV, rightV, leftT, rightT)
return leftV != rightV, nil
case OperatorEmpty:
if leftV == nil {
return true, nil
}
if leftArray, ok := leftV.([]any); ok {
return len(leftArray) == 0, nil
}
if leftObj, ok := leftV.(map[string]any); ok {
return len(leftObj) == 0, nil
}
if leftStr, ok := leftV.(string); ok {
return len(leftStr) == 0 || leftStr == "None", nil
}
if leftInt, ok := leftV.(int64); ok {
return leftInt == 0, nil
}
if leftFloat, ok := leftV.(float64); ok {
return leftFloat == 0, nil
}
if leftBool, ok := leftV.(bool); ok {
return !leftBool, nil
}
return false, nil
case OperatorNotEmpty:
empty, err := (&Clause{LeftOperant: leftV, Op: OperatorEmpty}).Resolve()
return !empty, err
case OperatorGreater:
if leftV == nil {
return false, nil
}
if rightV == nil {
return true, nil
}
leftV, rightV = alignNumberTypes(leftV, rightV, leftT, rightT)
if reflect.TypeOf(leftV).Kind() == reflect.Float64 {
return leftV.(float64) > rightV.(float64), nil
}
return leftV.(int64) > rightV.(int64), nil
case OperatorGreaterOrEqual:
if leftV == nil {
if rightV == nil {
return true, nil
}
return false, nil
}
if rightV == nil {
return true, nil
}
leftV, rightV = alignNumberTypes(leftV, rightV, leftT, rightT)
if reflect.TypeOf(leftV).Kind() == reflect.Float64 {
return leftV.(float64) >= rightV.(float64), nil
}
return leftV.(int64) >= rightV.(int64), nil
case OperatorLesser:
if leftV == nil {
if rightV == nil {
return false, nil
}
return true, nil
}
if rightV == nil {
return false, nil
}
leftV, rightV = alignNumberTypes(leftV, rightV, leftT, rightT)
if reflect.TypeOf(leftV).Kind() == reflect.Float64 {
return leftV.(float64) < rightV.(float64), nil
}
return leftV.(int64) < rightV.(int64), nil
case OperatorLesserOrEqual:
if leftV == nil {
return true, nil
}
if rightV == nil {
return false, nil
}
leftV, rightV = alignNumberTypes(leftV, rightV, leftT, rightT)
if reflect.TypeOf(leftV).Kind() == reflect.Float64 {
return leftV.(float64) <= rightV.(float64), nil
}
return leftV.(int64) <= rightV.(int64), nil
case OperatorIsTrue:
if leftV == nil {
return false, nil
}
return leftV.(bool), nil
case OperatorIsFalse:
if leftV == nil {
return true, nil
}
return !leftV.(bool), nil
case OperatorLengthGreater:
if leftV == nil {
return false, nil
}
return int64(reflect.ValueOf(leftV).Len()) > rightV.(int64), nil
case OperatorLengthGreaterOrEqual:
if leftV == nil {
if rightV.(int64) == 0 {
return true, nil
}
return false, nil
}
return int64(reflect.ValueOf(leftV).Len()) >= rightV.(int64), nil
case OperatorLengthLesser:
if leftV == nil {
if rightV.(int64) == 0 {
return false, nil
}
return true, nil
}
return int64(reflect.ValueOf(leftV).Len()) < rightV.(int64), nil
case OperatorLengthLesserOrEqual:
if leftV == nil {
return true, nil
}
return int64(reflect.ValueOf(leftV).Len()) <= rightV.(int64), nil
case OperatorContain:
if leftV == nil { // treat it as empty slice
return false, nil
}
if leftT.Kind() == reflect.String {
return strings.Contains(fmt.Sprintf("%v", leftV), rightV.(string)), nil
}
leftValue := reflect.ValueOf(leftV)
for i := 0; i < leftValue.Len(); i++ {
elem := leftValue.Index(i).Interface()
if elem == rightV {
return true, nil
}
}
return false, nil
case OperatorNotContain:
if leftV == nil { // treat it as empty slice
return false, nil
}
if leftT.Kind() == reflect.String {
return !strings.Contains(fmt.Sprintf("%v", leftV), rightV.(string)), nil
}
leftValue := reflect.ValueOf(leftV)
for i := 0; i < leftValue.Len(); i++ {
elem := leftValue.Index(i).Interface()
if elem == rightV {
return false, nil
}
}
return true, nil
case OperatorContainKey:
if leftV == nil { // treat it as empty map
return false, nil
}
if leftT.Kind() == reflect.Map {
leftValue := reflect.ValueOf(leftV)
for _, key := range leftValue.MapKeys() {
if key.Interface() == rightV {
return true, nil
}
}
} else { // struct, unreachable now
for i := 0; i < leftT.NumField(); i++ {
field := leftT.Field(i)
if field.IsExported() {
tag := field.Tag.Get("json")
if tag == rightV {
return true, nil
}
}
}
}
return false, nil
case OperatorNotContainKey:
if leftV == nil { // treat it as empty map
return false, nil
}
if leftT.Kind() == reflect.Map {
leftValue := reflect.ValueOf(leftV)
for _, key := range leftValue.MapKeys() {
if key.Interface() == rightV {
return false, nil
}
}
} else { // struct, unreachable now
for i := 0; i < leftT.NumField(); i++ {
field := leftT.Field(i)
if field.IsExported() {
tag := field.Tag.Get("json")
if tag == rightV {
return false, nil
}
}
}
}
return true, nil
default:
return false, fmt.Errorf("unknown operator: %v", c.Op)
}
}
func (mc *MultiClause) Resolve() (bool, error) {
if mc.Relation == ClauseRelationAND {
for _, clause := range mc.Clauses {
isTrue, err := clause.Resolve()
if err != nil {
return false, err
}
if !isTrue {
return false, nil
}
}
return true, nil
} else if mc.Relation == ClauseRelationOR {
for _, clause := range mc.Clauses {
isTrue, err := clause.Resolve()
if err != nil {
return false, err
}
if isTrue {
return true, nil
}
}
return false, nil
} else {
return false, fmt.Errorf("unknown relation: %v", mc.Relation)
}
}
func alignNumberTypes(leftV, rightV any, leftT, rightT reflect.Type) (any, any) {
if leftT == reflect.TypeOf(int64(0)) {
if rightT == reflect.TypeOf(float64(0)) {
leftV = float64(leftV.(int64))
}
} else if leftT == reflect.TypeOf(float64(0)) {
if rightT == reflect.TypeOf(int64(0)) {
rightV = float64(rightV.(int64))
}
}
return leftV, rightV
}

View File

@@ -0,0 +1,310 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package selector
import (
"testing"
"github.com/stretchr/testify/assert"
)
// TestClauseResolve tests the Resolve method of the Clause struct.
func TestClauseResolve(t *testing.T) {
// Test cases for different operators, considering acceptable operand types
testCases := []struct {
name string
clause Clause
want bool
wantErr bool
}{
// OperatorEqual
{
name: "OperatorEqual_IntMatch",
clause: Clause{
LeftOperant: int64(10),
Op: OperatorEqual,
RightOperant: int64(10),
},
want: true,
wantErr: false,
},
{
name: "OperatorEqual_IntMismatch",
clause: Clause{
LeftOperant: int64(10),
Op: OperatorEqual,
RightOperant: int64(20),
},
want: false,
wantErr: false,
},
{
name: "OperatorEqual_FloatMatch",
clause: Clause{
LeftOperant: 10.5,
Op: OperatorEqual,
RightOperant: 10.5,
},
want: true,
wantErr: false,
},
{
name: "OperatorEqual_StringMatch",
clause: Clause{
LeftOperant: "test",
Op: OperatorEqual,
RightOperant: "test",
},
want: true,
wantErr: false,
},
// OperatorNotEqual
{
name: "OperatorNotEqual_IntMatch",
clause: Clause{
LeftOperant: int64(10),
Op: OperatorNotEqual,
RightOperant: int64(20),
},
want: true,
wantErr: false,
},
{
name: "OperatorNotEqual_StringMatch",
clause: Clause{
LeftOperant: "test",
Op: OperatorNotEqual,
RightOperant: "xyz",
},
want: true,
wantErr: false,
},
// OperatorEmpty
{
name: "OperatorEmpty_NilValue",
clause: Clause{
LeftOperant: nil,
Op: OperatorEmpty,
},
want: true,
wantErr: false,
},
// OperatorNotEmpty
{
name: "OperatorNotEmpty_MapValue",
clause: Clause{
LeftOperant: map[string]any{"key1": "value1"},
Op: OperatorNotEmpty,
},
want: true,
wantErr: false,
},
// OperatorGreater
{
name: "OperatorGreater_IntMatch",
clause: Clause{
LeftOperant: int64(10),
Op: OperatorGreater,
RightOperant: int64(5),
},
want: true,
wantErr: false,
},
{
name: "OperatorGreater_FloatMatch",
clause: Clause{
LeftOperant: 10.5,
Op: OperatorGreater,
RightOperant: 5.0,
},
want: true,
wantErr: false,
},
// OperatorGreaterOrEqual
{
name: "OperatorGreaterOrEqual_IntMatch",
clause: Clause{
LeftOperant: int64(10),
Op: OperatorGreaterOrEqual,
RightOperant: int64(10),
},
want: true,
wantErr: false,
},
// OperatorLesser
{
name: "OperatorLesser_IntMatch",
clause: Clause{
LeftOperant: int64(10),
Op: OperatorLesser,
RightOperant: int64(15),
},
want: true,
wantErr: false,
},
// OperatorLesserOrEqual
{
name: "OperatorLesserOrEqual_IntMatch",
clause: Clause{
LeftOperant: int64(10),
Op: OperatorLesserOrEqual,
RightOperant: int64(10),
},
want: true,
wantErr: false,
},
// OperatorIsTrue
{
name: "OperatorIsTrue_BoolTrue",
clause: Clause{
LeftOperant: true,
Op: OperatorIsTrue,
},
want: true,
wantErr: false,
},
// OperatorIsFalse
{
name: "OperatorIsFalse_BoolFalse",
clause: Clause{
LeftOperant: true,
Op: OperatorIsFalse,
},
want: false,
wantErr: false,
},
// OperatorLengthGreater
{
name: "OperatorLengthGreater_Slice",
clause: Clause{
LeftOperant: []int{1, 2, 3},
Op: OperatorLengthGreater,
RightOperant: int64(2),
},
want: true,
wantErr: false,
},
{
name: "OperatorLengthGreater_String",
clause: Clause{
LeftOperant: "test",
Op: OperatorLengthGreater,
RightOperant: int64(2),
},
want: true,
wantErr: false,
},
// OperatorLengthGreaterOrEqual
{
name: "OperatorLengthGreaterOrEqual_Slice",
clause: Clause{
LeftOperant: []int{1, 2, 3},
Op: OperatorLengthGreaterOrEqual,
RightOperant: int64(3),
},
want: true,
wantErr: false,
},
// OperatorLengthLesser
{
name: "OperatorLengthLesser_Slice",
clause: Clause{
LeftOperant: []int{1, 2, 3},
Op: OperatorLengthLesser,
RightOperant: int64(4),
},
want: true,
wantErr: false,
},
// OperatorLengthLesserOrEqual
{
name: "OperatorLengthLesserOrEqual_Slice",
clause: Clause{
LeftOperant: []int{1, 2, 3},
Op: OperatorLengthLesserOrEqual,
RightOperant: int64(3),
},
want: true,
wantErr: false,
},
// OperatorContain
{
name: "OperatorContain_String",
clause: Clause{
LeftOperant: "test",
Op: OperatorContain,
RightOperant: "es",
},
want: true,
wantErr: false,
},
{
name: "OperatorContain_Slice",
clause: Clause{
LeftOperant: []int{1, 2, 3},
Op: OperatorContain,
RightOperant: 2,
},
want: true,
wantErr: false,
},
// OperatorNotContain
{
name: "OperatorNotContain_String",
clause: Clause{
LeftOperant: "test2",
Op: OperatorNotContain,
RightOperant: "xyz",
},
want: true,
wantErr: false,
},
// OperatorContainKey
{
name: "OperatorContainKey_Map",
clause: Clause{
LeftOperant: map[string]any{"key1": "value1"},
Op: OperatorContainKey,
RightOperant: "key1",
},
want: true,
wantErr: false,
},
// OperatorNotContainKey
{
name: "OperatorNotContainKey_Map",
clause: Clause{
LeftOperant: map[string]any{"key1": "value1"},
Op: OperatorNotContainKey,
RightOperant: "key2",
},
want: true,
wantErr: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got, err := tc.clause.Resolve()
if (err != nil) != tc.wantErr {
t.Errorf("Clause.Resolve() error = %v, wantErr %v", err, tc.wantErr)
return
}
assert.Equal(t, tc.want, got)
})
}
}

View File

@@ -0,0 +1,182 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package selector
import (
"fmt"
"reflect"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
)
type Operator string
const (
OperatorEqual Operator = "="
OperatorNotEqual Operator = "!="
OperatorEmpty Operator = "empty"
OperatorNotEmpty Operator = "not_empty"
OperatorGreater Operator = ">"
OperatorGreaterOrEqual Operator = ">="
OperatorLesser Operator = "<"
OperatorLesserOrEqual Operator = "<="
OperatorIsTrue Operator = "true"
OperatorIsFalse Operator = "false"
OperatorLengthGreater Operator = "len >"
OperatorLengthGreaterOrEqual Operator = "len >="
OperatorLengthLesser Operator = "len <"
OperatorLengthLesserOrEqual Operator = "len <="
OperatorContain Operator = "contain"
OperatorNotContain Operator = "not_contain"
OperatorContainKey Operator = "contain_key"
OperatorNotContainKey Operator = "not_contain_key"
)
func (o *Operator) WillAccept(leftT, rightT reflect.Type) error {
switch *o {
case OperatorEqual, OperatorNotEqual:
if leftT == nil || rightT == nil {
return nil
}
if leftT != reflect.TypeOf(int64(0)) && leftT != reflect.TypeOf(float64(0)) && leftT.Kind() != reflect.Bool && leftT.Kind() != reflect.String {
return fmt.Errorf("operator %v only accepts int64, float64, bool or string, not %v", *o, leftT)
}
if leftT.Kind() == reflect.Bool || leftT.Kind() != reflect.String {
if leftT != rightT {
return fmt.Errorf("operator %v left operant and right operant must be same type: %v, %v", *o, leftT, rightT)
}
}
if leftT == reflect.TypeOf(int64(0)) || leftT == reflect.TypeOf(float64(0)) {
if rightT != reflect.TypeOf(int64(0)) && rightT != reflect.TypeOf(float64(0)) {
return fmt.Errorf("operator %v right operant must be int64 or float64, not %v", *o, rightT)
}
}
case OperatorEmpty, OperatorNotEmpty:
if rightT != nil {
return fmt.Errorf("operator %v does not accept non-nil right operant: %v", *o, rightT)
}
case OperatorGreater, OperatorGreaterOrEqual, OperatorLesser, OperatorLesserOrEqual:
if leftT == nil {
return nil
}
if leftT != reflect.TypeOf(int64(0)) && leftT != reflect.TypeOf(float64(0)) {
return fmt.Errorf("operator %v only accepts float64 or int64, not %v", *o, leftT)
}
case OperatorIsTrue, OperatorIsFalse:
if leftT == nil {
return nil
}
if rightT != nil {
return fmt.Errorf("operator %v does not accept non-nil right operant: %v", *o, rightT)
}
if leftT.Kind() != reflect.Bool {
return fmt.Errorf("operator %v only accepts boolean, not %v", *o, leftT)
}
case OperatorLengthGreater, OperatorLengthGreaterOrEqual, OperatorLengthLesser, OperatorLengthLesserOrEqual:
if leftT == nil {
return nil
}
if leftT.Kind() != reflect.String && leftT.Kind() != reflect.Slice {
return fmt.Errorf("operator %v left operant only accepts string or slice, not %v", *o, leftT)
}
if rightT != reflect.TypeOf(int64(0)) {
return fmt.Errorf("operator %v right operant only accepts int64, not %v", *o, rightT)
}
case OperatorContain, OperatorNotContain:
if leftT == nil {
return nil
}
switch leftT.Kind() {
case reflect.String:
if rightT.Kind() != reflect.String {
return fmt.Errorf("operator %v whose left operant is string only accepts right operant of string, not %v", *o, rightT)
}
case reflect.Slice:
elemType := leftT.Elem()
if !rightT.AssignableTo(elemType) {
return fmt.Errorf("operator %v whose left operant is slice only accepts right operant of corresponding element type %v, not %v", *o, elemType, rightT)
}
default:
return fmt.Errorf("operator %v only accepts left operant of string or slice, not %v", *o, leftT)
}
case OperatorContainKey, OperatorNotContainKey:
if leftT == nil { // treat it as empty map
return nil
}
if leftT.Kind() != reflect.Map {
return fmt.Errorf("operator %v only accepts left operant of map, not %v", *o, leftT)
}
if rightT.Kind() != reflect.String {
return fmt.Errorf("operator %v only accepts right operant of string, not %v", *o, rightT)
}
default:
return fmt.Errorf("unknown operator: %d", o)
}
return nil
}
func (o *Operator) ToCanvasOperatorType() vo.OperatorType {
switch *o {
case OperatorEqual:
return vo.Equal
case OperatorNotEqual:
return vo.NotEqual
case OperatorEmpty:
return vo.Empty
case OperatorNotEmpty:
return vo.NotEmpty
case OperatorGreater:
return vo.GreaterThan
case OperatorGreaterOrEqual:
return vo.GreaterThanEqual
case OperatorLesser:
return vo.LessThan
case OperatorLesserOrEqual:
return vo.LessThanEqual
case OperatorIsTrue:
return vo.True
case OperatorIsFalse:
return vo.False
case OperatorLengthGreater:
return vo.LengthGreaterThan
case OperatorLengthGreaterOrEqual:
return vo.LengthGreaterThanEqual
case OperatorLengthLesser:
return vo.LengthLessThan
case OperatorLengthLesserOrEqual:
return vo.LengthLessThanEqual
case OperatorContain:
return vo.Contain
case OperatorNotContain:
return vo.NotContain
case OperatorContainKey:
return vo.Contain
case OperatorNotContainKey:
return vo.NotContain
default:
panic(fmt.Sprintf("unknown operator: %+v", o))
}
}

View File

@@ -0,0 +1,397 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package selector
import (
"reflect"
"testing"
)
// TestOperatorWillAccept tests the WillAccept method of the Operator struct.
func TestOperatorWillAccept(t *testing.T) {
testCases := []struct {
name string
operator Operator
leftType reflect.Type
rightType reflect.Type
wantErr bool
}{
// OperatorEqual
{
name: "OperatorEqual_Int64Match",
operator: OperatorEqual,
leftType: reflect.TypeOf(int64(0)),
rightType: reflect.TypeOf(int64(0)),
wantErr: false,
},
{
name: "OperatorEqual_InvalidType",
operator: OperatorEqual,
leftType: reflect.TypeOf(struct{}{}),
rightType: reflect.TypeOf(struct{}{}),
wantErr: true,
},
// OperatorNotEqual
{
name: "OperatorNotEqual_Float64Match",
operator: OperatorNotEqual,
leftType: reflect.TypeOf(float64(0)),
rightType: reflect.TypeOf(float64(0)),
wantErr: false,
},
{
name: "OperatorNotEqual_InvalidType",
operator: OperatorNotEqual,
leftType: reflect.TypeOf(struct{}{}),
rightType: reflect.TypeOf(struct{}{}),
wantErr: true,
},
// OperatorEmpty
{
name: "OperatorEmpty_Struct",
operator: OperatorEmpty,
leftType: reflect.TypeOf(map[string]int{}),
rightType: nil,
wantErr: false,
},
{
name: "OperatorEmpty_NonNilRight",
operator: OperatorEmpty,
leftType: reflect.TypeOf(map[string]int{}),
rightType: reflect.TypeOf(int64(0)),
wantErr: true,
},
// OperatorNotEmpty
{
name: "OperatorNotEmpty_Slice",
operator: OperatorNotEmpty,
leftType: reflect.TypeOf([]int{}),
rightType: nil,
wantErr: false,
},
{
name: "OperatorNotEmpty_NonNilRight",
operator: OperatorNotEmpty,
leftType: reflect.TypeOf([]int{}),
rightType: reflect.TypeOf(int64(0)),
wantErr: true,
},
// OperatorGreater
{
name: "OperatorGreater_Int64Match",
operator: OperatorGreater,
leftType: reflect.TypeOf(int64(0)),
rightType: reflect.TypeOf(int64(0)),
wantErr: false,
},
{
name: "OperatorGreater_InvalidType",
operator: OperatorGreater,
leftType: reflect.TypeOf(struct{}{}),
rightType: reflect.TypeOf(struct{}{}),
wantErr: true,
},
// OperatorGreaterOrEqual
{
name: "OperatorGreaterOrEqual_Float64Match",
operator: OperatorGreaterOrEqual,
leftType: reflect.TypeOf(float64(0)),
rightType: reflect.TypeOf(float64(0)),
wantErr: false,
},
{
name: "OperatorGreaterOrEqual_InvalidType",
operator: OperatorGreaterOrEqual,
leftType: reflect.TypeOf(struct{}{}),
rightType: reflect.TypeOf(struct{}{}),
wantErr: true,
},
// OperatorLesser
{
name: "OperatorLesser_Int64Match",
operator: OperatorLesser,
leftType: reflect.TypeOf(int64(0)),
rightType: reflect.TypeOf(int64(0)),
wantErr: false,
},
{
name: "OperatorLesser_InvalidType",
operator: OperatorLesser,
leftType: reflect.TypeOf(struct{}{}),
rightType: reflect.TypeOf(struct{}{}),
wantErr: true,
},
// OperatorLesserOrEqual
{
name: "OperatorLesserOrEqual_Float64Match",
operator: OperatorLesserOrEqual,
leftType: reflect.TypeOf(float64(0)),
rightType: reflect.TypeOf(float64(0)),
wantErr: false,
},
{
name: "OperatorLesserOrEqual_InvalidType",
operator: OperatorLesserOrEqual,
leftType: reflect.TypeOf(struct{}{}),
rightType: reflect.TypeOf(struct{}{}),
wantErr: true,
},
// OperatorIsTrue
{
name: "OperatorIsTrue_Bool",
operator: OperatorIsTrue,
leftType: reflect.TypeOf(true),
rightType: nil,
wantErr: false,
},
{
name: "OperatorIsTrue_NonNilRight",
operator: OperatorIsTrue,
leftType: reflect.TypeOf(true),
rightType: reflect.TypeOf(int64(0)),
wantErr: true,
},
{
name: "OperatorIsTrue_InvalidType",
operator: OperatorIsTrue,
leftType: reflect.TypeOf(int64(0)),
rightType: nil,
wantErr: true,
},
// OperatorIsFalse
{
name: "OperatorIsFalse_Bool",
operator: OperatorIsFalse,
leftType: reflect.TypeOf(false),
rightType: nil,
wantErr: false,
},
{
name: "OperatorIsFalse_NonNilRight",
operator: OperatorIsFalse,
leftType: reflect.TypeOf(false),
rightType: reflect.TypeOf(int64(0)),
wantErr: true,
},
{
name: "OperatorIsFalse_InvalidType",
operator: OperatorIsFalse,
leftType: reflect.TypeOf(int64(0)),
rightType: nil,
wantErr: true,
},
// OperatorLengthGreater
{
name: "OperatorLengthGreater_String",
operator: OperatorLengthGreater,
leftType: reflect.TypeOf(""),
rightType: reflect.TypeOf(int64(0)),
wantErr: false,
},
{
name: "OperatorLengthGreater_InvalidLeft",
operator: OperatorLengthGreater,
leftType: reflect.TypeOf(int64(0)),
rightType: reflect.TypeOf(int64(0)),
wantErr: true,
},
{
name: "OperatorLengthGreater_InvalidRight",
operator: OperatorLengthGreater,
leftType: reflect.TypeOf(""),
rightType: reflect.TypeOf(float64(0)),
wantErr: true,
},
// OperatorLengthGreaterOrEqual
{
name: "OperatorLengthGreaterOrEqual_Slice",
operator: OperatorLengthGreaterOrEqual,
leftType: reflect.TypeOf([]any{}),
rightType: reflect.TypeOf(int64(0)),
wantErr: false,
},
{
name: "OperatorLengthGreaterOrEqual_InvalidLeft",
operator: OperatorLengthGreaterOrEqual,
leftType: reflect.TypeOf(int64(0)),
rightType: reflect.TypeOf(int64(0)),
wantErr: true,
},
{
name: "OperatorLengthGreaterOrEqual_InvalidRight",
operator: OperatorLengthGreaterOrEqual,
leftType: reflect.TypeOf([]any{}),
rightType: reflect.TypeOf(float64(0)),
wantErr: true,
},
// OperatorLengthLesser
{
name: "OperatorLengthLesser_String",
operator: OperatorLengthLesser,
leftType: reflect.TypeOf(""),
rightType: reflect.TypeOf(int64(0)),
wantErr: false,
},
{
name: "OperatorLengthLesser_InvalidLeft",
operator: OperatorLengthLesser,
leftType: reflect.TypeOf(int64(0)),
rightType: reflect.TypeOf(int64(0)),
wantErr: true,
},
{
name: "OperatorLengthLesser_InvalidRight",
operator: OperatorLengthLesser,
leftType: reflect.TypeOf(""),
rightType: reflect.TypeOf(float64(0)),
wantErr: true,
},
// OperatorLengthLesserOrEqual
{
name: "OperatorLengthLesserOrEqual_Slice",
operator: OperatorLengthLesserOrEqual,
leftType: reflect.TypeOf([]any{}),
rightType: reflect.TypeOf(int64(0)),
wantErr: false,
},
{
name: "OperatorLengthLesserOrEqual_InvalidLeft",
operator: OperatorLengthLesserOrEqual,
leftType: reflect.TypeOf(int64(0)),
rightType: reflect.TypeOf(int64(0)),
wantErr: true,
},
{
name: "OperatorLengthLesserOrEqual_InvalidRight",
operator: OperatorLengthLesserOrEqual,
leftType: reflect.TypeOf([]any{}),
rightType: reflect.TypeOf(float64(0)),
wantErr: true,
},
// OperatorContain
{
name: "OperatorContain_String",
operator: OperatorContain,
leftType: reflect.TypeOf(""),
rightType: reflect.TypeOf(""),
wantErr: false,
},
{
name: "OperatorContain_Slice",
operator: OperatorContain,
leftType: reflect.TypeOf([]any{}),
rightType: reflect.TypeOf(int64(0)),
wantErr: false,
},
{
name: "OperatorContain_InvalidLeft",
operator: OperatorContain,
leftType: reflect.TypeOf(int64(0)),
rightType: reflect.TypeOf(0),
wantErr: true,
},
{
name: "OperatorContain_InvalidRight",
operator: OperatorContain,
leftType: reflect.TypeOf(""),
rightType: reflect.TypeOf(int64(0)),
wantErr: true,
},
// OperatorNotContain
{
name: "OperatorNotContain_String",
operator: OperatorNotContain,
leftType: reflect.TypeOf(""),
rightType: reflect.TypeOf(""),
wantErr: false,
},
{
name: "OperatorNotContain_Slice",
operator: OperatorNotContain,
leftType: reflect.TypeOf([]any{}),
rightType: reflect.TypeOf(int64(0)),
wantErr: false,
},
{
name: "OperatorNotContain_InvalidLeft",
operator: OperatorNotContain,
leftType: reflect.TypeOf(int64(0)),
rightType: reflect.TypeOf(int64(0)),
wantErr: true,
},
{
name: "OperatorNotContain_InvalidRight",
operator: OperatorNotContain,
leftType: reflect.TypeOf(""),
rightType: reflect.TypeOf(int64(0)),
wantErr: true,
},
// OperatorContainKey
{
name: "OperatorContainKey_Map",
operator: OperatorContainKey,
leftType: reflect.TypeOf(map[string]any{}),
rightType: reflect.TypeOf(""),
wantErr: false,
},
{
name: "OperatorContainKey_InvalidLeft",
operator: OperatorContainKey,
leftType: reflect.TypeOf(int64(0)),
rightType: reflect.TypeOf(""),
wantErr: true,
},
{
name: "OperatorContainKey_InvalidRight",
operator: OperatorContainKey,
leftType: reflect.TypeOf(map[string]any{}),
rightType: reflect.TypeOf(int64(0)),
wantErr: true,
},
// OperatorNotContainKey
{
name: "OperatorNotContainKey_Map",
operator: OperatorNotContainKey,
leftType: reflect.TypeOf(map[string]any{}),
rightType: reflect.TypeOf(""),
wantErr: false,
},
{
name: "OperatorNotContainKey_InvalidLeft",
operator: OperatorNotContainKey,
leftType: reflect.TypeOf(int64(0)),
rightType: reflect.TypeOf(""),
wantErr: true,
},
{
name: "OperatorNotContainKey_InvalidRight",
operator: OperatorNotContainKey,
leftType: reflect.TypeOf(map[string]any{}),
rightType: reflect.TypeOf(int64(0)),
wantErr: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := tc.operator.WillAccept(tc.leftType, tc.rightType)
if (err != nil) != tc.wantErr {
t.Errorf("Operator.WillAccept() error = %v, wantErr %v", err, tc.wantErr)
}
})
}
}

View File

@@ -0,0 +1,54 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package selector
import (
"fmt"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
)
type ClauseRelation string
const (
ClauseRelationAND ClauseRelation = "and"
ClauseRelationOR ClauseRelation = "or"
)
type Config struct {
Clauses []*OneClauseSchema `json:"clauses"`
}
type OneClauseSchema struct {
Single *Operator `json:"single,omitempty"`
Multi *MultiClauseSchema `json:"multi,omitempty"`
}
type MultiClauseSchema struct {
Clauses []*Operator `json:"clauses"`
Relation ClauseRelation `json:"relation"`
}
func (c ClauseRelation) ToVOLogicType() vo.LogicType {
if c == ClauseRelationAND {
return vo.AND
} else if c == ClauseRelationOR {
return vo.OR
}
panic(fmt.Sprintf("unknown clause relation: %s", c))
}

View File

@@ -0,0 +1,209 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package selector
import (
"context"
"fmt"
"strconv"
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
)
type Selector struct {
config *Config
}
func NewSelector(_ context.Context, config *Config) (*Selector, error) {
if config == nil {
return nil, fmt.Errorf("config is nil")
}
if len(config.Clauses) == 0 {
return nil, fmt.Errorf("config clauses are empty")
}
for _, clause := range config.Clauses {
if clause.Single == nil && clause.Multi == nil {
return nil, fmt.Errorf("single clause and multi clause are both nil")
}
if clause.Single != nil && clause.Multi != nil {
return nil, fmt.Errorf("multi clause and single clause are both non-nil")
}
if clause.Multi != nil {
if len(clause.Multi.Clauses) == 0 {
return nil, fmt.Errorf("multi clause's single clauses are empty")
}
if clause.Multi.Relation != ClauseRelationAND && clause.Multi.Relation != ClauseRelationOR {
return nil, fmt.Errorf("multi clause and clauses are both non-AND-OR: %v", clause.Multi.Relation)
}
}
}
return &Selector{
config: config,
}, nil
}
type Operants struct {
Left any
Right any
Multi []*Operants
}
const (
LeftKey = "left"
RightKey = "right"
SelectKey = "selected"
)
func (s *Selector) Select(_ context.Context, input map[string]any) (out map[string]any, err error) {
in, err := s.SelectorInputConverter(input)
if err != nil {
return nil, err
}
predicates := make([]Predicate, 0, len(s.config.Clauses))
for i, oneConf := range s.config.Clauses {
if oneConf.Single != nil {
left := in[i].Left
right := in[i].Right
if right != nil {
predicates = append(predicates, &Clause{
LeftOperant: left,
Op: *oneConf.Single,
RightOperant: right,
})
} else {
predicates = append(predicates, &Clause{
LeftOperant: left,
Op: *oneConf.Single,
})
}
} else if oneConf.Multi != nil {
multiClause := &MultiClause{
Relation: oneConf.Multi.Relation,
}
for j, singleConf := range oneConf.Multi.Clauses {
left := in[i].Multi[j].Left
right := in[i].Multi[j].Right
if right != nil {
multiClause.Clauses = append(multiClause.Clauses, &Clause{
LeftOperant: left,
Op: *singleConf,
RightOperant: right,
})
} else {
multiClause.Clauses = append(multiClause.Clauses, &Clause{
LeftOperant: left,
Op: *singleConf,
})
}
}
predicates = append(predicates, multiClause)
} else {
return nil, fmt.Errorf("invalid clause config, both single and multi are nil: %v", oneConf)
}
}
for i, p := range predicates {
isTrue, err := p.Resolve()
if err != nil {
return nil, err
}
if isTrue {
return map[string]any{SelectKey: i}, nil
}
}
return map[string]any{SelectKey: len(in)}, nil // default choice
}
func (s *Selector) GetType() string {
return "Selector"
}
func (s *Selector) ConditionCount() int {
return len(s.config.Clauses)
}
func (s *Selector) SelectorInputConverter(in map[string]any) (out []Operants, err error) {
conf := s.config.Clauses
for i, oneConf := range conf {
if oneConf.Single != nil {
left, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), LeftKey})
if !ok {
return nil, fmt.Errorf("failed to take left operant from input map: %v, clause index= %d", in, i)
}
right, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), RightKey})
if ok {
out = append(out, Operants{Left: left, Right: right})
} else {
out = append(out, Operants{Left: left})
}
} else if oneConf.Multi != nil {
multiClause := make([]*Operants, 0)
for j := range oneConf.Multi.Clauses {
left, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), strconv.Itoa(j), LeftKey})
if !ok {
return nil, fmt.Errorf("failed to take left operant from input map: %v, clause index= %d, single clause index= %d", in, i, j)
}
right, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), strconv.Itoa(j), RightKey})
if ok {
multiClause = append(multiClause, &Operants{Left: left, Right: right})
} else {
multiClause = append(multiClause, &Operants{Left: left})
}
}
out = append(out, Operants{Multi: multiClause})
} else {
return nil, fmt.Errorf("invalid clause config, both single and multi are nil: %v", oneConf)
}
}
return out, nil
}
func (s *Selector) ToCallbackOutput(_ context.Context, output map[string]any) (*nodes.StructuredCallbackOutput, error) {
count := len(s.config.Clauses)
out := output[SelectKey].(int)
if out == count {
cOutput := map[string]any{"result": "pass to else branch"}
return &nodes.StructuredCallbackOutput{
Output: cOutput,
RawOutput: cOutput,
}, nil
}
if out >= 0 && out < count {
cOutput := map[string]any{"result": fmt.Sprintf("pass to condition %d branch", out+1)}
return &nodes.StructuredCallbackOutput{
Output: cOutput,
RawOutput: cOutput,
}, nil
}
return nil, fmt.Errorf("out of range: %d", out)
}

View File

@@ -0,0 +1,185 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nodes
import (
"context"
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
)
var KeyIsFinished = "\x1FKey is finished\x1F"
type Mode string
const (
Streaming Mode = "streaming"
NonStreaming Mode = "non-streaming"
)
type FieldStreamType string
const (
FieldIsStream FieldStreamType = "yes" // absolutely a stream
FieldNotStream FieldStreamType = "no" // absolutely not a stream
FieldMaybeStream FieldStreamType = "maybe" // maybe a stream, requires request-time resolution
FieldSkipped FieldStreamType = "skipped" // the field source's node is skipped
)
// SourceInfo contains stream type for a input field source of a node.
type SourceInfo struct {
// IsIntermediate means this field is itself not a field source, but a map containing one or more field sources.
IsIntermediate bool
// FieldType the stream type of the field. May require request-time resolution in addition to compile-time.
FieldType FieldStreamType
// FromNodeKey is the node key that produces this field source. empty if the field is a static value or variable.
FromNodeKey vo.NodeKey
// FromPath is the path of this field source within the source node. empty if the field is a static value or variable.
FromPath compose.FieldPath
TypeInfo *vo.TypeInfo
// SubSources are SourceInfo for keys within this intermediate Map(Object) field.
SubSources map[string]*SourceInfo
}
type DynamicStreamContainer interface {
SaveDynamicChoice(nodeKey vo.NodeKey, groupToChoice map[string]int)
GetDynamicChoice(nodeKey vo.NodeKey) map[string]int
GetDynamicStreamType(nodeKey vo.NodeKey, group string) (FieldStreamType, error)
GetAllDynamicStreamTypes(nodeKey vo.NodeKey) (map[string]FieldStreamType, error)
}
// ResolveStreamSources resolves incoming field sources for a node, deciding their stream type.
func ResolveStreamSources(ctx context.Context, sources map[string]*SourceInfo) (map[string]*SourceInfo, error) {
resolved := make(map[string]*SourceInfo, len(sources))
nodeKey2Skipped := make(map[vo.NodeKey]bool)
var resolver func(path string, sInfo *SourceInfo) (*SourceInfo, error)
resolver = func(path string, sInfo *SourceInfo) (*SourceInfo, error) {
resolvedNode := &SourceInfo{
IsIntermediate: sInfo.IsIntermediate,
FieldType: sInfo.FieldType,
FromNodeKey: sInfo.FromNodeKey,
FromPath: sInfo.FromPath,
TypeInfo: sInfo.TypeInfo,
}
if len(sInfo.SubSources) > 0 {
resolvedNode.SubSources = make(map[string]*SourceInfo, len(sInfo.SubSources))
for k, subInfo := range sInfo.SubSources {
resolvedSub, err := resolver(k, subInfo)
if err != nil {
return nil, err
}
resolvedNode.SubSources[k] = resolvedSub
}
return resolvedNode, nil
}
if sInfo.FromNodeKey == "" { // static values and variables, always non-streaming and available
return resolvedNode, nil
}
var skipped, ok bool
if skipped, ok = nodeKey2Skipped[sInfo.FromNodeKey]; !ok {
_ = compose.ProcessState(ctx, func(ctx context.Context, state NodeExecuteStatusAware) error {
skipped = !state.NodeExecuted(sInfo.FromNodeKey)
return nil
})
nodeKey2Skipped[sInfo.FromNodeKey] = skipped
}
if skipped {
resolvedNode.FieldType = FieldSkipped
return resolvedNode, nil
}
if sInfo.FieldType == FieldMaybeStream {
if len(sInfo.SubSources) > 0 {
panic("a maybe stream field should not have sub sources")
}
var streamType FieldStreamType
err := compose.ProcessState(ctx, func(ctx context.Context, state DynamicStreamContainer) error {
var e error
streamType, e = state.GetDynamicStreamType(sInfo.FromNodeKey, sInfo.FromPath[0])
return e
})
if err != nil {
return nil, err
}
return &SourceInfo{
IsIntermediate: sInfo.IsIntermediate,
FieldType: streamType,
FromNodeKey: sInfo.FromNodeKey,
FromPath: sInfo.FromPath,
SubSources: sInfo.SubSources,
TypeInfo: sInfo.TypeInfo,
}, nil
}
return resolvedNode, nil
}
for k, sInfo := range sources {
resolvedInfo, err := resolver(k, sInfo)
if err != nil {
return nil, err
}
resolved[k] = resolvedInfo
}
return resolved, nil
}
type NodeExecuteStatusAware interface {
NodeExecuted(key vo.NodeKey) bool
}
func (s *SourceInfo) Skipped() bool {
if !s.IsIntermediate {
return s.FieldType == FieldSkipped
}
for _, sub := range s.SubSources {
if !sub.Skipped() {
return false
}
}
return true
}
func (s *SourceInfo) FromNode(nodeKey vo.NodeKey) bool {
if !s.IsIntermediate {
return s.FromNodeKey == nodeKey
}
for _, sub := range s.SubSources {
if sub.FromNode(nodeKey) {
return true
}
}
return false
}

View File

@@ -0,0 +1,154 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package subworkflow
import (
"context"
"errors"
"fmt"
"strconv"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
)
type Config struct {
Runner compose.Runnable[map[string]any, map[string]any]
}
type SubWorkflow struct {
cfg *Config
}
func NewSubWorkflow(_ context.Context, cfg *Config) (*SubWorkflow, error) {
if cfg == nil {
return nil, errors.New("config is nil")
}
if cfg.Runner == nil {
return nil, errors.New("runnable is nil")
}
return &SubWorkflow{cfg: cfg}, nil
}
func (s *SubWorkflow) Invoke(ctx context.Context, in map[string]any, opts ...nodes.NestedWorkflowOption) (map[string]any, error) {
nestedOpts, nodeKey, err := prepareOptions(ctx, opts...)
if err != nil {
return nil, err
}
out, err := s.cfg.Runner.Invoke(ctx, in, nestedOpts...)
if err != nil {
interruptInfo, ok := compose.ExtractInterruptInfo(err)
if !ok {
return nil, err
}
iEvent := &entity.InterruptEvent{
NodeKey: nodeKey,
NodeType: entity.NodeTypeSubWorkflow,
SubWorkflowInterruptInfo: interruptInfo,
}
err = compose.ProcessState(ctx, func(ctx context.Context, setter nodes.InterruptEventStore) error {
return setter.SetInterruptEvent(nodeKey, iEvent)
})
if err != nil {
return nil, err
}
return nil, compose.InterruptAndRerun
}
return out, nil
}
func (s *SubWorkflow) Stream(ctx context.Context, in map[string]any, opts ...nodes.NestedWorkflowOption) (*schema.StreamReader[map[string]any], error) {
nestedOpts, nodeKey, err := prepareOptions(ctx, opts...)
if err != nil {
return nil, err
}
out, err := s.cfg.Runner.Stream(ctx, in, nestedOpts...)
if err != nil {
interruptInfo, ok := compose.ExtractInterruptInfo(err)
if !ok {
return nil, err
}
iEvent := &entity.InterruptEvent{
NodeKey: nodeKey,
NodeType: entity.NodeTypeSubWorkflow,
SubWorkflowInterruptInfo: interruptInfo,
}
err = compose.ProcessState(ctx, func(ctx context.Context, setter nodes.InterruptEventStore) error {
return setter.SetInterruptEvent(nodeKey, iEvent)
})
if err != nil {
return nil, err
}
return nil, compose.InterruptAndRerun
}
return out, nil
}
func prepareOptions(ctx context.Context, opts ...nodes.NestedWorkflowOption) ([]compose.Option, vo.NodeKey, error) {
options := &nodes.NestedWorkflowOptions{}
for _, opt := range opts {
opt(options)
}
nestedOpts := options.GetOptsForNested()
exeCtx := execute.GetExeCtx(ctx)
if exeCtx == nil {
panic("impossible. exeCtx in sub workflow is nil")
}
checkPointID := exeCtx.CheckPointID
if len(checkPointID) > 0 {
newCheckpointID := checkPointID
if exeCtx.SubWorkflowCtx != nil {
newCheckpointID += "_" + strconv.Itoa(int(exeCtx.SubWorkflowCtx.SubExecuteID))
}
newCheckpointID += "_" + strconv.Itoa(int(exeCtx.NodeCtx.NodeExecuteID))
nestedOpts = append(nestedOpts, compose.WithCheckPointID(newCheckpointID))
}
if len(options.GetResumeIndexes()) > 0 {
if len(options.GetResumeIndexes()) != 1 {
return nil, "", fmt.Errorf("resume indexes for sub workflow length must be 1")
}
if _, ok := options.GetResumeIndexes()[0]; !ok {
return nil, "", fmt.Errorf("resume indexes for sub workflow must resume index 0")
}
stateModifier, ok := options.GetResumeIndexes()[0]
if ok {
nestedOpts = append(nestedOpts, compose.WithStateModifier(stateModifier))
}
}
return nestedOpts, exeCtx.NodeKey, nil
}

View File

@@ -0,0 +1,431 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nodes
import (
"context"
"encoding/json"
"errors"
"fmt"
"reflect"
"regexp"
"strconv"
"strings"
"github.com/bytedance/sonic"
"github.com/bytedance/sonic/ast"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
type TemplatePart struct {
IsVariable bool
Value string
Root string
SubPathsBeforeSlice []string
JsonPath []any
literal string
}
var re = regexp.MustCompile(`{{\s*([^}]+)\s*}}`)
func ParseTemplate(template string) []TemplatePart {
matches := re.FindAllStringSubmatchIndex(template, -1)
parts := make([]TemplatePart, 0)
lastEnd := 0
loop:
for _, match := range matches {
start, end := match[0], match[1]
placeholderStart, placeholderEnd := match[2], match[3]
// Add the literal part before the current variable placeholder
if start > lastEnd {
parts = append(parts, TemplatePart{
IsVariable: false,
Value: template[lastEnd:start],
})
}
// Add the variable placeholder
val := template[placeholderStart:placeholderEnd]
segments := strings.Split(val, ".")
var subPaths []string
if !strings.Contains(segments[0], "[") {
for i := 1; i < len(segments); i++ {
if strings.Contains(segments[i], "[") {
break
}
subPaths = append(subPaths, segments[i])
}
}
var jsonPath []any
for _, segment := range segments {
// find the first '[' to separate the initial key from array accessors
firstBracket := strings.Index(segment, "[")
if firstBracket == -1 {
// No brackets, the whole segment is a key
jsonPath = append(jsonPath, segment)
continue
}
// Add the initial key part
key := segment[:firstBracket]
if key != "" {
jsonPath = append(jsonPath, key)
}
// Now, parse the array accessors like [1][2]
rest := segment[firstBracket:]
for strings.HasPrefix(rest, "[") {
closeBracket := strings.Index(rest, "]")
if closeBracket == -1 {
// Malformed, treat as literal
parts = append(parts, TemplatePart{IsVariable: false, Value: val})
continue loop
}
idxStr := rest[1:closeBracket]
idx, err := strconv.Atoi(idxStr)
if err != nil {
// Malformed, treat as literal
parts = append(parts, TemplatePart{IsVariable: false, Value: val})
continue loop
}
jsonPath = append(jsonPath, idx)
rest = rest[closeBracket+1:]
}
if rest != "" {
// Malformed, treat as literal
parts = append(parts, TemplatePart{IsVariable: false, Value: val})
continue loop
}
}
parts = append(parts, TemplatePart{
IsVariable: true,
Value: val,
Root: removeSlice(segments[0]),
SubPathsBeforeSlice: subPaths,
JsonPath: jsonPath,
literal: "{{" + val + "}}",
})
lastEnd = end
}
// Add the remaining literal part if there is any
if lastEnd < len(template) {
parts = append(parts, TemplatePart{
IsVariable: false,
Value: template[lastEnd:],
})
}
return parts
}
func removeSlice(s string) string {
i := strings.Index(s, "[")
if i != -1 {
return s[:i]
}
return s
}
type renderOptions struct {
type2CustomRenderer map[reflect.Type]func(any) (string, error)
reservedKey map[string]struct{}
nilRenderer func() (string, error)
}
func WithNilRender(fn func() (string, error)) RenderOption {
return func(opts *renderOptions) {
opts.nilRenderer = fn
}
}
type RenderOption func(options *renderOptions)
func WithCustomRender(rType reflect.Type, fn func(any) (string, error)) RenderOption {
return func(opts *renderOptions) {
if opts.type2CustomRenderer == nil {
opts.type2CustomRenderer = make(map[reflect.Type]func(any) (string, error))
}
opts.type2CustomRenderer[rType] = fn
}
}
func WithReservedKey(keys ...string) RenderOption {
return func(opts *renderOptions) {
if opts.reservedKey == nil {
opts.reservedKey = make(map[string]struct{})
}
for _, key := range keys {
opts.reservedKey[key] = struct{}{}
}
}
}
var renderConfig = sonic.Config{
SortMapKeys: true,
}.Froze()
func joinJsonPath(p []any) string {
var sb strings.Builder
for i := range p {
field, ok := p[i].(string)
if ok {
if i > 0 {
_, ok := p[i-1].(string)
if ok {
sb.WriteString(".")
}
}
sb.WriteString(field)
} else {
sb.WriteString(fmt.Sprintf("[%d]", p[i]))
}
}
return sb.String()
}
func (tp TemplatePart) Render(m []byte, opts ...RenderOption) (string, error) {
options := &renderOptions{
type2CustomRenderer: make(map[reflect.Type]func(any) (string, error)),
}
for _, opt := range opts {
opt(options)
}
n, err := sonic.Get(m, tp.JsonPath...)
if err != nil {
notExist := errors.Is(err, ast.ErrNotExist)
var syntaxErr ast.SyntaxError
if notExist || errors.As(err, &syntaxErr) {
// get each path segments one by one until the first not found error
var segParent, current ast.Node
for i := range tp.JsonPath {
current, err = sonic.Get(m, tp.JsonPath[:i+1]...)
if err != nil {
if errors.Is(err, ast.ErrNotExist) { // first not found segment
segmentI, ok := tp.JsonPath[i].(int)
if ok {
if !segParent.Exists() {
panic("impossible")
} else {
segArr, err := segParent.Array()
if err != nil { // not taking elements from array
return tp.literal, nil
}
return "", vo.NewError(errno.ErrArrIndexOutOfRange,
errorx.KV("arr_name", joinJsonPath(tp.JsonPath[:i])),
errorx.KV("req_index", strconv.Itoa(segmentI)),
errorx.KV("arr_len", strconv.Itoa(len(segArr))))
}
}
return tp.literal, nil // not array element not found, but object field, just print
} else if errors.As(err, &syntaxErr) {
segmentI, ok := tp.JsonPath[i].(int)
if ok {
return "", vo.NewError(errno.ErrIndexingNilArray,
errorx.KV("arr_name", joinJsonPath(tp.JsonPath[:i])),
errorx.KV("req_index", strconv.Itoa(segmentI)))
}
return tp.literal, nil // not array element not found, but object field, just print
}
return tp.literal, nil // not ErrNotExist, just print
} else {
segParent = current
}
}
}
return tp.literal, nil
}
i, err := n.InterfaceUseNumber()
if err != nil {
return tp.literal, nil
}
if i == nil {
if options.nilRenderer != nil {
return options.nilRenderer()
}
return "", nil
}
if len(options.type2CustomRenderer) > 0 {
rType := reflect.TypeOf(i)
if fn, ok := options.type2CustomRenderer[rType]; ok {
return fn(i)
}
}
switch i.(type) {
case string:
return i.(string), nil
case json.Number:
return i.(json.Number).String(), nil
case bool:
return strconv.FormatBool(i.(bool)), nil
default:
ms, err := renderConfig.MarshalToString(i) // keep order of the map keys
if err != nil {
return "", err
}
return ms, nil
}
}
func (tp TemplatePart) Skipped(resolvedSources map[string]*SourceInfo) (skipped bool, invalid bool) {
if len(resolvedSources) == 0 { // no information available, maybe outside the scope of a workflow
return false, false
}
// examine along the TemplatePart's root and sub paths,
// trying to find a matching SourceInfo as far as possible.
// the result would be one of two cases:
// - a REAL field source is matched, just check if that field source is skipped
// - otherwise an INTERMEDIATE field source is matched, it can only be skipped if ALL its sub sources are skipped
matchingSource, ok := resolvedSources[tp.Root]
if !ok { // the user specified a non-existing source, it can never have any value, just skip it
return false, true
}
if !matchingSource.IsIntermediate {
return matchingSource.FieldType == FieldSkipped, false
}
for _, subPath := range tp.SubPathsBeforeSlice {
subSource, ok := matchingSource.SubSources[subPath]
if !ok { // has gone deeper than the field source
if matchingSource.IsIntermediate { // the user specified a non-existing source, just skip it
return false, true
}
return matchingSource.FieldType == FieldSkipped, false
}
matchingSource = subSource
}
if !matchingSource.IsIntermediate {
return matchingSource.FieldType == FieldSkipped, false
}
var checkSourceSkipped func(sInfo *SourceInfo) bool
checkSourceSkipped = func(sInfo *SourceInfo) bool {
if !sInfo.IsIntermediate {
return sInfo.FieldType == FieldSkipped
}
for _, subSource := range sInfo.SubSources {
if !checkSourceSkipped(subSource) {
return false
}
}
return true
}
return checkSourceSkipped(matchingSource), false
}
func (tp TemplatePart) TypeInfo(types map[string]*vo.TypeInfo) *vo.TypeInfo {
if len(tp.SubPathsBeforeSlice) == 0 {
return types[tp.Root]
}
rootType, ok := types[tp.Root]
if !ok {
return nil
}
currentType := rootType
for _, subPath := range tp.SubPathsBeforeSlice {
if len(currentType.Properties) == 0 {
return nil
}
subType, ok := currentType.Properties[subPath]
if !ok {
return nil
}
currentType = subType
}
return currentType
}
func Render(ctx context.Context, tpl string, input map[string]any, sources map[string]*SourceInfo, opts ...RenderOption) (string, error) {
mi, err := sonic.Marshal(input)
if err != nil {
return "", err
}
resolvedSources, err := ResolveStreamSources(ctx, sources)
if err != nil {
return "", err
}
options := &renderOptions{}
for _, opt := range opts {
opt(options)
}
var sb strings.Builder
parts := ParseTemplate(tpl)
for _, part := range parts {
if !part.IsVariable {
sb.WriteString(part.Value)
continue
}
if options.reservedKey != nil {
if _, ok := options.reservedKey[part.Root]; ok {
i, err := part.Render(mi, opts...)
if err != nil {
return "", err
}
sb.WriteString(i)
continue
}
}
skipped, invalid := part.Skipped(resolvedSources)
if skipped {
continue
}
if invalid {
sb.WriteString(part.literal)
continue
}
i, err := part.Render(mi, opts...)
if err != nil {
return "", err
}
sb.WriteString(i)
}
return sb.String(), nil
}

View File

@@ -0,0 +1,128 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package textprocessor
import (
"context"
"fmt"
"reflect"
"strings"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
)
type Type string
const (
ConcatText Type = "concat"
SplitText Type = "split"
)
type Config struct {
Type Type `json:"type"`
Tpl string `json:"tpl"`
ConcatChar string `json:"concatChar"`
Separators []string `json:"separator"`
FullSources map[string]*nodes.SourceInfo `json:"fullSources"`
}
type TextProcessor struct {
config *Config
}
func NewTextProcessor(_ context.Context, cfg *Config) (*TextProcessor, error) {
if cfg == nil {
return nil, fmt.Errorf("config requried")
}
if cfg.Type == ConcatText && len(cfg.Tpl) == 0 {
return nil, fmt.Errorf("config tpl requried")
}
return &TextProcessor{
config: cfg,
}, nil
}
const OutputKey = "output"
func (t *TextProcessor) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
switch t.config.Type {
case ConcatText:
arrayRenderer := func(i any) (string, error) {
vs := i.([]any)
return join(vs, t.config.ConcatChar)
}
result, err := nodes.Render(ctx, t.config.Tpl, input, t.config.FullSources,
nodes.WithCustomRender(reflect.TypeOf([]any{}), arrayRenderer))
if err != nil {
return nil, err
}
return map[string]any{OutputKey: result}, nil
case SplitText:
value, ok := input["String"]
if !ok {
return nil, fmt.Errorf("input string requried")
}
valueString, ok := value.(string)
if !ok {
return nil, fmt.Errorf("input string field must string type but got %T", valueString)
}
values := strings.Split(valueString, t.config.Separators[0])
// 对每个分隔符进行迭代处理
for _, sep := range t.config.Separators[1:] {
var tempParts []string
for _, part := range values {
tempParts = append(tempParts, strings.Split(part, sep)...)
}
values = tempParts
}
anyValues := make([]any, 0, len(values))
for _, v := range values {
anyValues = append(anyValues, v)
}
return map[string]any{OutputKey: anyValues}, nil
default:
return nil, fmt.Errorf("not support type %s", t.config.Type)
}
}
func join(vs []any, concatChar string) (string, error) {
as := make([]string, 0, len(vs))
for _, v := range vs {
if v == nil {
as = append(as, "")
continue
}
if _, ok := v.(map[string]any); ok {
bs, err := sonic.Marshal(v)
if err != nil {
return "", err
}
as = append(as, string(bs))
continue
}
as = append(as, fmt.Sprintf("%v", v))
}
return strings.Join(as, concatChar), nil
}

View File

@@ -0,0 +1,69 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package textprocessor
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
)
func TestNewTextProcessorNodeGenerator(t *testing.T) {
ctx := context.Background()
t.Run("split", func(t *testing.T) {
cfg := &Config{
Type: SplitText,
Separators: []string{",", "|", "."},
}
p, err := NewTextProcessor(ctx, cfg)
assert.NoError(t, err)
result, err := p.Invoke(ctx, map[string]any{
"String": "a,b|c.d,e|f|g",
})
assert.NoError(t, err)
assert.Equal(t, result["output"], []any{"a", "b", "c", "d", "e", "f", "g"})
})
t.Run("concat", func(t *testing.T) {
in := map[string]any{
"a": []any{"1", map[string]any{
"1": 1,
}, 3},
"b": map[string]any{
"b1": []string{"1", "2", "3"},
"b2": []any{"1", 2, "3"},
},
"c": map[string]any{
"c1": "1",
},
}
cfg := &Config{
Type: ConcatText,
ConcatChar: `\t`,
Tpl: "fx{{a}}=={{b.b1}}=={{b.b2[1]}}=={{c}}",
}
p, err := NewTextProcessor(context.Background(), cfg)
result, err := p.Invoke(ctx, in)
assert.NoError(t, err)
assert.Equal(t, result["output"], `fx1\t{"1":1}\t3==1\t2\t3==2=={"c1":"1"}`)
})
}

View File

@@ -0,0 +1,283 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nodes
import (
"fmt"
"maps"
"reflect"
"strings"
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
// TakeMapValue extracts the value for specified path from input map.
// Returns false if map key not exist for specified path.
func TakeMapValue(m map[string]any, path compose.FieldPath) (any, bool) {
if m == nil {
return nil, false
}
container := m
for _, p := range path[:len(path)-1] {
if _, ok := container[p]; !ok {
return nil, false
}
container = container[p].(map[string]any)
}
if v, ok := container[path[len(path)-1]]; ok {
return v, true
}
return nil, false
}
func SetMapValue(m map[string]any, path compose.FieldPath, v any) {
container := m
for _, p := range path[:len(path)-1] {
if _, ok := container[p]; !ok {
container[p] = make(map[string]any)
}
container = container[p].(map[string]any)
}
container[path[len(path)-1]] = v
}
func TemplateRender(template string, vals map[string]interface{}) (string, error) {
sb := strings.Builder{}
valsBytes, err := sonic.Marshal(vals)
if err != nil {
return "", vo.WrapError(errno.ErrSerializationDeserializationFail, err)
}
parts := ParseTemplate(template)
for idx := range parts {
part := parts[idx]
if !part.IsVariable {
sb.WriteString(part.Value)
} else {
renderString, err := part.Render(valsBytes)
if err != nil {
return "", err
}
sb.WriteString(renderString)
}
}
return sb.String(), nil
}
func ExtractJSONString(content string) string {
if strings.HasPrefix(content, "```") && strings.HasSuffix(content, "```") {
content = content[3 : len(content)-3]
}
if strings.HasPrefix(content, "json") {
content = content[4:]
}
return content
}
func ConcatTwoMaps(m1, m2 map[string]any) (map[string]any, error) {
merged := maps.Clone(m1)
for k, v := range m2 {
current, ok := merged[k]
if !ok || current == nil {
if vStr, ok := v.(string); ok {
if vStr == KeyIsFinished {
continue
}
}
merged[k] = v
continue
}
vStr, ok1 := v.(string)
currentStr, ok2 := current.(string)
if ok1 && ok2 {
if strings.HasSuffix(vStr, KeyIsFinished) {
vStr = strings.TrimSuffix(vStr, KeyIsFinished)
}
merged[k] = currentStr + vStr
continue
}
vMap, ok1 := v.(map[string]any)
currentMap, ok2 := current.(map[string]any)
if ok1 && ok2 {
concatenated, err := ConcatTwoMaps(currentMap, vMap)
if err != nil {
return nil, err
}
merged[k] = concatenated
continue
}
items, err := toSliceValue([]any{current, v})
if err != nil {
logs.Errorf("failed to convert to slice value: %v", err)
return nil, err
}
var cv reflect.Value
if reflect.TypeOf(v).Kind() == reflect.Map {
cv, err = concatMaps(items)
} else {
cv, err = concatSliceValue(items)
}
if err != nil {
return nil, err
}
merged[k] = cv.Interface()
}
return merged, nil
}
// the following codes are copied from github.com/cloudwego/eino
func concatMaps(ms reflect.Value) (reflect.Value, error) {
typ := ms.Type().Elem()
rms := reflect.MakeMap(reflect.MapOf(typ.Key(), reflect.TypeOf((*[]any)(nil)).Elem()))
ret := reflect.MakeMap(typ)
n := ms.Len()
for i := 0; i < n; i++ {
m := ms.Index(i)
for _, key := range m.MapKeys() {
vals := rms.MapIndex(key)
if !vals.IsValid() {
var s []any
vals = reflect.ValueOf(s)
}
val := m.MapIndex(key)
vals = reflect.Append(vals, val)
rms.SetMapIndex(key, vals)
}
}
for _, key := range rms.MapKeys() {
vals := rms.MapIndex(key)
anyVals := vals.Interface().([]any)
v, err := toSliceValue(anyVals)
if err != nil {
return reflect.Value{}, err
}
var cv reflect.Value
if v.Type().Elem().Kind() == reflect.Map {
cv, err = concatMaps(v)
} else {
cv, err = concatSliceValue(v)
}
if err != nil {
return reflect.Value{}, err
}
ret.SetMapIndex(key, cv)
}
return ret, nil
}
func concatSliceValue(val reflect.Value) (reflect.Value, error) {
elmType := val.Type().Elem()
if val.Len() == 1 {
return val.Index(0), nil
}
f := GetConcatFunc(elmType)
if f != nil {
return f(val)
}
// if all elements in the slice are empty, return an empty value
// if there is exactly one non-empty element in the slice, return that non-empty element
// otherwise, throw an error.
var filtered reflect.Value
for i := 0; i < val.Len(); i++ {
oneVal := val.Index(i)
if !oneVal.IsZero() {
if filtered.IsValid() {
return reflect.Value{}, fmt.Errorf("cannot concat multiple non-zero value of type %s", elmType)
}
filtered = oneVal
}
}
if !filtered.IsValid() {
filtered = reflect.New(elmType).Elem()
}
return filtered, nil
}
func toSliceValue(vs []any) (reflect.Value, error) {
typ := reflect.TypeOf(vs[0])
ret := reflect.MakeSlice(reflect.SliceOf(typ), len(vs), len(vs))
ret.Index(0).Set(reflect.ValueOf(vs[0]))
for i := 1; i < len(vs); i++ {
v := vs[i]
vt := reflect.TypeOf(v)
if typ != vt {
return reflect.Value{}, fmt.Errorf("unexpected slice element type. Got %v, expected %v", typ, vt)
}
ret.Index(i).Set(reflect.ValueOf(v))
}
return ret, nil
}
var (
concatFunctions = map[reflect.Type]any{}
)
func RegisterStreamChunkConcatFunc[T any](fn func([]T) (T, error)) {
concatFunctions[reflect.TypeOf((*T)(nil)).Elem()] = fn
}
func GetConcatFunc(typ reflect.Type) func(reflect.Value) (reflect.Value, error) {
if fn, ok := concatFunctions[typ]; ok {
return func(a reflect.Value) (reflect.Value, error) {
rvs := reflect.ValueOf(fn).Call([]reflect.Value{a})
var err error
if !rvs[1].IsNil() {
err = rvs[1].Interface().(error)
}
return rvs[0], err
}
}
return nil
}

View File

@@ -0,0 +1,596 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package variableaggregator
import (
"context"
"errors"
"fmt"
"io"
"maps"
"math"
"runtime/debug"
"slices"
"sort"
"strconv"
"strings"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
)
type MergeStrategy uint
const (
FirstNotNullValue MergeStrategy = 1
)
type Config struct {
MergeStrategy MergeStrategy
GroupLen map[string]int
FullSources map[string]*nodes.SourceInfo
NodeKey vo.NodeKey
InputSources []*vo.FieldInfo
GroupOrder []string // the order the groups are declared in frontend canvas
}
type VariableAggregator struct {
config *Config
}
func NewVariableAggregator(_ context.Context, cfg *Config) (*VariableAggregator, error) {
if cfg == nil {
return nil, errors.New("config is required")
}
if cfg.MergeStrategy != FirstNotNullValue {
return nil, fmt.Errorf("merge strategy not supported: %v", cfg.MergeStrategy)
}
return &VariableAggregator{config: cfg}, nil
}
func (v *VariableAggregator) Invoke(ctx context.Context, input map[string]any) (_ map[string]any, err error) {
in, err := inputConverter(input)
if err != nil {
return nil, err
}
result := make(map[string]any)
groupToChoice := make(map[string]int)
for group, length := range v.config.GroupLen {
for i := 0; i < length; i++ {
if value, ok := in[group][i]; ok {
if value != nil {
result[group] = value
groupToChoice[group] = i
break
}
}
}
if _, ok := result[group]; !ok {
groupToChoice[group] = -1
}
}
_ = compose.ProcessState(ctx, func(ctx context.Context, state nodes.DynamicStreamContainer) error {
state.SaveDynamicChoice(v.config.NodeKey, groupToChoice)
return nil
})
ctxcache.Store(ctx, groupChoiceTypeCacheKey, map[string]nodes.FieldStreamType{}) // none of the choices are stream
groupChoices := make([]any, 0, len(v.config.GroupOrder))
for _, group := range v.config.GroupOrder {
choice := groupToChoice[group]
if choice == -1 {
groupChoices = append(groupChoices, nil)
} else {
groupChoices = append(groupChoices, choice)
}
}
ctxcache.Store(ctx, groupChoiceCacheKey, groupChoices)
return result, nil
}
const (
resolvedSourcesCacheKey = "resolved_sources"
groupChoiceTypeCacheKey = "group_choice_type"
groupChoiceCacheKey = "group_choice"
)
// Transform picks the first non-nil value from each group from a stream of map[group]items.
func (v *VariableAggregator) Transform(ctx context.Context, input *schema.StreamReader[map[string]any]) (
_ *schema.StreamReader[map[string]any], err error) {
inStream := streamInputConverter(input)
resolvedSources, ok := ctxcache.Get[map[string]*nodes.SourceInfo](ctx, resolvedSourcesCacheKey)
if !ok {
panic("unable to get resolvesSources from ctx cache.")
}
groupToItems := make(map[string][]any)
groupToChoice := make(map[string]int)
type skipped struct{}
type null struct{}
type stream struct{}
defer func() {
if err == nil {
groupChoiceToStreamType := map[string]nodes.FieldStreamType{}
for group, choice := range groupToChoice {
if choice != -1 {
item := groupToItems[group][choice]
if _, ok := item.(stream); ok {
groupChoiceToStreamType[group] = nodes.FieldIsStream
}
}
}
groupChoices := make([]any, 0, len(v.config.GroupOrder))
for _, group := range v.config.GroupOrder {
choice := groupToChoice[group]
if choice == -1 {
groupChoices = append(groupChoices, nil)
} else {
groupChoices = append(groupChoices, choice)
}
}
// store group -> field type for use in callbacks.OnEnd
ctxcache.Store(ctx, groupChoiceTypeCacheKey, groupChoiceToStreamType)
ctxcache.Store(ctx, groupChoiceCacheKey, groupChoices)
}
}()
// goal: find the first non-nil element in each group. 'First' means the smallest index in each group's slice.
// For a stream element, if the stream source is not skipped, then this stream element is non-nil,
// even if there's no content in the stream.
// steps:
// - for each group, iterate over each element in order, check the element's stream type
// - if an element is skipped, move on
// - if an element is stream, pick it
// - if an element is not stream, actually receive from the stream to check if it's non-nil
groupToCurrentIndex := make(map[string]int) // the currently known smallest index that is non-nil for each group
for group, length := range v.config.GroupLen {
groupToItems[group] = make([]any, length)
groupToCurrentIndex[group] = math.MaxInt
for i := 0; i < length; i++ {
fType := resolvedSources[group].SubSources[strconv.Itoa(i)].FieldType
if fType == nodes.FieldSkipped {
groupToItems[group][i] = skipped{}
continue
}
if fType == nodes.FieldIsStream {
groupToItems[group][i] = stream{}
if ci, _ := groupToCurrentIndex[group]; i < ci {
groupToCurrentIndex[group] = i
}
}
}
hasUndecided := false
for i := 0; i < length; i++ {
if groupToItems[group][i] == nil {
hasUndecided = true
break
}
_, ok := groupToItems[group][i].(stream)
if ok { // if none of the elements before this one is none-stream, pick this first stream
groupToChoice[group] = i
break
}
}
if _, ok := groupToChoice[group]; !ok && !hasUndecided {
groupToChoice[group] = -1 // all of this group's elements are skipped, won't have any non-nil ones
}
}
allDone := func() bool {
for group := range v.config.GroupLen {
_, ok := groupToChoice[group]
if !ok {
return false
}
}
return true
}
alreadyDone := allDone()
if alreadyDone { // all groups have made their choices, no need to actually read input streams
result := make(map[string]any, len(v.config.GroupLen))
allSkip := true
for group := range groupToChoice {
choice := groupToChoice[group]
if choice == -1 {
result[group] = nil // all elements of this group are skipped
} else {
result[group] = choice
allSkip = false
}
}
if allSkip { // no need to convert input streams for the output, because all groups are skipped
_ = compose.ProcessState(ctx, func(ctx context.Context, state nodes.DynamicStreamContainer) error {
state.SaveDynamicChoice(v.config.NodeKey, groupToChoice)
return nil
})
return schema.StreamReaderFromArray([]map[string]any{result}), nil
}
}
outS := inStream
if !alreadyDone {
inCopy := inStream.Copy(2)
defer inCopy[0].Close()
outS = inCopy[1]
recvLoop:
for {
chunk, err := inCopy[0].Recv()
if err != nil {
if err == io.EOF {
panic("EOF reached before making choices for all groups")
}
return nil, err
}
for group, items := range chunk {
if _, ok := groupToChoice[group]; ok {
continue // already made the decision for the group.
}
for i := range items {
if i >= groupToCurrentIndex[group] {
continue
}
existing := groupToItems[group][i]
if existing != nil { // belongs to a stream element
continue
}
// now the item is always a non-stream element
item := items[i]
if item == nil {
groupToItems[group][i] = null{}
} else {
groupToItems[group][i] = item
}
groupToCurrentIndex[group] = i
finalized := true
for j := 0; j < i; j++ {
indexedItem := groupToItems[group][j]
if indexedItem == nil { // there exists non-finalized elements in front of the current item
finalized = false
break
}
}
if finalized {
if item == nil { // current item is nil, we need to find the first non-nil element in the group
foundNonNil := false
hasUndecided := false
for j := 0; j < len(groupToItems[group]); j++ {
indexedItem := groupToItems[group][j]
if indexedItem != nil {
_, ok := indexedItem.(skipped)
if ok {
continue
}
_, ok = indexedItem.(null)
if ok {
continue
}
groupToChoice[group] = j
foundNonNil = true
break
} else {
hasUndecided = true
break
}
}
if !foundNonNil && !hasUndecided {
groupToChoice[group] = -1 // this group does not have any non-nil value
}
} else {
groupToChoice[group] = i
}
if allDone() {
break recvLoop
}
}
}
}
}
}
_ = compose.ProcessState(ctx, func(ctx context.Context, state nodes.DynamicStreamContainer) error {
state.SaveDynamicChoice(v.config.NodeKey, groupToChoice)
return nil
})
actualStream := schema.StreamReaderWithConvert(outS, func(in map[string]map[int]any) (map[string]any, error) {
out := make(map[string]any)
for group, items := range in {
choice, ok := groupToChoice[group]
if !ok {
panic(fmt.Sprintf("group %s does not have choice", group))
}
if choice < 0 {
panic(fmt.Sprintf("group %s choice = %d, less than zero, but found actual item in stream", group, choice))
}
if _, ok := items[choice]; ok {
out[group] = items[choice]
}
}
if len(out) == 0 {
return nil, schema.ErrNoValue
}
return out, nil
})
nullGroups := make(map[string]any)
for group, choice := range groupToChoice {
if choice < 0 {
nullGroups[group] = nil
}
}
if len(nullGroups) > 0 {
nullStream := schema.StreamReaderFromArray([]map[string]any{nullGroups})
return schema.MergeStreamReaders([]*schema.StreamReader[map[string]any]{actualStream, nullStream}), nil
}
return actualStream, nil
}
func inputConverter(in map[string]any) (converted map[string]map[int]any, err error) {
converted = make(map[string]map[int]any)
for k, value := range in {
m, ok := value.(map[string]any)
if !ok {
return nil, errors.New("value is not a map[string]any")
}
converted[k] = make(map[int]any, len(m))
for i, sv := range m {
index, err := strconv.Atoi(i)
if err != nil {
return nil, fmt.Errorf(" converting %s to int failed, err=%v", i, err)
}
converted[k][index] = sv
}
}
return converted, nil
}
func streamInputConverter(in *schema.StreamReader[map[string]any]) *schema.StreamReader[map[string]map[int]any] {
converter := func(input map[string]any) (output map[string]map[int]any, err error) {
defer func() {
if r := recover(); r != nil {
err = safego.NewPanicErr(r, debug.Stack())
}
}()
return inputConverter(input)
}
return schema.StreamReaderWithConvert(in, converter)
}
type vaCallbackInput struct {
Name string `json:"name"`
Variables []any `json:"variables"`
}
func (v *VariableAggregator) Init(ctx context.Context) (context.Context, error) {
ctx = ctxcache.Init(ctx)
resolvedSources, err := nodes.ResolveStreamSources(ctx, v.config.FullSources)
if err != nil {
return nil, err
}
// need this info for callbacks.OnStart, so we put it in cache within Init()
ctxcache.Store(ctx, resolvedSourcesCacheKey, resolvedSources)
return ctx, nil
}
type streamMarkerType string
const streamMarker streamMarkerType = "<Stream Data...>"
func (v *VariableAggregator) ToCallbackInput(ctx context.Context, input map[string]any) (map[string]any, error) {
resolvedSources, ok := ctxcache.Get[map[string]*nodes.SourceInfo](ctx, resolvedSourcesCacheKey)
if !ok {
panic("unable to get resolved_sources from ctx cache")
}
in, err := inputConverter(input)
if err != nil {
return nil, err
}
merged := make([]vaCallbackInput, 0, len(in))
groupLen := v.config.GroupLen
for groupName, vars := range in {
orderedVars := make([]any, groupLen[groupName])
for index := range vars {
orderedVars[index] = vars[index]
if len(resolvedSources) > 0 {
if resolvedSources[groupName].SubSources[strconv.Itoa(index)].FieldType == nodes.FieldIsStream {
// replace the streams with streamMarker,
// because we won't read, save to execution history, or display these streams to user
orderedVars[index] = streamMarker
}
}
}
merged = append(merged, vaCallbackInput{
Name: groupName,
Variables: orderedVars,
})
}
// Sort merged slice by Name
sort.Slice(merged, func(i, j int) bool {
return merged[i].Name < merged[j].Name
})
return map[string]any{
"mergeGroups": merged,
}, nil
}
func (v *VariableAggregator) ToCallbackOutput(ctx context.Context, output map[string]any) (*nodes.StructuredCallbackOutput, error) {
dynamicStreamType, ok := ctxcache.Get[map[string]nodes.FieldStreamType](ctx, groupChoiceTypeCacheKey)
if !ok {
panic("unable to get dynamic stream types from ctx cache")
}
groupChoices, ok := ctxcache.Get[[]any](ctx, groupChoiceCacheKey)
if !ok {
panic("unable to get group choices from ctx cache")
}
if len(dynamicStreamType) == 0 {
return &nodes.StructuredCallbackOutput{
Output: output,
RawOutput: output,
Extra: map[string]any{
"variable_select": groupChoices,
},
}, nil
}
newOut := maps.Clone(output)
for k := range output {
if t, ok := dynamicStreamType[k]; ok && t == nodes.FieldIsStream {
newOut[k] = streamMarker
}
}
return &nodes.StructuredCallbackOutput{
Output: newOut,
RawOutput: newOut,
Extra: map[string]any{
"variable_select": groupChoices,
},
}, nil
}
func concatVACallbackInputs(vs [][]vaCallbackInput) ([]vaCallbackInput, error) {
if len(vs) == 0 {
return nil, nil
}
init := slices.Clone(vs[0])
for i := 1; i < len(vs); i++ {
next := vs[i]
for j := 0; j < len(next); j++ {
oneGroup := next[j]
groupName := oneGroup.Name
var (
existingGroup *vaCallbackInput
nextIndex = len(init)
currentIndex int
)
for k := 0; k < len(init); k++ {
if init[k].Name == groupName {
existingGroup = ptr.Of(init[k])
currentIndex = k
} else if init[k].Name > groupName && k < nextIndex {
nextIndex = k
}
}
if existingGroup == nil {
after := slices.Clone(init[nextIndex:])
init = append(init[:nextIndex], oneGroup)
init = append(init, after...)
} else {
for vi := 0; vi < len(oneGroup.Variables); vi++ {
newV := oneGroup.Variables[vi]
if newV == nil {
if vi >= len(existingGroup.Variables) {
for i := len(existingGroup.Variables); i <= vi; i++ {
existingGroup.Variables = append(existingGroup.Variables, nil)
}
}
continue
}
if newStr, ok := newV.(string); ok {
if strings.HasSuffix(newStr, nodes.KeyIsFinished) {
newStr = strings.TrimSuffix(newStr, nodes.KeyIsFinished)
}
newV = newStr
}
for ei := len(existingGroup.Variables); ei <= vi; ei++ {
existingGroup.Variables = append(existingGroup.Variables, nil)
}
ev := existingGroup.Variables[vi]
if ev == nil {
existingGroup.Variables[vi] = oneGroup.Variables[vi]
} else {
if evStr, ok := ev.(streamMarkerType); !ok {
return nil, fmt.Errorf("multiple stream chunk when concating VACallbackInputs, variable %s is not string", ev)
} else {
if evStr != streamMarker || newV.(streamMarkerType) != streamMarker {
return nil, fmt.Errorf("multiple stream chunk when concating VACallbackInputs, variable %s is not streamMarker", ev)
}
existingGroup.Variables[vi] = evStr
}
}
}
init[currentIndex] = *existingGroup
}
}
}
return init, nil
}
func concatStreamMarkers(_ []streamMarkerType) (streamMarkerType, error) {
return streamMarker, nil
}
func init() {
nodes.RegisterStreamChunkConcatFunc(concatVACallbackInputs)
nodes.RegisterStreamChunkConcatFunc(concatStreamMarkers)
}

View File

@@ -0,0 +1,146 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package variableassigner
import (
"context"
"fmt"
"strings"
"sync"
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
type AppVariables struct {
vars map[string]any
mu sync.RWMutex
}
func NewAppVariables() *AppVariables {
return &AppVariables{
vars: make(map[string]any),
}
}
func (av *AppVariables) Set(key string, value any) {
av.mu.Lock()
av.vars[key] = value
av.mu.Unlock()
}
func (av *AppVariables) Get(key string) (any, bool) {
av.mu.RLock()
defer av.mu.RUnlock()
if value, ok := av.vars[key]; ok {
return value, ok
}
return nil, false
}
type AppVariableStore interface {
GetAppVariableValue(key string) (any, bool)
SetAppVariableValue(key string, value any)
}
type VariableAssigner struct {
config *Config
}
type Config struct {
Pairs []*Pair
Handler *variable.Handler
}
type Pair struct {
Left vo.Reference
Right compose.FieldPath
}
func NewVariableAssigner(_ context.Context, conf *Config) (*VariableAssigner, error) {
for _, pair := range conf.Pairs {
if pair.Left.VariableType == nil {
return nil, fmt.Errorf("cannot assign to output of nodes in VariableAssigner, ref: %v", pair.Left)
}
if *pair.Left.VariableType == vo.GlobalSystem {
return nil, fmt.Errorf("cannot assign to global system variables in VariableAssigner because they are read-only, ref: %v", pair.Left)
}
vType := *pair.Left.VariableType
if vType != vo.GlobalAPP && vType != vo.GlobalUser {
return nil, fmt.Errorf("cannot assign to variable type %s in VariableAssigner", vType)
}
}
return &VariableAssigner{
config: conf,
}, nil
}
func (v *VariableAssigner) Assign(ctx context.Context, in map[string]any) (map[string]any, error) {
for _, pair := range v.config.Pairs {
right, ok := nodes.TakeMapValue(in, pair.Right)
if !ok {
return nil, vo.NewError(errno.ErrInputFieldMissing, errorx.KV("name", strings.Join(pair.Right, ".")))
}
vType := *pair.Left.VariableType
switch vType {
case vo.GlobalAPP:
err := compose.ProcessState(ctx, func(ctx context.Context, appVarsStore AppVariableStore) error {
if len(pair.Left.FromPath) != 1 {
return fmt.Errorf("can only assign to top level variable: %v", pair.Left.FromPath)
}
appVarsStore.SetAppVariableValue(pair.Left.FromPath[0], right)
return nil
})
if err != nil {
return nil, err
}
case vo.GlobalUser:
opts := make([]variable.OptionFn, 0, 1)
if exeCtx := execute.GetExeCtx(ctx); exeCtx != nil {
exeCfg := exeCtx.RootCtx.ExeCfg
opts = append(opts, variable.WithStoreInfo(variable.StoreInfo{
AgentID: exeCfg.AgentID,
AppID: exeCfg.AppID,
ConnectorID: exeCfg.ConnectorID,
ConnectorUID: exeCfg.ConnectorUID,
}))
}
err := v.config.Handler.Set(ctx, *pair.Left.VariableType, pair.Left.FromPath, right, opts...)
if err != nil {
return nil, vo.WrapIfNeeded(errno.ErrVariablesAPIFail, err)
}
default:
panic("impossible")
}
}
// TODO if not error considered successful
return map[string]any{
"isSuccess": true,
}, nil
}

View File

@@ -0,0 +1,58 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package variableassigner
import (
"context"
"fmt"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
)
type InLoop struct {
config *Config
intermediateVarStore variable.Store
}
func NewVariableAssignerInLoop(_ context.Context, conf *Config) (*InLoop, error) {
return &InLoop{
config: conf,
intermediateVarStore: &nodes.ParentIntermediateStore{},
}, nil
}
func (v *InLoop) Assign(ctx context.Context, in map[string]any) (out map[string]any, err error) {
for _, pair := range v.config.Pairs {
if pair.Left.VariableType == nil || *pair.Left.VariableType != vo.ParentIntermediate {
panic(fmt.Errorf("dest is %+v in VariableAssignerInloop, invalid", pair.Left))
}
right, ok := nodes.TakeMapValue(in, pair.Right)
if !ok {
return nil, fmt.Errorf("failed to extract right value for path %s", pair.Right)
}
err := v.intermediateVarStore.Set(ctx, pair.Left.FromPath, right)
if err != nil {
return nil, err
}
}
return map[string]any{}, nil
}

View File

@@ -0,0 +1,98 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package variableassigner
import (
"context"
"testing"
"github.com/cloudwego/eino/compose"
"github.com/stretchr/testify/assert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
func TestVariableAssigner(t *testing.T) {
intVar := any(1)
strVar := any("str")
objVar := any(map[string]any{
"key": "value",
})
arrVar := any([]any{1, "2"})
va := &InLoop{
config: &Config{
Pairs: []*Pair{
{
Left: vo.Reference{
FromPath: compose.FieldPath{"int_var_s"},
VariableType: ptr.Of(vo.ParentIntermediate),
},
Right: compose.FieldPath{"int_var_t"},
},
{
Left: vo.Reference{
FromPath: compose.FieldPath{"str_var_s"},
VariableType: ptr.Of(vo.ParentIntermediate),
},
Right: compose.FieldPath{"str_var_t"},
},
{
Left: vo.Reference{
FromPath: compose.FieldPath{"obj_var_s"},
VariableType: ptr.Of(vo.ParentIntermediate),
},
Right: compose.FieldPath{"obj_var_t"},
},
{
Left: vo.Reference{
FromPath: compose.FieldPath{"arr_var_s"},
VariableType: ptr.Of(vo.ParentIntermediate),
},
Right: compose.FieldPath{"arr_var_t"},
},
},
},
intermediateVarStore: &nodes.ParentIntermediateStore{},
}
ctx := nodes.InitIntermediateVars(context.Background(), map[string]*any{
"int_var_s": &intVar,
"str_var_s": &strVar,
"obj_var_s": &objVar,
"arr_var_s": &arrVar,
}, nil)
_, err := va.Assign(ctx, map[string]any{
"int_var_t": 2,
"str_var_t": "str2",
"obj_var_t": map[string]any{
"key2": "value2",
},
"arr_var_t": []any{3, "4"},
})
assert.NoError(t, err)
assert.Equal(t, 2, intVar)
assert.Equal(t, "str2", strVar)
assert.Equal(t, map[string]any{
"key2": "value2",
}, objVar)
assert.Equal(t, []any{3, "4"}, arrVar)
}