refactor: how to add a node type in workflow (#558)

This commit is contained in:
shentongmartin
2025-08-05 14:02:33 +08:00
committed by GitHub
parent 5dafd81a3f
commit bb6ff0026b
96 changed files with 8305 additions and 8717 deletions

View File

@@ -30,50 +30,108 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
)
type Batch struct {
config *Config
outputs map[string]*vo.FieldSource
outputs map[string]*vo.FieldSource
innerWorkflow compose.Runnable[map[string]any, map[string]any]
key vo.NodeKey
inputArrays []string
}
type Config struct {
BatchNodeKey vo.NodeKey `json:"batch_node_key"`
InnerWorkflow compose.Runnable[map[string]any, map[string]any]
type Config struct{}
InputArrays []string `json:"input_arrays"`
Outputs []*vo.FieldInfo `json:"outputs"`
}
func NewBatch(_ context.Context, config *Config) (*Batch, error) {
if config == nil {
return nil, errors.New("config is required")
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
if n.Parent() != nil {
return nil, fmt.Errorf("batch node cannot have parent: %s", n.Parent().ID)
}
if len(config.InputArrays) == 0 {
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeBatch,
Name: n.Data.Meta.Title,
Configs: c,
}
batchSizeField, err := convert.CanvasBlockInputToFieldInfo(n.Data.Inputs.BatchSize,
compose.FieldPath{MaxBatchSizeKey}, nil)
if err != nil {
return nil, err
}
ns.AddInputSource(batchSizeField...)
concurrentSizeField, err := convert.CanvasBlockInputToFieldInfo(n.Data.Inputs.ConcurrentSize,
compose.FieldPath{ConcurrentSizeKey}, nil)
if err != nil {
return nil, err
}
ns.AddInputSource(concurrentSizeField...)
batchSizeType, err := convert.CanvasBlockInputToTypeInfo(n.Data.Inputs.BatchSize)
if err != nil {
return nil, err
}
ns.SetInputType(MaxBatchSizeKey, batchSizeType)
concurrentSizeType, err := convert.CanvasBlockInputToTypeInfo(n.Data.Inputs.ConcurrentSize)
if err != nil {
return nil, err
}
ns.SetInputType(ConcurrentSizeKey, concurrentSizeType)
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err := convert.SetOutputsForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, opts ...schema.BuildOption) (any, error) {
var inputArrays []string
for key, tInfo := range ns.InputTypes {
if tInfo.Type != vo.DataTypeArray {
continue
}
inputArrays = append(inputArrays, key)
}
if len(inputArrays) == 0 {
return nil, errors.New("need to have at least one incoming array for batch")
}
if len(config.Outputs) == 0 {
if len(ns.OutputSources) == 0 {
return nil, errors.New("need to have at least one output variable for batch")
}
b := &Batch{
config: config,
outputs: make(map[string]*vo.FieldSource),
bo := schema.GetBuildOptions(opts...)
if bo.Inner == nil {
return nil, errors.New("need to have inner workflow for batch")
}
for i := range config.Outputs {
source := config.Outputs[i]
b := &Batch{
outputs: make(map[string]*vo.FieldSource),
innerWorkflow: bo.Inner,
key: ns.Key,
inputArrays: inputArrays,
}
for i := range ns.OutputSources {
source := ns.OutputSources[i]
path := source.Path
if len(path) != 1 {
return nil, fmt.Errorf("invalid path %q", path)
}
// from which inner node's which field does the batch's output fields come from
b.outputs[path[0]] = &source.Source
}
@@ -97,11 +155,11 @@ func (b *Batch) initOutput(length int) map[string]any {
return out
}
func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.NestedWorkflowOption) (
func (b *Batch) Invoke(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (
out map[string]any, err error) {
arrays := make(map[string]any, len(b.config.InputArrays))
arrays := make(map[string]any, len(b.inputArrays))
minLen := math.MaxInt64
for _, arrayKey := range b.config.InputArrays {
for _, arrayKey := range b.inputArrays {
a, ok := nodes.TakeMapValue(in, compose.FieldPath{arrayKey})
if !ok {
return nil, fmt.Errorf("incoming array not present in input: %s", arrayKey)
@@ -160,13 +218,13 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
}
}
input[string(b.config.BatchNodeKey)+"#index"] = int64(i)
input[string(b.key)+"#index"] = int64(i)
items := make(map[string]any)
for arrayKey, array := range arrays {
ele := reflect.ValueOf(array).Index(i).Interface()
items[arrayKey] = []any{ele}
currentKey := string(b.config.BatchNodeKey) + "#" + arrayKey
currentKey := string(b.key) + "#" + arrayKey
// Recursively expand map[string]any elements
var expand func(prefix string, val interface{})
@@ -200,15 +258,11 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
return nil
}
options := &nodes.NestedWorkflowOptions{}
for _, opt := range opts {
opt(options)
}
options := nodes.GetCommonOptions(&nodes.NodeOptions{}, opts...)
var existingCState *nodes.NestedWorkflowState
err = compose.ProcessState(ctx, func(ctx context.Context, getter nodes.NestedWorkflowAware) error {
var e error
existingCState, _, e = getter.GetNestedWorkflowState(b.config.BatchNodeKey)
existingCState, _, e = getter.GetNestedWorkflowState(b.key)
if e != nil {
return e
}
@@ -280,7 +334,7 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
mu.Unlock()
if subCheckpointID != "" {
logs.CtxInfof(ctx, "[testInterrupt] prepare %d th run for batch node %s, subCheckPointID %s",
i, b.config.BatchNodeKey, subCheckpointID)
i, b.key, subCheckpointID)
ithOpts = append(ithOpts, compose.WithCheckPointID(subCheckpointID))
}
@@ -298,7 +352,7 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
// if the innerWorkflow has output emitter that requires stream output, then we need to stream the inner workflow
// the output then needs to be concatenated.
taskOutput, err := b.config.InnerWorkflow.Invoke(subCtx, input, ithOpts...)
taskOutput, err := b.innerWorkflow.Invoke(subCtx, input, ithOpts...)
if err != nil {
info, ok := compose.ExtractInterruptInfo(err)
if !ok {
@@ -376,17 +430,17 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
if len(index2InterruptInfo) > 0 { // this invocation of batch.Execute has new interruptions
iEvent := &entity.InterruptEvent{
NodeKey: b.config.BatchNodeKey,
NodeKey: b.key,
NodeType: entity.NodeTypeBatch,
NestedInterruptInfo: index2InterruptInfo, // only emit the newly generated interruptInfo
}
err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error {
if e := setter.SaveNestedWorkflowState(b.config.BatchNodeKey, compState); e != nil {
if e := setter.SaveNestedWorkflowState(b.key, compState); e != nil {
return e
}
return setter.SetInterruptEvent(b.config.BatchNodeKey, iEvent)
return setter.SetInterruptEvent(b.key, iEvent)
})
if err != nil {
return nil, err
@@ -398,7 +452,7 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
return nil, compose.InterruptAndRerun
} else {
err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error {
if e := setter.SaveNestedWorkflowState(b.config.BatchNodeKey, compState); e != nil {
if e := setter.SaveNestedWorkflowState(b.key, compState); e != nil {
return e
}
@@ -409,8 +463,8 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
// although this invocation does not have new interruptions,
// this batch node previously have interrupts yet to be resumed.
// we overwrite the interrupt events, keeping only the interrupts yet to be resumed.
return setter.SetInterruptEvent(b.config.BatchNodeKey, &entity.InterruptEvent{
NodeKey: b.config.BatchNodeKey,
return setter.SetInterruptEvent(b.key, &entity.InterruptEvent{
NodeKey: b.key,
NodeType: entity.NodeTypeBatch,
NestedInterruptInfo: existingCState.Index2InterruptInfo,
})
@@ -424,7 +478,7 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
if existingCState != nil && len(existingCState.Index2InterruptInfo) > 0 {
logs.CtxInfof(ctx, "no interrupt thrown this round, but has historical interrupt events yet to be resumed, "+
"nodeKey: %v. indexes: %v", b.config.BatchNodeKey, maps.Keys(existingCState.Index2InterruptInfo))
"nodeKey: %v. indexes: %v", b.key, maps.Keys(existingCState.Index2InterruptInfo))
return nil, compose.InterruptAndRerun // interrupt again to wait for resuming of previously interrupted index runs
}
@@ -432,8 +486,8 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
}
func (b *Batch) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) {
trimmed := make(map[string]any, len(b.config.InputArrays))
for _, arrayKey := range b.config.InputArrays {
trimmed := make(map[string]any, len(b.inputArrays))
for _, arrayKey := range b.inputArrays {
if v, ok := in[arrayKey]; ok {
trimmed[arrayKey] = v
}

View File

@@ -25,6 +25,10 @@ import (
"golang.org/x/exp/maps"
code2 "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/code"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/infra/contract/coderunner"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
@@ -113,50 +117,77 @@ var pythonThirdPartyWhitelist = map[string]struct{}{
}
type Config struct {
Code string
Language coderunner.Language
OutputConfig map[string]*vo.TypeInfo
Runner coderunner.Runner
Code string
Language coderunner.Language
Runner coderunner.Runner
}
type CodeRunner struct {
config *Config
importError error
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeCodeRunner,
Name: n.Data.Meta.Title,
Configs: c,
}
inputs := n.Data.Inputs
code := inputs.Code
c.Code = code
language, err := convertCodeLanguage(inputs.Language)
if err != nil {
return nil, err
}
c.Language = language
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func NewCodeRunner(ctx context.Context, cfg *Config) (*CodeRunner, error) {
if cfg == nil {
return nil, errors.New("cfg is required")
func convertCodeLanguage(l int64) (coderunner.Language, error) {
switch l {
case 5:
return coderunner.JavaScript, nil
case 3:
return coderunner.Python, nil
default:
return "", fmt.Errorf("invalid language: %d", l)
}
}
if cfg.Language == "" {
return nil, errors.New("language is required")
}
func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
if cfg.Code == "" {
return nil, errors.New("code is required")
}
if cfg.Language != coderunner.Python {
if c.Language != coderunner.Python {
return nil, errors.New("only support python language")
}
if len(cfg.OutputConfig) == 0 {
return nil, errors.New("output config is required")
}
importErr := validatePythonImports(c.Code)
if cfg.Runner == nil {
return nil, errors.New("run coder is required")
}
importErr := validatePythonImports(cfg.Code)
return &CodeRunner{
config: cfg,
importError: importErr,
return &Runner{
code: c.Code,
language: c.Language,
outputConfig: ns.OutputTypes,
runner: code2.GetCodeRunner(),
importError: importErr,
}, nil
}
type Runner struct {
outputConfig map[string]*vo.TypeInfo
code string
language coderunner.Language
runner coderunner.Runner
importError error
}
func validatePythonImports(code string) error {
imports := parsePythonImports(code)
importErrors := make([]string, 0)
@@ -191,11 +222,11 @@ func validatePythonImports(code string) error {
return nil
}
func (c *CodeRunner) RunCode(ctx context.Context, input map[string]any) (ret map[string]any, err error) {
func (c *Runner) Invoke(ctx context.Context, input map[string]any) (ret map[string]any, err error) {
if c.importError != nil {
return nil, vo.WrapError(errno.ErrCodeExecuteFail, c.importError, errorx.KV("detail", c.importError.Error()))
}
response, err := c.config.Runner.Run(ctx, &coderunner.RunRequest{Code: c.config.Code, Language: c.config.Language, Params: input})
response, err := c.runner.Run(ctx, &coderunner.RunRequest{Code: c.code, Language: c.language, Params: input})
if err != nil {
return nil, vo.WrapError(errno.ErrCodeExecuteFail, err, errorx.KV("detail", err.Error()))
}
@@ -203,7 +234,7 @@ func (c *CodeRunner) RunCode(ctx context.Context, input map[string]any) (ret map
result := response.Result
ctxcache.Store(ctx, coderRunnerRawOutputCtxKey, result)
output, ws, err := nodes.ConvertInputs(ctx, result, c.config.OutputConfig)
output, ws, err := nodes.ConvertInputs(ctx, result, c.outputConfig)
if err != nil {
return nil, vo.WrapIfNeeded(errno.ErrCodeExecuteFail, err, errorx.KV("detail", err.Error()))
}
@@ -217,7 +248,7 @@ func (c *CodeRunner) RunCode(ctx context.Context, input map[string]any) (ret map
}
func (c *CodeRunner) ToCallbackOutput(ctx context.Context, output map[string]any) (*nodes.StructuredCallbackOutput, error) {
func (c *Runner) ToCallbackOutput(ctx context.Context, output map[string]any) (*nodes.StructuredCallbackOutput, error) {
rawOutput, ok := ctxcache.Get[map[string]any](ctx, coderRunnerRawOutputCtxKey)
if !ok {
return nil, errors.New("raw output config is required")

View File

@@ -75,30 +75,29 @@ async def main(args:Args)->Output:
mockRunner.EXPECT().Run(gomock.Any(), gomock.Any()).Return(response, nil)
ctx := t.Context()
c := &CodeRunner{
config: &Config{
Language: coderunner.Python,
Code: codeTpl,
OutputConfig: map[string]*vo.TypeInfo{
"key0": {Type: vo.DataTypeInteger},
"key1": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeString}},
"key2": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key3": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key31": &vo.TypeInfo{Type: vo.DataTypeString},
"key32": &vo.TypeInfo{Type: vo.DataTypeString},
"key33": &vo.TypeInfo{Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key34": &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key341": &vo.TypeInfo{Type: vo.DataTypeString},
"key342": &vo.TypeInfo{Type: vo.DataTypeString},
}},
},
},
"key4": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject}},
c := &Runner{
language: coderunner.Python,
code: codeTpl,
outputConfig: map[string]*vo.TypeInfo{
"key0": {Type: vo.DataTypeInteger},
"key1": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeString}},
"key2": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key3": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key31": {Type: vo.DataTypeString},
"key32": {Type: vo.DataTypeString},
"key33": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key34": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key341": {Type: vo.DataTypeString},
"key342": {Type: vo.DataTypeString},
}},
},
Runner: mockRunner,
},
"key4": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject}},
},
runner: mockRunner,
}
ret, err := c.RunCode(ctx, map[string]any{
ret, err := c.Invoke(ctx, map[string]any{
"input": "1123",
})
@@ -145,38 +144,36 @@ async def main(args:Args)->Output:
mockRunner.EXPECT().Run(gomock.Any(), gomock.Any()).Return(response, nil)
ctx := t.Context()
c := &CodeRunner{
config: &Config{
Code: codeTpl,
Language: coderunner.Python,
OutputConfig: map[string]*vo.TypeInfo{
"key0": {Type: vo.DataTypeInteger},
"key1": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeString}},
"key2": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key3": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key31": &vo.TypeInfo{Type: vo.DataTypeString},
"key32": &vo.TypeInfo{Type: vo.DataTypeString},
"key33": &vo.TypeInfo{Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key34": &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key341": &vo.TypeInfo{Type: vo.DataTypeString},
"key342": &vo.TypeInfo{Type: vo.DataTypeString},
}},
c := &Runner{
code: codeTpl,
language: coderunner.Python,
outputConfig: map[string]*vo.TypeInfo{
"key0": {Type: vo.DataTypeInteger},
"key1": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeString}},
"key2": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key3": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key31": {Type: vo.DataTypeString},
"key32": {Type: vo.DataTypeString},
"key33": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key34": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key341": {Type: vo.DataTypeString},
"key342": {Type: vo.DataTypeString},
}},
"key4": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key31": &vo.TypeInfo{Type: vo.DataTypeString},
"key32": &vo.TypeInfo{Type: vo.DataTypeString},
"key33": &vo.TypeInfo{Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key34": &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key341": &vo.TypeInfo{Type: vo.DataTypeString},
"key342": &vo.TypeInfo{Type: vo.DataTypeString},
},
}},
}},
"key4": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key31": {Type: vo.DataTypeString},
"key32": {Type: vo.DataTypeString},
"key33": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key34": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key341": {Type: vo.DataTypeString},
"key342": {Type: vo.DataTypeString},
},
}},
},
Runner: mockRunner,
},
runner: mockRunner,
}
ret, err := c.RunCode(ctx, map[string]any{
ret, err := c.Invoke(ctx, map[string]any{
"input": "1123",
})
@@ -219,30 +216,28 @@ async def main(args:Args)->Output:
}
mockRunner.EXPECT().Run(gomock.Any(), gomock.Any()).Return(response, nil)
c := &CodeRunner{
config: &Config{
Code: codeTpl,
Language: coderunner.Python,
OutputConfig: map[string]*vo.TypeInfo{
"key0": {Type: vo.DataTypeInteger},
"key1": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key2": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key3": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key31": &vo.TypeInfo{Type: vo.DataTypeString},
"key32": &vo.TypeInfo{Type: vo.DataTypeString},
"key33": &vo.TypeInfo{Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key34": &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key341": &vo.TypeInfo{Type: vo.DataTypeString},
"key342": &vo.TypeInfo{Type: vo.DataTypeString},
"key343": &vo.TypeInfo{Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
}},
},
},
c := &Runner{
code: codeTpl,
language: coderunner.Python,
outputConfig: map[string]*vo.TypeInfo{
"key0": {Type: vo.DataTypeInteger},
"key1": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key2": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key3": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key31": {Type: vo.DataTypeString},
"key32": {Type: vo.DataTypeString},
"key33": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key34": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key341": {Type: vo.DataTypeString},
"key342": {Type: vo.DataTypeString},
"key343": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
}},
},
},
Runner: mockRunner,
},
runner: mockRunner,
}
ret, err := c.RunCode(ctx, map[string]any{
ret, err := c.Invoke(ctx, map[string]any{
"input": "1123",
})

View File

@@ -0,0 +1,236 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package database
import (
"fmt"
einoCompose "github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
func setDatabaseInputsForNodeSchema(n *vo.Node, ns *schema.NodeSchema) (err error) {
selectParam := n.Data.Inputs.SelectParam
if selectParam != nil {
err = applyDBConditionToSchema(ns, selectParam.Condition, n.Parent())
if err != nil {
return err
}
}
insertParam := n.Data.Inputs.InsertParam
if insertParam != nil {
err = applyInsetFieldInfoToSchema(ns, insertParam.FieldInfo, n.Parent())
if err != nil {
return err
}
}
deleteParam := n.Data.Inputs.DeleteParam
if deleteParam != nil {
err = applyDBConditionToSchema(ns, &deleteParam.Condition, n.Parent())
if err != nil {
return err
}
}
updateParam := n.Data.Inputs.UpdateParam
if updateParam != nil {
err = applyDBConditionToSchema(ns, &updateParam.Condition, n.Parent())
if err != nil {
return err
}
err = applyInsetFieldInfoToSchema(ns, updateParam.FieldInfo, n.Parent())
if err != nil {
return err
}
}
return nil
}
func applyDBConditionToSchema(ns *schema.NodeSchema, condition *vo.DBCondition, parentNode *vo.Node) error {
if condition.ConditionList == nil {
return nil
}
for idx, params := range condition.ConditionList {
var right *vo.Param
for _, param := range params {
if param == nil {
continue
}
if param.Name == "right" {
right = param
break
}
}
if right == nil {
continue
}
name := fmt.Sprintf("__condition_right_%d", idx)
tInfo, err := convert.CanvasBlockInputToTypeInfo(right.Input)
if err != nil {
return err
}
ns.SetInputType(name, tInfo)
sources, err := convert.CanvasBlockInputToFieldInfo(right.Input, einoCompose.FieldPath{name}, parentNode)
if err != nil {
return err
}
ns.AddInputSource(sources...)
}
return nil
}
func applyInsetFieldInfoToSchema(ns *schema.NodeSchema, fieldInfo [][]*vo.Param, parentNode *vo.Node) error {
if len(fieldInfo) == 0 {
return nil
}
for _, params := range fieldInfo {
// Each FieldInfo is list params, containing two elements.
// The first is to set the name of the field and the second is the corresponding value.
p0 := params[0]
p1 := params[1]
name := p0.Input.Value.Content.(string) // must string type
tInfo, err := convert.CanvasBlockInputToTypeInfo(p1.Input)
if err != nil {
return err
}
name = "__setting_field_" + name
ns.SetInputType(name, tInfo)
sources, err := convert.CanvasBlockInputToFieldInfo(p1.Input, einoCompose.FieldPath{name}, parentNode)
if err != nil {
return err
}
ns.AddInputSource(sources...)
}
return nil
}
func buildClauseGroupFromCondition(condition *vo.DBCondition) (*database.ClauseGroup, error) {
clauseGroup := &database.ClauseGroup{}
if len(condition.ConditionList) == 1 {
params := condition.ConditionList[0]
clause, err := buildClauseFromParams(params)
if err != nil {
return nil, err
}
clauseGroup.Single = clause
} else {
relation, err := convertLogicTypeToRelation(condition.Logic)
if err != nil {
return nil, err
}
clauseGroup.Multi = &database.MultiClause{
Clauses: make([]*database.Clause, 0, len(condition.ConditionList)),
Relation: relation,
}
for i := range condition.ConditionList {
params := condition.ConditionList[i]
clause, err := buildClauseFromParams(params)
if err != nil {
return nil, err
}
clauseGroup.Multi.Clauses = append(clauseGroup.Multi.Clauses, clause)
}
}
return clauseGroup, nil
}
func buildClauseFromParams(params []*vo.Param) (*database.Clause, error) {
var left, operation *vo.Param
for _, p := range params {
if p == nil {
continue
}
if p.Name == "left" {
left = p
continue
}
if p.Name == "operation" {
operation = p
continue
}
}
if left == nil {
return nil, fmt.Errorf("left clause is required")
}
if operation == nil {
return nil, fmt.Errorf("operation clause is required")
}
operator, err := operationToOperator(operation.Input.Value.Content.(string))
if err != nil {
return nil, err
}
clause := &database.Clause{
Left: left.Input.Value.Content.(string),
Operator: operator,
}
return clause, nil
}
func convertLogicTypeToRelation(logicType vo.DatabaseLogicType) (database.ClauseRelation, error) {
switch logicType {
case vo.DatabaseLogicAnd:
return database.ClauseRelationAND, nil
case vo.DatabaseLogicOr:
return database.ClauseRelationOR, nil
default:
return "", fmt.Errorf("logic type %v is invalid", logicType)
}
}
func operationToOperator(s string) (database.Operator, error) {
switch s {
case "EQUAL":
return database.OperatorEqual, nil
case "NOT_EQUAL":
return database.OperatorNotEqual, nil
case "GREATER_THAN":
return database.OperatorGreater, nil
case "LESS_THAN":
return database.OperatorLesser, nil
case "GREATER_EQUAL":
return database.OperatorGreaterOrEqual, nil
case "LESS_EQUAL":
return database.OperatorLesserOrEqual, nil
case "IN":
return database.OperatorIn, nil
case "NOT_IN":
return database.OperatorNotIn, nil
case "IS_NULL":
return database.OperatorIsNull, nil
case "IS_NOT_NULL":
return database.OperatorIsNotNull, nil
case "LIKE":
return database.OperatorLike, nil
case "NOT_LIKE":
return database.OperatorNotLike, nil
}
return "", fmt.Errorf("not a valid Operation string")
}

View File

@@ -342,7 +342,7 @@ func responseFormatted(configOutput map[string]*vo.TypeInfo, response *database.
return ret, nil
}
func convertClauseGroupToConditionGroup(ctx context.Context, clauseGroup *database.ClauseGroup, input map[string]any) (*database.ConditionGroup, error) {
func convertClauseGroupToConditionGroup(_ context.Context, clauseGroup *database.ClauseGroup, input map[string]any) (*database.ConditionGroup, error) {
var (
rightValue any
ok bool
@@ -394,13 +394,13 @@ func convertClauseGroupToConditionGroup(ctx context.Context, clauseGroup *databa
return conditionGroup, nil
}
func convertClauseGroupToUpdateInventory(ctx context.Context, clauseGroup *database.ClauseGroup, input map[string]any) (*UpdateInventory, error) {
func convertClauseGroupToUpdateInventory(ctx context.Context, clauseGroup *database.ClauseGroup, input map[string]any) (*updateInventory, error) {
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, clauseGroup, input)
if err != nil {
return nil, err
}
fields := parseToInput(input)
inventory := &UpdateInventory{
inventory := &updateInventory{
ConditionGroup: conditionGroup,
Fields: fields,
}

View File

@@ -19,48 +19,89 @@ package database
import (
"context"
"errors"
"fmt"
"reflect"
"strconv"
"strings"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
)
type CustomSQLConfig struct {
DatabaseInfoID int64
SQLTemplate string
OutputConfig map[string]*vo.TypeInfo
CustomSQLExecutor database.DatabaseOperator
DatabaseInfoID int64
SQLTemplate string
}
func NewCustomSQL(_ context.Context, cfg *CustomSQLConfig) (*CustomSQL, error) {
if cfg == nil {
return nil, errors.New("config is required")
func (c *CustomSQLConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeDatabaseCustomSQL,
Name: n.Data.Meta.Title,
Configs: c,
}
if cfg.DatabaseInfoID == 0 {
dsList := n.Data.Inputs.DatabaseInfoList
if len(dsList) == 0 {
return nil, fmt.Errorf("database info is requird")
}
databaseInfo := dsList[0]
dsID, err := strconv.ParseInt(databaseInfo.DatabaseInfoID, 10, 64)
if err != nil {
return nil, err
}
c.DatabaseInfoID = dsID
sql := n.Data.Inputs.SQL
if len(sql) == 0 {
return nil, fmt.Errorf("sql is requird")
}
c.SQLTemplate = sql
if err = convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func (c *CustomSQLConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
if c.DatabaseInfoID == 0 {
return nil, errors.New("database info id is required and greater than 0")
}
if cfg.SQLTemplate == "" {
if c.SQLTemplate == "" {
return nil, errors.New("sql template is required")
}
if cfg.CustomSQLExecutor == nil {
return nil, errors.New("custom sqler is required")
}
return &CustomSQL{
config: cfg,
databaseInfoID: c.DatabaseInfoID,
sqlTemplate: c.SQLTemplate,
outputTypes: ns.OutputTypes,
customSQLExecutor: database.GetDatabaseOperator(),
}, nil
}
type CustomSQL struct {
config *CustomSQLConfig
databaseInfoID int64
sqlTemplate string
outputTypes map[string]*vo.TypeInfo
customSQLExecutor database.DatabaseOperator
}
func (c *CustomSQL) Execute(ctx context.Context, input map[string]any) (map[string]any, error) {
func (c *CustomSQL) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
req := &database.CustomSQLRequest{
DatabaseInfoID: c.config.DatabaseInfoID,
DatabaseInfoID: c.databaseInfoID,
IsDebugRun: isDebugExecute(ctx),
UserID: getExecUserID(ctx),
}
@@ -71,7 +112,7 @@ func (c *CustomSQL) Execute(ctx context.Context, input map[string]any) (map[stri
}
templateSQL := ""
templateParts := nodes.ParseTemplate(c.config.SQLTemplate)
templateParts := nodes.ParseTemplate(c.sqlTemplate)
sqlParams := make([]database.SQLParam, 0, len(templateParts))
var nilError = errors.New("field is nil")
for _, templatePart := range templateParts {
@@ -113,12 +154,12 @@ func (c *CustomSQL) Execute(ctx context.Context, input map[string]any) (map[stri
templateSQL = strings.Replace(templateSQL, "`?`", "?", -1)
req.SQL = templateSQL
req.Params = sqlParams
response, err := c.config.CustomSQLExecutor.Execute(ctx, req)
response, err := c.customSQLExecutor.Execute(ctx, req)
if err != nil {
return nil, err
}
ret, err := responseFormatted(c.config.OutputConfig, response)
ret, err := responseFormatted(c.outputTypes, response)
if err != nil {
return nil, err
}

View File

@@ -28,6 +28,7 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database/databasemock"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
type mockCustomSQLer struct {
@@ -39,7 +40,7 @@ func (m mockCustomSQLer) Execute() func(ctx context.Context, request *database.C
m.validate(request)
r := &database.Response{
Objects: []database.Object{
database.Object{
{
"v1": "v1_ret",
"v2": "v2_ret",
},
@@ -58,9 +59,9 @@ func TestCustomSQL_Execute(t *testing.T) {
validate: func(req *database.CustomSQLRequest) {
assert.Equal(t, int64(111), req.DatabaseInfoID)
ps := []database.SQLParam{
database.SQLParam{Value: "v1_value"},
database.SQLParam{Value: "v2_value"},
database.SQLParam{Value: "v3_value"},
{Value: "v1_value"},
{Value: "v2_value"},
{Value: "v3_value"},
}
assert.Equal(t, ps, req.Params)
assert.Equal(t, "select * from v1 where v1 = ? and v2 = ? and v3 = ?", req.SQL)
@@ -80,23 +81,25 @@ func TestCustomSQL_Execute(t *testing.T) {
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
mockDatabaseOperator.EXPECT().Execute(gomock.Any(), gomock.Any()).DoAndReturn(mockSQLer.Execute()).AnyTimes()
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
cfg := &CustomSQLConfig{
DatabaseInfoID: 111,
SQLTemplate: "select * from v1 where v1 = {{v1}} and v2 = '{{v2}}' and v3 = `{{v3}}`",
CustomSQLExecutor: mockDatabaseOperator,
OutputConfig: map[string]*vo.TypeInfo{
DatabaseInfoID: 111,
SQLTemplate: "select * from v1 where v1 = {{v1}} and v2 = '{{v2}}' and v3 = `{{v3}}`",
}
c1, err := cfg.Build(context.Background(), &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"v1": {Type: vo.DataTypeString},
"v2": {Type: vo.DataTypeString},
}}},
"rowNum": {Type: vo.DataTypeInteger},
},
}
cl := &CustomSQL{
config: cfg,
}
})
assert.NoError(t, err)
ret, err := cl.Execute(t.Context(), map[string]any{
ret, err := c1.(*CustomSQL).Invoke(t.Context(), map[string]any{
"v1": "v1_value",
"v2": "v2_value",
"v3": "v3_value",

View File

@@ -20,61 +20,102 @@ import (
"context"
"errors"
"fmt"
"strconv"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
type DeleteConfig struct {
DatabaseInfoID int64
ClauseGroup *database.ClauseGroup
OutputConfig map[string]*vo.TypeInfo
Deleter database.DatabaseOperator
}
type Delete struct {
config *DeleteConfig
}
func NewDelete(_ context.Context, cfg *DeleteConfig) (*Delete, error) {
if cfg == nil {
return nil, errors.New("config is required")
func (d *DeleteConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeDatabaseDelete,
Name: n.Data.Meta.Title,
Configs: d,
}
if cfg.DatabaseInfoID == 0 {
dsList := n.Data.Inputs.DatabaseInfoList
if len(dsList) == 0 {
return nil, fmt.Errorf("database info is requird")
}
databaseInfo := dsList[0]
dsID, err := strconv.ParseInt(databaseInfo.DatabaseInfoID, 10, 64)
if err != nil {
return nil, err
}
d.DatabaseInfoID = dsID
deleteParam := n.Data.Inputs.DeleteParam
clauseGroup, err := buildClauseGroupFromCondition(&deleteParam.Condition)
if err != nil {
return nil, err
}
d.ClauseGroup = clauseGroup
if err = setDatabaseInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func (d *DeleteConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
if d.DatabaseInfoID == 0 {
return nil, errors.New("database info id is required and greater than 0")
}
if cfg.ClauseGroup == nil {
if d.ClauseGroup == nil {
return nil, errors.New("clauseGroup is required")
}
if cfg.Deleter == nil {
return nil, errors.New("deleter is required")
}
return &Delete{
config: cfg,
databaseInfoID: d.DatabaseInfoID,
clauseGroup: d.ClauseGroup,
outputTypes: ns.OutputTypes,
deleter: database.GetDatabaseOperator(),
}, nil
}
func (d *Delete) Delete(ctx context.Context, in map[string]any) (map[string]any, error) {
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, d.config.ClauseGroup, in)
type Delete struct {
databaseInfoID int64
clauseGroup *database.ClauseGroup
outputTypes map[string]*vo.TypeInfo
deleter database.DatabaseOperator
}
func (d *Delete) Invoke(ctx context.Context, in map[string]any) (map[string]any, error) {
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, d.clauseGroup, in)
if err != nil {
return nil, err
}
request := &database.DeleteRequest{
DatabaseInfoID: d.config.DatabaseInfoID,
DatabaseInfoID: d.databaseInfoID,
ConditionGroup: conditionGroup,
IsDebugRun: isDebugExecute(ctx),
UserID: getExecUserID(ctx),
}
response, err := d.config.Deleter.Delete(ctx, request)
response, err := d.deleter.Delete(ctx, request)
if err != nil {
return nil, err
}
ret, err := responseFormatted(d.config.OutputConfig, response)
ret, err := responseFormatted(d.outputTypes, response)
if err != nil {
return nil, err
}
@@ -82,7 +123,7 @@ func (d *Delete) Delete(ctx context.Context, in map[string]any) (map[string]any,
}
func (d *Delete) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) {
conditionGroup, err := convertClauseGroupToConditionGroup(context.Background(), d.config.ClauseGroup, in)
conditionGroup, err := convertClauseGroupToConditionGroup(context.Background(), d.clauseGroup, in)
if err != nil {
return nil, err
}
@@ -90,7 +131,7 @@ func (d *Delete) ToCallbackInput(_ context.Context, in map[string]any) (map[stri
}
func (d *Delete) toDatabaseDeleteCallbackInput(conditionGroup *database.ConditionGroup) (map[string]any, error) {
databaseID := d.config.DatabaseInfoID
databaseID := d.databaseInfoID
result := make(map[string]any)
result["databaseInfoList"] = []string{fmt.Sprintf("%d", databaseID)}

View File

@@ -20,54 +20,84 @@ import (
"context"
"errors"
"fmt"
"strconv"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
type InsertConfig struct {
DatabaseInfoID int64
OutputConfig map[string]*vo.TypeInfo
Inserter database.DatabaseOperator
}
type Insert struct {
config *InsertConfig
}
func NewInsert(_ context.Context, cfg *InsertConfig) (*Insert, error) {
if cfg == nil {
return nil, errors.New("config is required")
func (i *InsertConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeDatabaseInsert,
Name: n.Data.Meta.Title,
Configs: i,
}
if cfg.DatabaseInfoID == 0 {
dsList := n.Data.Inputs.DatabaseInfoList
if len(dsList) == 0 {
return nil, fmt.Errorf("database info is requird")
}
databaseInfo := dsList[0]
dsID, err := strconv.ParseInt(databaseInfo.DatabaseInfoID, 10, 64)
if err != nil {
return nil, err
}
i.DatabaseInfoID = dsID
if err = setDatabaseInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func (i *InsertConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
if i.DatabaseInfoID == 0 {
return nil, errors.New("database info id is required and greater than 0")
}
if cfg.Inserter == nil {
return nil, errors.New("inserter is required")
}
return &Insert{
config: cfg,
databaseInfoID: i.DatabaseInfoID,
outputTypes: ns.OutputTypes,
inserter: database.GetDatabaseOperator(),
}, nil
}
func (is *Insert) Insert(ctx context.Context, input map[string]any) (map[string]any, error) {
type Insert struct {
databaseInfoID int64
outputTypes map[string]*vo.TypeInfo
inserter database.DatabaseOperator
}
func (is *Insert) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
fields := parseToInput(input)
req := &database.InsertRequest{
DatabaseInfoID: is.config.DatabaseInfoID,
DatabaseInfoID: is.databaseInfoID,
Fields: fields,
IsDebugRun: isDebugExecute(ctx),
UserID: getExecUserID(ctx),
}
response, err := is.config.Inserter.Insert(ctx, req)
response, err := is.inserter.Insert(ctx, req)
if err != nil {
return nil, err
}
ret, err := responseFormatted(is.config.OutputConfig, response)
ret, err := responseFormatted(is.outputTypes, response)
if err != nil {
return nil, err
}
@@ -76,7 +106,7 @@ func (is *Insert) Insert(ctx context.Context, input map[string]any) (map[string]
}
func (is *Insert) ToCallbackInput(_ context.Context, input map[string]any) (map[string]any, error) {
databaseID := is.config.DatabaseInfoID
databaseID := is.databaseInfoID
fs := parseToInput(input)
result := make(map[string]any)
result["databaseInfoList"] = []string{fmt.Sprintf("%d", databaseID)}

View File

@@ -20,68 +20,137 @@ import (
"context"
"errors"
"fmt"
"strconv"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
type QueryConfig struct {
DatabaseInfoID int64
QueryFields []string
OrderClauses []*database.OrderClause
OutputConfig map[string]*vo.TypeInfo
ClauseGroup *database.ClauseGroup
Limit int64
Op database.DatabaseOperator
}
type Query struct {
config *QueryConfig
}
func NewQuery(_ context.Context, cfg *QueryConfig) (*Query, error) {
if cfg == nil {
return nil, errors.New("config is required")
func (q *QueryConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeDatabaseQuery,
Name: n.Data.Meta.Title,
Configs: q,
}
if cfg.DatabaseInfoID == 0 {
dsList := n.Data.Inputs.DatabaseInfoList
if len(dsList) == 0 {
return nil, fmt.Errorf("database info is requird")
}
databaseInfo := dsList[0]
dsID, err := strconv.ParseInt(databaseInfo.DatabaseInfoID, 10, 64)
if err != nil {
return nil, err
}
q.DatabaseInfoID = dsID
selectParam := n.Data.Inputs.SelectParam
q.Limit = selectParam.Limit
queryFields := make([]string, 0)
for _, v := range selectParam.FieldList {
queryFields = append(queryFields, strconv.FormatInt(v.FieldID, 10))
}
q.QueryFields = queryFields
orderClauses := make([]*database.OrderClause, 0, len(selectParam.OrderByList))
for _, o := range selectParam.OrderByList {
orderClauses = append(orderClauses, &database.OrderClause{
FieldID: strconv.FormatInt(o.FieldID, 10),
IsAsc: o.IsAsc,
})
}
q.OrderClauses = orderClauses
clauseGroup := &database.ClauseGroup{}
if selectParam.Condition != nil {
clauseGroup, err = buildClauseGroupFromCondition(selectParam.Condition)
if err != nil {
return nil, err
}
}
q.ClauseGroup = clauseGroup
if err = setDatabaseInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func (q *QueryConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
if q.DatabaseInfoID == 0 {
return nil, errors.New("database info id is required and greater than 0")
}
if cfg.Limit == 0 {
if q.Limit == 0 {
return nil, errors.New("limit is required and greater than 0")
}
if cfg.Op == nil {
return nil, errors.New("op is required")
}
return &Query{config: cfg}, nil
return &Query{
databaseInfoID: q.DatabaseInfoID,
queryFields: q.QueryFields,
orderClauses: q.OrderClauses,
outputTypes: ns.OutputTypes,
clauseGroup: q.ClauseGroup,
limit: q.Limit,
op: database.GetDatabaseOperator(),
}, nil
}
func (ds *Query) Query(ctx context.Context, in map[string]any) (map[string]any, error) {
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, ds.config.ClauseGroup, in)
type Query struct {
databaseInfoID int64
queryFields []string
orderClauses []*database.OrderClause
outputTypes map[string]*vo.TypeInfo
clauseGroup *database.ClauseGroup
limit int64
op database.DatabaseOperator
}
func (ds *Query) Invoke(ctx context.Context, in map[string]any) (map[string]any, error) {
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, ds.clauseGroup, in)
if err != nil {
return nil, err
}
req := &database.QueryRequest{
DatabaseInfoID: ds.config.DatabaseInfoID,
OrderClauses: ds.config.OrderClauses,
SelectFields: ds.config.QueryFields,
Limit: ds.config.Limit,
DatabaseInfoID: ds.databaseInfoID,
OrderClauses: ds.orderClauses,
SelectFields: ds.queryFields,
Limit: ds.limit,
IsDebugRun: isDebugExecute(ctx),
UserID: getExecUserID(ctx),
}
req.ConditionGroup = conditionGroup
response, err := ds.config.Op.Query(ctx, req)
response, err := ds.op.Query(ctx, req)
if err != nil {
return nil, err
}
ret, err := responseFormatted(ds.config.OutputConfig, response)
ret, err := responseFormatted(ds.outputTypes, response)
if err != nil {
return nil, err
}
@@ -93,18 +162,18 @@ func notNeedTakeMapValue(op database.Operator) bool {
}
func (ds *Query) ToCallbackInput(ctx context.Context, in map[string]any) (map[string]any, error) {
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, ds.config.ClauseGroup, in)
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, ds.clauseGroup, in)
if err != nil {
return nil, err
}
return toDatabaseQueryCallbackInput(ds.config, conditionGroup)
return ds.toDatabaseQueryCallbackInput(conditionGroup)
}
func toDatabaseQueryCallbackInput(config *QueryConfig, conditionGroup *database.ConditionGroup) (map[string]any, error) {
func (ds *Query) toDatabaseQueryCallbackInput(conditionGroup *database.ConditionGroup) (map[string]any, error) {
result := make(map[string]any)
databaseID := config.DatabaseInfoID
databaseID := ds.databaseInfoID
result["databaseInfoList"] = []string{fmt.Sprintf("%d", databaseID)}
result["selectParam"] = map[string]any{}
@@ -116,8 +185,8 @@ func toDatabaseQueryCallbackInput(config *QueryConfig, conditionGroup *database.
FieldID string `json:"fieldId"`
IsDistinct bool `json:"isDistinct"`
}
fieldList := make([]Field, 0, len(config.QueryFields))
for _, f := range config.QueryFields {
fieldList := make([]Field, 0, len(ds.queryFields))
for _, f := range ds.queryFields {
fieldList = append(fieldList, Field{FieldID: f})
}
type Order struct {
@@ -126,7 +195,7 @@ func toDatabaseQueryCallbackInput(config *QueryConfig, conditionGroup *database.
}
OrderList := make([]Order, 0)
for _, c := range config.OrderClauses {
for _, c := range ds.orderClauses {
OrderList = append(OrderList, Order{
FieldID: c.FieldID,
IsAsc: c.IsAsc,
@@ -135,12 +204,11 @@ func toDatabaseQueryCallbackInput(config *QueryConfig, conditionGroup *database.
result["selectParam"] = map[string]any{
"condition": condition,
"fieldList": fieldList,
"limit": config.Limit,
"limit": ds.limit,
"orderByList": OrderList,
}
return result, nil
}
type ConditionItem struct {
@@ -216,6 +284,5 @@ func convertToLogic(rel database.ClauseRelation) (string, error) {
return "AND", nil
default:
return "", fmt.Errorf("unknown clause relation %v", rel)
}
}

View File

@@ -30,6 +30,7 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database/databasemock"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
type mockDsSelect struct {
@@ -82,16 +83,7 @@ func TestDataset_Query(t *testing.T) {
},
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
QueryFields: []string{"v1", "v2"},
OutputConfig: map[string]*vo.TypeInfo{
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"v1": {Type: vo.DataTypeString},
"v2": {Type: vo.DataTypeString},
},
}},
"rowNum": {Type: vo.DataTypeInteger},
},
Limit: 10,
}
mockQuery := &mockDsSelect{objects: objects, t: t, validate: func(request *database.QueryRequest) {
@@ -106,17 +98,27 @@ func TestDataset_Query(t *testing.T) {
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query())
cfg.Op = mockDatabaseOperator
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
ds := Query{
config: cfg,
}
ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"v1": {Type: vo.DataTypeString},
"v2": {Type: vo.DataTypeString},
},
}},
"rowNum": {Type: vo.DataTypeInteger},
},
})
assert.NoError(t, err)
in := map[string]interface{}{
"__condition_right_0": 1,
}
result, err := ds.Query(t.Context(), in)
result, err := ds.(*Query).Invoke(t.Context(), in)
assert.NoError(t, err)
assert.Equal(t, "1", result["outputList"].([]any)[0].(database.Object)["v1"])
assert.Equal(t, "2", result["outputList"].([]any)[0].(database.Object)["v2"])
@@ -137,17 +139,7 @@ func TestDataset_Query(t *testing.T) {
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
QueryFields: []string{"v1", "v2"},
OutputConfig: map[string]*vo.TypeInfo{
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"v1": {Type: vo.DataTypeString},
"v2": {Type: vo.DataTypeString},
},
}},
"rowNum": {Type: vo.DataTypeInteger},
},
Limit: 10,
}
objects := make([]database.Object, 0)
@@ -170,18 +162,28 @@ func TestDataset_Query(t *testing.T) {
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
cfg.Op = mockDatabaseOperator
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
ds := Query{
config: cfg,
}
ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"v1": {Type: vo.DataTypeString},
"v2": {Type: vo.DataTypeString},
},
}},
"rowNum": {Type: vo.DataTypeInteger},
},
})
assert.NoError(t, err)
in := map[string]any{
"__condition_right_0": 1,
"__condition_right_1": 2,
}
result, err := ds.Query(t.Context(), in)
result, err := ds.(*Query).Invoke(t.Context(), in)
assert.NoError(t, err)
assert.NoError(t, err)
assert.Equal(t, "1", result["outputList"].([]any)[0].(database.Object)["v1"])
@@ -199,17 +201,7 @@ func TestDataset_Query(t *testing.T) {
},
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
QueryFields: []string{"v1", "v2"},
OutputConfig: map[string]*vo.TypeInfo{
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"v1": {Type: vo.DataTypeInteger},
"v2": {Type: vo.DataTypeInteger},
},
}},
"rowNum": {Type: vo.DataTypeInteger},
},
Limit: 10,
}
objects := make([]database.Object, 0)
objects = append(objects, database.Object{
@@ -230,17 +222,27 @@ func TestDataset_Query(t *testing.T) {
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
cfg.Op = mockDatabaseOperator
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
ds := Query{
config: cfg,
}
ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"v1": {Type: vo.DataTypeInteger},
"v2": {Type: vo.DataTypeInteger},
},
}},
"rowNum": {Type: vo.DataTypeInteger},
},
})
assert.NoError(t, err)
in := map[string]any{
"__condition_right_0": 1,
}
result, err := ds.Query(t.Context(), in)
result, err := ds.(*Query).Invoke(t.Context(), in)
assert.NoError(t, err)
fmt.Println(result)
assert.Equal(t, map[string]any{
@@ -261,18 +263,7 @@ func TestDataset_Query(t *testing.T) {
},
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
QueryFields: []string{"v1", "v2"},
OutputConfig: map[string]*vo.TypeInfo{
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"v1": {Type: vo.DataTypeInteger},
"v2": {Type: vo.DataTypeInteger},
"v3": {Type: vo.DataTypeInteger},
},
}},
"rowNum": {Type: vo.DataTypeInteger},
},
Limit: 10,
}
objects := make([]database.Object, 0)
objects = append(objects, database.Object{
@@ -290,15 +281,26 @@ func TestDataset_Query(t *testing.T) {
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
cfg.Op = mockDatabaseOperator
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
ds := Query{
config: cfg,
}
ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"v1": {Type: vo.DataTypeInteger},
"v2": {Type: vo.DataTypeInteger},
"v3": {Type: vo.DataTypeInteger},
},
}},
"rowNum": {Type: vo.DataTypeInteger},
},
})
assert.NoError(t, err)
in := map[string]any{"__condition_right_0": 1}
result, err := ds.Query(t.Context(), in)
result, err := ds.(*Query).Invoke(t.Context(), in)
assert.NoError(t, err)
fmt.Println(result)
assert.Equal(t, int64(1), result["outputList"].([]any)[0].(database.Object)["v1"])
@@ -321,22 +323,7 @@ func TestDataset_Query(t *testing.T) {
},
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
QueryFields: []string{"v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8"},
OutputConfig: map[string]*vo.TypeInfo{
"outputList": {Type: vo.DataTypeArray,
ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"v1": {Type: vo.DataTypeInteger},
"v2": {Type: vo.DataTypeNumber},
"v3": {Type: vo.DataTypeBoolean},
"v4": {Type: vo.DataTypeBoolean},
"v5": {Type: vo.DataTypeTime},
"v6": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeInteger}},
"v7": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeBoolean}},
"v8": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
},
}},
"rowNum": {Type: vo.DataTypeInteger},
},
Limit: 10,
}
objects := make([]database.Object, 0)
@@ -363,17 +350,32 @@ func TestDataset_Query(t *testing.T) {
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
cfg.Op = mockDatabaseOperator
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
ds := Query{
config: cfg,
}
ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
"outputList": {Type: vo.DataTypeArray,
ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"v1": {Type: vo.DataTypeInteger},
"v2": {Type: vo.DataTypeNumber},
"v3": {Type: vo.DataTypeBoolean},
"v4": {Type: vo.DataTypeBoolean},
"v5": {Type: vo.DataTypeTime},
"v6": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeInteger}},
"v7": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeBoolean}},
"v8": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
},
}},
"rowNum": {Type: vo.DataTypeInteger},
},
})
assert.NoError(t, err)
in := map[string]any{
"__condition_right_0": 1,
}
result, err := ds.Query(t.Context(), in)
result, err := ds.(*Query).Invoke(t.Context(), in)
assert.NoError(t, err)
object := result["outputList"].([]any)[0].(database.Object)
@@ -400,10 +402,7 @@ func TestDataset_Query(t *testing.T) {
},
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
QueryFields: []string{"v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8"},
OutputConfig: map[string]*vo.TypeInfo{
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{}}},
"rowNum": {Type: vo.DataTypeInteger},
},
Limit: 10,
}
objects := make([]database.Object, 0)
@@ -429,16 +428,21 @@ func TestDataset_Query(t *testing.T) {
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
cfg.Op = mockDatabaseOperator
ds := Query{
config: cfg,
}
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{}}},
"rowNum": {Type: vo.DataTypeInteger},
},
})
assert.NoError(t, err)
in := map[string]any{
"__condition_right_0": 1,
}
result, err := ds.Query(t.Context(), in)
result, err := ds.(*Query).Invoke(t.Context(), in)
assert.NoError(t, err)
assert.Equal(t, result["outputList"].([]any)[0].(database.Object), database.Object{
"v1": "1",

View File

@@ -20,47 +20,93 @@ import (
"context"
"errors"
"fmt"
"strconv"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
type UpdateConfig struct {
DatabaseInfoID int64
ClauseGroup *database.ClauseGroup
OutputConfig map[string]*vo.TypeInfo
Updater database.DatabaseOperator
}
func (u *UpdateConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeDatabaseUpdate,
Name: n.Data.Meta.Title,
Configs: u,
}
dsList := n.Data.Inputs.DatabaseInfoList
if len(dsList) == 0 {
return nil, fmt.Errorf("database info is requird")
}
databaseInfo := dsList[0]
dsID, err := strconv.ParseInt(databaseInfo.DatabaseInfoID, 10, 64)
if err != nil {
return nil, err
}
u.DatabaseInfoID = dsID
updateParam := n.Data.Inputs.UpdateParam
if updateParam == nil {
return nil, fmt.Errorf("update param is requird")
}
clauseGroup, err := buildClauseGroupFromCondition(&updateParam.Condition)
if err != nil {
return nil, err
}
u.ClauseGroup = clauseGroup
if err = setDatabaseInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func (u *UpdateConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
if u.DatabaseInfoID == 0 {
return nil, errors.New("database info id is required and greater than 0")
}
if u.ClauseGroup == nil {
return nil, errors.New("clause group is required and greater than 0")
}
return &Update{
databaseInfoID: u.DatabaseInfoID,
clauseGroup: u.ClauseGroup,
outputTypes: ns.OutputTypes,
updater: database.GetDatabaseOperator(),
}, nil
}
type Update struct {
config *UpdateConfig
databaseInfoID int64
clauseGroup *database.ClauseGroup
outputTypes map[string]*vo.TypeInfo
updater database.DatabaseOperator
}
type UpdateInventory struct {
type updateInventory struct {
ConditionGroup *database.ConditionGroup
Fields map[string]any
}
func NewUpdate(_ context.Context, cfg *UpdateConfig) (*Update, error) {
if cfg == nil {
return nil, errors.New("config is required")
}
if cfg.DatabaseInfoID == 0 {
return nil, errors.New("database info id is required and greater than 0")
}
if cfg.ClauseGroup == nil {
return nil, errors.New("clause group is required and greater than 0")
}
if cfg.Updater == nil {
return nil, errors.New("updater is required")
}
return &Update{config: cfg}, nil
}
func (u *Update) Update(ctx context.Context, in map[string]any) (map[string]any, error) {
inventory, err := convertClauseGroupToUpdateInventory(ctx, u.config.ClauseGroup, in)
func (u *Update) Invoke(ctx context.Context, in map[string]any) (map[string]any, error) {
inventory, err := convertClauseGroupToUpdateInventory(ctx, u.clauseGroup, in)
if err != nil {
return nil, err
}
@@ -72,20 +118,20 @@ func (u *Update) Update(ctx context.Context, in map[string]any) (map[string]any,
}
req := &database.UpdateRequest{
DatabaseInfoID: u.config.DatabaseInfoID,
DatabaseInfoID: u.databaseInfoID,
ConditionGroup: inventory.ConditionGroup,
Fields: fields,
IsDebugRun: isDebugExecute(ctx),
UserID: getExecUserID(ctx),
}
response, err := u.config.Updater.Update(ctx, req)
response, err := u.updater.Update(ctx, req)
if err != nil {
return nil, err
}
ret, err := responseFormatted(u.config.OutputConfig, response)
ret, err := responseFormatted(u.outputTypes, response)
if err != nil {
return nil, err
}
@@ -94,15 +140,15 @@ func (u *Update) Update(ctx context.Context, in map[string]any) (map[string]any,
}
func (u *Update) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) {
inventory, err := convertClauseGroupToUpdateInventory(context.Background(), u.config.ClauseGroup, in)
inventory, err := convertClauseGroupToUpdateInventory(context.Background(), u.clauseGroup, in)
if err != nil {
return nil, err
}
return u.toDatabaseUpdateCallbackInput(inventory)
}
func (u *Update) toDatabaseUpdateCallbackInput(inventory *UpdateInventory) (map[string]any, error) {
databaseID := u.config.DatabaseInfoID
func (u *Update) toDatabaseUpdateCallbackInput(inventory *updateInventory) (map[string]any, error) {
databaseID := u.databaseInfoID
result := make(map[string]any)
result["databaseInfoList"] = []string{fmt.Sprintf("%d", databaseID)}
result["updateParam"] = map[string]any{}
@@ -128,6 +174,6 @@ func (u *Update) toDatabaseUpdateCallbackInput(inventory *UpdateInventory) (map[
"condition": condition,
"fieldInfo": fieldInfo,
}
return result, nil
return result, nil
}

View File

@@ -18,7 +18,6 @@ package emitter
import (
"context"
"errors"
"fmt"
"io"
"strings"
@@ -26,28 +25,77 @@ import (
"github.com/bytedance/sonic"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
)
type OutputEmitter struct {
cfg *Config
Template string
FullSources map[string]*schema2.SourceInfo
}
type Config struct {
Template string
FullSources map[string]*nodes.SourceInfo
Template string
}
func New(_ context.Context, cfg *Config) (*OutputEmitter, error) {
if cfg == nil {
return nil, errors.New("config is required")
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema2.NodeSchema, error) {
ns := &schema2.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeOutputEmitter,
Name: n.Data.Meta.Title,
Configs: c,
}
content := n.Data.Inputs.Content
streamingOutput := n.Data.Inputs.StreamingOutput
if streamingOutput {
ns.StreamConfigs = &schema2.StreamConfig{
RequireStreamingInput: true,
}
} else {
ns.StreamConfigs = &schema2.StreamConfig{
RequireStreamingInput: false,
}
}
if content != nil {
if content.Type != vo.VariableTypeString {
return nil, fmt.Errorf("output emitter node's content type must be %s, got %s", vo.VariableTypeString, content.Type)
}
if content.Value.Type != vo.BlockInputValueTypeLiteral {
return nil, fmt.Errorf("output emitter node's content value type must be %s, got %s", vo.BlockInputValueTypeLiteral, content.Value.Type)
}
if content.Value.Content == nil {
c.Template = ""
} else {
template, ok := content.Value.Content.(string)
if !ok {
return nil, fmt.Errorf("output emitter node's content value must be string, got %v", content.Value.Content)
}
c.Template = template
}
}
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func (c *Config) Build(_ context.Context, ns *schema2.NodeSchema, _ ...schema2.BuildOption) (any, error) {
return &OutputEmitter{
cfg: cfg,
Template: c.Template,
FullSources: ns.FullSources,
}, nil
}
@@ -59,10 +107,10 @@ type cachedVal struct {
type cacheStore struct {
store map[string]*cachedVal
infos map[string]*nodes.SourceInfo
infos map[string]*schema2.SourceInfo
}
func newCacheStore(infos map[string]*nodes.SourceInfo) *cacheStore {
func newCacheStore(infos map[string]*schema2.SourceInfo) *cacheStore {
return &cacheStore{
store: make(map[string]*cachedVal),
infos: infos,
@@ -76,7 +124,7 @@ func (c *cacheStore) put(k string, v any) (any, error) {
}
if !sInfo.IsIntermediate { // this is not an intermediate object container
isStream := sInfo.FieldType == nodes.FieldIsStream
isStream := sInfo.FieldType == schema2.FieldIsStream
if !isStream {
_, ok := c.store[k]
if !ok {
@@ -159,7 +207,7 @@ func (c *cacheStore) put(k string, v any) (any, error) {
func (c *cacheStore) finished(k string) bool {
cached, ok := c.store[k]
if !ok {
return c.infos[k].FieldType == nodes.FieldSkipped
return c.infos[k].FieldType == schema2.FieldSkipped
}
if cached.finished {
@@ -182,7 +230,7 @@ func (c *cacheStore) finished(k string) bool {
return true
}
func (c *cacheStore) find(part nodes.TemplatePart) (root any, subCache *cachedVal, sourceInfo *nodes.SourceInfo,
func (c *cacheStore) find(part nodes.TemplatePart) (root any, subCache *cachedVal, sourceInfo *schema2.SourceInfo,
actualPath []string,
) {
rootCached, ok := c.store[part.Root]
@@ -230,7 +278,7 @@ func (c *cacheStore) readyForPart(part nodes.TemplatePart, sw *schema.StreamWrit
hasErr bool, partFinished bool) {
cachedRoot, subCache, sourceInfo, _ := c.find(part)
if cachedRoot != nil && subCache != nil {
if subCache.finished || sourceInfo.FieldType == nodes.FieldIsStream {
if subCache.finished || sourceInfo.FieldType == schema2.FieldIsStream {
hasErr = renderAndSend(part, part.Root, cachedRoot, sw)
if hasErr {
return true, false
@@ -315,14 +363,14 @@ func merge(a, b any) any {
const outputKey = "output"
func (e *OutputEmitter) EmitStream(ctx context.Context, in *schema.StreamReader[map[string]any]) (out *schema.StreamReader[map[string]any], err error) {
resolvedSources, err := nodes.ResolveStreamSources(ctx, e.cfg.FullSources)
func (e *OutputEmitter) Transform(ctx context.Context, in *schema.StreamReader[map[string]any]) (out *schema.StreamReader[map[string]any], err error) {
resolvedSources, err := nodes.ResolveStreamSources(ctx, e.FullSources)
if err != nil {
return nil, err
}
sr, sw := schema.Pipe[map[string]any](0)
parts := nodes.ParseTemplate(e.cfg.Template)
parts := nodes.ParseTemplate(e.Template)
safego.Go(ctx, func() {
hasErr := false
defer func() {
@@ -454,7 +502,7 @@ func (e *OutputEmitter) EmitStream(ctx context.Context, in *schema.StreamReader[
shouldChangePart = true
}
} else {
if sourceInfo.FieldType == nodes.FieldIsStream {
if sourceInfo.FieldType == schema2.FieldIsStream {
currentV := v
for i := 0; i < len(actualPath)-1; i++ {
currentM, ok := currentV.(map[string]any)
@@ -518,8 +566,8 @@ func (e *OutputEmitter) EmitStream(ctx context.Context, in *schema.StreamReader[
return sr, nil
}
func (e *OutputEmitter) Emit(ctx context.Context, in map[string]any) (output map[string]any, err error) {
s, err := nodes.Render(ctx, e.cfg.Template, in, e.cfg.FullSources)
func (e *OutputEmitter) Invoke(ctx context.Context, in map[string]any) (output map[string]any, err error) {
s, err := nodes.Render(ctx, e.Template, in, e.FullSources)
if err != nil {
return nil, err
}

View File

@@ -20,41 +20,74 @@ import (
"context"
"fmt"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
type Config struct {
DefaultValues map[string]any
OutputTypes map[string]*vo.TypeInfo
}
type Entry struct {
cfg *Config
defaultValues map[string]any
}
func NewEntry(ctx context.Context, cfg *Config) (*Entry, error) {
if cfg == nil {
return nil, fmt.Errorf("config is requried")
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
if n.Parent() != nil {
return nil, fmt.Errorf("entry node cannot have parent: %s", n.Parent().ID)
}
defaultValues, _, err := nodes.ConvertInputs(ctx, cfg.DefaultValues, cfg.OutputTypes, nodes.FailFast(), nodes.SkipRequireCheck())
if n.ID != entity.EntryNodeKey {
return nil, fmt.Errorf("entry node id must be %s, got %s", entity.EntryNodeKey, n.ID)
}
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Name: n.Data.Meta.Title,
Type: entity.NodeTypeEntry,
}
defaultValues := make(map[string]any, len(n.Data.Outputs))
for _, v := range n.Data.Outputs {
variable, err := vo.ParseVariable(v)
if err != nil {
return nil, err
}
if variable.DefaultValue != nil {
defaultValues[variable.Name] = variable.DefaultValue
}
}
c.DefaultValues = defaultValues
ns.Configs = c
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func (c *Config) Build(ctx context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
defaultValues, _, err := nodes.ConvertInputs(ctx, c.DefaultValues, ns.OutputTypes, nodes.FailFast(), nodes.SkipRequireCheck())
if err != nil {
return nil, err
}
return &Entry{
cfg: cfg,
defaultValues: defaultValues,
outputTypes: ns.OutputTypes,
}, nil
}
type Entry struct {
defaultValues map[string]any
outputTypes map[string]*vo.TypeInfo
}
func (e *Entry) Invoke(_ context.Context, in map[string]any) (out map[string]any, err error) {
for k, v := range e.defaultValues {
if val, ok := in[k]; ok {
tInfo := e.cfg.OutputTypes[k]
tInfo := e.outputTypes[k]
switch tInfo.Type {
case vo.DataTypeString:
if len(val.(string)) == 0 {

View File

@@ -0,0 +1,113 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package exit
import (
"context"
"fmt"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/emitter"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
type Config struct {
Template string
TerminatePlan vo.TerminatePlan
}
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
if n.Parent() != nil {
return nil, fmt.Errorf("exit node cannot have parent: %s", n.Parent().ID)
}
if n.ID != entity.ExitNodeKey {
return nil, fmt.Errorf("exit node id must be %s, got %s", entity.ExitNodeKey, n.ID)
}
ns := &schema.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Name: n.Data.Meta.Title,
Configs: c,
}
var (
content *vo.BlockInput
streamingOutput bool
)
if n.Data.Inputs.OutputEmitter != nil {
content = n.Data.Inputs.Content
streamingOutput = n.Data.Inputs.StreamingOutput
}
if streamingOutput {
ns.StreamConfigs = &schema.StreamConfig{
RequireStreamingInput: true,
}
} else {
ns.StreamConfigs = &schema.StreamConfig{
RequireStreamingInput: false,
}
}
if content != nil {
if content.Type != vo.VariableTypeString {
return nil, fmt.Errorf("exit node's content type must be %s, got %s", vo.VariableTypeString, content.Type)
}
if content.Value.Type != vo.BlockInputValueTypeLiteral {
return nil, fmt.Errorf("exit node's content value type must be %s, got %s", vo.BlockInputValueTypeLiteral, content.Value.Type)
}
c.Template = content.Value.Content.(string)
}
if n.Data.Inputs.TerminatePlan == nil {
return nil, fmt.Errorf("exit node requires a TerminatePlan")
}
c.TerminatePlan = *n.Data.Inputs.TerminatePlan
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
if c.TerminatePlan == vo.ReturnVariables {
return &Exit{}, nil
}
return &emitter.OutputEmitter{
Template: c.Template,
FullSources: ns.FullSources,
}, nil
}
type Exit struct{}
func (e *Exit) Invoke(_ context.Context, in map[string]any) (map[string]any, error) {
if in == nil {
return map[string]any{}, nil
}
return in, nil
}

View File

@@ -0,0 +1,340 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package httprequester
import (
"fmt"
"regexp"
"strings"
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/lang/crypto"
)
var extractBracesRegexp = regexp.MustCompile(`\{\{(.*?)}}`)
func extractBracesContent(s string) []string {
matches := extractBracesRegexp.FindAllStringSubmatch(s, -1)
var result []string
for _, match := range matches {
if len(match) >= 2 {
result = append(result, match[1])
}
}
return result
}
type ImplicitNodeDependency struct {
NodeID string
FieldPath compose.FieldPath
TypeInfo *vo.TypeInfo
}
func extractImplicitDependency(node *vo.Node, canvas *vo.Canvas) ([]*ImplicitNodeDependency, error) {
dependencies := make([]*ImplicitNodeDependency, 0, len(canvas.Nodes))
url := node.Data.Inputs.APIInfo.URL
urlVars := extractBracesContent(url)
hasReferred := make(map[string]bool)
extractDependenciesFromVars := func(vars []string) error {
for _, v := range vars {
if strings.HasPrefix(v, "block_output_") {
paths := strings.Split(strings.TrimPrefix(v, "block_output_"), ".")
if len(paths) < 2 {
return fmt.Errorf("invalid block_output_ variable: %s", v)
}
if hasReferred[v] {
continue
}
hasReferred[v] = true
dependencies = append(dependencies, &ImplicitNodeDependency{
NodeID: paths[0],
FieldPath: paths[1:],
})
}
}
return nil
}
err := extractDependenciesFromVars(urlVars)
if err != nil {
return nil, err
}
if node.Data.Inputs.Body.BodyType == string(BodyTypeJSON) {
jsonVars := extractBracesContent(node.Data.Inputs.Body.BodyData.Json)
err = extractDependenciesFromVars(jsonVars)
if err != nil {
return nil, err
}
}
if node.Data.Inputs.Body.BodyType == string(BodyTypeRawText) {
rawTextVars := extractBracesContent(node.Data.Inputs.Body.BodyData.Json)
err = extractDependenciesFromVars(rawTextVars)
if err != nil {
return nil, err
}
}
var nodeFinder func(nodes []*vo.Node, nodeID string) *vo.Node
nodeFinder = func(nodes []*vo.Node, nodeID string) *vo.Node {
for i := range nodes {
if nodes[i].ID == nodeID {
return nodes[i]
}
if len(nodes[i].Blocks) > 0 {
if n := nodeFinder(nodes[i].Blocks, nodeID); n != nil {
return n
}
}
}
return nil
}
for _, ds := range dependencies {
fNode := nodeFinder(canvas.Nodes, ds.NodeID)
if fNode == nil {
continue
}
tInfoMap := make(map[string]*vo.TypeInfo, len(node.Data.Outputs))
for _, vAny := range fNode.Data.Outputs {
v, err := vo.ParseVariable(vAny)
if err != nil {
return nil, err
}
tInfo, err := convert.CanvasVariableToTypeInfo(v)
if err != nil {
return nil, err
}
tInfoMap[v.Name] = tInfo
}
tInfo, ok := getTypeInfoByPath(ds.FieldPath[0], ds.FieldPath[1:], tInfoMap)
if !ok {
return nil, fmt.Errorf("cannot find type info for dependency: %s", ds.FieldPath)
}
ds.TypeInfo = tInfo
}
return dependencies, nil
}
func getTypeInfoByPath(root string, properties []string, tInfoMap map[string]*vo.TypeInfo) (*vo.TypeInfo, bool) {
if len(properties) == 0 {
if tInfo, ok := tInfoMap[root]; ok {
return tInfo, true
}
return nil, false
}
tInfo, ok := tInfoMap[root]
if !ok {
return nil, false
}
return getTypeInfoByPath(properties[0], properties[1:], tInfo.Properties)
}
var globalVariableRegex = regexp.MustCompile(`global_variable_\w+\s*\["(.*?)"]`)
func setHttpRequesterInputsForNodeSchema(n *vo.Node, ns *schema.NodeSchema, implicitNodeDependencies []*ImplicitNodeDependency) (err error) {
inputs := n.Data.Inputs
implicitPathVars := make(map[string]bool)
addImplicitVarsSources := func(prefix string, vars []string) error {
for _, v := range vars {
if strings.HasPrefix(v, "block_output_") {
paths := strings.Split(strings.TrimPrefix(v, "block_output_"), ".")
if len(paths) < 2 {
return fmt.Errorf("invalid implicit var : %s", v)
}
for _, dep := range implicitNodeDependencies {
if dep.NodeID == paths[0] && strings.Join(dep.FieldPath, ".") == strings.Join(paths[1:], ".") {
pathValue := prefix + crypto.MD5HexValue(v)
if _, visited := implicitPathVars[pathValue]; visited {
continue
}
implicitPathVars[pathValue] = true
ns.SetInputType(pathValue, dep.TypeInfo)
ns.AddInputSource(&vo.FieldInfo{
Path: []string{pathValue},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: vo.NodeKey(dep.NodeID),
FromPath: dep.FieldPath,
},
},
})
}
}
}
if strings.HasPrefix(v, "global_variable_") {
matches := globalVariableRegex.FindStringSubmatch(v)
if len(matches) < 2 {
continue
}
var varType vo.GlobalVarType
if strings.HasPrefix(v, string(vo.RefSourceTypeGlobalApp)) {
varType = vo.GlobalAPP
} else if strings.HasPrefix(v, string(vo.RefSourceTypeGlobalUser)) {
varType = vo.GlobalUser
} else if strings.HasPrefix(v, string(vo.RefSourceTypeGlobalSystem)) {
varType = vo.GlobalSystem
} else {
return fmt.Errorf("invalid global variable type: %s", v)
}
source := vo.FieldSource{
Ref: &vo.Reference{
VariableType: &varType,
FromPath: []string{matches[1]},
},
}
ns.AddInputSource(&vo.FieldInfo{
Path: []string{prefix + crypto.MD5HexValue(v)},
Source: source,
})
}
}
return nil
}
urlVars := extractBracesContent(inputs.APIInfo.URL)
err = addImplicitVarsSources("__apiInfo_url_", urlVars)
if err != nil {
return err
}
err = applyParamsToSchema(ns, "__headers_", inputs.Headers, n.Parent())
if err != nil {
return err
}
err = applyParamsToSchema(ns, "__params_", inputs.Params, n.Parent())
if err != nil {
return err
}
if inputs.Auth != nil && inputs.Auth.AuthOpen {
authData := inputs.Auth.AuthData
const bearerTokenKey = "__auth_authData_bearerTokenData_token"
if inputs.Auth.AuthType == "BEARER_AUTH" {
bearTokenParam := authData.BearerTokenData[0]
tInfo, err := convert.CanvasBlockInputToTypeInfo(bearTokenParam.Input)
if err != nil {
return err
}
ns.SetInputType(bearerTokenKey, tInfo)
sources, err := convert.CanvasBlockInputToFieldInfo(bearTokenParam.Input, compose.FieldPath{bearerTokenKey}, n.Parent())
if err != nil {
return err
}
ns.AddInputSource(sources...)
}
if inputs.Auth.AuthType == "CUSTOM_AUTH" {
const (
customDataDataKey = "__auth_authData_customData_data_Key"
customDataDataValue = "__auth_authData_customData_data_Value"
)
dataParams := authData.CustomData.Data
keyParam := dataParams[0]
keyTypeInfo, err := convert.CanvasBlockInputToTypeInfo(keyParam.Input)
if err != nil {
return err
}
ns.SetInputType(customDataDataKey, keyTypeInfo)
sources, err := convert.CanvasBlockInputToFieldInfo(keyParam.Input, compose.FieldPath{customDataDataKey}, n.Parent())
if err != nil {
return err
}
ns.AddInputSource(sources...)
valueParam := dataParams[1]
valueTypeInfo, err := convert.CanvasBlockInputToTypeInfo(valueParam.Input)
if err != nil {
return err
}
ns.SetInputType(customDataDataValue, valueTypeInfo)
sources, err = convert.CanvasBlockInputToFieldInfo(valueParam.Input, compose.FieldPath{customDataDataValue}, n.Parent())
if err != nil {
return err
}
ns.AddInputSource(sources...)
}
}
switch BodyType(inputs.Body.BodyType) {
case BodyTypeFormData:
err = applyParamsToSchema(ns, "__body_bodyData_formData_", inputs.Body.BodyData.FormData.Data, n.Parent())
if err != nil {
return err
}
case BodyTypeFormURLEncoded:
err = applyParamsToSchema(ns, "__body_bodyData_formURLEncoded_", inputs.Body.BodyData.FormURLEncoded, n.Parent())
if err != nil {
return err
}
case BodyTypeBinary:
const fileURLName = "__body_bodyData_binary_fileURL"
fileURLInput := inputs.Body.BodyData.Binary.FileURL
ns.SetInputType(fileURLName, &vo.TypeInfo{
Type: vo.DataTypeString,
})
sources, err := convert.CanvasBlockInputToFieldInfo(fileURLInput, compose.FieldPath{fileURLName}, n.Parent())
if err != nil {
return err
}
ns.AddInputSource(sources...)
case BodyTypeJSON:
jsonVars := extractBracesContent(inputs.Body.BodyData.Json)
err = addImplicitVarsSources("__body_bodyData_json_", jsonVars)
if err != nil {
return err
}
case BodyTypeRawText:
rawTextVars := extractBracesContent(inputs.Body.BodyData.RawText)
err = addImplicitVarsSources("__body_bodyData_rawText_", rawTextVars)
if err != nil {
return err
}
}
return nil
}
func applyParamsToSchema(ns *schema.NodeSchema, prefix string, params []*vo.Param, parentNode *vo.Node) error {
for i := range params {
param := params[i]
name := param.Name
tInfo, err := convert.CanvasBlockInputToTypeInfo(param.Input)
if err != nil {
return err
}
fieldName := prefix + crypto.MD5HexValue(name)
ns.SetInputType(fieldName, tInfo)
sources, err := convert.CanvasBlockInputToFieldInfo(param.Input, compose.FieldPath{fieldName}, parentNode)
if err != nil {
return err
}
ns.AddInputSource(sources...)
}
return nil
}

View File

@@ -31,9 +31,14 @@ import (
"strings"
"time"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/lang/crypto"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
)
@@ -129,7 +134,7 @@ type Request struct {
FileURL *string
}
var globalVariableReplaceRegexp = regexp.MustCompile(`global_variable_(\w+)\["(\w+)"\]`)
var globalVariableReplaceRegexp = regexp.MustCompile(`global_variable_(\w+)\["(\w+)"]`)
type MD5FieldMapping struct {
HeaderMD5Mapping map[string]string `json:"header_md_5_mapping,omitempty"` // md5 vs key
@@ -184,49 +189,188 @@ type Config struct {
Timeout time.Duration
RetryTimes uint64
IgnoreException bool
DefaultOutput map[string]any
MD5FieldMapping
}
type HTTPRequester struct {
client *http.Client
config *Config
}
func NewHTTPRequester(_ context.Context, cfg *Config) (*HTTPRequester, error) {
if cfg == nil {
return nil, fmt.Errorf("config is requried")
func (c *Config) Adapt(_ context.Context, n *vo.Node, opts ...nodes.AdaptOption) (*schema.NodeSchema, error) {
options := nodes.GetAdaptOptions(opts...)
if options.Canvas == nil {
return nil, fmt.Errorf("canvas is requried when adapting HTTPRequester node")
}
if len(cfg.Method) == 0 {
implicitDeps, err := extractImplicitDependency(n, options.Canvas)
if err != nil {
return nil, err
}
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeHTTPRequester,
Name: n.Data.Meta.Title,
Configs: c,
}
inputs := n.Data.Inputs
md5FieldMapping := &MD5FieldMapping{}
method := inputs.APIInfo.Method
c.Method = method
reqURL := inputs.APIInfo.URL
c.URLConfig = URLConfig{
Tpl: strings.TrimSpace(reqURL),
}
urlVars := extractBracesContent(reqURL)
md5FieldMapping.SetURLFields(urlVars...)
md5FieldMapping.SetHeaderFields(slices.Transform(inputs.Headers, func(a *vo.Param) string {
return a.Name
})...)
md5FieldMapping.SetParamFields(slices.Transform(inputs.Params, func(a *vo.Param) string {
return a.Name
})...)
if inputs.Auth != nil && inputs.Auth.AuthOpen {
auth := &AuthenticationConfig{}
ty, err := convertAuthType(inputs.Auth.AuthType)
if err != nil {
return nil, err
}
auth.Type = ty
location, err := convertLocation(inputs.Auth.AuthData.CustomData.AddTo)
if err != nil {
return nil, err
}
auth.Location = location
c.AuthConfig = auth
}
bodyConfig := BodyConfig{}
bodyConfig.BodyType = BodyType(inputs.Body.BodyType)
switch BodyType(inputs.Body.BodyType) {
case BodyTypeJSON:
jsonTpl := inputs.Body.BodyData.Json
bodyConfig.TextJsonConfig = &TextJsonConfig{
Tpl: jsonTpl,
}
jsonVars := extractBracesContent(jsonTpl)
md5FieldMapping.SetBodyFields(jsonVars...)
case BodyTypeFormData:
bodyConfig.FormDataConfig = &FormDataConfig{
FileTypeMapping: map[string]bool{},
}
formDataVars := make([]string, 0)
for i := range inputs.Body.BodyData.FormData.Data {
p := inputs.Body.BodyData.FormData.Data[i]
formDataVars = append(formDataVars, p.Name)
if p.Input.Type == vo.VariableTypeString && p.Input.AssistType > vo.AssistTypeNotSet && p.Input.AssistType < vo.AssistTypeTime {
bodyConfig.FormDataConfig.FileTypeMapping[p.Name] = true
}
}
md5FieldMapping.SetBodyFields(formDataVars...)
case BodyTypeRawText:
TextTpl := inputs.Body.BodyData.RawText
bodyConfig.TextPlainConfig = &TextPlainConfig{
Tpl: TextTpl,
}
textPlainVars := extractBracesContent(TextTpl)
md5FieldMapping.SetBodyFields(textPlainVars...)
case BodyTypeFormURLEncoded:
formURLEncodedVars := make([]string, 0)
for _, p := range inputs.Body.BodyData.FormURLEncoded {
formURLEncodedVars = append(formURLEncodedVars, p.Name)
}
md5FieldMapping.SetBodyFields(formURLEncodedVars...)
}
c.BodyConfig = bodyConfig
c.MD5FieldMapping = *md5FieldMapping
if inputs.Setting != nil {
c.Timeout = time.Duration(inputs.Setting.Timeout) * time.Second
c.RetryTimes = uint64(inputs.Setting.RetryTimes)
}
if err := setHttpRequesterInputsForNodeSchema(n, ns, implicitDeps); err != nil {
return nil, err
}
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func convertAuthType(auth string) (AuthType, error) {
switch auth {
case "CUSTOM_AUTH":
return Custom, nil
case "BEARER_AUTH":
return BearToken, nil
default:
return AuthType(0), fmt.Errorf("invalid auth type")
}
}
func convertLocation(l string) (Location, error) {
switch l {
case "header":
return Header, nil
case "query":
return QueryParam, nil
default:
return 0, fmt.Errorf("invalid location")
}
}
func (c *Config) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
if len(c.Method) == 0 {
return nil, fmt.Errorf("method is requried")
}
hg := &HTTPRequester{}
hg := &HTTPRequester{
urlConfig: c.URLConfig,
method: c.Method,
retryTimes: c.RetryTimes,
authConfig: c.AuthConfig,
bodyConfig: c.BodyConfig,
md5FieldMapping: c.MD5FieldMapping,
}
client := http.DefaultClient
if cfg.Timeout > 0 {
client.Timeout = cfg.Timeout
if c.Timeout > 0 {
client.Timeout = c.Timeout
}
hg.client = client
hg.config = cfg
return hg, nil
}
type HTTPRequester struct {
client *http.Client
urlConfig URLConfig
authConfig *AuthenticationConfig
bodyConfig BodyConfig
method string
retryTimes uint64
md5FieldMapping MD5FieldMapping
}
func (hg *HTTPRequester) Invoke(ctx context.Context, input map[string]any) (output map[string]any, err error) {
var (
req = &Request{}
method = hg.config.Method
retryTimes = hg.config.RetryTimes
method = hg.method
retryTimes = hg.retryTimes
body io.ReadCloser
contentType string
response *http.Response
)
req, err = hg.config.parserToRequest(input)
req, err = hg.parserToRequest(input)
if err != nil {
return nil, err
}
@@ -236,7 +380,7 @@ func (hg *HTTPRequester) Invoke(ctx context.Context, input map[string]any) (outp
Header: http.Header{},
}
httpURL, err := nodes.TemplateRender(hg.config.URLConfig.Tpl, req.URLVars)
httpURL, err := nodes.TemplateRender(hg.urlConfig.Tpl, req.URLVars)
if err != nil {
return nil, err
}
@@ -255,8 +399,8 @@ func (hg *HTTPRequester) Invoke(ctx context.Context, input map[string]any) (outp
params.Set(key, value)
}
if hg.config.AuthConfig != nil {
httpRequest.Header, params, err = hg.config.AuthConfig.addAuthentication(ctx, req.Authentication, httpRequest.Header, params)
if hg.authConfig != nil {
httpRequest.Header, params, err = hg.authConfig.addAuthentication(ctx, req.Authentication, httpRequest.Header, params)
if err != nil {
return nil, err
}
@@ -264,7 +408,7 @@ func (hg *HTTPRequester) Invoke(ctx context.Context, input map[string]any) (outp
u.RawQuery = params.Encode()
httpRequest.URL = u
body, contentType, err = hg.config.BodyConfig.getBodyAndContentType(ctx, req)
body, contentType, err = hg.bodyConfig.getBodyAndContentType(ctx, req)
if err != nil {
return nil, err
}
@@ -479,18 +623,16 @@ func httpGet(ctx context.Context, url string) (*http.Response, error) {
}
func (hg *HTTPRequester) ToCallbackInput(_ context.Context, input map[string]any) (map[string]any, error) {
var (
request = &Request{}
config = hg.config
)
request, err := hg.config.parserToRequest(input)
var request = &Request{}
request, err := hg.parserToRequest(input)
if err != nil {
return nil, err
}
result := make(map[string]any)
result["method"] = config.Method
result["method"] = hg.method
u, err := nodes.TemplateRender(config.URLConfig.Tpl, request.URLVars)
u, err := nodes.TemplateRender(hg.urlConfig.Tpl, request.URLVars)
if err != nil {
return nil, err
}
@@ -508,13 +650,13 @@ func (hg *HTTPRequester) ToCallbackInput(_ context.Context, input map[string]any
}
result["header"] = headers
result["auth"] = nil
if config.AuthConfig != nil {
if config.AuthConfig.Type == Custom {
if hg.authConfig != nil {
if hg.authConfig.Type == Custom {
result["auth"] = map[string]interface{}{
"Key": request.Authentication.Key,
"Value": request.Authentication.Value,
}
} else if config.AuthConfig.Type == BearToken {
} else if hg.authConfig.Type == BearToken {
result["auth"] = map[string]interface{}{
"token": request.Authentication.Token,
}
@@ -522,9 +664,9 @@ func (hg *HTTPRequester) ToCallbackInput(_ context.Context, input map[string]any
}
result["body"] = nil
switch config.BodyConfig.BodyType {
switch hg.bodyConfig.BodyType {
case BodyTypeJSON:
js, err := nodes.TemplateRender(config.BodyConfig.TextJsonConfig.Tpl, request.JsonVars)
js, err := nodes.TemplateRender(hg.bodyConfig.TextJsonConfig.Tpl, request.JsonVars)
if err != nil {
return nil, err
}
@@ -535,7 +677,7 @@ func (hg *HTTPRequester) ToCallbackInput(_ context.Context, input map[string]any
}
result["body"] = ret
case BodyTypeRawText:
tx, err := nodes.TemplateRender(config.BodyConfig.TextPlainConfig.Tpl, request.TextPlainVars)
tx, err := nodes.TemplateRender(hg.bodyConfig.TextPlainConfig.Tpl, request.TextPlainVars)
if err != nil {
return nil, err
@@ -569,7 +711,7 @@ const (
bodyBinaryFileURLPrefix = "binary_fileURL"
)
func (cfg *Config) parserToRequest(input map[string]any) (*Request, error) {
func (hg *HTTPRequester) parserToRequest(input map[string]any) (*Request, error) {
request := &Request{
URLVars: make(map[string]any),
Headers: make(map[string]string),
@@ -583,7 +725,7 @@ func (cfg *Config) parserToRequest(input map[string]any) (*Request, error) {
for key, value := range input {
if strings.HasPrefix(key, apiInfoURLPrefix) {
urlMD5 := strings.TrimPrefix(key, apiInfoURLPrefix)
if urlKey, ok := cfg.URLMD5Mapping[urlMD5]; ok {
if urlKey, ok := hg.md5FieldMapping.URLMD5Mapping[urlMD5]; ok {
if strings.HasPrefix(urlKey, "global_variable_") {
urlKey = globalVariableReplaceRegexp.ReplaceAllString(urlKey, "global_variable_$1.$2")
}
@@ -592,13 +734,13 @@ func (cfg *Config) parserToRequest(input map[string]any) (*Request, error) {
}
if strings.HasPrefix(key, headersPrefix) {
headerKeyMD5 := strings.TrimPrefix(key, headersPrefix)
if headerKey, ok := cfg.HeaderMD5Mapping[headerKeyMD5]; ok {
if headerKey, ok := hg.md5FieldMapping.HeaderMD5Mapping[headerKeyMD5]; ok {
request.Headers[headerKey] = value.(string)
}
}
if strings.HasPrefix(key, paramsPrefix) {
paramKeyMD5 := strings.TrimPrefix(key, paramsPrefix)
if paramKey, ok := cfg.ParamMD5Mapping[paramKeyMD5]; ok {
if paramKey, ok := hg.md5FieldMapping.ParamMD5Mapping[paramKeyMD5]; ok {
request.Params[paramKey] = value.(string)
}
}
@@ -622,7 +764,7 @@ func (cfg *Config) parserToRequest(input map[string]any) (*Request, error) {
bodyKey := strings.TrimPrefix(key, bodyDataPrefix)
if strings.HasPrefix(bodyKey, bodyJsonPrefix) {
jsonMd5Key := strings.TrimPrefix(bodyKey, bodyJsonPrefix)
if jsonKey, ok := cfg.BodyMD5Mapping[jsonMd5Key]; ok {
if jsonKey, ok := hg.md5FieldMapping.BodyMD5Mapping[jsonMd5Key]; ok {
if strings.HasPrefix(jsonKey, "global_variable_") {
jsonKey = globalVariableReplaceRegexp.ReplaceAllString(jsonKey, "global_variable_$1.$2")
}
@@ -632,7 +774,7 @@ func (cfg *Config) parserToRequest(input map[string]any) (*Request, error) {
}
if strings.HasPrefix(bodyKey, bodyFormDataPrefix) {
formDataMd5Key := strings.TrimPrefix(bodyKey, bodyFormDataPrefix)
if formDataKey, ok := cfg.BodyMD5Mapping[formDataMd5Key]; ok {
if formDataKey, ok := hg.md5FieldMapping.BodyMD5Mapping[formDataMd5Key]; ok {
request.FormDataVars[formDataKey] = value.(string)
}
@@ -640,14 +782,14 @@ func (cfg *Config) parserToRequest(input map[string]any) (*Request, error) {
if strings.HasPrefix(bodyKey, bodyFormURLEncodedPrefix) {
formURLEncodeMd5Key := strings.TrimPrefix(bodyKey, bodyFormURLEncodedPrefix)
if formURLEncodeKey, ok := cfg.BodyMD5Mapping[formURLEncodeMd5Key]; ok {
if formURLEncodeKey, ok := hg.md5FieldMapping.BodyMD5Mapping[formURLEncodeMd5Key]; ok {
request.FormURLEncodedVars[formURLEncodeKey] = value.(string)
}
}
if strings.HasPrefix(bodyKey, bodyRawTextPrefix) {
rawTextMd5Key := strings.TrimPrefix(bodyKey, bodyRawTextPrefix)
if rawTextKey, ok := cfg.BodyMD5Mapping[rawTextMd5Key]; ok {
if rawTextKey, ok := hg.md5FieldMapping.BodyMD5Mapping[rawTextMd5Key]; ok {
if strings.HasPrefix(rawTextKey, "global_variable_") {
rawTextKey = globalVariableReplaceRegexp.ReplaceAllString(rawTextKey, "global_variable_$1.$2")
}

View File

@@ -28,6 +28,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/lang/crypto"
)
@@ -68,7 +69,7 @@ func TestInvoke(t *testing.T) {
},
},
}
hg, err := NewHTTPRequester(context.Background(), cfg)
hg, err := cfg.Build(context.Background(), &schema.NodeSchema{})
assert.NoError(t, err)
m := map[string]any{
"__apiInfo_url_" + crypto.MD5HexValue("url_v1"): "v1",
@@ -78,7 +79,7 @@ func TestInvoke(t *testing.T) {
"__params_" + crypto.MD5HexValue("p2"): "v2",
}
result, err := hg.Invoke(context.Background(), m)
result, err := hg.(*HTTPRequester).Invoke(context.Background(), m)
assert.NoError(t, err)
assert.Equal(t, `{"message":"success"}`, result["body"])
assert.Equal(t, int64(200), result["statusCode"])
@@ -157,7 +158,7 @@ func TestInvoke(t *testing.T) {
}
// Create an HTTPRequest instance
hg, err := NewHTTPRequester(context.Background(), cfg)
hg, err := cfg.Build(context.Background(), &schema.NodeSchema{})
assert.NoError(t, err)
m := map[string]any{
@@ -171,7 +172,7 @@ func TestInvoke(t *testing.T) {
"__body_bodyData_formData_" + crypto.MD5HexValue("fileURL"): fileServer.URL,
}
result, err := hg.Invoke(context.Background(), m)
result, err := hg.(*HTTPRequester).Invoke(context.Background(), m)
assert.NoError(t, err)
assert.Equal(t, `{"message":"success"}`, result["body"])
assert.Equal(t, int64(200), result["statusCode"])
@@ -228,7 +229,7 @@ func TestInvoke(t *testing.T) {
},
},
}
hg, err := NewHTTPRequester(context.Background(), cfg)
hg, err := cfg.Build(context.Background(), &schema.NodeSchema{})
assert.NoError(t, err)
m := map[string]any{
@@ -241,7 +242,7 @@ func TestInvoke(t *testing.T) {
"__body_bodyData_rawText_" + crypto.MD5HexValue("v2"): "v2",
}
result, err := hg.Invoke(context.Background(), m)
result, err := hg.(*HTTPRequester).Invoke(context.Background(), m)
assert.NoError(t, err)
assert.Equal(t, `{"message":"success"}`, result["body"])
assert.Equal(t, int64(200), result["statusCode"])
@@ -303,7 +304,7 @@ func TestInvoke(t *testing.T) {
}
// Create an HTTPRequest instance
hg, err := NewHTTPRequester(context.Background(), cfg)
hg, err := cfg.Build(context.Background(), &schema.NodeSchema{})
assert.NoError(t, err)
m := map[string]any{
@@ -316,7 +317,7 @@ func TestInvoke(t *testing.T) {
"__body_bodyData_json_" + crypto.MD5HexValue("v2"): "v2",
}
result, err := hg.Invoke(context.Background(), m)
result, err := hg.(*HTTPRequester).Invoke(context.Background(), m)
assert.NoError(t, err)
assert.Equal(t, `{"message":"success"}`, result["body"])
assert.Equal(t, int64(200), result["statusCode"])
@@ -376,7 +377,7 @@ func TestInvoke(t *testing.T) {
}
// Create an HTTPRequest instance
hg, err := NewHTTPRequester(context.Background(), cfg)
hg, err := cfg.Build(context.Background(), &schema.NodeSchema{})
assert.NoError(t, err)
m := map[string]any{
@@ -388,7 +389,7 @@ func TestInvoke(t *testing.T) {
"__body_bodyData_binary_fileURL" + crypto.MD5HexValue("v1"): fileServer.URL,
}
result, err := hg.Invoke(context.Background(), m)
result, err := hg.(*HTTPRequester).Invoke(context.Background(), m)
assert.NoError(t, err)
assert.Equal(t, `{"message":"success"}`, result["body"])
assert.Equal(t, int64(200), result["statusCode"])

View File

@@ -18,26 +18,167 @@ package intentdetector
import (
"context"
"encoding/json"
"errors"
"fmt"
"strconv"
"strings"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/components/prompt"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
"github.com/spf13/cast"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
)
type Config struct {
Intents []string
SystemPrompt string
IsFastMode bool
ChatModel model.BaseChatModel
LLMParams *model.LLMParams
}
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema2.NodeSchema, error) {
ns := &schema2.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeIntentDetector,
Name: n.Data.Meta.Title,
Configs: c,
}
param := n.Data.Inputs.LLMParam
if param == nil {
return nil, fmt.Errorf("intent detector node's llmParam is nil")
}
llmParam, ok := param.(vo.IntentDetectorLLMParam)
if !ok {
return nil, fmt.Errorf("llm node's llmParam must be LLMParam, got %v", llmParam)
}
paramBytes, err := sonic.Marshal(param)
if err != nil {
return nil, err
}
var intentDetectorConfig = &vo.IntentDetectorLLMConfig{}
err = sonic.Unmarshal(paramBytes, &intentDetectorConfig)
if err != nil {
return nil, err
}
modelLLMParams := &model.LLMParams{}
modelLLMParams.ModelType = int64(intentDetectorConfig.ModelType)
modelLLMParams.ModelName = intentDetectorConfig.ModelName
modelLLMParams.TopP = intentDetectorConfig.TopP
modelLLMParams.Temperature = intentDetectorConfig.Temperature
modelLLMParams.MaxTokens = intentDetectorConfig.MaxTokens
modelLLMParams.ResponseFormat = model.ResponseFormat(intentDetectorConfig.ResponseFormat)
modelLLMParams.SystemPrompt = intentDetectorConfig.SystemPrompt.Value.Content.(string)
c.LLMParams = modelLLMParams
c.SystemPrompt = modelLLMParams.SystemPrompt
var intents = make([]string, 0, len(n.Data.Inputs.Intents))
for _, it := range n.Data.Inputs.Intents {
intents = append(intents, it.Name)
}
c.Intents = intents
if n.Data.Inputs.Mode == "top_speed" {
c.IsFastMode = true
}
if err = convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func (c *Config) Build(ctx context.Context, _ *schema2.NodeSchema, _ ...schema2.BuildOption) (any, error) {
if !c.IsFastMode && c.LLMParams == nil {
return nil, errors.New("config chat model is required")
}
if len(c.Intents) == 0 {
return nil, errors.New("config intents is required")
}
m, _, err := model.GetManager().GetModel(ctx, c.LLMParams)
if err != nil {
return nil, err
}
chain := compose.NewChain[map[string]any, *schema.Message]()
spt := ternary.IFElse[string](c.IsFastMode, FastModeSystemIntentPrompt, SystemIntentPrompt)
intents, err := toIntentString(c.Intents)
if err != nil {
return nil, err
}
sptTemplate, err := nodes.TemplateRender(spt, map[string]interface{}{
"intents": intents,
})
if err != nil {
return nil, err
}
prompts := prompt.FromMessages(schema.Jinja2,
&schema.Message{Content: sptTemplate, Role: schema.System},
&schema.Message{Content: "{{query}}", Role: schema.User})
r, err := chain.AppendChatTemplate(prompts).AppendChatModel(m).Compile(ctx)
if err != nil {
return nil, err
}
return &IntentDetector{
isFastMode: c.IsFastMode,
systemPrompt: c.SystemPrompt,
runner: r,
}, nil
}
func (c *Config) BuildBranch(_ context.Context) (
func(ctx context.Context, nodeOutput map[string]any) (int64, bool, error), bool) {
return func(ctx context.Context, nodeOutput map[string]any) (int64, bool, error) {
classificationId, ok := nodeOutput[classificationID]
if !ok {
return -1, false, fmt.Errorf("failed to take classification id from input map: %v", nodeOutput)
}
cID64, ok := classificationId.(int64)
if !ok {
return -1, false, fmt.Errorf("classificationID not of type int64, actual type: %T", classificationId)
}
if cID64 == 0 {
return -1, true, nil
}
return cID64 - 1, false, nil
}, true
}
func (c *Config) ExpectPorts(ctx context.Context, n *vo.Node) []string {
expects := make([]string, len(n.Data.Inputs.Intents)+1)
expects[0] = schema2.PortDefault
for i := 0; i < len(n.Data.Inputs.Intents); i++ {
expects[i+1] = fmt.Sprintf(schema2.PortBranchFormat, i)
}
return expects
}
const SystemIntentPrompt = `
@@ -95,71 +236,39 @@ Note:
##Limit
- Please do not reply in text.`
const classificationID = "classificationId"
type IntentDetector struct {
config *Config
runner compose.Runnable[map[string]any, *schema.Message]
}
func NewIntentDetector(ctx context.Context, cfg *Config) (*IntentDetector, error) {
if cfg == nil {
return nil, errors.New("cfg is required")
}
if !cfg.IsFastMode && cfg.ChatModel == nil {
return nil, errors.New("config chat model is required")
}
if len(cfg.Intents) == 0 {
return nil, errors.New("config intents is required")
}
chain := compose.NewChain[map[string]any, *schema.Message]()
spt := ternary.IFElse[string](cfg.IsFastMode, FastModeSystemIntentPrompt, SystemIntentPrompt)
sptTemplate, err := nodes.TemplateRender(spt, map[string]interface{}{
"intents": toIntentString(cfg.Intents),
})
if err != nil {
return nil, err
}
prompts := prompt.FromMessages(schema.Jinja2,
&schema.Message{Content: sptTemplate, Role: schema.System},
&schema.Message{Content: "{{query}}", Role: schema.User})
r, err := chain.AppendChatTemplate(prompts).AppendChatModel(cfg.ChatModel).Compile(ctx)
if err != nil {
return nil, err
}
return &IntentDetector{
config: cfg,
runner: r,
}, nil
isFastMode bool
systemPrompt string
runner compose.Runnable[map[string]any, *schema.Message]
}
func (id *IntentDetector) parseToNodeOut(content string) (map[string]any, error) {
nodeOutput := make(map[string]any)
nodeOutput["classificationId"] = 0
if content == "" {
return nodeOutput, errors.New("content is empty")
return nil, errors.New("intent detector's LLM output content is empty")
}
if id.config.IsFastMode {
if id.isFastMode {
cid, err := strconv.ParseInt(content, 10, 64)
if err != nil {
return nodeOutput, err
return nil, err
}
nodeOutput["classificationId"] = cid
return nodeOutput, nil
return map[string]any{
classificationID: cid,
}, nil
}
leftIndex := strings.Index(content, "{")
rightIndex := strings.Index(content, "}")
if leftIndex == -1 || rightIndex == -1 {
return nodeOutput, errors.New("content is invalid")
return nil, fmt.Errorf("intent detector's LLM output content is invalid: %s", content)
}
err := json.Unmarshal([]byte(content[leftIndex:rightIndex+1]), &nodeOutput)
var nodeOutput map[string]any
err := sonic.UnmarshalString(content[leftIndex:rightIndex+1], &nodeOutput)
if err != nil {
return nodeOutput, err
return nil, err
}
return nodeOutput, nil
@@ -178,8 +287,8 @@ func (id *IntentDetector) Invoke(ctx context.Context, input map[string]any) (map
vars := make(map[string]any)
vars["query"] = queryStr
if !id.config.IsFastMode {
ad, err := nodes.TemplateRender(id.config.SystemPrompt, map[string]any{"query": query})
if !id.isFastMode {
ad, err := nodes.TemplateRender(id.systemPrompt, map[string]any{"query": query})
if err != nil {
return nil, err
}
@@ -193,7 +302,7 @@ func (id *IntentDetector) Invoke(ctx context.Context, input map[string]any) (map
return id.parseToNodeOut(o.Content)
}
func toIntentString(its []string) string {
func toIntentString(its []string) (string, error) {
type IntentVariableItem struct {
ClassificationID int64 `json:"classificationId"`
Content string `json:"content"`
@@ -207,6 +316,6 @@ func toIntentString(its []string) string {
Content: it,
})
}
itsBytes, _ := json.Marshal(vs)
return string(itsBytes)
return sonic.MarshalString(vs)
}

View File

@@ -1,88 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package intentdetector
import (
"context"
"fmt"
"testing"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/schema"
"github.com/stretchr/testify/assert"
)
type mockChatModel struct {
topSeed bool
}
func (m mockChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) {
if m.topSeed {
return &schema.Message{
Content: "1",
}, nil
}
return &schema.Message{
Content: `{"classificationId":1,"reason":"高兴"}`,
}, nil
}
func (m mockChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) {
return nil, nil
}
func (m mockChatModel) BindTools(tools []*schema.ToolInfo) error {
return nil
}
func TestNewIntentDetector(t *testing.T) {
ctx := context.Background()
t.Run("fast mode", func(t *testing.T) {
dt, err := NewIntentDetector(ctx, &Config{
Intents: []string{"高兴", "悲伤"},
IsFastMode: true,
ChatModel: &mockChatModel{topSeed: true},
})
assert.Nil(t, err)
ret, err := dt.Invoke(ctx, map[string]any{
"query": "我考了100分",
})
assert.Nil(t, err)
assert.Equal(t, ret["classificationId"], int64(1))
})
t.Run("full mode", func(t *testing.T) {
dt, err := NewIntentDetector(ctx, &Config{
Intents: []string{"高兴", "悲伤"},
IsFastMode: false,
ChatModel: &mockChatModel{},
})
assert.Nil(t, err)
ret, err := dt.Invoke(ctx, map[string]any{
"query": "我考了100分",
})
fmt.Println(err)
assert.Nil(t, err)
fmt.Println(ret)
assert.Equal(t, ret["classificationId"], float64(1))
assert.Equal(t, ret["reason"], "高兴")
})
}

View File

@@ -20,8 +20,11 @@ import (
"context"
"fmt"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
@@ -34,32 +37,42 @@ const (
warningsKey = "deserialization_warnings"
)
type DeserializationConfig struct {
OutputFields map[string]*vo.TypeInfo `json:"outputFields,omitempty"`
type DeserializationConfig struct{}
func (d *DeserializationConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (
*schema.NodeSchema, error) {
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeJsonDeserialization,
Name: n.Data.Meta.Title,
Configs: d,
}
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
type Deserializer struct {
config *DeserializationConfig
typeInfo *vo.TypeInfo
}
func NewJsonDeserializer(_ context.Context, cfg *DeserializationConfig) (*Deserializer, error) {
if cfg == nil {
return nil, fmt.Errorf("config required")
}
if cfg.OutputFields == nil {
return nil, fmt.Errorf("OutputFields is required for deserialization")
}
typeInfo := cfg.OutputFields[OutputKeyDeserialization]
if typeInfo == nil {
func (d *DeserializationConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
typeInfo, ok := ns.OutputTypes[OutputKeyDeserialization]
if !ok {
return nil, fmt.Errorf("no output field specified in deserialization config")
}
return &Deserializer{
config: cfg,
typeInfo: typeInfo,
}, nil
}
type Deserializer struct {
typeInfo *vo.TypeInfo
}
func (jd *Deserializer) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
jsonStrValue := input[InputKeyDeserialization]

View File

@@ -24,6 +24,7 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
)
@@ -31,19 +32,9 @@ import (
func TestNewJsonDeserializer(t *testing.T) {
ctx := context.Background()
// Test with nil config
_, err := NewJsonDeserializer(ctx, nil)
assert.Error(t, err)
assert.Contains(t, err.Error(), "config required")
// Test with missing OutputFields config
_, err = NewJsonDeserializer(ctx, &DeserializationConfig{})
assert.Error(t, err)
assert.Contains(t, err.Error(), "OutputFields is required")
// Test with missing output key in OutputFields
_, err = NewJsonDeserializer(ctx, &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
_, err := (&DeserializationConfig{}).Build(ctx, &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
"testKey": {Type: vo.DataTypeString},
},
})
@@ -51,12 +42,12 @@ func TestNewJsonDeserializer(t *testing.T) {
assert.Contains(t, err.Error(), "no output field specified in deserialization config")
// Test with valid config
validConfig := &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
validConfig := &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
OutputKeyDeserialization: {Type: vo.DataTypeString},
},
}
processor, err := NewJsonDeserializer(ctx, validConfig)
processor, err := (&DeserializationConfig{}).Build(ctx, validConfig)
assert.NoError(t, err)
assert.NotNil(t, processor)
}
@@ -65,16 +56,16 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
ctx := context.Background()
// Base type test config
baseTypeConfig := &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {Type: vo.DataTypeString},
baseTypeConfig := &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
OutputKeyDeserialization: {Type: vo.DataTypeString},
},
}
// Object type test config
objectTypeConfig := &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {
objectTypeConfig := &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
OutputKeyDeserialization: {
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"name": {Type: vo.DataTypeString, Required: true},
@@ -85,9 +76,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
}
// Array type test config
arrayTypeConfig := &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {
arrayTypeConfig := &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
OutputKeyDeserialization: {
Type: vo.DataTypeArray,
ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeInteger},
},
@@ -95,9 +86,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
}
// Nested array object test config
nestedArrayConfig := &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {
nestedArrayConfig := &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
OutputKeyDeserialization: {
Type: vo.DataTypeArray,
ElemTypeInfo: &vo.TypeInfo{
Type: vo.DataTypeObject,
@@ -113,7 +104,7 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
// Test cases
tests := []struct {
name string
config *DeserializationConfig
config *schema.NodeSchema
inputJSON string
expectedOutput any
expectErr bool
@@ -127,9 +118,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
expectWarnings: 0,
}, {
name: "Test integer deserialization",
config: &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {Type: vo.DataTypeInteger},
config: &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
OutputKeyDeserialization: {Type: vo.DataTypeInteger},
},
},
inputJSON: `123`,
@@ -138,9 +129,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
expectWarnings: 0,
}, {
name: "Test boolean deserialization",
config: &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {Type: vo.DataTypeBoolean},
config: &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
OutputKeyDeserialization: {Type: vo.DataTypeBoolean},
},
},
inputJSON: `true`,
@@ -180,9 +171,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
expectWarnings: 0,
}, {
name: "Test type mismatch warning",
config: &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {Type: vo.DataTypeInteger},
config: &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
OutputKeyDeserialization: {Type: vo.DataTypeInteger},
},
},
inputJSON: `"not a number"`,
@@ -198,9 +189,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
expectWarnings: 0,
}, {
name: "Test string to integer conversion",
config: &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {Type: vo.DataTypeInteger},
config: &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
OutputKeyDeserialization: {Type: vo.DataTypeInteger},
},
},
inputJSON: `"123"`,
@@ -209,9 +200,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
expectWarnings: 0,
}, {
name: "Test float to integer conversion (integer part)",
config: &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {Type: vo.DataTypeInteger},
config: &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
OutputKeyDeserialization: {Type: vo.DataTypeInteger},
},
},
inputJSON: `123.0`,
@@ -220,9 +211,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
expectWarnings: 0,
}, {
name: "Test float to integer conversion (non-integer part)",
config: &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {Type: vo.DataTypeInteger},
config: &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
OutputKeyDeserialization: {Type: vo.DataTypeInteger},
},
},
inputJSON: `123.5`,
@@ -231,9 +222,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
expectWarnings: 0,
}, {
name: "Test boolean to integer conversion",
config: &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {Type: vo.DataTypeInteger},
config: &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
OutputKeyDeserialization: {Type: vo.DataTypeInteger},
},
},
inputJSON: `true`,
@@ -242,9 +233,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
expectWarnings: 1,
}, {
name: "Test string to boolean conversion",
config: &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {Type: vo.DataTypeBoolean},
config: &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
OutputKeyDeserialization: {Type: vo.DataTypeBoolean},
},
},
inputJSON: `"true"`,
@@ -252,10 +243,11 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
expectErr: false,
expectWarnings: 0,
}, {
name: "Test string to integer conversion in nested object",
config: &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {
name: "Test string to integer conversion in nested object",
inputJSON: `{"age":"456"}`,
config: &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
OutputKeyDeserialization: {
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"age": {Type: vo.DataTypeInteger},
@@ -263,15 +255,14 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
},
},
},
inputJSON: `{"age":"456"}`,
expectedOutput: map[string]any{"age": 456},
expectErr: false,
expectWarnings: 0,
}, {
name: "Test string to integer conversion for array elements",
config: &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {
config: &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
OutputKeyDeserialization: {
Type: vo.DataTypeArray,
ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeInteger},
},
@@ -283,9 +274,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
expectWarnings: 0,
}, {
name: "Test string with non-numeric characters to integer conversion",
config: &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {Type: vo.DataTypeInteger},
config: &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
OutputKeyDeserialization: {Type: vo.DataTypeInteger},
},
},
inputJSON: `"123abc"`,
@@ -294,9 +285,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
expectWarnings: 1,
}, {
name: "Test type mismatch in nested object field",
config: &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {
config: &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
OutputKeyDeserialization: {
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"score": {Type: vo.DataTypeInteger},
@@ -310,9 +301,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
expectWarnings: 1,
}, {
name: "Test partial conversion failure in array elements",
config: &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {
config: &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
OutputKeyDeserialization: {
Type: vo.DataTypeArray,
ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeInteger},
},
@@ -326,12 +317,12 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
processor, err := NewJsonDeserializer(ctx, tt.config)
processor, err := (&DeserializationConfig{}).Build(ctx, tt.config)
assert.NoError(t, err)
ctxWithCache := ctxcache.Init(ctx)
input := map[string]any{"input": tt.inputJSON}
result, err := processor.Invoke(ctxWithCache, input)
result, err := processor.(*Deserializer).Invoke(ctxWithCache, input)
if tt.expectErr {
assert.Error(t, err)

View File

@@ -20,7 +20,11 @@ import (
"context"
"fmt"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
)
@@ -29,28 +33,57 @@ const (
OutputKeySerialization = "output"
)
// SerializationConfig is the Config type for NodeTypeJsonSerialization.
// Each Node Type should have its own designated Config type,
// which should implement NodeAdaptor and NodeBuilder.
// NOTE: we didn't define any fields for this type,
// because this node is simple, we doesn't need to extract any SPECIFIC piece of info
// from frontend Node. In other cases we would need to do it, such as LLM's model configs.
type SerializationConfig struct {
InputTypes map[string]*vo.TypeInfo
// you can define ANY number of fields here,
// as long as these fields are SERIALIZABLE and EXPORTED.
// to store specific info extracted from frontend node.
// e.g.
// - LLM model configs
// - conditional expressions
// - fixed input fields such as MaxBatchSize
}
type JsonSerializer struct {
config *SerializationConfig
}
func NewJsonSerializer(_ context.Context, cfg *SerializationConfig) (*JsonSerializer, error) {
if cfg == nil {
return nil, fmt.Errorf("config required")
}
if cfg.InputTypes == nil {
return nil, fmt.Errorf("InputTypes is required for serialization")
// Adapt provides conversion from Node to NodeSchema.
// NOTE: in this specific case, we don't need AdaptOption.
func (s *SerializationConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeJsonSerialization,
Name: n.Data.Meta.Title,
Configs: s, // remember to set the Node's Config Type to NodeSchema as well
}
return &JsonSerializer{
config: cfg,
}, nil
// this sets input fields' type and mapping info
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
// this set output fields' type info
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func (js *JsonSerializer) Invoke(_ context.Context, input map[string]any) (map[string]any, error) {
func (s *SerializationConfig) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (
any, error) {
return &Serializer{}, nil
}
// Serializer is the actual node implementation.
type Serializer struct {
// here can holds ANY data required for node execution
}
// Invoke implements the InvokableNode interface.
func (js *Serializer) Invoke(_ context.Context, input map[string]any) (map[string]any, error) {
// Directly use the input map for serialization
if input == nil {
return nil, fmt.Errorf("input data for serialization cannot be nil")

View File

@@ -23,44 +23,34 @@ import (
"github.com/stretchr/testify/assert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
func TestNewJsonSerialize(t *testing.T) {
ctx := context.Background()
// Test with nil config
_, err := NewJsonSerializer(ctx, nil)
assert.Error(t, err)
assert.Contains(t, err.Error(), "config required")
// Test with missing InputTypes config
_, err = NewJsonSerializer(ctx, &SerializationConfig{})
assert.Error(t, err)
assert.Contains(t, err.Error(), "InputTypes is required")
// Test with valid config
validConfig := &SerializationConfig{
s, err := (&SerializationConfig{}).Build(ctx, &schema.NodeSchema{
InputTypes: map[string]*vo.TypeInfo{
"testKey": {Type: "string"},
},
}
processor, err := NewJsonSerializer(ctx, validConfig)
})
assert.NoError(t, err)
assert.NotNil(t, processor)
assert.NotNil(t, s)
}
func TestJsonSerialize_Invoke(t *testing.T) {
ctx := context.Background()
config := &SerializationConfig{
processor, err := (&SerializationConfig{}).Build(ctx, &schema.NodeSchema{
InputTypes: map[string]*vo.TypeInfo{
"stringKey": {Type: "string"},
"intKey": {Type: "integer"},
"boolKey": {Type: "boolean"},
"objKey": {Type: "object"},
},
}
processor, err := NewJsonSerializer(ctx, config)
})
assert.NoError(t, err)
// Test cases
@@ -115,7 +105,7 @@ func TestJsonSerialize_Invoke(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := processor.Invoke(ctx, tt.input)
result, err := processor.(*Serializer).Invoke(ctx, tt.input)
if tt.expectErr {
assert.Error(t, err)

View File

@@ -0,0 +1,57 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package knowledge
import (
"fmt"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
)
func convertParsingType(p string) (knowledge.ParseMode, error) {
switch p {
case "fast":
return knowledge.FastParseMode, nil
case "accurate":
return knowledge.AccurateParseMode, nil
default:
return "", fmt.Errorf("invalid parsingType: %s", p)
}
}
func convertChunkType(p string) (knowledge.ChunkType, error) {
switch p {
case "custom":
return knowledge.ChunkTypeCustom, nil
case "default":
return knowledge.ChunkTypeDefault, nil
default:
return "", fmt.Errorf("invalid ChunkType: %s", p)
}
}
func convertRetrievalSearchType(s int64) (knowledge.SearchType, error) {
switch s {
case 0:
return knowledge.SearchTypeSemantic, nil
case 1:
return knowledge.SearchTypeHybrid, nil
case 20:
return knowledge.SearchTypeFullText, nil
default:
return "", fmt.Errorf("invalid RetrievalSearchType %v", s)
}
}

View File

@@ -21,27 +21,45 @@ import (
"errors"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
type DeleterConfig struct {
KnowledgeID int64
KnowledgeDeleter knowledge.KnowledgeOperator
}
type DeleterConfig struct{}
type KnowledgeDeleter struct {
config *DeleterConfig
}
func NewKnowledgeDeleter(_ context.Context, cfg *DeleterConfig) (*KnowledgeDeleter, error) {
if cfg.KnowledgeDeleter == nil {
return nil, errors.New("knowledge deleter is required")
func (d *DeleterConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeKnowledgeDeleter,
Name: n.Data.Meta.Title,
Configs: d,
}
return &KnowledgeDeleter{
config: cfg,
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func (d *DeleterConfig) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
return &Deleter{
knowledgeDeleter: knowledge.GetKnowledgeOperator(),
}, nil
}
func (k *KnowledgeDeleter) Delete(ctx context.Context, input map[string]any) (map[string]any, error) {
type Deleter struct {
knowledgeDeleter knowledge.KnowledgeOperator
}
func (k *Deleter) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
documentID, ok := input["documentID"].(string)
if !ok {
return nil, errors.New("documentID is required and must be a string")
@@ -51,7 +69,7 @@ func (k *KnowledgeDeleter) Delete(ctx context.Context, input map[string]any) (ma
DocumentID: documentID,
}
response, err := k.config.KnowledgeDeleter.Delete(ctx, req)
response, err := k.knowledgeDeleter.Delete(ctx, req)
if err != nil {
return nil, err
}

View File

@@ -24,7 +24,14 @@ import (
"path/filepath"
"strings"
"github.com/spf13/cast"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
)
@@ -32,30 +39,88 @@ type IndexerConfig struct {
KnowledgeID int64
ParsingStrategy *knowledge.ParsingStrategy
ChunkingStrategy *knowledge.ChunkingStrategy
KnowledgeIndexer knowledge.KnowledgeOperator
}
type KnowledgeIndexer struct {
config *IndexerConfig
func (i *IndexerConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeKnowledgeIndexer,
Name: n.Data.Meta.Title,
Configs: i,
}
inputs := n.Data.Inputs
datasetListInfoParam := inputs.DatasetParam[0]
datasetIDs := datasetListInfoParam.Input.Value.Content.([]any)
if len(datasetIDs) == 0 {
return nil, fmt.Errorf("dataset ids is required")
}
knowledgeID, err := cast.ToInt64E(datasetIDs[0])
if err != nil {
return nil, err
}
i.KnowledgeID = knowledgeID
ps := inputs.StrategyParam.ParsingStrategy
parseMode, err := convertParsingType(ps.ParsingType)
if err != nil {
return nil, err
}
parsingStrategy := &knowledge.ParsingStrategy{
ParseMode: parseMode,
ImageOCR: ps.ImageOcr,
ExtractImage: ps.ImageExtraction,
ExtractTable: ps.TableExtraction,
}
i.ParsingStrategy = parsingStrategy
cs := inputs.StrategyParam.ChunkStrategy
chunkType, err := convertChunkType(cs.ChunkType)
if err != nil {
return nil, err
}
chunkingStrategy := &knowledge.ChunkingStrategy{
ChunkType: chunkType,
Separator: cs.Separator,
ChunkSize: cs.MaxToken,
Overlap: int64(cs.Overlap * float64(cs.MaxToken)),
}
i.ChunkingStrategy = chunkingStrategy
if err = convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func NewKnowledgeIndexer(_ context.Context, cfg *IndexerConfig) (*KnowledgeIndexer, error) {
if cfg.ParsingStrategy == nil {
func (i *IndexerConfig) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
if i.ParsingStrategy == nil {
return nil, errors.New("parsing strategy is required")
}
if cfg.ChunkingStrategy == nil {
if i.ChunkingStrategy == nil {
return nil, errors.New("chunking strategy is required")
}
if cfg.KnowledgeIndexer == nil {
return nil, errors.New("knowledge indexer is required")
}
return &KnowledgeIndexer{
config: cfg,
return &Indexer{
knowledgeID: i.KnowledgeID,
parsingStrategy: i.ParsingStrategy,
chunkingStrategy: i.ChunkingStrategy,
knowledgeIndexer: knowledge.GetKnowledgeOperator(),
}, nil
}
func (k *KnowledgeIndexer) Store(ctx context.Context, input map[string]any) (map[string]any, error) {
type Indexer struct {
knowledgeID int64
parsingStrategy *knowledge.ParsingStrategy
chunkingStrategy *knowledge.ChunkingStrategy
knowledgeIndexer knowledge.KnowledgeOperator
}
func (k *Indexer) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
fileURL, ok := input["knowledge"].(string)
if !ok {
return nil, errors.New("knowledge is required")
@@ -68,15 +133,15 @@ func (k *KnowledgeIndexer) Store(ctx context.Context, input map[string]any) (map
}
req := &knowledge.CreateDocumentRequest{
KnowledgeID: k.config.KnowledgeID,
ParsingStrategy: k.config.ParsingStrategy,
ChunkingStrategy: k.config.ChunkingStrategy,
KnowledgeID: k.knowledgeID,
ParsingStrategy: k.parsingStrategy,
ChunkingStrategy: k.chunkingStrategy,
FileURL: fileURL,
FileName: fileName,
FileExtension: ext,
}
response, err := k.config.KnowledgeIndexer.Store(ctx, req)
response, err := k.knowledgeIndexer.Store(ctx, req)
if err != nil {
return nil, err
}

View File

@@ -20,7 +20,14 @@ import (
"context"
"errors"
"github.com/spf13/cast"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
)
@@ -29,37 +36,136 @@ const outputList = "outputList"
type RetrieveConfig struct {
KnowledgeIDs []int64
RetrievalStrategy *knowledge.RetrievalStrategy
Retriever knowledge.KnowledgeOperator
}
type KnowledgeRetrieve struct {
config *RetrieveConfig
func (r *RetrieveConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeKnowledgeRetriever,
Name: n.Data.Meta.Title,
Configs: r,
}
inputs := n.Data.Inputs
datasetListInfoParam := inputs.DatasetParam[0]
datasetIDs := datasetListInfoParam.Input.Value.Content.([]any)
knowledgeIDs := make([]int64, 0, len(datasetIDs))
for _, id := range datasetIDs {
k, err := cast.ToInt64E(id)
if err != nil {
return nil, err
}
knowledgeIDs = append(knowledgeIDs, k)
}
r.KnowledgeIDs = knowledgeIDs
retrievalStrategy := &knowledge.RetrievalStrategy{}
var getDesignatedParamContent = func(name string) (any, bool) {
for _, param := range inputs.DatasetParam {
if param.Name == name {
return param.Input.Value.Content, true
}
}
return nil, false
}
if content, ok := getDesignatedParamContent("topK"); ok {
topK, err := cast.ToInt64E(content)
if err != nil {
return nil, err
}
retrievalStrategy.TopK = &topK
}
if content, ok := getDesignatedParamContent("useRerank"); ok {
useRerank, err := cast.ToBoolE(content)
if err != nil {
return nil, err
}
retrievalStrategy.EnableRerank = useRerank
}
if content, ok := getDesignatedParamContent("useRewrite"); ok {
useRewrite, err := cast.ToBoolE(content)
if err != nil {
return nil, err
}
retrievalStrategy.EnableQueryRewrite = useRewrite
}
if content, ok := getDesignatedParamContent("isPersonalOnly"); ok {
isPersonalOnly, err := cast.ToBoolE(content)
if err != nil {
return nil, err
}
retrievalStrategy.IsPersonalOnly = isPersonalOnly
}
if content, ok := getDesignatedParamContent("useNl2sql"); ok {
useNl2sql, err := cast.ToBoolE(content)
if err != nil {
return nil, err
}
retrievalStrategy.EnableNL2SQL = useNl2sql
}
if content, ok := getDesignatedParamContent("minScore"); ok {
minScore, err := cast.ToFloat64E(content)
if err != nil {
return nil, err
}
retrievalStrategy.MinScore = &minScore
}
if content, ok := getDesignatedParamContent("strategy"); ok {
strategy, err := cast.ToInt64E(content)
if err != nil {
return nil, err
}
searchType, err := convertRetrievalSearchType(strategy)
if err != nil {
return nil, err
}
retrievalStrategy.SearchType = searchType
}
r.RetrievalStrategy = retrievalStrategy
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func NewKnowledgeRetrieve(_ context.Context, cfg *RetrieveConfig) (*KnowledgeRetrieve, error) {
if cfg == nil {
return nil, errors.New("cfg is required")
func (r *RetrieveConfig) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
if len(r.KnowledgeIDs) == 0 {
return nil, errors.New("knowledge ids are required")
}
if cfg.Retriever == nil {
return nil, errors.New("retriever is required")
}
if len(cfg.KnowledgeIDs) == 0 {
return nil, errors.New("knowledgeI ids is required")
}
if cfg.RetrievalStrategy == nil {
if r.RetrievalStrategy == nil {
return nil, errors.New("retrieval strategy is required")
}
return &KnowledgeRetrieve{
config: cfg,
return &Retrieve{
knowledgeIDs: r.KnowledgeIDs,
retrievalStrategy: r.RetrievalStrategy,
retriever: knowledge.GetKnowledgeOperator(),
}, nil
}
func (kr *KnowledgeRetrieve) Retrieve(ctx context.Context, input map[string]any) (map[string]any, error) {
type Retrieve struct {
knowledgeIDs []int64
retrievalStrategy *knowledge.RetrievalStrategy
retriever knowledge.KnowledgeOperator
}
func (kr *Retrieve) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
query, ok := input["Query"].(string)
if !ok {
return nil, errors.New("capital query key is required")
@@ -67,11 +173,11 @@ func (kr *KnowledgeRetrieve) Retrieve(ctx context.Context, input map[string]any)
req := &knowledge.RetrieveRequest{
Query: query,
KnowledgeIDs: kr.config.KnowledgeIDs,
RetrievalStrategy: kr.config.RetrievalStrategy,
KnowledgeIDs: kr.knowledgeIDs,
RetrievalStrategy: kr.retrievalStrategy,
}
response, err := kr.config.Retriever.Retrieve(ctx, req)
response, err := kr.retriever.Retrieve(ctx, req)
if err != nil {
return nil, err
}

View File

@@ -34,13 +34,20 @@ import (
callbacks2 "github.com/cloudwego/eino/utils/callbacks"
"golang.org/x/exp/maps"
workflow3 "github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow"
crossknowledge "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
crossmodel "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
@@ -143,126 +150,408 @@ const (
)
type RetrievalStrategy struct {
RetrievalStrategy *crossknowledge.RetrievalStrategy
RetrievalStrategy *knowledge.RetrievalStrategy
NoReCallReplyMode NoReCallReplyMode
NoReCallReplyCustomizePrompt string
}
type KnowledgeRecallConfig struct {
ChatModel model.BaseChatModel
Retriever crossknowledge.KnowledgeOperator
Retriever knowledge.KnowledgeOperator
RetrievalStrategy *RetrievalStrategy
SelectedKnowledgeDetails []*crossknowledge.KnowledgeDetail
SelectedKnowledgeDetails []*knowledge.KnowledgeDetail
}
type Config struct {
ChatModel ModelWithInfo
Tools []tool.BaseTool
SystemPrompt string
UserPrompt string
OutputFormat Format
InputFields map[string]*vo.TypeInfo
OutputFields map[string]*vo.TypeInfo
ToolsReturnDirectly map[string]bool
KnowledgeRecallConfig *KnowledgeRecallConfig
FullSources map[string]*nodes.SourceInfo
SystemPrompt string
UserPrompt string
OutputFormat Format
LLMParams *crossmodel.LLMParams
FCParam *vo.FCParam
BackupLLMParams *crossmodel.LLMParams
}
type LLM struct {
r compose.Runnable[map[string]any, map[string]any]
outputFormat Format
outputFields map[string]*vo.TypeInfo
canStream bool
requireCheckpoint bool
fullSources map[string]*nodes.SourceInfo
}
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema2.NodeSchema, error) {
ns := &schema2.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeLLM,
Name: n.Data.Meta.Title,
Configs: c,
}
const (
rawOutputKey = "llm_raw_output_%s"
warningKey = "llm_warning_%s"
)
param := n.Data.Inputs.LLMParam
if param == nil {
return nil, fmt.Errorf("llm node's llmParam is nil")
}
func jsonParse(ctx context.Context, data string, schema_ map[string]*vo.TypeInfo) (map[string]any, error) {
data = nodes.ExtractJSONString(data)
var result map[string]any
err := sonic.UnmarshalString(data, &result)
bs, _ := sonic.Marshal(param)
llmParam := make(vo.LLMParam, 0)
if err := sonic.Unmarshal(bs, &llmParam); err != nil {
return nil, err
}
convertedLLMParam, err := llmParamsToLLMParam(llmParam)
if err != nil {
c := execute.GetExeCtx(ctx)
if c != nil {
logs.CtxErrorf(ctx, "failed to parse json: %v, data: %s", err, data)
rawOutputK := fmt.Sprintf(rawOutputKey, c.NodeCtx.NodeKey)
warningK := fmt.Sprintf(warningKey, c.NodeCtx.NodeKey)
ctxcache.Store(ctx, rawOutputK, data)
ctxcache.Store(ctx, warningK, vo.WrapWarn(errno.ErrLLMStructuredOutputParseFail, err))
return map[string]any{}, nil
}
return nil, err
}
r, ws, err := nodes.ConvertInputs(ctx, result, schema_)
if err != nil {
return nil, vo.WrapError(errno.ErrLLMStructuredOutputParseFail, err)
c.LLMParams = convertedLLMParam
c.SystemPrompt = convertedLLMParam.SystemPrompt
c.UserPrompt = convertedLLMParam.Prompt
var resFormat Format
switch convertedLLMParam.ResponseFormat {
case crossmodel.ResponseFormatText:
resFormat = FormatText
case crossmodel.ResponseFormatMarkdown:
resFormat = FormatMarkdown
case crossmodel.ResponseFormatJSON:
resFormat = FormatJSON
default:
return nil, fmt.Errorf("unsupported response format: %d", convertedLLMParam.ResponseFormat)
}
if ws != nil {
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
c.OutputFormat = resFormat
if err = convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
return r, nil
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
if resFormat == FormatJSON {
if len(ns.OutputTypes) == 1 {
for _, v := range ns.OutputTypes {
if v.Type == vo.DataTypeString {
resFormat = FormatText
break
}
}
} else if len(ns.OutputTypes) == 2 {
if _, ok := ns.OutputTypes[ReasoningOutputKey]; ok {
for k, v := range ns.OutputTypes {
if k != ReasoningOutputKey && v.Type == vo.DataTypeString {
resFormat = FormatText
break
}
}
}
}
}
if resFormat == FormatJSON {
ns.StreamConfigs = &schema2.StreamConfig{
CanGeneratesStream: false,
}
} else {
ns.StreamConfigs = &schema2.StreamConfig{
CanGeneratesStream: true,
}
}
if n.Data.Inputs.LLM != nil && n.Data.Inputs.FCParam != nil {
c.FCParam = n.Data.Inputs.FCParam
}
if se := n.Data.Inputs.SettingOnError; se != nil {
if se.Ext != nil && len(se.Ext.BackupLLMParam) > 0 {
var backupLLMParam vo.SimpleLLMParam
if err = sonic.UnmarshalString(se.Ext.BackupLLMParam, &backupLLMParam); err != nil {
return nil, err
}
backupModel, err := simpleLLMParamsToLLMParams(backupLLMParam)
if err != nil {
return nil, err
}
c.BackupLLMParams = backupModel
}
}
return ns, nil
}
func llmParamsToLLMParam(params vo.LLMParam) (*crossmodel.LLMParams, error) {
p := &crossmodel.LLMParams{}
for _, param := range params {
switch param.Name {
case "temperature":
strVal := param.Input.Value.Content.(string)
floatVal, err := strconv.ParseFloat(strVal, 64)
if err != nil {
return nil, err
}
p.Temperature = &floatVal
case "maxTokens":
strVal := param.Input.Value.Content.(string)
intVal, err := strconv.Atoi(strVal)
if err != nil {
return nil, err
}
p.MaxTokens = intVal
case "responseFormat":
strVal := param.Input.Value.Content.(string)
int64Val, err := strconv.ParseInt(strVal, 10, 64)
if err != nil {
return nil, err
}
p.ResponseFormat = crossmodel.ResponseFormat(int64Val)
case "modleName":
strVal := param.Input.Value.Content.(string)
p.ModelName = strVal
case "modelType":
strVal := param.Input.Value.Content.(string)
int64Val, err := strconv.ParseInt(strVal, 10, 64)
if err != nil {
return nil, err
}
p.ModelType = int64Val
case "prompt":
strVal := param.Input.Value.Content.(string)
p.Prompt = strVal
case "enableChatHistory":
boolVar := param.Input.Value.Content.(bool)
p.EnableChatHistory = boolVar
case "systemPrompt":
strVal := param.Input.Value.Content.(string)
p.SystemPrompt = strVal
case "chatHistoryRound", "generationDiversity", "frequencyPenalty", "presencePenalty":
// do nothing
case "topP":
strVal := param.Input.Value.Content.(string)
floatVar, err := strconv.ParseFloat(strVal, 64)
if err != nil {
return nil, err
}
p.TopP = &floatVar
default:
return nil, fmt.Errorf("invalid LLMParam name: %s", param.Name)
}
}
return p, nil
}
func simpleLLMParamsToLLMParams(params vo.SimpleLLMParam) (*crossmodel.LLMParams, error) {
p := &crossmodel.LLMParams{}
p.ModelName = params.ModelName
p.ModelType = params.ModelType
p.Temperature = &params.Temperature
p.MaxTokens = params.MaxTokens
p.TopP = &params.TopP
p.ResponseFormat = params.ResponseFormat
p.SystemPrompt = params.SystemPrompt
return p, nil
}
func getReasoningContent(message *schema.Message) string {
return message.ReasoningContent
}
type Options struct {
nested []nodes.NestedWorkflowOption
toolWorkflowSW *schema.StreamWriter[*entity.Message]
}
func (c *Config) Build(ctx context.Context, ns *schema2.NodeSchema, _ ...schema2.BuildOption) (any, error) {
var (
err error
chatModel, fallbackM model.BaseChatModel
info, fallbackI *modelmgr.Model
modelWithInfo ModelWithInfo
tools []tool.BaseTool
toolsReturnDirectly map[string]bool
knowledgeRecallConfig *KnowledgeRecallConfig
)
type Option func(o *Options)
func WithNestedWorkflowOptions(nested ...nodes.NestedWorkflowOption) Option {
return func(o *Options) {
o.nested = append(o.nested, nested...)
chatModel, info, err = crossmodel.GetManager().GetModel(ctx, c.LLMParams)
if err != nil {
return nil, err
}
}
func WithToolWorkflowMessageWriter(sw *schema.StreamWriter[*entity.Message]) Option {
return func(o *Options) {
o.toolWorkflowSW = sw
exceptionConf := ns.ExceptionConfigs
if exceptionConf != nil && exceptionConf.MaxRetry > 0 {
backupModelParams := c.BackupLLMParams
if backupModelParams != nil {
fallbackM, fallbackI, err = crossmodel.GetManager().GetModel(ctx, backupModelParams)
if err != nil {
return nil, err
}
}
}
}
type llmState = map[string]any
if fallbackM == nil {
modelWithInfo = NewModel(chatModel, info)
} else {
modelWithInfo = NewModelWithFallback(chatModel, fallbackM, info, fallbackI)
}
const agentModelName = "agent_model"
fcParams := c.FCParam
if fcParams != nil {
if fcParams.WorkflowFCParam != nil {
for _, wf := range fcParams.WorkflowFCParam.WorkflowList {
wfIDStr := wf.WorkflowID
wfID, err := strconv.ParseInt(wfIDStr, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid workflow id: %s", wfIDStr)
}
workflowToolConfig := vo.WorkflowToolConfig{}
if wf.FCSetting != nil {
workflowToolConfig.InputParametersConfig = wf.FCSetting.RequestParameters
workflowToolConfig.OutputParametersConfig = wf.FCSetting.ResponseParameters
}
locator := vo.FromDraft
if wf.WorkflowVersion != "" {
locator = vo.FromSpecificVersion
}
wfTool, err := workflow.GetRepository().WorkflowAsTool(ctx, vo.GetPolicy{
ID: wfID,
QType: locator,
Version: wf.WorkflowVersion,
}, workflowToolConfig)
if err != nil {
return nil, err
}
tools = append(tools, wfTool)
if wfTool.TerminatePlan() == vo.UseAnswerContent {
if toolsReturnDirectly == nil {
toolsReturnDirectly = make(map[string]bool)
}
toolInfo, err := wfTool.Info(ctx)
if err != nil {
return nil, err
}
toolsReturnDirectly[toolInfo.Name] = true
}
}
}
if fcParams.PluginFCParam != nil {
pluginToolsInvokableReq := make(map[int64]*plugin.ToolsInvokableRequest)
for _, p := range fcParams.PluginFCParam.PluginList {
pid, err := strconv.ParseInt(p.PluginID, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid plugin id: %s", p.PluginID)
}
toolID, err := strconv.ParseInt(p.ApiId, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid plugin id: %s", p.PluginID)
}
var (
requestParameters []*workflow3.APIParameter
responseParameters []*workflow3.APIParameter
)
if p.FCSetting != nil {
requestParameters = p.FCSetting.RequestParameters
responseParameters = p.FCSetting.ResponseParameters
}
if req, ok := pluginToolsInvokableReq[pid]; ok {
req.ToolsInvokableInfo[toolID] = &plugin.ToolsInvokableInfo{
ToolID: toolID,
RequestAPIParametersConfig: requestParameters,
ResponseAPIParametersConfig: responseParameters,
}
} else {
pluginToolsInfoRequest := &plugin.ToolsInvokableRequest{
PluginEntity: plugin.Entity{
PluginID: pid,
PluginVersion: ptr.Of(p.PluginVersion),
},
ToolsInvokableInfo: map[int64]*plugin.ToolsInvokableInfo{
toolID: {
ToolID: toolID,
RequestAPIParametersConfig: requestParameters,
ResponseAPIParametersConfig: responseParameters,
},
},
IsDraft: p.IsDraft,
}
pluginToolsInvokableReq[pid] = pluginToolsInfoRequest
}
}
inInvokableTools := make([]tool.BaseTool, 0, len(fcParams.PluginFCParam.PluginList))
for _, req := range pluginToolsInvokableReq {
toolMap, err := plugin.GetPluginService().GetPluginInvokableTools(ctx, req)
if err != nil {
return nil, err
}
for _, t := range toolMap {
inInvokableTools = append(inInvokableTools, plugin.NewInvokableTool(t))
}
}
if len(inInvokableTools) > 0 {
tools = append(tools, inInvokableTools...)
}
}
if fcParams.KnowledgeFCParam != nil && len(fcParams.KnowledgeFCParam.KnowledgeList) > 0 {
kwChatModel := workflow.GetRepository().GetKnowledgeRecallChatModel()
if kwChatModel == nil {
return nil, fmt.Errorf("workflow builtin chat model for knowledge recall not configured")
}
knowledgeOperator := knowledge.GetKnowledgeOperator()
setting := fcParams.KnowledgeFCParam.GlobalSetting
knowledgeRecallConfig = &KnowledgeRecallConfig{
ChatModel: kwChatModel,
Retriever: knowledgeOperator,
}
searchType, err := toRetrievalSearchType(setting.SearchMode)
if err != nil {
return nil, err
}
knowledgeRecallConfig.RetrievalStrategy = &RetrievalStrategy{
RetrievalStrategy: &knowledge.RetrievalStrategy{
TopK: ptr.Of(setting.TopK),
MinScore: ptr.Of(setting.MinScore),
SearchType: searchType,
EnableNL2SQL: setting.UseNL2SQL,
EnableQueryRewrite: setting.UseRewrite,
EnableRerank: setting.UseRerank,
},
NoReCallReplyMode: NoReCallReplyMode(setting.NoRecallReplyMode),
NoReCallReplyCustomizePrompt: setting.NoRecallReplyCustomizePrompt,
}
knowledgeIDs := make([]int64, 0, len(fcParams.KnowledgeFCParam.KnowledgeList))
for _, kw := range fcParams.KnowledgeFCParam.KnowledgeList {
kid, err := strconv.ParseInt(kw.ID, 10, 64)
if err != nil {
return nil, err
}
knowledgeIDs = append(knowledgeIDs, kid)
}
detailResp, err := knowledgeOperator.ListKnowledgeDetail(ctx,
&knowledge.ListKnowledgeDetailRequest{
KnowledgeIDs: knowledgeIDs,
})
if err != nil {
return nil, err
}
knowledgeRecallConfig.SelectedKnowledgeDetails = detailResp.KnowledgeDetails
}
}
func New(ctx context.Context, cfg *Config) (*LLM, error) {
g := compose.NewGraph[map[string]any, map[string]any](compose.WithGenLocalState(func(ctx context.Context) (state llmState) {
return llmState{}
}))
var (
hasReasoning bool
canStream = true
)
var hasReasoning bool
format := cfg.OutputFormat
format := c.OutputFormat
if format == FormatJSON {
if len(cfg.OutputFields) == 1 {
for _, v := range cfg.OutputFields {
if len(ns.OutputTypes) == 1 {
for _, v := range ns.OutputTypes {
if v.Type == vo.DataTypeString {
format = FormatText
break
}
}
} else if len(cfg.OutputFields) == 2 {
if _, ok := cfg.OutputFields[ReasoningOutputKey]; ok {
for k, v := range cfg.OutputFields {
} else if len(ns.OutputTypes) == 2 {
if _, ok := ns.OutputTypes[ReasoningOutputKey]; ok {
for k, v := range ns.OutputTypes {
if k != ReasoningOutputKey && v.Type == vo.DataTypeString {
format = FormatText
break
@@ -272,10 +561,10 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
}
}
userPrompt := cfg.UserPrompt
userPrompt := c.UserPrompt
switch format {
case FormatJSON:
jsonSchema, err := vo.TypeInfoToJSONSchema(cfg.OutputFields, nil)
jsonSchema, err := vo.TypeInfoToJSONSchema(ns.OutputTypes, nil)
if err != nil {
return nil, err
}
@@ -287,20 +576,20 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
case FormatText:
}
if cfg.KnowledgeRecallConfig != nil {
err := injectKnowledgeTool(ctx, g, cfg.UserPrompt, cfg.KnowledgeRecallConfig)
if knowledgeRecallConfig != nil {
err := injectKnowledgeTool(ctx, g, c.UserPrompt, knowledgeRecallConfig)
if err != nil {
return nil, err
}
userPrompt = fmt.Sprintf("{{%s}}%s", knowledgeUserPromptTemplateKey, userPrompt)
inputs := maps.Clone(cfg.InputFields)
inputs := maps.Clone(ns.InputTypes)
inputs[knowledgeUserPromptTemplateKey] = &vo.TypeInfo{
Type: vo.DataTypeString,
}
sp := newPromptTpl(schema.System, cfg.SystemPrompt, inputs, nil)
sp := newPromptTpl(schema.System, c.SystemPrompt, inputs, nil)
up := newPromptTpl(schema.User, userPrompt, inputs, []string{knowledgeUserPromptTemplateKey})
template := newPrompts(sp, up, cfg.ChatModel)
template := newPrompts(sp, up, modelWithInfo)
_ = g.AddChatTemplateNode(templateNodeKey, template,
compose.WithStatePreHandler(func(ctx context.Context, in map[string]any, state llmState) (map[string]any, error) {
@@ -312,28 +601,28 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
_ = g.AddEdge(knowledgeLambdaKey, templateNodeKey)
} else {
sp := newPromptTpl(schema.System, cfg.SystemPrompt, cfg.InputFields, nil)
up := newPromptTpl(schema.User, userPrompt, cfg.InputFields, nil)
template := newPrompts(sp, up, cfg.ChatModel)
sp := newPromptTpl(schema.System, c.SystemPrompt, ns.InputTypes, nil)
up := newPromptTpl(schema.User, userPrompt, ns.InputTypes, nil)
template := newPrompts(sp, up, modelWithInfo)
_ = g.AddChatTemplateNode(templateNodeKey, template)
_ = g.AddEdge(compose.START, templateNodeKey)
}
if len(cfg.Tools) > 0 {
m, ok := cfg.ChatModel.(model.ToolCallingChatModel)
if len(tools) > 0 {
m, ok := modelWithInfo.(model.ToolCallingChatModel)
if !ok {
return nil, errors.New("requires a ToolCallingChatModel to use with tools")
}
reactConfig := react.AgentConfig{
ToolCallingModel: m,
ToolsConfig: compose.ToolsNodeConfig{Tools: cfg.Tools},
ToolsConfig: compose.ToolsNodeConfig{Tools: tools},
ModelNodeName: agentModelName,
}
if len(cfg.ToolsReturnDirectly) > 0 {
reactConfig.ToolReturnDirectly = make(map[string]struct{}, len(cfg.ToolsReturnDirectly))
for k := range cfg.ToolsReturnDirectly {
if len(toolsReturnDirectly) > 0 {
reactConfig.ToolReturnDirectly = make(map[string]struct{}, len(toolsReturnDirectly))
for k := range toolsReturnDirectly {
reactConfig.ToolReturnDirectly[k] = struct{}{}
}
}
@@ -347,28 +636,26 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
opts = append(opts, compose.WithNodeName("workflow_llm_react_agent"))
_ = g.AddGraphNode(llmNodeKey, agentNode, opts...)
} else {
_ = g.AddChatModelNode(llmNodeKey, cfg.ChatModel)
_ = g.AddChatModelNode(llmNodeKey, modelWithInfo)
}
_ = g.AddEdge(templateNodeKey, llmNodeKey)
if format == FormatJSON {
iConvert := func(ctx context.Context, msg *schema.Message) (map[string]any, error) {
return jsonParse(ctx, msg.Content, cfg.OutputFields)
return jsonParse(ctx, msg.Content, ns.OutputTypes)
}
convertNode := compose.InvokableLambda(iConvert)
_ = g.AddLambdaNode(outputConvertNodeKey, convertNode)
canStream = false
} else {
var outputKey string
if len(cfg.OutputFields) != 1 && len(cfg.OutputFields) != 2 {
if len(ns.OutputTypes) != 1 && len(ns.OutputTypes) != 2 {
panic("impossible")
}
for k, v := range cfg.OutputFields {
for k, v := range ns.OutputTypes {
if v.Type != vo.DataTypeString {
panic("impossible")
}
@@ -443,17 +730,17 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
_ = g.AddEdge(outputConvertNodeKey, compose.END)
requireCheckpoint := false
if len(cfg.Tools) > 0 {
if len(tools) > 0 {
requireCheckpoint = true
}
var opts []compose.GraphCompileOption
var compileOpts []compose.GraphCompileOption
if requireCheckpoint {
opts = append(opts, compose.WithCheckPointStore(workflow.GetRepository()))
compileOpts = append(compileOpts, compose.WithCheckPointStore(workflow.GetRepository()))
}
opts = append(opts, compose.WithGraphName("workflow_llm_node_graph"))
compileOpts = append(compileOpts, compose.WithGraphName("workflow_llm_node_graph"))
r, err := g.Compile(ctx, opts...)
r, err := g.Compile(ctx, compileOpts...)
if err != nil {
return nil, err
}
@@ -461,15 +748,132 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
llm := &LLM{
r: r,
outputFormat: format,
canStream: canStream,
requireCheckpoint: requireCheckpoint,
fullSources: cfg.FullSources,
fullSources: ns.FullSources,
}
return llm, nil
}
func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...Option) (composeOpts []compose.Option, resumingEvent *entity.InterruptEvent, err error) {
func (c *Config) RequireCheckpoint() bool {
if c.FCParam != nil {
if c.FCParam.WorkflowFCParam != nil || c.FCParam.PluginFCParam != nil {
return true
}
}
return false
}
func (c *Config) FieldStreamType(path compose.FieldPath, ns *schema2.NodeSchema,
sc *schema2.WorkflowSchema) (schema2.FieldStreamType, error) {
if !sc.RequireStreaming() {
return schema2.FieldNotStream, nil
}
if len(path) != 1 {
return schema2.FieldNotStream, nil
}
outputs := ns.OutputTypes
if len(outputs) != 1 && len(outputs) != 2 {
return schema2.FieldNotStream, nil
}
var outputKey string
for key, output := range outputs {
if output.Type != vo.DataTypeString {
return schema2.FieldNotStream, nil
}
if key != ReasoningOutputKey {
if len(outputKey) > 0 {
return schema2.FieldNotStream, nil
}
outputKey = key
}
}
field := path[0]
if field == ReasoningOutputKey || field == outputKey {
return schema2.FieldIsStream, nil
}
return schema2.FieldNotStream, nil
}
func toRetrievalSearchType(s int64) (knowledge.SearchType, error) {
switch s {
case 0:
return knowledge.SearchTypeSemantic, nil
case 1:
return knowledge.SearchTypeHybrid, nil
case 20:
return knowledge.SearchTypeFullText, nil
default:
return "", fmt.Errorf("invalid retrieval search type %v", s)
}
}
type LLM struct {
r compose.Runnable[map[string]any, map[string]any]
outputFormat Format
requireCheckpoint bool
fullSources map[string]*schema2.SourceInfo
}
const (
rawOutputKey = "llm_raw_output_%s"
warningKey = "llm_warning_%s"
)
func jsonParse(ctx context.Context, data string, schema_ map[string]*vo.TypeInfo) (map[string]any, error) {
data = nodes.ExtractJSONString(data)
var result map[string]any
err := sonic.UnmarshalString(data, &result)
if err != nil {
c := execute.GetExeCtx(ctx)
if c != nil {
logs.CtxErrorf(ctx, "failed to parse json: %v, data: %s", err, data)
rawOutputK := fmt.Sprintf(rawOutputKey, c.NodeCtx.NodeKey)
warningK := fmt.Sprintf(warningKey, c.NodeCtx.NodeKey)
ctxcache.Store(ctx, rawOutputK, data)
ctxcache.Store(ctx, warningK, vo.WrapWarn(errno.ErrLLMStructuredOutputParseFail, err))
return map[string]any{}, nil
}
return nil, err
}
r, ws, err := nodes.ConvertInputs(ctx, result, schema_)
if err != nil {
return nil, vo.WrapError(errno.ErrLLMStructuredOutputParseFail, err)
}
if ws != nil {
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
}
return r, nil
}
type llmOptions struct {
toolWorkflowSW *schema.StreamWriter[*entity.Message]
}
func WithToolWorkflowMessageWriter(sw *schema.StreamWriter[*entity.Message]) nodes.NodeOption {
return nodes.WrapImplSpecificOptFn(func(o *llmOptions) {
o.toolWorkflowSW = sw
})
}
type llmState = map[string]any
const agentModelName = "agent_model"
func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...nodes.NodeOption) (composeOpts []compose.Option, resumingEvent *entity.InterruptEvent, err error) {
c := execute.GetExeCtx(ctx)
if c != nil {
resumingEvent = c.NodeCtx.ResumingEvent
@@ -502,17 +906,9 @@ func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...Option) (co
composeOpts = append(composeOpts, compose.WithCheckPointID(checkpointID))
}
llmOpts := &Options{}
for _, opt := range opts {
opt(llmOpts)
}
options := nodes.GetCommonOptions(&nodes.NodeOptions{}, opts...)
nestedOpts := &nodes.NestedWorkflowOptions{}
for _, opt := range llmOpts.nested {
opt(nestedOpts)
}
composeOpts = append(composeOpts, nestedOpts.GetOptsForNested()...)
composeOpts = append(composeOpts, options.GetOptsForNested()...)
if resumingEvent != nil {
var (
@@ -580,6 +976,7 @@ func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...Option) (co
composeOpts = append(composeOpts, compose.WithToolsNodeOption(compose.WithToolOption(execute.WithExecuteConfig(exeCfg))))
}
llmOpts := nodes.GetImplSpecificOptions(&llmOptions{}, opts...)
if llmOpts.toolWorkflowSW != nil {
toolMsgOpt, toolMsgSR := execute.WithMessagePipe()
composeOpts = append(composeOpts, toolMsgOpt)
@@ -697,7 +1094,7 @@ func handleInterrupt(ctx context.Context, err error, resumingEvent *entity.Inter
return compose.NewInterruptAndRerunErr(ie)
}
func (l *LLM) Chat(ctx context.Context, in map[string]any, opts ...Option) (out map[string]any, err error) {
func (l *LLM) Invoke(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (out map[string]any, err error) {
composeOpts, resumingEvent, err := l.prepare(ctx, in, opts...)
if err != nil {
return nil, err
@@ -712,7 +1109,7 @@ func (l *LLM) Chat(ctx context.Context, in map[string]any, opts ...Option) (out
return out, nil
}
func (l *LLM) ChatStream(ctx context.Context, in map[string]any, opts ...Option) (out *schema.StreamReader[map[string]any], err error) {
func (l *LLM) Stream(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (out *schema.StreamReader[map[string]any], err error) {
composeOpts, resumingEvent, err := l.prepare(ctx, in, opts...)
if err != nil {
return nil, err
@@ -745,7 +1142,7 @@ func injectKnowledgeTool(_ context.Context, g *compose.Graph[map[string]any, map
_ = g.AddLambdaNode(knowledgeLambdaKey, compose.InvokableLambda(func(ctx context.Context, input *schema.Message) (output map[string]any, err error) {
modelPredictionIDs := strings.Split(input.Content, ",")
selectKwIDs := slices.ToMap(cfg.SelectedKnowledgeDetails, func(e *crossknowledge.KnowledgeDetail) (string, int64) {
selectKwIDs := slices.ToMap(cfg.SelectedKnowledgeDetails, func(e *knowledge.KnowledgeDetail) (string, int64) {
return strconv.Itoa(int(e.ID)), e.ID
})
recallKnowledgeIDs := make([]int64, 0)
@@ -759,7 +1156,7 @@ func injectKnowledgeTool(_ context.Context, g *compose.Graph[map[string]any, map
return make(map[string]any), nil
}
docs, err := cfg.Retriever.Retrieve(ctx, &crossknowledge.RetrieveRequest{
docs, err := cfg.Retriever.Retrieve(ctx, &knowledge.RetrieveRequest{
Query: userPrompt,
KnowledgeIDs: recallKnowledgeIDs,
RetrievalStrategy: cfg.RetrievalStrategy.RetrievalStrategy,

View File

@@ -26,6 +26,7 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
@@ -107,7 +108,7 @@ func newPrompts(sp, up *promptTpl, model ModelWithInfo) *prompts {
}
func (pl *promptTpl) render(ctx context.Context, vs map[string]any,
sources map[string]*nodes.SourceInfo,
sources map[string]*schema2.SourceInfo,
supportedModals map[modelmgr.Modal]bool,
) (*schema.Message, error) {
if !pl.hasMultiModal || len(supportedModals) == 0 {
@@ -247,7 +248,7 @@ func (p *prompts) Format(ctx context.Context, vs map[string]any, _ ...prompt.Opt
}
sk := fmt.Sprintf(sourceKey, nodeKey)
sources, ok := ctxcache.Get[map[string]*nodes.SourceInfo](ctx, sk)
sources, ok := ctxcache.Get[map[string]*schema2.SourceInfo](ctx, sk)
if !ok {
return nil, fmt.Errorf("resolved sources not found llm node, key: %s", sk)
}

View File

@@ -14,7 +14,7 @@
* limitations under the License.
*/
package loop
package _break
import (
"context"
@@ -22,21 +22,36 @@ import (
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
type Break struct {
parentIntermediateStore variable.Store
}
func NewBreak(_ context.Context, store variable.Store) (*Break, error) {
type Config struct{}
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
return &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeBreak,
Name: n.Data.Meta.Title,
Configs: c,
}, nil
}
func (c *Config) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
return &Break{
parentIntermediateStore: store,
parentIntermediateStore: &nodes.ParentIntermediateStore{},
}, nil
}
const BreakKey = "$break"
func (b *Break) DoBreak(ctx context.Context, _ map[string]any) (map[string]any, error) {
func (b *Break) Invoke(ctx context.Context, _ map[string]any) (map[string]any, error) {
err := b.parentIntermediateStore.Set(ctx, compose.FieldPath{BreakKey}, true)
if err != nil {
return nil, err

View File

@@ -0,0 +1,47 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package _continue
import (
"context"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
type Continue struct{}
type Config struct{}
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
return &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeContinue,
Name: n.Data.Meta.Title,
Configs: c,
}, nil
}
func (c *Config) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
return &Continue{}, nil
}
func (co *Continue) Invoke(_ context.Context, in map[string]any) (map[string]any, error) {
return in, nil
}

View File

@@ -27,53 +27,150 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
_break "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop/break"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
type Loop struct {
config *Config
outputs map[string]*vo.FieldSource
outputVars map[string]string
inner compose.Runnable[map[string]any, map[string]any]
nodeKey vo.NodeKey
loopType Type
inputArrays []string
intermediateVars map[string]*vo.TypeInfo
}
type Config struct {
LoopNodeKey vo.NodeKey
LoopType Type
InputArrays []string
IntermediateVars map[string]*vo.TypeInfo
Outputs []*vo.FieldInfo
Inner compose.Runnable[map[string]any, map[string]any]
}
type Type string
const (
ByArray Type = "by_array"
ByIteration Type = "by_iteration"
Infinite Type = "infinite"
)
func NewLoop(_ context.Context, conf *Config) (*Loop, error) {
if conf == nil {
return nil, errors.New("config is nil")
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
if n.Parent() != nil {
return nil, fmt.Errorf("loop node cannot have parent: %s", n.Parent().ID)
}
if conf.LoopType == ByArray {
if len(conf.InputArrays) == 0 {
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeLoop,
Name: n.Data.Meta.Title,
Configs: c,
}
loopType, err := toLoopType(n.Data.Inputs.LoopType)
if err != nil {
return nil, err
}
c.LoopType = loopType
intermediateVars := make(map[string]*vo.TypeInfo)
for _, param := range n.Data.Inputs.VariableParameters {
tInfo, err := convert.CanvasBlockInputToTypeInfo(param.Input)
if err != nil {
return nil, err
}
intermediateVars[param.Name] = tInfo
ns.SetInputType(param.Name, tInfo)
sources, err := convert.CanvasBlockInputToFieldInfo(param.Input, compose.FieldPath{param.Name}, nil)
if err != nil {
return nil, err
}
ns.AddInputSource(sources...)
}
c.IntermediateVars = intermediateVars
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err := convert.SetOutputsForNodeSchema(n, ns); err != nil {
return nil, err
}
for _, fieldInfo := range ns.OutputSources {
if fieldInfo.Source.Ref != nil {
if len(fieldInfo.Source.Ref.FromPath) == 1 {
if _, ok := intermediateVars[fieldInfo.Source.Ref.FromPath[0]]; ok {
fieldInfo.Source.Ref.VariableType = ptr.Of(vo.ParentIntermediate)
}
}
}
}
loopCount := n.Data.Inputs.LoopCount
if loopCount != nil {
typeInfo, err := convert.CanvasBlockInputToTypeInfo(loopCount)
if err != nil {
return nil, err
}
ns.SetInputType(Count, typeInfo)
sources, err := convert.CanvasBlockInputToFieldInfo(loopCount, compose.FieldPath{Count}, nil)
if err != nil {
return nil, err
}
ns.AddInputSource(sources...)
}
for key, tInfo := range ns.InputTypes {
if tInfo.Type != vo.DataTypeArray {
continue
}
if _, ok := intermediateVars[key]; ok { // exclude arrays in intermediate vars
continue
}
c.InputArrays = append(c.InputArrays, key)
}
return ns, nil
}
func toLoopType(l vo.LoopType) (Type, error) {
switch l {
case vo.LoopTypeArray:
return ByArray, nil
case vo.LoopTypeCount:
return ByIteration, nil
case vo.LoopTypeInfinite:
return Infinite, nil
default:
return "", fmt.Errorf("unsupported loop type: %s", l)
}
}
func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, opts ...schema.BuildOption) (any, error) {
if c.LoopType == ByArray {
if len(c.InputArrays) == 0 {
return nil, errors.New("input arrays is empty when loop type is ByArray")
}
}
loop := &Loop{
config: conf,
outputs: make(map[string]*vo.FieldSource),
outputVars: make(map[string]string),
options := schema.GetBuildOptions(opts...)
if options.Inner == nil {
return nil, errors.New("inner workflow is required for Loop Node")
}
for _, info := range conf.Outputs {
loop := &Loop{
outputs: make(map[string]*vo.FieldSource),
outputVars: make(map[string]string),
inputArrays: c.InputArrays,
nodeKey: ns.Key,
intermediateVars: c.IntermediateVars,
inner: options.Inner,
loopType: c.LoopType,
}
for _, info := range ns.OutputSources {
if len(info.Path) != 1 {
return nil, fmt.Errorf("invalid output path: %s", info.Path)
}
@@ -87,7 +184,7 @@ func NewLoop(_ context.Context, conf *Config) (*Loop, error) {
return nil, fmt.Errorf("loop output refers to intermediate variable, but path length > 1: %v", fromPath)
}
if _, ok := conf.IntermediateVars[fromPath[0]]; !ok {
if _, ok := c.IntermediateVars[fromPath[0]]; !ok {
return nil, fmt.Errorf("loop output refers to intermediate variable, but not found in intermediate vars: %v", fromPath)
}
@@ -102,18 +199,27 @@ func NewLoop(_ context.Context, conf *Config) (*Loop, error) {
return loop, nil
}
type Type string
const (
ByArray Type = "by_array"
ByIteration Type = "by_iteration"
Infinite Type = "infinite"
)
const (
Count = "loopCount"
)
func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.NestedWorkflowOption) (out map[string]any, err error) {
func (l *Loop) Invoke(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (
out map[string]any, err error) {
maxIter, err := l.getMaxIter(in)
if err != nil {
return nil, err
}
arrays := make(map[string][]any, len(l.config.InputArrays))
for _, arrayKey := range l.config.InputArrays {
arrays := make(map[string][]any, len(l.inputArrays))
for _, arrayKey := range l.inputArrays {
a, ok := nodes.TakeMapValue(in, compose.FieldPath{arrayKey})
if !ok {
return nil, fmt.Errorf("incoming array not present in input: %s", arrayKey)
@@ -121,10 +227,7 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
arrays[arrayKey] = a.([]any)
}
options := &nodes.NestedWorkflowOptions{}
for _, opt := range opts {
opt(options)
}
options := nodes.GetCommonOptions(&nodes.NodeOptions{}, opts...)
var (
existingCState *nodes.NestedWorkflowState
@@ -134,7 +237,7 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
)
err = compose.ProcessState(ctx, func(ctx context.Context, getter nodes.NestedWorkflowAware) error {
var e error
existingCState, _, e = getter.GetNestedWorkflowState(l.config.LoopNodeKey)
existingCState, _, e = getter.GetNestedWorkflowState(l.nodeKey)
if e != nil {
return e
}
@@ -150,15 +253,15 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
for k := range existingCState.IntermediateVars {
intermediateVars[k] = ptr.Of(existingCState.IntermediateVars[k])
}
intermediateVars[BreakKey] = &hasBreak
intermediateVars[_break.BreakKey] = &hasBreak
} else {
output = make(map[string]any, len(l.outputs))
for k := range l.outputs {
output[k] = make([]any, 0)
}
intermediateVars = make(map[string]*any, len(l.config.IntermediateVars))
for varKey := range l.config.IntermediateVars {
intermediateVars = make(map[string]*any, len(l.intermediateVars))
for varKey := range l.intermediateVars {
v, ok := nodes.TakeMapValue(in, compose.FieldPath{varKey})
if !ok {
return nil, fmt.Errorf("incoming intermediate variable not present in input: %s", varKey)
@@ -166,10 +269,10 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
intermediateVars[varKey] = &v
}
intermediateVars[BreakKey] = &hasBreak
intermediateVars[_break.BreakKey] = &hasBreak
}
ctx = nodes.InitIntermediateVars(ctx, intermediateVars, l.config.IntermediateVars)
ctx = nodes.InitIntermediateVars(ctx, intermediateVars, l.intermediateVars)
getIthInput := func(i int) (map[string]any, map[string]any, error) {
input := make(map[string]any)
@@ -190,13 +293,13 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
input[k] = v
}
input[string(l.config.LoopNodeKey)+"#index"] = int64(i)
input[string(l.nodeKey)+"#index"] = int64(i)
items := make(map[string]any)
for arrayKey := range arrays {
ele := arrays[arrayKey][i]
items[arrayKey] = ele
currentKey := string(l.config.LoopNodeKey) + "#" + arrayKey
currentKey := string(l.nodeKey) + "#" + arrayKey
// Recursively expand map[string]any elements
var expand func(prefix string, val interface{})
@@ -276,7 +379,7 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
}
}
taskOutput, err := l.config.Inner.Invoke(subCtx, input, ithOpts...)
taskOutput, err := l.inner.Invoke(subCtx, input, ithOpts...)
if err != nil {
info, ok := compose.ExtractInterruptInfo(err)
if !ok {
@@ -322,29 +425,26 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
if len(index2InterruptInfo) > 0 { // this invocation of batch.Execute has new interruptions
iEvent := &entity.InterruptEvent{
NodeKey: l.config.LoopNodeKey,
NodeKey: l.nodeKey,
NodeType: entity.NodeTypeLoop,
NestedInterruptInfo: index2InterruptInfo, // only emit the newly generated interruptInfo
}
err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error {
if e := setter.SaveNestedWorkflowState(l.config.LoopNodeKey, compState); e != nil {
if e := setter.SaveNestedWorkflowState(l.nodeKey, compState); e != nil {
return e
}
return setter.SetInterruptEvent(l.config.LoopNodeKey, iEvent)
return setter.SetInterruptEvent(l.nodeKey, iEvent)
})
if err != nil {
return nil, err
}
fmt.Println("save interruptEvent in state within loop: ", iEvent)
fmt.Println("save composite info in state within loop: ", compState)
return nil, compose.InterruptAndRerun
} else {
err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error {
return setter.SaveNestedWorkflowState(l.config.LoopNodeKey, compState)
return setter.SaveNestedWorkflowState(l.nodeKey, compState)
})
if err != nil {
return nil, err
@@ -354,8 +454,7 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
}
if existingCState != nil && len(existingCState.Index2InterruptInfo) > 0 {
fmt.Println("no interrupt thrown this round, but has historical interrupt events: ", existingCState.Index2InterruptInfo)
panic("impossible")
panic(fmt.Sprintf("no interrupt thrown this round, but has historical interrupt events: %v", existingCState.Index2InterruptInfo))
}
for outputVarKey, intermediateVarKey := range l.outputVars {
@@ -368,9 +467,9 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
func (l *Loop) getMaxIter(in map[string]any) (int, error) {
maxIter := math.MaxInt
switch l.config.LoopType {
switch l.loopType {
case ByArray:
for _, arrayKey := range l.config.InputArrays {
for _, arrayKey := range l.inputArrays {
a, ok := nodes.TakeMapValue(in, compose.FieldPath{arrayKey})
if !ok {
return 0, fmt.Errorf("incoming array not present in input: %s", arrayKey)
@@ -394,7 +493,7 @@ func (l *Loop) getMaxIter(in map[string]any) (int, error) {
maxIter = int(iter.(int64))
case Infinite:
default:
return 0, fmt.Errorf("loop type not supported: %v", l.config.LoopType)
return 0, fmt.Errorf("loop type not supported: %v", l.loopType)
}
return maxIter, nil
@@ -409,8 +508,8 @@ func convertIntermediateVars(vars map[string]*any) map[string]any {
}
func (l *Loop) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) {
trimmed := make(map[string]any, len(l.config.InputArrays))
for _, arrayKey := range l.config.InputArrays {
trimmed := make(map[string]any, len(l.inputArrays))
for _, arrayKey := range l.inputArrays {
if v, ok := in[arrayKey]; ok {
trimmed[arrayKey] = v
}

View File

@@ -1,90 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nodes
import (
"github.com/bytedance/sonic"
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
)
type NestedWorkflowOptions struct {
optsForNested []compose.Option
toResumeIndexes map[int]compose.StateModifier
optsForIndexed map[int][]compose.Option
}
type NestedWorkflowOption func(*NestedWorkflowOptions)
func WithOptsForNested(opts ...compose.Option) NestedWorkflowOption {
return func(o *NestedWorkflowOptions) {
o.optsForNested = append(o.optsForNested, opts...)
}
}
func (c *NestedWorkflowOptions) GetOptsForNested() []compose.Option {
return c.optsForNested
}
func WithResumeIndex(i int, m compose.StateModifier) NestedWorkflowOption {
return func(o *NestedWorkflowOptions) {
if o.toResumeIndexes == nil {
o.toResumeIndexes = map[int]compose.StateModifier{}
}
o.toResumeIndexes[i] = m
}
}
func (c *NestedWorkflowOptions) GetResumeIndexes() map[int]compose.StateModifier {
return c.toResumeIndexes
}
func WithOptsForIndexed(index int, opts ...compose.Option) NestedWorkflowOption {
return func(o *NestedWorkflowOptions) {
if o.optsForIndexed == nil {
o.optsForIndexed = map[int][]compose.Option{}
}
o.optsForIndexed[index] = opts
}
}
func (c *NestedWorkflowOptions) GetOptsForIndexed(index int) []compose.Option {
if c.optsForIndexed == nil {
return nil
}
return c.optsForIndexed[index]
}
type NestedWorkflowState struct {
Index2Done map[int]bool `json:"index_2_done,omitempty"`
Index2InterruptInfo map[int]*compose.InterruptInfo `json:"index_2_interrupt_info,omitempty"`
FullOutput map[string]any `json:"full_output,omitempty"`
IntermediateVars map[string]any `json:"intermediate_vars,omitempty"`
}
func (c *NestedWorkflowState) String() string {
s, _ := sonic.MarshalIndent(c, "", " ")
return string(s)
}
type NestedWorkflowAware interface {
SaveNestedWorkflowState(key vo.NodeKey, state *NestedWorkflowState) error
GetNestedWorkflowState(key vo.NodeKey) (*NestedWorkflowState, bool, error)
InterruptEventStore
}

View File

@@ -0,0 +1,194 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nodes
import (
"context"
"fmt"
"github.com/cloudwego/eino/compose"
einoschema "github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
// InvokableNode is a basic workflow node that can Invoke.
// Invoke accepts non-streaming input and returns non-streaming output.
// It does not accept any options.
// Most nodes implement this, such as NodeTypePlugin.
type InvokableNode interface {
Invoke(ctx context.Context, input map[string]any) (
output map[string]any, err error)
}
// InvokableNodeWOpt is a workflow node that can Invoke.
// Invoke accepts non-streaming input and returns non-streaming output.
// It can accept NodeOption.
// e.g. NodeTypeLLM, NodeTypeSubWorkflow implement this.
type InvokableNodeWOpt interface {
Invoke(ctx context.Context, in map[string]any, opts ...NodeOption) (
map[string]any, error)
}
// StreamableNode is a workflow node that can Stream.
// Stream accepts non-streaming input and returns streaming output.
// It does not accept and options
// Currently NO Node implement this.
// A potential example would be streamable plugin for NodeTypePlugin.
type StreamableNode interface {
Stream(ctx context.Context, in map[string]any) (
*einoschema.StreamReader[map[string]any], error)
}
// StreamableNodeWOpt is a workflow node that can Stream.
// Stream accepts non-streaming input and returns streaming output.
// It can accept NodeOption.
// e.g. NodeTypeLLM implement this.
type StreamableNodeWOpt interface {
Stream(ctx context.Context, in map[string]any, opts ...NodeOption) (
*einoschema.StreamReader[map[string]any], error)
}
// CollectableNode is a workflow node that can Collect.
// Collect accepts streaming input and returns non-streaming output.
// It does not accept and options
// Currently NO Node implement this.
// A potential example would be a new condition node that makes decisions
// based on streaming input.
type CollectableNode interface {
Collect(ctx context.Context, in *einoschema.StreamReader[map[string]any]) (
map[string]any, error)
}
// CollectableNodeWOpt is a workflow node that can Collect.
// Collect accepts streaming input and returns non-streaming output.
// It accepts NodeOption.
// Currently NO Node implement this.
// A potential example would be a new batch node that accepts streaming input,
// process them, and finally returns non-stream aggregation of results.
type CollectableNodeWOpt interface {
Collect(ctx context.Context, in *einoschema.StreamReader[map[string]any], opts ...NodeOption) (
map[string]any, error)
}
// TransformableNode is a workflow node that can Transform.
// Transform accepts streaming input and returns streaming output.
// It does not accept and options
// e.g.
// NodeTypeVariableAggregator implements TransformableNode.
type TransformableNode interface {
Transform(ctx context.Context, in *einoschema.StreamReader[map[string]any]) (
*einoschema.StreamReader[map[string]any], error)
}
// TransformableNodeWOpt is a workflow node that can Transform.
// Transform accepts streaming input and returns streaming output.
// It accepts NodeOption.
// Currently NO Node implement this.
// A potential example would be an audio processing node that
// transforms input audio clips, but within the node is a graph
// composed by Eino, and the audio processing node needs to carry
// options for this inner graph.
type TransformableNodeWOpt interface {
Transform(ctx context.Context, in *einoschema.StreamReader[map[string]any], opts ...NodeOption) (
*einoschema.StreamReader[map[string]any], error)
}
// CallbackInputConverted converts node input to a form better suited for UI.
// The converted input will be displayed on canvas when test run,
// and will be returned when querying the node's input through OpenAPI.
type CallbackInputConverted interface {
ToCallbackInput(ctx context.Context, in map[string]any) (map[string]any, error)
}
// CallbackOutputConverted converts node input to a form better suited for UI.
// The converted output will be displayed on canvas when test run,
// and will be returned when querying the node's output through OpenAPI.
type CallbackOutputConverted interface {
ToCallbackOutput(ctx context.Context, out map[string]any) (*StructuredCallbackOutput, error)
}
type Initializer interface {
Init(ctx context.Context) (context.Context, error)
}
type AdaptOptions struct {
Canvas *vo.Canvas
}
type AdaptOption func(*AdaptOptions)
func WithCanvas(canvas *vo.Canvas) AdaptOption {
return func(opts *AdaptOptions) {
opts.Canvas = canvas
}
}
func GetAdaptOptions(opts ...AdaptOption) *AdaptOptions {
options := &AdaptOptions{}
for _, opt := range opts {
opt(options)
}
return options
}
// NodeAdaptor provides conversion from frontend Node to backend NodeSchema.
type NodeAdaptor interface {
Adapt(ctx context.Context, n *vo.Node, opts ...AdaptOption) (
*schema.NodeSchema, error)
}
// BranchAdaptor provides validation and conversion from frontend port to backend port.
type BranchAdaptor interface {
ExpectPorts(ctx context.Context, n *vo.Node) []string
}
var (
nodeAdaptors = map[entity.NodeType]func() NodeAdaptor{}
branchAdaptors = map[entity.NodeType]func() BranchAdaptor{}
)
func RegisterNodeAdaptor(et entity.NodeType, f func() NodeAdaptor) {
nodeAdaptors[et] = f
}
func GetNodeAdaptor(et entity.NodeType) (NodeAdaptor, bool) {
na, ok := nodeAdaptors[et]
if !ok {
panic(fmt.Sprintf("node type %s not registered", et))
}
return na(), ok
}
func RegisterBranchAdaptor(et entity.NodeType, f func() BranchAdaptor) {
branchAdaptors[et] = f
}
func GetBranchAdaptor(et entity.NodeType) (BranchAdaptor, bool) {
na, ok := branchAdaptors[et]
if !ok {
return nil, false
}
return na(), ok
}
type StreamGenerator interface {
FieldStreamType(path compose.FieldPath, ns *schema.NodeSchema,
sc *schema.WorkflowSchema) (schema.FieldStreamType, error)
}

View File

@@ -0,0 +1,170 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nodes
import (
"github.com/bytedance/sonic"
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
)
type NodeOptions struct {
Nested *NestedWorkflowOptions
}
type NestedWorkflowOptions struct {
optsForNested []compose.Option
toResumeIndexes map[int]compose.StateModifier
optsForIndexed map[int][]compose.Option
}
type NodeOption struct {
apply func(opts *NodeOptions)
implSpecificOptFn any
}
type NestedWorkflowOption func(*NestedWorkflowOptions)
func WithOptsForNested(opts ...compose.Option) NodeOption {
return NodeOption{
apply: func(options *NodeOptions) {
if options.Nested == nil {
options.Nested = &NestedWorkflowOptions{}
}
options.Nested.optsForNested = append(options.Nested.optsForNested, opts...)
},
}
}
func (c *NodeOptions) GetOptsForNested() []compose.Option {
if c.Nested == nil {
return nil
}
return c.Nested.optsForNested
}
func WithResumeIndex(i int, m compose.StateModifier) NodeOption {
return NodeOption{
apply: func(options *NodeOptions) {
if options.Nested == nil {
options.Nested = &NestedWorkflowOptions{}
}
if options.Nested.toResumeIndexes == nil {
options.Nested.toResumeIndexes = map[int]compose.StateModifier{}
}
options.Nested.toResumeIndexes[i] = m
},
}
}
func (c *NodeOptions) GetResumeIndexes() map[int]compose.StateModifier {
if c.Nested == nil {
return nil
}
return c.Nested.toResumeIndexes
}
func WithOptsForIndexed(index int, opts ...compose.Option) NodeOption {
return NodeOption{
apply: func(options *NodeOptions) {
if options.Nested == nil {
options.Nested = &NestedWorkflowOptions{}
}
if options.Nested.optsForIndexed == nil {
options.Nested.optsForIndexed = map[int][]compose.Option{}
}
options.Nested.optsForIndexed[index] = opts
},
}
}
func (c *NodeOptions) GetOptsForIndexed(index int) []compose.Option {
if c.Nested == nil {
return nil
}
return c.Nested.optsForIndexed[index]
}
// WrapImplSpecificOptFn is the option to wrap the implementation specific option function.
func WrapImplSpecificOptFn[T any](optFn func(*T)) NodeOption {
return NodeOption{
implSpecificOptFn: optFn,
}
}
// GetCommonOptions extract model Options from Option list, optionally providing a base Options with default values.
func GetCommonOptions(base *NodeOptions, opts ...NodeOption) *NodeOptions {
if base == nil {
base = &NodeOptions{}
}
for i := range opts {
opt := opts[i]
if opt.apply != nil {
opt.apply(base)
}
}
return base
}
// GetImplSpecificOptions extract the implementation specific options from Option list, optionally providing a base options with default values.
// e.g.
//
// myOption := &MyOption{
// Field1: "default_value",
// }
//
// myOption := model.GetImplSpecificOptions(myOption, opts...)
func GetImplSpecificOptions[T any](base *T, opts ...NodeOption) *T {
if base == nil {
base = new(T)
}
for i := range opts {
opt := opts[i]
if opt.implSpecificOptFn != nil {
optFn, ok := opt.implSpecificOptFn.(func(*T))
if ok {
optFn(base)
}
}
}
return base
}
type NestedWorkflowState struct {
Index2Done map[int]bool `json:"index_2_done,omitempty"`
Index2InterruptInfo map[int]*compose.InterruptInfo `json:"index_2_interrupt_info,omitempty"`
FullOutput map[string]any `json:"full_output,omitempty"`
IntermediateVars map[string]any `json:"intermediate_vars,omitempty"`
}
func (c *NestedWorkflowState) String() string {
s, _ := sonic.MarshalIndent(c, "", " ")
return string(s)
}
type NestedWorkflowAware interface {
SaveNestedWorkflowState(key vo.NodeKey, state *NestedWorkflowState) error
GetNestedWorkflowState(key vo.NodeKey) (*NestedWorkflowState, bool, error)
InterruptEventStore
}

View File

@@ -18,16 +18,21 @@ package plugin
import (
"context"
"errors"
"fmt"
"strconv"
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
@@ -35,29 +40,76 @@ type Config struct {
PluginID int64
ToolID int64
PluginVersion string
}
PluginService plugin.Service
func (c *Config) Adapt(ctx context.Context, n *vo.Node, opts ...nodes.AdaptOption) (*schema.NodeSchema, error) {
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypePlugin,
Name: n.Data.Meta.Title,
Configs: c,
}
inputs := n.Data.Inputs
apiParams := slices.ToMap(inputs.APIParams, func(e *vo.Param) (string, *vo.Param) {
return e.Name, e
})
ps, ok := apiParams["pluginID"]
if !ok {
return nil, fmt.Errorf("plugin id param is not found")
}
pID, err := strconv.ParseInt(ps.Input.Value.Content.(string), 10, 64)
c.PluginID = pID
ps, ok = apiParams["apiID"]
if !ok {
return nil, fmt.Errorf("plugin id param is not found")
}
tID, err := strconv.ParseInt(ps.Input.Value.Content.(string), 10, 64)
if err != nil {
return nil, err
}
c.ToolID = tID
ps, ok = apiParams["pluginVersion"]
if !ok {
return nil, fmt.Errorf("plugin version param is not found")
}
version := ps.Input.Value.Content.(string)
c.PluginVersion = version
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func (c *Config) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
return &Plugin{
pluginID: c.PluginID,
toolID: c.ToolID,
pluginVersion: c.PluginVersion,
pluginService: plugin.GetPluginService(),
}, nil
}
type Plugin struct {
config *Config
}
pluginID int64
toolID int64
pluginVersion string
func NewPlugin(_ context.Context, cfg *Config) (*Plugin, error) {
if cfg == nil {
return nil, errors.New("config is nil")
}
if cfg.PluginID == 0 {
return nil, errors.New("plugin id is required")
}
if cfg.ToolID == 0 {
return nil, errors.New("tool id is required")
}
if cfg.PluginService == nil {
return nil, errors.New("tool service is required")
}
return &Plugin{config: cfg}, nil
pluginService plugin.Service
}
func (p *Plugin) Invoke(ctx context.Context, parameters map[string]any) (ret map[string]any, err error) {
@@ -65,10 +117,10 @@ func (p *Plugin) Invoke(ctx context.Context, parameters map[string]any) (ret map
if ctxExeCfg := execute.GetExeCtx(ctx); ctxExeCfg != nil {
exeCfg = ctxExeCfg.ExeCfg
}
result, err := p.config.PluginService.ExecutePlugin(ctx, parameters, &vo.PluginEntity{
PluginID: p.config.PluginID,
PluginVersion: ptr.Of(p.config.PluginVersion),
}, p.config.ToolID, exeCfg)
result, err := p.pluginService.ExecutePlugin(ctx, parameters, &vo.PluginEntity{
PluginID: p.pluginID,
PluginVersion: ptr.Of(p.pluginVersion),
}, p.toolID, exeCfg)
if err != nil {
if extra, ok := compose.IsInterruptRerunError(err); ok {
// TODO: temporarily replace interrupt with real error, because frontend cannot handle interrupt for now

View File

@@ -29,9 +29,12 @@ import (
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/domain/workflow"
crossmodel "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
@@ -39,8 +42,21 @@ import (
)
type QuestionAnswer struct {
config *Config
model model.BaseChatModel
nodeMeta entity.NodeTypeMeta
questionTpl string
answerType AnswerType
choiceType ChoiceType
fixedChoices []string
needExtractFromAnswer bool
additionalSystemPromptTpl string
maxAnswerCount int
nodeKey vo.NodeKey
outputFields map[string]*vo.TypeInfo
}
type Config struct {
@@ -51,15 +67,249 @@ type Config struct {
FixedChoices []string
// used for intent recognize if answer by choices and given a custom answer, as well as for extracting structured output from user response
Model model.BaseChatModel
LLMParams *crossmodel.LLMParams
// the following are required if AnswerType is AnswerDirectly and needs to extract from answer
ExtractFromAnswer bool
AdditionalSystemPromptTpl string
MaxAnswerCount int
OutputFields map[string]*vo.TypeInfo
}
NodeKey vo.NodeKey
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema2.NodeSchema, error) {
ns := &schema2.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeQuestionAnswer,
Name: n.Data.Meta.Title,
Configs: c,
}
qaConf := n.Data.Inputs.QA
if qaConf == nil {
return nil, fmt.Errorf("qa config is nil")
}
c.QuestionTpl = qaConf.Question
var llmParams *crossmodel.LLMParams
if n.Data.Inputs.LLMParam != nil {
llmParamBytes, err := sonic.Marshal(n.Data.Inputs.LLMParam)
if err != nil {
return nil, err
}
var qaLLMParams vo.SimpleLLMParam
err = sonic.Unmarshal(llmParamBytes, &qaLLMParams)
if err != nil {
return nil, err
}
llmParams, err = convertLLMParams(qaLLMParams)
if err != nil {
return nil, err
}
c.LLMParams = llmParams
}
answerType, err := convertAnswerType(qaConf.AnswerType)
if err != nil {
return nil, err
}
c.AnswerType = answerType
var choiceType ChoiceType
if len(qaConf.OptionType) > 0 {
choiceType, err = convertChoiceType(qaConf.OptionType)
if err != nil {
return nil, err
}
c.ChoiceType = choiceType
}
if answerType == AnswerByChoices {
switch choiceType {
case FixedChoices:
var options []string
for _, option := range qaConf.Options {
options = append(options, option.Name)
}
c.FixedChoices = options
case DynamicChoices:
inputSources, err := convert.CanvasBlockInputToFieldInfo(qaConf.DynamicOption, compose.FieldPath{DynamicChoicesKey}, n.Parent())
if err != nil {
return nil, err
}
ns.AddInputSource(inputSources...)
inputTypes, err := convert.CanvasBlockInputToTypeInfo(qaConf.DynamicOption)
if err != nil {
return nil, err
}
ns.SetInputType(DynamicChoicesKey, inputTypes)
default:
return nil, fmt.Errorf("qa node is answer by options, but option type not provided")
}
} else if answerType == AnswerDirectly {
c.ExtractFromAnswer = qaConf.ExtractOutput
if qaConf.ExtractOutput {
if llmParams == nil {
return nil, fmt.Errorf("qa node needs to extract from answer, but LLMParams not provided")
}
c.AdditionalSystemPromptTpl = llmParams.SystemPrompt
c.MaxAnswerCount = qaConf.Limit
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
}
}
if err = convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func convertLLMParams(params vo.SimpleLLMParam) (*crossmodel.LLMParams, error) {
p := &crossmodel.LLMParams{}
p.ModelName = params.ModelName
p.ModelType = params.ModelType
p.Temperature = &params.Temperature
p.MaxTokens = params.MaxTokens
p.TopP = &params.TopP
p.ResponseFormat = params.ResponseFormat
p.SystemPrompt = params.SystemPrompt
return p, nil
}
func convertAnswerType(t vo.QAAnswerType) (AnswerType, error) {
switch t {
case vo.QAAnswerTypeOption:
return AnswerByChoices, nil
case vo.QAAnswerTypeText:
return AnswerDirectly, nil
default:
return "", fmt.Errorf("invalid QAAnswerType: %s", t)
}
}
func convertChoiceType(t vo.QAOptionType) (ChoiceType, error) {
switch t {
case vo.QAOptionTypeStatic:
return FixedChoices, nil
case vo.QAOptionTypeDynamic:
return DynamicChoices, nil
default:
return "", fmt.Errorf("invalid QAOptionType: %s", t)
}
}
func (c *Config) Build(ctx context.Context, ns *schema2.NodeSchema, _ ...schema2.BuildOption) (any, error) {
if c.AnswerType == AnswerDirectly {
if c.ExtractFromAnswer {
if c.LLMParams == nil {
return nil, errors.New("model is required when extract from answer")
}
if len(ns.OutputTypes) == 0 {
return nil, errors.New("output fields is required when extract from answer")
}
}
} else if c.AnswerType == AnswerByChoices {
if c.ChoiceType == FixedChoices {
if len(c.FixedChoices) == 0 {
return nil, errors.New("fixed choices is required when extract from answer")
}
}
} else {
return nil, fmt.Errorf("unknown answer type: %s", c.AnswerType)
}
nodeMeta := entity.NodeMetaByNodeType(entity.NodeTypeQuestionAnswer)
if nodeMeta == nil {
return nil, errors.New("node meta not found for question answer")
}
var (
m model.BaseChatModel
err error
)
if c.LLMParams != nil {
m, _, err = crossmodel.GetManager().GetModel(ctx, c.LLMParams)
if err != nil {
return nil, err
}
}
return &QuestionAnswer{
model: m,
nodeMeta: *nodeMeta,
questionTpl: c.QuestionTpl,
answerType: c.AnswerType,
choiceType: c.ChoiceType,
fixedChoices: c.FixedChoices,
needExtractFromAnswer: c.ExtractFromAnswer,
additionalSystemPromptTpl: c.AdditionalSystemPromptTpl,
maxAnswerCount: c.MaxAnswerCount,
nodeKey: ns.Key,
outputFields: ns.OutputTypes,
}, nil
}
func (c *Config) BuildBranch(_ context.Context) (
func(ctx context.Context, nodeOutput map[string]any) (int64, bool, error), bool) {
if c.AnswerType != AnswerByChoices {
return nil, false
}
return func(ctx context.Context, nodeOutput map[string]any) (int64, bool, error) {
optionID, ok := nodeOutput[OptionIDKey]
if !ok {
return -1, false, fmt.Errorf("failed to take option id from input map: %v", nodeOutput)
}
if c.ChoiceType == DynamicChoices {
if optionID.(string) == "other" {
return -1, true, nil
} else {
return 0, false, nil
}
}
if optionID.(string) == "other" {
return -1, true, nil
}
optionIDInt, ok := AlphabetToInt(optionID.(string))
if !ok {
return -1, false, fmt.Errorf("failed to convert option id from input map: %v", optionID)
}
return optionIDInt, false, nil
}, true
}
func (c *Config) ExpectPorts(ctx context.Context, n *vo.Node) (expects []string) {
if n.Data.Inputs.QA.AnswerType != vo.QAAnswerTypeOption {
return expects
}
if n.Data.Inputs.QA.OptionType == vo.QAOptionTypeStatic {
for index := range n.Data.Inputs.QA.Options {
expects = append(expects, fmt.Sprintf(schema2.PortBranchFormat, index))
}
expects = append(expects, schema2.PortDefault)
return expects
}
if n.Data.Inputs.QA.OptionType == vo.QAOptionTypeDynamic {
expects = append(expects, fmt.Sprintf(schema2.PortBranchFormat, 0))
expects = append(expects, schema2.PortDefault)
}
return expects
}
func (c *Config) RequireCheckpoint() bool {
return true
}
type AnswerType string
@@ -126,41 +376,6 @@ Strictly identify the intention and select the most suitable option. You can onl
Note: You can only output the id or -1. Your output can only be a pure number and no other content (including the reason)!`
)
func NewQuestionAnswer(_ context.Context, conf *Config) (*QuestionAnswer, error) {
if conf == nil {
return nil, errors.New("config is nil")
}
if conf.AnswerType == AnswerDirectly {
if conf.ExtractFromAnswer {
if conf.Model == nil {
return nil, errors.New("model is required when extract from answer")
}
if len(conf.OutputFields) == 0 {
return nil, errors.New("output fields is required when extract from answer")
}
}
} else if conf.AnswerType == AnswerByChoices {
if conf.ChoiceType == FixedChoices {
if len(conf.FixedChoices) == 0 {
return nil, errors.New("fixed choices is required when extract from answer")
}
}
} else {
return nil, fmt.Errorf("unknown answer type: %s", conf.AnswerType)
}
nodeMeta := entity.NodeMetaByNodeType(entity.NodeTypeQuestionAnswer)
if nodeMeta == nil {
return nil, errors.New("node meta not found for question answer")
}
return &QuestionAnswer{
config: conf,
nodeMeta: *nodeMeta,
}, nil
}
type Question struct {
Question string
Choices []string
@@ -182,10 +397,10 @@ type message struct {
ID string `json:"id,omitempty"`
}
// Execute formats the question (optionally with choices), interrupts, then extracts the answer.
// Invoke formats the question (optionally with choices), interrupts, then extracts the answer.
// input: the references by input fields, as well as the dynamic choices array if needed.
// output: USER_RESPONSE for direct answer, structured output if needs to extract from answer, and option ID / content for answer by choices.
func (q *QuestionAnswer) Execute(ctx context.Context, in map[string]any) (out map[string]any, err error) {
func (q *QuestionAnswer) Invoke(ctx context.Context, in map[string]any) (out map[string]any, err error) {
var (
questions []*Question
answers []string
@@ -206,11 +421,11 @@ func (q *QuestionAnswer) Execute(ctx context.Context, in map[string]any) (out ma
out[QuestionsKey] = questions
out[AnswersKey] = answers
switch q.config.AnswerType {
switch q.answerType {
case AnswerDirectly:
if isFirst { // first execution, ask the question
// format the question. Which is common to all use cases
firstQuestion, err := nodes.TemplateRender(q.config.QuestionTpl, in)
firstQuestion, err := nodes.TemplateRender(q.questionTpl, in)
if err != nil {
return nil, err
}
@@ -218,7 +433,7 @@ func (q *QuestionAnswer) Execute(ctx context.Context, in map[string]any) (out ma
return nil, q.interrupt(ctx, firstQuestion, nil, nil, nil)
}
if q.config.ExtractFromAnswer {
if q.needExtractFromAnswer {
return q.extractFromAnswer(ctx, in, questions, answers)
}
@@ -253,15 +468,15 @@ func (q *QuestionAnswer) Execute(ctx context.Context, in map[string]any) (out ma
}
// format the question. Which is common to all use cases
firstQuestion, err := nodes.TemplateRender(q.config.QuestionTpl, in)
firstQuestion, err := nodes.TemplateRender(q.questionTpl, in)
if err != nil {
return nil, err
}
var formattedChoices []string
switch q.config.ChoiceType {
switch q.choiceType {
case FixedChoices:
for _, choice := range q.config.FixedChoices {
for _, choice := range q.fixedChoices {
formattedChoice, err := nodes.TemplateRender(choice, in)
if err != nil {
return nil, err
@@ -283,18 +498,18 @@ func (q *QuestionAnswer) Execute(ctx context.Context, in map[string]any) (out ma
formattedChoices = append(formattedChoices, c)
}
default:
return nil, fmt.Errorf("unknown choice type: %s", q.config.ChoiceType)
return nil, fmt.Errorf("unknown choice type: %s", q.choiceType)
}
return nil, q.interrupt(ctx, firstQuestion, formattedChoices, nil, nil)
default:
return nil, fmt.Errorf("unknown answer type: %s", q.config.AnswerType)
return nil, fmt.Errorf("unknown answer type: %s", q.answerType)
}
}
func (q *QuestionAnswer) extractFromAnswer(ctx context.Context, in map[string]any, questions []*Question, answers []string) (map[string]any, error) {
fieldInfo := "FieldInfo"
s, err := vo.TypeInfoToJSONSchema(q.config.OutputFields, &fieldInfo)
s, err := vo.TypeInfoToJSONSchema(q.outputFields, &fieldInfo)
if err != nil {
return nil, err
}
@@ -302,15 +517,15 @@ func (q *QuestionAnswer) extractFromAnswer(ctx context.Context, in map[string]an
sysPrompt := fmt.Sprintf(extractSystemPrompt, s)
var requiredFields []string
for fName, tInfo := range q.config.OutputFields {
for fName, tInfo := range q.outputFields {
if tInfo.Required {
requiredFields = append(requiredFields, fName)
}
}
var formattedAdditionalPrompt string
if len(q.config.AdditionalSystemPromptTpl) > 0 {
additionalPrompt, err := nodes.TemplateRender(q.config.AdditionalSystemPromptTpl, in)
if len(q.additionalSystemPromptTpl) > 0 {
additionalPrompt, err := nodes.TemplateRender(q.additionalSystemPromptTpl, in)
if err != nil {
return nil, err
}
@@ -336,7 +551,7 @@ func (q *QuestionAnswer) extractFromAnswer(ctx context.Context, in map[string]an
messages = append(messages, schema.UserMessage(answer))
}
out, err := q.config.Model.Generate(ctx, messages)
out, err := q.model.Generate(ctx, messages)
if err != nil {
return nil, err
}
@@ -353,8 +568,8 @@ func (q *QuestionAnswer) extractFromAnswer(ctx context.Context, in map[string]an
if ok {
nextQuestionStr, ok := nextQuestion.(string)
if ok && len(nextQuestionStr) > 0 {
if len(answers) >= q.config.MaxAnswerCount {
return nil, fmt.Errorf("max answer count= %d exceeded", q.config.MaxAnswerCount)
if len(answers) >= q.maxAnswerCount {
return nil, fmt.Errorf("max answer count= %d exceeded", q.maxAnswerCount)
}
return nil, q.interrupt(ctx, nextQuestionStr, nil, questions, answers)
@@ -366,7 +581,7 @@ func (q *QuestionAnswer) extractFromAnswer(ctx context.Context, in map[string]an
return nil, fmt.Errorf("field %s not found", fieldInfo)
}
realOutput, ws, err := nodes.ConvertInputs(ctx, fields.(map[string]any), q.config.OutputFields, nodes.SkipRequireCheck())
realOutput, ws, err := nodes.ConvertInputs(ctx, fields.(map[string]any), q.outputFields, nodes.SkipRequireCheck())
if err != nil {
return nil, err
}
@@ -431,7 +646,7 @@ func (q *QuestionAnswer) intentDetect(ctx context.Context, answer string, choice
schema.UserMessage(answer),
}
out, err := q.config.Model.Generate(ctx, messages)
out, err := q.model.Generate(ctx, messages)
if err != nil {
return -1, err
}
@@ -468,7 +683,7 @@ func (q *QuestionAnswer) interrupt(ctx context.Context, newQuestion string, choi
event := &entity.InterruptEvent{
ID: eventID,
NodeKey: q.config.NodeKey,
NodeKey: q.nodeKey,
NodeType: entity.NodeTypeQuestionAnswer,
NodeTitle: q.nodeMeta.Name,
NodeIcon: q.nodeMeta.IconURL,
@@ -477,7 +692,7 @@ func (q *QuestionAnswer) interrupt(ctx context.Context, newQuestion string, choi
}
_ = compose.ProcessState(ctx, func(ctx context.Context, setter QuestionAnswerAware) error {
setter.AddQuestion(q.config.NodeKey, &Question{
setter.AddQuestion(q.nodeKey, &Question{
Question: newQuestion,
Choices: choices,
})
@@ -495,14 +710,14 @@ func intToAlphabet(num int) string {
return ""
}
func AlphabetToInt(str string) (int, bool) {
func AlphabetToInt(str string) (int64, bool) {
if len(str) != 1 {
return 0, false
}
char := rune(str[0])
char = unicode.ToUpper(char)
if char >= 'A' && char <= 'Z' {
return int(char - 'A'), true
return int64(char - 'A'), true
}
return 0, false
}
@@ -521,14 +736,14 @@ func (q *QuestionAnswer) generateHistory(oldQuestions []*Question, oldAnswers []
for i := 0; i < len(oldQuestions); i++ {
oldQuestion := oldQuestions[i]
oldAnswer := oldAnswers[i]
contentType := ternary.IFElse(q.config.AnswerType == AnswerByChoices, "option", "text")
contentType := ternary.IFElse(q.answerType == AnswerByChoices, "option", "text")
questionMsg := &message{
Type: "question",
ContentType: contentType,
ID: fmt.Sprintf("%s_%d", q.config.NodeKey, i*2),
ID: fmt.Sprintf("%s_%d", q.nodeKey, i*2),
}
if q.config.AnswerType == AnswerByChoices {
if q.answerType == AnswerByChoices {
questionMsg.Content = optionContent{
Options: conv(oldQuestion.Choices),
Question: oldQuestion.Question,
@@ -541,14 +756,14 @@ func (q *QuestionAnswer) generateHistory(oldQuestions []*Question, oldAnswers []
Type: "answer",
ContentType: contentType,
Content: oldAnswer,
ID: fmt.Sprintf("%s_%d", q.config.NodeKey, i+1),
ID: fmt.Sprintf("%s_%d", q.nodeKey, i+1),
}
history = append(history, questionMsg, answerMsg)
}
if newQuestion != nil {
if q.config.AnswerType == AnswerByChoices {
if q.answerType == AnswerByChoices {
history = append(history, &message{
Type: "question",
ContentType: "option",
@@ -556,14 +771,14 @@ func (q *QuestionAnswer) generateHistory(oldQuestions []*Question, oldAnswers []
Options: conv(choices),
Question: *newQuestion,
},
ID: fmt.Sprintf("%s_%d", q.config.NodeKey, len(oldQuestions)*2),
ID: fmt.Sprintf("%s_%d", q.nodeKey, len(oldQuestions)*2),
})
} else {
history = append(history, &message{
Type: "question",
ContentType: "text",
Content: *newQuestion,
ID: fmt.Sprintf("%s_%d", q.config.NodeKey, len(oldQuestions)*2),
ID: fmt.Sprintf("%s_%d", q.nodeKey, len(oldQuestions)*2),
})
}
}

View File

@@ -27,8 +27,10 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
@@ -37,19 +39,27 @@ import (
)
type Config struct {
OutputTypes map[string]*vo.TypeInfo
NodeKey vo.NodeKey
OutputSchema string
}
type InputReceiver struct {
outputTypes map[string]*vo.TypeInfo
interruptData string
nodeKey vo.NodeKey
nodeMeta entity.NodeTypeMeta
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
c.OutputSchema = n.Data.Inputs.OutputSchema
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeInputReceiver,
Name: n.Data.Meta.Title,
Configs: c,
}
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func New(_ context.Context, cfg *Config) (*InputReceiver, error) {
func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
nodeMeta := entity.NodeMetaByNodeType(entity.NodeTypeInputReceiver)
if nodeMeta == nil {
return nil, errors.New("node meta not found for input receiver")
@@ -57,7 +67,7 @@ func New(_ context.Context, cfg *Config) (*InputReceiver, error) {
interruptData := map[string]string{
"content_type": "form_schema",
"content": cfg.OutputSchema,
"content": c.OutputSchema,
}
interruptDataStr, err := sonic.ConfigStd.MarshalToString(interruptData) // keep the order of the keys
@@ -66,13 +76,24 @@ func New(_ context.Context, cfg *Config) (*InputReceiver, error) {
}
return &InputReceiver{
outputTypes: cfg.OutputTypes,
outputTypes: ns.OutputTypes, // so the node can refer to its output types during execution
nodeMeta: *nodeMeta,
nodeKey: cfg.NodeKey,
nodeKey: ns.Key,
interruptData: interruptDataStr,
}, nil
}
func (c *Config) RequireCheckpoint() bool {
return true
}
type InputReceiver struct {
outputTypes map[string]*vo.TypeInfo
interruptData string
nodeKey vo.NodeKey
nodeMeta entity.NodeTypeMeta
}
const (
ReceivedDataKey = "$received_data"
receiverWarningKey = "receiver_warning_%d_%s"

View File

@@ -0,0 +1,190 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package selector
import (
"context"
"fmt"
"strconv"
"strings"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
)
type selectorCallbackField struct {
Key string `json:"key"`
Type vo.DataType `json:"type"`
Value any `json:"value"`
}
type selectorCondition struct {
Left selectorCallbackField `json:"left"`
Operator vo.OperatorType `json:"operator"`
Right *selectorCallbackField `json:"right,omitempty"`
}
type selectorBranch struct {
Conditions []*selectorCondition `json:"conditions"`
Logic vo.LogicType `json:"logic"`
Name string `json:"name"`
}
func (s *Selector) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) {
count := len(s.clauses)
output := make([]*selectorBranch, count)
for _, source := range s.ns.InputSources {
targetPath := source.Path
if len(targetPath) == 2 {
indexStr := targetPath[0]
index, err := strconv.Atoi(indexStr)
if err != nil {
return nil, err
}
branch := output[index]
if branch == nil {
output[index] = &selectorBranch{
Conditions: []*selectorCondition{
{
Operator: s.clauses[index].Single.ToCanvasOperatorType(),
},
},
Logic: ClauseRelationAND.ToVOLogicType(),
}
}
if targetPath[1] == LeftKey {
leftV, ok := nodes.TakeMapValue(in, targetPath)
if !ok {
return nil, fmt.Errorf("failed to take left value of %s", targetPath)
}
if source.Source.Ref.VariableType != nil {
if *source.Source.Ref.VariableType == vo.ParentIntermediate {
parentNodeKey, ok := s.ws.Hierarchy[s.ns.Key]
if !ok {
return nil, fmt.Errorf("failed to find parent node key of %s", s.ns.Key)
}
parentNode := s.ws.GetNode(parentNodeKey)
output[index].Conditions[0].Left = selectorCallbackField{
Key: parentNode.Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
Value: leftV,
}
} else {
output[index].Conditions[0].Left = selectorCallbackField{
Key: "",
Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
Value: leftV,
}
}
} else {
output[index].Conditions[0].Left = selectorCallbackField{
Key: s.ws.GetNode(source.Source.Ref.FromNodeKey).Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
Value: leftV,
}
}
} else if targetPath[1] == RightKey {
rightV, ok := nodes.TakeMapValue(in, targetPath)
if !ok {
return nil, fmt.Errorf("failed to take right value of %s", targetPath)
}
output[index].Conditions[0].Right = &selectorCallbackField{
Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
Value: rightV,
}
}
} else if len(targetPath) == 3 {
indexStr := targetPath[0]
index, err := strconv.Atoi(indexStr)
if err != nil {
return nil, err
}
multi := s.clauses[index].Multi
branch := output[index]
if branch == nil {
output[index] = &selectorBranch{
Conditions: make([]*selectorCondition, len(multi.Clauses)),
Logic: multi.Relation.ToVOLogicType(),
}
}
clauseIndexStr := targetPath[1]
clauseIndex, err := strconv.Atoi(clauseIndexStr)
if err != nil {
return nil, err
}
clause := multi.Clauses[clauseIndex]
if output[index].Conditions[clauseIndex] == nil {
output[index].Conditions[clauseIndex] = &selectorCondition{
Operator: clause.ToCanvasOperatorType(),
}
}
if targetPath[2] == LeftKey {
leftV, ok := nodes.TakeMapValue(in, targetPath)
if !ok {
return nil, fmt.Errorf("failed to take left value of %s", targetPath)
}
if source.Source.Ref.VariableType != nil {
if *source.Source.Ref.VariableType == vo.ParentIntermediate {
parentNodeKey, ok := s.ws.Hierarchy[s.ns.Key]
if !ok {
return nil, fmt.Errorf("failed to find parent node key of %s", s.ns.Key)
}
parentNode := s.ws.GetNode(parentNodeKey)
output[index].Conditions[clauseIndex].Left = selectorCallbackField{
Key: parentNode.Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
Value: leftV,
}
} else {
output[index].Conditions[clauseIndex].Left = selectorCallbackField{
Key: "",
Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
Value: leftV,
}
}
} else {
output[index].Conditions[clauseIndex].Left = selectorCallbackField{
Key: s.ws.GetNode(source.Source.Ref.FromNodeKey).Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
Value: leftV,
}
}
} else if targetPath[2] == RightKey {
rightV, ok := nodes.TakeMapValue(in, targetPath)
if !ok {
return nil, fmt.Errorf("failed to take right value of %s", targetPath)
}
output[index].Conditions[clauseIndex].Right = &selectorCallbackField{
Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
Value: rightV,
}
}
}
}
return map[string]any{"branches": output}, nil
}

View File

@@ -180,3 +180,48 @@ func (o *Operator) ToCanvasOperatorType() vo.OperatorType {
panic(fmt.Sprintf("unknown operator: %+v", o))
}
}
func ToSelectorOperator(o vo.OperatorType, leftType *vo.TypeInfo) (Operator, error) {
switch o {
case vo.Equal:
return OperatorEqual, nil
case vo.NotEqual:
return OperatorNotEqual, nil
case vo.LengthGreaterThan:
return OperatorLengthGreater, nil
case vo.LengthGreaterThanEqual:
return OperatorLengthGreaterOrEqual, nil
case vo.LengthLessThan:
return OperatorLengthLesser, nil
case vo.LengthLessThanEqual:
return OperatorLengthLesserOrEqual, nil
case vo.Contain:
if leftType.Type == vo.DataTypeObject {
return OperatorContainKey, nil
}
return OperatorContain, nil
case vo.NotContain:
if leftType.Type == vo.DataTypeObject {
return OperatorNotContainKey, nil
}
return OperatorNotContain, nil
case vo.Empty:
return OperatorEmpty, nil
case vo.NotEmpty:
return OperatorNotEmpty, nil
case vo.True:
return OperatorIsTrue, nil
case vo.False:
return OperatorIsFalse, nil
case vo.GreaterThan:
return OperatorGreater, nil
case vo.GreaterThanEqual:
return OperatorGreaterOrEqual, nil
case vo.LessThan:
return OperatorLesser, nil
case vo.LessThanEqual:
return OperatorLesserOrEqual, nil
default:
return "", fmt.Errorf("unsupported operator type: %d", o)
}
}

View File

@@ -17,9 +17,16 @@
package selector
import (
"context"
"fmt"
einoCompose "github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
type ClauseRelation string
@@ -29,10 +36,6 @@ const (
ClauseRelationOR ClauseRelation = "or"
)
type Config struct {
Clauses []*OneClauseSchema `json:"clauses"`
}
type OneClauseSchema struct {
Single *Operator `json:"single,omitempty"`
Multi *MultiClauseSchema `json:"multi,omitempty"`
@@ -52,3 +55,140 @@ func (c ClauseRelation) ToVOLogicType() vo.LogicType {
panic(fmt.Sprintf("unknown clause relation: %s", c))
}
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
clauses := make([]*OneClauseSchema, 0)
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Name: n.Data.Meta.Title,
Type: entity.NodeTypeSelector,
Configs: c,
}
for i, branchCond := range n.Data.Inputs.Branches {
inputType := &vo.TypeInfo{
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{},
}
if len(branchCond.Condition.Conditions) == 1 { // single condition
cond := branchCond.Condition.Conditions[0]
left := cond.Left
if left == nil {
return nil, fmt.Errorf("operator left is nil")
}
leftType, err := convert.CanvasBlockInputToTypeInfo(left.Input)
if err != nil {
return nil, err
}
leftSources, err := convert.CanvasBlockInputToFieldInfo(left.Input, einoCompose.FieldPath{fmt.Sprintf("%d", i), LeftKey}, n.Parent())
if err != nil {
return nil, err
}
inputType.Properties[LeftKey] = leftType
ns.AddInputSource(leftSources...)
op, err := ToSelectorOperator(cond.Operator, leftType)
if err != nil {
return nil, err
}
if cond.Right != nil {
rightType, err := convert.CanvasBlockInputToTypeInfo(cond.Right.Input)
if err != nil {
return nil, err
}
rightSources, err := convert.CanvasBlockInputToFieldInfo(cond.Right.Input, einoCompose.FieldPath{fmt.Sprintf("%d", i), RightKey}, n.Parent())
if err != nil {
return nil, err
}
inputType.Properties[RightKey] = rightType
ns.AddInputSource(rightSources...)
}
ns.SetInputType(fmt.Sprintf("%d", i), inputType)
clauses = append(clauses, &OneClauseSchema{
Single: &op,
})
continue
}
var relation ClauseRelation
logic := branchCond.Condition.Logic
if logic == vo.OR {
relation = ClauseRelationOR
} else if logic == vo.AND {
relation = ClauseRelationAND
}
var ops []*Operator
for j, cond := range branchCond.Condition.Conditions {
left := cond.Left
if left == nil {
return nil, fmt.Errorf("operator left is nil")
}
leftType, err := convert.CanvasBlockInputToTypeInfo(left.Input)
if err != nil {
return nil, err
}
leftSources, err := convert.CanvasBlockInputToFieldInfo(left.Input, einoCompose.FieldPath{fmt.Sprintf("%d", i), fmt.Sprintf("%d", j), LeftKey}, n.Parent())
if err != nil {
return nil, err
}
inputType.Properties[fmt.Sprintf("%d", j)] = &vo.TypeInfo{
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
LeftKey: leftType,
},
}
ns.AddInputSource(leftSources...)
op, err := ToSelectorOperator(cond.Operator, leftType)
if err != nil {
return nil, err
}
ops = append(ops, &op)
if cond.Right != nil {
rightType, err := convert.CanvasBlockInputToTypeInfo(cond.Right.Input)
if err != nil {
return nil, err
}
rightSources, err := convert.CanvasBlockInputToFieldInfo(cond.Right.Input, einoCompose.FieldPath{fmt.Sprintf("%d", i), fmt.Sprintf("%d", j), RightKey}, n.Parent())
if err != nil {
return nil, err
}
inputType.Properties[fmt.Sprintf("%d", j)].Properties[RightKey] = rightType
ns.AddInputSource(rightSources...)
}
}
ns.SetInputType(fmt.Sprintf("%d", i), inputType)
clauses = append(clauses, &OneClauseSchema{
Multi: &MultiClauseSchema{
Clauses: ops,
Relation: relation,
},
})
}
c.Clauses = clauses
return ns, nil
}

View File

@@ -23,23 +23,32 @@ import (
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
type Selector struct {
config *Config
clauses []*OneClauseSchema
ns *schema.NodeSchema
ws *schema.WorkflowSchema
}
func NewSelector(_ context.Context, config *Config) (*Selector, error) {
if config == nil {
return nil, fmt.Errorf("config is nil")
type Config struct {
Clauses []*OneClauseSchema `json:"clauses"`
}
func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, opts ...schema.BuildOption) (any, error) {
ws := schema.GetBuildOptions(opts...).WS
if ws == nil {
return nil, fmt.Errorf("workflow schema is required")
}
if len(config.Clauses) == 0 {
if len(c.Clauses) == 0 {
return nil, fmt.Errorf("config clauses are empty")
}
for _, clause := range config.Clauses {
for _, clause := range c.Clauses {
if clause.Single == nil && clause.Multi == nil {
return nil, fmt.Errorf("single clause and multi clause are both nil")
}
@@ -60,10 +69,42 @@ func NewSelector(_ context.Context, config *Config) (*Selector, error) {
}
return &Selector{
config: config,
clauses: c.Clauses,
ns: ns,
ws: ws,
}, nil
}
func (c *Config) BuildBranch(_ context.Context) (
func(ctx context.Context, nodeOutput map[string]any) (int64, bool, error), bool) {
return func(ctx context.Context, nodeOutput map[string]any) (int64, bool, error) {
choice := nodeOutput[SelectKey].(int64)
if choice < 0 || choice > int64(len(c.Clauses)+1) {
return -1, false, fmt.Errorf("selector choice out of range: %d", choice)
}
if choice == int64(len(c.Clauses)) { // default
return -1, true, nil
}
return choice, false, nil
}, true
}
func (c *Config) ExpectPorts(_ context.Context, n *vo.Node) []string {
expects := make([]string, len(n.Data.Inputs.Branches)+1)
expects[0] = "false" // default branch
if len(n.Data.Inputs.Branches) > 0 {
expects[1] = "true" // first condition
}
for i := 1; i < len(n.Data.Inputs.Branches); i++ { // other conditions
expects[i+1] = "true_" + strconv.Itoa(i)
}
return expects
}
type Operants struct {
Left any
Right any
@@ -76,14 +117,14 @@ const (
SelectKey = "selected"
)
func (s *Selector) Select(_ context.Context, input map[string]any) (out map[string]any, err error) {
in, err := s.SelectorInputConverter(input)
func (s *Selector) Invoke(_ context.Context, input map[string]any) (out map[string]any, err error) {
in, err := s.selectorInputConverter(input)
if err != nil {
return nil, err
}
predicates := make([]Predicate, 0, len(s.config.Clauses))
for i, oneConf := range s.config.Clauses {
predicates := make([]Predicate, 0, len(s.clauses))
for i, oneConf := range s.clauses {
if oneConf.Single != nil {
left := in[i].Left
right := in[i].Right
@@ -132,23 +173,15 @@ func (s *Selector) Select(_ context.Context, input map[string]any) (out map[stri
}
if isTrue {
return map[string]any{SelectKey: i}, nil
return map[string]any{SelectKey: int64(i)}, nil
}
}
return map[string]any{SelectKey: len(in)}, nil // default choice
return map[string]any{SelectKey: int64(len(in))}, nil // default choice
}
func (s *Selector) GetType() string {
return "Selector"
}
func (s *Selector) ConditionCount() int {
return len(s.config.Clauses)
}
func (s *Selector) SelectorInputConverter(in map[string]any) (out []Operants, err error) {
conf := s.config.Clauses
func (s *Selector) selectorInputConverter(in map[string]any) (out []Operants, err error) {
conf := s.clauses
for i, oneConf := range conf {
if oneConf.Single != nil {
@@ -187,8 +220,8 @@ func (s *Selector) SelectorInputConverter(in map[string]any) (out []Operants, er
}
func (s *Selector) ToCallbackOutput(_ context.Context, output map[string]any) (*nodes.StructuredCallbackOutput, error) {
count := len(s.config.Clauses)
out := output[SelectKey].(int)
count := int64(len(s.clauses))
out := output[SelectKey].(int64)
if out == count {
cOutput := map[string]any{"result": "pass to else branch"}
return &nodes.StructuredCallbackOutput{

View File

@@ -22,57 +22,27 @@ import (
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
var KeyIsFinished = "\x1FKey is finished\x1F"
type Mode string
const (
Streaming Mode = "streaming"
NonStreaming Mode = "non-streaming"
)
type FieldStreamType string
const (
FieldIsStream FieldStreamType = "yes" // absolutely a stream
FieldNotStream FieldStreamType = "no" // absolutely not a stream
FieldMaybeStream FieldStreamType = "maybe" // maybe a stream, requires request-time resolution
FieldSkipped FieldStreamType = "skipped" // the field source's node is skipped
)
// SourceInfo contains stream type for a input field source of a node.
type SourceInfo struct {
// IsIntermediate means this field is itself not a field source, but a map containing one or more field sources.
IsIntermediate bool
// FieldType the stream type of the field. May require request-time resolution in addition to compile-time.
FieldType FieldStreamType
// FromNodeKey is the node key that produces this field source. empty if the field is a static value or variable.
FromNodeKey vo.NodeKey
// FromPath is the path of this field source within the source node. empty if the field is a static value or variable.
FromPath compose.FieldPath
TypeInfo *vo.TypeInfo
// SubSources are SourceInfo for keys within this intermediate Map(Object) field.
SubSources map[string]*SourceInfo
}
type DynamicStreamContainer interface {
SaveDynamicChoice(nodeKey vo.NodeKey, groupToChoice map[string]int)
GetDynamicChoice(nodeKey vo.NodeKey) map[string]int
GetDynamicStreamType(nodeKey vo.NodeKey, group string) (FieldStreamType, error)
GetAllDynamicStreamTypes(nodeKey vo.NodeKey) (map[string]FieldStreamType, error)
GetDynamicStreamType(nodeKey vo.NodeKey, group string) (schema.FieldStreamType, error)
GetAllDynamicStreamTypes(nodeKey vo.NodeKey) (map[string]schema.FieldStreamType, error)
}
// ResolveStreamSources resolves incoming field sources for a node, deciding their stream type.
func ResolveStreamSources(ctx context.Context, sources map[string]*SourceInfo) (map[string]*SourceInfo, error) {
resolved := make(map[string]*SourceInfo, len(sources))
func ResolveStreamSources(ctx context.Context, sources map[string]*schema.SourceInfo) (map[string]*schema.SourceInfo, error) {
resolved := make(map[string]*schema.SourceInfo, len(sources))
nodeKey2Skipped := make(map[vo.NodeKey]bool)
var resolver func(path string, sInfo *SourceInfo) (*SourceInfo, error)
resolver = func(path string, sInfo *SourceInfo) (*SourceInfo, error) {
resolvedNode := &SourceInfo{
var resolver func(path string, sInfo *schema.SourceInfo) (*schema.SourceInfo, error)
resolver = func(path string, sInfo *schema.SourceInfo) (*schema.SourceInfo, error) {
resolvedNode := &schema.SourceInfo{
IsIntermediate: sInfo.IsIntermediate,
FieldType: sInfo.FieldType,
FromNodeKey: sInfo.FromNodeKey,
@@ -81,7 +51,7 @@ func ResolveStreamSources(ctx context.Context, sources map[string]*SourceInfo) (
}
if len(sInfo.SubSources) > 0 {
resolvedNode.SubSources = make(map[string]*SourceInfo, len(sInfo.SubSources))
resolvedNode.SubSources = make(map[string]*schema.SourceInfo, len(sInfo.SubSources))
for k, subInfo := range sInfo.SubSources {
resolvedSub, err := resolver(k, subInfo)
@@ -109,16 +79,16 @@ func ResolveStreamSources(ctx context.Context, sources map[string]*SourceInfo) (
}
if skipped {
resolvedNode.FieldType = FieldSkipped
resolvedNode.FieldType = schema.FieldSkipped
return resolvedNode, nil
}
if sInfo.FieldType == FieldMaybeStream {
if sInfo.FieldType == schema.FieldMaybeStream {
if len(sInfo.SubSources) > 0 {
panic("a maybe stream field should not have sub sources")
}
var streamType FieldStreamType
var streamType schema.FieldStreamType
err := compose.ProcessState(ctx, func(ctx context.Context, state DynamicStreamContainer) error {
var e error
streamType, e = state.GetDynamicStreamType(sInfo.FromNodeKey, sInfo.FromPath[0])
@@ -128,7 +98,7 @@ func ResolveStreamSources(ctx context.Context, sources map[string]*SourceInfo) (
return nil, err
}
return &SourceInfo{
return &schema.SourceInfo{
IsIntermediate: sInfo.IsIntermediate,
FieldType: streamType,
FromNodeKey: sInfo.FromNodeKey,
@@ -156,30 +126,12 @@ type NodeExecuteStatusAware interface {
NodeExecuted(key vo.NodeKey) bool
}
func (s *SourceInfo) Skipped() bool {
if !s.IsIntermediate {
return s.FieldType == FieldSkipped
func IsStreamingField(s *schema.NodeSchema, path compose.FieldPath,
sc *schema.WorkflowSchema) (schema.FieldStreamType, error) {
sg, ok := s.Configs.(StreamGenerator)
if !ok {
return schema.FieldNotStream, nil
}
for _, sub := range s.SubSources {
if !sub.Skipped() {
return false
}
}
return true
}
func (s *SourceInfo) FromNode(nodeKey vo.NodeKey) bool {
if !s.IsIntermediate {
return s.FromNodeKey == nodeKey
}
for _, sub := range s.SubSources {
if sub.FromNode(nodeKey) {
return true
}
}
return false
return sg.FieldStreamType(path, s, sc)
}

View File

@@ -18,7 +18,6 @@ package subworkflow
import (
"context"
"errors"
"fmt"
"strconv"
@@ -29,35 +28,56 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
type Config struct {
Runner compose.Runnable[map[string]any, map[string]any]
WorkflowID int64
WorkflowVersion string
}
func (c *Config) FieldStreamType(path compose.FieldPath, ns *schema2.NodeSchema,
sc *schema2.WorkflowSchema) (schema2.FieldStreamType, error) {
if !sc.RequireStreaming() {
return schema2.FieldNotStream, nil
}
innerWF := ns.SubWorkflowSchema
if !innerWF.RequireStreaming() {
return schema2.FieldNotStream, nil
}
innerExit := innerWF.GetNode(entity.ExitNodeKey)
if innerExit.Configs.(*exit.Config).TerminatePlan == vo.ReturnVariables {
return schema2.FieldNotStream, nil
}
if !innerExit.StreamConfigs.RequireStreamingInput {
return schema2.FieldNotStream, nil
}
if len(path) > 1 || path[0] != "output" {
return schema2.FieldNotStream, fmt.Errorf(
"streaming answering sub-workflow node can only have out field 'output'")
}
return schema2.FieldIsStream, nil
}
type SubWorkflow struct {
cfg *Config
Runner compose.Runnable[map[string]any, map[string]any]
}
func NewSubWorkflow(_ context.Context, cfg *Config) (*SubWorkflow, error) {
if cfg == nil {
return nil, errors.New("config is nil")
}
if cfg.Runner == nil {
return nil, errors.New("runnable is nil")
}
return &SubWorkflow{cfg: cfg}, nil
}
func (s *SubWorkflow) Invoke(ctx context.Context, in map[string]any, opts ...nodes.NestedWorkflowOption) (map[string]any, error) {
func (s *SubWorkflow) Invoke(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (map[string]any, error) {
nestedOpts, nodeKey, err := prepareOptions(ctx, opts...)
if err != nil {
return nil, err
}
out, err := s.cfg.Runner.Invoke(ctx, in, nestedOpts...)
out, err := s.Runner.Invoke(ctx, in, nestedOpts...)
if err != nil {
interruptInfo, ok := compose.ExtractInterruptInfo(err)
if !ok {
@@ -82,13 +102,13 @@ func (s *SubWorkflow) Invoke(ctx context.Context, in map[string]any, opts ...nod
return out, nil
}
func (s *SubWorkflow) Stream(ctx context.Context, in map[string]any, opts ...nodes.NestedWorkflowOption) (*schema.StreamReader[map[string]any], error) {
func (s *SubWorkflow) Stream(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (*schema.StreamReader[map[string]any], error) {
nestedOpts, nodeKey, err := prepareOptions(ctx, opts...)
if err != nil {
return nil, err
}
out, err := s.cfg.Runner.Stream(ctx, in, nestedOpts...)
out, err := s.Runner.Stream(ctx, in, nestedOpts...)
if err != nil {
interruptInfo, ok := compose.ExtractInterruptInfo(err)
if !ok {
@@ -114,11 +134,8 @@ func (s *SubWorkflow) Stream(ctx context.Context, in map[string]any, opts ...nod
return out, nil
}
func prepareOptions(ctx context.Context, opts ...nodes.NestedWorkflowOption) ([]compose.Option, vo.NodeKey, error) {
options := &nodes.NestedWorkflowOptions{}
for _, opt := range opts {
opt(options)
}
func prepareOptions(ctx context.Context, opts ...nodes.NodeOption) ([]compose.Option, vo.NodeKey, error) {
options := nodes.GetCommonOptions(&nodes.NodeOptions{}, opts...)
nestedOpts := options.GetOptsForNested()

View File

@@ -30,6 +30,7 @@ import (
"github.com/bytedance/sonic/ast"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
@@ -156,7 +157,7 @@ func removeSlice(s string) string {
type renderOptions struct {
type2CustomRenderer map[reflect.Type]func(any) (string, error)
reservedKey map[string]struct{}
reservedKey map[string]struct{} // a reservedKey will always render, won't check for node skipping
nilRenderer func() (string, error)
}
@@ -300,7 +301,7 @@ func (tp TemplatePart) Render(m []byte, opts ...RenderOption) (string, error) {
}
}
func (tp TemplatePart) Skipped(resolvedSources map[string]*SourceInfo) (skipped bool, invalid bool) {
func (tp TemplatePart) Skipped(resolvedSources map[string]*schema.SourceInfo) (skipped bool, invalid bool) {
if len(resolvedSources) == 0 { // no information available, maybe outside the scope of a workflow
return false, false
}
@@ -316,7 +317,7 @@ func (tp TemplatePart) Skipped(resolvedSources map[string]*SourceInfo) (skipped
}
if !matchingSource.IsIntermediate {
return matchingSource.FieldType == FieldSkipped, false
return matchingSource.FieldType == schema.FieldSkipped, false
}
for _, subPath := range tp.SubPathsBeforeSlice {
@@ -325,20 +326,20 @@ func (tp TemplatePart) Skipped(resolvedSources map[string]*SourceInfo) (skipped
if matchingSource.IsIntermediate { // the user specified a non-existing source, just skip it
return false, true
}
return matchingSource.FieldType == FieldSkipped, false
return matchingSource.FieldType == schema.FieldSkipped, false
}
matchingSource = subSource
}
if !matchingSource.IsIntermediate {
return matchingSource.FieldType == FieldSkipped, false
return matchingSource.FieldType == schema.FieldSkipped, false
}
var checkSourceSkipped func(sInfo *SourceInfo) bool
checkSourceSkipped = func(sInfo *SourceInfo) bool {
var checkSourceSkipped func(sInfo *schema.SourceInfo) bool
checkSourceSkipped = func(sInfo *schema.SourceInfo) bool {
if !sInfo.IsIntermediate {
return sInfo.FieldType == FieldSkipped
return sInfo.FieldType == schema.FieldSkipped
}
for _, subSource := range sInfo.SubSources {
if !checkSourceSkipped(subSource) {
@@ -373,7 +374,7 @@ func (tp TemplatePart) TypeInfo(types map[string]*vo.TypeInfo) *vo.TypeInfo {
return currentType
}
func Render(ctx context.Context, tpl string, input map[string]any, sources map[string]*SourceInfo, opts ...RenderOption) (string, error) {
func Render(ctx context.Context, tpl string, input map[string]any, sources map[string]*schema.SourceInfo, opts ...RenderOption) (string, error) {
mi, err := sonic.Marshal(input)
if err != nil {
return "", err

View File

@@ -22,7 +22,11 @@ import (
"reflect"
"strings"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
)
@@ -34,42 +38,92 @@ const (
)
type Config struct {
Type Type `json:"type"`
Tpl string `json:"tpl"`
ConcatChar string `json:"concatChar"`
Separators []string `json:"separator"`
FullSources map[string]*nodes.SourceInfo `json:"fullSources"`
Type Type `json:"type"`
Tpl string `json:"tpl"`
ConcatChar string `json:"concatChar"`
Separators []string `json:"separator"`
}
type TextProcessor struct {
config *Config
}
func NewTextProcessor(_ context.Context, cfg *Config) (*TextProcessor, error) {
if cfg == nil {
return nil, fmt.Errorf("config requried")
func (c *Config) Adapt(ctx context.Context, n *vo.Node, opts ...nodes.AdaptOption) (*schema.NodeSchema, error) {
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeTextProcessor,
Name: n.Data.Meta.Title,
Configs: c,
}
if cfg.Type == ConcatText && len(cfg.Tpl) == 0 {
if n.Data.Inputs.Method == vo.Concat {
c.Type = ConcatText
params := n.Data.Inputs.ConcatParams
for _, param := range params {
if param.Name == "concatResult" {
c.Tpl = param.Input.Value.Content.(string)
} else if param.Name == "arrayItemConcatChar" {
c.ConcatChar = param.Input.Value.Content.(string)
}
}
} else if n.Data.Inputs.Method == vo.Split {
c.Type = SplitText
params := n.Data.Inputs.SplitParams
separators := make([]string, 0, len(params))
for _, param := range params {
if param.Name == "delimiters" {
delimiters := param.Input.Value.Content.([]any)
for _, d := range delimiters {
separators = append(separators, d.(string))
}
}
}
c.Separators = separators
} else {
return nil, fmt.Errorf("not supported method: %s", n.Data.Inputs.Method)
}
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
if c.Type == ConcatText && len(c.Tpl) == 0 {
return nil, fmt.Errorf("config tpl requried")
}
return &TextProcessor{
config: cfg,
typ: c.Type,
tpl: c.Tpl,
concatChar: c.ConcatChar,
separators: c.Separators,
fullSources: ns.FullSources,
}, nil
}
type TextProcessor struct {
typ Type
tpl string
concatChar string
separators []string
fullSources map[string]*schema.SourceInfo
}
const OutputKey = "output"
func (t *TextProcessor) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
switch t.config.Type {
switch t.typ {
case ConcatText:
arrayRenderer := func(i any) (string, error) {
vs := i.([]any)
return join(vs, t.config.ConcatChar)
return join(vs, t.concatChar)
}
result, err := nodes.Render(ctx, t.config.Tpl, input, t.config.FullSources,
result, err := nodes.Render(ctx, t.tpl, input, t.fullSources,
nodes.WithCustomRender(reflect.TypeOf([]any{}), arrayRenderer))
if err != nil {
return nil, err
@@ -86,9 +140,9 @@ func (t *TextProcessor) Invoke(ctx context.Context, input map[string]any) (map[s
if !ok {
return nil, fmt.Errorf("input string field must string type but got %T", valueString)
}
values := strings.Split(valueString, t.config.Separators[0])
values := strings.Split(valueString, t.separators[0])
// Iterate over each delimiter
for _, sep := range t.config.Separators[1:] {
for _, sep := range t.separators[1:] {
var tempParts []string
for _, part := range values {
tempParts = append(tempParts, strings.Split(part, sep)...)
@@ -102,7 +156,7 @@ func (t *TextProcessor) Invoke(ctx context.Context, input map[string]any) (map[s
return map[string]any{OutputKey: anyValues}, nil
default:
return nil, fmt.Errorf("not support type %s", t.config.Type)
return nil, fmt.Errorf("not support type %s", t.typ)
}
}

View File

@@ -21,6 +21,8 @@ import (
"testing"
"github.com/stretchr/testify/assert"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
func TestNewTextProcessorNodeGenerator(t *testing.T) {
@@ -30,10 +32,10 @@ func TestNewTextProcessorNodeGenerator(t *testing.T) {
Type: SplitText,
Separators: []string{",", "|", "."},
}
p, err := NewTextProcessor(ctx, cfg)
p, err := cfg.Build(ctx, &schema2.NodeSchema{})
assert.NoError(t, err)
result, err := p.Invoke(ctx, map[string]any{
result, err := p.(*TextProcessor).Invoke(ctx, map[string]any{
"String": "a,b|c.d,e|f|g",
})
@@ -60,9 +62,9 @@ func TestNewTextProcessorNodeGenerator(t *testing.T) {
ConcatChar: `\t`,
Tpl: "fx{{a}}=={{b.b1}}=={{b.b2[1]}}=={{c}}",
}
p, err := NewTextProcessor(context.Background(), cfg)
p, err := cfg.Build(context.Background(), &schema2.NodeSchema{})
result, err := p.Invoke(ctx, in)
result, err := p.(*TextProcessor).Invoke(ctx, in)
assert.NoError(t, err)
assert.Equal(t, result["output"], `fx1\t{"1":1}\t3==1\t2\t3==2=={"c1":"1"}`)
})

View File

@@ -32,8 +32,11 @@ import (
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
@@ -48,24 +51,147 @@ const (
type Config struct {
MergeStrategy MergeStrategy
GroupLen map[string]int
FullSources map[string]*nodes.SourceInfo
NodeKey vo.NodeKey
InputSources []*vo.FieldInfo
GroupOrder []string // the order the groups are declared in frontend canvas
}
type VariableAggregator struct {
config *Config
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema2.NodeSchema, error) {
ns := &schema2.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeVariableAggregator,
Name: n.Data.Meta.Title,
Configs: c,
}
c.MergeStrategy = FirstNotNullValue
inputs := n.Data.Inputs
groupToLen := make(map[string]int, len(inputs.VariableAggregator.MergeGroups))
for i := range inputs.VariableAggregator.MergeGroups {
group := inputs.VariableAggregator.MergeGroups[i]
tInfo := &vo.TypeInfo{
Type: vo.DataTypeObject,
Properties: make(map[string]*vo.TypeInfo),
}
ns.SetInputType(group.Name, tInfo)
for ii, v := range group.Variables {
name := strconv.Itoa(ii)
valueTypeInfo, err := convert.CanvasBlockInputToTypeInfo(v)
if err != nil {
return nil, err
}
tInfo.Properties[name] = valueTypeInfo
sources, err := convert.CanvasBlockInputToFieldInfo(v, compose.FieldPath{group.Name, name}, n.Parent())
if err != nil {
return nil, err
}
ns.AddInputSource(sources...)
}
length := len(group.Variables)
groupToLen[group.Name] = length
}
groupOrder := make([]string, 0, len(groupToLen))
for i := range inputs.VariableAggregator.MergeGroups {
group := inputs.VariableAggregator.MergeGroups[i]
groupOrder = append(groupOrder, group.Name)
}
c.GroupLen = groupToLen
c.GroupOrder = groupOrder
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func NewVariableAggregator(_ context.Context, cfg *Config) (*VariableAggregator, error) {
if cfg == nil {
return nil, errors.New("config is required")
func (c *Config) Build(_ context.Context, ns *schema2.NodeSchema, _ ...schema2.BuildOption) (any, error) {
if c.MergeStrategy != FirstNotNullValue {
return nil, fmt.Errorf("merge strategy not supported: %v", c.MergeStrategy)
}
if cfg.MergeStrategy != FirstNotNullValue {
return nil, fmt.Errorf("merge strategy not supported: %v", cfg.MergeStrategy)
return &VariableAggregator{
groupLen: c.GroupLen,
fullSources: ns.FullSources,
nodeKey: ns.Key,
}, nil
}
func (c *Config) FieldStreamType(path compose.FieldPath, ns *schema2.NodeSchema,
sc *schema2.WorkflowSchema) (schema2.FieldStreamType, error) {
if !sc.RequireStreaming() {
return schema2.FieldNotStream, nil
}
return &VariableAggregator{config: cfg}, nil
if len(path) == 2 { // asking about a specific index within a group
for _, fInfo := range ns.InputSources {
if len(fInfo.Path) == len(path) {
equal := true
for i := range fInfo.Path {
if fInfo.Path[i] != path[i] {
equal = false
break
}
}
if equal {
if fInfo.Source.Ref == nil || fInfo.Source.Ref.FromNodeKey == "" {
return schema2.FieldNotStream, nil // variables or static values
}
fromNodeKey := fInfo.Source.Ref.FromNodeKey
fromNode := sc.GetNode(fromNodeKey)
if fromNode == nil {
return schema2.FieldNotStream, fmt.Errorf("node %s not found", fromNodeKey)
}
return nodes.IsStreamingField(fromNode, fInfo.Source.Ref.FromPath, sc)
}
}
}
} else if len(path) == 1 { // asking about the entire group
var streamCount, notStreamCount int
for _, fInfo := range ns.InputSources {
if fInfo.Path[0] == path[0] { // belong to the group
if fInfo.Source.Ref != nil && len(fInfo.Source.Ref.FromNodeKey) > 0 {
fromNode := sc.GetNode(fInfo.Source.Ref.FromNodeKey)
if fromNode == nil {
return schema2.FieldNotStream, fmt.Errorf("node %s not found", fInfo.Source.Ref.FromNodeKey)
}
subStreamType, err := nodes.IsStreamingField(fromNode, fInfo.Source.Ref.FromPath, sc)
if err != nil {
return schema2.FieldNotStream, err
}
if subStreamType == schema2.FieldMaybeStream {
return schema2.FieldMaybeStream, nil
} else if subStreamType == schema2.FieldIsStream {
streamCount++
} else {
notStreamCount++
}
}
}
}
if streamCount > 0 && notStreamCount == 0 {
return schema2.FieldIsStream, nil
}
if streamCount == 0 && notStreamCount > 0 {
return schema2.FieldNotStream, nil
}
return schema2.FieldMaybeStream, nil
}
return schema2.FieldNotStream, fmt.Errorf("variable aggregator output path max len = 2, actual: %v", path)
}
type VariableAggregator struct {
groupLen map[string]int
fullSources map[string]*schema2.SourceInfo
nodeKey vo.NodeKey
groupOrder []string // the order the groups are declared in frontend canvas
}
func (v *VariableAggregator) Invoke(ctx context.Context, input map[string]any) (_ map[string]any, err error) {
@@ -76,7 +202,7 @@ func (v *VariableAggregator) Invoke(ctx context.Context, input map[string]any) (
result := make(map[string]any)
groupToChoice := make(map[string]int)
for group, length := range v.config.GroupLen {
for group, length := range v.groupLen {
for i := 0; i < length; i++ {
if value, ok := in[group][i]; ok {
if value != nil {
@@ -93,14 +219,14 @@ func (v *VariableAggregator) Invoke(ctx context.Context, input map[string]any) (
}
_ = compose.ProcessState(ctx, func(ctx context.Context, state nodes.DynamicStreamContainer) error {
state.SaveDynamicChoice(v.config.NodeKey, groupToChoice)
state.SaveDynamicChoice(v.nodeKey, groupToChoice)
return nil
})
ctxcache.Store(ctx, groupChoiceTypeCacheKey, map[string]nodes.FieldStreamType{}) // none of the choices are stream
ctxcache.Store(ctx, groupChoiceTypeCacheKey, map[string]schema2.FieldStreamType{}) // none of the choices are stream
groupChoices := make([]any, 0, len(v.config.GroupOrder))
for _, group := range v.config.GroupOrder {
groupChoices := make([]any, 0, len(v.groupOrder))
for _, group := range v.groupOrder {
choice := groupToChoice[group]
if choice == -1 {
groupChoices = append(groupChoices, nil)
@@ -125,7 +251,7 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
_ *schema.StreamReader[map[string]any], err error) {
inStream := streamInputConverter(input)
resolvedSources, ok := ctxcache.Get[map[string]*nodes.SourceInfo](ctx, resolvedSourcesCacheKey)
resolvedSources, ok := ctxcache.Get[map[string]*schema2.SourceInfo](ctx, resolvedSourcesCacheKey)
if !ok {
panic("unable to get resolvesSources from ctx cache.")
}
@@ -138,18 +264,18 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
defer func() {
if err == nil {
groupChoiceToStreamType := map[string]nodes.FieldStreamType{}
groupChoiceToStreamType := map[string]schema2.FieldStreamType{}
for group, choice := range groupToChoice {
if choice != -1 {
item := groupToItems[group][choice]
if _, ok := item.(stream); ok {
groupChoiceToStreamType[group] = nodes.FieldIsStream
groupChoiceToStreamType[group] = schema2.FieldIsStream
}
}
}
groupChoices := make([]any, 0, len(v.config.GroupOrder))
for _, group := range v.config.GroupOrder {
groupChoices := make([]any, 0, len(v.groupOrder))
for _, group := range v.groupOrder {
choice := groupToChoice[group]
if choice == -1 {
groupChoices = append(groupChoices, nil)
@@ -174,16 +300,16 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
// - if an element is not stream, actually receive from the stream to check if it's non-nil
groupToCurrentIndex := make(map[string]int) // the currently known smallest index that is non-nil for each group
for group, length := range v.config.GroupLen {
for group, length := range v.groupLen {
groupToItems[group] = make([]any, length)
groupToCurrentIndex[group] = math.MaxInt
for i := 0; i < length; i++ {
fType := resolvedSources[group].SubSources[strconv.Itoa(i)].FieldType
if fType == nodes.FieldSkipped {
if fType == schema2.FieldSkipped {
groupToItems[group][i] = skipped{}
continue
}
if fType == nodes.FieldIsStream {
if fType == schema2.FieldIsStream {
groupToItems[group][i] = stream{}
if ci, _ := groupToCurrentIndex[group]; i < ci {
groupToCurrentIndex[group] = i
@@ -211,7 +337,7 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
}
allDone := func() bool {
for group := range v.config.GroupLen {
for group := range v.groupLen {
_, ok := groupToChoice[group]
if !ok {
return false
@@ -223,7 +349,7 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
alreadyDone := allDone()
if alreadyDone { // all groups have made their choices, no need to actually read input streams
result := make(map[string]any, len(v.config.GroupLen))
result := make(map[string]any, len(v.groupLen))
allSkip := true
for group := range groupToChoice {
choice := groupToChoice[group]
@@ -237,7 +363,7 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
if allSkip { // no need to convert input streams for the output, because all groups are skipped
_ = compose.ProcessState(ctx, func(ctx context.Context, state nodes.DynamicStreamContainer) error {
state.SaveDynamicChoice(v.config.NodeKey, groupToChoice)
state.SaveDynamicChoice(v.nodeKey, groupToChoice)
return nil
})
return schema.StreamReaderFromArray([]map[string]any{result}), nil
@@ -336,7 +462,7 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
}
_ = compose.ProcessState(ctx, func(ctx context.Context, state nodes.DynamicStreamContainer) error {
state.SaveDynamicChoice(v.config.NodeKey, groupToChoice)
state.SaveDynamicChoice(v.nodeKey, groupToChoice)
return nil
})
@@ -416,26 +542,12 @@ type vaCallbackInput struct {
Variables []any `json:"variables"`
}
func (v *VariableAggregator) Init(ctx context.Context) (context.Context, error) {
ctx = ctxcache.Init(ctx)
resolvedSources, err := nodes.ResolveStreamSources(ctx, v.config.FullSources)
if err != nil {
return nil, err
}
// need this info for callbacks.OnStart, so we put it in cache within Init()
ctxcache.Store(ctx, resolvedSourcesCacheKey, resolvedSources)
return ctx, nil
}
type streamMarkerType string
const streamMarker streamMarkerType = "<Stream Data...>"
func (v *VariableAggregator) ToCallbackInput(ctx context.Context, input map[string]any) (map[string]any, error) {
resolvedSources, ok := ctxcache.Get[map[string]*nodes.SourceInfo](ctx, resolvedSourcesCacheKey)
resolvedSources, ok := ctxcache.Get[map[string]*schema2.SourceInfo](ctx, resolvedSourcesCacheKey)
if !ok {
panic("unable to get resolved_sources from ctx cache")
}
@@ -447,14 +559,14 @@ func (v *VariableAggregator) ToCallbackInput(ctx context.Context, input map[stri
merged := make([]vaCallbackInput, 0, len(in))
groupLen := v.config.GroupLen
groupLen := v.groupLen
for groupName, vars := range in {
orderedVars := make([]any, groupLen[groupName])
for index := range vars {
orderedVars[index] = vars[index]
if len(resolvedSources) > 0 {
if resolvedSources[groupName].SubSources[strconv.Itoa(index)].FieldType == nodes.FieldIsStream {
if resolvedSources[groupName].SubSources[strconv.Itoa(index)].FieldType == schema2.FieldIsStream {
// replace the streams with streamMarker,
// because we won't read, save to execution history, or display these streams to user
orderedVars[index] = streamMarker
@@ -479,7 +591,7 @@ func (v *VariableAggregator) ToCallbackInput(ctx context.Context, input map[stri
}
func (v *VariableAggregator) ToCallbackOutput(ctx context.Context, output map[string]any) (*nodes.StructuredCallbackOutput, error) {
dynamicStreamType, ok := ctxcache.Get[map[string]nodes.FieldStreamType](ctx, groupChoiceTypeCacheKey)
dynamicStreamType, ok := ctxcache.Get[map[string]schema2.FieldStreamType](ctx, groupChoiceTypeCacheKey)
if !ok {
panic("unable to get dynamic stream types from ctx cache")
}
@@ -501,7 +613,7 @@ func (v *VariableAggregator) ToCallbackOutput(ctx context.Context, output map[st
newOut := maps.Clone(output)
for k := range output {
if t, ok := dynamicStreamType[k]; ok && t == nodes.FieldIsStream {
if t, ok := dynamicStreamType[k]; ok && t == schema2.FieldIsStream {
newOut[k] = streamMarker
}
}
@@ -594,3 +706,15 @@ func init() {
nodes.RegisterStreamChunkConcatFunc(concatVACallbackInputs)
nodes.RegisterStreamChunkConcatFunc(concatStreamMarkers)
}
func (v *VariableAggregator) Init(ctx context.Context) (context.Context, error) {
resolvedSources, err := nodes.ResolveStreamSources(ctx, v.fullSources)
if err != nil {
return nil, err
}
// need this info for callbacks.OnStart, so we put it in cache within Init()
ctxcache.Store(ctx, resolvedSourcesCacheKey, resolvedSources)
return ctx, nil
}

View File

@@ -25,29 +25,75 @@ import (
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
type VariableAssigner struct {
config *Config
pairs []*Pair
handler *variable.Handler
}
type Config struct {
Pairs []*Pair
Handler *variable.Handler
Pairs []*Pair
}
type Pair struct {
Left vo.Reference
Right compose.FieldPath
func (c *Config) Adapt(ctx context.Context, n *vo.Node, opts ...nodes.AdaptOption) (*schema.NodeSchema, error) {
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeVariableAssigner,
Name: n.Data.Meta.Title,
Configs: c,
}
var pairs = make([]*Pair, 0, len(n.Data.Inputs.InputParameters))
for i, param := range n.Data.Inputs.InputParameters {
if param.Left == nil || param.Input == nil {
return nil, fmt.Errorf("variable assigner node's param left or input is nil")
}
leftSources, err := convert.CanvasBlockInputToFieldInfo(param.Left, compose.FieldPath{fmt.Sprintf("left_%d", i)}, n.Parent())
if err != nil {
return nil, err
}
if leftSources[0].Source.Ref == nil {
return nil, fmt.Errorf("variable assigner node's param left source ref is nil")
}
if leftSources[0].Source.Ref.VariableType == nil {
return nil, fmt.Errorf("variable assigner node's param left source ref's variable type is nil")
}
if *leftSources[0].Source.Ref.VariableType == vo.GlobalSystem {
return nil, fmt.Errorf("variable assigner node's param left's ref's variable type cannot be variable.GlobalSystem")
}
inputSource, err := convert.CanvasBlockInputToFieldInfo(param.Input, leftSources[0].Source.Ref.FromPath, n.Parent())
if err != nil {
return nil, err
}
ns.AddInputSource(inputSource...)
pair := &Pair{
Left: *leftSources[0].Source.Ref,
Right: inputSource[0].Path,
}
pairs = append(pairs, pair)
}
c.Pairs = pairs
return ns, nil
}
func NewVariableAssigner(_ context.Context, conf *Config) (*VariableAssigner, error) {
for _, pair := range conf.Pairs {
func (c *Config) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
for _, pair := range c.Pairs {
if pair.Left.VariableType == nil {
return nil, fmt.Errorf("cannot assign to output of nodes in VariableAssigner, ref: %v", pair.Left)
}
@@ -63,12 +109,18 @@ func NewVariableAssigner(_ context.Context, conf *Config) (*VariableAssigner, er
}
return &VariableAssigner{
config: conf,
pairs: c.Pairs,
handler: variable.GetVariableHandler(),
}, nil
}
func (v *VariableAssigner) Assign(ctx context.Context, in map[string]any) (map[string]any, error) {
for _, pair := range v.config.Pairs {
type Pair struct {
Left vo.Reference
Right compose.FieldPath
}
func (v *VariableAssigner) Invoke(ctx context.Context, in map[string]any) (map[string]any, error) {
for _, pair := range v.pairs {
right, ok := nodes.TakeMapValue(in, pair.Right)
if !ok {
return nil, vo.NewError(errno.ErrInputFieldMissing, errorx.KV("name", strings.Join(pair.Right, ".")))
@@ -98,7 +150,7 @@ func (v *VariableAssigner) Assign(ctx context.Context, in map[string]any) (map[s
ConnectorUID: exeCfg.ConnectorUID,
}))
}
err := v.config.Handler.Set(ctx, *pair.Left.VariableType, pair.Left.FromPath, right, opts...)
err := v.handler.Set(ctx, *pair.Left.VariableType, pair.Left.FromPath, right, opts...)
if err != nil {
return nil, vo.WrapIfNeeded(errno.ErrVariablesAPIFail, err)
}

View File

@@ -20,25 +20,93 @@ import (
"context"
"fmt"
einoCompose "github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
type InLoop struct {
config *Config
intermediateVarStore variable.Store
type InLoopConfig struct {
Pairs []*Pair
}
func NewVariableAssignerInLoop(_ context.Context, conf *Config) (*InLoop, error) {
func (i *InLoopConfig) Adapt(ctx context.Context, n *vo.Node, opts ...nodes.AdaptOption) (*schema.NodeSchema, error) {
if n.Parent() == nil {
return nil, fmt.Errorf("loop set variable node must have parent: %s", n.ID)
}
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeVariableAssignerWithinLoop,
Name: n.Data.Meta.Title,
Configs: i,
}
var pairs []*Pair
for i, param := range n.Data.Inputs.InputParameters {
if param.Left == nil || param.Right == nil {
return nil, fmt.Errorf("loop set variable node's param left or right is nil")
}
leftSources, err := convert.CanvasBlockInputToFieldInfo(param.Left, einoCompose.FieldPath{fmt.Sprintf("left_%d", i)}, n.Parent())
if err != nil {
return nil, err
}
if len(leftSources) != 1 {
return nil, fmt.Errorf("loop set variable node's param left is not a single source")
}
if leftSources[0].Source.Ref == nil {
return nil, fmt.Errorf("loop set variable node's param left's ref is nil")
}
if leftSources[0].Source.Ref.VariableType == nil || *leftSources[0].Source.Ref.VariableType != vo.ParentIntermediate {
return nil, fmt.Errorf("loop set variable node's param left's ref's variable type is not variable.ParentIntermediate")
}
rightSources, err := convert.CanvasBlockInputToFieldInfo(param.Right, leftSources[0].Source.Ref.FromPath, n.Parent())
if err != nil {
return nil, err
}
ns.AddInputSource(rightSources...)
if len(rightSources) != 1 {
return nil, fmt.Errorf("loop set variable node's param right is not a single source")
}
pair := &Pair{
Left: *leftSources[0].Source.Ref,
Right: rightSources[0].Path,
}
pairs = append(pairs, pair)
}
i.Pairs = pairs
return ns, nil
}
func (i *InLoopConfig) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
return &InLoop{
config: conf,
pairs: i.Pairs,
intermediateVarStore: &nodes.ParentIntermediateStore{},
}, nil
}
func (v *InLoop) Assign(ctx context.Context, in map[string]any) (out map[string]any, err error) {
for _, pair := range v.config.Pairs {
type InLoop struct {
pairs []*Pair
intermediateVarStore variable.Store
}
func (v *InLoop) Invoke(ctx context.Context, in map[string]any) (out map[string]any, err error) {
for _, pair := range v.pairs {
if pair.Left.VariableType == nil || *pair.Left.VariableType != vo.ParentIntermediate {
panic(fmt.Errorf("dest is %+v in VariableAssignerInloop, invalid", pair.Left))
}

View File

@@ -37,36 +37,34 @@ func TestVariableAssigner(t *testing.T) {
arrVar := any([]any{1, "2"})
va := &InLoop{
config: &Config{
Pairs: []*Pair{
{
Left: vo.Reference{
FromPath: compose.FieldPath{"int_var_s"},
VariableType: ptr.Of(vo.ParentIntermediate),
},
Right: compose.FieldPath{"int_var_t"},
pairs: []*Pair{
{
Left: vo.Reference{
FromPath: compose.FieldPath{"int_var_s"},
VariableType: ptr.Of(vo.ParentIntermediate),
},
{
Left: vo.Reference{
FromPath: compose.FieldPath{"str_var_s"},
VariableType: ptr.Of(vo.ParentIntermediate),
},
Right: compose.FieldPath{"str_var_t"},
Right: compose.FieldPath{"int_var_t"},
},
{
Left: vo.Reference{
FromPath: compose.FieldPath{"str_var_s"},
VariableType: ptr.Of(vo.ParentIntermediate),
},
{
Left: vo.Reference{
FromPath: compose.FieldPath{"obj_var_s"},
VariableType: ptr.Of(vo.ParentIntermediate),
},
Right: compose.FieldPath{"obj_var_t"},
Right: compose.FieldPath{"str_var_t"},
},
{
Left: vo.Reference{
FromPath: compose.FieldPath{"obj_var_s"},
VariableType: ptr.Of(vo.ParentIntermediate),
},
{
Left: vo.Reference{
FromPath: compose.FieldPath{"arr_var_s"},
VariableType: ptr.Of(vo.ParentIntermediate),
},
Right: compose.FieldPath{"arr_var_t"},
Right: compose.FieldPath{"obj_var_t"},
},
{
Left: vo.Reference{
FromPath: compose.FieldPath{"arr_var_s"},
VariableType: ptr.Of(vo.ParentIntermediate),
},
Right: compose.FieldPath{"arr_var_t"},
},
},
intermediateVarStore: &nodes.ParentIntermediateStore{},
@@ -79,7 +77,7 @@ func TestVariableAssigner(t *testing.T) {
"arr_var_s": &arrVar,
}, nil)
_, err := va.Assign(ctx, map[string]any{
_, err := va.Invoke(ctx, map[string]any{
"int_var_t": 2,
"str_var_t": "str2",
"obj_var_t": map[string]any{