refactor: how to add a node type in workflow (#558)
This commit is contained in:
@@ -30,50 +30,108 @@ import (
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
)
|
||||
|
||||
type Batch struct {
|
||||
config *Config
|
||||
outputs map[string]*vo.FieldSource
|
||||
outputs map[string]*vo.FieldSource
|
||||
innerWorkflow compose.Runnable[map[string]any, map[string]any]
|
||||
key vo.NodeKey
|
||||
inputArrays []string
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
BatchNodeKey vo.NodeKey `json:"batch_node_key"`
|
||||
InnerWorkflow compose.Runnable[map[string]any, map[string]any]
|
||||
type Config struct{}
|
||||
|
||||
InputArrays []string `json:"input_arrays"`
|
||||
Outputs []*vo.FieldInfo `json:"outputs"`
|
||||
}
|
||||
|
||||
func NewBatch(_ context.Context, config *Config) (*Batch, error) {
|
||||
if config == nil {
|
||||
return nil, errors.New("config is required")
|
||||
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
if n.Parent() != nil {
|
||||
return nil, fmt.Errorf("batch node cannot have parent: %s", n.Parent().ID)
|
||||
}
|
||||
|
||||
if len(config.InputArrays) == 0 {
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeBatch,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: c,
|
||||
}
|
||||
|
||||
batchSizeField, err := convert.CanvasBlockInputToFieldInfo(n.Data.Inputs.BatchSize,
|
||||
compose.FieldPath{MaxBatchSizeKey}, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ns.AddInputSource(batchSizeField...)
|
||||
concurrentSizeField, err := convert.CanvasBlockInputToFieldInfo(n.Data.Inputs.ConcurrentSize,
|
||||
compose.FieldPath{ConcurrentSizeKey}, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ns.AddInputSource(concurrentSizeField...)
|
||||
|
||||
batchSizeType, err := convert.CanvasBlockInputToTypeInfo(n.Data.Inputs.BatchSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ns.SetInputType(MaxBatchSizeKey, batchSizeType)
|
||||
concurrentSizeType, err := convert.CanvasBlockInputToTypeInfo(n.Data.Inputs.ConcurrentSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ns.SetInputType(ConcurrentSizeKey, concurrentSizeType)
|
||||
|
||||
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := convert.SetOutputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, opts ...schema.BuildOption) (any, error) {
|
||||
var inputArrays []string
|
||||
for key, tInfo := range ns.InputTypes {
|
||||
if tInfo.Type != vo.DataTypeArray {
|
||||
continue
|
||||
}
|
||||
|
||||
inputArrays = append(inputArrays, key)
|
||||
}
|
||||
|
||||
if len(inputArrays) == 0 {
|
||||
return nil, errors.New("need to have at least one incoming array for batch")
|
||||
}
|
||||
|
||||
if len(config.Outputs) == 0 {
|
||||
if len(ns.OutputSources) == 0 {
|
||||
return nil, errors.New("need to have at least one output variable for batch")
|
||||
}
|
||||
|
||||
b := &Batch{
|
||||
config: config,
|
||||
outputs: make(map[string]*vo.FieldSource),
|
||||
bo := schema.GetBuildOptions(opts...)
|
||||
if bo.Inner == nil {
|
||||
return nil, errors.New("need to have inner workflow for batch")
|
||||
}
|
||||
|
||||
for i := range config.Outputs {
|
||||
source := config.Outputs[i]
|
||||
b := &Batch{
|
||||
outputs: make(map[string]*vo.FieldSource),
|
||||
innerWorkflow: bo.Inner,
|
||||
key: ns.Key,
|
||||
inputArrays: inputArrays,
|
||||
}
|
||||
|
||||
for i := range ns.OutputSources {
|
||||
source := ns.OutputSources[i]
|
||||
path := source.Path
|
||||
if len(path) != 1 {
|
||||
return nil, fmt.Errorf("invalid path %q", path)
|
||||
}
|
||||
|
||||
// from which inner node's which field does the batch's output fields come from
|
||||
b.outputs[path[0]] = &source.Source
|
||||
}
|
||||
|
||||
@@ -97,11 +155,11 @@ func (b *Batch) initOutput(length int) map[string]any {
|
||||
return out
|
||||
}
|
||||
|
||||
func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.NestedWorkflowOption) (
|
||||
func (b *Batch) Invoke(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (
|
||||
out map[string]any, err error) {
|
||||
arrays := make(map[string]any, len(b.config.InputArrays))
|
||||
arrays := make(map[string]any, len(b.inputArrays))
|
||||
minLen := math.MaxInt64
|
||||
for _, arrayKey := range b.config.InputArrays {
|
||||
for _, arrayKey := range b.inputArrays {
|
||||
a, ok := nodes.TakeMapValue(in, compose.FieldPath{arrayKey})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("incoming array not present in input: %s", arrayKey)
|
||||
@@ -160,13 +218,13 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
|
||||
}
|
||||
}
|
||||
|
||||
input[string(b.config.BatchNodeKey)+"#index"] = int64(i)
|
||||
input[string(b.key)+"#index"] = int64(i)
|
||||
|
||||
items := make(map[string]any)
|
||||
for arrayKey, array := range arrays {
|
||||
ele := reflect.ValueOf(array).Index(i).Interface()
|
||||
items[arrayKey] = []any{ele}
|
||||
currentKey := string(b.config.BatchNodeKey) + "#" + arrayKey
|
||||
currentKey := string(b.key) + "#" + arrayKey
|
||||
|
||||
// Recursively expand map[string]any elements
|
||||
var expand func(prefix string, val interface{})
|
||||
@@ -200,15 +258,11 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
|
||||
return nil
|
||||
}
|
||||
|
||||
options := &nodes.NestedWorkflowOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
options := nodes.GetCommonOptions(&nodes.NodeOptions{}, opts...)
|
||||
var existingCState *nodes.NestedWorkflowState
|
||||
err = compose.ProcessState(ctx, func(ctx context.Context, getter nodes.NestedWorkflowAware) error {
|
||||
var e error
|
||||
existingCState, _, e = getter.GetNestedWorkflowState(b.config.BatchNodeKey)
|
||||
existingCState, _, e = getter.GetNestedWorkflowState(b.key)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
@@ -280,7 +334,7 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
|
||||
mu.Unlock()
|
||||
if subCheckpointID != "" {
|
||||
logs.CtxInfof(ctx, "[testInterrupt] prepare %d th run for batch node %s, subCheckPointID %s",
|
||||
i, b.config.BatchNodeKey, subCheckpointID)
|
||||
i, b.key, subCheckpointID)
|
||||
ithOpts = append(ithOpts, compose.WithCheckPointID(subCheckpointID))
|
||||
}
|
||||
|
||||
@@ -298,7 +352,7 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
|
||||
|
||||
// if the innerWorkflow has output emitter that requires stream output, then we need to stream the inner workflow
|
||||
// the output then needs to be concatenated.
|
||||
taskOutput, err := b.config.InnerWorkflow.Invoke(subCtx, input, ithOpts...)
|
||||
taskOutput, err := b.innerWorkflow.Invoke(subCtx, input, ithOpts...)
|
||||
if err != nil {
|
||||
info, ok := compose.ExtractInterruptInfo(err)
|
||||
if !ok {
|
||||
@@ -376,17 +430,17 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
|
||||
|
||||
if len(index2InterruptInfo) > 0 { // this invocation of batch.Execute has new interruptions
|
||||
iEvent := &entity.InterruptEvent{
|
||||
NodeKey: b.config.BatchNodeKey,
|
||||
NodeKey: b.key,
|
||||
NodeType: entity.NodeTypeBatch,
|
||||
NestedInterruptInfo: index2InterruptInfo, // only emit the newly generated interruptInfo
|
||||
}
|
||||
|
||||
err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error {
|
||||
if e := setter.SaveNestedWorkflowState(b.config.BatchNodeKey, compState); e != nil {
|
||||
if e := setter.SaveNestedWorkflowState(b.key, compState); e != nil {
|
||||
return e
|
||||
}
|
||||
|
||||
return setter.SetInterruptEvent(b.config.BatchNodeKey, iEvent)
|
||||
return setter.SetInterruptEvent(b.key, iEvent)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -398,7 +452,7 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
|
||||
return nil, compose.InterruptAndRerun
|
||||
} else {
|
||||
err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error {
|
||||
if e := setter.SaveNestedWorkflowState(b.config.BatchNodeKey, compState); e != nil {
|
||||
if e := setter.SaveNestedWorkflowState(b.key, compState); e != nil {
|
||||
return e
|
||||
}
|
||||
|
||||
@@ -409,8 +463,8 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
|
||||
// although this invocation does not have new interruptions,
|
||||
// this batch node previously have interrupts yet to be resumed.
|
||||
// we overwrite the interrupt events, keeping only the interrupts yet to be resumed.
|
||||
return setter.SetInterruptEvent(b.config.BatchNodeKey, &entity.InterruptEvent{
|
||||
NodeKey: b.config.BatchNodeKey,
|
||||
return setter.SetInterruptEvent(b.key, &entity.InterruptEvent{
|
||||
NodeKey: b.key,
|
||||
NodeType: entity.NodeTypeBatch,
|
||||
NestedInterruptInfo: existingCState.Index2InterruptInfo,
|
||||
})
|
||||
@@ -424,7 +478,7 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
|
||||
|
||||
if existingCState != nil && len(existingCState.Index2InterruptInfo) > 0 {
|
||||
logs.CtxInfof(ctx, "no interrupt thrown this round, but has historical interrupt events yet to be resumed, "+
|
||||
"nodeKey: %v. indexes: %v", b.config.BatchNodeKey, maps.Keys(existingCState.Index2InterruptInfo))
|
||||
"nodeKey: %v. indexes: %v", b.key, maps.Keys(existingCState.Index2InterruptInfo))
|
||||
return nil, compose.InterruptAndRerun // interrupt again to wait for resuming of previously interrupted index runs
|
||||
}
|
||||
|
||||
@@ -432,8 +486,8 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
|
||||
}
|
||||
|
||||
func (b *Batch) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) {
|
||||
trimmed := make(map[string]any, len(b.config.InputArrays))
|
||||
for _, arrayKey := range b.config.InputArrays {
|
||||
trimmed := make(map[string]any, len(b.inputArrays))
|
||||
for _, arrayKey := range b.inputArrays {
|
||||
if v, ok := in[arrayKey]; ok {
|
||||
trimmed[arrayKey] = v
|
||||
}
|
||||
|
||||
@@ -25,6 +25,10 @@ import (
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
code2 "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/code"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/coderunner"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
@@ -113,50 +117,77 @@ var pythonThirdPartyWhitelist = map[string]struct{}{
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Code string
|
||||
Language coderunner.Language
|
||||
OutputConfig map[string]*vo.TypeInfo
|
||||
Runner coderunner.Runner
|
||||
Code string
|
||||
Language coderunner.Language
|
||||
|
||||
Runner coderunner.Runner
|
||||
}
|
||||
|
||||
type CodeRunner struct {
|
||||
config *Config
|
||||
importError error
|
||||
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeCodeRunner,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: c,
|
||||
}
|
||||
inputs := n.Data.Inputs
|
||||
|
||||
code := inputs.Code
|
||||
c.Code = code
|
||||
|
||||
language, err := convertCodeLanguage(inputs.Language)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.Language = language
|
||||
|
||||
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func NewCodeRunner(ctx context.Context, cfg *Config) (*CodeRunner, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("cfg is required")
|
||||
func convertCodeLanguage(l int64) (coderunner.Language, error) {
|
||||
switch l {
|
||||
case 5:
|
||||
return coderunner.JavaScript, nil
|
||||
case 3:
|
||||
return coderunner.Python, nil
|
||||
default:
|
||||
return "", fmt.Errorf("invalid language: %d", l)
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.Language == "" {
|
||||
return nil, errors.New("language is required")
|
||||
}
|
||||
func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
|
||||
if cfg.Code == "" {
|
||||
return nil, errors.New("code is required")
|
||||
}
|
||||
|
||||
if cfg.Language != coderunner.Python {
|
||||
if c.Language != coderunner.Python {
|
||||
return nil, errors.New("only support python language")
|
||||
}
|
||||
|
||||
if len(cfg.OutputConfig) == 0 {
|
||||
return nil, errors.New("output config is required")
|
||||
}
|
||||
importErr := validatePythonImports(c.Code)
|
||||
|
||||
if cfg.Runner == nil {
|
||||
return nil, errors.New("run coder is required")
|
||||
}
|
||||
|
||||
importErr := validatePythonImports(cfg.Code)
|
||||
|
||||
return &CodeRunner{
|
||||
config: cfg,
|
||||
importError: importErr,
|
||||
return &Runner{
|
||||
code: c.Code,
|
||||
language: c.Language,
|
||||
outputConfig: ns.OutputTypes,
|
||||
runner: code2.GetCodeRunner(),
|
||||
importError: importErr,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type Runner struct {
|
||||
outputConfig map[string]*vo.TypeInfo
|
||||
code string
|
||||
language coderunner.Language
|
||||
runner coderunner.Runner
|
||||
importError error
|
||||
}
|
||||
|
||||
func validatePythonImports(code string) error {
|
||||
imports := parsePythonImports(code)
|
||||
importErrors := make([]string, 0)
|
||||
@@ -191,11 +222,11 @@ func validatePythonImports(code string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CodeRunner) RunCode(ctx context.Context, input map[string]any) (ret map[string]any, err error) {
|
||||
func (c *Runner) Invoke(ctx context.Context, input map[string]any) (ret map[string]any, err error) {
|
||||
if c.importError != nil {
|
||||
return nil, vo.WrapError(errno.ErrCodeExecuteFail, c.importError, errorx.KV("detail", c.importError.Error()))
|
||||
}
|
||||
response, err := c.config.Runner.Run(ctx, &coderunner.RunRequest{Code: c.config.Code, Language: c.config.Language, Params: input})
|
||||
response, err := c.runner.Run(ctx, &coderunner.RunRequest{Code: c.code, Language: c.language, Params: input})
|
||||
if err != nil {
|
||||
return nil, vo.WrapError(errno.ErrCodeExecuteFail, err, errorx.KV("detail", err.Error()))
|
||||
}
|
||||
@@ -203,7 +234,7 @@ func (c *CodeRunner) RunCode(ctx context.Context, input map[string]any) (ret map
|
||||
result := response.Result
|
||||
ctxcache.Store(ctx, coderRunnerRawOutputCtxKey, result)
|
||||
|
||||
output, ws, err := nodes.ConvertInputs(ctx, result, c.config.OutputConfig)
|
||||
output, ws, err := nodes.ConvertInputs(ctx, result, c.outputConfig)
|
||||
if err != nil {
|
||||
return nil, vo.WrapIfNeeded(errno.ErrCodeExecuteFail, err, errorx.KV("detail", err.Error()))
|
||||
}
|
||||
@@ -217,7 +248,7 @@ func (c *CodeRunner) RunCode(ctx context.Context, input map[string]any) (ret map
|
||||
|
||||
}
|
||||
|
||||
func (c *CodeRunner) ToCallbackOutput(ctx context.Context, output map[string]any) (*nodes.StructuredCallbackOutput, error) {
|
||||
func (c *Runner) ToCallbackOutput(ctx context.Context, output map[string]any) (*nodes.StructuredCallbackOutput, error) {
|
||||
rawOutput, ok := ctxcache.Get[map[string]any](ctx, coderRunnerRawOutputCtxKey)
|
||||
if !ok {
|
||||
return nil, errors.New("raw output config is required")
|
||||
|
||||
@@ -75,30 +75,29 @@ async def main(args:Args)->Output:
|
||||
|
||||
mockRunner.EXPECT().Run(gomock.Any(), gomock.Any()).Return(response, nil)
|
||||
ctx := t.Context()
|
||||
c := &CodeRunner{
|
||||
config: &Config{
|
||||
Language: coderunner.Python,
|
||||
Code: codeTpl,
|
||||
OutputConfig: map[string]*vo.TypeInfo{
|
||||
"key0": {Type: vo.DataTypeInteger},
|
||||
"key1": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeString}},
|
||||
"key2": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
"key3": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"key31": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key32": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key33": &vo.TypeInfo{Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
"key34": &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"key341": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key342": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
}},
|
||||
},
|
||||
},
|
||||
"key4": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject}},
|
||||
c := &Runner{
|
||||
language: coderunner.Python,
|
||||
code: codeTpl,
|
||||
outputConfig: map[string]*vo.TypeInfo{
|
||||
"key0": {Type: vo.DataTypeInteger},
|
||||
"key1": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeString}},
|
||||
"key2": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
"key3": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"key31": {Type: vo.DataTypeString},
|
||||
"key32": {Type: vo.DataTypeString},
|
||||
"key33": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
"key34": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"key341": {Type: vo.DataTypeString},
|
||||
"key342": {Type: vo.DataTypeString},
|
||||
}},
|
||||
},
|
||||
Runner: mockRunner,
|
||||
},
|
||||
"key4": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject}},
|
||||
},
|
||||
runner: mockRunner,
|
||||
}
|
||||
ret, err := c.RunCode(ctx, map[string]any{
|
||||
|
||||
ret, err := c.Invoke(ctx, map[string]any{
|
||||
"input": "1123",
|
||||
})
|
||||
|
||||
@@ -145,38 +144,36 @@ async def main(args:Args)->Output:
|
||||
mockRunner.EXPECT().Run(gomock.Any(), gomock.Any()).Return(response, nil)
|
||||
|
||||
ctx := t.Context()
|
||||
c := &CodeRunner{
|
||||
config: &Config{
|
||||
Code: codeTpl,
|
||||
Language: coderunner.Python,
|
||||
OutputConfig: map[string]*vo.TypeInfo{
|
||||
"key0": {Type: vo.DataTypeInteger},
|
||||
"key1": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeString}},
|
||||
"key2": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
"key3": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"key31": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key32": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key33": &vo.TypeInfo{Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
"key34": &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"key341": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key342": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
}},
|
||||
c := &Runner{
|
||||
code: codeTpl,
|
||||
language: coderunner.Python,
|
||||
outputConfig: map[string]*vo.TypeInfo{
|
||||
"key0": {Type: vo.DataTypeInteger},
|
||||
"key1": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeString}},
|
||||
"key2": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
"key3": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"key31": {Type: vo.DataTypeString},
|
||||
"key32": {Type: vo.DataTypeString},
|
||||
"key33": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
"key34": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"key341": {Type: vo.DataTypeString},
|
||||
"key342": {Type: vo.DataTypeString},
|
||||
}},
|
||||
"key4": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"key31": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key32": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key33": &vo.TypeInfo{Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
"key34": &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"key341": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key342": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
},
|
||||
}},
|
||||
}},
|
||||
"key4": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"key31": {Type: vo.DataTypeString},
|
||||
"key32": {Type: vo.DataTypeString},
|
||||
"key33": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
"key34": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"key341": {Type: vo.DataTypeString},
|
||||
"key342": {Type: vo.DataTypeString},
|
||||
},
|
||||
}},
|
||||
},
|
||||
Runner: mockRunner,
|
||||
},
|
||||
runner: mockRunner,
|
||||
}
|
||||
ret, err := c.RunCode(ctx, map[string]any{
|
||||
ret, err := c.Invoke(ctx, map[string]any{
|
||||
"input": "1123",
|
||||
})
|
||||
|
||||
@@ -219,30 +216,28 @@ async def main(args:Args)->Output:
|
||||
}
|
||||
mockRunner.EXPECT().Run(gomock.Any(), gomock.Any()).Return(response, nil)
|
||||
|
||||
c := &CodeRunner{
|
||||
config: &Config{
|
||||
Code: codeTpl,
|
||||
Language: coderunner.Python,
|
||||
OutputConfig: map[string]*vo.TypeInfo{
|
||||
"key0": {Type: vo.DataTypeInteger},
|
||||
"key1": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
"key2": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
"key3": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"key31": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key32": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key33": &vo.TypeInfo{Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
"key34": &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"key341": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key342": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key343": &vo.TypeInfo{Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
}},
|
||||
},
|
||||
},
|
||||
c := &Runner{
|
||||
code: codeTpl,
|
||||
language: coderunner.Python,
|
||||
outputConfig: map[string]*vo.TypeInfo{
|
||||
"key0": {Type: vo.DataTypeInteger},
|
||||
"key1": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
"key2": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
"key3": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"key31": {Type: vo.DataTypeString},
|
||||
"key32": {Type: vo.DataTypeString},
|
||||
"key33": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
"key34": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"key341": {Type: vo.DataTypeString},
|
||||
"key342": {Type: vo.DataTypeString},
|
||||
"key343": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
}},
|
||||
},
|
||||
},
|
||||
Runner: mockRunner,
|
||||
},
|
||||
runner: mockRunner,
|
||||
}
|
||||
ret, err := c.RunCode(ctx, map[string]any{
|
||||
ret, err := c.Invoke(ctx, map[string]any{
|
||||
"input": "1123",
|
||||
})
|
||||
|
||||
|
||||
236
backend/domain/workflow/internal/nodes/database/adapt.go
Normal file
236
backend/domain/workflow/internal/nodes/database/adapt.go
Normal file
@@ -0,0 +1,236 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
einoCompose "github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
func setDatabaseInputsForNodeSchema(n *vo.Node, ns *schema.NodeSchema) (err error) {
|
||||
selectParam := n.Data.Inputs.SelectParam
|
||||
if selectParam != nil {
|
||||
err = applyDBConditionToSchema(ns, selectParam.Condition, n.Parent())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
insertParam := n.Data.Inputs.InsertParam
|
||||
if insertParam != nil {
|
||||
err = applyInsetFieldInfoToSchema(ns, insertParam.FieldInfo, n.Parent())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
deleteParam := n.Data.Inputs.DeleteParam
|
||||
if deleteParam != nil {
|
||||
err = applyDBConditionToSchema(ns, &deleteParam.Condition, n.Parent())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
updateParam := n.Data.Inputs.UpdateParam
|
||||
if updateParam != nil {
|
||||
err = applyDBConditionToSchema(ns, &updateParam.Condition, n.Parent())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = applyInsetFieldInfoToSchema(ns, updateParam.FieldInfo, n.Parent())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func applyDBConditionToSchema(ns *schema.NodeSchema, condition *vo.DBCondition, parentNode *vo.Node) error {
|
||||
if condition.ConditionList == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
for idx, params := range condition.ConditionList {
|
||||
var right *vo.Param
|
||||
for _, param := range params {
|
||||
if param == nil {
|
||||
continue
|
||||
}
|
||||
if param.Name == "right" {
|
||||
right = param
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if right == nil {
|
||||
continue
|
||||
}
|
||||
name := fmt.Sprintf("__condition_right_%d", idx)
|
||||
tInfo, err := convert.CanvasBlockInputToTypeInfo(right.Input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ns.SetInputType(name, tInfo)
|
||||
sources, err := convert.CanvasBlockInputToFieldInfo(right.Input, einoCompose.FieldPath{name}, parentNode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ns.AddInputSource(sources...)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func applyInsetFieldInfoToSchema(ns *schema.NodeSchema, fieldInfo [][]*vo.Param, parentNode *vo.Node) error {
|
||||
if len(fieldInfo) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, params := range fieldInfo {
|
||||
// Each FieldInfo is list params, containing two elements.
|
||||
// The first is to set the name of the field and the second is the corresponding value.
|
||||
p0 := params[0]
|
||||
p1 := params[1]
|
||||
name := p0.Input.Value.Content.(string) // must string type
|
||||
tInfo, err := convert.CanvasBlockInputToTypeInfo(p1.Input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
name = "__setting_field_" + name
|
||||
ns.SetInputType(name, tInfo)
|
||||
sources, err := convert.CanvasBlockInputToFieldInfo(p1.Input, einoCompose.FieldPath{name}, parentNode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ns.AddInputSource(sources...)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildClauseGroupFromCondition(condition *vo.DBCondition) (*database.ClauseGroup, error) {
|
||||
clauseGroup := &database.ClauseGroup{}
|
||||
if len(condition.ConditionList) == 1 {
|
||||
params := condition.ConditionList[0]
|
||||
clause, err := buildClauseFromParams(params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clauseGroup.Single = clause
|
||||
} else {
|
||||
relation, err := convertLogicTypeToRelation(condition.Logic)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clauseGroup.Multi = &database.MultiClause{
|
||||
Clauses: make([]*database.Clause, 0, len(condition.ConditionList)),
|
||||
Relation: relation,
|
||||
}
|
||||
for i := range condition.ConditionList {
|
||||
params := condition.ConditionList[i]
|
||||
clause, err := buildClauseFromParams(params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clauseGroup.Multi.Clauses = append(clauseGroup.Multi.Clauses, clause)
|
||||
}
|
||||
}
|
||||
|
||||
return clauseGroup, nil
|
||||
}
|
||||
|
||||
func buildClauseFromParams(params []*vo.Param) (*database.Clause, error) {
|
||||
var left, operation *vo.Param
|
||||
for _, p := range params {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
if p.Name == "left" {
|
||||
left = p
|
||||
continue
|
||||
}
|
||||
if p.Name == "operation" {
|
||||
operation = p
|
||||
continue
|
||||
}
|
||||
}
|
||||
if left == nil {
|
||||
return nil, fmt.Errorf("left clause is required")
|
||||
}
|
||||
if operation == nil {
|
||||
return nil, fmt.Errorf("operation clause is required")
|
||||
}
|
||||
operator, err := operationToOperator(operation.Input.Value.Content.(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clause := &database.Clause{
|
||||
Left: left.Input.Value.Content.(string),
|
||||
Operator: operator,
|
||||
}
|
||||
|
||||
return clause, nil
|
||||
}
|
||||
|
||||
func convertLogicTypeToRelation(logicType vo.DatabaseLogicType) (database.ClauseRelation, error) {
|
||||
switch logicType {
|
||||
case vo.DatabaseLogicAnd:
|
||||
return database.ClauseRelationAND, nil
|
||||
case vo.DatabaseLogicOr:
|
||||
return database.ClauseRelationOR, nil
|
||||
default:
|
||||
return "", fmt.Errorf("logic type %v is invalid", logicType)
|
||||
}
|
||||
}
|
||||
|
||||
func operationToOperator(s string) (database.Operator, error) {
|
||||
switch s {
|
||||
case "EQUAL":
|
||||
return database.OperatorEqual, nil
|
||||
case "NOT_EQUAL":
|
||||
return database.OperatorNotEqual, nil
|
||||
case "GREATER_THAN":
|
||||
return database.OperatorGreater, nil
|
||||
case "LESS_THAN":
|
||||
return database.OperatorLesser, nil
|
||||
case "GREATER_EQUAL":
|
||||
return database.OperatorGreaterOrEqual, nil
|
||||
case "LESS_EQUAL":
|
||||
return database.OperatorLesserOrEqual, nil
|
||||
case "IN":
|
||||
return database.OperatorIn, nil
|
||||
case "NOT_IN":
|
||||
return database.OperatorNotIn, nil
|
||||
case "IS_NULL":
|
||||
return database.OperatorIsNull, nil
|
||||
case "IS_NOT_NULL":
|
||||
return database.OperatorIsNotNull, nil
|
||||
case "LIKE":
|
||||
return database.OperatorLike, nil
|
||||
case "NOT_LIKE":
|
||||
return database.OperatorNotLike, nil
|
||||
}
|
||||
return "", fmt.Errorf("not a valid Operation string")
|
||||
}
|
||||
@@ -342,7 +342,7 @@ func responseFormatted(configOutput map[string]*vo.TypeInfo, response *database.
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func convertClauseGroupToConditionGroup(ctx context.Context, clauseGroup *database.ClauseGroup, input map[string]any) (*database.ConditionGroup, error) {
|
||||
func convertClauseGroupToConditionGroup(_ context.Context, clauseGroup *database.ClauseGroup, input map[string]any) (*database.ConditionGroup, error) {
|
||||
var (
|
||||
rightValue any
|
||||
ok bool
|
||||
@@ -394,13 +394,13 @@ func convertClauseGroupToConditionGroup(ctx context.Context, clauseGroup *databa
|
||||
return conditionGroup, nil
|
||||
}
|
||||
|
||||
func convertClauseGroupToUpdateInventory(ctx context.Context, clauseGroup *database.ClauseGroup, input map[string]any) (*UpdateInventory, error) {
|
||||
func convertClauseGroupToUpdateInventory(ctx context.Context, clauseGroup *database.ClauseGroup, input map[string]any) (*updateInventory, error) {
|
||||
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, clauseGroup, input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fields := parseToInput(input)
|
||||
inventory := &UpdateInventory{
|
||||
inventory := &updateInventory{
|
||||
ConditionGroup: conditionGroup,
|
||||
Fields: fields,
|
||||
}
|
||||
|
||||
@@ -19,48 +19,89 @@ package database
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
)
|
||||
|
||||
type CustomSQLConfig struct {
|
||||
DatabaseInfoID int64
|
||||
SQLTemplate string
|
||||
OutputConfig map[string]*vo.TypeInfo
|
||||
CustomSQLExecutor database.DatabaseOperator
|
||||
DatabaseInfoID int64
|
||||
SQLTemplate string
|
||||
}
|
||||
|
||||
func NewCustomSQL(_ context.Context, cfg *CustomSQLConfig) (*CustomSQL, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config is required")
|
||||
func (c *CustomSQLConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeDatabaseCustomSQL,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: c,
|
||||
}
|
||||
if cfg.DatabaseInfoID == 0 {
|
||||
|
||||
dsList := n.Data.Inputs.DatabaseInfoList
|
||||
if len(dsList) == 0 {
|
||||
return nil, fmt.Errorf("database info is requird")
|
||||
}
|
||||
databaseInfo := dsList[0]
|
||||
|
||||
dsID, err := strconv.ParseInt(databaseInfo.DatabaseInfoID, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.DatabaseInfoID = dsID
|
||||
|
||||
sql := n.Data.Inputs.SQL
|
||||
if len(sql) == 0 {
|
||||
return nil, fmt.Errorf("sql is requird")
|
||||
}
|
||||
|
||||
c.SQLTemplate = sql
|
||||
|
||||
if err = convert.SetInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func (c *CustomSQLConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
if c.DatabaseInfoID == 0 {
|
||||
return nil, errors.New("database info id is required and greater than 0")
|
||||
}
|
||||
if cfg.SQLTemplate == "" {
|
||||
if c.SQLTemplate == "" {
|
||||
return nil, errors.New("sql template is required")
|
||||
}
|
||||
if cfg.CustomSQLExecutor == nil {
|
||||
return nil, errors.New("custom sqler is required")
|
||||
}
|
||||
|
||||
return &CustomSQL{
|
||||
config: cfg,
|
||||
databaseInfoID: c.DatabaseInfoID,
|
||||
sqlTemplate: c.SQLTemplate,
|
||||
outputTypes: ns.OutputTypes,
|
||||
customSQLExecutor: database.GetDatabaseOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type CustomSQL struct {
|
||||
config *CustomSQLConfig
|
||||
databaseInfoID int64
|
||||
sqlTemplate string
|
||||
outputTypes map[string]*vo.TypeInfo
|
||||
customSQLExecutor database.DatabaseOperator
|
||||
}
|
||||
|
||||
func (c *CustomSQL) Execute(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
|
||||
func (c *CustomSQL) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
req := &database.CustomSQLRequest{
|
||||
DatabaseInfoID: c.config.DatabaseInfoID,
|
||||
DatabaseInfoID: c.databaseInfoID,
|
||||
IsDebugRun: isDebugExecute(ctx),
|
||||
UserID: getExecUserID(ctx),
|
||||
}
|
||||
@@ -71,7 +112,7 @@ func (c *CustomSQL) Execute(ctx context.Context, input map[string]any) (map[stri
|
||||
}
|
||||
|
||||
templateSQL := ""
|
||||
templateParts := nodes.ParseTemplate(c.config.SQLTemplate)
|
||||
templateParts := nodes.ParseTemplate(c.sqlTemplate)
|
||||
sqlParams := make([]database.SQLParam, 0, len(templateParts))
|
||||
var nilError = errors.New("field is nil")
|
||||
for _, templatePart := range templateParts {
|
||||
@@ -113,12 +154,12 @@ func (c *CustomSQL) Execute(ctx context.Context, input map[string]any) (map[stri
|
||||
templateSQL = strings.Replace(templateSQL, "`?`", "?", -1)
|
||||
req.SQL = templateSQL
|
||||
req.Params = sqlParams
|
||||
response, err := c.config.CustomSQLExecutor.Execute(ctx, req)
|
||||
response, err := c.customSQLExecutor.Execute(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret, err := responseFormatted(c.config.OutputConfig, response)
|
||||
ret, err := responseFormatted(c.outputTypes, response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -28,6 +28,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database/databasemock"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
type mockCustomSQLer struct {
|
||||
@@ -39,7 +40,7 @@ func (m mockCustomSQLer) Execute() func(ctx context.Context, request *database.C
|
||||
m.validate(request)
|
||||
r := &database.Response{
|
||||
Objects: []database.Object{
|
||||
database.Object{
|
||||
{
|
||||
"v1": "v1_ret",
|
||||
"v2": "v2_ret",
|
||||
},
|
||||
@@ -58,9 +59,9 @@ func TestCustomSQL_Execute(t *testing.T) {
|
||||
validate: func(req *database.CustomSQLRequest) {
|
||||
assert.Equal(t, int64(111), req.DatabaseInfoID)
|
||||
ps := []database.SQLParam{
|
||||
database.SQLParam{Value: "v1_value"},
|
||||
database.SQLParam{Value: "v2_value"},
|
||||
database.SQLParam{Value: "v3_value"},
|
||||
{Value: "v1_value"},
|
||||
{Value: "v2_value"},
|
||||
{Value: "v3_value"},
|
||||
}
|
||||
assert.Equal(t, ps, req.Params)
|
||||
assert.Equal(t, "select * from v1 where v1 = ? and v2 = ? and v3 = ?", req.SQL)
|
||||
@@ -80,23 +81,25 @@ func TestCustomSQL_Execute(t *testing.T) {
|
||||
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
|
||||
mockDatabaseOperator.EXPECT().Execute(gomock.Any(), gomock.Any()).DoAndReturn(mockSQLer.Execute()).AnyTimes()
|
||||
|
||||
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
|
||||
|
||||
cfg := &CustomSQLConfig{
|
||||
DatabaseInfoID: 111,
|
||||
SQLTemplate: "select * from v1 where v1 = {{v1}} and v2 = '{{v2}}' and v3 = `{{v3}}`",
|
||||
CustomSQLExecutor: mockDatabaseOperator,
|
||||
OutputConfig: map[string]*vo.TypeInfo{
|
||||
DatabaseInfoID: 111,
|
||||
SQLTemplate: "select * from v1 where v1 = {{v1}} and v2 = '{{v2}}' and v3 = `{{v3}}`",
|
||||
}
|
||||
|
||||
c1, err := cfg.Build(context.Background(), &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"v1": {Type: vo.DataTypeString},
|
||||
"v2": {Type: vo.DataTypeString},
|
||||
}}},
|
||||
"rowNum": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
}
|
||||
cl := &CustomSQL{
|
||||
config: cfg,
|
||||
}
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
ret, err := cl.Execute(t.Context(), map[string]any{
|
||||
ret, err := c1.(*CustomSQL).Invoke(t.Context(), map[string]any{
|
||||
"v1": "v1_value",
|
||||
"v2": "v2_value",
|
||||
"v3": "v3_value",
|
||||
|
||||
@@ -20,61 +20,102 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
type DeleteConfig struct {
|
||||
DatabaseInfoID int64
|
||||
ClauseGroup *database.ClauseGroup
|
||||
OutputConfig map[string]*vo.TypeInfo
|
||||
|
||||
Deleter database.DatabaseOperator
|
||||
}
|
||||
type Delete struct {
|
||||
config *DeleteConfig
|
||||
}
|
||||
|
||||
func NewDelete(_ context.Context, cfg *DeleteConfig) (*Delete, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config is required")
|
||||
func (d *DeleteConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeDatabaseDelete,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: d,
|
||||
}
|
||||
if cfg.DatabaseInfoID == 0 {
|
||||
|
||||
dsList := n.Data.Inputs.DatabaseInfoList
|
||||
if len(dsList) == 0 {
|
||||
return nil, fmt.Errorf("database info is requird")
|
||||
}
|
||||
databaseInfo := dsList[0]
|
||||
|
||||
dsID, err := strconv.ParseInt(databaseInfo.DatabaseInfoID, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.DatabaseInfoID = dsID
|
||||
|
||||
deleteParam := n.Data.Inputs.DeleteParam
|
||||
|
||||
clauseGroup, err := buildClauseGroupFromCondition(&deleteParam.Condition)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.ClauseGroup = clauseGroup
|
||||
|
||||
if err = setDatabaseInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func (d *DeleteConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
if d.DatabaseInfoID == 0 {
|
||||
return nil, errors.New("database info id is required and greater than 0")
|
||||
}
|
||||
|
||||
if cfg.ClauseGroup == nil {
|
||||
if d.ClauseGroup == nil {
|
||||
return nil, errors.New("clauseGroup is required")
|
||||
}
|
||||
if cfg.Deleter == nil {
|
||||
return nil, errors.New("deleter is required")
|
||||
}
|
||||
|
||||
return &Delete{
|
||||
config: cfg,
|
||||
databaseInfoID: d.DatabaseInfoID,
|
||||
clauseGroup: d.ClauseGroup,
|
||||
outputTypes: ns.OutputTypes,
|
||||
deleter: database.GetDatabaseOperator(),
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
func (d *Delete) Delete(ctx context.Context, in map[string]any) (map[string]any, error) {
|
||||
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, d.config.ClauseGroup, in)
|
||||
type Delete struct {
|
||||
databaseInfoID int64
|
||||
clauseGroup *database.ClauseGroup
|
||||
outputTypes map[string]*vo.TypeInfo
|
||||
deleter database.DatabaseOperator
|
||||
}
|
||||
|
||||
func (d *Delete) Invoke(ctx context.Context, in map[string]any) (map[string]any, error) {
|
||||
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, d.clauseGroup, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request := &database.DeleteRequest{
|
||||
DatabaseInfoID: d.config.DatabaseInfoID,
|
||||
DatabaseInfoID: d.databaseInfoID,
|
||||
ConditionGroup: conditionGroup,
|
||||
IsDebugRun: isDebugExecute(ctx),
|
||||
UserID: getExecUserID(ctx),
|
||||
}
|
||||
|
||||
response, err := d.config.Deleter.Delete(ctx, request)
|
||||
response, err := d.deleter.Delete(ctx, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret, err := responseFormatted(d.config.OutputConfig, response)
|
||||
ret, err := responseFormatted(d.outputTypes, response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -82,7 +123,7 @@ func (d *Delete) Delete(ctx context.Context, in map[string]any) (map[string]any,
|
||||
}
|
||||
|
||||
func (d *Delete) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) {
|
||||
conditionGroup, err := convertClauseGroupToConditionGroup(context.Background(), d.config.ClauseGroup, in)
|
||||
conditionGroup, err := convertClauseGroupToConditionGroup(context.Background(), d.clauseGroup, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -90,7 +131,7 @@ func (d *Delete) ToCallbackInput(_ context.Context, in map[string]any) (map[stri
|
||||
}
|
||||
|
||||
func (d *Delete) toDatabaseDeleteCallbackInput(conditionGroup *database.ConditionGroup) (map[string]any, error) {
|
||||
databaseID := d.config.DatabaseInfoID
|
||||
databaseID := d.databaseInfoID
|
||||
result := make(map[string]any)
|
||||
|
||||
result["databaseInfoList"] = []string{fmt.Sprintf("%d", databaseID)}
|
||||
|
||||
@@ -20,54 +20,84 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
type InsertConfig struct {
|
||||
DatabaseInfoID int64
|
||||
OutputConfig map[string]*vo.TypeInfo
|
||||
Inserter database.DatabaseOperator
|
||||
}
|
||||
|
||||
type Insert struct {
|
||||
config *InsertConfig
|
||||
}
|
||||
|
||||
func NewInsert(_ context.Context, cfg *InsertConfig) (*Insert, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config is required")
|
||||
func (i *InsertConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeDatabaseInsert,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: i,
|
||||
}
|
||||
if cfg.DatabaseInfoID == 0 {
|
||||
|
||||
dsList := n.Data.Inputs.DatabaseInfoList
|
||||
if len(dsList) == 0 {
|
||||
return nil, fmt.Errorf("database info is requird")
|
||||
}
|
||||
databaseInfo := dsList[0]
|
||||
|
||||
dsID, err := strconv.ParseInt(databaseInfo.DatabaseInfoID, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
i.DatabaseInfoID = dsID
|
||||
|
||||
if err = setDatabaseInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func (i *InsertConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
if i.DatabaseInfoID == 0 {
|
||||
return nil, errors.New("database info id is required and greater than 0")
|
||||
}
|
||||
|
||||
if cfg.Inserter == nil {
|
||||
return nil, errors.New("inserter is required")
|
||||
}
|
||||
return &Insert{
|
||||
config: cfg,
|
||||
databaseInfoID: i.DatabaseInfoID,
|
||||
outputTypes: ns.OutputTypes,
|
||||
inserter: database.GetDatabaseOperator(),
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
func (is *Insert) Insert(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
type Insert struct {
|
||||
databaseInfoID int64
|
||||
outputTypes map[string]*vo.TypeInfo
|
||||
inserter database.DatabaseOperator
|
||||
}
|
||||
|
||||
func (is *Insert) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
fields := parseToInput(input)
|
||||
req := &database.InsertRequest{
|
||||
DatabaseInfoID: is.config.DatabaseInfoID,
|
||||
DatabaseInfoID: is.databaseInfoID,
|
||||
Fields: fields,
|
||||
IsDebugRun: isDebugExecute(ctx),
|
||||
UserID: getExecUserID(ctx),
|
||||
}
|
||||
|
||||
response, err := is.config.Inserter.Insert(ctx, req)
|
||||
response, err := is.inserter.Insert(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret, err := responseFormatted(is.config.OutputConfig, response)
|
||||
ret, err := responseFormatted(is.outputTypes, response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -76,7 +106,7 @@ func (is *Insert) Insert(ctx context.Context, input map[string]any) (map[string]
|
||||
}
|
||||
|
||||
func (is *Insert) ToCallbackInput(_ context.Context, input map[string]any) (map[string]any, error) {
|
||||
databaseID := is.config.DatabaseInfoID
|
||||
databaseID := is.databaseInfoID
|
||||
fs := parseToInput(input)
|
||||
result := make(map[string]any)
|
||||
result["databaseInfoList"] = []string{fmt.Sprintf("%d", databaseID)}
|
||||
|
||||
@@ -20,68 +20,137 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
type QueryConfig struct {
|
||||
DatabaseInfoID int64
|
||||
QueryFields []string
|
||||
OrderClauses []*database.OrderClause
|
||||
OutputConfig map[string]*vo.TypeInfo
|
||||
ClauseGroup *database.ClauseGroup
|
||||
Limit int64
|
||||
Op database.DatabaseOperator
|
||||
}
|
||||
|
||||
type Query struct {
|
||||
config *QueryConfig
|
||||
}
|
||||
|
||||
func NewQuery(_ context.Context, cfg *QueryConfig) (*Query, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config is required")
|
||||
func (q *QueryConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeDatabaseQuery,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: q,
|
||||
}
|
||||
if cfg.DatabaseInfoID == 0 {
|
||||
|
||||
dsList := n.Data.Inputs.DatabaseInfoList
|
||||
if len(dsList) == 0 {
|
||||
return nil, fmt.Errorf("database info is requird")
|
||||
}
|
||||
databaseInfo := dsList[0]
|
||||
|
||||
dsID, err := strconv.ParseInt(databaseInfo.DatabaseInfoID, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
q.DatabaseInfoID = dsID
|
||||
|
||||
selectParam := n.Data.Inputs.SelectParam
|
||||
q.Limit = selectParam.Limit
|
||||
|
||||
queryFields := make([]string, 0)
|
||||
for _, v := range selectParam.FieldList {
|
||||
queryFields = append(queryFields, strconv.FormatInt(v.FieldID, 10))
|
||||
}
|
||||
q.QueryFields = queryFields
|
||||
|
||||
orderClauses := make([]*database.OrderClause, 0, len(selectParam.OrderByList))
|
||||
for _, o := range selectParam.OrderByList {
|
||||
orderClauses = append(orderClauses, &database.OrderClause{
|
||||
FieldID: strconv.FormatInt(o.FieldID, 10),
|
||||
IsAsc: o.IsAsc,
|
||||
})
|
||||
}
|
||||
q.OrderClauses = orderClauses
|
||||
|
||||
clauseGroup := &database.ClauseGroup{}
|
||||
|
||||
if selectParam.Condition != nil {
|
||||
clauseGroup, err = buildClauseGroupFromCondition(selectParam.Condition)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
q.ClauseGroup = clauseGroup
|
||||
|
||||
if err = setDatabaseInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func (q *QueryConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
if q.DatabaseInfoID == 0 {
|
||||
return nil, errors.New("database info id is required and greater than 0")
|
||||
}
|
||||
|
||||
if cfg.Limit == 0 {
|
||||
if q.Limit == 0 {
|
||||
return nil, errors.New("limit is required and greater than 0")
|
||||
}
|
||||
|
||||
if cfg.Op == nil {
|
||||
return nil, errors.New("op is required")
|
||||
}
|
||||
|
||||
return &Query{config: cfg}, nil
|
||||
|
||||
return &Query{
|
||||
databaseInfoID: q.DatabaseInfoID,
|
||||
queryFields: q.QueryFields,
|
||||
orderClauses: q.OrderClauses,
|
||||
outputTypes: ns.OutputTypes,
|
||||
clauseGroup: q.ClauseGroup,
|
||||
limit: q.Limit,
|
||||
op: database.GetDatabaseOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (ds *Query) Query(ctx context.Context, in map[string]any) (map[string]any, error) {
|
||||
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, ds.config.ClauseGroup, in)
|
||||
type Query struct {
|
||||
databaseInfoID int64
|
||||
queryFields []string
|
||||
orderClauses []*database.OrderClause
|
||||
outputTypes map[string]*vo.TypeInfo
|
||||
clauseGroup *database.ClauseGroup
|
||||
limit int64
|
||||
op database.DatabaseOperator
|
||||
}
|
||||
|
||||
func (ds *Query) Invoke(ctx context.Context, in map[string]any) (map[string]any, error) {
|
||||
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, ds.clauseGroup, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req := &database.QueryRequest{
|
||||
DatabaseInfoID: ds.config.DatabaseInfoID,
|
||||
OrderClauses: ds.config.OrderClauses,
|
||||
SelectFields: ds.config.QueryFields,
|
||||
Limit: ds.config.Limit,
|
||||
DatabaseInfoID: ds.databaseInfoID,
|
||||
OrderClauses: ds.orderClauses,
|
||||
SelectFields: ds.queryFields,
|
||||
Limit: ds.limit,
|
||||
IsDebugRun: isDebugExecute(ctx),
|
||||
UserID: getExecUserID(ctx),
|
||||
}
|
||||
|
||||
req.ConditionGroup = conditionGroup
|
||||
|
||||
response, err := ds.config.Op.Query(ctx, req)
|
||||
response, err := ds.op.Query(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret, err := responseFormatted(ds.config.OutputConfig, response)
|
||||
ret, err := responseFormatted(ds.outputTypes, response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -93,18 +162,18 @@ func notNeedTakeMapValue(op database.Operator) bool {
|
||||
}
|
||||
|
||||
func (ds *Query) ToCallbackInput(ctx context.Context, in map[string]any) (map[string]any, error) {
|
||||
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, ds.config.ClauseGroup, in)
|
||||
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, ds.clauseGroup, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toDatabaseQueryCallbackInput(ds.config, conditionGroup)
|
||||
return ds.toDatabaseQueryCallbackInput(conditionGroup)
|
||||
}
|
||||
|
||||
func toDatabaseQueryCallbackInput(config *QueryConfig, conditionGroup *database.ConditionGroup) (map[string]any, error) {
|
||||
func (ds *Query) toDatabaseQueryCallbackInput(conditionGroup *database.ConditionGroup) (map[string]any, error) {
|
||||
result := make(map[string]any)
|
||||
|
||||
databaseID := config.DatabaseInfoID
|
||||
databaseID := ds.databaseInfoID
|
||||
result["databaseInfoList"] = []string{fmt.Sprintf("%d", databaseID)}
|
||||
result["selectParam"] = map[string]any{}
|
||||
|
||||
@@ -116,8 +185,8 @@ func toDatabaseQueryCallbackInput(config *QueryConfig, conditionGroup *database.
|
||||
FieldID string `json:"fieldId"`
|
||||
IsDistinct bool `json:"isDistinct"`
|
||||
}
|
||||
fieldList := make([]Field, 0, len(config.QueryFields))
|
||||
for _, f := range config.QueryFields {
|
||||
fieldList := make([]Field, 0, len(ds.queryFields))
|
||||
for _, f := range ds.queryFields {
|
||||
fieldList = append(fieldList, Field{FieldID: f})
|
||||
}
|
||||
type Order struct {
|
||||
@@ -126,7 +195,7 @@ func toDatabaseQueryCallbackInput(config *QueryConfig, conditionGroup *database.
|
||||
}
|
||||
|
||||
OrderList := make([]Order, 0)
|
||||
for _, c := range config.OrderClauses {
|
||||
for _, c := range ds.orderClauses {
|
||||
OrderList = append(OrderList, Order{
|
||||
FieldID: c.FieldID,
|
||||
IsAsc: c.IsAsc,
|
||||
@@ -135,12 +204,11 @@ func toDatabaseQueryCallbackInput(config *QueryConfig, conditionGroup *database.
|
||||
result["selectParam"] = map[string]any{
|
||||
"condition": condition,
|
||||
"fieldList": fieldList,
|
||||
"limit": config.Limit,
|
||||
"limit": ds.limit,
|
||||
"orderByList": OrderList,
|
||||
}
|
||||
|
||||
return result, nil
|
||||
|
||||
}
|
||||
|
||||
type ConditionItem struct {
|
||||
@@ -216,6 +284,5 @@ func convertToLogic(rel database.ClauseRelation) (string, error) {
|
||||
return "AND", nil
|
||||
default:
|
||||
return "", fmt.Errorf("unknown clause relation %v", rel)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,6 +30,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database/databasemock"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
type mockDsSelect struct {
|
||||
@@ -82,16 +83,7 @@ func TestDataset_Query(t *testing.T) {
|
||||
},
|
||||
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
|
||||
QueryFields: []string{"v1", "v2"},
|
||||
OutputConfig: map[string]*vo.TypeInfo{
|
||||
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
"v1": {Type: vo.DataTypeString},
|
||||
"v2": {Type: vo.DataTypeString},
|
||||
},
|
||||
}},
|
||||
"rowNum": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
Limit: 10,
|
||||
}
|
||||
|
||||
mockQuery := &mockDsSelect{objects: objects, t: t, validate: func(request *database.QueryRequest) {
|
||||
@@ -106,17 +98,27 @@ func TestDataset_Query(t *testing.T) {
|
||||
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
|
||||
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query())
|
||||
|
||||
cfg.Op = mockDatabaseOperator
|
||||
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
|
||||
|
||||
ds := Query{
|
||||
config: cfg,
|
||||
}
|
||||
ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
"v1": {Type: vo.DataTypeString},
|
||||
"v2": {Type: vo.DataTypeString},
|
||||
},
|
||||
}},
|
||||
"rowNum": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
in := map[string]interface{}{
|
||||
"__condition_right_0": 1,
|
||||
}
|
||||
|
||||
result, err := ds.Query(t.Context(), in)
|
||||
result, err := ds.(*Query).Invoke(t.Context(), in)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "1", result["outputList"].([]any)[0].(database.Object)["v1"])
|
||||
assert.Equal(t, "2", result["outputList"].([]any)[0].(database.Object)["v2"])
|
||||
@@ -137,17 +139,7 @@ func TestDataset_Query(t *testing.T) {
|
||||
|
||||
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
|
||||
QueryFields: []string{"v1", "v2"},
|
||||
|
||||
OutputConfig: map[string]*vo.TypeInfo{
|
||||
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
"v1": {Type: vo.DataTypeString},
|
||||
"v2": {Type: vo.DataTypeString},
|
||||
},
|
||||
}},
|
||||
"rowNum": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
Limit: 10,
|
||||
}
|
||||
|
||||
objects := make([]database.Object, 0)
|
||||
@@ -170,18 +162,28 @@ func TestDataset_Query(t *testing.T) {
|
||||
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
|
||||
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
|
||||
|
||||
cfg.Op = mockDatabaseOperator
|
||||
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
|
||||
|
||||
ds := Query{
|
||||
config: cfg,
|
||||
}
|
||||
ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
"v1": {Type: vo.DataTypeString},
|
||||
"v2": {Type: vo.DataTypeString},
|
||||
},
|
||||
}},
|
||||
"rowNum": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
in := map[string]any{
|
||||
"__condition_right_0": 1,
|
||||
"__condition_right_1": 2,
|
||||
}
|
||||
|
||||
result, err := ds.Query(t.Context(), in)
|
||||
result, err := ds.(*Query).Invoke(t.Context(), in)
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "1", result["outputList"].([]any)[0].(database.Object)["v1"])
|
||||
@@ -199,17 +201,7 @@ func TestDataset_Query(t *testing.T) {
|
||||
},
|
||||
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
|
||||
QueryFields: []string{"v1", "v2"},
|
||||
|
||||
OutputConfig: map[string]*vo.TypeInfo{
|
||||
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
"v1": {Type: vo.DataTypeInteger},
|
||||
"v2": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
}},
|
||||
"rowNum": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
Limit: 10,
|
||||
}
|
||||
objects := make([]database.Object, 0)
|
||||
objects = append(objects, database.Object{
|
||||
@@ -230,17 +222,27 @@ func TestDataset_Query(t *testing.T) {
|
||||
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
|
||||
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
|
||||
|
||||
cfg.Op = mockDatabaseOperator
|
||||
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
|
||||
|
||||
ds := Query{
|
||||
config: cfg,
|
||||
}
|
||||
ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
"v1": {Type: vo.DataTypeInteger},
|
||||
"v2": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
}},
|
||||
"rowNum": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
in := map[string]any{
|
||||
"__condition_right_0": 1,
|
||||
}
|
||||
|
||||
result, err := ds.Query(t.Context(), in)
|
||||
result, err := ds.(*Query).Invoke(t.Context(), in)
|
||||
assert.NoError(t, err)
|
||||
fmt.Println(result)
|
||||
assert.Equal(t, map[string]any{
|
||||
@@ -261,18 +263,7 @@ func TestDataset_Query(t *testing.T) {
|
||||
},
|
||||
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
|
||||
QueryFields: []string{"v1", "v2"},
|
||||
|
||||
OutputConfig: map[string]*vo.TypeInfo{
|
||||
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
"v1": {Type: vo.DataTypeInteger},
|
||||
"v2": {Type: vo.DataTypeInteger},
|
||||
"v3": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
}},
|
||||
"rowNum": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
Limit: 10,
|
||||
}
|
||||
objects := make([]database.Object, 0)
|
||||
objects = append(objects, database.Object{
|
||||
@@ -290,15 +281,26 @@ func TestDataset_Query(t *testing.T) {
|
||||
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
|
||||
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
|
||||
|
||||
cfg.Op = mockDatabaseOperator
|
||||
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
|
||||
|
||||
ds := Query{
|
||||
config: cfg,
|
||||
}
|
||||
ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
"v1": {Type: vo.DataTypeInteger},
|
||||
"v2": {Type: vo.DataTypeInteger},
|
||||
"v3": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
}},
|
||||
"rowNum": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
in := map[string]any{"__condition_right_0": 1}
|
||||
|
||||
result, err := ds.Query(t.Context(), in)
|
||||
result, err := ds.(*Query).Invoke(t.Context(), in)
|
||||
assert.NoError(t, err)
|
||||
fmt.Println(result)
|
||||
assert.Equal(t, int64(1), result["outputList"].([]any)[0].(database.Object)["v1"])
|
||||
@@ -321,22 +323,7 @@ func TestDataset_Query(t *testing.T) {
|
||||
},
|
||||
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
|
||||
QueryFields: []string{"v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8"},
|
||||
|
||||
OutputConfig: map[string]*vo.TypeInfo{
|
||||
"outputList": {Type: vo.DataTypeArray,
|
||||
ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"v1": {Type: vo.DataTypeInteger},
|
||||
"v2": {Type: vo.DataTypeNumber},
|
||||
"v3": {Type: vo.DataTypeBoolean},
|
||||
"v4": {Type: vo.DataTypeBoolean},
|
||||
"v5": {Type: vo.DataTypeTime},
|
||||
"v6": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeInteger}},
|
||||
"v7": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeBoolean}},
|
||||
"v8": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
},
|
||||
}},
|
||||
"rowNum": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
Limit: 10,
|
||||
}
|
||||
|
||||
objects := make([]database.Object, 0)
|
||||
@@ -363,17 +350,32 @@ func TestDataset_Query(t *testing.T) {
|
||||
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
|
||||
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
|
||||
|
||||
cfg.Op = mockDatabaseOperator
|
||||
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
|
||||
|
||||
ds := Query{
|
||||
config: cfg,
|
||||
}
|
||||
ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
"outputList": {Type: vo.DataTypeArray,
|
||||
ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"v1": {Type: vo.DataTypeInteger},
|
||||
"v2": {Type: vo.DataTypeNumber},
|
||||
"v3": {Type: vo.DataTypeBoolean},
|
||||
"v4": {Type: vo.DataTypeBoolean},
|
||||
"v5": {Type: vo.DataTypeTime},
|
||||
"v6": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeInteger}},
|
||||
"v7": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeBoolean}},
|
||||
"v8": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
},
|
||||
}},
|
||||
"rowNum": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
in := map[string]any{
|
||||
"__condition_right_0": 1,
|
||||
}
|
||||
|
||||
result, err := ds.Query(t.Context(), in)
|
||||
result, err := ds.(*Query).Invoke(t.Context(), in)
|
||||
assert.NoError(t, err)
|
||||
object := result["outputList"].([]any)[0].(database.Object)
|
||||
|
||||
@@ -400,10 +402,7 @@ func TestDataset_Query(t *testing.T) {
|
||||
},
|
||||
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
|
||||
QueryFields: []string{"v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8"},
|
||||
OutputConfig: map[string]*vo.TypeInfo{
|
||||
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{}}},
|
||||
"rowNum": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
Limit: 10,
|
||||
}
|
||||
|
||||
objects := make([]database.Object, 0)
|
||||
@@ -429,16 +428,21 @@ func TestDataset_Query(t *testing.T) {
|
||||
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
|
||||
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
|
||||
|
||||
cfg.Op = mockDatabaseOperator
|
||||
ds := Query{
|
||||
config: cfg,
|
||||
}
|
||||
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
|
||||
|
||||
ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{}}},
|
||||
"rowNum": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
in := map[string]any{
|
||||
"__condition_right_0": 1,
|
||||
}
|
||||
|
||||
result, err := ds.Query(t.Context(), in)
|
||||
result, err := ds.(*Query).Invoke(t.Context(), in)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, result["outputList"].([]any)[0].(database.Object), database.Object{
|
||||
"v1": "1",
|
||||
|
||||
@@ -20,47 +20,93 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
type UpdateConfig struct {
|
||||
DatabaseInfoID int64
|
||||
ClauseGroup *database.ClauseGroup
|
||||
OutputConfig map[string]*vo.TypeInfo
|
||||
Updater database.DatabaseOperator
|
||||
}
|
||||
|
||||
func (u *UpdateConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeDatabaseUpdate,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: u,
|
||||
}
|
||||
|
||||
dsList := n.Data.Inputs.DatabaseInfoList
|
||||
if len(dsList) == 0 {
|
||||
return nil, fmt.Errorf("database info is requird")
|
||||
}
|
||||
databaseInfo := dsList[0]
|
||||
|
||||
dsID, err := strconv.ParseInt(databaseInfo.DatabaseInfoID, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.DatabaseInfoID = dsID
|
||||
|
||||
updateParam := n.Data.Inputs.UpdateParam
|
||||
if updateParam == nil {
|
||||
return nil, fmt.Errorf("update param is requird")
|
||||
}
|
||||
clauseGroup, err := buildClauseGroupFromCondition(&updateParam.Condition)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.ClauseGroup = clauseGroup
|
||||
|
||||
if err = setDatabaseInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func (u *UpdateConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
if u.DatabaseInfoID == 0 {
|
||||
return nil, errors.New("database info id is required and greater than 0")
|
||||
}
|
||||
|
||||
if u.ClauseGroup == nil {
|
||||
return nil, errors.New("clause group is required and greater than 0")
|
||||
}
|
||||
|
||||
return &Update{
|
||||
databaseInfoID: u.DatabaseInfoID,
|
||||
clauseGroup: u.ClauseGroup,
|
||||
outputTypes: ns.OutputTypes,
|
||||
updater: database.GetDatabaseOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type Update struct {
|
||||
config *UpdateConfig
|
||||
databaseInfoID int64
|
||||
clauseGroup *database.ClauseGroup
|
||||
outputTypes map[string]*vo.TypeInfo
|
||||
updater database.DatabaseOperator
|
||||
}
|
||||
type UpdateInventory struct {
|
||||
|
||||
type updateInventory struct {
|
||||
ConditionGroup *database.ConditionGroup
|
||||
Fields map[string]any
|
||||
}
|
||||
|
||||
func NewUpdate(_ context.Context, cfg *UpdateConfig) (*Update, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config is required")
|
||||
}
|
||||
if cfg.DatabaseInfoID == 0 {
|
||||
return nil, errors.New("database info id is required and greater than 0")
|
||||
}
|
||||
|
||||
if cfg.ClauseGroup == nil {
|
||||
return nil, errors.New("clause group is required and greater than 0")
|
||||
}
|
||||
|
||||
if cfg.Updater == nil {
|
||||
return nil, errors.New("updater is required")
|
||||
}
|
||||
|
||||
return &Update{config: cfg}, nil
|
||||
}
|
||||
|
||||
func (u *Update) Update(ctx context.Context, in map[string]any) (map[string]any, error) {
|
||||
inventory, err := convertClauseGroupToUpdateInventory(ctx, u.config.ClauseGroup, in)
|
||||
func (u *Update) Invoke(ctx context.Context, in map[string]any) (map[string]any, error) {
|
||||
inventory, err := convertClauseGroupToUpdateInventory(ctx, u.clauseGroup, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -72,20 +118,20 @@ func (u *Update) Update(ctx context.Context, in map[string]any) (map[string]any,
|
||||
}
|
||||
|
||||
req := &database.UpdateRequest{
|
||||
DatabaseInfoID: u.config.DatabaseInfoID,
|
||||
DatabaseInfoID: u.databaseInfoID,
|
||||
ConditionGroup: inventory.ConditionGroup,
|
||||
Fields: fields,
|
||||
IsDebugRun: isDebugExecute(ctx),
|
||||
UserID: getExecUserID(ctx),
|
||||
}
|
||||
|
||||
response, err := u.config.Updater.Update(ctx, req)
|
||||
response, err := u.updater.Update(ctx, req)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret, err := responseFormatted(u.config.OutputConfig, response)
|
||||
ret, err := responseFormatted(u.outputTypes, response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -94,15 +140,15 @@ func (u *Update) Update(ctx context.Context, in map[string]any) (map[string]any,
|
||||
}
|
||||
|
||||
func (u *Update) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) {
|
||||
inventory, err := convertClauseGroupToUpdateInventory(context.Background(), u.config.ClauseGroup, in)
|
||||
inventory, err := convertClauseGroupToUpdateInventory(context.Background(), u.clauseGroup, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return u.toDatabaseUpdateCallbackInput(inventory)
|
||||
}
|
||||
|
||||
func (u *Update) toDatabaseUpdateCallbackInput(inventory *UpdateInventory) (map[string]any, error) {
|
||||
databaseID := u.config.DatabaseInfoID
|
||||
func (u *Update) toDatabaseUpdateCallbackInput(inventory *updateInventory) (map[string]any, error) {
|
||||
databaseID := u.databaseInfoID
|
||||
result := make(map[string]any)
|
||||
result["databaseInfoList"] = []string{fmt.Sprintf("%d", databaseID)}
|
||||
result["updateParam"] = map[string]any{}
|
||||
@@ -128,6 +174,6 @@ func (u *Update) toDatabaseUpdateCallbackInput(inventory *UpdateInventory) (map[
|
||||
"condition": condition,
|
||||
"fieldInfo": fieldInfo,
|
||||
}
|
||||
return result, nil
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -18,7 +18,6 @@ package emitter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
@@ -26,28 +25,77 @@ import (
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
)
|
||||
|
||||
type OutputEmitter struct {
|
||||
cfg *Config
|
||||
Template string
|
||||
FullSources map[string]*schema2.SourceInfo
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Template string
|
||||
FullSources map[string]*nodes.SourceInfo
|
||||
Template string
|
||||
}
|
||||
|
||||
func New(_ context.Context, cfg *Config) (*OutputEmitter, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config is required")
|
||||
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema2.NodeSchema, error) {
|
||||
ns := &schema2.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeOutputEmitter,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: c,
|
||||
}
|
||||
|
||||
content := n.Data.Inputs.Content
|
||||
streamingOutput := n.Data.Inputs.StreamingOutput
|
||||
|
||||
if streamingOutput {
|
||||
ns.StreamConfigs = &schema2.StreamConfig{
|
||||
RequireStreamingInput: true,
|
||||
}
|
||||
} else {
|
||||
ns.StreamConfigs = &schema2.StreamConfig{
|
||||
RequireStreamingInput: false,
|
||||
}
|
||||
}
|
||||
|
||||
if content != nil {
|
||||
if content.Type != vo.VariableTypeString {
|
||||
return nil, fmt.Errorf("output emitter node's content type must be %s, got %s", vo.VariableTypeString, content.Type)
|
||||
}
|
||||
|
||||
if content.Value.Type != vo.BlockInputValueTypeLiteral {
|
||||
return nil, fmt.Errorf("output emitter node's content value type must be %s, got %s", vo.BlockInputValueTypeLiteral, content.Value.Type)
|
||||
}
|
||||
|
||||
if content.Value.Content == nil {
|
||||
c.Template = ""
|
||||
} else {
|
||||
template, ok := content.Value.Content.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("output emitter node's content value must be string, got %v", content.Value.Content)
|
||||
}
|
||||
|
||||
c.Template = template
|
||||
}
|
||||
}
|
||||
|
||||
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func (c *Config) Build(_ context.Context, ns *schema2.NodeSchema, _ ...schema2.BuildOption) (any, error) {
|
||||
return &OutputEmitter{
|
||||
cfg: cfg,
|
||||
Template: c.Template,
|
||||
FullSources: ns.FullSources,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -59,10 +107,10 @@ type cachedVal struct {
|
||||
|
||||
type cacheStore struct {
|
||||
store map[string]*cachedVal
|
||||
infos map[string]*nodes.SourceInfo
|
||||
infos map[string]*schema2.SourceInfo
|
||||
}
|
||||
|
||||
func newCacheStore(infos map[string]*nodes.SourceInfo) *cacheStore {
|
||||
func newCacheStore(infos map[string]*schema2.SourceInfo) *cacheStore {
|
||||
return &cacheStore{
|
||||
store: make(map[string]*cachedVal),
|
||||
infos: infos,
|
||||
@@ -76,7 +124,7 @@ func (c *cacheStore) put(k string, v any) (any, error) {
|
||||
}
|
||||
|
||||
if !sInfo.IsIntermediate { // this is not an intermediate object container
|
||||
isStream := sInfo.FieldType == nodes.FieldIsStream
|
||||
isStream := sInfo.FieldType == schema2.FieldIsStream
|
||||
if !isStream {
|
||||
_, ok := c.store[k]
|
||||
if !ok {
|
||||
@@ -159,7 +207,7 @@ func (c *cacheStore) put(k string, v any) (any, error) {
|
||||
func (c *cacheStore) finished(k string) bool {
|
||||
cached, ok := c.store[k]
|
||||
if !ok {
|
||||
return c.infos[k].FieldType == nodes.FieldSkipped
|
||||
return c.infos[k].FieldType == schema2.FieldSkipped
|
||||
}
|
||||
|
||||
if cached.finished {
|
||||
@@ -182,7 +230,7 @@ func (c *cacheStore) finished(k string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *cacheStore) find(part nodes.TemplatePart) (root any, subCache *cachedVal, sourceInfo *nodes.SourceInfo,
|
||||
func (c *cacheStore) find(part nodes.TemplatePart) (root any, subCache *cachedVal, sourceInfo *schema2.SourceInfo,
|
||||
actualPath []string,
|
||||
) {
|
||||
rootCached, ok := c.store[part.Root]
|
||||
@@ -230,7 +278,7 @@ func (c *cacheStore) readyForPart(part nodes.TemplatePart, sw *schema.StreamWrit
|
||||
hasErr bool, partFinished bool) {
|
||||
cachedRoot, subCache, sourceInfo, _ := c.find(part)
|
||||
if cachedRoot != nil && subCache != nil {
|
||||
if subCache.finished || sourceInfo.FieldType == nodes.FieldIsStream {
|
||||
if subCache.finished || sourceInfo.FieldType == schema2.FieldIsStream {
|
||||
hasErr = renderAndSend(part, part.Root, cachedRoot, sw)
|
||||
if hasErr {
|
||||
return true, false
|
||||
@@ -315,14 +363,14 @@ func merge(a, b any) any {
|
||||
|
||||
const outputKey = "output"
|
||||
|
||||
func (e *OutputEmitter) EmitStream(ctx context.Context, in *schema.StreamReader[map[string]any]) (out *schema.StreamReader[map[string]any], err error) {
|
||||
resolvedSources, err := nodes.ResolveStreamSources(ctx, e.cfg.FullSources)
|
||||
func (e *OutputEmitter) Transform(ctx context.Context, in *schema.StreamReader[map[string]any]) (out *schema.StreamReader[map[string]any], err error) {
|
||||
resolvedSources, err := nodes.ResolveStreamSources(ctx, e.FullSources)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sr, sw := schema.Pipe[map[string]any](0)
|
||||
parts := nodes.ParseTemplate(e.cfg.Template)
|
||||
parts := nodes.ParseTemplate(e.Template)
|
||||
safego.Go(ctx, func() {
|
||||
hasErr := false
|
||||
defer func() {
|
||||
@@ -454,7 +502,7 @@ func (e *OutputEmitter) EmitStream(ctx context.Context, in *schema.StreamReader[
|
||||
shouldChangePart = true
|
||||
}
|
||||
} else {
|
||||
if sourceInfo.FieldType == nodes.FieldIsStream {
|
||||
if sourceInfo.FieldType == schema2.FieldIsStream {
|
||||
currentV := v
|
||||
for i := 0; i < len(actualPath)-1; i++ {
|
||||
currentM, ok := currentV.(map[string]any)
|
||||
@@ -518,8 +566,8 @@ func (e *OutputEmitter) EmitStream(ctx context.Context, in *schema.StreamReader[
|
||||
return sr, nil
|
||||
}
|
||||
|
||||
func (e *OutputEmitter) Emit(ctx context.Context, in map[string]any) (output map[string]any, err error) {
|
||||
s, err := nodes.Render(ctx, e.cfg.Template, in, e.cfg.FullSources)
|
||||
func (e *OutputEmitter) Invoke(ctx context.Context, in map[string]any) (output map[string]any, err error) {
|
||||
s, err := nodes.Render(ctx, e.Template, in, e.FullSources)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -20,41 +20,74 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
DefaultValues map[string]any
|
||||
OutputTypes map[string]*vo.TypeInfo
|
||||
}
|
||||
|
||||
type Entry struct {
|
||||
cfg *Config
|
||||
defaultValues map[string]any
|
||||
}
|
||||
|
||||
func NewEntry(ctx context.Context, cfg *Config) (*Entry, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("config is requried")
|
||||
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
if n.Parent() != nil {
|
||||
return nil, fmt.Errorf("entry node cannot have parent: %s", n.Parent().ID)
|
||||
}
|
||||
defaultValues, _, err := nodes.ConvertInputs(ctx, cfg.DefaultValues, cfg.OutputTypes, nodes.FailFast(), nodes.SkipRequireCheck())
|
||||
|
||||
if n.ID != entity.EntryNodeKey {
|
||||
return nil, fmt.Errorf("entry node id must be %s, got %s", entity.EntryNodeKey, n.ID)
|
||||
}
|
||||
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Name: n.Data.Meta.Title,
|
||||
Type: entity.NodeTypeEntry,
|
||||
}
|
||||
|
||||
defaultValues := make(map[string]any, len(n.Data.Outputs))
|
||||
for _, v := range n.Data.Outputs {
|
||||
variable, err := vo.ParseVariable(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if variable.DefaultValue != nil {
|
||||
defaultValues[variable.Name] = variable.DefaultValue
|
||||
}
|
||||
}
|
||||
|
||||
c.DefaultValues = defaultValues
|
||||
ns.Configs = c
|
||||
|
||||
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func (c *Config) Build(ctx context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
defaultValues, _, err := nodes.ConvertInputs(ctx, c.DefaultValues, ns.OutputTypes, nodes.FailFast(), nodes.SkipRequireCheck())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Entry{
|
||||
cfg: cfg,
|
||||
defaultValues: defaultValues,
|
||||
outputTypes: ns.OutputTypes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type Entry struct {
|
||||
defaultValues map[string]any
|
||||
outputTypes map[string]*vo.TypeInfo
|
||||
}
|
||||
|
||||
func (e *Entry) Invoke(_ context.Context, in map[string]any) (out map[string]any, err error) {
|
||||
|
||||
for k, v := range e.defaultValues {
|
||||
if val, ok := in[k]; ok {
|
||||
tInfo := e.cfg.OutputTypes[k]
|
||||
tInfo := e.outputTypes[k]
|
||||
switch tInfo.Type {
|
||||
case vo.DataTypeString:
|
||||
if len(val.(string)) == 0 {
|
||||
|
||||
113
backend/domain/workflow/internal/nodes/exit/exit.go
Normal file
113
backend/domain/workflow/internal/nodes/exit/exit.go
Normal file
@@ -0,0 +1,113 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package exit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/emitter"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Template string
|
||||
TerminatePlan vo.TerminatePlan
|
||||
}
|
||||
|
||||
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
if n.Parent() != nil {
|
||||
return nil, fmt.Errorf("exit node cannot have parent: %s", n.Parent().ID)
|
||||
}
|
||||
|
||||
if n.ID != entity.ExitNodeKey {
|
||||
return nil, fmt.Errorf("exit node id must be %s, got %s", entity.ExitNodeKey, n.ID)
|
||||
}
|
||||
|
||||
ns := &schema.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: c,
|
||||
}
|
||||
|
||||
var (
|
||||
content *vo.BlockInput
|
||||
streamingOutput bool
|
||||
)
|
||||
if n.Data.Inputs.OutputEmitter != nil {
|
||||
content = n.Data.Inputs.Content
|
||||
streamingOutput = n.Data.Inputs.StreamingOutput
|
||||
}
|
||||
|
||||
if streamingOutput {
|
||||
ns.StreamConfigs = &schema.StreamConfig{
|
||||
RequireStreamingInput: true,
|
||||
}
|
||||
} else {
|
||||
ns.StreamConfigs = &schema.StreamConfig{
|
||||
RequireStreamingInput: false,
|
||||
}
|
||||
}
|
||||
|
||||
if content != nil {
|
||||
if content.Type != vo.VariableTypeString {
|
||||
return nil, fmt.Errorf("exit node's content type must be %s, got %s", vo.VariableTypeString, content.Type)
|
||||
}
|
||||
|
||||
if content.Value.Type != vo.BlockInputValueTypeLiteral {
|
||||
return nil, fmt.Errorf("exit node's content value type must be %s, got %s", vo.BlockInputValueTypeLiteral, content.Value.Type)
|
||||
}
|
||||
|
||||
c.Template = content.Value.Content.(string)
|
||||
}
|
||||
|
||||
if n.Data.Inputs.TerminatePlan == nil {
|
||||
return nil, fmt.Errorf("exit node requires a TerminatePlan")
|
||||
}
|
||||
c.TerminatePlan = *n.Data.Inputs.TerminatePlan
|
||||
|
||||
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
if c.TerminatePlan == vo.ReturnVariables {
|
||||
return &Exit{}, nil
|
||||
}
|
||||
|
||||
return &emitter.OutputEmitter{
|
||||
Template: c.Template,
|
||||
FullSources: ns.FullSources,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type Exit struct{}
|
||||
|
||||
func (e *Exit) Invoke(_ context.Context, in map[string]any) (map[string]any, error) {
|
||||
if in == nil {
|
||||
return map[string]any{}, nil
|
||||
}
|
||||
return in, nil
|
||||
}
|
||||
340
backend/domain/workflow/internal/nodes/httprequester/adapt.go
Normal file
340
backend/domain/workflow/internal/nodes/httprequester/adapt.go
Normal file
@@ -0,0 +1,340 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package httprequester
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/crypto"
|
||||
)
|
||||
|
||||
var extractBracesRegexp = regexp.MustCompile(`\{\{(.*?)}}`)
|
||||
|
||||
func extractBracesContent(s string) []string {
|
||||
matches := extractBracesRegexp.FindAllStringSubmatch(s, -1)
|
||||
var result []string
|
||||
for _, match := range matches {
|
||||
if len(match) >= 2 {
|
||||
result = append(result, match[1])
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
type ImplicitNodeDependency struct {
|
||||
NodeID string
|
||||
FieldPath compose.FieldPath
|
||||
TypeInfo *vo.TypeInfo
|
||||
}
|
||||
|
||||
func extractImplicitDependency(node *vo.Node, canvas *vo.Canvas) ([]*ImplicitNodeDependency, error) {
|
||||
dependencies := make([]*ImplicitNodeDependency, 0, len(canvas.Nodes))
|
||||
url := node.Data.Inputs.APIInfo.URL
|
||||
urlVars := extractBracesContent(url)
|
||||
hasReferred := make(map[string]bool)
|
||||
extractDependenciesFromVars := func(vars []string) error {
|
||||
for _, v := range vars {
|
||||
if strings.HasPrefix(v, "block_output_") {
|
||||
paths := strings.Split(strings.TrimPrefix(v, "block_output_"), ".")
|
||||
if len(paths) < 2 {
|
||||
return fmt.Errorf("invalid block_output_ variable: %s", v)
|
||||
}
|
||||
if hasReferred[v] {
|
||||
continue
|
||||
}
|
||||
hasReferred[v] = true
|
||||
dependencies = append(dependencies, &ImplicitNodeDependency{
|
||||
NodeID: paths[0],
|
||||
FieldPath: paths[1:],
|
||||
})
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
err := extractDependenciesFromVars(urlVars)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if node.Data.Inputs.Body.BodyType == string(BodyTypeJSON) {
|
||||
jsonVars := extractBracesContent(node.Data.Inputs.Body.BodyData.Json)
|
||||
err = extractDependenciesFromVars(jsonVars)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if node.Data.Inputs.Body.BodyType == string(BodyTypeRawText) {
|
||||
rawTextVars := extractBracesContent(node.Data.Inputs.Body.BodyData.Json)
|
||||
err = extractDependenciesFromVars(rawTextVars)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var nodeFinder func(nodes []*vo.Node, nodeID string) *vo.Node
|
||||
nodeFinder = func(nodes []*vo.Node, nodeID string) *vo.Node {
|
||||
for i := range nodes {
|
||||
if nodes[i].ID == nodeID {
|
||||
return nodes[i]
|
||||
}
|
||||
if len(nodes[i].Blocks) > 0 {
|
||||
if n := nodeFinder(nodes[i].Blocks, nodeID); n != nil {
|
||||
return n
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
for _, ds := range dependencies {
|
||||
fNode := nodeFinder(canvas.Nodes, ds.NodeID)
|
||||
if fNode == nil {
|
||||
continue
|
||||
}
|
||||
tInfoMap := make(map[string]*vo.TypeInfo, len(node.Data.Outputs))
|
||||
for _, vAny := range fNode.Data.Outputs {
|
||||
v, err := vo.ParseVariable(vAny)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tInfo, err := convert.CanvasVariableToTypeInfo(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tInfoMap[v.Name] = tInfo
|
||||
}
|
||||
tInfo, ok := getTypeInfoByPath(ds.FieldPath[0], ds.FieldPath[1:], tInfoMap)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("cannot find type info for dependency: %s", ds.FieldPath)
|
||||
}
|
||||
ds.TypeInfo = tInfo
|
||||
}
|
||||
|
||||
return dependencies, nil
|
||||
}
|
||||
|
||||
func getTypeInfoByPath(root string, properties []string, tInfoMap map[string]*vo.TypeInfo) (*vo.TypeInfo, bool) {
|
||||
if len(properties) == 0 {
|
||||
if tInfo, ok := tInfoMap[root]; ok {
|
||||
return tInfo, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
tInfo, ok := tInfoMap[root]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return getTypeInfoByPath(properties[0], properties[1:], tInfo.Properties)
|
||||
}
|
||||
|
||||
var globalVariableRegex = regexp.MustCompile(`global_variable_\w+\s*\["(.*?)"]`)
|
||||
|
||||
func setHttpRequesterInputsForNodeSchema(n *vo.Node, ns *schema.NodeSchema, implicitNodeDependencies []*ImplicitNodeDependency) (err error) {
|
||||
inputs := n.Data.Inputs
|
||||
implicitPathVars := make(map[string]bool)
|
||||
addImplicitVarsSources := func(prefix string, vars []string) error {
|
||||
for _, v := range vars {
|
||||
if strings.HasPrefix(v, "block_output_") {
|
||||
paths := strings.Split(strings.TrimPrefix(v, "block_output_"), ".")
|
||||
if len(paths) < 2 {
|
||||
return fmt.Errorf("invalid implicit var : %s", v)
|
||||
}
|
||||
for _, dep := range implicitNodeDependencies {
|
||||
if dep.NodeID == paths[0] && strings.Join(dep.FieldPath, ".") == strings.Join(paths[1:], ".") {
|
||||
pathValue := prefix + crypto.MD5HexValue(v)
|
||||
if _, visited := implicitPathVars[pathValue]; visited {
|
||||
continue
|
||||
}
|
||||
implicitPathVars[pathValue] = true
|
||||
ns.SetInputType(pathValue, dep.TypeInfo)
|
||||
ns.AddInputSource(&vo.FieldInfo{
|
||||
Path: []string{pathValue},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: vo.NodeKey(dep.NodeID),
|
||||
FromPath: dep.FieldPath,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(v, "global_variable_") {
|
||||
matches := globalVariableRegex.FindStringSubmatch(v)
|
||||
if len(matches) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
var varType vo.GlobalVarType
|
||||
if strings.HasPrefix(v, string(vo.RefSourceTypeGlobalApp)) {
|
||||
varType = vo.GlobalAPP
|
||||
} else if strings.HasPrefix(v, string(vo.RefSourceTypeGlobalUser)) {
|
||||
varType = vo.GlobalUser
|
||||
} else if strings.HasPrefix(v, string(vo.RefSourceTypeGlobalSystem)) {
|
||||
varType = vo.GlobalSystem
|
||||
} else {
|
||||
return fmt.Errorf("invalid global variable type: %s", v)
|
||||
}
|
||||
|
||||
source := vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
VariableType: &varType,
|
||||
FromPath: []string{matches[1]},
|
||||
},
|
||||
}
|
||||
|
||||
ns.AddInputSource(&vo.FieldInfo{
|
||||
Path: []string{prefix + crypto.MD5HexValue(v)},
|
||||
Source: source,
|
||||
})
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
urlVars := extractBracesContent(inputs.APIInfo.URL)
|
||||
err = addImplicitVarsSources("__apiInfo_url_", urlVars)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = applyParamsToSchema(ns, "__headers_", inputs.Headers, n.Parent())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = applyParamsToSchema(ns, "__params_", inputs.Params, n.Parent())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if inputs.Auth != nil && inputs.Auth.AuthOpen {
|
||||
authData := inputs.Auth.AuthData
|
||||
const bearerTokenKey = "__auth_authData_bearerTokenData_token"
|
||||
if inputs.Auth.AuthType == "BEARER_AUTH" {
|
||||
bearTokenParam := authData.BearerTokenData[0]
|
||||
tInfo, err := convert.CanvasBlockInputToTypeInfo(bearTokenParam.Input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ns.SetInputType(bearerTokenKey, tInfo)
|
||||
sources, err := convert.CanvasBlockInputToFieldInfo(bearTokenParam.Input, compose.FieldPath{bearerTokenKey}, n.Parent())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ns.AddInputSource(sources...)
|
||||
|
||||
}
|
||||
|
||||
if inputs.Auth.AuthType == "CUSTOM_AUTH" {
|
||||
const (
|
||||
customDataDataKey = "__auth_authData_customData_data_Key"
|
||||
customDataDataValue = "__auth_authData_customData_data_Value"
|
||||
)
|
||||
dataParams := authData.CustomData.Data
|
||||
keyParam := dataParams[0]
|
||||
keyTypeInfo, err := convert.CanvasBlockInputToTypeInfo(keyParam.Input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ns.SetInputType(customDataDataKey, keyTypeInfo)
|
||||
sources, err := convert.CanvasBlockInputToFieldInfo(keyParam.Input, compose.FieldPath{customDataDataKey}, n.Parent())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ns.AddInputSource(sources...)
|
||||
|
||||
valueParam := dataParams[1]
|
||||
valueTypeInfo, err := convert.CanvasBlockInputToTypeInfo(valueParam.Input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ns.SetInputType(customDataDataValue, valueTypeInfo)
|
||||
sources, err = convert.CanvasBlockInputToFieldInfo(valueParam.Input, compose.FieldPath{customDataDataValue}, n.Parent())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ns.AddInputSource(sources...)
|
||||
}
|
||||
}
|
||||
|
||||
switch BodyType(inputs.Body.BodyType) {
|
||||
case BodyTypeFormData:
|
||||
err = applyParamsToSchema(ns, "__body_bodyData_formData_", inputs.Body.BodyData.FormData.Data, n.Parent())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case BodyTypeFormURLEncoded:
|
||||
err = applyParamsToSchema(ns, "__body_bodyData_formURLEncoded_", inputs.Body.BodyData.FormURLEncoded, n.Parent())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case BodyTypeBinary:
|
||||
const fileURLName = "__body_bodyData_binary_fileURL"
|
||||
fileURLInput := inputs.Body.BodyData.Binary.FileURL
|
||||
ns.SetInputType(fileURLName, &vo.TypeInfo{
|
||||
Type: vo.DataTypeString,
|
||||
})
|
||||
sources, err := convert.CanvasBlockInputToFieldInfo(fileURLInput, compose.FieldPath{fileURLName}, n.Parent())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ns.AddInputSource(sources...)
|
||||
case BodyTypeJSON:
|
||||
jsonVars := extractBracesContent(inputs.Body.BodyData.Json)
|
||||
err = addImplicitVarsSources("__body_bodyData_json_", jsonVars)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case BodyTypeRawText:
|
||||
rawTextVars := extractBracesContent(inputs.Body.BodyData.RawText)
|
||||
err = addImplicitVarsSources("__body_bodyData_rawText_", rawTextVars)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func applyParamsToSchema(ns *schema.NodeSchema, prefix string, params []*vo.Param, parentNode *vo.Node) error {
|
||||
for i := range params {
|
||||
param := params[i]
|
||||
name := param.Name
|
||||
tInfo, err := convert.CanvasBlockInputToTypeInfo(param.Input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fieldName := prefix + crypto.MD5HexValue(name)
|
||||
ns.SetInputType(fieldName, tInfo)
|
||||
sources, err := convert.CanvasBlockInputToFieldInfo(param.Input, compose.FieldPath{fieldName}, parentNode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ns.AddInputSource(sources...)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -31,9 +31,14 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/crypto"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
)
|
||||
|
||||
@@ -129,7 +134,7 @@ type Request struct {
|
||||
FileURL *string
|
||||
}
|
||||
|
||||
var globalVariableReplaceRegexp = regexp.MustCompile(`global_variable_(\w+)\["(\w+)"\]`)
|
||||
var globalVariableReplaceRegexp = regexp.MustCompile(`global_variable_(\w+)\["(\w+)"]`)
|
||||
|
||||
type MD5FieldMapping struct {
|
||||
HeaderMD5Mapping map[string]string `json:"header_md_5_mapping,omitempty"` // md5 vs key
|
||||
@@ -184,49 +189,188 @@ type Config struct {
|
||||
Timeout time.Duration
|
||||
RetryTimes uint64
|
||||
|
||||
IgnoreException bool
|
||||
DefaultOutput map[string]any
|
||||
|
||||
MD5FieldMapping
|
||||
}
|
||||
|
||||
type HTTPRequester struct {
|
||||
client *http.Client
|
||||
config *Config
|
||||
}
|
||||
|
||||
func NewHTTPRequester(_ context.Context, cfg *Config) (*HTTPRequester, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("config is requried")
|
||||
func (c *Config) Adapt(_ context.Context, n *vo.Node, opts ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
options := nodes.GetAdaptOptions(opts...)
|
||||
if options.Canvas == nil {
|
||||
return nil, fmt.Errorf("canvas is requried when adapting HTTPRequester node")
|
||||
}
|
||||
|
||||
if len(cfg.Method) == 0 {
|
||||
implicitDeps, err := extractImplicitDependency(n, options.Canvas)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeHTTPRequester,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: c,
|
||||
}
|
||||
|
||||
inputs := n.Data.Inputs
|
||||
|
||||
md5FieldMapping := &MD5FieldMapping{}
|
||||
|
||||
method := inputs.APIInfo.Method
|
||||
c.Method = method
|
||||
reqURL := inputs.APIInfo.URL
|
||||
c.URLConfig = URLConfig{
|
||||
Tpl: strings.TrimSpace(reqURL),
|
||||
}
|
||||
|
||||
urlVars := extractBracesContent(reqURL)
|
||||
md5FieldMapping.SetURLFields(urlVars...)
|
||||
|
||||
md5FieldMapping.SetHeaderFields(slices.Transform(inputs.Headers, func(a *vo.Param) string {
|
||||
return a.Name
|
||||
})...)
|
||||
|
||||
md5FieldMapping.SetParamFields(slices.Transform(inputs.Params, func(a *vo.Param) string {
|
||||
return a.Name
|
||||
})...)
|
||||
|
||||
if inputs.Auth != nil && inputs.Auth.AuthOpen {
|
||||
auth := &AuthenticationConfig{}
|
||||
ty, err := convertAuthType(inputs.Auth.AuthType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
auth.Type = ty
|
||||
location, err := convertLocation(inputs.Auth.AuthData.CustomData.AddTo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
auth.Location = location
|
||||
|
||||
c.AuthConfig = auth
|
||||
}
|
||||
|
||||
bodyConfig := BodyConfig{}
|
||||
|
||||
bodyConfig.BodyType = BodyType(inputs.Body.BodyType)
|
||||
switch BodyType(inputs.Body.BodyType) {
|
||||
case BodyTypeJSON:
|
||||
jsonTpl := inputs.Body.BodyData.Json
|
||||
bodyConfig.TextJsonConfig = &TextJsonConfig{
|
||||
Tpl: jsonTpl,
|
||||
}
|
||||
jsonVars := extractBracesContent(jsonTpl)
|
||||
md5FieldMapping.SetBodyFields(jsonVars...)
|
||||
case BodyTypeFormData:
|
||||
bodyConfig.FormDataConfig = &FormDataConfig{
|
||||
FileTypeMapping: map[string]bool{},
|
||||
}
|
||||
formDataVars := make([]string, 0)
|
||||
for i := range inputs.Body.BodyData.FormData.Data {
|
||||
p := inputs.Body.BodyData.FormData.Data[i]
|
||||
formDataVars = append(formDataVars, p.Name)
|
||||
if p.Input.Type == vo.VariableTypeString && p.Input.AssistType > vo.AssistTypeNotSet && p.Input.AssistType < vo.AssistTypeTime {
|
||||
bodyConfig.FormDataConfig.FileTypeMapping[p.Name] = true
|
||||
}
|
||||
}
|
||||
|
||||
md5FieldMapping.SetBodyFields(formDataVars...)
|
||||
case BodyTypeRawText:
|
||||
TextTpl := inputs.Body.BodyData.RawText
|
||||
bodyConfig.TextPlainConfig = &TextPlainConfig{
|
||||
Tpl: TextTpl,
|
||||
}
|
||||
textPlainVars := extractBracesContent(TextTpl)
|
||||
md5FieldMapping.SetBodyFields(textPlainVars...)
|
||||
case BodyTypeFormURLEncoded:
|
||||
formURLEncodedVars := make([]string, 0)
|
||||
for _, p := range inputs.Body.BodyData.FormURLEncoded {
|
||||
formURLEncodedVars = append(formURLEncodedVars, p.Name)
|
||||
}
|
||||
md5FieldMapping.SetBodyFields(formURLEncodedVars...)
|
||||
}
|
||||
c.BodyConfig = bodyConfig
|
||||
c.MD5FieldMapping = *md5FieldMapping
|
||||
|
||||
if inputs.Setting != nil {
|
||||
c.Timeout = time.Duration(inputs.Setting.Timeout) * time.Second
|
||||
c.RetryTimes = uint64(inputs.Setting.RetryTimes)
|
||||
}
|
||||
|
||||
if err := setHttpRequesterInputsForNodeSchema(n, ns, implicitDeps); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func convertAuthType(auth string) (AuthType, error) {
|
||||
switch auth {
|
||||
case "CUSTOM_AUTH":
|
||||
return Custom, nil
|
||||
case "BEARER_AUTH":
|
||||
return BearToken, nil
|
||||
default:
|
||||
return AuthType(0), fmt.Errorf("invalid auth type")
|
||||
}
|
||||
}
|
||||
|
||||
func convertLocation(l string) (Location, error) {
|
||||
switch l {
|
||||
case "header":
|
||||
return Header, nil
|
||||
case "query":
|
||||
return QueryParam, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("invalid location")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (c *Config) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
if len(c.Method) == 0 {
|
||||
return nil, fmt.Errorf("method is requried")
|
||||
}
|
||||
|
||||
hg := &HTTPRequester{}
|
||||
hg := &HTTPRequester{
|
||||
urlConfig: c.URLConfig,
|
||||
method: c.Method,
|
||||
retryTimes: c.RetryTimes,
|
||||
authConfig: c.AuthConfig,
|
||||
bodyConfig: c.BodyConfig,
|
||||
md5FieldMapping: c.MD5FieldMapping,
|
||||
}
|
||||
client := http.DefaultClient
|
||||
if cfg.Timeout > 0 {
|
||||
client.Timeout = cfg.Timeout
|
||||
if c.Timeout > 0 {
|
||||
client.Timeout = c.Timeout
|
||||
}
|
||||
|
||||
hg.client = client
|
||||
hg.config = cfg
|
||||
|
||||
return hg, nil
|
||||
}
|
||||
|
||||
type HTTPRequester struct {
|
||||
client *http.Client
|
||||
urlConfig URLConfig
|
||||
authConfig *AuthenticationConfig
|
||||
bodyConfig BodyConfig
|
||||
method string
|
||||
retryTimes uint64
|
||||
md5FieldMapping MD5FieldMapping
|
||||
}
|
||||
|
||||
func (hg *HTTPRequester) Invoke(ctx context.Context, input map[string]any) (output map[string]any, err error) {
|
||||
var (
|
||||
req = &Request{}
|
||||
method = hg.config.Method
|
||||
retryTimes = hg.config.RetryTimes
|
||||
method = hg.method
|
||||
retryTimes = hg.retryTimes
|
||||
body io.ReadCloser
|
||||
contentType string
|
||||
response *http.Response
|
||||
)
|
||||
|
||||
req, err = hg.config.parserToRequest(input)
|
||||
req, err = hg.parserToRequest(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -236,7 +380,7 @@ func (hg *HTTPRequester) Invoke(ctx context.Context, input map[string]any) (outp
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
httpURL, err := nodes.TemplateRender(hg.config.URLConfig.Tpl, req.URLVars)
|
||||
httpURL, err := nodes.TemplateRender(hg.urlConfig.Tpl, req.URLVars)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -255,8 +399,8 @@ func (hg *HTTPRequester) Invoke(ctx context.Context, input map[string]any) (outp
|
||||
params.Set(key, value)
|
||||
}
|
||||
|
||||
if hg.config.AuthConfig != nil {
|
||||
httpRequest.Header, params, err = hg.config.AuthConfig.addAuthentication(ctx, req.Authentication, httpRequest.Header, params)
|
||||
if hg.authConfig != nil {
|
||||
httpRequest.Header, params, err = hg.authConfig.addAuthentication(ctx, req.Authentication, httpRequest.Header, params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -264,7 +408,7 @@ func (hg *HTTPRequester) Invoke(ctx context.Context, input map[string]any) (outp
|
||||
u.RawQuery = params.Encode()
|
||||
httpRequest.URL = u
|
||||
|
||||
body, contentType, err = hg.config.BodyConfig.getBodyAndContentType(ctx, req)
|
||||
body, contentType, err = hg.bodyConfig.getBodyAndContentType(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -479,18 +623,16 @@ func httpGet(ctx context.Context, url string) (*http.Response, error) {
|
||||
}
|
||||
|
||||
func (hg *HTTPRequester) ToCallbackInput(_ context.Context, input map[string]any) (map[string]any, error) {
|
||||
var (
|
||||
request = &Request{}
|
||||
config = hg.config
|
||||
)
|
||||
request, err := hg.config.parserToRequest(input)
|
||||
var request = &Request{}
|
||||
|
||||
request, err := hg.parserToRequest(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := make(map[string]any)
|
||||
result["method"] = config.Method
|
||||
result["method"] = hg.method
|
||||
|
||||
u, err := nodes.TemplateRender(config.URLConfig.Tpl, request.URLVars)
|
||||
u, err := nodes.TemplateRender(hg.urlConfig.Tpl, request.URLVars)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -508,13 +650,13 @@ func (hg *HTTPRequester) ToCallbackInput(_ context.Context, input map[string]any
|
||||
}
|
||||
result["header"] = headers
|
||||
result["auth"] = nil
|
||||
if config.AuthConfig != nil {
|
||||
if config.AuthConfig.Type == Custom {
|
||||
if hg.authConfig != nil {
|
||||
if hg.authConfig.Type == Custom {
|
||||
result["auth"] = map[string]interface{}{
|
||||
"Key": request.Authentication.Key,
|
||||
"Value": request.Authentication.Value,
|
||||
}
|
||||
} else if config.AuthConfig.Type == BearToken {
|
||||
} else if hg.authConfig.Type == BearToken {
|
||||
result["auth"] = map[string]interface{}{
|
||||
"token": request.Authentication.Token,
|
||||
}
|
||||
@@ -522,9 +664,9 @@ func (hg *HTTPRequester) ToCallbackInput(_ context.Context, input map[string]any
|
||||
}
|
||||
|
||||
result["body"] = nil
|
||||
switch config.BodyConfig.BodyType {
|
||||
switch hg.bodyConfig.BodyType {
|
||||
case BodyTypeJSON:
|
||||
js, err := nodes.TemplateRender(config.BodyConfig.TextJsonConfig.Tpl, request.JsonVars)
|
||||
js, err := nodes.TemplateRender(hg.bodyConfig.TextJsonConfig.Tpl, request.JsonVars)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -535,7 +677,7 @@ func (hg *HTTPRequester) ToCallbackInput(_ context.Context, input map[string]any
|
||||
}
|
||||
result["body"] = ret
|
||||
case BodyTypeRawText:
|
||||
tx, err := nodes.TemplateRender(config.BodyConfig.TextPlainConfig.Tpl, request.TextPlainVars)
|
||||
tx, err := nodes.TemplateRender(hg.bodyConfig.TextPlainConfig.Tpl, request.TextPlainVars)
|
||||
if err != nil {
|
||||
|
||||
return nil, err
|
||||
@@ -569,7 +711,7 @@ const (
|
||||
bodyBinaryFileURLPrefix = "binary_fileURL"
|
||||
)
|
||||
|
||||
func (cfg *Config) parserToRequest(input map[string]any) (*Request, error) {
|
||||
func (hg *HTTPRequester) parserToRequest(input map[string]any) (*Request, error) {
|
||||
request := &Request{
|
||||
URLVars: make(map[string]any),
|
||||
Headers: make(map[string]string),
|
||||
@@ -583,7 +725,7 @@ func (cfg *Config) parserToRequest(input map[string]any) (*Request, error) {
|
||||
for key, value := range input {
|
||||
if strings.HasPrefix(key, apiInfoURLPrefix) {
|
||||
urlMD5 := strings.TrimPrefix(key, apiInfoURLPrefix)
|
||||
if urlKey, ok := cfg.URLMD5Mapping[urlMD5]; ok {
|
||||
if urlKey, ok := hg.md5FieldMapping.URLMD5Mapping[urlMD5]; ok {
|
||||
if strings.HasPrefix(urlKey, "global_variable_") {
|
||||
urlKey = globalVariableReplaceRegexp.ReplaceAllString(urlKey, "global_variable_$1.$2")
|
||||
}
|
||||
@@ -592,13 +734,13 @@ func (cfg *Config) parserToRequest(input map[string]any) (*Request, error) {
|
||||
}
|
||||
if strings.HasPrefix(key, headersPrefix) {
|
||||
headerKeyMD5 := strings.TrimPrefix(key, headersPrefix)
|
||||
if headerKey, ok := cfg.HeaderMD5Mapping[headerKeyMD5]; ok {
|
||||
if headerKey, ok := hg.md5FieldMapping.HeaderMD5Mapping[headerKeyMD5]; ok {
|
||||
request.Headers[headerKey] = value.(string)
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(key, paramsPrefix) {
|
||||
paramKeyMD5 := strings.TrimPrefix(key, paramsPrefix)
|
||||
if paramKey, ok := cfg.ParamMD5Mapping[paramKeyMD5]; ok {
|
||||
if paramKey, ok := hg.md5FieldMapping.ParamMD5Mapping[paramKeyMD5]; ok {
|
||||
request.Params[paramKey] = value.(string)
|
||||
}
|
||||
}
|
||||
@@ -622,7 +764,7 @@ func (cfg *Config) parserToRequest(input map[string]any) (*Request, error) {
|
||||
bodyKey := strings.TrimPrefix(key, bodyDataPrefix)
|
||||
if strings.HasPrefix(bodyKey, bodyJsonPrefix) {
|
||||
jsonMd5Key := strings.TrimPrefix(bodyKey, bodyJsonPrefix)
|
||||
if jsonKey, ok := cfg.BodyMD5Mapping[jsonMd5Key]; ok {
|
||||
if jsonKey, ok := hg.md5FieldMapping.BodyMD5Mapping[jsonMd5Key]; ok {
|
||||
if strings.HasPrefix(jsonKey, "global_variable_") {
|
||||
jsonKey = globalVariableReplaceRegexp.ReplaceAllString(jsonKey, "global_variable_$1.$2")
|
||||
}
|
||||
@@ -632,7 +774,7 @@ func (cfg *Config) parserToRequest(input map[string]any) (*Request, error) {
|
||||
}
|
||||
if strings.HasPrefix(bodyKey, bodyFormDataPrefix) {
|
||||
formDataMd5Key := strings.TrimPrefix(bodyKey, bodyFormDataPrefix)
|
||||
if formDataKey, ok := cfg.BodyMD5Mapping[formDataMd5Key]; ok {
|
||||
if formDataKey, ok := hg.md5FieldMapping.BodyMD5Mapping[formDataMd5Key]; ok {
|
||||
request.FormDataVars[formDataKey] = value.(string)
|
||||
}
|
||||
|
||||
@@ -640,14 +782,14 @@ func (cfg *Config) parserToRequest(input map[string]any) (*Request, error) {
|
||||
|
||||
if strings.HasPrefix(bodyKey, bodyFormURLEncodedPrefix) {
|
||||
formURLEncodeMd5Key := strings.TrimPrefix(bodyKey, bodyFormURLEncodedPrefix)
|
||||
if formURLEncodeKey, ok := cfg.BodyMD5Mapping[formURLEncodeMd5Key]; ok {
|
||||
if formURLEncodeKey, ok := hg.md5FieldMapping.BodyMD5Mapping[formURLEncodeMd5Key]; ok {
|
||||
request.FormURLEncodedVars[formURLEncodeKey] = value.(string)
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(bodyKey, bodyRawTextPrefix) {
|
||||
rawTextMd5Key := strings.TrimPrefix(bodyKey, bodyRawTextPrefix)
|
||||
if rawTextKey, ok := cfg.BodyMD5Mapping[rawTextMd5Key]; ok {
|
||||
if rawTextKey, ok := hg.md5FieldMapping.BodyMD5Mapping[rawTextMd5Key]; ok {
|
||||
if strings.HasPrefix(rawTextKey, "global_variable_") {
|
||||
rawTextKey = globalVariableReplaceRegexp.ReplaceAllString(rawTextKey, "global_variable_$1.$2")
|
||||
}
|
||||
|
||||
@@ -28,6 +28,7 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/crypto"
|
||||
)
|
||||
|
||||
@@ -68,7 +69,7 @@ func TestInvoke(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
hg, err := NewHTTPRequester(context.Background(), cfg)
|
||||
hg, err := cfg.Build(context.Background(), &schema.NodeSchema{})
|
||||
assert.NoError(t, err)
|
||||
m := map[string]any{
|
||||
"__apiInfo_url_" + crypto.MD5HexValue("url_v1"): "v1",
|
||||
@@ -78,7 +79,7 @@ func TestInvoke(t *testing.T) {
|
||||
"__params_" + crypto.MD5HexValue("p2"): "v2",
|
||||
}
|
||||
|
||||
result, err := hg.Invoke(context.Background(), m)
|
||||
result, err := hg.(*HTTPRequester).Invoke(context.Background(), m)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, `{"message":"success"}`, result["body"])
|
||||
assert.Equal(t, int64(200), result["statusCode"])
|
||||
@@ -157,7 +158,7 @@ func TestInvoke(t *testing.T) {
|
||||
}
|
||||
|
||||
// Create an HTTPRequest instance
|
||||
hg, err := NewHTTPRequester(context.Background(), cfg)
|
||||
hg, err := cfg.Build(context.Background(), &schema.NodeSchema{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
m := map[string]any{
|
||||
@@ -171,7 +172,7 @@ func TestInvoke(t *testing.T) {
|
||||
"__body_bodyData_formData_" + crypto.MD5HexValue("fileURL"): fileServer.URL,
|
||||
}
|
||||
|
||||
result, err := hg.Invoke(context.Background(), m)
|
||||
result, err := hg.(*HTTPRequester).Invoke(context.Background(), m)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, `{"message":"success"}`, result["body"])
|
||||
assert.Equal(t, int64(200), result["statusCode"])
|
||||
@@ -228,7 +229,7 @@ func TestInvoke(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
hg, err := NewHTTPRequester(context.Background(), cfg)
|
||||
hg, err := cfg.Build(context.Background(), &schema.NodeSchema{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
m := map[string]any{
|
||||
@@ -241,7 +242,7 @@ func TestInvoke(t *testing.T) {
|
||||
"__body_bodyData_rawText_" + crypto.MD5HexValue("v2"): "v2",
|
||||
}
|
||||
|
||||
result, err := hg.Invoke(context.Background(), m)
|
||||
result, err := hg.(*HTTPRequester).Invoke(context.Background(), m)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, `{"message":"success"}`, result["body"])
|
||||
assert.Equal(t, int64(200), result["statusCode"])
|
||||
@@ -303,7 +304,7 @@ func TestInvoke(t *testing.T) {
|
||||
}
|
||||
|
||||
// Create an HTTPRequest instance
|
||||
hg, err := NewHTTPRequester(context.Background(), cfg)
|
||||
hg, err := cfg.Build(context.Background(), &schema.NodeSchema{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
m := map[string]any{
|
||||
@@ -316,7 +317,7 @@ func TestInvoke(t *testing.T) {
|
||||
"__body_bodyData_json_" + crypto.MD5HexValue("v2"): "v2",
|
||||
}
|
||||
|
||||
result, err := hg.Invoke(context.Background(), m)
|
||||
result, err := hg.(*HTTPRequester).Invoke(context.Background(), m)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, `{"message":"success"}`, result["body"])
|
||||
assert.Equal(t, int64(200), result["statusCode"])
|
||||
@@ -376,7 +377,7 @@ func TestInvoke(t *testing.T) {
|
||||
}
|
||||
|
||||
// Create an HTTPRequest instance
|
||||
hg, err := NewHTTPRequester(context.Background(), cfg)
|
||||
hg, err := cfg.Build(context.Background(), &schema.NodeSchema{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
m := map[string]any{
|
||||
@@ -388,7 +389,7 @@ func TestInvoke(t *testing.T) {
|
||||
"__body_bodyData_binary_fileURL" + crypto.MD5HexValue("v1"): fileServer.URL,
|
||||
}
|
||||
|
||||
result, err := hg.Invoke(context.Background(), m)
|
||||
result, err := hg.(*HTTPRequester).Invoke(context.Background(), m)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, `{"message":"success"}`, result["body"])
|
||||
assert.Equal(t, int64(200), result["statusCode"])
|
||||
|
||||
@@ -18,26 +18,167 @@ package intentdetector
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/components/prompt"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/spf13/cast"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Intents []string
|
||||
SystemPrompt string
|
||||
IsFastMode bool
|
||||
ChatModel model.BaseChatModel
|
||||
LLMParams *model.LLMParams
|
||||
}
|
||||
|
||||
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema2.NodeSchema, error) {
|
||||
ns := &schema2.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeIntentDetector,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: c,
|
||||
}
|
||||
|
||||
param := n.Data.Inputs.LLMParam
|
||||
if param == nil {
|
||||
return nil, fmt.Errorf("intent detector node's llmParam is nil")
|
||||
}
|
||||
|
||||
llmParam, ok := param.(vo.IntentDetectorLLMParam)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("llm node's llmParam must be LLMParam, got %v", llmParam)
|
||||
}
|
||||
|
||||
paramBytes, err := sonic.Marshal(param)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var intentDetectorConfig = &vo.IntentDetectorLLMConfig{}
|
||||
|
||||
err = sonic.Unmarshal(paramBytes, &intentDetectorConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
modelLLMParams := &model.LLMParams{}
|
||||
modelLLMParams.ModelType = int64(intentDetectorConfig.ModelType)
|
||||
modelLLMParams.ModelName = intentDetectorConfig.ModelName
|
||||
modelLLMParams.TopP = intentDetectorConfig.TopP
|
||||
modelLLMParams.Temperature = intentDetectorConfig.Temperature
|
||||
modelLLMParams.MaxTokens = intentDetectorConfig.MaxTokens
|
||||
modelLLMParams.ResponseFormat = model.ResponseFormat(intentDetectorConfig.ResponseFormat)
|
||||
modelLLMParams.SystemPrompt = intentDetectorConfig.SystemPrompt.Value.Content.(string)
|
||||
|
||||
c.LLMParams = modelLLMParams
|
||||
c.SystemPrompt = modelLLMParams.SystemPrompt
|
||||
|
||||
var intents = make([]string, 0, len(n.Data.Inputs.Intents))
|
||||
for _, it := range n.Data.Inputs.Intents {
|
||||
intents = append(intents, it.Name)
|
||||
}
|
||||
c.Intents = intents
|
||||
|
||||
if n.Data.Inputs.Mode == "top_speed" {
|
||||
c.IsFastMode = true
|
||||
}
|
||||
|
||||
if err = convert.SetInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func (c *Config) Build(ctx context.Context, _ *schema2.NodeSchema, _ ...schema2.BuildOption) (any, error) {
|
||||
if !c.IsFastMode && c.LLMParams == nil {
|
||||
return nil, errors.New("config chat model is required")
|
||||
}
|
||||
|
||||
if len(c.Intents) == 0 {
|
||||
return nil, errors.New("config intents is required")
|
||||
}
|
||||
|
||||
m, _, err := model.GetManager().GetModel(ctx, c.LLMParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
chain := compose.NewChain[map[string]any, *schema.Message]()
|
||||
|
||||
spt := ternary.IFElse[string](c.IsFastMode, FastModeSystemIntentPrompt, SystemIntentPrompt)
|
||||
|
||||
intents, err := toIntentString(c.Intents)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sptTemplate, err := nodes.TemplateRender(spt, map[string]interface{}{
|
||||
"intents": intents,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
prompts := prompt.FromMessages(schema.Jinja2,
|
||||
&schema.Message{Content: sptTemplate, Role: schema.System},
|
||||
&schema.Message{Content: "{{query}}", Role: schema.User})
|
||||
|
||||
r, err := chain.AppendChatTemplate(prompts).AppendChatModel(m).Compile(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &IntentDetector{
|
||||
isFastMode: c.IsFastMode,
|
||||
systemPrompt: c.SystemPrompt,
|
||||
runner: r,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Config) BuildBranch(_ context.Context) (
|
||||
func(ctx context.Context, nodeOutput map[string]any) (int64, bool, error), bool) {
|
||||
return func(ctx context.Context, nodeOutput map[string]any) (int64, bool, error) {
|
||||
classificationId, ok := nodeOutput[classificationID]
|
||||
if !ok {
|
||||
return -1, false, fmt.Errorf("failed to take classification id from input map: %v", nodeOutput)
|
||||
}
|
||||
|
||||
cID64, ok := classificationId.(int64)
|
||||
if !ok {
|
||||
return -1, false, fmt.Errorf("classificationID not of type int64, actual type: %T", classificationId)
|
||||
}
|
||||
|
||||
if cID64 == 0 {
|
||||
return -1, true, nil
|
||||
}
|
||||
|
||||
return cID64 - 1, false, nil
|
||||
}, true
|
||||
}
|
||||
|
||||
func (c *Config) ExpectPorts(ctx context.Context, n *vo.Node) []string {
|
||||
expects := make([]string, len(n.Data.Inputs.Intents)+1)
|
||||
expects[0] = schema2.PortDefault
|
||||
for i := 0; i < len(n.Data.Inputs.Intents); i++ {
|
||||
expects[i+1] = fmt.Sprintf(schema2.PortBranchFormat, i)
|
||||
}
|
||||
return expects
|
||||
}
|
||||
|
||||
const SystemIntentPrompt = `
|
||||
@@ -95,71 +236,39 @@ Note:
|
||||
##Limit
|
||||
- Please do not reply in text.`
|
||||
|
||||
const classificationID = "classificationId"
|
||||
|
||||
type IntentDetector struct {
|
||||
config *Config
|
||||
runner compose.Runnable[map[string]any, *schema.Message]
|
||||
}
|
||||
|
||||
func NewIntentDetector(ctx context.Context, cfg *Config) (*IntentDetector, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("cfg is required")
|
||||
}
|
||||
if !cfg.IsFastMode && cfg.ChatModel == nil {
|
||||
return nil, errors.New("config chat model is required")
|
||||
}
|
||||
|
||||
if len(cfg.Intents) == 0 {
|
||||
return nil, errors.New("config intents is required")
|
||||
}
|
||||
chain := compose.NewChain[map[string]any, *schema.Message]()
|
||||
|
||||
spt := ternary.IFElse[string](cfg.IsFastMode, FastModeSystemIntentPrompt, SystemIntentPrompt)
|
||||
|
||||
sptTemplate, err := nodes.TemplateRender(spt, map[string]interface{}{
|
||||
"intents": toIntentString(cfg.Intents),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
prompts := prompt.FromMessages(schema.Jinja2,
|
||||
&schema.Message{Content: sptTemplate, Role: schema.System},
|
||||
&schema.Message{Content: "{{query}}", Role: schema.User})
|
||||
|
||||
r, err := chain.AppendChatTemplate(prompts).AppendChatModel(cfg.ChatModel).Compile(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &IntentDetector{
|
||||
config: cfg,
|
||||
runner: r,
|
||||
}, nil
|
||||
isFastMode bool
|
||||
systemPrompt string
|
||||
runner compose.Runnable[map[string]any, *schema.Message]
|
||||
}
|
||||
|
||||
func (id *IntentDetector) parseToNodeOut(content string) (map[string]any, error) {
|
||||
nodeOutput := make(map[string]any)
|
||||
nodeOutput["classificationId"] = 0
|
||||
if content == "" {
|
||||
return nodeOutput, errors.New("content is empty")
|
||||
return nil, errors.New("intent detector's LLM output content is empty")
|
||||
}
|
||||
|
||||
if id.config.IsFastMode {
|
||||
if id.isFastMode {
|
||||
cid, err := strconv.ParseInt(content, 10, 64)
|
||||
if err != nil {
|
||||
return nodeOutput, err
|
||||
return nil, err
|
||||
}
|
||||
nodeOutput["classificationId"] = cid
|
||||
return nodeOutput, nil
|
||||
return map[string]any{
|
||||
classificationID: cid,
|
||||
}, nil
|
||||
}
|
||||
|
||||
leftIndex := strings.Index(content, "{")
|
||||
rightIndex := strings.Index(content, "}")
|
||||
if leftIndex == -1 || rightIndex == -1 {
|
||||
return nodeOutput, errors.New("content is invalid")
|
||||
return nil, fmt.Errorf("intent detector's LLM output content is invalid: %s", content)
|
||||
}
|
||||
|
||||
err := json.Unmarshal([]byte(content[leftIndex:rightIndex+1]), &nodeOutput)
|
||||
var nodeOutput map[string]any
|
||||
err := sonic.UnmarshalString(content[leftIndex:rightIndex+1], &nodeOutput)
|
||||
if err != nil {
|
||||
return nodeOutput, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nodeOutput, nil
|
||||
@@ -178,8 +287,8 @@ func (id *IntentDetector) Invoke(ctx context.Context, input map[string]any) (map
|
||||
|
||||
vars := make(map[string]any)
|
||||
vars["query"] = queryStr
|
||||
if !id.config.IsFastMode {
|
||||
ad, err := nodes.TemplateRender(id.config.SystemPrompt, map[string]any{"query": query})
|
||||
if !id.isFastMode {
|
||||
ad, err := nodes.TemplateRender(id.systemPrompt, map[string]any{"query": query})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -193,7 +302,7 @@ func (id *IntentDetector) Invoke(ctx context.Context, input map[string]any) (map
|
||||
return id.parseToNodeOut(o.Content)
|
||||
}
|
||||
|
||||
func toIntentString(its []string) string {
|
||||
func toIntentString(its []string) (string, error) {
|
||||
type IntentVariableItem struct {
|
||||
ClassificationID int64 `json:"classificationId"`
|
||||
Content string `json:"content"`
|
||||
@@ -207,6 +316,6 @@ func toIntentString(its []string) string {
|
||||
Content: it,
|
||||
})
|
||||
}
|
||||
itsBytes, _ := json.Marshal(vs)
|
||||
return string(itsBytes)
|
||||
|
||||
return sonic.MarshalString(vs)
|
||||
}
|
||||
|
||||
@@ -1,88 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package intentdetector
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type mockChatModel struct {
|
||||
topSeed bool
|
||||
}
|
||||
|
||||
func (m mockChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) {
|
||||
if m.topSeed {
|
||||
return &schema.Message{
|
||||
Content: "1",
|
||||
}, nil
|
||||
}
|
||||
return &schema.Message{
|
||||
Content: `{"classificationId":1,"reason":"高兴"}`,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m mockChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m mockChatModel) BindTools(tools []*schema.ToolInfo) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestNewIntentDetector(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
t.Run("fast mode", func(t *testing.T) {
|
||||
dt, err := NewIntentDetector(ctx, &Config{
|
||||
Intents: []string{"高兴", "悲伤"},
|
||||
IsFastMode: true,
|
||||
ChatModel: &mockChatModel{topSeed: true},
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
|
||||
ret, err := dt.Invoke(ctx, map[string]any{
|
||||
"query": "我考了100分",
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, ret["classificationId"], int64(1))
|
||||
})
|
||||
|
||||
t.Run("full mode", func(t *testing.T) {
|
||||
|
||||
dt, err := NewIntentDetector(ctx, &Config{
|
||||
Intents: []string{"高兴", "悲伤"},
|
||||
IsFastMode: false,
|
||||
ChatModel: &mockChatModel{},
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
|
||||
ret, err := dt.Invoke(ctx, map[string]any{
|
||||
"query": "我考了100分",
|
||||
})
|
||||
fmt.Println(err)
|
||||
assert.Nil(t, err)
|
||||
fmt.Println(ret)
|
||||
assert.Equal(t, ret["classificationId"], float64(1))
|
||||
assert.Equal(t, ret["reason"], "高兴")
|
||||
})
|
||||
|
||||
}
|
||||
@@ -20,8 +20,11 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
@@ -34,32 +37,42 @@ const (
|
||||
warningsKey = "deserialization_warnings"
|
||||
)
|
||||
|
||||
type DeserializationConfig struct {
|
||||
OutputFields map[string]*vo.TypeInfo `json:"outputFields,omitempty"`
|
||||
type DeserializationConfig struct{}
|
||||
|
||||
func (d *DeserializationConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (
|
||||
*schema.NodeSchema, error) {
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeJsonDeserialization,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: d,
|
||||
}
|
||||
|
||||
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
type Deserializer struct {
|
||||
config *DeserializationConfig
|
||||
typeInfo *vo.TypeInfo
|
||||
}
|
||||
|
||||
func NewJsonDeserializer(_ context.Context, cfg *DeserializationConfig) (*Deserializer, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("config required")
|
||||
}
|
||||
if cfg.OutputFields == nil {
|
||||
return nil, fmt.Errorf("OutputFields is required for deserialization")
|
||||
}
|
||||
typeInfo := cfg.OutputFields[OutputKeyDeserialization]
|
||||
if typeInfo == nil {
|
||||
func (d *DeserializationConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
typeInfo, ok := ns.OutputTypes[OutputKeyDeserialization]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no output field specified in deserialization config")
|
||||
}
|
||||
return &Deserializer{
|
||||
config: cfg,
|
||||
typeInfo: typeInfo,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type Deserializer struct {
|
||||
typeInfo *vo.TypeInfo
|
||||
}
|
||||
|
||||
func (jd *Deserializer) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
jsonStrValue := input[InputKeyDeserialization]
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
)
|
||||
@@ -31,19 +32,9 @@ import (
|
||||
func TestNewJsonDeserializer(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Test with nil config
|
||||
_, err := NewJsonDeserializer(ctx, nil)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "config required")
|
||||
|
||||
// Test with missing OutputFields config
|
||||
_, err = NewJsonDeserializer(ctx, &DeserializationConfig{})
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "OutputFields is required")
|
||||
|
||||
// Test with missing output key in OutputFields
|
||||
_, err = NewJsonDeserializer(ctx, &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
_, err := (&DeserializationConfig{}).Build(ctx, &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
"testKey": {Type: vo.DataTypeString},
|
||||
},
|
||||
})
|
||||
@@ -51,12 +42,12 @@ func TestNewJsonDeserializer(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "no output field specified in deserialization config")
|
||||
|
||||
// Test with valid config
|
||||
validConfig := &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
validConfig := &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
OutputKeyDeserialization: {Type: vo.DataTypeString},
|
||||
},
|
||||
}
|
||||
processor, err := NewJsonDeserializer(ctx, validConfig)
|
||||
processor, err := (&DeserializationConfig{}).Build(ctx, validConfig)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, processor)
|
||||
}
|
||||
@@ -65,16 +56,16 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Base type test config
|
||||
baseTypeConfig := &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {Type: vo.DataTypeString},
|
||||
baseTypeConfig := &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
OutputKeyDeserialization: {Type: vo.DataTypeString},
|
||||
},
|
||||
}
|
||||
|
||||
// Object type test config
|
||||
objectTypeConfig := &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {
|
||||
objectTypeConfig := &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
OutputKeyDeserialization: {
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
"name": {Type: vo.DataTypeString, Required: true},
|
||||
@@ -85,9 +76,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
|
||||
}
|
||||
|
||||
// Array type test config
|
||||
arrayTypeConfig := &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {
|
||||
arrayTypeConfig := &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
OutputKeyDeserialization: {
|
||||
Type: vo.DataTypeArray,
|
||||
ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeInteger},
|
||||
},
|
||||
@@ -95,9 +86,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
|
||||
}
|
||||
|
||||
// Nested array object test config
|
||||
nestedArrayConfig := &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {
|
||||
nestedArrayConfig := &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
OutputKeyDeserialization: {
|
||||
Type: vo.DataTypeArray,
|
||||
ElemTypeInfo: &vo.TypeInfo{
|
||||
Type: vo.DataTypeObject,
|
||||
@@ -113,7 +104,7 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
|
||||
// Test cases
|
||||
tests := []struct {
|
||||
name string
|
||||
config *DeserializationConfig
|
||||
config *schema.NodeSchema
|
||||
inputJSON string
|
||||
expectedOutput any
|
||||
expectErr bool
|
||||
@@ -127,9 +118,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
|
||||
expectWarnings: 0,
|
||||
}, {
|
||||
name: "Test integer deserialization",
|
||||
config: &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {Type: vo.DataTypeInteger},
|
||||
config: &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
OutputKeyDeserialization: {Type: vo.DataTypeInteger},
|
||||
},
|
||||
},
|
||||
inputJSON: `123`,
|
||||
@@ -138,9 +129,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
|
||||
expectWarnings: 0,
|
||||
}, {
|
||||
name: "Test boolean deserialization",
|
||||
config: &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {Type: vo.DataTypeBoolean},
|
||||
config: &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
OutputKeyDeserialization: {Type: vo.DataTypeBoolean},
|
||||
},
|
||||
},
|
||||
inputJSON: `true`,
|
||||
@@ -180,9 +171,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
|
||||
expectWarnings: 0,
|
||||
}, {
|
||||
name: "Test type mismatch warning",
|
||||
config: &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {Type: vo.DataTypeInteger},
|
||||
config: &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
OutputKeyDeserialization: {Type: vo.DataTypeInteger},
|
||||
},
|
||||
},
|
||||
inputJSON: `"not a number"`,
|
||||
@@ -198,9 +189,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
|
||||
expectWarnings: 0,
|
||||
}, {
|
||||
name: "Test string to integer conversion",
|
||||
config: &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {Type: vo.DataTypeInteger},
|
||||
config: &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
OutputKeyDeserialization: {Type: vo.DataTypeInteger},
|
||||
},
|
||||
},
|
||||
inputJSON: `"123"`,
|
||||
@@ -209,9 +200,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
|
||||
expectWarnings: 0,
|
||||
}, {
|
||||
name: "Test float to integer conversion (integer part)",
|
||||
config: &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {Type: vo.DataTypeInteger},
|
||||
config: &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
OutputKeyDeserialization: {Type: vo.DataTypeInteger},
|
||||
},
|
||||
},
|
||||
inputJSON: `123.0`,
|
||||
@@ -220,9 +211,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
|
||||
expectWarnings: 0,
|
||||
}, {
|
||||
name: "Test float to integer conversion (non-integer part)",
|
||||
config: &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {Type: vo.DataTypeInteger},
|
||||
config: &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
OutputKeyDeserialization: {Type: vo.DataTypeInteger},
|
||||
},
|
||||
},
|
||||
inputJSON: `123.5`,
|
||||
@@ -231,9 +222,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
|
||||
expectWarnings: 0,
|
||||
}, {
|
||||
name: "Test boolean to integer conversion",
|
||||
config: &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {Type: vo.DataTypeInteger},
|
||||
config: &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
OutputKeyDeserialization: {Type: vo.DataTypeInteger},
|
||||
},
|
||||
},
|
||||
inputJSON: `true`,
|
||||
@@ -242,9 +233,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
|
||||
expectWarnings: 1,
|
||||
}, {
|
||||
name: "Test string to boolean conversion",
|
||||
config: &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {Type: vo.DataTypeBoolean},
|
||||
config: &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
OutputKeyDeserialization: {Type: vo.DataTypeBoolean},
|
||||
},
|
||||
},
|
||||
inputJSON: `"true"`,
|
||||
@@ -252,10 +243,11 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
|
||||
expectErr: false,
|
||||
expectWarnings: 0,
|
||||
}, {
|
||||
name: "Test string to integer conversion in nested object",
|
||||
config: &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {
|
||||
name: "Test string to integer conversion in nested object",
|
||||
inputJSON: `{"age":"456"}`,
|
||||
config: &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
OutputKeyDeserialization: {
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
"age": {Type: vo.DataTypeInteger},
|
||||
@@ -263,15 +255,14 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
inputJSON: `{"age":"456"}`,
|
||||
expectedOutput: map[string]any{"age": 456},
|
||||
expectErr: false,
|
||||
expectWarnings: 0,
|
||||
}, {
|
||||
name: "Test string to integer conversion for array elements",
|
||||
config: &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {
|
||||
config: &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
OutputKeyDeserialization: {
|
||||
Type: vo.DataTypeArray,
|
||||
ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeInteger},
|
||||
},
|
||||
@@ -283,9 +274,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
|
||||
expectWarnings: 0,
|
||||
}, {
|
||||
name: "Test string with non-numeric characters to integer conversion",
|
||||
config: &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {Type: vo.DataTypeInteger},
|
||||
config: &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
OutputKeyDeserialization: {Type: vo.DataTypeInteger},
|
||||
},
|
||||
},
|
||||
inputJSON: `"123abc"`,
|
||||
@@ -294,9 +285,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
|
||||
expectWarnings: 1,
|
||||
}, {
|
||||
name: "Test type mismatch in nested object field",
|
||||
config: &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {
|
||||
config: &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
OutputKeyDeserialization: {
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
"score": {Type: vo.DataTypeInteger},
|
||||
@@ -310,9 +301,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
|
||||
expectWarnings: 1,
|
||||
}, {
|
||||
name: "Test partial conversion failure in array elements",
|
||||
config: &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {
|
||||
config: &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
OutputKeyDeserialization: {
|
||||
Type: vo.DataTypeArray,
|
||||
ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeInteger},
|
||||
},
|
||||
@@ -326,12 +317,12 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
processor, err := NewJsonDeserializer(ctx, tt.config)
|
||||
processor, err := (&DeserializationConfig{}).Build(ctx, tt.config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
ctxWithCache := ctxcache.Init(ctx)
|
||||
input := map[string]any{"input": tt.inputJSON}
|
||||
result, err := processor.Invoke(ctxWithCache, input)
|
||||
result, err := processor.(*Deserializer).Invoke(ctxWithCache, input)
|
||||
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
|
||||
@@ -20,7 +20,11 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
)
|
||||
|
||||
@@ -29,28 +33,57 @@ const (
|
||||
OutputKeySerialization = "output"
|
||||
)
|
||||
|
||||
// SerializationConfig is the Config type for NodeTypeJsonSerialization.
|
||||
// Each Node Type should have its own designated Config type,
|
||||
// which should implement NodeAdaptor and NodeBuilder.
|
||||
// NOTE: we didn't define any fields for this type,
|
||||
// because this node is simple, we doesn't need to extract any SPECIFIC piece of info
|
||||
// from frontend Node. In other cases we would need to do it, such as LLM's model configs.
|
||||
type SerializationConfig struct {
|
||||
InputTypes map[string]*vo.TypeInfo
|
||||
// you can define ANY number of fields here,
|
||||
// as long as these fields are SERIALIZABLE and EXPORTED.
|
||||
// to store specific info extracted from frontend node.
|
||||
// e.g.
|
||||
// - LLM model configs
|
||||
// - conditional expressions
|
||||
// - fixed input fields such as MaxBatchSize
|
||||
}
|
||||
|
||||
type JsonSerializer struct {
|
||||
config *SerializationConfig
|
||||
}
|
||||
|
||||
func NewJsonSerializer(_ context.Context, cfg *SerializationConfig) (*JsonSerializer, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("config required")
|
||||
}
|
||||
if cfg.InputTypes == nil {
|
||||
return nil, fmt.Errorf("InputTypes is required for serialization")
|
||||
// Adapt provides conversion from Node to NodeSchema.
|
||||
// NOTE: in this specific case, we don't need AdaptOption.
|
||||
func (s *SerializationConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeJsonSerialization,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: s, // remember to set the Node's Config Type to NodeSchema as well
|
||||
}
|
||||
|
||||
return &JsonSerializer{
|
||||
config: cfg,
|
||||
}, nil
|
||||
// this sets input fields' type and mapping info
|
||||
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// this set output fields' type info
|
||||
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func (js *JsonSerializer) Invoke(_ context.Context, input map[string]any) (map[string]any, error) {
|
||||
func (s *SerializationConfig) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (
|
||||
any, error) {
|
||||
return &Serializer{}, nil
|
||||
}
|
||||
|
||||
// Serializer is the actual node implementation.
|
||||
type Serializer struct {
|
||||
// here can holds ANY data required for node execution
|
||||
}
|
||||
|
||||
// Invoke implements the InvokableNode interface.
|
||||
func (js *Serializer) Invoke(_ context.Context, input map[string]any) (map[string]any, error) {
|
||||
// Directly use the input map for serialization
|
||||
if input == nil {
|
||||
return nil, fmt.Errorf("input data for serialization cannot be nil")
|
||||
|
||||
@@ -23,44 +23,34 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
func TestNewJsonSerialize(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Test with nil config
|
||||
_, err := NewJsonSerializer(ctx, nil)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "config required")
|
||||
|
||||
// Test with missing InputTypes config
|
||||
_, err = NewJsonSerializer(ctx, &SerializationConfig{})
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "InputTypes is required")
|
||||
|
||||
// Test with valid config
|
||||
validConfig := &SerializationConfig{
|
||||
s, err := (&SerializationConfig{}).Build(ctx, &schema.NodeSchema{
|
||||
InputTypes: map[string]*vo.TypeInfo{
|
||||
"testKey": {Type: "string"},
|
||||
},
|
||||
}
|
||||
processor, err := NewJsonSerializer(ctx, validConfig)
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, processor)
|
||||
assert.NotNil(t, s)
|
||||
}
|
||||
|
||||
func TestJsonSerialize_Invoke(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
config := &SerializationConfig{
|
||||
|
||||
processor, err := (&SerializationConfig{}).Build(ctx, &schema.NodeSchema{
|
||||
InputTypes: map[string]*vo.TypeInfo{
|
||||
"stringKey": {Type: "string"},
|
||||
"intKey": {Type: "integer"},
|
||||
"boolKey": {Type: "boolean"},
|
||||
"objKey": {Type: "object"},
|
||||
},
|
||||
}
|
||||
|
||||
processor, err := NewJsonSerializer(ctx, config)
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test cases
|
||||
@@ -115,7 +105,7 @@ func TestJsonSerialize_Invoke(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := processor.Invoke(ctx, tt.input)
|
||||
result, err := processor.(*Serializer).Invoke(ctx, tt.input)
|
||||
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
|
||||
57
backend/domain/workflow/internal/nodes/knowledge/adaptor.go
Normal file
57
backend/domain/workflow/internal/nodes/knowledge/adaptor.go
Normal file
@@ -0,0 +1,57 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
|
||||
)
|
||||
|
||||
func convertParsingType(p string) (knowledge.ParseMode, error) {
|
||||
switch p {
|
||||
case "fast":
|
||||
return knowledge.FastParseMode, nil
|
||||
case "accurate":
|
||||
return knowledge.AccurateParseMode, nil
|
||||
default:
|
||||
return "", fmt.Errorf("invalid parsingType: %s", p)
|
||||
}
|
||||
}
|
||||
|
||||
func convertChunkType(p string) (knowledge.ChunkType, error) {
|
||||
switch p {
|
||||
case "custom":
|
||||
return knowledge.ChunkTypeCustom, nil
|
||||
case "default":
|
||||
return knowledge.ChunkTypeDefault, nil
|
||||
default:
|
||||
return "", fmt.Errorf("invalid ChunkType: %s", p)
|
||||
}
|
||||
}
|
||||
func convertRetrievalSearchType(s int64) (knowledge.SearchType, error) {
|
||||
switch s {
|
||||
case 0:
|
||||
return knowledge.SearchTypeSemantic, nil
|
||||
case 1:
|
||||
return knowledge.SearchTypeHybrid, nil
|
||||
case 20:
|
||||
return knowledge.SearchTypeFullText, nil
|
||||
default:
|
||||
return "", fmt.Errorf("invalid RetrievalSearchType %v", s)
|
||||
}
|
||||
}
|
||||
@@ -21,27 +21,45 @@ import (
|
||||
"errors"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
type DeleterConfig struct {
|
||||
KnowledgeID int64
|
||||
KnowledgeDeleter knowledge.KnowledgeOperator
|
||||
}
|
||||
type DeleterConfig struct{}
|
||||
|
||||
type KnowledgeDeleter struct {
|
||||
config *DeleterConfig
|
||||
}
|
||||
|
||||
func NewKnowledgeDeleter(_ context.Context, cfg *DeleterConfig) (*KnowledgeDeleter, error) {
|
||||
if cfg.KnowledgeDeleter == nil {
|
||||
return nil, errors.New("knowledge deleter is required")
|
||||
func (d *DeleterConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeKnowledgeDeleter,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: d,
|
||||
}
|
||||
return &KnowledgeDeleter{
|
||||
config: cfg,
|
||||
|
||||
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func (d *DeleterConfig) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
return &Deleter{
|
||||
knowledgeDeleter: knowledge.GetKnowledgeOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (k *KnowledgeDeleter) Delete(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
type Deleter struct {
|
||||
knowledgeDeleter knowledge.KnowledgeOperator
|
||||
}
|
||||
|
||||
func (k *Deleter) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
documentID, ok := input["documentID"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("documentID is required and must be a string")
|
||||
@@ -51,7 +69,7 @@ func (k *KnowledgeDeleter) Delete(ctx context.Context, input map[string]any) (ma
|
||||
DocumentID: documentID,
|
||||
}
|
||||
|
||||
response, err := k.config.KnowledgeDeleter.Delete(ctx, req)
|
||||
response, err := k.knowledgeDeleter.Delete(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -24,7 +24,14 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cast"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
|
||||
)
|
||||
|
||||
@@ -32,30 +39,88 @@ type IndexerConfig struct {
|
||||
KnowledgeID int64
|
||||
ParsingStrategy *knowledge.ParsingStrategy
|
||||
ChunkingStrategy *knowledge.ChunkingStrategy
|
||||
KnowledgeIndexer knowledge.KnowledgeOperator
|
||||
}
|
||||
|
||||
type KnowledgeIndexer struct {
|
||||
config *IndexerConfig
|
||||
func (i *IndexerConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeKnowledgeIndexer,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: i,
|
||||
}
|
||||
|
||||
inputs := n.Data.Inputs
|
||||
datasetListInfoParam := inputs.DatasetParam[0]
|
||||
datasetIDs := datasetListInfoParam.Input.Value.Content.([]any)
|
||||
if len(datasetIDs) == 0 {
|
||||
return nil, fmt.Errorf("dataset ids is required")
|
||||
}
|
||||
knowledgeID, err := cast.ToInt64E(datasetIDs[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
i.KnowledgeID = knowledgeID
|
||||
ps := inputs.StrategyParam.ParsingStrategy
|
||||
parseMode, err := convertParsingType(ps.ParsingType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parsingStrategy := &knowledge.ParsingStrategy{
|
||||
ParseMode: parseMode,
|
||||
ImageOCR: ps.ImageOcr,
|
||||
ExtractImage: ps.ImageExtraction,
|
||||
ExtractTable: ps.TableExtraction,
|
||||
}
|
||||
i.ParsingStrategy = parsingStrategy
|
||||
|
||||
cs := inputs.StrategyParam.ChunkStrategy
|
||||
chunkType, err := convertChunkType(cs.ChunkType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
chunkingStrategy := &knowledge.ChunkingStrategy{
|
||||
ChunkType: chunkType,
|
||||
Separator: cs.Separator,
|
||||
ChunkSize: cs.MaxToken,
|
||||
Overlap: int64(cs.Overlap * float64(cs.MaxToken)),
|
||||
}
|
||||
i.ChunkingStrategy = chunkingStrategy
|
||||
|
||||
if err = convert.SetInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func NewKnowledgeIndexer(_ context.Context, cfg *IndexerConfig) (*KnowledgeIndexer, error) {
|
||||
if cfg.ParsingStrategy == nil {
|
||||
func (i *IndexerConfig) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
if i.ParsingStrategy == nil {
|
||||
return nil, errors.New("parsing strategy is required")
|
||||
}
|
||||
if cfg.ChunkingStrategy == nil {
|
||||
if i.ChunkingStrategy == nil {
|
||||
return nil, errors.New("chunking strategy is required")
|
||||
}
|
||||
if cfg.KnowledgeIndexer == nil {
|
||||
return nil, errors.New("knowledge indexer is required")
|
||||
}
|
||||
return &KnowledgeIndexer{
|
||||
config: cfg,
|
||||
return &Indexer{
|
||||
knowledgeID: i.KnowledgeID,
|
||||
parsingStrategy: i.ParsingStrategy,
|
||||
chunkingStrategy: i.ChunkingStrategy,
|
||||
knowledgeIndexer: knowledge.GetKnowledgeOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (k *KnowledgeIndexer) Store(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
type Indexer struct {
|
||||
knowledgeID int64
|
||||
parsingStrategy *knowledge.ParsingStrategy
|
||||
chunkingStrategy *knowledge.ChunkingStrategy
|
||||
knowledgeIndexer knowledge.KnowledgeOperator
|
||||
}
|
||||
|
||||
func (k *Indexer) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
fileURL, ok := input["knowledge"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("knowledge is required")
|
||||
@@ -68,15 +133,15 @@ func (k *KnowledgeIndexer) Store(ctx context.Context, input map[string]any) (map
|
||||
}
|
||||
|
||||
req := &knowledge.CreateDocumentRequest{
|
||||
KnowledgeID: k.config.KnowledgeID,
|
||||
ParsingStrategy: k.config.ParsingStrategy,
|
||||
ChunkingStrategy: k.config.ChunkingStrategy,
|
||||
KnowledgeID: k.knowledgeID,
|
||||
ParsingStrategy: k.parsingStrategy,
|
||||
ChunkingStrategy: k.chunkingStrategy,
|
||||
FileURL: fileURL,
|
||||
FileName: fileName,
|
||||
FileExtension: ext,
|
||||
}
|
||||
|
||||
response, err := k.config.KnowledgeIndexer.Store(ctx, req)
|
||||
response, err := k.knowledgeIndexer.Store(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -20,7 +20,14 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/spf13/cast"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
)
|
||||
|
||||
@@ -29,37 +36,136 @@ const outputList = "outputList"
|
||||
type RetrieveConfig struct {
|
||||
KnowledgeIDs []int64
|
||||
RetrievalStrategy *knowledge.RetrievalStrategy
|
||||
Retriever knowledge.KnowledgeOperator
|
||||
}
|
||||
|
||||
type KnowledgeRetrieve struct {
|
||||
config *RetrieveConfig
|
||||
func (r *RetrieveConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeKnowledgeRetriever,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: r,
|
||||
}
|
||||
|
||||
inputs := n.Data.Inputs
|
||||
datasetListInfoParam := inputs.DatasetParam[0]
|
||||
datasetIDs := datasetListInfoParam.Input.Value.Content.([]any)
|
||||
knowledgeIDs := make([]int64, 0, len(datasetIDs))
|
||||
for _, id := range datasetIDs {
|
||||
k, err := cast.ToInt64E(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
knowledgeIDs = append(knowledgeIDs, k)
|
||||
}
|
||||
r.KnowledgeIDs = knowledgeIDs
|
||||
|
||||
retrievalStrategy := &knowledge.RetrievalStrategy{}
|
||||
|
||||
var getDesignatedParamContent = func(name string) (any, bool) {
|
||||
for _, param := range inputs.DatasetParam {
|
||||
if param.Name == name {
|
||||
return param.Input.Value.Content, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if content, ok := getDesignatedParamContent("topK"); ok {
|
||||
topK, err := cast.ToInt64E(content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
retrievalStrategy.TopK = &topK
|
||||
}
|
||||
|
||||
if content, ok := getDesignatedParamContent("useRerank"); ok {
|
||||
useRerank, err := cast.ToBoolE(content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
retrievalStrategy.EnableRerank = useRerank
|
||||
}
|
||||
|
||||
if content, ok := getDesignatedParamContent("useRewrite"); ok {
|
||||
useRewrite, err := cast.ToBoolE(content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
retrievalStrategy.EnableQueryRewrite = useRewrite
|
||||
}
|
||||
|
||||
if content, ok := getDesignatedParamContent("isPersonalOnly"); ok {
|
||||
isPersonalOnly, err := cast.ToBoolE(content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
retrievalStrategy.IsPersonalOnly = isPersonalOnly
|
||||
}
|
||||
|
||||
if content, ok := getDesignatedParamContent("useNl2sql"); ok {
|
||||
useNl2sql, err := cast.ToBoolE(content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
retrievalStrategy.EnableNL2SQL = useNl2sql
|
||||
}
|
||||
|
||||
if content, ok := getDesignatedParamContent("minScore"); ok {
|
||||
minScore, err := cast.ToFloat64E(content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
retrievalStrategy.MinScore = &minScore
|
||||
}
|
||||
|
||||
if content, ok := getDesignatedParamContent("strategy"); ok {
|
||||
strategy, err := cast.ToInt64E(content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
searchType, err := convertRetrievalSearchType(strategy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
retrievalStrategy.SearchType = searchType
|
||||
}
|
||||
|
||||
r.RetrievalStrategy = retrievalStrategy
|
||||
|
||||
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func NewKnowledgeRetrieve(_ context.Context, cfg *RetrieveConfig) (*KnowledgeRetrieve, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("cfg is required")
|
||||
func (r *RetrieveConfig) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
if len(r.KnowledgeIDs) == 0 {
|
||||
return nil, errors.New("knowledge ids are required")
|
||||
}
|
||||
|
||||
if cfg.Retriever == nil {
|
||||
return nil, errors.New("retriever is required")
|
||||
}
|
||||
|
||||
if len(cfg.KnowledgeIDs) == 0 {
|
||||
return nil, errors.New("knowledgeI ids is required")
|
||||
}
|
||||
|
||||
if cfg.RetrievalStrategy == nil {
|
||||
if r.RetrievalStrategy == nil {
|
||||
return nil, errors.New("retrieval strategy is required")
|
||||
}
|
||||
|
||||
return &KnowledgeRetrieve{
|
||||
config: cfg,
|
||||
return &Retrieve{
|
||||
knowledgeIDs: r.KnowledgeIDs,
|
||||
retrievalStrategy: r.RetrievalStrategy,
|
||||
retriever: knowledge.GetKnowledgeOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (kr *KnowledgeRetrieve) Retrieve(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
type Retrieve struct {
|
||||
knowledgeIDs []int64
|
||||
retrievalStrategy *knowledge.RetrievalStrategy
|
||||
retriever knowledge.KnowledgeOperator
|
||||
}
|
||||
|
||||
func (kr *Retrieve) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
query, ok := input["Query"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("capital query key is required")
|
||||
@@ -67,11 +173,11 @@ func (kr *KnowledgeRetrieve) Retrieve(ctx context.Context, input map[string]any)
|
||||
|
||||
req := &knowledge.RetrieveRequest{
|
||||
Query: query,
|
||||
KnowledgeIDs: kr.config.KnowledgeIDs,
|
||||
RetrievalStrategy: kr.config.RetrievalStrategy,
|
||||
KnowledgeIDs: kr.knowledgeIDs,
|
||||
RetrievalStrategy: kr.retrievalStrategy,
|
||||
}
|
||||
|
||||
response, err := kr.config.Retriever.Retrieve(ctx, req)
|
||||
response, err := kr.retriever.Retrieve(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -34,13 +34,20 @@ import (
|
||||
callbacks2 "github.com/cloudwego/eino/utils/callbacks"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
workflow3 "github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
crossknowledge "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
|
||||
crossmodel "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
@@ -143,126 +150,408 @@ const (
|
||||
)
|
||||
|
||||
type RetrievalStrategy struct {
|
||||
RetrievalStrategy *crossknowledge.RetrievalStrategy
|
||||
RetrievalStrategy *knowledge.RetrievalStrategy
|
||||
NoReCallReplyMode NoReCallReplyMode
|
||||
NoReCallReplyCustomizePrompt string
|
||||
}
|
||||
|
||||
type KnowledgeRecallConfig struct {
|
||||
ChatModel model.BaseChatModel
|
||||
Retriever crossknowledge.KnowledgeOperator
|
||||
Retriever knowledge.KnowledgeOperator
|
||||
RetrievalStrategy *RetrievalStrategy
|
||||
SelectedKnowledgeDetails []*crossknowledge.KnowledgeDetail
|
||||
SelectedKnowledgeDetails []*knowledge.KnowledgeDetail
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
ChatModel ModelWithInfo
|
||||
Tools []tool.BaseTool
|
||||
SystemPrompt string
|
||||
UserPrompt string
|
||||
OutputFormat Format
|
||||
InputFields map[string]*vo.TypeInfo
|
||||
OutputFields map[string]*vo.TypeInfo
|
||||
ToolsReturnDirectly map[string]bool
|
||||
KnowledgeRecallConfig *KnowledgeRecallConfig
|
||||
FullSources map[string]*nodes.SourceInfo
|
||||
SystemPrompt string
|
||||
UserPrompt string
|
||||
OutputFormat Format
|
||||
LLMParams *crossmodel.LLMParams
|
||||
FCParam *vo.FCParam
|
||||
BackupLLMParams *crossmodel.LLMParams
|
||||
}
|
||||
|
||||
type LLM struct {
|
||||
r compose.Runnable[map[string]any, map[string]any]
|
||||
outputFormat Format
|
||||
outputFields map[string]*vo.TypeInfo
|
||||
canStream bool
|
||||
requireCheckpoint bool
|
||||
fullSources map[string]*nodes.SourceInfo
|
||||
}
|
||||
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema2.NodeSchema, error) {
|
||||
ns := &schema2.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeLLM,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: c,
|
||||
}
|
||||
|
||||
const (
|
||||
rawOutputKey = "llm_raw_output_%s"
|
||||
warningKey = "llm_warning_%s"
|
||||
)
|
||||
param := n.Data.Inputs.LLMParam
|
||||
if param == nil {
|
||||
return nil, fmt.Errorf("llm node's llmParam is nil")
|
||||
}
|
||||
|
||||
func jsonParse(ctx context.Context, data string, schema_ map[string]*vo.TypeInfo) (map[string]any, error) {
|
||||
data = nodes.ExtractJSONString(data)
|
||||
|
||||
var result map[string]any
|
||||
|
||||
err := sonic.UnmarshalString(data, &result)
|
||||
bs, _ := sonic.Marshal(param)
|
||||
llmParam := make(vo.LLMParam, 0)
|
||||
if err := sonic.Unmarshal(bs, &llmParam); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
convertedLLMParam, err := llmParamsToLLMParam(llmParam)
|
||||
if err != nil {
|
||||
c := execute.GetExeCtx(ctx)
|
||||
if c != nil {
|
||||
logs.CtxErrorf(ctx, "failed to parse json: %v, data: %s", err, data)
|
||||
rawOutputK := fmt.Sprintf(rawOutputKey, c.NodeCtx.NodeKey)
|
||||
warningK := fmt.Sprintf(warningKey, c.NodeCtx.NodeKey)
|
||||
ctxcache.Store(ctx, rawOutputK, data)
|
||||
ctxcache.Store(ctx, warningK, vo.WrapWarn(errno.ErrLLMStructuredOutputParseFail, err))
|
||||
return map[string]any{}, nil
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r, ws, err := nodes.ConvertInputs(ctx, result, schema_)
|
||||
if err != nil {
|
||||
return nil, vo.WrapError(errno.ErrLLMStructuredOutputParseFail, err)
|
||||
c.LLMParams = convertedLLMParam
|
||||
c.SystemPrompt = convertedLLMParam.SystemPrompt
|
||||
c.UserPrompt = convertedLLMParam.Prompt
|
||||
|
||||
var resFormat Format
|
||||
switch convertedLLMParam.ResponseFormat {
|
||||
case crossmodel.ResponseFormatText:
|
||||
resFormat = FormatText
|
||||
case crossmodel.ResponseFormatMarkdown:
|
||||
resFormat = FormatMarkdown
|
||||
case crossmodel.ResponseFormatJSON:
|
||||
resFormat = FormatJSON
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported response format: %d", convertedLLMParam.ResponseFormat)
|
||||
}
|
||||
|
||||
if ws != nil {
|
||||
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
|
||||
c.OutputFormat = resFormat
|
||||
|
||||
if err = convert.SetInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return r, nil
|
||||
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resFormat == FormatJSON {
|
||||
if len(ns.OutputTypes) == 1 {
|
||||
for _, v := range ns.OutputTypes {
|
||||
if v.Type == vo.DataTypeString {
|
||||
resFormat = FormatText
|
||||
break
|
||||
}
|
||||
}
|
||||
} else if len(ns.OutputTypes) == 2 {
|
||||
if _, ok := ns.OutputTypes[ReasoningOutputKey]; ok {
|
||||
for k, v := range ns.OutputTypes {
|
||||
if k != ReasoningOutputKey && v.Type == vo.DataTypeString {
|
||||
resFormat = FormatText
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if resFormat == FormatJSON {
|
||||
ns.StreamConfigs = &schema2.StreamConfig{
|
||||
CanGeneratesStream: false,
|
||||
}
|
||||
} else {
|
||||
ns.StreamConfigs = &schema2.StreamConfig{
|
||||
CanGeneratesStream: true,
|
||||
}
|
||||
}
|
||||
|
||||
if n.Data.Inputs.LLM != nil && n.Data.Inputs.FCParam != nil {
|
||||
c.FCParam = n.Data.Inputs.FCParam
|
||||
}
|
||||
|
||||
if se := n.Data.Inputs.SettingOnError; se != nil {
|
||||
if se.Ext != nil && len(se.Ext.BackupLLMParam) > 0 {
|
||||
var backupLLMParam vo.SimpleLLMParam
|
||||
if err = sonic.UnmarshalString(se.Ext.BackupLLMParam, &backupLLMParam); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
backupModel, err := simpleLLMParamsToLLMParams(backupLLMParam)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.BackupLLMParams = backupModel
|
||||
}
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func llmParamsToLLMParam(params vo.LLMParam) (*crossmodel.LLMParams, error) {
|
||||
p := &crossmodel.LLMParams{}
|
||||
for _, param := range params {
|
||||
switch param.Name {
|
||||
case "temperature":
|
||||
strVal := param.Input.Value.Content.(string)
|
||||
floatVal, err := strconv.ParseFloat(strVal, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.Temperature = &floatVal
|
||||
case "maxTokens":
|
||||
strVal := param.Input.Value.Content.(string)
|
||||
intVal, err := strconv.Atoi(strVal)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.MaxTokens = intVal
|
||||
case "responseFormat":
|
||||
strVal := param.Input.Value.Content.(string)
|
||||
int64Val, err := strconv.ParseInt(strVal, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.ResponseFormat = crossmodel.ResponseFormat(int64Val)
|
||||
case "modleName":
|
||||
strVal := param.Input.Value.Content.(string)
|
||||
p.ModelName = strVal
|
||||
case "modelType":
|
||||
strVal := param.Input.Value.Content.(string)
|
||||
int64Val, err := strconv.ParseInt(strVal, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.ModelType = int64Val
|
||||
case "prompt":
|
||||
strVal := param.Input.Value.Content.(string)
|
||||
p.Prompt = strVal
|
||||
case "enableChatHistory":
|
||||
boolVar := param.Input.Value.Content.(bool)
|
||||
p.EnableChatHistory = boolVar
|
||||
case "systemPrompt":
|
||||
strVal := param.Input.Value.Content.(string)
|
||||
p.SystemPrompt = strVal
|
||||
case "chatHistoryRound", "generationDiversity", "frequencyPenalty", "presencePenalty":
|
||||
// do nothing
|
||||
case "topP":
|
||||
strVal := param.Input.Value.Content.(string)
|
||||
floatVar, err := strconv.ParseFloat(strVal, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.TopP = &floatVar
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid LLMParam name: %s", param.Name)
|
||||
}
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func simpleLLMParamsToLLMParams(params vo.SimpleLLMParam) (*crossmodel.LLMParams, error) {
|
||||
p := &crossmodel.LLMParams{}
|
||||
p.ModelName = params.ModelName
|
||||
p.ModelType = params.ModelType
|
||||
p.Temperature = ¶ms.Temperature
|
||||
p.MaxTokens = params.MaxTokens
|
||||
p.TopP = ¶ms.TopP
|
||||
p.ResponseFormat = params.ResponseFormat
|
||||
p.SystemPrompt = params.SystemPrompt
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func getReasoningContent(message *schema.Message) string {
|
||||
return message.ReasoningContent
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
nested []nodes.NestedWorkflowOption
|
||||
toolWorkflowSW *schema.StreamWriter[*entity.Message]
|
||||
}
|
||||
func (c *Config) Build(ctx context.Context, ns *schema2.NodeSchema, _ ...schema2.BuildOption) (any, error) {
|
||||
var (
|
||||
err error
|
||||
chatModel, fallbackM model.BaseChatModel
|
||||
info, fallbackI *modelmgr.Model
|
||||
modelWithInfo ModelWithInfo
|
||||
tools []tool.BaseTool
|
||||
toolsReturnDirectly map[string]bool
|
||||
knowledgeRecallConfig *KnowledgeRecallConfig
|
||||
)
|
||||
|
||||
type Option func(o *Options)
|
||||
|
||||
func WithNestedWorkflowOptions(nested ...nodes.NestedWorkflowOption) Option {
|
||||
return func(o *Options) {
|
||||
o.nested = append(o.nested, nested...)
|
||||
chatModel, info, err = crossmodel.GetManager().GetModel(ctx, c.LLMParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func WithToolWorkflowMessageWriter(sw *schema.StreamWriter[*entity.Message]) Option {
|
||||
return func(o *Options) {
|
||||
o.toolWorkflowSW = sw
|
||||
exceptionConf := ns.ExceptionConfigs
|
||||
if exceptionConf != nil && exceptionConf.MaxRetry > 0 {
|
||||
backupModelParams := c.BackupLLMParams
|
||||
if backupModelParams != nil {
|
||||
fallbackM, fallbackI, err = crossmodel.GetManager().GetModel(ctx, backupModelParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type llmState = map[string]any
|
||||
if fallbackM == nil {
|
||||
modelWithInfo = NewModel(chatModel, info)
|
||||
} else {
|
||||
modelWithInfo = NewModelWithFallback(chatModel, fallbackM, info, fallbackI)
|
||||
}
|
||||
|
||||
const agentModelName = "agent_model"
|
||||
fcParams := c.FCParam
|
||||
if fcParams != nil {
|
||||
if fcParams.WorkflowFCParam != nil {
|
||||
for _, wf := range fcParams.WorkflowFCParam.WorkflowList {
|
||||
wfIDStr := wf.WorkflowID
|
||||
wfID, err := strconv.ParseInt(wfIDStr, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid workflow id: %s", wfIDStr)
|
||||
}
|
||||
|
||||
workflowToolConfig := vo.WorkflowToolConfig{}
|
||||
if wf.FCSetting != nil {
|
||||
workflowToolConfig.InputParametersConfig = wf.FCSetting.RequestParameters
|
||||
workflowToolConfig.OutputParametersConfig = wf.FCSetting.ResponseParameters
|
||||
}
|
||||
|
||||
locator := vo.FromDraft
|
||||
if wf.WorkflowVersion != "" {
|
||||
locator = vo.FromSpecificVersion
|
||||
}
|
||||
|
||||
wfTool, err := workflow.GetRepository().WorkflowAsTool(ctx, vo.GetPolicy{
|
||||
ID: wfID,
|
||||
QType: locator,
|
||||
Version: wf.WorkflowVersion,
|
||||
}, workflowToolConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tools = append(tools, wfTool)
|
||||
if wfTool.TerminatePlan() == vo.UseAnswerContent {
|
||||
if toolsReturnDirectly == nil {
|
||||
toolsReturnDirectly = make(map[string]bool)
|
||||
}
|
||||
toolInfo, err := wfTool.Info(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
toolsReturnDirectly[toolInfo.Name] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if fcParams.PluginFCParam != nil {
|
||||
pluginToolsInvokableReq := make(map[int64]*plugin.ToolsInvokableRequest)
|
||||
for _, p := range fcParams.PluginFCParam.PluginList {
|
||||
pid, err := strconv.ParseInt(p.PluginID, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid plugin id: %s", p.PluginID)
|
||||
}
|
||||
toolID, err := strconv.ParseInt(p.ApiId, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid plugin id: %s", p.PluginID)
|
||||
}
|
||||
|
||||
var (
|
||||
requestParameters []*workflow3.APIParameter
|
||||
responseParameters []*workflow3.APIParameter
|
||||
)
|
||||
if p.FCSetting != nil {
|
||||
requestParameters = p.FCSetting.RequestParameters
|
||||
responseParameters = p.FCSetting.ResponseParameters
|
||||
}
|
||||
|
||||
if req, ok := pluginToolsInvokableReq[pid]; ok {
|
||||
req.ToolsInvokableInfo[toolID] = &plugin.ToolsInvokableInfo{
|
||||
ToolID: toolID,
|
||||
RequestAPIParametersConfig: requestParameters,
|
||||
ResponseAPIParametersConfig: responseParameters,
|
||||
}
|
||||
} else {
|
||||
pluginToolsInfoRequest := &plugin.ToolsInvokableRequest{
|
||||
PluginEntity: plugin.Entity{
|
||||
PluginID: pid,
|
||||
PluginVersion: ptr.Of(p.PluginVersion),
|
||||
},
|
||||
ToolsInvokableInfo: map[int64]*plugin.ToolsInvokableInfo{
|
||||
toolID: {
|
||||
ToolID: toolID,
|
||||
RequestAPIParametersConfig: requestParameters,
|
||||
ResponseAPIParametersConfig: responseParameters,
|
||||
},
|
||||
},
|
||||
IsDraft: p.IsDraft,
|
||||
}
|
||||
pluginToolsInvokableReq[pid] = pluginToolsInfoRequest
|
||||
}
|
||||
}
|
||||
inInvokableTools := make([]tool.BaseTool, 0, len(fcParams.PluginFCParam.PluginList))
|
||||
for _, req := range pluginToolsInvokableReq {
|
||||
toolMap, err := plugin.GetPluginService().GetPluginInvokableTools(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, t := range toolMap {
|
||||
inInvokableTools = append(inInvokableTools, plugin.NewInvokableTool(t))
|
||||
}
|
||||
}
|
||||
if len(inInvokableTools) > 0 {
|
||||
tools = append(tools, inInvokableTools...)
|
||||
}
|
||||
}
|
||||
|
||||
if fcParams.KnowledgeFCParam != nil && len(fcParams.KnowledgeFCParam.KnowledgeList) > 0 {
|
||||
kwChatModel := workflow.GetRepository().GetKnowledgeRecallChatModel()
|
||||
if kwChatModel == nil {
|
||||
return nil, fmt.Errorf("workflow builtin chat model for knowledge recall not configured")
|
||||
}
|
||||
|
||||
knowledgeOperator := knowledge.GetKnowledgeOperator()
|
||||
setting := fcParams.KnowledgeFCParam.GlobalSetting
|
||||
knowledgeRecallConfig = &KnowledgeRecallConfig{
|
||||
ChatModel: kwChatModel,
|
||||
Retriever: knowledgeOperator,
|
||||
}
|
||||
searchType, err := toRetrievalSearchType(setting.SearchMode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
knowledgeRecallConfig.RetrievalStrategy = &RetrievalStrategy{
|
||||
RetrievalStrategy: &knowledge.RetrievalStrategy{
|
||||
TopK: ptr.Of(setting.TopK),
|
||||
MinScore: ptr.Of(setting.MinScore),
|
||||
SearchType: searchType,
|
||||
EnableNL2SQL: setting.UseNL2SQL,
|
||||
EnableQueryRewrite: setting.UseRewrite,
|
||||
EnableRerank: setting.UseRerank,
|
||||
},
|
||||
NoReCallReplyMode: NoReCallReplyMode(setting.NoRecallReplyMode),
|
||||
NoReCallReplyCustomizePrompt: setting.NoRecallReplyCustomizePrompt,
|
||||
}
|
||||
|
||||
knowledgeIDs := make([]int64, 0, len(fcParams.KnowledgeFCParam.KnowledgeList))
|
||||
for _, kw := range fcParams.KnowledgeFCParam.KnowledgeList {
|
||||
kid, err := strconv.ParseInt(kw.ID, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
knowledgeIDs = append(knowledgeIDs, kid)
|
||||
}
|
||||
|
||||
detailResp, err := knowledgeOperator.ListKnowledgeDetail(ctx,
|
||||
&knowledge.ListKnowledgeDetailRequest{
|
||||
KnowledgeIDs: knowledgeIDs,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
knowledgeRecallConfig.SelectedKnowledgeDetails = detailResp.KnowledgeDetails
|
||||
}
|
||||
}
|
||||
|
||||
func New(ctx context.Context, cfg *Config) (*LLM, error) {
|
||||
g := compose.NewGraph[map[string]any, map[string]any](compose.WithGenLocalState(func(ctx context.Context) (state llmState) {
|
||||
return llmState{}
|
||||
}))
|
||||
|
||||
var (
|
||||
hasReasoning bool
|
||||
canStream = true
|
||||
)
|
||||
var hasReasoning bool
|
||||
|
||||
format := cfg.OutputFormat
|
||||
format := c.OutputFormat
|
||||
if format == FormatJSON {
|
||||
if len(cfg.OutputFields) == 1 {
|
||||
for _, v := range cfg.OutputFields {
|
||||
if len(ns.OutputTypes) == 1 {
|
||||
for _, v := range ns.OutputTypes {
|
||||
if v.Type == vo.DataTypeString {
|
||||
format = FormatText
|
||||
break
|
||||
}
|
||||
}
|
||||
} else if len(cfg.OutputFields) == 2 {
|
||||
if _, ok := cfg.OutputFields[ReasoningOutputKey]; ok {
|
||||
for k, v := range cfg.OutputFields {
|
||||
} else if len(ns.OutputTypes) == 2 {
|
||||
if _, ok := ns.OutputTypes[ReasoningOutputKey]; ok {
|
||||
for k, v := range ns.OutputTypes {
|
||||
if k != ReasoningOutputKey && v.Type == vo.DataTypeString {
|
||||
format = FormatText
|
||||
break
|
||||
@@ -272,10 +561,10 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
|
||||
}
|
||||
}
|
||||
|
||||
userPrompt := cfg.UserPrompt
|
||||
userPrompt := c.UserPrompt
|
||||
switch format {
|
||||
case FormatJSON:
|
||||
jsonSchema, err := vo.TypeInfoToJSONSchema(cfg.OutputFields, nil)
|
||||
jsonSchema, err := vo.TypeInfoToJSONSchema(ns.OutputTypes, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -287,20 +576,20 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
|
||||
case FormatText:
|
||||
}
|
||||
|
||||
if cfg.KnowledgeRecallConfig != nil {
|
||||
err := injectKnowledgeTool(ctx, g, cfg.UserPrompt, cfg.KnowledgeRecallConfig)
|
||||
if knowledgeRecallConfig != nil {
|
||||
err := injectKnowledgeTool(ctx, g, c.UserPrompt, knowledgeRecallConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userPrompt = fmt.Sprintf("{{%s}}%s", knowledgeUserPromptTemplateKey, userPrompt)
|
||||
|
||||
inputs := maps.Clone(cfg.InputFields)
|
||||
inputs := maps.Clone(ns.InputTypes)
|
||||
inputs[knowledgeUserPromptTemplateKey] = &vo.TypeInfo{
|
||||
Type: vo.DataTypeString,
|
||||
}
|
||||
sp := newPromptTpl(schema.System, cfg.SystemPrompt, inputs, nil)
|
||||
sp := newPromptTpl(schema.System, c.SystemPrompt, inputs, nil)
|
||||
up := newPromptTpl(schema.User, userPrompt, inputs, []string{knowledgeUserPromptTemplateKey})
|
||||
template := newPrompts(sp, up, cfg.ChatModel)
|
||||
template := newPrompts(sp, up, modelWithInfo)
|
||||
|
||||
_ = g.AddChatTemplateNode(templateNodeKey, template,
|
||||
compose.WithStatePreHandler(func(ctx context.Context, in map[string]any, state llmState) (map[string]any, error) {
|
||||
@@ -312,28 +601,28 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
|
||||
_ = g.AddEdge(knowledgeLambdaKey, templateNodeKey)
|
||||
|
||||
} else {
|
||||
sp := newPromptTpl(schema.System, cfg.SystemPrompt, cfg.InputFields, nil)
|
||||
up := newPromptTpl(schema.User, userPrompt, cfg.InputFields, nil)
|
||||
template := newPrompts(sp, up, cfg.ChatModel)
|
||||
sp := newPromptTpl(schema.System, c.SystemPrompt, ns.InputTypes, nil)
|
||||
up := newPromptTpl(schema.User, userPrompt, ns.InputTypes, nil)
|
||||
template := newPrompts(sp, up, modelWithInfo)
|
||||
_ = g.AddChatTemplateNode(templateNodeKey, template)
|
||||
|
||||
_ = g.AddEdge(compose.START, templateNodeKey)
|
||||
}
|
||||
|
||||
if len(cfg.Tools) > 0 {
|
||||
m, ok := cfg.ChatModel.(model.ToolCallingChatModel)
|
||||
if len(tools) > 0 {
|
||||
m, ok := modelWithInfo.(model.ToolCallingChatModel)
|
||||
if !ok {
|
||||
return nil, errors.New("requires a ToolCallingChatModel to use with tools")
|
||||
}
|
||||
reactConfig := react.AgentConfig{
|
||||
ToolCallingModel: m,
|
||||
ToolsConfig: compose.ToolsNodeConfig{Tools: cfg.Tools},
|
||||
ToolsConfig: compose.ToolsNodeConfig{Tools: tools},
|
||||
ModelNodeName: agentModelName,
|
||||
}
|
||||
|
||||
if len(cfg.ToolsReturnDirectly) > 0 {
|
||||
reactConfig.ToolReturnDirectly = make(map[string]struct{}, len(cfg.ToolsReturnDirectly))
|
||||
for k := range cfg.ToolsReturnDirectly {
|
||||
if len(toolsReturnDirectly) > 0 {
|
||||
reactConfig.ToolReturnDirectly = make(map[string]struct{}, len(toolsReturnDirectly))
|
||||
for k := range toolsReturnDirectly {
|
||||
reactConfig.ToolReturnDirectly[k] = struct{}{}
|
||||
}
|
||||
}
|
||||
@@ -347,28 +636,26 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
|
||||
opts = append(opts, compose.WithNodeName("workflow_llm_react_agent"))
|
||||
_ = g.AddGraphNode(llmNodeKey, agentNode, opts...)
|
||||
} else {
|
||||
_ = g.AddChatModelNode(llmNodeKey, cfg.ChatModel)
|
||||
_ = g.AddChatModelNode(llmNodeKey, modelWithInfo)
|
||||
}
|
||||
|
||||
_ = g.AddEdge(templateNodeKey, llmNodeKey)
|
||||
|
||||
if format == FormatJSON {
|
||||
iConvert := func(ctx context.Context, msg *schema.Message) (map[string]any, error) {
|
||||
return jsonParse(ctx, msg.Content, cfg.OutputFields)
|
||||
return jsonParse(ctx, msg.Content, ns.OutputTypes)
|
||||
}
|
||||
|
||||
convertNode := compose.InvokableLambda(iConvert)
|
||||
|
||||
_ = g.AddLambdaNode(outputConvertNodeKey, convertNode)
|
||||
|
||||
canStream = false
|
||||
} else {
|
||||
var outputKey string
|
||||
if len(cfg.OutputFields) != 1 && len(cfg.OutputFields) != 2 {
|
||||
if len(ns.OutputTypes) != 1 && len(ns.OutputTypes) != 2 {
|
||||
panic("impossible")
|
||||
}
|
||||
|
||||
for k, v := range cfg.OutputFields {
|
||||
for k, v := range ns.OutputTypes {
|
||||
if v.Type != vo.DataTypeString {
|
||||
panic("impossible")
|
||||
}
|
||||
@@ -443,17 +730,17 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
|
||||
_ = g.AddEdge(outputConvertNodeKey, compose.END)
|
||||
|
||||
requireCheckpoint := false
|
||||
if len(cfg.Tools) > 0 {
|
||||
if len(tools) > 0 {
|
||||
requireCheckpoint = true
|
||||
}
|
||||
|
||||
var opts []compose.GraphCompileOption
|
||||
var compileOpts []compose.GraphCompileOption
|
||||
if requireCheckpoint {
|
||||
opts = append(opts, compose.WithCheckPointStore(workflow.GetRepository()))
|
||||
compileOpts = append(compileOpts, compose.WithCheckPointStore(workflow.GetRepository()))
|
||||
}
|
||||
opts = append(opts, compose.WithGraphName("workflow_llm_node_graph"))
|
||||
compileOpts = append(compileOpts, compose.WithGraphName("workflow_llm_node_graph"))
|
||||
|
||||
r, err := g.Compile(ctx, opts...)
|
||||
r, err := g.Compile(ctx, compileOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -461,15 +748,132 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
|
||||
llm := &LLM{
|
||||
r: r,
|
||||
outputFormat: format,
|
||||
canStream: canStream,
|
||||
requireCheckpoint: requireCheckpoint,
|
||||
fullSources: cfg.FullSources,
|
||||
fullSources: ns.FullSources,
|
||||
}
|
||||
|
||||
return llm, nil
|
||||
}
|
||||
|
||||
func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...Option) (composeOpts []compose.Option, resumingEvent *entity.InterruptEvent, err error) {
|
||||
func (c *Config) RequireCheckpoint() bool {
|
||||
if c.FCParam != nil {
|
||||
if c.FCParam.WorkflowFCParam != nil || c.FCParam.PluginFCParam != nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Config) FieldStreamType(path compose.FieldPath, ns *schema2.NodeSchema,
|
||||
sc *schema2.WorkflowSchema) (schema2.FieldStreamType, error) {
|
||||
if !sc.RequireStreaming() {
|
||||
return schema2.FieldNotStream, nil
|
||||
}
|
||||
|
||||
if len(path) != 1 {
|
||||
return schema2.FieldNotStream, nil
|
||||
}
|
||||
|
||||
outputs := ns.OutputTypes
|
||||
if len(outputs) != 1 && len(outputs) != 2 {
|
||||
return schema2.FieldNotStream, nil
|
||||
}
|
||||
|
||||
var outputKey string
|
||||
for key, output := range outputs {
|
||||
if output.Type != vo.DataTypeString {
|
||||
return schema2.FieldNotStream, nil
|
||||
}
|
||||
|
||||
if key != ReasoningOutputKey {
|
||||
if len(outputKey) > 0 {
|
||||
return schema2.FieldNotStream, nil
|
||||
}
|
||||
outputKey = key
|
||||
}
|
||||
}
|
||||
|
||||
field := path[0]
|
||||
if field == ReasoningOutputKey || field == outputKey {
|
||||
return schema2.FieldIsStream, nil
|
||||
}
|
||||
|
||||
return schema2.FieldNotStream, nil
|
||||
}
|
||||
|
||||
func toRetrievalSearchType(s int64) (knowledge.SearchType, error) {
|
||||
switch s {
|
||||
case 0:
|
||||
return knowledge.SearchTypeSemantic, nil
|
||||
case 1:
|
||||
return knowledge.SearchTypeHybrid, nil
|
||||
case 20:
|
||||
return knowledge.SearchTypeFullText, nil
|
||||
default:
|
||||
return "", fmt.Errorf("invalid retrieval search type %v", s)
|
||||
}
|
||||
}
|
||||
|
||||
type LLM struct {
|
||||
r compose.Runnable[map[string]any, map[string]any]
|
||||
outputFormat Format
|
||||
requireCheckpoint bool
|
||||
fullSources map[string]*schema2.SourceInfo
|
||||
}
|
||||
|
||||
const (
|
||||
rawOutputKey = "llm_raw_output_%s"
|
||||
warningKey = "llm_warning_%s"
|
||||
)
|
||||
|
||||
func jsonParse(ctx context.Context, data string, schema_ map[string]*vo.TypeInfo) (map[string]any, error) {
|
||||
data = nodes.ExtractJSONString(data)
|
||||
|
||||
var result map[string]any
|
||||
|
||||
err := sonic.UnmarshalString(data, &result)
|
||||
if err != nil {
|
||||
c := execute.GetExeCtx(ctx)
|
||||
if c != nil {
|
||||
logs.CtxErrorf(ctx, "failed to parse json: %v, data: %s", err, data)
|
||||
rawOutputK := fmt.Sprintf(rawOutputKey, c.NodeCtx.NodeKey)
|
||||
warningK := fmt.Sprintf(warningKey, c.NodeCtx.NodeKey)
|
||||
ctxcache.Store(ctx, rawOutputK, data)
|
||||
ctxcache.Store(ctx, warningK, vo.WrapWarn(errno.ErrLLMStructuredOutputParseFail, err))
|
||||
return map[string]any{}, nil
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r, ws, err := nodes.ConvertInputs(ctx, result, schema_)
|
||||
if err != nil {
|
||||
return nil, vo.WrapError(errno.ErrLLMStructuredOutputParseFail, err)
|
||||
}
|
||||
|
||||
if ws != nil {
|
||||
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
type llmOptions struct {
|
||||
toolWorkflowSW *schema.StreamWriter[*entity.Message]
|
||||
}
|
||||
|
||||
func WithToolWorkflowMessageWriter(sw *schema.StreamWriter[*entity.Message]) nodes.NodeOption {
|
||||
return nodes.WrapImplSpecificOptFn(func(o *llmOptions) {
|
||||
o.toolWorkflowSW = sw
|
||||
})
|
||||
}
|
||||
|
||||
type llmState = map[string]any
|
||||
|
||||
const agentModelName = "agent_model"
|
||||
|
||||
func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...nodes.NodeOption) (composeOpts []compose.Option, resumingEvent *entity.InterruptEvent, err error) {
|
||||
c := execute.GetExeCtx(ctx)
|
||||
if c != nil {
|
||||
resumingEvent = c.NodeCtx.ResumingEvent
|
||||
@@ -502,17 +906,9 @@ func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...Option) (co
|
||||
composeOpts = append(composeOpts, compose.WithCheckPointID(checkpointID))
|
||||
}
|
||||
|
||||
llmOpts := &Options{}
|
||||
for _, opt := range opts {
|
||||
opt(llmOpts)
|
||||
}
|
||||
options := nodes.GetCommonOptions(&nodes.NodeOptions{}, opts...)
|
||||
|
||||
nestedOpts := &nodes.NestedWorkflowOptions{}
|
||||
for _, opt := range llmOpts.nested {
|
||||
opt(nestedOpts)
|
||||
}
|
||||
|
||||
composeOpts = append(composeOpts, nestedOpts.GetOptsForNested()...)
|
||||
composeOpts = append(composeOpts, options.GetOptsForNested()...)
|
||||
|
||||
if resumingEvent != nil {
|
||||
var (
|
||||
@@ -580,6 +976,7 @@ func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...Option) (co
|
||||
composeOpts = append(composeOpts, compose.WithToolsNodeOption(compose.WithToolOption(execute.WithExecuteConfig(exeCfg))))
|
||||
}
|
||||
|
||||
llmOpts := nodes.GetImplSpecificOptions(&llmOptions{}, opts...)
|
||||
if llmOpts.toolWorkflowSW != nil {
|
||||
toolMsgOpt, toolMsgSR := execute.WithMessagePipe()
|
||||
composeOpts = append(composeOpts, toolMsgOpt)
|
||||
@@ -697,7 +1094,7 @@ func handleInterrupt(ctx context.Context, err error, resumingEvent *entity.Inter
|
||||
return compose.NewInterruptAndRerunErr(ie)
|
||||
}
|
||||
|
||||
func (l *LLM) Chat(ctx context.Context, in map[string]any, opts ...Option) (out map[string]any, err error) {
|
||||
func (l *LLM) Invoke(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (out map[string]any, err error) {
|
||||
composeOpts, resumingEvent, err := l.prepare(ctx, in, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -712,7 +1109,7 @@ func (l *LLM) Chat(ctx context.Context, in map[string]any, opts ...Option) (out
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (l *LLM) ChatStream(ctx context.Context, in map[string]any, opts ...Option) (out *schema.StreamReader[map[string]any], err error) {
|
||||
func (l *LLM) Stream(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (out *schema.StreamReader[map[string]any], err error) {
|
||||
composeOpts, resumingEvent, err := l.prepare(ctx, in, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -745,7 +1142,7 @@ func injectKnowledgeTool(_ context.Context, g *compose.Graph[map[string]any, map
|
||||
|
||||
_ = g.AddLambdaNode(knowledgeLambdaKey, compose.InvokableLambda(func(ctx context.Context, input *schema.Message) (output map[string]any, err error) {
|
||||
modelPredictionIDs := strings.Split(input.Content, ",")
|
||||
selectKwIDs := slices.ToMap(cfg.SelectedKnowledgeDetails, func(e *crossknowledge.KnowledgeDetail) (string, int64) {
|
||||
selectKwIDs := slices.ToMap(cfg.SelectedKnowledgeDetails, func(e *knowledge.KnowledgeDetail) (string, int64) {
|
||||
return strconv.Itoa(int(e.ID)), e.ID
|
||||
})
|
||||
recallKnowledgeIDs := make([]int64, 0)
|
||||
@@ -759,7 +1156,7 @@ func injectKnowledgeTool(_ context.Context, g *compose.Graph[map[string]any, map
|
||||
return make(map[string]any), nil
|
||||
}
|
||||
|
||||
docs, err := cfg.Retriever.Retrieve(ctx, &crossknowledge.RetrieveRequest{
|
||||
docs, err := cfg.Retriever.Retrieve(ctx, &knowledge.RetrieveRequest{
|
||||
Query: userPrompt,
|
||||
KnowledgeIDs: recallKnowledgeIDs,
|
||||
RetrievalStrategy: cfg.RetrievalStrategy.RetrievalStrategy,
|
||||
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
@@ -107,7 +108,7 @@ func newPrompts(sp, up *promptTpl, model ModelWithInfo) *prompts {
|
||||
}
|
||||
|
||||
func (pl *promptTpl) render(ctx context.Context, vs map[string]any,
|
||||
sources map[string]*nodes.SourceInfo,
|
||||
sources map[string]*schema2.SourceInfo,
|
||||
supportedModals map[modelmgr.Modal]bool,
|
||||
) (*schema.Message, error) {
|
||||
if !pl.hasMultiModal || len(supportedModals) == 0 {
|
||||
@@ -247,7 +248,7 @@ func (p *prompts) Format(ctx context.Context, vs map[string]any, _ ...prompt.Opt
|
||||
}
|
||||
sk := fmt.Sprintf(sourceKey, nodeKey)
|
||||
|
||||
sources, ok := ctxcache.Get[map[string]*nodes.SourceInfo](ctx, sk)
|
||||
sources, ok := ctxcache.Get[map[string]*schema2.SourceInfo](ctx, sk)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("resolved sources not found llm node, key: %s", sk)
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package loop
|
||||
package _break
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -22,21 +22,36 @@ import (
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
type Break struct {
|
||||
parentIntermediateStore variable.Store
|
||||
}
|
||||
|
||||
func NewBreak(_ context.Context, store variable.Store) (*Break, error) {
|
||||
type Config struct{}
|
||||
|
||||
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
return &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeBreak,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: c,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Config) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
return &Break{
|
||||
parentIntermediateStore: store,
|
||||
parentIntermediateStore: &nodes.ParentIntermediateStore{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
const BreakKey = "$break"
|
||||
|
||||
func (b *Break) DoBreak(ctx context.Context, _ map[string]any) (map[string]any, error) {
|
||||
func (b *Break) Invoke(ctx context.Context, _ map[string]any) (map[string]any, error) {
|
||||
err := b.parentIntermediateStore.Set(ctx, compose.FieldPath{BreakKey}, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -0,0 +1,47 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package _continue
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
type Continue struct{}
|
||||
|
||||
type Config struct{}
|
||||
|
||||
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
return &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeContinue,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: c,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Config) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
return &Continue{}, nil
|
||||
}
|
||||
|
||||
func (co *Continue) Invoke(_ context.Context, in map[string]any) (map[string]any, error) {
|
||||
return in, nil
|
||||
}
|
||||
@@ -27,53 +27,150 @@ import (
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
_break "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop/break"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
type Loop struct {
|
||||
config *Config
|
||||
outputs map[string]*vo.FieldSource
|
||||
outputVars map[string]string
|
||||
inner compose.Runnable[map[string]any, map[string]any]
|
||||
nodeKey vo.NodeKey
|
||||
|
||||
loopType Type
|
||||
inputArrays []string
|
||||
intermediateVars map[string]*vo.TypeInfo
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
LoopNodeKey vo.NodeKey
|
||||
LoopType Type
|
||||
InputArrays []string
|
||||
IntermediateVars map[string]*vo.TypeInfo
|
||||
Outputs []*vo.FieldInfo
|
||||
|
||||
Inner compose.Runnable[map[string]any, map[string]any]
|
||||
}
|
||||
|
||||
type Type string
|
||||
|
||||
const (
|
||||
ByArray Type = "by_array"
|
||||
ByIteration Type = "by_iteration"
|
||||
Infinite Type = "infinite"
|
||||
)
|
||||
|
||||
func NewLoop(_ context.Context, conf *Config) (*Loop, error) {
|
||||
if conf == nil {
|
||||
return nil, errors.New("config is nil")
|
||||
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
if n.Parent() != nil {
|
||||
return nil, fmt.Errorf("loop node cannot have parent: %s", n.Parent().ID)
|
||||
}
|
||||
|
||||
if conf.LoopType == ByArray {
|
||||
if len(conf.InputArrays) == 0 {
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeLoop,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: c,
|
||||
}
|
||||
|
||||
loopType, err := toLoopType(n.Data.Inputs.LoopType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.LoopType = loopType
|
||||
|
||||
intermediateVars := make(map[string]*vo.TypeInfo)
|
||||
for _, param := range n.Data.Inputs.VariableParameters {
|
||||
tInfo, err := convert.CanvasBlockInputToTypeInfo(param.Input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
intermediateVars[param.Name] = tInfo
|
||||
|
||||
ns.SetInputType(param.Name, tInfo)
|
||||
sources, err := convert.CanvasBlockInputToFieldInfo(param.Input, compose.FieldPath{param.Name}, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ns.AddInputSource(sources...)
|
||||
}
|
||||
c.IntermediateVars = intermediateVars
|
||||
|
||||
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := convert.SetOutputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, fieldInfo := range ns.OutputSources {
|
||||
if fieldInfo.Source.Ref != nil {
|
||||
if len(fieldInfo.Source.Ref.FromPath) == 1 {
|
||||
if _, ok := intermediateVars[fieldInfo.Source.Ref.FromPath[0]]; ok {
|
||||
fieldInfo.Source.Ref.VariableType = ptr.Of(vo.ParentIntermediate)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
loopCount := n.Data.Inputs.LoopCount
|
||||
if loopCount != nil {
|
||||
typeInfo, err := convert.CanvasBlockInputToTypeInfo(loopCount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ns.SetInputType(Count, typeInfo)
|
||||
|
||||
sources, err := convert.CanvasBlockInputToFieldInfo(loopCount, compose.FieldPath{Count}, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ns.AddInputSource(sources...)
|
||||
}
|
||||
|
||||
for key, tInfo := range ns.InputTypes {
|
||||
if tInfo.Type != vo.DataTypeArray {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := intermediateVars[key]; ok { // exclude arrays in intermediate vars
|
||||
continue
|
||||
}
|
||||
|
||||
c.InputArrays = append(c.InputArrays, key)
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func toLoopType(l vo.LoopType) (Type, error) {
|
||||
switch l {
|
||||
case vo.LoopTypeArray:
|
||||
return ByArray, nil
|
||||
case vo.LoopTypeCount:
|
||||
return ByIteration, nil
|
||||
case vo.LoopTypeInfinite:
|
||||
return Infinite, nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported loop type: %s", l)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, opts ...schema.BuildOption) (any, error) {
|
||||
if c.LoopType == ByArray {
|
||||
if len(c.InputArrays) == 0 {
|
||||
return nil, errors.New("input arrays is empty when loop type is ByArray")
|
||||
}
|
||||
}
|
||||
|
||||
loop := &Loop{
|
||||
config: conf,
|
||||
outputs: make(map[string]*vo.FieldSource),
|
||||
outputVars: make(map[string]string),
|
||||
options := schema.GetBuildOptions(opts...)
|
||||
if options.Inner == nil {
|
||||
return nil, errors.New("inner workflow is required for Loop Node")
|
||||
}
|
||||
|
||||
for _, info := range conf.Outputs {
|
||||
loop := &Loop{
|
||||
outputs: make(map[string]*vo.FieldSource),
|
||||
outputVars: make(map[string]string),
|
||||
inputArrays: c.InputArrays,
|
||||
nodeKey: ns.Key,
|
||||
intermediateVars: c.IntermediateVars,
|
||||
inner: options.Inner,
|
||||
loopType: c.LoopType,
|
||||
}
|
||||
|
||||
for _, info := range ns.OutputSources {
|
||||
if len(info.Path) != 1 {
|
||||
return nil, fmt.Errorf("invalid output path: %s", info.Path)
|
||||
}
|
||||
@@ -87,7 +184,7 @@ func NewLoop(_ context.Context, conf *Config) (*Loop, error) {
|
||||
return nil, fmt.Errorf("loop output refers to intermediate variable, but path length > 1: %v", fromPath)
|
||||
}
|
||||
|
||||
if _, ok := conf.IntermediateVars[fromPath[0]]; !ok {
|
||||
if _, ok := c.IntermediateVars[fromPath[0]]; !ok {
|
||||
return nil, fmt.Errorf("loop output refers to intermediate variable, but not found in intermediate vars: %v", fromPath)
|
||||
}
|
||||
|
||||
@@ -102,18 +199,27 @@ func NewLoop(_ context.Context, conf *Config) (*Loop, error) {
|
||||
return loop, nil
|
||||
}
|
||||
|
||||
type Type string
|
||||
|
||||
const (
|
||||
ByArray Type = "by_array"
|
||||
ByIteration Type = "by_iteration"
|
||||
Infinite Type = "infinite"
|
||||
)
|
||||
|
||||
const (
|
||||
Count = "loopCount"
|
||||
)
|
||||
|
||||
func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.NestedWorkflowOption) (out map[string]any, err error) {
|
||||
func (l *Loop) Invoke(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (
|
||||
out map[string]any, err error) {
|
||||
maxIter, err := l.getMaxIter(in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
arrays := make(map[string][]any, len(l.config.InputArrays))
|
||||
for _, arrayKey := range l.config.InputArrays {
|
||||
arrays := make(map[string][]any, len(l.inputArrays))
|
||||
for _, arrayKey := range l.inputArrays {
|
||||
a, ok := nodes.TakeMapValue(in, compose.FieldPath{arrayKey})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("incoming array not present in input: %s", arrayKey)
|
||||
@@ -121,10 +227,7 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
|
||||
arrays[arrayKey] = a.([]any)
|
||||
}
|
||||
|
||||
options := &nodes.NestedWorkflowOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
options := nodes.GetCommonOptions(&nodes.NodeOptions{}, opts...)
|
||||
|
||||
var (
|
||||
existingCState *nodes.NestedWorkflowState
|
||||
@@ -134,7 +237,7 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
|
||||
)
|
||||
err = compose.ProcessState(ctx, func(ctx context.Context, getter nodes.NestedWorkflowAware) error {
|
||||
var e error
|
||||
existingCState, _, e = getter.GetNestedWorkflowState(l.config.LoopNodeKey)
|
||||
existingCState, _, e = getter.GetNestedWorkflowState(l.nodeKey)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
@@ -150,15 +253,15 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
|
||||
for k := range existingCState.IntermediateVars {
|
||||
intermediateVars[k] = ptr.Of(existingCState.IntermediateVars[k])
|
||||
}
|
||||
intermediateVars[BreakKey] = &hasBreak
|
||||
intermediateVars[_break.BreakKey] = &hasBreak
|
||||
} else {
|
||||
output = make(map[string]any, len(l.outputs))
|
||||
for k := range l.outputs {
|
||||
output[k] = make([]any, 0)
|
||||
}
|
||||
|
||||
intermediateVars = make(map[string]*any, len(l.config.IntermediateVars))
|
||||
for varKey := range l.config.IntermediateVars {
|
||||
intermediateVars = make(map[string]*any, len(l.intermediateVars))
|
||||
for varKey := range l.intermediateVars {
|
||||
v, ok := nodes.TakeMapValue(in, compose.FieldPath{varKey})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("incoming intermediate variable not present in input: %s", varKey)
|
||||
@@ -166,10 +269,10 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
|
||||
|
||||
intermediateVars[varKey] = &v
|
||||
}
|
||||
intermediateVars[BreakKey] = &hasBreak
|
||||
intermediateVars[_break.BreakKey] = &hasBreak
|
||||
}
|
||||
|
||||
ctx = nodes.InitIntermediateVars(ctx, intermediateVars, l.config.IntermediateVars)
|
||||
ctx = nodes.InitIntermediateVars(ctx, intermediateVars, l.intermediateVars)
|
||||
|
||||
getIthInput := func(i int) (map[string]any, map[string]any, error) {
|
||||
input := make(map[string]any)
|
||||
@@ -190,13 +293,13 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
|
||||
input[k] = v
|
||||
}
|
||||
|
||||
input[string(l.config.LoopNodeKey)+"#index"] = int64(i)
|
||||
input[string(l.nodeKey)+"#index"] = int64(i)
|
||||
|
||||
items := make(map[string]any)
|
||||
for arrayKey := range arrays {
|
||||
ele := arrays[arrayKey][i]
|
||||
items[arrayKey] = ele
|
||||
currentKey := string(l.config.LoopNodeKey) + "#" + arrayKey
|
||||
currentKey := string(l.nodeKey) + "#" + arrayKey
|
||||
|
||||
// Recursively expand map[string]any elements
|
||||
var expand func(prefix string, val interface{})
|
||||
@@ -276,7 +379,7 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
|
||||
}
|
||||
}
|
||||
|
||||
taskOutput, err := l.config.Inner.Invoke(subCtx, input, ithOpts...)
|
||||
taskOutput, err := l.inner.Invoke(subCtx, input, ithOpts...)
|
||||
if err != nil {
|
||||
info, ok := compose.ExtractInterruptInfo(err)
|
||||
if !ok {
|
||||
@@ -322,29 +425,26 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
|
||||
|
||||
if len(index2InterruptInfo) > 0 { // this invocation of batch.Execute has new interruptions
|
||||
iEvent := &entity.InterruptEvent{
|
||||
NodeKey: l.config.LoopNodeKey,
|
||||
NodeKey: l.nodeKey,
|
||||
NodeType: entity.NodeTypeLoop,
|
||||
NestedInterruptInfo: index2InterruptInfo, // only emit the newly generated interruptInfo
|
||||
}
|
||||
|
||||
err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error {
|
||||
if e := setter.SaveNestedWorkflowState(l.config.LoopNodeKey, compState); e != nil {
|
||||
if e := setter.SaveNestedWorkflowState(l.nodeKey, compState); e != nil {
|
||||
return e
|
||||
}
|
||||
|
||||
return setter.SetInterruptEvent(l.config.LoopNodeKey, iEvent)
|
||||
return setter.SetInterruptEvent(l.nodeKey, iEvent)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fmt.Println("save interruptEvent in state within loop: ", iEvent)
|
||||
fmt.Println("save composite info in state within loop: ", compState)
|
||||
|
||||
return nil, compose.InterruptAndRerun
|
||||
} else {
|
||||
err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error {
|
||||
return setter.SaveNestedWorkflowState(l.config.LoopNodeKey, compState)
|
||||
return setter.SaveNestedWorkflowState(l.nodeKey, compState)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -354,8 +454,7 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
|
||||
}
|
||||
|
||||
if existingCState != nil && len(existingCState.Index2InterruptInfo) > 0 {
|
||||
fmt.Println("no interrupt thrown this round, but has historical interrupt events: ", existingCState.Index2InterruptInfo)
|
||||
panic("impossible")
|
||||
panic(fmt.Sprintf("no interrupt thrown this round, but has historical interrupt events: %v", existingCState.Index2InterruptInfo))
|
||||
}
|
||||
|
||||
for outputVarKey, intermediateVarKey := range l.outputVars {
|
||||
@@ -368,9 +467,9 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
|
||||
func (l *Loop) getMaxIter(in map[string]any) (int, error) {
|
||||
maxIter := math.MaxInt
|
||||
|
||||
switch l.config.LoopType {
|
||||
switch l.loopType {
|
||||
case ByArray:
|
||||
for _, arrayKey := range l.config.InputArrays {
|
||||
for _, arrayKey := range l.inputArrays {
|
||||
a, ok := nodes.TakeMapValue(in, compose.FieldPath{arrayKey})
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("incoming array not present in input: %s", arrayKey)
|
||||
@@ -394,7 +493,7 @@ func (l *Loop) getMaxIter(in map[string]any) (int, error) {
|
||||
maxIter = int(iter.(int64))
|
||||
case Infinite:
|
||||
default:
|
||||
return 0, fmt.Errorf("loop type not supported: %v", l.config.LoopType)
|
||||
return 0, fmt.Errorf("loop type not supported: %v", l.loopType)
|
||||
}
|
||||
|
||||
return maxIter, nil
|
||||
@@ -409,8 +508,8 @@ func convertIntermediateVars(vars map[string]*any) map[string]any {
|
||||
}
|
||||
|
||||
func (l *Loop) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) {
|
||||
trimmed := make(map[string]any, len(l.config.InputArrays))
|
||||
for _, arrayKey := range l.config.InputArrays {
|
||||
trimmed := make(map[string]any, len(l.inputArrays))
|
||||
for _, arrayKey := range l.inputArrays {
|
||||
if v, ok := in[arrayKey]; ok {
|
||||
trimmed[arrayKey] = v
|
||||
}
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
)
|
||||
|
||||
type NestedWorkflowOptions struct {
|
||||
optsForNested []compose.Option
|
||||
toResumeIndexes map[int]compose.StateModifier
|
||||
optsForIndexed map[int][]compose.Option
|
||||
}
|
||||
|
||||
type NestedWorkflowOption func(*NestedWorkflowOptions)
|
||||
|
||||
func WithOptsForNested(opts ...compose.Option) NestedWorkflowOption {
|
||||
return func(o *NestedWorkflowOptions) {
|
||||
o.optsForNested = append(o.optsForNested, opts...)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *NestedWorkflowOptions) GetOptsForNested() []compose.Option {
|
||||
return c.optsForNested
|
||||
}
|
||||
|
||||
func WithResumeIndex(i int, m compose.StateModifier) NestedWorkflowOption {
|
||||
return func(o *NestedWorkflowOptions) {
|
||||
if o.toResumeIndexes == nil {
|
||||
o.toResumeIndexes = map[int]compose.StateModifier{}
|
||||
}
|
||||
|
||||
o.toResumeIndexes[i] = m
|
||||
}
|
||||
}
|
||||
|
||||
func (c *NestedWorkflowOptions) GetResumeIndexes() map[int]compose.StateModifier {
|
||||
return c.toResumeIndexes
|
||||
}
|
||||
|
||||
func WithOptsForIndexed(index int, opts ...compose.Option) NestedWorkflowOption {
|
||||
return func(o *NestedWorkflowOptions) {
|
||||
if o.optsForIndexed == nil {
|
||||
o.optsForIndexed = map[int][]compose.Option{}
|
||||
}
|
||||
o.optsForIndexed[index] = opts
|
||||
}
|
||||
}
|
||||
|
||||
func (c *NestedWorkflowOptions) GetOptsForIndexed(index int) []compose.Option {
|
||||
if c.optsForIndexed == nil {
|
||||
return nil
|
||||
}
|
||||
return c.optsForIndexed[index]
|
||||
}
|
||||
|
||||
type NestedWorkflowState struct {
|
||||
Index2Done map[int]bool `json:"index_2_done,omitempty"`
|
||||
Index2InterruptInfo map[int]*compose.InterruptInfo `json:"index_2_interrupt_info,omitempty"`
|
||||
FullOutput map[string]any `json:"full_output,omitempty"`
|
||||
IntermediateVars map[string]any `json:"intermediate_vars,omitempty"`
|
||||
}
|
||||
|
||||
func (c *NestedWorkflowState) String() string {
|
||||
s, _ := sonic.MarshalIndent(c, "", " ")
|
||||
return string(s)
|
||||
}
|
||||
|
||||
type NestedWorkflowAware interface {
|
||||
SaveNestedWorkflowState(key vo.NodeKey, state *NestedWorkflowState) error
|
||||
GetNestedWorkflowState(key vo.NodeKey) (*NestedWorkflowState, bool, error)
|
||||
InterruptEventStore
|
||||
}
|
||||
194
backend/domain/workflow/internal/nodes/node.go
Normal file
194
backend/domain/workflow/internal/nodes/node.go
Normal file
@@ -0,0 +1,194 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
einoschema "github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
// InvokableNode is a basic workflow node that can Invoke.
|
||||
// Invoke accepts non-streaming input and returns non-streaming output.
|
||||
// It does not accept any options.
|
||||
// Most nodes implement this, such as NodeTypePlugin.
|
||||
type InvokableNode interface {
|
||||
Invoke(ctx context.Context, input map[string]any) (
|
||||
output map[string]any, err error)
|
||||
}
|
||||
|
||||
// InvokableNodeWOpt is a workflow node that can Invoke.
|
||||
// Invoke accepts non-streaming input and returns non-streaming output.
|
||||
// It can accept NodeOption.
|
||||
// e.g. NodeTypeLLM, NodeTypeSubWorkflow implement this.
|
||||
type InvokableNodeWOpt interface {
|
||||
Invoke(ctx context.Context, in map[string]any, opts ...NodeOption) (
|
||||
map[string]any, error)
|
||||
}
|
||||
|
||||
// StreamableNode is a workflow node that can Stream.
|
||||
// Stream accepts non-streaming input and returns streaming output.
|
||||
// It does not accept and options
|
||||
// Currently NO Node implement this.
|
||||
// A potential example would be streamable plugin for NodeTypePlugin.
|
||||
type StreamableNode interface {
|
||||
Stream(ctx context.Context, in map[string]any) (
|
||||
*einoschema.StreamReader[map[string]any], error)
|
||||
}
|
||||
|
||||
// StreamableNodeWOpt is a workflow node that can Stream.
|
||||
// Stream accepts non-streaming input and returns streaming output.
|
||||
// It can accept NodeOption.
|
||||
// e.g. NodeTypeLLM implement this.
|
||||
type StreamableNodeWOpt interface {
|
||||
Stream(ctx context.Context, in map[string]any, opts ...NodeOption) (
|
||||
*einoschema.StreamReader[map[string]any], error)
|
||||
}
|
||||
|
||||
// CollectableNode is a workflow node that can Collect.
|
||||
// Collect accepts streaming input and returns non-streaming output.
|
||||
// It does not accept and options
|
||||
// Currently NO Node implement this.
|
||||
// A potential example would be a new condition node that makes decisions
|
||||
// based on streaming input.
|
||||
type CollectableNode interface {
|
||||
Collect(ctx context.Context, in *einoschema.StreamReader[map[string]any]) (
|
||||
map[string]any, error)
|
||||
}
|
||||
|
||||
// CollectableNodeWOpt is a workflow node that can Collect.
|
||||
// Collect accepts streaming input and returns non-streaming output.
|
||||
// It accepts NodeOption.
|
||||
// Currently NO Node implement this.
|
||||
// A potential example would be a new batch node that accepts streaming input,
|
||||
// process them, and finally returns non-stream aggregation of results.
|
||||
type CollectableNodeWOpt interface {
|
||||
Collect(ctx context.Context, in *einoschema.StreamReader[map[string]any], opts ...NodeOption) (
|
||||
map[string]any, error)
|
||||
}
|
||||
|
||||
// TransformableNode is a workflow node that can Transform.
|
||||
// Transform accepts streaming input and returns streaming output.
|
||||
// It does not accept and options
|
||||
// e.g.
|
||||
// NodeTypeVariableAggregator implements TransformableNode.
|
||||
type TransformableNode interface {
|
||||
Transform(ctx context.Context, in *einoschema.StreamReader[map[string]any]) (
|
||||
*einoschema.StreamReader[map[string]any], error)
|
||||
}
|
||||
|
||||
// TransformableNodeWOpt is a workflow node that can Transform.
|
||||
// Transform accepts streaming input and returns streaming output.
|
||||
// It accepts NodeOption.
|
||||
// Currently NO Node implement this.
|
||||
// A potential example would be an audio processing node that
|
||||
// transforms input audio clips, but within the node is a graph
|
||||
// composed by Eino, and the audio processing node needs to carry
|
||||
// options for this inner graph.
|
||||
type TransformableNodeWOpt interface {
|
||||
Transform(ctx context.Context, in *einoschema.StreamReader[map[string]any], opts ...NodeOption) (
|
||||
*einoschema.StreamReader[map[string]any], error)
|
||||
}
|
||||
|
||||
// CallbackInputConverted converts node input to a form better suited for UI.
|
||||
// The converted input will be displayed on canvas when test run,
|
||||
// and will be returned when querying the node's input through OpenAPI.
|
||||
type CallbackInputConverted interface {
|
||||
ToCallbackInput(ctx context.Context, in map[string]any) (map[string]any, error)
|
||||
}
|
||||
|
||||
// CallbackOutputConverted converts node input to a form better suited for UI.
|
||||
// The converted output will be displayed on canvas when test run,
|
||||
// and will be returned when querying the node's output through OpenAPI.
|
||||
type CallbackOutputConverted interface {
|
||||
ToCallbackOutput(ctx context.Context, out map[string]any) (*StructuredCallbackOutput, error)
|
||||
}
|
||||
|
||||
type Initializer interface {
|
||||
Init(ctx context.Context) (context.Context, error)
|
||||
}
|
||||
|
||||
type AdaptOptions struct {
|
||||
Canvas *vo.Canvas
|
||||
}
|
||||
|
||||
type AdaptOption func(*AdaptOptions)
|
||||
|
||||
func WithCanvas(canvas *vo.Canvas) AdaptOption {
|
||||
return func(opts *AdaptOptions) {
|
||||
opts.Canvas = canvas
|
||||
}
|
||||
}
|
||||
|
||||
func GetAdaptOptions(opts ...AdaptOption) *AdaptOptions {
|
||||
options := &AdaptOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
return options
|
||||
}
|
||||
|
||||
// NodeAdaptor provides conversion from frontend Node to backend NodeSchema.
|
||||
type NodeAdaptor interface {
|
||||
Adapt(ctx context.Context, n *vo.Node, opts ...AdaptOption) (
|
||||
*schema.NodeSchema, error)
|
||||
}
|
||||
|
||||
// BranchAdaptor provides validation and conversion from frontend port to backend port.
|
||||
type BranchAdaptor interface {
|
||||
ExpectPorts(ctx context.Context, n *vo.Node) []string
|
||||
}
|
||||
|
||||
var (
|
||||
nodeAdaptors = map[entity.NodeType]func() NodeAdaptor{}
|
||||
branchAdaptors = map[entity.NodeType]func() BranchAdaptor{}
|
||||
)
|
||||
|
||||
func RegisterNodeAdaptor(et entity.NodeType, f func() NodeAdaptor) {
|
||||
nodeAdaptors[et] = f
|
||||
}
|
||||
|
||||
func GetNodeAdaptor(et entity.NodeType) (NodeAdaptor, bool) {
|
||||
na, ok := nodeAdaptors[et]
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("node type %s not registered", et))
|
||||
}
|
||||
return na(), ok
|
||||
}
|
||||
|
||||
func RegisterBranchAdaptor(et entity.NodeType, f func() BranchAdaptor) {
|
||||
branchAdaptors[et] = f
|
||||
}
|
||||
|
||||
func GetBranchAdaptor(et entity.NodeType) (BranchAdaptor, bool) {
|
||||
na, ok := branchAdaptors[et]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return na(), ok
|
||||
}
|
||||
|
||||
type StreamGenerator interface {
|
||||
FieldStreamType(path compose.FieldPath, ns *schema.NodeSchema,
|
||||
sc *schema.WorkflowSchema) (schema.FieldStreamType, error)
|
||||
}
|
||||
170
backend/domain/workflow/internal/nodes/option.go
Normal file
170
backend/domain/workflow/internal/nodes/option.go
Normal file
@@ -0,0 +1,170 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
)
|
||||
|
||||
type NodeOptions struct {
|
||||
Nested *NestedWorkflowOptions
|
||||
}
|
||||
|
||||
type NestedWorkflowOptions struct {
|
||||
optsForNested []compose.Option
|
||||
toResumeIndexes map[int]compose.StateModifier
|
||||
optsForIndexed map[int][]compose.Option
|
||||
}
|
||||
|
||||
type NodeOption struct {
|
||||
apply func(opts *NodeOptions)
|
||||
|
||||
implSpecificOptFn any
|
||||
}
|
||||
|
||||
type NestedWorkflowOption func(*NestedWorkflowOptions)
|
||||
|
||||
func WithOptsForNested(opts ...compose.Option) NodeOption {
|
||||
return NodeOption{
|
||||
apply: func(options *NodeOptions) {
|
||||
if options.Nested == nil {
|
||||
options.Nested = &NestedWorkflowOptions{}
|
||||
}
|
||||
options.Nested.optsForNested = append(options.Nested.optsForNested, opts...)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *NodeOptions) GetOptsForNested() []compose.Option {
|
||||
if c.Nested == nil {
|
||||
return nil
|
||||
}
|
||||
return c.Nested.optsForNested
|
||||
}
|
||||
|
||||
func WithResumeIndex(i int, m compose.StateModifier) NodeOption {
|
||||
return NodeOption{
|
||||
apply: func(options *NodeOptions) {
|
||||
if options.Nested == nil {
|
||||
options.Nested = &NestedWorkflowOptions{}
|
||||
}
|
||||
if options.Nested.toResumeIndexes == nil {
|
||||
options.Nested.toResumeIndexes = map[int]compose.StateModifier{}
|
||||
}
|
||||
|
||||
options.Nested.toResumeIndexes[i] = m
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *NodeOptions) GetResumeIndexes() map[int]compose.StateModifier {
|
||||
if c.Nested == nil {
|
||||
return nil
|
||||
}
|
||||
return c.Nested.toResumeIndexes
|
||||
}
|
||||
|
||||
func WithOptsForIndexed(index int, opts ...compose.Option) NodeOption {
|
||||
return NodeOption{
|
||||
apply: func(options *NodeOptions) {
|
||||
if options.Nested == nil {
|
||||
options.Nested = &NestedWorkflowOptions{}
|
||||
}
|
||||
if options.Nested.optsForIndexed == nil {
|
||||
options.Nested.optsForIndexed = map[int][]compose.Option{}
|
||||
}
|
||||
options.Nested.optsForIndexed[index] = opts
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *NodeOptions) GetOptsForIndexed(index int) []compose.Option {
|
||||
if c.Nested == nil {
|
||||
return nil
|
||||
}
|
||||
return c.Nested.optsForIndexed[index]
|
||||
}
|
||||
|
||||
// WrapImplSpecificOptFn is the option to wrap the implementation specific option function.
|
||||
func WrapImplSpecificOptFn[T any](optFn func(*T)) NodeOption {
|
||||
return NodeOption{
|
||||
implSpecificOptFn: optFn,
|
||||
}
|
||||
}
|
||||
|
||||
// GetCommonOptions extract model Options from Option list, optionally providing a base Options with default values.
|
||||
func GetCommonOptions(base *NodeOptions, opts ...NodeOption) *NodeOptions {
|
||||
if base == nil {
|
||||
base = &NodeOptions{}
|
||||
}
|
||||
|
||||
for i := range opts {
|
||||
opt := opts[i]
|
||||
if opt.apply != nil {
|
||||
opt.apply(base)
|
||||
}
|
||||
}
|
||||
|
||||
return base
|
||||
}
|
||||
|
||||
// GetImplSpecificOptions extract the implementation specific options from Option list, optionally providing a base options with default values.
|
||||
// e.g.
|
||||
//
|
||||
// myOption := &MyOption{
|
||||
// Field1: "default_value",
|
||||
// }
|
||||
//
|
||||
// myOption := model.GetImplSpecificOptions(myOption, opts...)
|
||||
func GetImplSpecificOptions[T any](base *T, opts ...NodeOption) *T {
|
||||
if base == nil {
|
||||
base = new(T)
|
||||
}
|
||||
|
||||
for i := range opts {
|
||||
opt := opts[i]
|
||||
if opt.implSpecificOptFn != nil {
|
||||
optFn, ok := opt.implSpecificOptFn.(func(*T))
|
||||
if ok {
|
||||
optFn(base)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return base
|
||||
}
|
||||
|
||||
type NestedWorkflowState struct {
|
||||
Index2Done map[int]bool `json:"index_2_done,omitempty"`
|
||||
Index2InterruptInfo map[int]*compose.InterruptInfo `json:"index_2_interrupt_info,omitempty"`
|
||||
FullOutput map[string]any `json:"full_output,omitempty"`
|
||||
IntermediateVars map[string]any `json:"intermediate_vars,omitempty"`
|
||||
}
|
||||
|
||||
func (c *NestedWorkflowState) String() string {
|
||||
s, _ := sonic.MarshalIndent(c, "", " ")
|
||||
return string(s)
|
||||
}
|
||||
|
||||
type NestedWorkflowAware interface {
|
||||
SaveNestedWorkflowState(key vo.NodeKey, state *NestedWorkflowState) error
|
||||
GetNestedWorkflowState(key vo.NodeKey) (*NestedWorkflowState, bool, error)
|
||||
InterruptEventStore
|
||||
}
|
||||
@@ -18,16 +18,21 @@ package plugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
@@ -35,29 +40,76 @@ type Config struct {
|
||||
PluginID int64
|
||||
ToolID int64
|
||||
PluginVersion string
|
||||
}
|
||||
|
||||
PluginService plugin.Service
|
||||
func (c *Config) Adapt(ctx context.Context, n *vo.Node, opts ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypePlugin,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: c,
|
||||
}
|
||||
inputs := n.Data.Inputs
|
||||
|
||||
apiParams := slices.ToMap(inputs.APIParams, func(e *vo.Param) (string, *vo.Param) {
|
||||
return e.Name, e
|
||||
})
|
||||
|
||||
ps, ok := apiParams["pluginID"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("plugin id param is not found")
|
||||
}
|
||||
|
||||
pID, err := strconv.ParseInt(ps.Input.Value.Content.(string), 10, 64)
|
||||
|
||||
c.PluginID = pID
|
||||
|
||||
ps, ok = apiParams["apiID"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("plugin id param is not found")
|
||||
}
|
||||
|
||||
tID, err := strconv.ParseInt(ps.Input.Value.Content.(string), 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.ToolID = tID
|
||||
|
||||
ps, ok = apiParams["pluginVersion"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("plugin version param is not found")
|
||||
}
|
||||
version := ps.Input.Value.Content.(string)
|
||||
|
||||
c.PluginVersion = version
|
||||
|
||||
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func (c *Config) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
return &Plugin{
|
||||
pluginID: c.PluginID,
|
||||
toolID: c.ToolID,
|
||||
pluginVersion: c.PluginVersion,
|
||||
pluginService: plugin.GetPluginService(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type Plugin struct {
|
||||
config *Config
|
||||
}
|
||||
pluginID int64
|
||||
toolID int64
|
||||
pluginVersion string
|
||||
|
||||
func NewPlugin(_ context.Context, cfg *Config) (*Plugin, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config is nil")
|
||||
}
|
||||
if cfg.PluginID == 0 {
|
||||
return nil, errors.New("plugin id is required")
|
||||
}
|
||||
if cfg.ToolID == 0 {
|
||||
return nil, errors.New("tool id is required")
|
||||
}
|
||||
if cfg.PluginService == nil {
|
||||
return nil, errors.New("tool service is required")
|
||||
}
|
||||
|
||||
return &Plugin{config: cfg}, nil
|
||||
pluginService plugin.Service
|
||||
}
|
||||
|
||||
func (p *Plugin) Invoke(ctx context.Context, parameters map[string]any) (ret map[string]any, err error) {
|
||||
@@ -65,10 +117,10 @@ func (p *Plugin) Invoke(ctx context.Context, parameters map[string]any) (ret map
|
||||
if ctxExeCfg := execute.GetExeCtx(ctx); ctxExeCfg != nil {
|
||||
exeCfg = ctxExeCfg.ExeCfg
|
||||
}
|
||||
result, err := p.config.PluginService.ExecutePlugin(ctx, parameters, &vo.PluginEntity{
|
||||
PluginID: p.config.PluginID,
|
||||
PluginVersion: ptr.Of(p.config.PluginVersion),
|
||||
}, p.config.ToolID, exeCfg)
|
||||
result, err := p.pluginService.ExecutePlugin(ctx, parameters, &vo.PluginEntity{
|
||||
PluginID: p.pluginID,
|
||||
PluginVersion: ptr.Of(p.pluginVersion),
|
||||
}, p.toolID, exeCfg)
|
||||
if err != nil {
|
||||
if extra, ok := compose.IsInterruptRerunError(err); ok {
|
||||
// TODO: temporarily replace interrupt with real error, because frontend cannot handle interrupt for now
|
||||
|
||||
@@ -29,9 +29,12 @@ import (
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
crossmodel "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
@@ -39,8 +42,21 @@ import (
|
||||
)
|
||||
|
||||
type QuestionAnswer struct {
|
||||
config *Config
|
||||
model model.BaseChatModel
|
||||
nodeMeta entity.NodeTypeMeta
|
||||
|
||||
questionTpl string
|
||||
answerType AnswerType
|
||||
|
||||
choiceType ChoiceType
|
||||
fixedChoices []string
|
||||
|
||||
needExtractFromAnswer bool
|
||||
additionalSystemPromptTpl string
|
||||
maxAnswerCount int
|
||||
|
||||
nodeKey vo.NodeKey
|
||||
outputFields map[string]*vo.TypeInfo
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -51,15 +67,249 @@ type Config struct {
|
||||
FixedChoices []string
|
||||
|
||||
// used for intent recognize if answer by choices and given a custom answer, as well as for extracting structured output from user response
|
||||
Model model.BaseChatModel
|
||||
LLMParams *crossmodel.LLMParams
|
||||
|
||||
// the following are required if AnswerType is AnswerDirectly and needs to extract from answer
|
||||
ExtractFromAnswer bool
|
||||
AdditionalSystemPromptTpl string
|
||||
MaxAnswerCount int
|
||||
OutputFields map[string]*vo.TypeInfo
|
||||
}
|
||||
|
||||
NodeKey vo.NodeKey
|
||||
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema2.NodeSchema, error) {
|
||||
ns := &schema2.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeQuestionAnswer,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: c,
|
||||
}
|
||||
|
||||
qaConf := n.Data.Inputs.QA
|
||||
if qaConf == nil {
|
||||
return nil, fmt.Errorf("qa config is nil")
|
||||
}
|
||||
c.QuestionTpl = qaConf.Question
|
||||
|
||||
var llmParams *crossmodel.LLMParams
|
||||
if n.Data.Inputs.LLMParam != nil {
|
||||
llmParamBytes, err := sonic.Marshal(n.Data.Inputs.LLMParam)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var qaLLMParams vo.SimpleLLMParam
|
||||
err = sonic.Unmarshal(llmParamBytes, &qaLLMParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
llmParams, err = convertLLMParams(qaLLMParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.LLMParams = llmParams
|
||||
}
|
||||
|
||||
answerType, err := convertAnswerType(qaConf.AnswerType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.AnswerType = answerType
|
||||
|
||||
var choiceType ChoiceType
|
||||
if len(qaConf.OptionType) > 0 {
|
||||
choiceType, err = convertChoiceType(qaConf.OptionType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.ChoiceType = choiceType
|
||||
}
|
||||
|
||||
if answerType == AnswerByChoices {
|
||||
switch choiceType {
|
||||
case FixedChoices:
|
||||
var options []string
|
||||
for _, option := range qaConf.Options {
|
||||
options = append(options, option.Name)
|
||||
}
|
||||
c.FixedChoices = options
|
||||
case DynamicChoices:
|
||||
inputSources, err := convert.CanvasBlockInputToFieldInfo(qaConf.DynamicOption, compose.FieldPath{DynamicChoicesKey}, n.Parent())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ns.AddInputSource(inputSources...)
|
||||
|
||||
inputTypes, err := convert.CanvasBlockInputToTypeInfo(qaConf.DynamicOption)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ns.SetInputType(DynamicChoicesKey, inputTypes)
|
||||
default:
|
||||
return nil, fmt.Errorf("qa node is answer by options, but option type not provided")
|
||||
}
|
||||
} else if answerType == AnswerDirectly {
|
||||
c.ExtractFromAnswer = qaConf.ExtractOutput
|
||||
if qaConf.ExtractOutput {
|
||||
if llmParams == nil {
|
||||
return nil, fmt.Errorf("qa node needs to extract from answer, but LLMParams not provided")
|
||||
}
|
||||
c.AdditionalSystemPromptTpl = llmParams.SystemPrompt
|
||||
c.MaxAnswerCount = qaConf.Limit
|
||||
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err = convert.SetInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func convertLLMParams(params vo.SimpleLLMParam) (*crossmodel.LLMParams, error) {
|
||||
p := &crossmodel.LLMParams{}
|
||||
p.ModelName = params.ModelName
|
||||
p.ModelType = params.ModelType
|
||||
p.Temperature = ¶ms.Temperature
|
||||
p.MaxTokens = params.MaxTokens
|
||||
p.TopP = ¶ms.TopP
|
||||
p.ResponseFormat = params.ResponseFormat
|
||||
p.SystemPrompt = params.SystemPrompt
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func convertAnswerType(t vo.QAAnswerType) (AnswerType, error) {
|
||||
switch t {
|
||||
case vo.QAAnswerTypeOption:
|
||||
return AnswerByChoices, nil
|
||||
case vo.QAAnswerTypeText:
|
||||
return AnswerDirectly, nil
|
||||
default:
|
||||
return "", fmt.Errorf("invalid QAAnswerType: %s", t)
|
||||
}
|
||||
}
|
||||
|
||||
func convertChoiceType(t vo.QAOptionType) (ChoiceType, error) {
|
||||
switch t {
|
||||
case vo.QAOptionTypeStatic:
|
||||
return FixedChoices, nil
|
||||
case vo.QAOptionTypeDynamic:
|
||||
return DynamicChoices, nil
|
||||
default:
|
||||
return "", fmt.Errorf("invalid QAOptionType: %s", t)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) Build(ctx context.Context, ns *schema2.NodeSchema, _ ...schema2.BuildOption) (any, error) {
|
||||
if c.AnswerType == AnswerDirectly {
|
||||
if c.ExtractFromAnswer {
|
||||
if c.LLMParams == nil {
|
||||
return nil, errors.New("model is required when extract from answer")
|
||||
}
|
||||
if len(ns.OutputTypes) == 0 {
|
||||
return nil, errors.New("output fields is required when extract from answer")
|
||||
}
|
||||
}
|
||||
} else if c.AnswerType == AnswerByChoices {
|
||||
if c.ChoiceType == FixedChoices {
|
||||
if len(c.FixedChoices) == 0 {
|
||||
return nil, errors.New("fixed choices is required when extract from answer")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("unknown answer type: %s", c.AnswerType)
|
||||
}
|
||||
|
||||
nodeMeta := entity.NodeMetaByNodeType(entity.NodeTypeQuestionAnswer)
|
||||
if nodeMeta == nil {
|
||||
return nil, errors.New("node meta not found for question answer")
|
||||
}
|
||||
|
||||
var (
|
||||
m model.BaseChatModel
|
||||
err error
|
||||
)
|
||||
if c.LLMParams != nil {
|
||||
m, _, err = crossmodel.GetManager().GetModel(ctx, c.LLMParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &QuestionAnswer{
|
||||
model: m,
|
||||
nodeMeta: *nodeMeta,
|
||||
questionTpl: c.QuestionTpl,
|
||||
answerType: c.AnswerType,
|
||||
choiceType: c.ChoiceType,
|
||||
fixedChoices: c.FixedChoices,
|
||||
needExtractFromAnswer: c.ExtractFromAnswer,
|
||||
additionalSystemPromptTpl: c.AdditionalSystemPromptTpl,
|
||||
maxAnswerCount: c.MaxAnswerCount,
|
||||
nodeKey: ns.Key,
|
||||
outputFields: ns.OutputTypes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Config) BuildBranch(_ context.Context) (
|
||||
func(ctx context.Context, nodeOutput map[string]any) (int64, bool, error), bool) {
|
||||
if c.AnswerType != AnswerByChoices {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return func(ctx context.Context, nodeOutput map[string]any) (int64, bool, error) {
|
||||
optionID, ok := nodeOutput[OptionIDKey]
|
||||
if !ok {
|
||||
return -1, false, fmt.Errorf("failed to take option id from input map: %v", nodeOutput)
|
||||
}
|
||||
|
||||
if c.ChoiceType == DynamicChoices {
|
||||
if optionID.(string) == "other" {
|
||||
return -1, true, nil
|
||||
} else {
|
||||
return 0, false, nil
|
||||
}
|
||||
}
|
||||
|
||||
if optionID.(string) == "other" {
|
||||
return -1, true, nil
|
||||
}
|
||||
|
||||
optionIDInt, ok := AlphabetToInt(optionID.(string))
|
||||
if !ok {
|
||||
return -1, false, fmt.Errorf("failed to convert option id from input map: %v", optionID)
|
||||
}
|
||||
|
||||
return optionIDInt, false, nil
|
||||
}, true
|
||||
}
|
||||
|
||||
func (c *Config) ExpectPorts(ctx context.Context, n *vo.Node) (expects []string) {
|
||||
if n.Data.Inputs.QA.AnswerType != vo.QAAnswerTypeOption {
|
||||
return expects
|
||||
}
|
||||
|
||||
if n.Data.Inputs.QA.OptionType == vo.QAOptionTypeStatic {
|
||||
for index := range n.Data.Inputs.QA.Options {
|
||||
expects = append(expects, fmt.Sprintf(schema2.PortBranchFormat, index))
|
||||
}
|
||||
|
||||
expects = append(expects, schema2.PortDefault)
|
||||
return expects
|
||||
}
|
||||
|
||||
if n.Data.Inputs.QA.OptionType == vo.QAOptionTypeDynamic {
|
||||
expects = append(expects, fmt.Sprintf(schema2.PortBranchFormat, 0))
|
||||
expects = append(expects, schema2.PortDefault)
|
||||
}
|
||||
|
||||
return expects
|
||||
}
|
||||
|
||||
func (c *Config) RequireCheckpoint() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type AnswerType string
|
||||
@@ -126,41 +376,6 @@ Strictly identify the intention and select the most suitable option. You can onl
|
||||
Note: You can only output the id or -1. Your output can only be a pure number and no other content (including the reason)!`
|
||||
)
|
||||
|
||||
func NewQuestionAnswer(_ context.Context, conf *Config) (*QuestionAnswer, error) {
|
||||
if conf == nil {
|
||||
return nil, errors.New("config is nil")
|
||||
}
|
||||
|
||||
if conf.AnswerType == AnswerDirectly {
|
||||
if conf.ExtractFromAnswer {
|
||||
if conf.Model == nil {
|
||||
return nil, errors.New("model is required when extract from answer")
|
||||
}
|
||||
if len(conf.OutputFields) == 0 {
|
||||
return nil, errors.New("output fields is required when extract from answer")
|
||||
}
|
||||
}
|
||||
} else if conf.AnswerType == AnswerByChoices {
|
||||
if conf.ChoiceType == FixedChoices {
|
||||
if len(conf.FixedChoices) == 0 {
|
||||
return nil, errors.New("fixed choices is required when extract from answer")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("unknown answer type: %s", conf.AnswerType)
|
||||
}
|
||||
|
||||
nodeMeta := entity.NodeMetaByNodeType(entity.NodeTypeQuestionAnswer)
|
||||
if nodeMeta == nil {
|
||||
return nil, errors.New("node meta not found for question answer")
|
||||
}
|
||||
|
||||
return &QuestionAnswer{
|
||||
config: conf,
|
||||
nodeMeta: *nodeMeta,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type Question struct {
|
||||
Question string
|
||||
Choices []string
|
||||
@@ -182,10 +397,10 @@ type message struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
}
|
||||
|
||||
// Execute formats the question (optionally with choices), interrupts, then extracts the answer.
|
||||
// Invoke formats the question (optionally with choices), interrupts, then extracts the answer.
|
||||
// input: the references by input fields, as well as the dynamic choices array if needed.
|
||||
// output: USER_RESPONSE for direct answer, structured output if needs to extract from answer, and option ID / content for answer by choices.
|
||||
func (q *QuestionAnswer) Execute(ctx context.Context, in map[string]any) (out map[string]any, err error) {
|
||||
func (q *QuestionAnswer) Invoke(ctx context.Context, in map[string]any) (out map[string]any, err error) {
|
||||
var (
|
||||
questions []*Question
|
||||
answers []string
|
||||
@@ -206,11 +421,11 @@ func (q *QuestionAnswer) Execute(ctx context.Context, in map[string]any) (out ma
|
||||
out[QuestionsKey] = questions
|
||||
out[AnswersKey] = answers
|
||||
|
||||
switch q.config.AnswerType {
|
||||
switch q.answerType {
|
||||
case AnswerDirectly:
|
||||
if isFirst { // first execution, ask the question
|
||||
// format the question. Which is common to all use cases
|
||||
firstQuestion, err := nodes.TemplateRender(q.config.QuestionTpl, in)
|
||||
firstQuestion, err := nodes.TemplateRender(q.questionTpl, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -218,7 +433,7 @@ func (q *QuestionAnswer) Execute(ctx context.Context, in map[string]any) (out ma
|
||||
return nil, q.interrupt(ctx, firstQuestion, nil, nil, nil)
|
||||
}
|
||||
|
||||
if q.config.ExtractFromAnswer {
|
||||
if q.needExtractFromAnswer {
|
||||
return q.extractFromAnswer(ctx, in, questions, answers)
|
||||
}
|
||||
|
||||
@@ -253,15 +468,15 @@ func (q *QuestionAnswer) Execute(ctx context.Context, in map[string]any) (out ma
|
||||
}
|
||||
|
||||
// format the question. Which is common to all use cases
|
||||
firstQuestion, err := nodes.TemplateRender(q.config.QuestionTpl, in)
|
||||
firstQuestion, err := nodes.TemplateRender(q.questionTpl, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var formattedChoices []string
|
||||
switch q.config.ChoiceType {
|
||||
switch q.choiceType {
|
||||
case FixedChoices:
|
||||
for _, choice := range q.config.FixedChoices {
|
||||
for _, choice := range q.fixedChoices {
|
||||
formattedChoice, err := nodes.TemplateRender(choice, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -283,18 +498,18 @@ func (q *QuestionAnswer) Execute(ctx context.Context, in map[string]any) (out ma
|
||||
formattedChoices = append(formattedChoices, c)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown choice type: %s", q.config.ChoiceType)
|
||||
return nil, fmt.Errorf("unknown choice type: %s", q.choiceType)
|
||||
}
|
||||
|
||||
return nil, q.interrupt(ctx, firstQuestion, formattedChoices, nil, nil)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown answer type: %s", q.config.AnswerType)
|
||||
return nil, fmt.Errorf("unknown answer type: %s", q.answerType)
|
||||
}
|
||||
}
|
||||
|
||||
func (q *QuestionAnswer) extractFromAnswer(ctx context.Context, in map[string]any, questions []*Question, answers []string) (map[string]any, error) {
|
||||
fieldInfo := "FieldInfo"
|
||||
s, err := vo.TypeInfoToJSONSchema(q.config.OutputFields, &fieldInfo)
|
||||
s, err := vo.TypeInfoToJSONSchema(q.outputFields, &fieldInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -302,15 +517,15 @@ func (q *QuestionAnswer) extractFromAnswer(ctx context.Context, in map[string]an
|
||||
sysPrompt := fmt.Sprintf(extractSystemPrompt, s)
|
||||
|
||||
var requiredFields []string
|
||||
for fName, tInfo := range q.config.OutputFields {
|
||||
for fName, tInfo := range q.outputFields {
|
||||
if tInfo.Required {
|
||||
requiredFields = append(requiredFields, fName)
|
||||
}
|
||||
}
|
||||
|
||||
var formattedAdditionalPrompt string
|
||||
if len(q.config.AdditionalSystemPromptTpl) > 0 {
|
||||
additionalPrompt, err := nodes.TemplateRender(q.config.AdditionalSystemPromptTpl, in)
|
||||
if len(q.additionalSystemPromptTpl) > 0 {
|
||||
additionalPrompt, err := nodes.TemplateRender(q.additionalSystemPromptTpl, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -336,7 +551,7 @@ func (q *QuestionAnswer) extractFromAnswer(ctx context.Context, in map[string]an
|
||||
messages = append(messages, schema.UserMessage(answer))
|
||||
}
|
||||
|
||||
out, err := q.config.Model.Generate(ctx, messages)
|
||||
out, err := q.model.Generate(ctx, messages)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -353,8 +568,8 @@ func (q *QuestionAnswer) extractFromAnswer(ctx context.Context, in map[string]an
|
||||
if ok {
|
||||
nextQuestionStr, ok := nextQuestion.(string)
|
||||
if ok && len(nextQuestionStr) > 0 {
|
||||
if len(answers) >= q.config.MaxAnswerCount {
|
||||
return nil, fmt.Errorf("max answer count= %d exceeded", q.config.MaxAnswerCount)
|
||||
if len(answers) >= q.maxAnswerCount {
|
||||
return nil, fmt.Errorf("max answer count= %d exceeded", q.maxAnswerCount)
|
||||
}
|
||||
|
||||
return nil, q.interrupt(ctx, nextQuestionStr, nil, questions, answers)
|
||||
@@ -366,7 +581,7 @@ func (q *QuestionAnswer) extractFromAnswer(ctx context.Context, in map[string]an
|
||||
return nil, fmt.Errorf("field %s not found", fieldInfo)
|
||||
}
|
||||
|
||||
realOutput, ws, err := nodes.ConvertInputs(ctx, fields.(map[string]any), q.config.OutputFields, nodes.SkipRequireCheck())
|
||||
realOutput, ws, err := nodes.ConvertInputs(ctx, fields.(map[string]any), q.outputFields, nodes.SkipRequireCheck())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -431,7 +646,7 @@ func (q *QuestionAnswer) intentDetect(ctx context.Context, answer string, choice
|
||||
schema.UserMessage(answer),
|
||||
}
|
||||
|
||||
out, err := q.config.Model.Generate(ctx, messages)
|
||||
out, err := q.model.Generate(ctx, messages)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
@@ -468,7 +683,7 @@ func (q *QuestionAnswer) interrupt(ctx context.Context, newQuestion string, choi
|
||||
|
||||
event := &entity.InterruptEvent{
|
||||
ID: eventID,
|
||||
NodeKey: q.config.NodeKey,
|
||||
NodeKey: q.nodeKey,
|
||||
NodeType: entity.NodeTypeQuestionAnswer,
|
||||
NodeTitle: q.nodeMeta.Name,
|
||||
NodeIcon: q.nodeMeta.IconURL,
|
||||
@@ -477,7 +692,7 @@ func (q *QuestionAnswer) interrupt(ctx context.Context, newQuestion string, choi
|
||||
}
|
||||
|
||||
_ = compose.ProcessState(ctx, func(ctx context.Context, setter QuestionAnswerAware) error {
|
||||
setter.AddQuestion(q.config.NodeKey, &Question{
|
||||
setter.AddQuestion(q.nodeKey, &Question{
|
||||
Question: newQuestion,
|
||||
Choices: choices,
|
||||
})
|
||||
@@ -495,14 +710,14 @@ func intToAlphabet(num int) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func AlphabetToInt(str string) (int, bool) {
|
||||
func AlphabetToInt(str string) (int64, bool) {
|
||||
if len(str) != 1 {
|
||||
return 0, false
|
||||
}
|
||||
char := rune(str[0])
|
||||
char = unicode.ToUpper(char)
|
||||
if char >= 'A' && char <= 'Z' {
|
||||
return int(char - 'A'), true
|
||||
return int64(char - 'A'), true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
@@ -521,14 +736,14 @@ func (q *QuestionAnswer) generateHistory(oldQuestions []*Question, oldAnswers []
|
||||
for i := 0; i < len(oldQuestions); i++ {
|
||||
oldQuestion := oldQuestions[i]
|
||||
oldAnswer := oldAnswers[i]
|
||||
contentType := ternary.IFElse(q.config.AnswerType == AnswerByChoices, "option", "text")
|
||||
contentType := ternary.IFElse(q.answerType == AnswerByChoices, "option", "text")
|
||||
questionMsg := &message{
|
||||
Type: "question",
|
||||
ContentType: contentType,
|
||||
ID: fmt.Sprintf("%s_%d", q.config.NodeKey, i*2),
|
||||
ID: fmt.Sprintf("%s_%d", q.nodeKey, i*2),
|
||||
}
|
||||
|
||||
if q.config.AnswerType == AnswerByChoices {
|
||||
if q.answerType == AnswerByChoices {
|
||||
questionMsg.Content = optionContent{
|
||||
Options: conv(oldQuestion.Choices),
|
||||
Question: oldQuestion.Question,
|
||||
@@ -541,14 +756,14 @@ func (q *QuestionAnswer) generateHistory(oldQuestions []*Question, oldAnswers []
|
||||
Type: "answer",
|
||||
ContentType: contentType,
|
||||
Content: oldAnswer,
|
||||
ID: fmt.Sprintf("%s_%d", q.config.NodeKey, i+1),
|
||||
ID: fmt.Sprintf("%s_%d", q.nodeKey, i+1),
|
||||
}
|
||||
|
||||
history = append(history, questionMsg, answerMsg)
|
||||
}
|
||||
|
||||
if newQuestion != nil {
|
||||
if q.config.AnswerType == AnswerByChoices {
|
||||
if q.answerType == AnswerByChoices {
|
||||
history = append(history, &message{
|
||||
Type: "question",
|
||||
ContentType: "option",
|
||||
@@ -556,14 +771,14 @@ func (q *QuestionAnswer) generateHistory(oldQuestions []*Question, oldAnswers []
|
||||
Options: conv(choices),
|
||||
Question: *newQuestion,
|
||||
},
|
||||
ID: fmt.Sprintf("%s_%d", q.config.NodeKey, len(oldQuestions)*2),
|
||||
ID: fmt.Sprintf("%s_%d", q.nodeKey, len(oldQuestions)*2),
|
||||
})
|
||||
} else {
|
||||
history = append(history, &message{
|
||||
Type: "question",
|
||||
ContentType: "text",
|
||||
Content: *newQuestion,
|
||||
ID: fmt.Sprintf("%s_%d", q.config.NodeKey, len(oldQuestions)*2),
|
||||
ID: fmt.Sprintf("%s_%d", q.nodeKey, len(oldQuestions)*2),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,8 +27,10 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
@@ -37,19 +39,27 @@ import (
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
OutputTypes map[string]*vo.TypeInfo
|
||||
NodeKey vo.NodeKey
|
||||
OutputSchema string
|
||||
}
|
||||
|
||||
type InputReceiver struct {
|
||||
outputTypes map[string]*vo.TypeInfo
|
||||
interruptData string
|
||||
nodeKey vo.NodeKey
|
||||
nodeMeta entity.NodeTypeMeta
|
||||
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
c.OutputSchema = n.Data.Inputs.OutputSchema
|
||||
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeInputReceiver,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: c,
|
||||
}
|
||||
|
||||
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func New(_ context.Context, cfg *Config) (*InputReceiver, error) {
|
||||
func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
nodeMeta := entity.NodeMetaByNodeType(entity.NodeTypeInputReceiver)
|
||||
if nodeMeta == nil {
|
||||
return nil, errors.New("node meta not found for input receiver")
|
||||
@@ -57,7 +67,7 @@ func New(_ context.Context, cfg *Config) (*InputReceiver, error) {
|
||||
|
||||
interruptData := map[string]string{
|
||||
"content_type": "form_schema",
|
||||
"content": cfg.OutputSchema,
|
||||
"content": c.OutputSchema,
|
||||
}
|
||||
|
||||
interruptDataStr, err := sonic.ConfigStd.MarshalToString(interruptData) // keep the order of the keys
|
||||
@@ -66,13 +76,24 @@ func New(_ context.Context, cfg *Config) (*InputReceiver, error) {
|
||||
}
|
||||
|
||||
return &InputReceiver{
|
||||
outputTypes: cfg.OutputTypes,
|
||||
outputTypes: ns.OutputTypes, // so the node can refer to its output types during execution
|
||||
nodeMeta: *nodeMeta,
|
||||
nodeKey: cfg.NodeKey,
|
||||
nodeKey: ns.Key,
|
||||
interruptData: interruptDataStr,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Config) RequireCheckpoint() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type InputReceiver struct {
|
||||
outputTypes map[string]*vo.TypeInfo
|
||||
interruptData string
|
||||
nodeKey vo.NodeKey
|
||||
nodeMeta entity.NodeTypeMeta
|
||||
}
|
||||
|
||||
const (
|
||||
ReceivedDataKey = "$received_data"
|
||||
receiverWarningKey = "receiver_warning_%d_%s"
|
||||
|
||||
190
backend/domain/workflow/internal/nodes/selector/callbacks.go
Normal file
190
backend/domain/workflow/internal/nodes/selector/callbacks.go
Normal file
@@ -0,0 +1,190 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package selector
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
)
|
||||
|
||||
type selectorCallbackField struct {
|
||||
Key string `json:"key"`
|
||||
Type vo.DataType `json:"type"`
|
||||
Value any `json:"value"`
|
||||
}
|
||||
|
||||
type selectorCondition struct {
|
||||
Left selectorCallbackField `json:"left"`
|
||||
Operator vo.OperatorType `json:"operator"`
|
||||
Right *selectorCallbackField `json:"right,omitempty"`
|
||||
}
|
||||
|
||||
type selectorBranch struct {
|
||||
Conditions []*selectorCondition `json:"conditions"`
|
||||
Logic vo.LogicType `json:"logic"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func (s *Selector) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) {
|
||||
count := len(s.clauses)
|
||||
|
||||
output := make([]*selectorBranch, count)
|
||||
|
||||
for _, source := range s.ns.InputSources {
|
||||
targetPath := source.Path
|
||||
if len(targetPath) == 2 {
|
||||
indexStr := targetPath[0]
|
||||
index, err := strconv.Atoi(indexStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
branch := output[index]
|
||||
if branch == nil {
|
||||
output[index] = &selectorBranch{
|
||||
Conditions: []*selectorCondition{
|
||||
{
|
||||
Operator: s.clauses[index].Single.ToCanvasOperatorType(),
|
||||
},
|
||||
},
|
||||
Logic: ClauseRelationAND.ToVOLogicType(),
|
||||
}
|
||||
}
|
||||
|
||||
if targetPath[1] == LeftKey {
|
||||
leftV, ok := nodes.TakeMapValue(in, targetPath)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to take left value of %s", targetPath)
|
||||
}
|
||||
if source.Source.Ref.VariableType != nil {
|
||||
if *source.Source.Ref.VariableType == vo.ParentIntermediate {
|
||||
parentNodeKey, ok := s.ws.Hierarchy[s.ns.Key]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to find parent node key of %s", s.ns.Key)
|
||||
}
|
||||
parentNode := s.ws.GetNode(parentNodeKey)
|
||||
output[index].Conditions[0].Left = selectorCallbackField{
|
||||
Key: parentNode.Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
|
||||
Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
|
||||
Value: leftV,
|
||||
}
|
||||
} else {
|
||||
output[index].Conditions[0].Left = selectorCallbackField{
|
||||
Key: "",
|
||||
Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
|
||||
Value: leftV,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
output[index].Conditions[0].Left = selectorCallbackField{
|
||||
Key: s.ws.GetNode(source.Source.Ref.FromNodeKey).Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
|
||||
Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
|
||||
Value: leftV,
|
||||
}
|
||||
}
|
||||
} else if targetPath[1] == RightKey {
|
||||
rightV, ok := nodes.TakeMapValue(in, targetPath)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to take right value of %s", targetPath)
|
||||
}
|
||||
output[index].Conditions[0].Right = &selectorCallbackField{
|
||||
Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
|
||||
Value: rightV,
|
||||
}
|
||||
}
|
||||
} else if len(targetPath) == 3 {
|
||||
indexStr := targetPath[0]
|
||||
index, err := strconv.Atoi(indexStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
multi := s.clauses[index].Multi
|
||||
|
||||
branch := output[index]
|
||||
if branch == nil {
|
||||
output[index] = &selectorBranch{
|
||||
Conditions: make([]*selectorCondition, len(multi.Clauses)),
|
||||
Logic: multi.Relation.ToVOLogicType(),
|
||||
}
|
||||
}
|
||||
|
||||
clauseIndexStr := targetPath[1]
|
||||
clauseIndex, err := strconv.Atoi(clauseIndexStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
clause := multi.Clauses[clauseIndex]
|
||||
|
||||
if output[index].Conditions[clauseIndex] == nil {
|
||||
output[index].Conditions[clauseIndex] = &selectorCondition{
|
||||
Operator: clause.ToCanvasOperatorType(),
|
||||
}
|
||||
}
|
||||
|
||||
if targetPath[2] == LeftKey {
|
||||
leftV, ok := nodes.TakeMapValue(in, targetPath)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to take left value of %s", targetPath)
|
||||
}
|
||||
if source.Source.Ref.VariableType != nil {
|
||||
if *source.Source.Ref.VariableType == vo.ParentIntermediate {
|
||||
parentNodeKey, ok := s.ws.Hierarchy[s.ns.Key]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to find parent node key of %s", s.ns.Key)
|
||||
}
|
||||
parentNode := s.ws.GetNode(parentNodeKey)
|
||||
output[index].Conditions[clauseIndex].Left = selectorCallbackField{
|
||||
Key: parentNode.Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
|
||||
Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
|
||||
Value: leftV,
|
||||
}
|
||||
} else {
|
||||
output[index].Conditions[clauseIndex].Left = selectorCallbackField{
|
||||
Key: "",
|
||||
Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
|
||||
Value: leftV,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
output[index].Conditions[clauseIndex].Left = selectorCallbackField{
|
||||
Key: s.ws.GetNode(source.Source.Ref.FromNodeKey).Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
|
||||
Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
|
||||
Value: leftV,
|
||||
}
|
||||
}
|
||||
} else if targetPath[2] == RightKey {
|
||||
rightV, ok := nodes.TakeMapValue(in, targetPath)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to take right value of %s", targetPath)
|
||||
}
|
||||
output[index].Conditions[clauseIndex].Right = &selectorCallbackField{
|
||||
Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
|
||||
Value: rightV,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]any{"branches": output}, nil
|
||||
}
|
||||
@@ -180,3 +180,48 @@ func (o *Operator) ToCanvasOperatorType() vo.OperatorType {
|
||||
panic(fmt.Sprintf("unknown operator: %+v", o))
|
||||
}
|
||||
}
|
||||
|
||||
func ToSelectorOperator(o vo.OperatorType, leftType *vo.TypeInfo) (Operator, error) {
|
||||
switch o {
|
||||
case vo.Equal:
|
||||
return OperatorEqual, nil
|
||||
case vo.NotEqual:
|
||||
return OperatorNotEqual, nil
|
||||
case vo.LengthGreaterThan:
|
||||
return OperatorLengthGreater, nil
|
||||
case vo.LengthGreaterThanEqual:
|
||||
return OperatorLengthGreaterOrEqual, nil
|
||||
case vo.LengthLessThan:
|
||||
return OperatorLengthLesser, nil
|
||||
case vo.LengthLessThanEqual:
|
||||
return OperatorLengthLesserOrEqual, nil
|
||||
case vo.Contain:
|
||||
if leftType.Type == vo.DataTypeObject {
|
||||
return OperatorContainKey, nil
|
||||
}
|
||||
return OperatorContain, nil
|
||||
case vo.NotContain:
|
||||
if leftType.Type == vo.DataTypeObject {
|
||||
return OperatorNotContainKey, nil
|
||||
}
|
||||
return OperatorNotContain, nil
|
||||
case vo.Empty:
|
||||
return OperatorEmpty, nil
|
||||
case vo.NotEmpty:
|
||||
return OperatorNotEmpty, nil
|
||||
case vo.True:
|
||||
return OperatorIsTrue, nil
|
||||
case vo.False:
|
||||
return OperatorIsFalse, nil
|
||||
case vo.GreaterThan:
|
||||
return OperatorGreater, nil
|
||||
case vo.GreaterThanEqual:
|
||||
return OperatorGreaterOrEqual, nil
|
||||
case vo.LessThan:
|
||||
return OperatorLesser, nil
|
||||
case vo.LessThanEqual:
|
||||
return OperatorLesserOrEqual, nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported operator type: %d", o)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,9 +17,16 @@
|
||||
package selector
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
einoCompose "github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
type ClauseRelation string
|
||||
@@ -29,10 +36,6 @@ const (
|
||||
ClauseRelationOR ClauseRelation = "or"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Clauses []*OneClauseSchema `json:"clauses"`
|
||||
}
|
||||
|
||||
type OneClauseSchema struct {
|
||||
Single *Operator `json:"single,omitempty"`
|
||||
Multi *MultiClauseSchema `json:"multi,omitempty"`
|
||||
@@ -52,3 +55,140 @@ func (c ClauseRelation) ToVOLogicType() vo.LogicType {
|
||||
|
||||
panic(fmt.Sprintf("unknown clause relation: %s", c))
|
||||
}
|
||||
|
||||
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
clauses := make([]*OneClauseSchema, 0)
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Name: n.Data.Meta.Title,
|
||||
Type: entity.NodeTypeSelector,
|
||||
Configs: c,
|
||||
}
|
||||
|
||||
for i, branchCond := range n.Data.Inputs.Branches {
|
||||
inputType := &vo.TypeInfo{
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{},
|
||||
}
|
||||
|
||||
if len(branchCond.Condition.Conditions) == 1 { // single condition
|
||||
cond := branchCond.Condition.Conditions[0]
|
||||
|
||||
left := cond.Left
|
||||
if left == nil {
|
||||
return nil, fmt.Errorf("operator left is nil")
|
||||
}
|
||||
|
||||
leftType, err := convert.CanvasBlockInputToTypeInfo(left.Input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
leftSources, err := convert.CanvasBlockInputToFieldInfo(left.Input, einoCompose.FieldPath{fmt.Sprintf("%d", i), LeftKey}, n.Parent())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
inputType.Properties[LeftKey] = leftType
|
||||
|
||||
ns.AddInputSource(leftSources...)
|
||||
|
||||
op, err := ToSelectorOperator(cond.Operator, leftType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cond.Right != nil {
|
||||
rightType, err := convert.CanvasBlockInputToTypeInfo(cond.Right.Input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rightSources, err := convert.CanvasBlockInputToFieldInfo(cond.Right.Input, einoCompose.FieldPath{fmt.Sprintf("%d", i), RightKey}, n.Parent())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
inputType.Properties[RightKey] = rightType
|
||||
ns.AddInputSource(rightSources...)
|
||||
}
|
||||
|
||||
ns.SetInputType(fmt.Sprintf("%d", i), inputType)
|
||||
|
||||
clauses = append(clauses, &OneClauseSchema{
|
||||
Single: &op,
|
||||
})
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
var relation ClauseRelation
|
||||
logic := branchCond.Condition.Logic
|
||||
if logic == vo.OR {
|
||||
relation = ClauseRelationOR
|
||||
} else if logic == vo.AND {
|
||||
relation = ClauseRelationAND
|
||||
}
|
||||
|
||||
var ops []*Operator
|
||||
for j, cond := range branchCond.Condition.Conditions {
|
||||
left := cond.Left
|
||||
if left == nil {
|
||||
return nil, fmt.Errorf("operator left is nil")
|
||||
}
|
||||
|
||||
leftType, err := convert.CanvasBlockInputToTypeInfo(left.Input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
leftSources, err := convert.CanvasBlockInputToFieldInfo(left.Input, einoCompose.FieldPath{fmt.Sprintf("%d", i), fmt.Sprintf("%d", j), LeftKey}, n.Parent())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
inputType.Properties[fmt.Sprintf("%d", j)] = &vo.TypeInfo{
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
LeftKey: leftType,
|
||||
},
|
||||
}
|
||||
|
||||
ns.AddInputSource(leftSources...)
|
||||
|
||||
op, err := ToSelectorOperator(cond.Operator, leftType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ops = append(ops, &op)
|
||||
|
||||
if cond.Right != nil {
|
||||
rightType, err := convert.CanvasBlockInputToTypeInfo(cond.Right.Input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rightSources, err := convert.CanvasBlockInputToFieldInfo(cond.Right.Input, einoCompose.FieldPath{fmt.Sprintf("%d", i), fmt.Sprintf("%d", j), RightKey}, n.Parent())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
inputType.Properties[fmt.Sprintf("%d", j)].Properties[RightKey] = rightType
|
||||
ns.AddInputSource(rightSources...)
|
||||
}
|
||||
}
|
||||
|
||||
ns.SetInputType(fmt.Sprintf("%d", i), inputType)
|
||||
|
||||
clauses = append(clauses, &OneClauseSchema{
|
||||
Multi: &MultiClauseSchema{
|
||||
Clauses: ops,
|
||||
Relation: relation,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
c.Clauses = clauses
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
@@ -23,23 +23,32 @@ import (
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
type Selector struct {
|
||||
config *Config
|
||||
clauses []*OneClauseSchema
|
||||
ns *schema.NodeSchema
|
||||
ws *schema.WorkflowSchema
|
||||
}
|
||||
|
||||
func NewSelector(_ context.Context, config *Config) (*Selector, error) {
|
||||
if config == nil {
|
||||
return nil, fmt.Errorf("config is nil")
|
||||
type Config struct {
|
||||
Clauses []*OneClauseSchema `json:"clauses"`
|
||||
}
|
||||
|
||||
func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, opts ...schema.BuildOption) (any, error) {
|
||||
ws := schema.GetBuildOptions(opts...).WS
|
||||
if ws == nil {
|
||||
return nil, fmt.Errorf("workflow schema is required")
|
||||
}
|
||||
|
||||
if len(config.Clauses) == 0 {
|
||||
if len(c.Clauses) == 0 {
|
||||
return nil, fmt.Errorf("config clauses are empty")
|
||||
}
|
||||
|
||||
for _, clause := range config.Clauses {
|
||||
for _, clause := range c.Clauses {
|
||||
if clause.Single == nil && clause.Multi == nil {
|
||||
return nil, fmt.Errorf("single clause and multi clause are both nil")
|
||||
}
|
||||
@@ -60,10 +69,42 @@ func NewSelector(_ context.Context, config *Config) (*Selector, error) {
|
||||
}
|
||||
|
||||
return &Selector{
|
||||
config: config,
|
||||
clauses: c.Clauses,
|
||||
ns: ns,
|
||||
ws: ws,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Config) BuildBranch(_ context.Context) (
|
||||
func(ctx context.Context, nodeOutput map[string]any) (int64, bool, error), bool) {
|
||||
return func(ctx context.Context, nodeOutput map[string]any) (int64, bool, error) {
|
||||
choice := nodeOutput[SelectKey].(int64)
|
||||
if choice < 0 || choice > int64(len(c.Clauses)+1) {
|
||||
return -1, false, fmt.Errorf("selector choice out of range: %d", choice)
|
||||
}
|
||||
|
||||
if choice == int64(len(c.Clauses)) { // default
|
||||
return -1, true, nil
|
||||
}
|
||||
|
||||
return choice, false, nil
|
||||
}, true
|
||||
}
|
||||
|
||||
func (c *Config) ExpectPorts(_ context.Context, n *vo.Node) []string {
|
||||
expects := make([]string, len(n.Data.Inputs.Branches)+1)
|
||||
expects[0] = "false" // default branch
|
||||
if len(n.Data.Inputs.Branches) > 0 {
|
||||
expects[1] = "true" // first condition
|
||||
}
|
||||
|
||||
for i := 1; i < len(n.Data.Inputs.Branches); i++ { // other conditions
|
||||
expects[i+1] = "true_" + strconv.Itoa(i)
|
||||
}
|
||||
|
||||
return expects
|
||||
}
|
||||
|
||||
type Operants struct {
|
||||
Left any
|
||||
Right any
|
||||
@@ -76,14 +117,14 @@ const (
|
||||
SelectKey = "selected"
|
||||
)
|
||||
|
||||
func (s *Selector) Select(_ context.Context, input map[string]any) (out map[string]any, err error) {
|
||||
in, err := s.SelectorInputConverter(input)
|
||||
func (s *Selector) Invoke(_ context.Context, input map[string]any) (out map[string]any, err error) {
|
||||
in, err := s.selectorInputConverter(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
predicates := make([]Predicate, 0, len(s.config.Clauses))
|
||||
for i, oneConf := range s.config.Clauses {
|
||||
predicates := make([]Predicate, 0, len(s.clauses))
|
||||
for i, oneConf := range s.clauses {
|
||||
if oneConf.Single != nil {
|
||||
left := in[i].Left
|
||||
right := in[i].Right
|
||||
@@ -132,23 +173,15 @@ func (s *Selector) Select(_ context.Context, input map[string]any) (out map[stri
|
||||
}
|
||||
|
||||
if isTrue {
|
||||
return map[string]any{SelectKey: i}, nil
|
||||
return map[string]any{SelectKey: int64(i)}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]any{SelectKey: len(in)}, nil // default choice
|
||||
return map[string]any{SelectKey: int64(len(in))}, nil // default choice
|
||||
}
|
||||
|
||||
func (s *Selector) GetType() string {
|
||||
return "Selector"
|
||||
}
|
||||
|
||||
func (s *Selector) ConditionCount() int {
|
||||
return len(s.config.Clauses)
|
||||
}
|
||||
|
||||
func (s *Selector) SelectorInputConverter(in map[string]any) (out []Operants, err error) {
|
||||
conf := s.config.Clauses
|
||||
func (s *Selector) selectorInputConverter(in map[string]any) (out []Operants, err error) {
|
||||
conf := s.clauses
|
||||
|
||||
for i, oneConf := range conf {
|
||||
if oneConf.Single != nil {
|
||||
@@ -187,8 +220,8 @@ func (s *Selector) SelectorInputConverter(in map[string]any) (out []Operants, er
|
||||
}
|
||||
|
||||
func (s *Selector) ToCallbackOutput(_ context.Context, output map[string]any) (*nodes.StructuredCallbackOutput, error) {
|
||||
count := len(s.config.Clauses)
|
||||
out := output[SelectKey].(int)
|
||||
count := int64(len(s.clauses))
|
||||
out := output[SelectKey].(int64)
|
||||
if out == count {
|
||||
cOutput := map[string]any{"result": "pass to else branch"}
|
||||
return &nodes.StructuredCallbackOutput{
|
||||
|
||||
@@ -22,57 +22,27 @@ import (
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
var KeyIsFinished = "\x1FKey is finished\x1F"
|
||||
|
||||
type Mode string
|
||||
|
||||
const (
|
||||
Streaming Mode = "streaming"
|
||||
NonStreaming Mode = "non-streaming"
|
||||
)
|
||||
|
||||
type FieldStreamType string
|
||||
|
||||
const (
|
||||
FieldIsStream FieldStreamType = "yes" // absolutely a stream
|
||||
FieldNotStream FieldStreamType = "no" // absolutely not a stream
|
||||
FieldMaybeStream FieldStreamType = "maybe" // maybe a stream, requires request-time resolution
|
||||
FieldSkipped FieldStreamType = "skipped" // the field source's node is skipped
|
||||
)
|
||||
|
||||
// SourceInfo contains stream type for a input field source of a node.
|
||||
type SourceInfo struct {
|
||||
// IsIntermediate means this field is itself not a field source, but a map containing one or more field sources.
|
||||
IsIntermediate bool
|
||||
// FieldType the stream type of the field. May require request-time resolution in addition to compile-time.
|
||||
FieldType FieldStreamType
|
||||
// FromNodeKey is the node key that produces this field source. empty if the field is a static value or variable.
|
||||
FromNodeKey vo.NodeKey
|
||||
// FromPath is the path of this field source within the source node. empty if the field is a static value or variable.
|
||||
FromPath compose.FieldPath
|
||||
TypeInfo *vo.TypeInfo
|
||||
// SubSources are SourceInfo for keys within this intermediate Map(Object) field.
|
||||
SubSources map[string]*SourceInfo
|
||||
}
|
||||
|
||||
type DynamicStreamContainer interface {
|
||||
SaveDynamicChoice(nodeKey vo.NodeKey, groupToChoice map[string]int)
|
||||
GetDynamicChoice(nodeKey vo.NodeKey) map[string]int
|
||||
GetDynamicStreamType(nodeKey vo.NodeKey, group string) (FieldStreamType, error)
|
||||
GetAllDynamicStreamTypes(nodeKey vo.NodeKey) (map[string]FieldStreamType, error)
|
||||
GetDynamicStreamType(nodeKey vo.NodeKey, group string) (schema.FieldStreamType, error)
|
||||
GetAllDynamicStreamTypes(nodeKey vo.NodeKey) (map[string]schema.FieldStreamType, error)
|
||||
}
|
||||
|
||||
// ResolveStreamSources resolves incoming field sources for a node, deciding their stream type.
|
||||
func ResolveStreamSources(ctx context.Context, sources map[string]*SourceInfo) (map[string]*SourceInfo, error) {
|
||||
resolved := make(map[string]*SourceInfo, len(sources))
|
||||
func ResolveStreamSources(ctx context.Context, sources map[string]*schema.SourceInfo) (map[string]*schema.SourceInfo, error) {
|
||||
resolved := make(map[string]*schema.SourceInfo, len(sources))
|
||||
|
||||
nodeKey2Skipped := make(map[vo.NodeKey]bool)
|
||||
|
||||
var resolver func(path string, sInfo *SourceInfo) (*SourceInfo, error)
|
||||
resolver = func(path string, sInfo *SourceInfo) (*SourceInfo, error) {
|
||||
resolvedNode := &SourceInfo{
|
||||
var resolver func(path string, sInfo *schema.SourceInfo) (*schema.SourceInfo, error)
|
||||
resolver = func(path string, sInfo *schema.SourceInfo) (*schema.SourceInfo, error) {
|
||||
resolvedNode := &schema.SourceInfo{
|
||||
IsIntermediate: sInfo.IsIntermediate,
|
||||
FieldType: sInfo.FieldType,
|
||||
FromNodeKey: sInfo.FromNodeKey,
|
||||
@@ -81,7 +51,7 @@ func ResolveStreamSources(ctx context.Context, sources map[string]*SourceInfo) (
|
||||
}
|
||||
|
||||
if len(sInfo.SubSources) > 0 {
|
||||
resolvedNode.SubSources = make(map[string]*SourceInfo, len(sInfo.SubSources))
|
||||
resolvedNode.SubSources = make(map[string]*schema.SourceInfo, len(sInfo.SubSources))
|
||||
|
||||
for k, subInfo := range sInfo.SubSources {
|
||||
resolvedSub, err := resolver(k, subInfo)
|
||||
@@ -109,16 +79,16 @@ func ResolveStreamSources(ctx context.Context, sources map[string]*SourceInfo) (
|
||||
}
|
||||
|
||||
if skipped {
|
||||
resolvedNode.FieldType = FieldSkipped
|
||||
resolvedNode.FieldType = schema.FieldSkipped
|
||||
return resolvedNode, nil
|
||||
}
|
||||
|
||||
if sInfo.FieldType == FieldMaybeStream {
|
||||
if sInfo.FieldType == schema.FieldMaybeStream {
|
||||
if len(sInfo.SubSources) > 0 {
|
||||
panic("a maybe stream field should not have sub sources")
|
||||
}
|
||||
|
||||
var streamType FieldStreamType
|
||||
var streamType schema.FieldStreamType
|
||||
err := compose.ProcessState(ctx, func(ctx context.Context, state DynamicStreamContainer) error {
|
||||
var e error
|
||||
streamType, e = state.GetDynamicStreamType(sInfo.FromNodeKey, sInfo.FromPath[0])
|
||||
@@ -128,7 +98,7 @@ func ResolveStreamSources(ctx context.Context, sources map[string]*SourceInfo) (
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &SourceInfo{
|
||||
return &schema.SourceInfo{
|
||||
IsIntermediate: sInfo.IsIntermediate,
|
||||
FieldType: streamType,
|
||||
FromNodeKey: sInfo.FromNodeKey,
|
||||
@@ -156,30 +126,12 @@ type NodeExecuteStatusAware interface {
|
||||
NodeExecuted(key vo.NodeKey) bool
|
||||
}
|
||||
|
||||
func (s *SourceInfo) Skipped() bool {
|
||||
if !s.IsIntermediate {
|
||||
return s.FieldType == FieldSkipped
|
||||
func IsStreamingField(s *schema.NodeSchema, path compose.FieldPath,
|
||||
sc *schema.WorkflowSchema) (schema.FieldStreamType, error) {
|
||||
sg, ok := s.Configs.(StreamGenerator)
|
||||
if !ok {
|
||||
return schema.FieldNotStream, nil
|
||||
}
|
||||
|
||||
for _, sub := range s.SubSources {
|
||||
if !sub.Skipped() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *SourceInfo) FromNode(nodeKey vo.NodeKey) bool {
|
||||
if !s.IsIntermediate {
|
||||
return s.FromNodeKey == nodeKey
|
||||
}
|
||||
|
||||
for _, sub := range s.SubSources {
|
||||
if sub.FromNode(nodeKey) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
return sg.FieldStreamType(path, s, sc)
|
||||
}
|
||||
|
||||
@@ -18,7 +18,6 @@ package subworkflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
@@ -29,35 +28,56 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
|
||||
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Runner compose.Runnable[map[string]any, map[string]any]
|
||||
WorkflowID int64
|
||||
WorkflowVersion string
|
||||
}
|
||||
|
||||
func (c *Config) FieldStreamType(path compose.FieldPath, ns *schema2.NodeSchema,
|
||||
sc *schema2.WorkflowSchema) (schema2.FieldStreamType, error) {
|
||||
if !sc.RequireStreaming() {
|
||||
return schema2.FieldNotStream, nil
|
||||
}
|
||||
|
||||
innerWF := ns.SubWorkflowSchema
|
||||
|
||||
if !innerWF.RequireStreaming() {
|
||||
return schema2.FieldNotStream, nil
|
||||
}
|
||||
|
||||
innerExit := innerWF.GetNode(entity.ExitNodeKey)
|
||||
|
||||
if innerExit.Configs.(*exit.Config).TerminatePlan == vo.ReturnVariables {
|
||||
return schema2.FieldNotStream, nil
|
||||
}
|
||||
|
||||
if !innerExit.StreamConfigs.RequireStreamingInput {
|
||||
return schema2.FieldNotStream, nil
|
||||
}
|
||||
|
||||
if len(path) > 1 || path[0] != "output" {
|
||||
return schema2.FieldNotStream, fmt.Errorf(
|
||||
"streaming answering sub-workflow node can only have out field 'output'")
|
||||
}
|
||||
|
||||
return schema2.FieldIsStream, nil
|
||||
}
|
||||
|
||||
type SubWorkflow struct {
|
||||
cfg *Config
|
||||
Runner compose.Runnable[map[string]any, map[string]any]
|
||||
}
|
||||
|
||||
func NewSubWorkflow(_ context.Context, cfg *Config) (*SubWorkflow, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config is nil")
|
||||
}
|
||||
|
||||
if cfg.Runner == nil {
|
||||
return nil, errors.New("runnable is nil")
|
||||
}
|
||||
|
||||
return &SubWorkflow{cfg: cfg}, nil
|
||||
}
|
||||
|
||||
func (s *SubWorkflow) Invoke(ctx context.Context, in map[string]any, opts ...nodes.NestedWorkflowOption) (map[string]any, error) {
|
||||
func (s *SubWorkflow) Invoke(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (map[string]any, error) {
|
||||
nestedOpts, nodeKey, err := prepareOptions(ctx, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out, err := s.cfg.Runner.Invoke(ctx, in, nestedOpts...)
|
||||
out, err := s.Runner.Invoke(ctx, in, nestedOpts...)
|
||||
if err != nil {
|
||||
interruptInfo, ok := compose.ExtractInterruptInfo(err)
|
||||
if !ok {
|
||||
@@ -82,13 +102,13 @@ func (s *SubWorkflow) Invoke(ctx context.Context, in map[string]any, opts ...nod
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *SubWorkflow) Stream(ctx context.Context, in map[string]any, opts ...nodes.NestedWorkflowOption) (*schema.StreamReader[map[string]any], error) {
|
||||
func (s *SubWorkflow) Stream(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (*schema.StreamReader[map[string]any], error) {
|
||||
nestedOpts, nodeKey, err := prepareOptions(ctx, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out, err := s.cfg.Runner.Stream(ctx, in, nestedOpts...)
|
||||
out, err := s.Runner.Stream(ctx, in, nestedOpts...)
|
||||
if err != nil {
|
||||
interruptInfo, ok := compose.ExtractInterruptInfo(err)
|
||||
if !ok {
|
||||
@@ -114,11 +134,8 @@ func (s *SubWorkflow) Stream(ctx context.Context, in map[string]any, opts ...nod
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func prepareOptions(ctx context.Context, opts ...nodes.NestedWorkflowOption) ([]compose.Option, vo.NodeKey, error) {
|
||||
options := &nodes.NestedWorkflowOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
func prepareOptions(ctx context.Context, opts ...nodes.NodeOption) ([]compose.Option, vo.NodeKey, error) {
|
||||
options := nodes.GetCommonOptions(&nodes.NodeOptions{}, opts...)
|
||||
|
||||
nestedOpts := options.GetOptsForNested()
|
||||
|
||||
|
||||
@@ -30,6 +30,7 @@ import (
|
||||
"github.com/bytedance/sonic/ast"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
@@ -156,7 +157,7 @@ func removeSlice(s string) string {
|
||||
|
||||
type renderOptions struct {
|
||||
type2CustomRenderer map[reflect.Type]func(any) (string, error)
|
||||
reservedKey map[string]struct{}
|
||||
reservedKey map[string]struct{} // a reservedKey will always render, won't check for node skipping
|
||||
nilRenderer func() (string, error)
|
||||
}
|
||||
|
||||
@@ -300,7 +301,7 @@ func (tp TemplatePart) Render(m []byte, opts ...RenderOption) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (tp TemplatePart) Skipped(resolvedSources map[string]*SourceInfo) (skipped bool, invalid bool) {
|
||||
func (tp TemplatePart) Skipped(resolvedSources map[string]*schema.SourceInfo) (skipped bool, invalid bool) {
|
||||
if len(resolvedSources) == 0 { // no information available, maybe outside the scope of a workflow
|
||||
return false, false
|
||||
}
|
||||
@@ -316,7 +317,7 @@ func (tp TemplatePart) Skipped(resolvedSources map[string]*SourceInfo) (skipped
|
||||
}
|
||||
|
||||
if !matchingSource.IsIntermediate {
|
||||
return matchingSource.FieldType == FieldSkipped, false
|
||||
return matchingSource.FieldType == schema.FieldSkipped, false
|
||||
}
|
||||
|
||||
for _, subPath := range tp.SubPathsBeforeSlice {
|
||||
@@ -325,20 +326,20 @@ func (tp TemplatePart) Skipped(resolvedSources map[string]*SourceInfo) (skipped
|
||||
if matchingSource.IsIntermediate { // the user specified a non-existing source, just skip it
|
||||
return false, true
|
||||
}
|
||||
return matchingSource.FieldType == FieldSkipped, false
|
||||
return matchingSource.FieldType == schema.FieldSkipped, false
|
||||
}
|
||||
|
||||
matchingSource = subSource
|
||||
}
|
||||
|
||||
if !matchingSource.IsIntermediate {
|
||||
return matchingSource.FieldType == FieldSkipped, false
|
||||
return matchingSource.FieldType == schema.FieldSkipped, false
|
||||
}
|
||||
|
||||
var checkSourceSkipped func(sInfo *SourceInfo) bool
|
||||
checkSourceSkipped = func(sInfo *SourceInfo) bool {
|
||||
var checkSourceSkipped func(sInfo *schema.SourceInfo) bool
|
||||
checkSourceSkipped = func(sInfo *schema.SourceInfo) bool {
|
||||
if !sInfo.IsIntermediate {
|
||||
return sInfo.FieldType == FieldSkipped
|
||||
return sInfo.FieldType == schema.FieldSkipped
|
||||
}
|
||||
for _, subSource := range sInfo.SubSources {
|
||||
if !checkSourceSkipped(subSource) {
|
||||
@@ -373,7 +374,7 @@ func (tp TemplatePart) TypeInfo(types map[string]*vo.TypeInfo) *vo.TypeInfo {
|
||||
return currentType
|
||||
}
|
||||
|
||||
func Render(ctx context.Context, tpl string, input map[string]any, sources map[string]*SourceInfo, opts ...RenderOption) (string, error) {
|
||||
func Render(ctx context.Context, tpl string, input map[string]any, sources map[string]*schema.SourceInfo, opts ...RenderOption) (string, error) {
|
||||
mi, err := sonic.Marshal(input)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
||||
@@ -22,7 +22,11 @@ import (
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
)
|
||||
|
||||
@@ -34,42 +38,92 @@ const (
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Type Type `json:"type"`
|
||||
Tpl string `json:"tpl"`
|
||||
ConcatChar string `json:"concatChar"`
|
||||
Separators []string `json:"separator"`
|
||||
FullSources map[string]*nodes.SourceInfo `json:"fullSources"`
|
||||
Type Type `json:"type"`
|
||||
Tpl string `json:"tpl"`
|
||||
ConcatChar string `json:"concatChar"`
|
||||
Separators []string `json:"separator"`
|
||||
}
|
||||
|
||||
type TextProcessor struct {
|
||||
config *Config
|
||||
}
|
||||
|
||||
func NewTextProcessor(_ context.Context, cfg *Config) (*TextProcessor, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("config requried")
|
||||
func (c *Config) Adapt(ctx context.Context, n *vo.Node, opts ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeTextProcessor,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: c,
|
||||
}
|
||||
if cfg.Type == ConcatText && len(cfg.Tpl) == 0 {
|
||||
|
||||
if n.Data.Inputs.Method == vo.Concat {
|
||||
c.Type = ConcatText
|
||||
params := n.Data.Inputs.ConcatParams
|
||||
for _, param := range params {
|
||||
if param.Name == "concatResult" {
|
||||
c.Tpl = param.Input.Value.Content.(string)
|
||||
} else if param.Name == "arrayItemConcatChar" {
|
||||
c.ConcatChar = param.Input.Value.Content.(string)
|
||||
}
|
||||
}
|
||||
} else if n.Data.Inputs.Method == vo.Split {
|
||||
c.Type = SplitText
|
||||
params := n.Data.Inputs.SplitParams
|
||||
separators := make([]string, 0, len(params))
|
||||
for _, param := range params {
|
||||
if param.Name == "delimiters" {
|
||||
delimiters := param.Input.Value.Content.([]any)
|
||||
for _, d := range delimiters {
|
||||
separators = append(separators, d.(string))
|
||||
}
|
||||
}
|
||||
}
|
||||
c.Separators = separators
|
||||
|
||||
} else {
|
||||
return nil, fmt.Errorf("not supported method: %s", n.Data.Inputs.Method)
|
||||
}
|
||||
|
||||
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
if c.Type == ConcatText && len(c.Tpl) == 0 {
|
||||
return nil, fmt.Errorf("config tpl requried")
|
||||
}
|
||||
|
||||
return &TextProcessor{
|
||||
config: cfg,
|
||||
typ: c.Type,
|
||||
tpl: c.Tpl,
|
||||
concatChar: c.ConcatChar,
|
||||
separators: c.Separators,
|
||||
fullSources: ns.FullSources,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type TextProcessor struct {
|
||||
typ Type
|
||||
tpl string
|
||||
concatChar string
|
||||
separators []string
|
||||
fullSources map[string]*schema.SourceInfo
|
||||
}
|
||||
|
||||
const OutputKey = "output"
|
||||
|
||||
func (t *TextProcessor) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
switch t.config.Type {
|
||||
switch t.typ {
|
||||
case ConcatText:
|
||||
arrayRenderer := func(i any) (string, error) {
|
||||
vs := i.([]any)
|
||||
return join(vs, t.config.ConcatChar)
|
||||
return join(vs, t.concatChar)
|
||||
}
|
||||
|
||||
result, err := nodes.Render(ctx, t.config.Tpl, input, t.config.FullSources,
|
||||
result, err := nodes.Render(ctx, t.tpl, input, t.fullSources,
|
||||
nodes.WithCustomRender(reflect.TypeOf([]any{}), arrayRenderer))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -86,9 +140,9 @@ func (t *TextProcessor) Invoke(ctx context.Context, input map[string]any) (map[s
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("input string field must string type but got %T", valueString)
|
||||
}
|
||||
values := strings.Split(valueString, t.config.Separators[0])
|
||||
values := strings.Split(valueString, t.separators[0])
|
||||
// Iterate over each delimiter
|
||||
for _, sep := range t.config.Separators[1:] {
|
||||
for _, sep := range t.separators[1:] {
|
||||
var tempParts []string
|
||||
for _, part := range values {
|
||||
tempParts = append(tempParts, strings.Split(part, sep)...)
|
||||
@@ -102,7 +156,7 @@ func (t *TextProcessor) Invoke(ctx context.Context, input map[string]any) (map[s
|
||||
|
||||
return map[string]any{OutputKey: anyValues}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("not support type %s", t.config.Type)
|
||||
return nil, fmt.Errorf("not support type %s", t.typ)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -21,6 +21,8 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
func TestNewTextProcessorNodeGenerator(t *testing.T) {
|
||||
@@ -30,10 +32,10 @@ func TestNewTextProcessorNodeGenerator(t *testing.T) {
|
||||
Type: SplitText,
|
||||
Separators: []string{",", "|", "."},
|
||||
}
|
||||
p, err := NewTextProcessor(ctx, cfg)
|
||||
p, err := cfg.Build(ctx, &schema2.NodeSchema{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
result, err := p.Invoke(ctx, map[string]any{
|
||||
result, err := p.(*TextProcessor).Invoke(ctx, map[string]any{
|
||||
"String": "a,b|c.d,e|f|g",
|
||||
})
|
||||
|
||||
@@ -60,9 +62,9 @@ func TestNewTextProcessorNodeGenerator(t *testing.T) {
|
||||
ConcatChar: `\t`,
|
||||
Tpl: "fx{{a}}=={{b.b1}}=={{b.b2[1]}}=={{c}}",
|
||||
}
|
||||
p, err := NewTextProcessor(context.Background(), cfg)
|
||||
p, err := cfg.Build(context.Background(), &schema2.NodeSchema{})
|
||||
|
||||
result, err := p.Invoke(ctx, in)
|
||||
result, err := p.(*TextProcessor).Invoke(ctx, in)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, result["output"], `fx1\t{"1":1}\t3==1\t2\t3==2=={"c1":"1"}`)
|
||||
})
|
||||
|
||||
@@ -32,8 +32,11 @@ import (
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
@@ -48,24 +51,147 @@ const (
|
||||
type Config struct {
|
||||
MergeStrategy MergeStrategy
|
||||
GroupLen map[string]int
|
||||
FullSources map[string]*nodes.SourceInfo
|
||||
NodeKey vo.NodeKey
|
||||
InputSources []*vo.FieldInfo
|
||||
GroupOrder []string // the order the groups are declared in frontend canvas
|
||||
}
|
||||
|
||||
type VariableAggregator struct {
|
||||
config *Config
|
||||
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema2.NodeSchema, error) {
|
||||
ns := &schema2.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeVariableAggregator,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: c,
|
||||
}
|
||||
|
||||
c.MergeStrategy = FirstNotNullValue
|
||||
inputs := n.Data.Inputs
|
||||
|
||||
groupToLen := make(map[string]int, len(inputs.VariableAggregator.MergeGroups))
|
||||
for i := range inputs.VariableAggregator.MergeGroups {
|
||||
group := inputs.VariableAggregator.MergeGroups[i]
|
||||
tInfo := &vo.TypeInfo{
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: make(map[string]*vo.TypeInfo),
|
||||
}
|
||||
ns.SetInputType(group.Name, tInfo)
|
||||
for ii, v := range group.Variables {
|
||||
name := strconv.Itoa(ii)
|
||||
valueTypeInfo, err := convert.CanvasBlockInputToTypeInfo(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tInfo.Properties[name] = valueTypeInfo
|
||||
sources, err := convert.CanvasBlockInputToFieldInfo(v, compose.FieldPath{group.Name, name}, n.Parent())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ns.AddInputSource(sources...)
|
||||
}
|
||||
|
||||
length := len(group.Variables)
|
||||
groupToLen[group.Name] = length
|
||||
}
|
||||
|
||||
groupOrder := make([]string, 0, len(groupToLen))
|
||||
for i := range inputs.VariableAggregator.MergeGroups {
|
||||
group := inputs.VariableAggregator.MergeGroups[i]
|
||||
groupOrder = append(groupOrder, group.Name)
|
||||
}
|
||||
|
||||
c.GroupLen = groupToLen
|
||||
c.GroupOrder = groupOrder
|
||||
|
||||
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func NewVariableAggregator(_ context.Context, cfg *Config) (*VariableAggregator, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config is required")
|
||||
func (c *Config) Build(_ context.Context, ns *schema2.NodeSchema, _ ...schema2.BuildOption) (any, error) {
|
||||
if c.MergeStrategy != FirstNotNullValue {
|
||||
return nil, fmt.Errorf("merge strategy not supported: %v", c.MergeStrategy)
|
||||
}
|
||||
if cfg.MergeStrategy != FirstNotNullValue {
|
||||
return nil, fmt.Errorf("merge strategy not supported: %v", cfg.MergeStrategy)
|
||||
|
||||
return &VariableAggregator{
|
||||
groupLen: c.GroupLen,
|
||||
fullSources: ns.FullSources,
|
||||
nodeKey: ns.Key,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Config) FieldStreamType(path compose.FieldPath, ns *schema2.NodeSchema,
|
||||
sc *schema2.WorkflowSchema) (schema2.FieldStreamType, error) {
|
||||
if !sc.RequireStreaming() {
|
||||
return schema2.FieldNotStream, nil
|
||||
}
|
||||
return &VariableAggregator{config: cfg}, nil
|
||||
|
||||
if len(path) == 2 { // asking about a specific index within a group
|
||||
for _, fInfo := range ns.InputSources {
|
||||
if len(fInfo.Path) == len(path) {
|
||||
equal := true
|
||||
for i := range fInfo.Path {
|
||||
if fInfo.Path[i] != path[i] {
|
||||
equal = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if equal {
|
||||
if fInfo.Source.Ref == nil || fInfo.Source.Ref.FromNodeKey == "" {
|
||||
return schema2.FieldNotStream, nil // variables or static values
|
||||
}
|
||||
fromNodeKey := fInfo.Source.Ref.FromNodeKey
|
||||
fromNode := sc.GetNode(fromNodeKey)
|
||||
if fromNode == nil {
|
||||
return schema2.FieldNotStream, fmt.Errorf("node %s not found", fromNodeKey)
|
||||
}
|
||||
return nodes.IsStreamingField(fromNode, fInfo.Source.Ref.FromPath, sc)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if len(path) == 1 { // asking about the entire group
|
||||
var streamCount, notStreamCount int
|
||||
for _, fInfo := range ns.InputSources {
|
||||
if fInfo.Path[0] == path[0] { // belong to the group
|
||||
if fInfo.Source.Ref != nil && len(fInfo.Source.Ref.FromNodeKey) > 0 {
|
||||
fromNode := sc.GetNode(fInfo.Source.Ref.FromNodeKey)
|
||||
if fromNode == nil {
|
||||
return schema2.FieldNotStream, fmt.Errorf("node %s not found", fInfo.Source.Ref.FromNodeKey)
|
||||
}
|
||||
subStreamType, err := nodes.IsStreamingField(fromNode, fInfo.Source.Ref.FromPath, sc)
|
||||
if err != nil {
|
||||
return schema2.FieldNotStream, err
|
||||
}
|
||||
|
||||
if subStreamType == schema2.FieldMaybeStream {
|
||||
return schema2.FieldMaybeStream, nil
|
||||
} else if subStreamType == schema2.FieldIsStream {
|
||||
streamCount++
|
||||
} else {
|
||||
notStreamCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if streamCount > 0 && notStreamCount == 0 {
|
||||
return schema2.FieldIsStream, nil
|
||||
}
|
||||
|
||||
if streamCount == 0 && notStreamCount > 0 {
|
||||
return schema2.FieldNotStream, nil
|
||||
}
|
||||
|
||||
return schema2.FieldMaybeStream, nil
|
||||
}
|
||||
|
||||
return schema2.FieldNotStream, fmt.Errorf("variable aggregator output path max len = 2, actual: %v", path)
|
||||
}
|
||||
|
||||
type VariableAggregator struct {
|
||||
groupLen map[string]int
|
||||
fullSources map[string]*schema2.SourceInfo
|
||||
nodeKey vo.NodeKey
|
||||
groupOrder []string // the order the groups are declared in frontend canvas
|
||||
}
|
||||
|
||||
func (v *VariableAggregator) Invoke(ctx context.Context, input map[string]any) (_ map[string]any, err error) {
|
||||
@@ -76,7 +202,7 @@ func (v *VariableAggregator) Invoke(ctx context.Context, input map[string]any) (
|
||||
|
||||
result := make(map[string]any)
|
||||
groupToChoice := make(map[string]int)
|
||||
for group, length := range v.config.GroupLen {
|
||||
for group, length := range v.groupLen {
|
||||
for i := 0; i < length; i++ {
|
||||
if value, ok := in[group][i]; ok {
|
||||
if value != nil {
|
||||
@@ -93,14 +219,14 @@ func (v *VariableAggregator) Invoke(ctx context.Context, input map[string]any) (
|
||||
}
|
||||
|
||||
_ = compose.ProcessState(ctx, func(ctx context.Context, state nodes.DynamicStreamContainer) error {
|
||||
state.SaveDynamicChoice(v.config.NodeKey, groupToChoice)
|
||||
state.SaveDynamicChoice(v.nodeKey, groupToChoice)
|
||||
return nil
|
||||
})
|
||||
|
||||
ctxcache.Store(ctx, groupChoiceTypeCacheKey, map[string]nodes.FieldStreamType{}) // none of the choices are stream
|
||||
ctxcache.Store(ctx, groupChoiceTypeCacheKey, map[string]schema2.FieldStreamType{}) // none of the choices are stream
|
||||
|
||||
groupChoices := make([]any, 0, len(v.config.GroupOrder))
|
||||
for _, group := range v.config.GroupOrder {
|
||||
groupChoices := make([]any, 0, len(v.groupOrder))
|
||||
for _, group := range v.groupOrder {
|
||||
choice := groupToChoice[group]
|
||||
if choice == -1 {
|
||||
groupChoices = append(groupChoices, nil)
|
||||
@@ -125,7 +251,7 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
|
||||
_ *schema.StreamReader[map[string]any], err error) {
|
||||
inStream := streamInputConverter(input)
|
||||
|
||||
resolvedSources, ok := ctxcache.Get[map[string]*nodes.SourceInfo](ctx, resolvedSourcesCacheKey)
|
||||
resolvedSources, ok := ctxcache.Get[map[string]*schema2.SourceInfo](ctx, resolvedSourcesCacheKey)
|
||||
if !ok {
|
||||
panic("unable to get resolvesSources from ctx cache.")
|
||||
}
|
||||
@@ -138,18 +264,18 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
|
||||
|
||||
defer func() {
|
||||
if err == nil {
|
||||
groupChoiceToStreamType := map[string]nodes.FieldStreamType{}
|
||||
groupChoiceToStreamType := map[string]schema2.FieldStreamType{}
|
||||
for group, choice := range groupToChoice {
|
||||
if choice != -1 {
|
||||
item := groupToItems[group][choice]
|
||||
if _, ok := item.(stream); ok {
|
||||
groupChoiceToStreamType[group] = nodes.FieldIsStream
|
||||
groupChoiceToStreamType[group] = schema2.FieldIsStream
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
groupChoices := make([]any, 0, len(v.config.GroupOrder))
|
||||
for _, group := range v.config.GroupOrder {
|
||||
groupChoices := make([]any, 0, len(v.groupOrder))
|
||||
for _, group := range v.groupOrder {
|
||||
choice := groupToChoice[group]
|
||||
if choice == -1 {
|
||||
groupChoices = append(groupChoices, nil)
|
||||
@@ -174,16 +300,16 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
|
||||
// - if an element is not stream, actually receive from the stream to check if it's non-nil
|
||||
|
||||
groupToCurrentIndex := make(map[string]int) // the currently known smallest index that is non-nil for each group
|
||||
for group, length := range v.config.GroupLen {
|
||||
for group, length := range v.groupLen {
|
||||
groupToItems[group] = make([]any, length)
|
||||
groupToCurrentIndex[group] = math.MaxInt
|
||||
for i := 0; i < length; i++ {
|
||||
fType := resolvedSources[group].SubSources[strconv.Itoa(i)].FieldType
|
||||
if fType == nodes.FieldSkipped {
|
||||
if fType == schema2.FieldSkipped {
|
||||
groupToItems[group][i] = skipped{}
|
||||
continue
|
||||
}
|
||||
if fType == nodes.FieldIsStream {
|
||||
if fType == schema2.FieldIsStream {
|
||||
groupToItems[group][i] = stream{}
|
||||
if ci, _ := groupToCurrentIndex[group]; i < ci {
|
||||
groupToCurrentIndex[group] = i
|
||||
@@ -211,7 +337,7 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
|
||||
}
|
||||
|
||||
allDone := func() bool {
|
||||
for group := range v.config.GroupLen {
|
||||
for group := range v.groupLen {
|
||||
_, ok := groupToChoice[group]
|
||||
if !ok {
|
||||
return false
|
||||
@@ -223,7 +349,7 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
|
||||
|
||||
alreadyDone := allDone()
|
||||
if alreadyDone { // all groups have made their choices, no need to actually read input streams
|
||||
result := make(map[string]any, len(v.config.GroupLen))
|
||||
result := make(map[string]any, len(v.groupLen))
|
||||
allSkip := true
|
||||
for group := range groupToChoice {
|
||||
choice := groupToChoice[group]
|
||||
@@ -237,7 +363,7 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
|
||||
|
||||
if allSkip { // no need to convert input streams for the output, because all groups are skipped
|
||||
_ = compose.ProcessState(ctx, func(ctx context.Context, state nodes.DynamicStreamContainer) error {
|
||||
state.SaveDynamicChoice(v.config.NodeKey, groupToChoice)
|
||||
state.SaveDynamicChoice(v.nodeKey, groupToChoice)
|
||||
return nil
|
||||
})
|
||||
return schema.StreamReaderFromArray([]map[string]any{result}), nil
|
||||
@@ -336,7 +462,7 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
|
||||
}
|
||||
|
||||
_ = compose.ProcessState(ctx, func(ctx context.Context, state nodes.DynamicStreamContainer) error {
|
||||
state.SaveDynamicChoice(v.config.NodeKey, groupToChoice)
|
||||
state.SaveDynamicChoice(v.nodeKey, groupToChoice)
|
||||
return nil
|
||||
})
|
||||
|
||||
@@ -416,26 +542,12 @@ type vaCallbackInput struct {
|
||||
Variables []any `json:"variables"`
|
||||
}
|
||||
|
||||
func (v *VariableAggregator) Init(ctx context.Context) (context.Context, error) {
|
||||
ctx = ctxcache.Init(ctx)
|
||||
|
||||
resolvedSources, err := nodes.ResolveStreamSources(ctx, v.config.FullSources)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// need this info for callbacks.OnStart, so we put it in cache within Init()
|
||||
ctxcache.Store(ctx, resolvedSourcesCacheKey, resolvedSources)
|
||||
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
type streamMarkerType string
|
||||
|
||||
const streamMarker streamMarkerType = "<Stream Data...>"
|
||||
|
||||
func (v *VariableAggregator) ToCallbackInput(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
resolvedSources, ok := ctxcache.Get[map[string]*nodes.SourceInfo](ctx, resolvedSourcesCacheKey)
|
||||
resolvedSources, ok := ctxcache.Get[map[string]*schema2.SourceInfo](ctx, resolvedSourcesCacheKey)
|
||||
if !ok {
|
||||
panic("unable to get resolved_sources from ctx cache")
|
||||
}
|
||||
@@ -447,14 +559,14 @@ func (v *VariableAggregator) ToCallbackInput(ctx context.Context, input map[stri
|
||||
|
||||
merged := make([]vaCallbackInput, 0, len(in))
|
||||
|
||||
groupLen := v.config.GroupLen
|
||||
groupLen := v.groupLen
|
||||
|
||||
for groupName, vars := range in {
|
||||
orderedVars := make([]any, groupLen[groupName])
|
||||
for index := range vars {
|
||||
orderedVars[index] = vars[index]
|
||||
if len(resolvedSources) > 0 {
|
||||
if resolvedSources[groupName].SubSources[strconv.Itoa(index)].FieldType == nodes.FieldIsStream {
|
||||
if resolvedSources[groupName].SubSources[strconv.Itoa(index)].FieldType == schema2.FieldIsStream {
|
||||
// replace the streams with streamMarker,
|
||||
// because we won't read, save to execution history, or display these streams to user
|
||||
orderedVars[index] = streamMarker
|
||||
@@ -479,7 +591,7 @@ func (v *VariableAggregator) ToCallbackInput(ctx context.Context, input map[stri
|
||||
}
|
||||
|
||||
func (v *VariableAggregator) ToCallbackOutput(ctx context.Context, output map[string]any) (*nodes.StructuredCallbackOutput, error) {
|
||||
dynamicStreamType, ok := ctxcache.Get[map[string]nodes.FieldStreamType](ctx, groupChoiceTypeCacheKey)
|
||||
dynamicStreamType, ok := ctxcache.Get[map[string]schema2.FieldStreamType](ctx, groupChoiceTypeCacheKey)
|
||||
if !ok {
|
||||
panic("unable to get dynamic stream types from ctx cache")
|
||||
}
|
||||
@@ -501,7 +613,7 @@ func (v *VariableAggregator) ToCallbackOutput(ctx context.Context, output map[st
|
||||
|
||||
newOut := maps.Clone(output)
|
||||
for k := range output {
|
||||
if t, ok := dynamicStreamType[k]; ok && t == nodes.FieldIsStream {
|
||||
if t, ok := dynamicStreamType[k]; ok && t == schema2.FieldIsStream {
|
||||
newOut[k] = streamMarker
|
||||
}
|
||||
}
|
||||
@@ -594,3 +706,15 @@ func init() {
|
||||
nodes.RegisterStreamChunkConcatFunc(concatVACallbackInputs)
|
||||
nodes.RegisterStreamChunkConcatFunc(concatStreamMarkers)
|
||||
}
|
||||
|
||||
func (v *VariableAggregator) Init(ctx context.Context) (context.Context, error) {
|
||||
resolvedSources, err := nodes.ResolveStreamSources(ctx, v.fullSources)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// need this info for callbacks.OnStart, so we put it in cache within Init()
|
||||
ctxcache.Store(ctx, resolvedSourcesCacheKey, resolvedSources)
|
||||
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
@@ -25,29 +25,75 @@ import (
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
type VariableAssigner struct {
|
||||
config *Config
|
||||
pairs []*Pair
|
||||
handler *variable.Handler
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Pairs []*Pair
|
||||
Handler *variable.Handler
|
||||
Pairs []*Pair
|
||||
}
|
||||
|
||||
type Pair struct {
|
||||
Left vo.Reference
|
||||
Right compose.FieldPath
|
||||
func (c *Config) Adapt(ctx context.Context, n *vo.Node, opts ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeVariableAssigner,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: c,
|
||||
}
|
||||
|
||||
var pairs = make([]*Pair, 0, len(n.Data.Inputs.InputParameters))
|
||||
for i, param := range n.Data.Inputs.InputParameters {
|
||||
if param.Left == nil || param.Input == nil {
|
||||
return nil, fmt.Errorf("variable assigner node's param left or input is nil")
|
||||
}
|
||||
|
||||
leftSources, err := convert.CanvasBlockInputToFieldInfo(param.Left, compose.FieldPath{fmt.Sprintf("left_%d", i)}, n.Parent())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if leftSources[0].Source.Ref == nil {
|
||||
return nil, fmt.Errorf("variable assigner node's param left source ref is nil")
|
||||
}
|
||||
|
||||
if leftSources[0].Source.Ref.VariableType == nil {
|
||||
return nil, fmt.Errorf("variable assigner node's param left source ref's variable type is nil")
|
||||
}
|
||||
|
||||
if *leftSources[0].Source.Ref.VariableType == vo.GlobalSystem {
|
||||
return nil, fmt.Errorf("variable assigner node's param left's ref's variable type cannot be variable.GlobalSystem")
|
||||
}
|
||||
|
||||
inputSource, err := convert.CanvasBlockInputToFieldInfo(param.Input, leftSources[0].Source.Ref.FromPath, n.Parent())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ns.AddInputSource(inputSource...)
|
||||
pair := &Pair{
|
||||
Left: *leftSources[0].Source.Ref,
|
||||
Right: inputSource[0].Path,
|
||||
}
|
||||
pairs = append(pairs, pair)
|
||||
}
|
||||
|
||||
c.Pairs = pairs
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func NewVariableAssigner(_ context.Context, conf *Config) (*VariableAssigner, error) {
|
||||
for _, pair := range conf.Pairs {
|
||||
func (c *Config) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
for _, pair := range c.Pairs {
|
||||
if pair.Left.VariableType == nil {
|
||||
return nil, fmt.Errorf("cannot assign to output of nodes in VariableAssigner, ref: %v", pair.Left)
|
||||
}
|
||||
@@ -63,12 +109,18 @@ func NewVariableAssigner(_ context.Context, conf *Config) (*VariableAssigner, er
|
||||
}
|
||||
|
||||
return &VariableAssigner{
|
||||
config: conf,
|
||||
pairs: c.Pairs,
|
||||
handler: variable.GetVariableHandler(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (v *VariableAssigner) Assign(ctx context.Context, in map[string]any) (map[string]any, error) {
|
||||
for _, pair := range v.config.Pairs {
|
||||
type Pair struct {
|
||||
Left vo.Reference
|
||||
Right compose.FieldPath
|
||||
}
|
||||
|
||||
func (v *VariableAssigner) Invoke(ctx context.Context, in map[string]any) (map[string]any, error) {
|
||||
for _, pair := range v.pairs {
|
||||
right, ok := nodes.TakeMapValue(in, pair.Right)
|
||||
if !ok {
|
||||
return nil, vo.NewError(errno.ErrInputFieldMissing, errorx.KV("name", strings.Join(pair.Right, ".")))
|
||||
@@ -98,7 +150,7 @@ func (v *VariableAssigner) Assign(ctx context.Context, in map[string]any) (map[s
|
||||
ConnectorUID: exeCfg.ConnectorUID,
|
||||
}))
|
||||
}
|
||||
err := v.config.Handler.Set(ctx, *pair.Left.VariableType, pair.Left.FromPath, right, opts...)
|
||||
err := v.handler.Set(ctx, *pair.Left.VariableType, pair.Left.FromPath, right, opts...)
|
||||
if err != nil {
|
||||
return nil, vo.WrapIfNeeded(errno.ErrVariablesAPIFail, err)
|
||||
}
|
||||
|
||||
@@ -20,25 +20,93 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
einoCompose "github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
type InLoop struct {
|
||||
config *Config
|
||||
intermediateVarStore variable.Store
|
||||
type InLoopConfig struct {
|
||||
Pairs []*Pair
|
||||
}
|
||||
|
||||
func NewVariableAssignerInLoop(_ context.Context, conf *Config) (*InLoop, error) {
|
||||
func (i *InLoopConfig) Adapt(ctx context.Context, n *vo.Node, opts ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
if n.Parent() == nil {
|
||||
return nil, fmt.Errorf("loop set variable node must have parent: %s", n.ID)
|
||||
}
|
||||
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeVariableAssignerWithinLoop,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: i,
|
||||
}
|
||||
|
||||
var pairs []*Pair
|
||||
for i, param := range n.Data.Inputs.InputParameters {
|
||||
if param.Left == nil || param.Right == nil {
|
||||
return nil, fmt.Errorf("loop set variable node's param left or right is nil")
|
||||
}
|
||||
|
||||
leftSources, err := convert.CanvasBlockInputToFieldInfo(param.Left, einoCompose.FieldPath{fmt.Sprintf("left_%d", i)}, n.Parent())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(leftSources) != 1 {
|
||||
return nil, fmt.Errorf("loop set variable node's param left is not a single source")
|
||||
}
|
||||
|
||||
if leftSources[0].Source.Ref == nil {
|
||||
return nil, fmt.Errorf("loop set variable node's param left's ref is nil")
|
||||
}
|
||||
|
||||
if leftSources[0].Source.Ref.VariableType == nil || *leftSources[0].Source.Ref.VariableType != vo.ParentIntermediate {
|
||||
return nil, fmt.Errorf("loop set variable node's param left's ref's variable type is not variable.ParentIntermediate")
|
||||
}
|
||||
|
||||
rightSources, err := convert.CanvasBlockInputToFieldInfo(param.Right, leftSources[0].Source.Ref.FromPath, n.Parent())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ns.AddInputSource(rightSources...)
|
||||
|
||||
if len(rightSources) != 1 {
|
||||
return nil, fmt.Errorf("loop set variable node's param right is not a single source")
|
||||
}
|
||||
|
||||
pair := &Pair{
|
||||
Left: *leftSources[0].Source.Ref,
|
||||
Right: rightSources[0].Path,
|
||||
}
|
||||
|
||||
pairs = append(pairs, pair)
|
||||
}
|
||||
|
||||
i.Pairs = pairs
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func (i *InLoopConfig) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
return &InLoop{
|
||||
config: conf,
|
||||
pairs: i.Pairs,
|
||||
intermediateVarStore: &nodes.ParentIntermediateStore{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (v *InLoop) Assign(ctx context.Context, in map[string]any) (out map[string]any, err error) {
|
||||
for _, pair := range v.config.Pairs {
|
||||
type InLoop struct {
|
||||
pairs []*Pair
|
||||
intermediateVarStore variable.Store
|
||||
}
|
||||
|
||||
func (v *InLoop) Invoke(ctx context.Context, in map[string]any) (out map[string]any, err error) {
|
||||
for _, pair := range v.pairs {
|
||||
if pair.Left.VariableType == nil || *pair.Left.VariableType != vo.ParentIntermediate {
|
||||
panic(fmt.Errorf("dest is %+v in VariableAssignerInloop, invalid", pair.Left))
|
||||
}
|
||||
|
||||
@@ -37,36 +37,34 @@ func TestVariableAssigner(t *testing.T) {
|
||||
arrVar := any([]any{1, "2"})
|
||||
|
||||
va := &InLoop{
|
||||
config: &Config{
|
||||
Pairs: []*Pair{
|
||||
{
|
||||
Left: vo.Reference{
|
||||
FromPath: compose.FieldPath{"int_var_s"},
|
||||
VariableType: ptr.Of(vo.ParentIntermediate),
|
||||
},
|
||||
Right: compose.FieldPath{"int_var_t"},
|
||||
pairs: []*Pair{
|
||||
{
|
||||
Left: vo.Reference{
|
||||
FromPath: compose.FieldPath{"int_var_s"},
|
||||
VariableType: ptr.Of(vo.ParentIntermediate),
|
||||
},
|
||||
{
|
||||
Left: vo.Reference{
|
||||
FromPath: compose.FieldPath{"str_var_s"},
|
||||
VariableType: ptr.Of(vo.ParentIntermediate),
|
||||
},
|
||||
Right: compose.FieldPath{"str_var_t"},
|
||||
Right: compose.FieldPath{"int_var_t"},
|
||||
},
|
||||
{
|
||||
Left: vo.Reference{
|
||||
FromPath: compose.FieldPath{"str_var_s"},
|
||||
VariableType: ptr.Of(vo.ParentIntermediate),
|
||||
},
|
||||
{
|
||||
Left: vo.Reference{
|
||||
FromPath: compose.FieldPath{"obj_var_s"},
|
||||
VariableType: ptr.Of(vo.ParentIntermediate),
|
||||
},
|
||||
Right: compose.FieldPath{"obj_var_t"},
|
||||
Right: compose.FieldPath{"str_var_t"},
|
||||
},
|
||||
{
|
||||
Left: vo.Reference{
|
||||
FromPath: compose.FieldPath{"obj_var_s"},
|
||||
VariableType: ptr.Of(vo.ParentIntermediate),
|
||||
},
|
||||
{
|
||||
Left: vo.Reference{
|
||||
FromPath: compose.FieldPath{"arr_var_s"},
|
||||
VariableType: ptr.Of(vo.ParentIntermediate),
|
||||
},
|
||||
Right: compose.FieldPath{"arr_var_t"},
|
||||
Right: compose.FieldPath{"obj_var_t"},
|
||||
},
|
||||
{
|
||||
Left: vo.Reference{
|
||||
FromPath: compose.FieldPath{"arr_var_s"},
|
||||
VariableType: ptr.Of(vo.ParentIntermediate),
|
||||
},
|
||||
Right: compose.FieldPath{"arr_var_t"},
|
||||
},
|
||||
},
|
||||
intermediateVarStore: &nodes.ParentIntermediateStore{},
|
||||
@@ -79,7 +77,7 @@ func TestVariableAssigner(t *testing.T) {
|
||||
"arr_var_s": &arrVar,
|
||||
}, nil)
|
||||
|
||||
_, err := va.Assign(ctx, map[string]any{
|
||||
_, err := va.Invoke(ctx, map[string]any{
|
||||
"int_var_t": 2,
|
||||
"str_var_t": "str2",
|
||||
"obj_var_t": map[string]any{
|
||||
|
||||
Reference in New Issue
Block a user