refactor: how to add a node type in workflow (#558)

This commit is contained in:
shentongmartin
2025-08-05 14:02:33 +08:00
committed by GitHub
parent 5dafd81a3f
commit bb6ff0026b
96 changed files with 8305 additions and 8717 deletions

View File

@@ -105,26 +105,28 @@ import (
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
callbacks.AppendGlobalHandlers(service.GetTokenCallbackHandler()) callbacks.AppendGlobalHandlers(service.GetTokenCallbackHandler())
service.RegisterAllNodeAdaptors()
os.Exit(m.Run()) os.Exit(m.Run())
} }
type wfTestRunner struct { type wfTestRunner struct {
t *testing.T t *testing.T
h *server.Hertz h *server.Hertz
ctrl *gomock.Controller ctrl *gomock.Controller
idGen *mock.MockIDGenerator idGen *mock.MockIDGenerator
search *searchmock.MockNotifier search *searchmock.MockNotifier
appVarS *mockvar.MockStore appVarS *mockvar.MockStore
userVarS *mockvar.MockStore userVarS *mockvar.MockStore
varGetter *mockvar.MockVariablesMetaGetter varGetter *mockvar.MockVariablesMetaGetter
modelManage *mockmodel.MockManager modelManage *mockmodel.MockManager
plugin *mockPlugin.MockPluginService plugin *mockPlugin.MockPluginService
tos *storageMock.MockStorage tos *storageMock.MockStorage
knowledge *knowledgemock.MockKnowledgeOperator knowledge *knowledgemock.MockKnowledgeOperator
database *databasemock.MockDatabaseOperator database *databasemock.MockDatabaseOperator
pluginSrv *pluginmock.MockService pluginSrv *pluginmock.MockService
ctx context.Context internalModel *testutil.UTChatModel
closeFn func() ctx context.Context
closeFn func()
} }
var req2URL = map[reflect.Type]string{ var req2URL = map[reflect.Type]string{
@@ -243,9 +245,11 @@ func newWfTestRunner(t *testing.T) *wfTestRunner {
cpStore := checkpoint.NewRedisStore(redisClient) cpStore := checkpoint.NewRedisStore(redisClient)
utChatModel := &testutil.UTChatModel{}
mockTos := storageMock.NewMockStorage(ctrl) mockTos := storageMock.NewMockStorage(ctrl)
mockTos.EXPECT().GetObjectUrl(gomock.Any(), gomock.Any(), gomock.Any()).Return("", nil).AnyTimes() mockTos.EXPECT().GetObjectUrl(gomock.Any(), gomock.Any(), gomock.Any()).Return("", nil).AnyTimes()
workflowRepo := service.NewWorkflowRepository(mockIDGen, db, redisClient, mockTos, cpStore, nil) workflowRepo := service.NewWorkflowRepository(mockIDGen, db, redisClient, mockTos, cpStore, utChatModel)
mockey.Mock(appworkflow.GetWorkflowDomainSVC).Return(service.NewWorkflowService(workflowRepo)).Build() mockey.Mock(appworkflow.GetWorkflowDomainSVC).Return(service.NewWorkflowService(workflowRepo)).Build()
mockey.Mock(workflow2.GetRepository).Return(workflowRepo).Build() mockey.Mock(workflow2.GetRepository).Return(workflowRepo).Build()
@@ -312,22 +316,23 @@ func newWfTestRunner(t *testing.T) *wfTestRunner {
} }
return &wfTestRunner{ return &wfTestRunner{
t: t, t: t,
h: h, h: h,
ctrl: ctrl, ctrl: ctrl,
idGen: mockIDGen, idGen: mockIDGen,
search: mockSearchNotify, search: mockSearchNotify,
appVarS: mockGlobalAppVarStore, appVarS: mockGlobalAppVarStore,
userVarS: mockGlobalUserVarStore, userVarS: mockGlobalUserVarStore,
varGetter: mockVarGetter, varGetter: mockVarGetter,
modelManage: mockModelManage, modelManage: mockModelManage,
plugin: mPlugin, plugin: mPlugin,
tos: mockTos, tos: mockTos,
knowledge: mockKwOperator, knowledge: mockKwOperator,
database: mockDatabaseOperator, database: mockDatabaseOperator,
ctx: context.Background(), internalModel: utChatModel,
closeFn: f, ctx: context.Background(),
pluginSrv: mockPluginSrv, closeFn: f,
pluginSrv: mockPluginSrv,
} }
} }
@@ -1110,7 +1115,8 @@ func TestValidateTree(t *testing.T) {
assert.Equal(t, i.Message, `node "代码_1" not connected`) assert.Equal(t, i.Message, `node "代码_1" not connected`)
} }
if i.NodeError.NodeID == "160892" { if i.NodeError.NodeID == "160892" {
assert.Contains(t, i.Message, `node "意图识别"'s port "branch_1" not connected`, `node "意图识别"'s port "default" not connected;`) assert.Contains(t, i.Message, `node "意图识别"'s port "branch_1" not connected`)
assert.Contains(t, i.Message, `node "意图识别"'s port "default" not connected`)
} }
} }
@@ -1157,7 +1163,8 @@ func TestValidateTree(t *testing.T) {
assert.Equal(t, i.Message, `node "代码_1" not connected`) assert.Equal(t, i.Message, `node "代码_1" not connected`)
} }
if i.NodeError.NodeID == "160892" { if i.NodeError.NodeID == "160892" {
assert.Contains(t, i.Message, `node "意图识别"'s port "branch_1" not connected`, `node "意图识别"'s port "default" not connected;`) assert.Contains(t, i.Message, `node "意图识别"'s port "branch_1" not connected`)
assert.Contains(t, i.Message, `node "意图识别"'s port "default" not connected`)
} }
} }
} }
@@ -2950,41 +2957,41 @@ func TestLLMWithSkills(t *testing.T) {
r := newWfTestRunner(t) r := newWfTestRunner(t)
defer r.closeFn() defer r.closeFn()
utChatModel := &testutil.UTChatModel{ utChatModel := r.internalModel
InvokeResultProvider: func(index int, in []*schema.Message) (*schema.Message, error) { utChatModel.InvokeResultProvider = func(index int, in []*schema.Message) (*schema.Message, error) {
if index == 0 { if index == 0 {
assert.Equal(t, 1, len(in)) assert.Equal(t, 1, len(in))
assert.Contains(t, in[0].Content, "7512369185624686592", "你是一个知识库意图识别AI Agent", "北京有哪些著名的景点") assert.Contains(t, in[0].Content, "7512369185624686592", "你是一个知识库意图识别AI Agent", "北京有哪些著名的景点")
return &schema.Message{ return &schema.Message{
Role: schema.Assistant, Role: schema.Assistant,
Content: "7512369185624686592", Content: "7512369185624686592",
ResponseMeta: &schema.ResponseMeta{ ResponseMeta: &schema.ResponseMeta{
Usage: &schema.TokenUsage{ Usage: &schema.TokenUsage{
PromptTokens: 10, PromptTokens: 10,
CompletionTokens: 11, CompletionTokens: 11,
TotalTokens: 21, TotalTokens: 21,
},
}, },
}, nil },
}, nil
} else if index == 1 { } else if index == 1 {
assert.Equal(t, 2, len(in)) assert.Equal(t, 2, len(in))
for _, message := range in { for _, message := range in {
if message.Role == schema.System { if message.Role == schema.System {
assert.Equal(t, "你是一个旅游推荐专家,通过用户提出的问题,推荐用户具体城市的旅游景点", message.Content) assert.Equal(t, "你是一个旅游推荐专家,通过用户提出的问题,推荐用户具体城市的旅游景点", message.Content)
} }
if message.Role == schema.User { if message.Role == schema.User {
assert.Contains(t, message.Content, "天安门广场 ‌:中国政治文化中心,见证了近现代重大历史事件‌", "八达岭长城 ‌:明代长城的精华段,被誉为“不到长城非好汉") assert.Contains(t, message.Content, "天安门广场 ‌:中国政治文化中心,见证了近现代重大历史事件‌", "八达岭长城 ‌:明代长城的精华段,被誉为“不到长城非好汉")
}
} }
return &schema.Message{
Role: schema.Assistant,
Content: `八达岭长城 ‌:明代长城的精华段,被誉为“不到长城非好汉‌`,
}, nil
} }
return nil, fmt.Errorf("unexpected index: %d", index) return &schema.Message{
}, Role: schema.Assistant,
Content: `八达岭长城 ‌:明代长城的精华段,被誉为“不到长城非好汉‌`,
}, nil
}
return nil, fmt.Errorf("unexpected index: %d", index)
} }
r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(utChatModel, nil, nil).AnyTimes() r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(utChatModel, nil, nil).AnyTimes()
r.knowledge.EXPECT().ListKnowledgeDetail(gomock.Any(), gomock.Any()).Return(&knowledge.ListKnowledgeDetailResponse{ r.knowledge.EXPECT().ListKnowledgeDetail(gomock.Any(), gomock.Any()).Return(&knowledge.ListKnowledgeDetailResponse{
@@ -3767,10 +3774,10 @@ func TestReleaseApplicationWorkflows(t *testing.T) {
var validateCv func(ns []*vo.Node) var validateCv func(ns []*vo.Node)
validateCv = func(ns []*vo.Node) { validateCv = func(ns []*vo.Node) {
for _, n := range ns { for _, n := range ns {
if n.Type == vo.BlockTypeBotSubWorkflow { if n.Type == entity.NodeTypeSubWorkflow.IDStr() {
assert.Equal(t, n.Data.Inputs.WorkflowVersion, version) assert.Equal(t, n.Data.Inputs.WorkflowVersion, version)
} }
if n.Type == vo.BlockTypeBotAPI { if n.Type == entity.NodeTypePlugin.IDStr() {
for _, apiParam := range n.Data.Inputs.APIParams { for _, apiParam := range n.Data.Inputs.APIParams {
// In the application, the workflow plugin node When the plugin version is equal to 0, the plugin is a plugin created in the application // In the application, the workflow plugin node When the plugin version is equal to 0, the plugin is a plugin created in the application
if apiParam.Name == "pluginVersion" { if apiParam.Name == "pluginVersion" {
@@ -3779,7 +3786,7 @@ func TestReleaseApplicationWorkflows(t *testing.T) {
} }
} }
if n.Type == vo.BlockTypeBotLLM { if n.Type == entity.NodeTypeLLM.IDStr() {
if n.Data.Inputs.FCParam != nil && n.Data.Inputs.FCParam.PluginFCParam != nil { if n.Data.Inputs.FCParam != nil && n.Data.Inputs.FCParam.PluginFCParam != nil {
// In the application, the workflow llm node When the plugin version is equal to 0, the plugin is a plugin created in the application // In the application, the workflow llm node When the plugin version is equal to 0, the plugin is a plugin created in the application
for _, p := range n.Data.Inputs.FCParam.PluginFCParam.PluginList { for _, p := range n.Data.Inputs.FCParam.PluginFCParam.PluginList {
@@ -4063,8 +4070,8 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) {
var validateSubWorkflowIDs func(nodes []*vo.Node) var validateSubWorkflowIDs func(nodes []*vo.Node)
validateSubWorkflowIDs = func(nodes []*vo.Node) { validateSubWorkflowIDs = func(nodes []*vo.Node) {
for _, node := range nodes { for _, node := range nodes {
switch node.Type { switch entity.IDStrToNodeType(node.Type) {
case vo.BlockTypeBotAPI: case entity.NodeTypePlugin:
apiParams := slices.ToMap(node.Data.Inputs.APIParams, func(e *vo.Param) (string, *vo.Param) { apiParams := slices.ToMap(node.Data.Inputs.APIParams, func(e *vo.Param) (string, *vo.Param) {
return e.Name, e return e.Name, e
}) })
@@ -4082,7 +4089,7 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) {
assert.Equal(t, "100100", pID) assert.Equal(t, "100100", pID)
} }
case vo.BlockTypeBotSubWorkflow: case entity.NodeTypeSubWorkflow:
assert.True(t, copiedIDMap[node.Data.Inputs.WorkflowID]) assert.True(t, copiedIDMap[node.Data.Inputs.WorkflowID])
wfId, err := strconv.ParseInt(node.Data.Inputs.WorkflowID, 10, 64) wfId, err := strconv.ParseInt(node.Data.Inputs.WorkflowID, 10, 64)
assert.NoError(t, err) assert.NoError(t, err)
@@ -4096,7 +4103,7 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) {
err = sonic.UnmarshalString(subWf.Canvas, subworkflowCanvas) err = sonic.UnmarshalString(subWf.Canvas, subworkflowCanvas)
assert.NoError(t, err) assert.NoError(t, err)
validateSubWorkflowIDs(subworkflowCanvas.Nodes) validateSubWorkflowIDs(subworkflowCanvas.Nodes)
case vo.BlockTypeBotLLM: case entity.NodeTypeLLM:
if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil { if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil {
for _, w := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList { for _, w := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList {
assert.True(t, copiedIDMap[w.WorkflowID]) assert.True(t, copiedIDMap[w.WorkflowID])
@@ -4116,13 +4123,13 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) {
assert.Equal(t, "100100", k.ID) assert.Equal(t, "100100", k.ID)
} }
} }
case vo.BlockTypeBotDataset, vo.BlockTypeBotDatasetWrite: case entity.NodeTypeKnowledgeIndexer, entity.NodeTypeKnowledgeRetriever:
datasetListInfoParam := node.Data.Inputs.DatasetParam[0] datasetListInfoParam := node.Data.Inputs.DatasetParam[0]
knowledgeIDs := datasetListInfoParam.Input.Value.Content.([]any) knowledgeIDs := datasetListInfoParam.Input.Value.Content.([]any)
for idx := range knowledgeIDs { for idx := range knowledgeIDs {
assert.Equal(t, "100100", knowledgeIDs[idx].(string)) assert.Equal(t, "100100", knowledgeIDs[idx].(string))
} }
case vo.BlockTypeDatabase, vo.BlockTypeDatabaseSelect, vo.BlockTypeDatabaseInsert, vo.BlockTypeDatabaseDelete, vo.BlockTypeDatabaseUpdate: case entity.NodeTypeDatabaseCustomSQL, entity.NodeTypeDatabaseQuery, entity.NodeTypeDatabaseInsert, entity.NodeTypeDatabaseDelete, entity.NodeTypeDatabaseUpdate:
for _, d := range node.Data.Inputs.DatabaseInfoList { for _, d := range node.Data.Inputs.DatabaseInfoList {
assert.Equal(t, "100100", d.DatabaseInfoID) assert.Equal(t, "100100", d.DatabaseInfoID)
} }
@@ -4208,10 +4215,10 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) {
var validateSubWorkflowIDs func(nodes []*vo.Node) var validateSubWorkflowIDs func(nodes []*vo.Node)
validateSubWorkflowIDs = func(nodes []*vo.Node) { validateSubWorkflowIDs = func(nodes []*vo.Node) {
for _, node := range nodes { for _, node := range nodes {
switch node.Type { switch entity.IDStrToNodeType(node.Type) {
case vo.BlockTypeBotSubWorkflow: case entity.NodeTypeSubWorkflow:
assert.True(t, copiedIDMap[node.Data.Inputs.WorkflowID]) assert.True(t, copiedIDMap[node.Data.Inputs.WorkflowID])
case vo.BlockTypeBotLLM: case entity.NodeTypeLLM:
if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil { if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil {
for _, w := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList { for _, w := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList {
assert.True(t, copiedIDMap[w.WorkflowID]) assert.True(t, copiedIDMap[w.WorkflowID])
@@ -4229,13 +4236,13 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) {
assert.Equal(t, "100100", k.ID) assert.Equal(t, "100100", k.ID)
} }
} }
case vo.BlockTypeBotDataset, vo.BlockTypeBotDatasetWrite: case entity.NodeTypeKnowledgeIndexer, entity.NodeTypeKnowledgeRetriever:
datasetListInfoParam := node.Data.Inputs.DatasetParam[0] datasetListInfoParam := node.Data.Inputs.DatasetParam[0]
knowledgeIDs := datasetListInfoParam.Input.Value.Content.([]any) knowledgeIDs := datasetListInfoParam.Input.Value.Content.([]any)
for idx := range knowledgeIDs { for idx := range knowledgeIDs {
assert.Equal(t, "100100", knowledgeIDs[idx].(string)) assert.Equal(t, "100100", knowledgeIDs[idx].(string))
} }
case vo.BlockTypeDatabase, vo.BlockTypeDatabaseSelect, vo.BlockTypeDatabaseInsert, vo.BlockTypeDatabaseDelete, vo.BlockTypeDatabaseUpdate: case entity.NodeTypeDatabaseCustomSQL, entity.NodeTypeDatabaseQuery, entity.NodeTypeDatabaseInsert, entity.NodeTypeDatabaseDelete, entity.NodeTypeDatabaseUpdate:
for _, d := range node.Data.Inputs.DatabaseInfoList { for _, d := range node.Data.Inputs.DatabaseInfoList {
assert.Equal(t, "100100", d.DatabaseInfoID) assert.Equal(t, "100100", d.DatabaseInfoID)
} }
@@ -4356,7 +4363,7 @@ func TestMoveWorkflowAppToLibrary(t *testing.T) {
err = sonic.Unmarshal(data, mainCanvas) err = sonic.Unmarshal(data, mainCanvas)
assert.NoError(t, err) assert.NoError(t, err)
for _, node := range mainCanvas.Nodes { for _, node := range mainCanvas.Nodes {
if node.Type == vo.BlockTypeBotSubWorkflow { if node.Type == entity.NodeTypeSubWorkflow.IDStr() {
if node.Data.Inputs.WorkflowID == "7516826260387921920" { if node.Data.Inputs.WorkflowID == "7516826260387921920" {
node.Data.Inputs.WorkflowID = c1IdStr node.Data.Inputs.WorkflowID = c1IdStr
} }
@@ -4372,7 +4379,7 @@ func TestMoveWorkflowAppToLibrary(t *testing.T) {
err = sonic.Unmarshal(cc1Data, cc1Canvas) err = sonic.Unmarshal(cc1Data, cc1Canvas)
assert.NoError(t, err) assert.NoError(t, err)
for _, node := range cc1Canvas.Nodes { for _, node := range cc1Canvas.Nodes {
if node.Type == vo.BlockTypeBotSubWorkflow { if node.Type == entity.NodeTypeSubWorkflow.IDStr() {
if node.Data.Inputs.WorkflowID == "7516826283318181888" { if node.Data.Inputs.WorkflowID == "7516826283318181888" {
node.Data.Inputs.WorkflowID = c2IdStr node.Data.Inputs.WorkflowID = c2IdStr
} }
@@ -4423,7 +4430,7 @@ func TestMoveWorkflowAppToLibrary(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
for _, node := range newMainCanvas.Nodes { for _, node := range newMainCanvas.Nodes {
if node.Type == vo.BlockTypeBotSubWorkflow { if node.Type == entity.NodeTypeSubWorkflow.IDStr() {
assert.True(t, newSubWorkflowID[node.Data.Inputs.WorkflowID]) assert.True(t, newSubWorkflowID[node.Data.Inputs.WorkflowID])
assert.Equal(t, "v0.0.1", node.Data.Inputs.WorkflowVersion) assert.Equal(t, "v0.0.1", node.Data.Inputs.WorkflowVersion)
} }
@@ -4437,7 +4444,7 @@ func TestMoveWorkflowAppToLibrary(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
for _, node := range cc1Canvas.Nodes { for _, node := range cc1Canvas.Nodes {
if node.Type == vo.BlockTypeBotSubWorkflow { if node.Type == entity.NodeTypeSubWorkflow.IDStr() {
assert.True(t, newSubWorkflowID[node.Data.Inputs.WorkflowID]) assert.True(t, newSubWorkflowID[node.Data.Inputs.WorkflowID])
assert.Equal(t, "v0.0.1", node.Data.Inputs.WorkflowVersion) assert.Equal(t, "v0.0.1", node.Data.Inputs.WorkflowVersion)
} }
@@ -4508,10 +4515,10 @@ func TestDuplicateWorkflowsByAppID(t *testing.T) {
var validateSubWorkflowIDs func(nodes []*vo.Node) var validateSubWorkflowIDs func(nodes []*vo.Node)
validateSubWorkflowIDs = func(nodes []*vo.Node) { validateSubWorkflowIDs = func(nodes []*vo.Node) {
for _, node := range nodes { for _, node := range nodes {
if node.Type == vo.BlockTypeBotSubWorkflow { if node.Type == entity.NodeTypeSubWorkflow.IDStr() {
assert.True(t, copiedIDMap[node.Data.Inputs.WorkflowID]) assert.True(t, copiedIDMap[node.Data.Inputs.WorkflowID])
} }
if node.Type == vo.BlockTypeBotLLM { if node.Type == entity.NodeTypeLLM.IDStr() {
if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil { if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil {
for _, w := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList { for _, w := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList {
assert.True(t, copiedIDMap[w.WorkflowID]) assert.True(t, copiedIDMap[w.WorkflowID])

View File

@@ -24,6 +24,7 @@ import (
"github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/callbacks"
"github.com/coze-dev/coze-studio/backend/application/internal" "github.com/coze-dev/coze-studio/backend/application/internal"
wfdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/database" wfdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/database"
wfknowledge "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/knowledge" wfknowledge "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/knowledge"
wfmodel "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/model" wfmodel "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/model"
@@ -78,6 +79,9 @@ func InitService(ctx context.Context, components *ServiceComponents) (*Applicati
if !ok { if !ok {
logs.CtxWarnf(ctx, "workflow builtin chat model for knowledge recall not configured") logs.CtxWarnf(ctx, "workflow builtin chat model for knowledge recall not configured")
} }
service.RegisterAllNodeAdaptors()
workflowRepo := service.NewWorkflowRepository(components.IDGen, components.DB, components.Cache, workflowRepo := service.NewWorkflowRepository(components.IDGen, components.DB, components.Cache,
components.Tos, components.CPStore, bcm) components.Tos, components.CPStore, bcm)
workflow.SetRepository(workflowRepo) workflow.SetRepository(workflowRepo)

View File

@@ -93,8 +93,8 @@ func (w *ApplicationService) GetNodeTemplateList(ctx context.Context, req *workf
toQueryTypes := make(map[entity.NodeType]bool) toQueryTypes := make(map[entity.NodeType]bool)
for _, t := range req.NodeTypes { for _, t := range req.NodeTypes {
entityType, err := nodeType2EntityNodeType(t) entityType := entity.IDStrToNodeType(t)
if err != nil { if len(entityType) == 0 {
logs.Warnf("get node type %v failed, err:=%v", t, err) logs.Warnf("get node type %v failed, err:=%v", t, err)
continue continue
} }
@@ -116,23 +116,19 @@ func (w *ApplicationService) GetNodeTemplateList(ctx context.Context, req *workf
Name: category, Name: category,
} }
for _, nodeMeta := range nodeMetaList { for _, nodeMeta := range nodeMetaList {
tplType, err := entityNodeTypeToAPINodeTemplateType(nodeMeta.Type)
if err != nil {
return nil, err
}
tpl := &workflow.NodeTemplate{ tpl := &workflow.NodeTemplate{
ID: fmt.Sprintf("%d", nodeMeta.ID), ID: fmt.Sprintf("%d", nodeMeta.ID),
Type: tplType, Type: workflow.NodeTemplateType(nodeMeta.ID),
Name: ternary.IFElse(i18n.GetLocale(ctx) == i18n.LocaleEN, nodeMeta.EnUSName, nodeMeta.Name), Name: ternary.IFElse(i18n.GetLocale(ctx) == i18n.LocaleEN, nodeMeta.EnUSName, nodeMeta.Name),
Desc: ternary.IFElse(i18n.GetLocale(ctx) == i18n.LocaleEN, nodeMeta.EnUSDescription, nodeMeta.Desc), Desc: ternary.IFElse(i18n.GetLocale(ctx) == i18n.LocaleEN, nodeMeta.EnUSDescription, nodeMeta.Desc),
IconURL: nodeMeta.IconURL, IconURL: nodeMeta.IconURL,
SupportBatch: ternary.IFElse(nodeMeta.SupportBatch, workflow.SupportBatch_SUPPORT, workflow.SupportBatch_NOT_SUPPORT), SupportBatch: ternary.IFElse(nodeMeta.SupportBatch, workflow.SupportBatch_SUPPORT, workflow.SupportBatch_NOT_SUPPORT),
NodeType: fmt.Sprintf("%d", tplType), NodeType: fmt.Sprintf("%d", nodeMeta.ID),
Color: nodeMeta.Color, Color: nodeMeta.Color,
} }
resp.Data.TemplateList = append(resp.Data.TemplateList, tpl) resp.Data.TemplateList = append(resp.Data.TemplateList, tpl)
categoryMap[category].NodeTypeList = append(categoryMap[category].NodeTypeList, fmt.Sprintf("%d", tplType)) categoryMap[category].NodeTypeList = append(categoryMap[category].NodeTypeList, fmt.Sprintf("%d", nodeMeta.ID))
} }
} }
@@ -178,7 +174,7 @@ func (w *ApplicationService) CreateWorkflow(ctx context.Context, req *workflow.C
IconURI: req.IconURI, IconURI: req.IconURI,
AppID: parseInt64(req.ProjectID), AppID: parseInt64(req.ProjectID),
Mode: ternary.IFElse(req.IsSetFlowMode(), req.GetFlowMode(), workflow.WorkflowMode_Workflow), Mode: ternary.IFElse(req.IsSetFlowMode(), req.GetFlowMode(), workflow.WorkflowMode_Workflow),
InitCanvasSchema: entity.GetDefaultInitCanvasJsonSchema(i18n.GetLocale(ctx)), InitCanvasSchema: vo.GetDefaultInitCanvasJsonSchema(i18n.GetLocale(ctx)),
} }
id, err := GetWorkflowDomainSVC().Create(ctx, wf) id, err := GetWorkflowDomainSVC().Create(ctx, wf)
@@ -1041,7 +1037,8 @@ func (w *ApplicationService) CopyWorkflowFromLibraryToApp(ctx context.Context, w
return wf.ID, nil return wf.ID, nil
} }
func (w *ApplicationService) MoveWorkflowFromAppToLibrary(ctx context.Context, workflowID int64, spaceID, appID int64) (_ int64, _ []*vo.ValidateIssue, err error) { func (w *ApplicationService) MoveWorkflowFromAppToLibrary(ctx context.Context, workflowID int64, spaceID, /*not used for now*/
appID int64) (_ int64, _ []*vo.ValidateIssue, err error) {
defer func() { defer func() {
if panicErr := recover(); panicErr != nil { if panicErr := recover(); panicErr != nil {
err = safego.NewPanicErr(panicErr, debug.Stack()) err = safego.NewPanicErr(panicErr, debug.Stack())
@@ -1127,15 +1124,10 @@ func (w *ApplicationService) MoveWorkflowFromAppToLibrary(ctx context.Context, w
} }
func convertNodeExecution(nodeExe *entity.NodeExecution) (*workflow.NodeResult, error) { func convertNodeExecution(nodeExe *entity.NodeExecution) (*workflow.NodeResult, error) {
nType, err := entityNodeTypeToAPINodeTemplateType(nodeExe.NodeType)
if err != nil {
return nil, err
}
nr := &workflow.NodeResult{ nr := &workflow.NodeResult{
NodeId: nodeExe.NodeID, NodeId: nodeExe.NodeID,
NodeName: nodeExe.NodeName, NodeName: nodeExe.NodeName,
NodeType: nType.String(), NodeType: entity.NodeMetaByNodeType(nodeExe.NodeType).GetDisplayKey(),
NodeStatus: workflow.NodeExeStatus(nodeExe.Status), NodeStatus: workflow.NodeExeStatus(nodeExe.Status),
ErrorInfo: ptr.FromOrDefault(nodeExe.ErrorInfo, ""), ErrorInfo: ptr.FromOrDefault(nodeExe.ErrorInfo, ""),
Input: ptr.FromOrDefault(nodeExe.Input, ""), Input: ptr.FromOrDefault(nodeExe.Input, ""),
@@ -1316,13 +1308,6 @@ func convertStreamRunEvent(workflowID int64) func(msg *entity.Message) (res *wor
return nil, schema.ErrNoValue return nil, schema.ErrNoValue
} }
var nodeType workflow.NodeTemplateType
nodeType, err = entityNodeTypeToAPINodeTemplateType(msg.NodeType)
if err != nil {
logs.Errorf("convert node type %v failed, err:=%v", msg.NodeType, err)
nodeType = workflow.NodeTemplateType(0)
}
res = &workflow.OpenAPIStreamRunFlowResponse{ res = &workflow.OpenAPIStreamRunFlowResponse{
ID: strconv.Itoa(messageID), ID: strconv.Itoa(messageID),
Event: string(MessageEvent), Event: string(MessageEvent),
@@ -1330,7 +1315,7 @@ func convertStreamRunEvent(workflowID int64) func(msg *entity.Message) (res *wor
Content: ptr.Of(msg.Content), Content: ptr.Of(msg.Content),
ContentType: ptr.Of("text"), ContentType: ptr.Of("text"),
NodeIsFinish: ptr.Of(msg.Last), NodeIsFinish: ptr.Of(msg.Last),
NodeType: ptr.Of(nodeType.String()), NodeType: ptr.Of(entity.NodeMetaByNodeType(msg.NodeType).GetDisplayKey()),
NodeID: ptr.Of(msg.NodeID), NodeID: ptr.Of(msg.NodeID),
} }
@@ -3344,178 +3329,6 @@ func toWorkflowParameter(nType *vo.NamedTypeInfo) (*workflow.Parameter, error) {
return wp, nil return wp, nil
} }
func nodeType2EntityNodeType(t string) (entity.NodeType, error) {
i, err := strconv.Atoi(t)
if err != nil {
return "", fmt.Errorf("invalid node type string '%s': %w", t, err)
}
switch i {
case 1:
return entity.NodeTypeEntry, nil
case 2:
return entity.NodeTypeExit, nil
case 3:
return entity.NodeTypeLLM, nil
case 4:
return entity.NodeTypePlugin, nil
case 5:
return entity.NodeTypeCodeRunner, nil
case 6:
return entity.NodeTypeKnowledgeRetriever, nil
case 8:
return entity.NodeTypeSelector, nil
case 9:
return entity.NodeTypeSubWorkflow, nil
case 12:
return entity.NodeTypeDatabaseCustomSQL, nil
case 13:
return entity.NodeTypeOutputEmitter, nil
case 15:
return entity.NodeTypeTextProcessor, nil
case 18:
return entity.NodeTypeQuestionAnswer, nil
case 19:
return entity.NodeTypeBreak, nil
case 20:
return entity.NodeTypeVariableAssignerWithinLoop, nil
case 21:
return entity.NodeTypeLoop, nil
case 22:
return entity.NodeTypeIntentDetector, nil
case 27:
return entity.NodeTypeKnowledgeIndexer, nil
case 28:
return entity.NodeTypeBatch, nil
case 29:
return entity.NodeTypeContinue, nil
case 30:
return entity.NodeTypeInputReceiver, nil
case 32:
return entity.NodeTypeVariableAggregator, nil
case 37:
return entity.NodeTypeMessageList, nil
case 38:
return entity.NodeTypeClearMessage, nil
case 39:
return entity.NodeTypeCreateConversation, nil
case 40:
return entity.NodeTypeVariableAssigner, nil
case 42:
return entity.NodeTypeDatabaseUpdate, nil
case 43:
return entity.NodeTypeDatabaseQuery, nil
case 44:
return entity.NodeTypeDatabaseDelete, nil
case 45:
return entity.NodeTypeHTTPRequester, nil
case 46:
return entity.NodeTypeDatabaseInsert, nil
case 58:
return entity.NodeTypeJsonSerialization, nil
case 59:
return entity.NodeTypeJsonDeserialization, nil
case 60:
return entity.NodeTypeKnowledgeDeleter, nil
default:
// Handle all unknown or unsupported types here
return "", fmt.Errorf("unsupported or unknown node type ID: %d", i)
}
}
// entityNodeTypeToAPINodeTemplateType converts an entity.NodeType to the corresponding workflow.NodeTemplateType.
func entityNodeTypeToAPINodeTemplateType(nodeType entity.NodeType) (workflow.NodeTemplateType, error) {
switch nodeType {
case entity.NodeTypeEntry:
return workflow.NodeTemplateType_Start, nil
case entity.NodeTypeExit:
return workflow.NodeTemplateType_End, nil
case entity.NodeTypeLLM:
return workflow.NodeTemplateType_LLM, nil
case entity.NodeTypePlugin:
// Maps to Api type in the API model
return workflow.NodeTemplateType_Api, nil
case entity.NodeTypeCodeRunner:
return workflow.NodeTemplateType_Code, nil
case entity.NodeTypeKnowledgeRetriever:
// Maps to Dataset type in the API model
return workflow.NodeTemplateType_Dataset, nil
case entity.NodeTypeSelector:
// Maps to If type in the API model
return workflow.NodeTemplateType_If, nil
case entity.NodeTypeSubWorkflow:
return workflow.NodeTemplateType_SubWorkflow, nil
case entity.NodeTypeDatabaseCustomSQL:
// Maps to the generic Database type in the API model
return workflow.NodeTemplateType_Database, nil
case entity.NodeTypeOutputEmitter:
// Maps to Message type in the API model
return workflow.NodeTemplateType_Message, nil
case entity.NodeTypeTextProcessor:
return workflow.NodeTemplateType_Text, nil
case entity.NodeTypeQuestionAnswer:
return workflow.NodeTemplateType_Question, nil
case entity.NodeTypeBreak:
return workflow.NodeTemplateType_Break, nil
case entity.NodeTypeVariableAssigner:
return workflow.NodeTemplateType_AssignVariable, nil
case entity.NodeTypeVariableAssignerWithinLoop:
return workflow.NodeTemplateType_LoopSetVariable, nil
case entity.NodeTypeLoop:
return workflow.NodeTemplateType_Loop, nil
case entity.NodeTypeIntentDetector:
return workflow.NodeTemplateType_Intent, nil
case entity.NodeTypeKnowledgeIndexer:
// Maps to DatasetWrite type in the API model
return workflow.NodeTemplateType_DatasetWrite, nil
case entity.NodeTypeBatch:
return workflow.NodeTemplateType_Batch, nil
case entity.NodeTypeContinue:
return workflow.NodeTemplateType_Continue, nil
case entity.NodeTypeInputReceiver:
return workflow.NodeTemplateType_Input, nil
case entity.NodeTypeMessageList:
return workflow.NodeTemplateType(37), nil
case entity.NodeTypeVariableAggregator:
return workflow.NodeTemplateType(32), nil
case entity.NodeTypeClearMessage:
return workflow.NodeTemplateType(38), nil
case entity.NodeTypeCreateConversation:
return workflow.NodeTemplateType(39), nil
// Note: entity.NodeTypeVariableAggregator (ID 32) has no direct mapping in NodeTemplateType
// Note: entity.NodeTypeMessageList (ID 37) has no direct mapping in NodeTemplateType
// Note: entity.NodeTypeClearMessage (ID 38) has no direct mapping in NodeTemplateType
// Note: entity.NodeTypeCreateConversation (ID 39) has no direct mapping in NodeTemplateType
case entity.NodeTypeDatabaseUpdate:
return workflow.NodeTemplateType_DatabaseUpdate, nil
case entity.NodeTypeDatabaseQuery:
// Maps to DatabasesELECT (ID 43) in the API model (note potential typo)
return workflow.NodeTemplateType_DatabasesELECT, nil
case entity.NodeTypeDatabaseDelete:
return workflow.NodeTemplateType_DatabaseDelete, nil
// Note: entity.NodeTypeHTTPRequester (ID 45) has no direct mapping in NodeTemplateType
case entity.NodeTypeHTTPRequester:
return workflow.NodeTemplateType(45), nil
case entity.NodeTypeDatabaseInsert:
// Maps to DatabaseInsert (ID 41) in the API model, despite entity ID being 46.
// return workflow.NodeTemplateType_DatabaseInsert, nil
return workflow.NodeTemplateType(46), nil
case entity.NodeTypeJsonSerialization:
return workflow.NodeTemplateType(58), nil
case entity.NodeTypeJsonDeserialization:
return workflow.NodeTemplateType_JsonDeserialization, nil
case entity.NodeTypeKnowledgeDeleter:
return workflow.NodeTemplateType_DatasetDelete, nil
case entity.NodeTypeLambda:
return 0, nil
default:
// Handle entity types that don't have a corresponding NodeTemplateType
return workflow.NodeTemplateType(0), fmt.Errorf("cannot map entity node type '%s' to a workflow.NodeTemplateType", nodeType)
}
}
func i64PtrToStringPtr(i *int64) *string { func i64PtrToStringPtr(i *int64) *string {
if i == nil { if i == nil {
return nil return nil
@@ -3761,7 +3574,7 @@ func mergeWorkflowAPIParameters(latestAPIParameters []*workflow.APIParameter, ex
func parseWorkflowTerminatePlanType(c *vo.Canvas) (int32, error) { func parseWorkflowTerminatePlanType(c *vo.Canvas) (int32, error) {
var endNode *vo.Node var endNode *vo.Node
for _, n := range c.Nodes { for _, n := range c.Nodes {
if n.Type == vo.BlockTypeBotEnd { if n.Type == entity.NodeTypeExit.IDStr() {
endNode = n endNode = n
break break
} }

View File

@@ -1,49 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package service
import (
"net/http"
"net/url"
"testing"
. "github.com/bytedance/mockey"
"github.com/stretchr/testify/assert"
)
func TestGenRequestString(t *testing.T) {
PatchConvey("", t, func() {
requestStr, err := genRequestString(&http.Request{
Header: http.Header{
"Content-Type": []string{"application/json"},
},
Method: http.MethodPost,
URL: &url.URL{Path: "/test"},
}, []byte(`{"a": 1}`))
assert.NoError(t, err)
assert.Equal(t, `{"header":{"Content-Type":["application/json"]},"query":null,"path":"/test","body":{"a": 1}}`, requestStr)
})
PatchConvey("", t, func() {
var body []byte
requestStr, err := genRequestString(&http.Request{
URL: &url.URL{Path: "/test"},
}, body)
assert.NoError(t, err)
assert.Equal(t, `{"header":null,"query":null,"path":"/test","body":null}`, requestStr)
})
}

View File

@@ -1,4 +1,5 @@
/* /*
* Copyright 2025 coze-dev Authors * Copyright 2025 coze-dev Authors
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,58 +17,85 @@
package entity package entity
import (
"fmt"
"strconv"
)
type NodeType string type NodeType string
func (nt NodeType) IDStr() string {
m := NodeMetaByNodeType(nt)
if m == nil {
return ""
}
return fmt.Sprintf("%d", m.ID)
}
func IDStrToNodeType(s string) NodeType {
id, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return ""
}
for _, m := range NodeTypeMetas {
if m.ID == id {
return m.Key
}
}
return ""
}
type NodeTypeMeta struct { type NodeTypeMeta struct {
ID int64 `json:"id"` ID int64
Name string `json:"name"` Key NodeType
Type NodeType `json:"type"` DisplayKey string
Category string `json:"category"` Name string `json:"name"`
Color string `json:"color"` Category string `json:"category"`
Desc string `json:"desc"` Color string `json:"color"`
IconURL string `json:"icon_url"` Desc string `json:"desc"`
SupportBatch bool `json:"support_batch"` IconURL string `json:"icon_url"`
Disabled bool `json:"disabled,omitempty"` SupportBatch bool `json:"support_batch"`
EnUSName string `json:"en_us_name,omitempty"` Disabled bool `json:"disabled,omitempty"`
EnUSDescription string `json:"en_us_description,omitempty"` EnUSName string `json:"en_us_name,omitempty"`
EnUSDescription string `json:"en_us_description,omitempty"`
ExecutableMeta ExecutableMeta
} }
func (ntm *NodeTypeMeta) GetDisplayKey() string {
if len(ntm.DisplayKey) > 0 {
return ntm.DisplayKey
}
return string(ntm.Key)
}
type Category struct { type Category struct {
Key string `json:"key"` Key string `json:"key"`
Name string `json:"name"` Name string `json:"name"`
EnUSName string `json:"en_us_name"` EnUSName string `json:"en_us_name"`
} }
type StreamingParadigm string
const (
Invoke StreamingParadigm = "invoke"
Stream StreamingParadigm = "stream"
Collect StreamingParadigm = "collect"
Transform StreamingParadigm = "transform"
)
type ExecutableMeta struct { type ExecutableMeta struct {
IsComposite bool `json:"is_composite,omitempty"` IsComposite bool `json:"is_composite,omitempty"`
DefaultTimeoutMS int64 `json:"default_timeout_ms,omitempty"` // default timeout in milliseconds, 0 means no timeout DefaultTimeoutMS int64 `json:"default_timeout_ms,omitempty"` // default timeout in milliseconds, 0 means no timeout
PreFillZero bool `json:"pre_fill_zero,omitempty"` PreFillZero bool `json:"pre_fill_zero,omitempty"`
PostFillNil bool `json:"post_fill_nil,omitempty"` PostFillNil bool `json:"post_fill_nil,omitempty"`
CallbackEnabled bool `json:"callback_enabled,omitempty"` // is false, Eino framework will inject callbacks for this node MayUseChatModel bool `json:"may_use_chat_model,omitempty"`
MayUseChatModel bool `json:"may_use_chat_model,omitempty"` InputSourceAware bool `json:"input_source_aware,omitempty"` // whether this node needs to know the runtime status of its input sources
InputSourceAware bool `json:"input_source_aware,omitempty"` // whether this node needs to know the runtime status of its input sources StreamSourceEOFAware bool `json:"needs_stream_source_eof,omitempty"` // whether this node needs to be aware stream sources' SourceEOF error
StreamingParadigms map[StreamingParadigm]bool `json:"streaming_paradigms,omitempty"`
StreamSourceEOFAware bool `json:"needs_stream_source_eof,omitempty"` // whether this node needs to be aware stream sources' SourceEOF error // IncrementalOutput indicates that the node's output is intended for progressive, user-facing streaming.
/* // This distinguishes nodes that actually stream text to the user (e.g., 'Exit', 'Output')
IncrementalOutput indicates that the node's output is intended for progressive, user-facing streaming. //from those that are merely capable of streaming internally (defined by StreamingParadigms),
This distinguishes nodes that actually stream text to the user (e.g., 'Exit', 'Output') // In essence, nodes with IncrementalOutput are a subset of those defined in StreamingParadigms.
from those that are merely capable of streaming internally (defined by StreamingParadigms), // When set to true, stream chunks from the node are persisted in real-time and can be fetched by get_process.
whose output is consumed by other nodes.
In essence, nodes with IncrementalOutput are a subset of those defined in StreamingParadigms.
When set to true, stream chunks from the node are persisted in real-time and can be fetched by get_process.
*/
IncrementalOutput bool `json:"incremental_output,omitempty"` IncrementalOutput bool `json:"incremental_output,omitempty"`
// UseCtxCache indicates that the node would require a newly initialized ctx cache for each invocation.
// example use cases:
// - write warnings to the ctx cache during Invoke, and read from the ctx within Callback output converter
UseCtxCache bool `json:"use_ctx_cache"`
} }
type PluginNodeMeta struct { type PluginNodeMeta struct {
@@ -125,9 +153,700 @@ const (
NodeTypeSubWorkflow NodeType = "SubWorkflow" NodeTypeSubWorkflow NodeType = "SubWorkflow"
NodeTypeJsonSerialization NodeType = "JsonSerialization" NodeTypeJsonSerialization NodeType = "JsonSerialization"
NodeTypeJsonDeserialization NodeType = "JsonDeserialization" NodeTypeJsonDeserialization NodeType = "JsonDeserialization"
NodeTypeComment NodeType = "Comment"
) )
const ( const (
EntryNodeKey = "100001" EntryNodeKey = "100001"
ExitNodeKey = "900001" ExitNodeKey = "900001"
) )
var Categories = []Category{
{
Key: "", // this is the default category. some of the most important nodes belong here, such as LLM, plugin, sub-workflow
Name: "",
EnUSName: "",
},
{
Key: "logic",
Name: "业务逻辑",
EnUSName: "Logic",
},
{
Key: "input&output",
Name: "输入&输出",
EnUSName: "Input&Output",
},
{
Key: "database",
Name: "数据库",
EnUSName: "Database",
},
{
Key: "data",
Name: "知识库&数据",
EnUSName: "Data",
},
{
Key: "image",
Name: "图像处理",
EnUSName: "Image",
},
{
Key: "audio&video",
Name: "音视频处理",
EnUSName: "Audio&Video",
},
{
Key: "utilities",
Name: "组件",
EnUSName: "Utilities",
},
{
Key: "conversation_management",
Name: "会话管理",
EnUSName: "Conversation management",
},
{
Key: "conversation_history",
Name: "会话历史",
EnUSName: "Conversation history",
},
{
Key: "message",
Name: "消息",
EnUSName: "Message",
},
}
// NodeTypeMetas holds the metadata for all available node types.
// It is initialized with built-in node types and potentially extended by loading from external sources.
var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
NodeTypeEntry: {
ID: 1,
Key: NodeTypeEntry,
DisplayKey: "Start",
Name: "开始",
Category: "input&output",
Desc: "工作流的起始节点,用于设定启动工作流需要的信息",
Color: "#5C62FF",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start-v2.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
PostFillNil: true,
},
EnUSName: "Start",
EnUSDescription: "The starting node of the workflow, used to set the information needed to initiate the workflow.",
},
NodeTypeExit: {
ID: 2,
Key: NodeTypeExit,
DisplayKey: "End",
Name: "结束",
Category: "input&output",
Desc: "工作流的最终节点,用于返回工作流运行后的结果信息",
Color: "#5C62FF",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End-v2.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
PreFillZero: true,
InputSourceAware: true,
StreamSourceEOFAware: true,
IncrementalOutput: true,
},
EnUSName: "End",
EnUSDescription: "The final node of the workflow, used to return the result information after the workflow runs.",
},
NodeTypeLLM: {
ID: 3,
Key: NodeTypeLLM,
DisplayKey: "LLM",
Name: "大模型",
Category: "",
Desc: "调用大语言模型,使用变量和提示词生成回复",
Color: "#5C62FF",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-LLM-v2.jpg",
SupportBatch: true,
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 3 * 60 * 1000, // 3 minutes
PreFillZero: true,
PostFillNil: true,
InputSourceAware: true,
MayUseChatModel: true,
},
EnUSName: "LLM",
EnUSDescription: "Invoke the large language model, generate responses using variables and prompt words.",
},
NodeTypePlugin: {
ID: 4,
Key: NodeTypePlugin,
DisplayKey: "Api",
Name: "插件",
Category: "",
Desc: "通过添加工具访问实时数据和执行外部操作",
Color: "#CA61FF",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Plugin-v2.jpg",
SupportBatch: true,
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 3 * 60 * 1000, // 3 minutes
PreFillZero: true,
PostFillNil: true,
},
EnUSName: "Plugin",
EnUSDescription: "Used to access external real-time data and perform operations",
},
NodeTypeCodeRunner: {
ID: 5,
Key: NodeTypeCodeRunner,
DisplayKey: "Code",
Name: "代码",
Category: "logic",
Desc: "编写代码,处理输入变量来生成返回值",
Color: "#00B2B2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Code-v2.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
UseCtxCache: true,
},
EnUSName: "Code",
EnUSDescription: "Write code to process input variables to generate return values.",
},
NodeTypeKnowledgeRetriever: {
ID: 6,
Key: NodeTypeKnowledgeRetriever,
DisplayKey: "Dataset",
Name: "知识库检索",
Category: "data",
Desc: "在选定的知识中,根据输入变量召回最匹配的信息,并以列表形式返回",
Color: "#FF811A",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-KnowledgeQuery-v2.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
},
EnUSName: "Knowledge retrieval",
EnUSDescription: "In the selected knowledge, the best matching information is recalled based on the input variable and returned as an Array.",
},
NodeTypeSelector: {
ID: 8,
Key: NodeTypeSelector,
DisplayKey: "If",
Name: "选择器",
Category: "logic",
Desc: "连接多个下游分支,若设定的条件成立则仅运行对应的分支,若均不成立则只运行“否则”分支",
Color: "#00B2B2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Condition-v2.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{},
EnUSName: "Condition",
EnUSDescription: "Connect multiple downstream branches. Only the corresponding branch will be executed if the set conditions are met. If none are met, only the 'else' branch will be executed.",
},
NodeTypeSubWorkflow: {
ID: 9,
Key: NodeTypeSubWorkflow,
DisplayKey: "SubWorkflow",
Name: "工作流",
Category: "",
Desc: "集成已发布工作流,可以执行嵌套子任务",
Color: "#00B83E",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Workflow-v2.jpg",
SupportBatch: true,
ExecutableMeta: ExecutableMeta{},
EnUSName: "Workflow",
EnUSDescription: "Add published workflows to execute subtasks",
},
NodeTypeDatabaseCustomSQL: {
ID: 12,
Key: NodeTypeDatabaseCustomSQL,
DisplayKey: "End",
Name: "SQL自定义",
Category: "database",
Desc: "基于用户自定义的 SQL 完成对数据库的增删改查操作",
Color: "#FF811A",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Database-v2.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
},
EnUSName: "SQL Customization",
EnUSDescription: "Complete the operations of adding, deleting, modifying and querying the database based on user-defined SQL",
},
NodeTypeOutputEmitter: {
ID: 13,
Key: NodeTypeOutputEmitter,
DisplayKey: "Message",
Name: "输出",
Category: "input&output",
Desc: "节点从“消息”更名为“输出”,支持中间过程的消息输出,支持流式和非流式两种方式",
Color: "#5C62FF",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Output-v2.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
PreFillZero: true,
InputSourceAware: true,
StreamSourceEOFAware: true,
IncrementalOutput: true,
},
EnUSName: "Output",
EnUSDescription: "The node is renamed from \"message\" to \"output\", Supports message output in the intermediate process and streaming and non-streaming methods",
},
NodeTypeTextProcessor: {
ID: 15,
Key: NodeTypeTextProcessor,
DisplayKey: "Text",
Name: "文本处理",
Category: "utilities",
Desc: "用于处理多个字符串类型变量的格式",
Color: "#3071F2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-StrConcat-v2.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
PreFillZero: true,
InputSourceAware: true,
},
EnUSName: "Text Processing",
EnUSDescription: "The format used for handling multiple string-type variables.",
},
NodeTypeQuestionAnswer: {
ID: 18,
Key: NodeTypeQuestionAnswer,
DisplayKey: "Question",
Name: "问答",
Category: "utilities",
Desc: "支持中间向用户提问问题,支持预置选项提问和开放式问题提问两种方式",
Color: "#3071F2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Direct-Question-v2.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
MayUseChatModel: true,
},
EnUSName: "Question",
EnUSDescription: "Support asking questions to the user in the middle of the conversation, with both preset options and open-ended questions",
},
NodeTypeBreak: {
ID: 19,
Key: NodeTypeBreak,
DisplayKey: "Break",
Name: "终止循环",
Category: "logic",
Desc: "用于立即终止当前所在的循环,跳出循环体",
Color: "#00B2B2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Break-v2.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{},
EnUSName: "Break",
EnUSDescription: "Used to immediately terminate the current loop and jump out of the loop",
},
NodeTypeVariableAssignerWithinLoop: {
ID: 20,
Key: NodeTypeVariableAssignerWithinLoop,
DisplayKey: "LoopSetVariable",
Name: "设置变量",
Category: "logic",
Desc: "用于重置循环变量的值,使其下次循环使用重置后的值",
Color: "#00B2B2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-LoopSetVariable-v2.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{},
EnUSName: "Set Variable",
EnUSDescription: "Used to reset the value of the loop variable so that it uses the reset value in the next iteration",
},
NodeTypeLoop: {
ID: 21,
Key: NodeTypeLoop,
DisplayKey: "Loop",
Name: "循环",
Category: "logic",
Desc: "用于通过设定循环次数和逻辑,重复执行一系列任务",
Color: "#00B2B2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Loop-v2.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
IsComposite: true,
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
},
EnUSName: "Loop",
EnUSDescription: "Used to repeatedly execute a series of tasks by setting the number of iterations and logic",
},
NodeTypeIntentDetector: {
ID: 22,
Key: NodeTypeIntentDetector,
DisplayKey: "Intent",
Name: "意图识别",
Category: "logic",
Desc: "用于用户输入的意图识别,并将其与预设意图选项进行匹配。",
Color: "#00B2B2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Intent-v2.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
MayUseChatModel: true,
},
EnUSName: "Intent recognition",
EnUSDescription: "Used for recognizing the intent in user input and matching it with preset intent options.",
},
NodeTypeKnowledgeIndexer: {
ID: 27,
Key: NodeTypeKnowledgeIndexer,
DisplayKey: "DatasetWrite",
Name: "知识库写入",
Category: "data",
Desc: "写入节点可以添加 文本类型 的知识库,仅可以添加一个知识库",
Color: "#FF811A",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-KnowledgeWriting-v2.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
},
EnUSName: "Knowledge writing",
EnUSDescription: "The write node can add a knowledge base of type text. Only one knowledge base can be added.",
},
NodeTypeBatch: {
ID: 28,
Key: NodeTypeBatch,
DisplayKey: "Batch",
Name: "批处理",
Category: "logic",
Desc: "通过设定批量运行次数和逻辑,运行批处理体内的任务",
Color: "#00B2B2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Batch-v2.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
IsComposite: true,
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
},
EnUSName: "Batch",
EnUSDescription: "By setting the number of batch runs and logic, run the tasks in the batch body.",
},
NodeTypeContinue: {
ID: 29,
Key: NodeTypeContinue,
DisplayKey: "Continue",
Name: "继续循环",
Category: "logic",
Desc: "用于终止当前循环,执行下次循环",
Color: "#00B2B2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Continue-v2.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{},
EnUSName: "Continue",
EnUSDescription: "Used to immediately terminate the current loop and execute next loop",
},
NodeTypeInputReceiver: {
ID: 30,
Key: NodeTypeInputReceiver,
DisplayKey: "Input",
Name: "输入",
Category: "input&output",
Desc: "支持中间过程的信息输入",
Color: "#5C62FF",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Input-v2.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
PostFillNil: true,
},
EnUSName: "Input",
EnUSDescription: "Support intermediate information input",
},
NodeTypeComment: {
ID: 31,
Key: "",
Name: "注释",
Category: "", // Not found in cate_list
Desc: "comment_desc", // Placeholder from JSON
Color: "",
IconURL: "comment_icon", // Placeholder from JSON
SupportBatch: false, // supportBatch: 1
EnUSName: "Comment",
},
NodeTypeVariableAggregator: {
ID: 32,
Key: NodeTypeVariableAggregator,
Name: "变量聚合",
Category: "logic",
Desc: "对多个分支的输出进行聚合处理",
Color: "#00B2B2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/VariableMerge-icon.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
PostFillNil: true,
InputSourceAware: true,
UseCtxCache: true,
},
EnUSName: "Variable Merge",
EnUSDescription: "Aggregate the outputs of multiple branches.",
},
NodeTypeMessageList: {
ID: 37,
Key: NodeTypeMessageList,
Name: "查询消息列表",
Category: "message",
Desc: "用于查询消息列表",
Color: "#F2B600",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Conversation-List.jpeg",
SupportBatch: false,
Disabled: true,
ExecutableMeta: ExecutableMeta{
PreFillZero: true,
PostFillNil: true,
},
EnUSName: "Query message list",
EnUSDescription: "Used to query the message list",
},
NodeTypeClearMessage: {
ID: 38,
Key: NodeTypeClearMessage,
Name: "清除上下文",
Category: "conversation_history",
Desc: "用于清空会话历史清空后LLM看到的会话历史为空",
Color: "#F2B600",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Conversation-Delete.jpeg",
SupportBatch: false,
Disabled: true,
ExecutableMeta: ExecutableMeta{
PreFillZero: true,
PostFillNil: true,
},
EnUSName: "Clear conversation history",
EnUSDescription: "Used to clear conversation history. After clearing, the conversation history visible to the LLM node will be empty.",
},
NodeTypeCreateConversation: {
ID: 39,
Key: NodeTypeCreateConversation,
Name: "创建会话",
Category: "conversation_management",
Desc: "用于创建会话",
Color: "#F2B600",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Conversation-Create.jpeg",
SupportBatch: false,
Disabled: true,
ExecutableMeta: ExecutableMeta{
PreFillZero: true,
PostFillNil: true,
},
EnUSName: "Create conversation",
EnUSDescription: "This node is used to create a conversation.",
},
NodeTypeVariableAssigner: {
ID: 40,
Key: NodeTypeVariableAssigner,
DisplayKey: "AssignVariable",
Name: "变量赋值",
Category: "data",
Desc: "用于给支持写入的变量赋值,包括应用变量、用户变量",
Color: "#FF811A",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/Variable.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{},
EnUSName: "Variable assign",
EnUSDescription: "Assigns values to variables that support the write operation, including app and user variables.",
},
NodeTypeDatabaseUpdate: {
ID: 42,
Key: NodeTypeDatabaseUpdate,
DisplayKey: "DatabaseUpdate",
Name: "更新数据",
Category: "database",
Desc: "修改表中已存在的数据记录,用户指定更新条件和内容来更新数据",
Color: "#F2B600",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-database-update.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
},
EnUSName: "Update Data",
EnUSDescription: "Modify the existing data records in the table, and the user specifies the update conditions and contents to update the data",
},
NodeTypeDatabaseQuery: {
ID: 43,
Key: NodeTypeDatabaseQuery,
DisplayKey: "DatabaseSelect",
Name: "查询数据",
Category: "database",
Desc: "从表获取数据,用户可定义查询条件、选择列等,输出符合条件的数据",
Color: "#F2B600",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icaon-database-select.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
},
EnUSName: "Query Data",
EnUSDescription: "Query data from the table, and the user can define query conditions, select columns, etc., and output the data that meets the conditions",
},
NodeTypeDatabaseDelete: {
ID: 44,
Key: NodeTypeDatabaseDelete,
DisplayKey: "DatabaseDelete",
Name: "删除数据",
Category: "database",
Desc: "从表中删除数据记录,用户指定删除条件来删除符合条件的记录",
Color: "#F2B600",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-database-delete.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
},
EnUSName: "Delete Data",
EnUSDescription: "Delete data records from the table, and the user specifies the deletion conditions to delete the records that meet the conditions",
},
NodeTypeHTTPRequester: {
ID: 45,
Key: NodeTypeHTTPRequester,
DisplayKey: "Http",
Name: "HTTP 请求",
Category: "utilities",
Desc: "用于发送API请求从接口返回数据",
Color: "#3071F2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-HTTP.png",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
},
EnUSName: "HTTP request",
EnUSDescription: "It is used to send API requests and return data from the interface.",
},
NodeTypeDatabaseInsert: {
ID: 46,
Key: NodeTypeDatabaseInsert,
DisplayKey: "DatabaseInsert",
Name: "新增数据",
Category: "database",
Desc: "向表添加新数据记录,用户输入数据内容后插入数据库",
Color: "#F2B600",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-database-insert.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
},
EnUSName: "Add Data",
EnUSDescription: "Add new data records to the table, and insert them into the database after the user enters the data content",
},
NodeTypeJsonSerialization: {
// ID is the unique identifier of this node type. Used in various front-end APIs.
ID: 58,
// Key is the unique NodeType of this node. Used in backend code as well as saved in DB.
Key: NodeTypeJsonSerialization,
// DisplayKey is the string used in frontend to identify this node.
// Example use cases:
// - used during querying test-run results for nodes
// - used in returned messages from streaming openAPI Runs.
// If empty, will use Key as DisplayKey.
DisplayKey: "ToJSON",
// Name is the node in ZH_CN, will be displayed on Canvas.
Name: "JSON 序列化",
// Category is the category of this node, determines which category this node will be displayed in.
Category: "utilities",
// Desc is the desc in ZH_CN, will be displayed as tooltip on Canvas.
Desc: "用于把变量转化为JSON字符串",
// Color is the color of the upper edge of the node displayed on Canvas.
Color: "F2B600",
// IconURL is the URL of the icon displayed on Canvas.
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-to_json.png",
// SupportBatch indicates whether this node can set batch mode.
// NOTE: ultimately it's frontend that decides which node can enable batch mode.
SupportBatch: false,
// ExecutableMeta configures certain common aspects of request-time behaviors for this node.
ExecutableMeta: ExecutableMeta{
// DefaultTimeoutMS configures the default timeout for this node, in milliseconds. 0 means no timeout.
DefaultTimeoutMS: 60 * 1000, // 1 minute
// PreFillZero decides whether to pre-fill zero value for any missing fields in input.
PreFillZero: true,
// PostFillNil decides whether to post-fill nil value for any missing fields in output.
PostFillNil: true,
},
// EnUSName is the name in EN_US, will be displayed on Canvas if language of Coze-Studio is set to EnUS.
EnUSName: "JSON serialization",
// EnUSDescription is the description in EN_US, will be displayed on Canvas if language of Coze-Studio is set to EnUS.
EnUSDescription: "Convert variable to JSON string",
},
NodeTypeJsonDeserialization: {
ID: 59,
Key: NodeTypeJsonDeserialization,
DisplayKey: "FromJSON",
Name: "JSON 反序列化",
Category: "utilities",
Desc: "用于将JSON字符串解析为变量",
Color: "F2B600",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-from_json.png",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
UseCtxCache: true,
},
EnUSName: "JSON deserialization",
EnUSDescription: "Parse JSON string to variable",
},
NodeTypeKnowledgeDeleter: {
ID: 60,
Key: NodeTypeKnowledgeDeleter,
DisplayKey: "KnowledgeDelete",
Name: "知识库删除",
Category: "data",
Desc: "用于删除知识库中的文档",
Color: "#FF811A",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icons-dataset-delete.png",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
},
EnUSName: "Knowledge delete",
EnUSDescription: "The delete node can delete a document in knowledge base.",
},
NodeTypeLambda: {
ID: 1000,
Key: NodeTypeLambda,
Name: "Lambda",
EnUSName: "Comment",
},
}
// PluginNodeMetas holds metadata for specific plugin API entity.
var PluginNodeMetas []*PluginNodeMeta
// PluginCategoryMetas holds metadata for plugin category entity.
var PluginCategoryMetas []*PluginCategoryMeta
func NodeMetaByNodeType(t NodeType) *NodeTypeMeta {
if m, ok := NodeTypeMetas[t]; ok {
return m
}
return nil
}

View File

@@ -1,867 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package entity
import (
"github.com/coze-dev/coze-studio/backend/pkg/i18n"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
)
var Categories = []Category{
{
Key: "", // this is the default category. some of the most important nodes belong here, such as LLM, plugin, sub-workflow
Name: "",
EnUSName: "",
},
{
Key: "logic",
Name: "业务逻辑",
EnUSName: "Logic",
},
{
Key: "input&output",
Name: "输入&输出",
EnUSName: "Input&Output",
},
{
Key: "database",
Name: "数据库",
EnUSName: "Database",
},
{
Key: "data",
Name: "知识库&数据",
EnUSName: "Data",
},
{
Key: "image",
Name: "图像处理",
EnUSName: "Image",
},
{
Key: "audio&video",
Name: "音视频处理",
EnUSName: "Audio&Video",
},
{
Key: "utilities",
Name: "组件",
EnUSName: "Utilities",
},
{
Key: "conversation_management",
Name: "会话管理",
EnUSName: "Conversation management",
},
{
Key: "conversation_history",
Name: "会话历史",
EnUSName: "Conversation history",
},
{
Key: "message",
Name: "消息",
EnUSName: "Message",
},
}
// NodeTypeMetas holds the metadata for all available node types.
// It is initialized with built-in types and potentially extended by loading from external sources.
var NodeTypeMetas = []*NodeTypeMeta{
{
ID: 1,
Name: "开始",
Type: NodeTypeEntry,
Category: "input&output", // Mapped from cate_list
Desc: "工作流的起始节点,用于设定启动工作流需要的信息",
Color: "#5C62FF",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start-v2.jpg",
SupportBatch: false, // supportBatch: 1
ExecutableMeta: ExecutableMeta{
PostFillNil: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Start",
EnUSDescription: "The starting node of the workflow, used to set the information needed to initiate the workflow.",
},
{
ID: 2,
Name: "结束",
Type: NodeTypeExit,
Category: "input&output", // Mapped from cate_list
Desc: "工作流的最终节点,用于返回工作流运行后的结果信息",
Color: "#5C62FF",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End-v2.jpg",
SupportBatch: false, // supportBatch: 1
ExecutableMeta: ExecutableMeta{
PreFillZero: true,
CallbackEnabled: true,
InputSourceAware: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true, Transform: true},
StreamSourceEOFAware: true,
IncrementalOutput: true,
},
EnUSName: "End",
EnUSDescription: "The final node of the workflow, used to return the result information after the workflow runs.",
},
{
ID: 3,
Name: "大模型",
Type: NodeTypeLLM,
Category: "", // Mapped from cate_list
Desc: "调用大语言模型,使用变量和提示词生成回复",
Color: "#5C62FF",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-LLM-v2.jpg",
SupportBatch: true, // supportBatch: 2
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 3 * 60 * 1000, // 3 minutes
PreFillZero: true,
PostFillNil: true,
CallbackEnabled: true,
InputSourceAware: true,
MayUseChatModel: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true, Stream: true},
},
EnUSName: "LLM",
EnUSDescription: "Invoke the large language model, generate responses using variables and prompt words.",
},
{
ID: 4,
Name: "插件",
Type: NodeTypePlugin,
Category: "", // Mapped from cate_list
Desc: "通过添加工具访问实时数据和执行外部操作",
Color: "#CA61FF",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Plugin-v2.jpg",
SupportBatch: true, // supportBatch: 2
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 3 * 60 * 1000, // 3 minutes
PreFillZero: true,
PostFillNil: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Plugin",
EnUSDescription: "Used to access external real-time data and perform operations",
},
{
ID: 5,
Name: "代码",
Type: NodeTypeCodeRunner,
Category: "logic", // Mapped from cate_list
Desc: "编写代码,处理输入变量来生成返回值",
Color: "#00B2B2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Code-v2.jpg",
SupportBatch: false, // supportBatch: 1
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
CallbackEnabled: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Code",
EnUSDescription: "Write code to process input variables to generate return values.",
},
{
ID: 6,
Name: "知识库检索",
Type: NodeTypeKnowledgeRetriever,
Category: "data", // Mapped from cate_list
Desc: "在选定的知识中,根据输入变量召回最匹配的信息,并以列表形式返回",
Color: "#FF811A",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-KnowledgeQuery-v2.jpg",
SupportBatch: false, // supportBatch: 1
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Knowledge retrieval",
EnUSDescription: "In the selected knowledge, the best matching information is recalled based on the input variable and returned as an Array.",
},
{
ID: 8,
Name: "选择器",
Type: NodeTypeSelector,
Category: "logic", // Mapped from cate_list
Desc: "连接多个下游分支,若设定的条件成立则仅运行对应的分支,若均不成立则只运行“否则”分支",
Color: "#00B2B2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Condition-v2.jpg",
SupportBatch: false, // supportBatch: 1
ExecutableMeta: ExecutableMeta{
CallbackEnabled: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Condition",
EnUSDescription: "Connect multiple downstream branches. Only the corresponding branch will be executed if the set conditions are met. If none are met, only the 'else' branch will be executed.",
},
{
ID: 9,
Name: "工作流",
Type: NodeTypeSubWorkflow,
Category: "", // Mapped from cate_list
Desc: "集成已发布工作流,可以执行嵌套子任务",
Color: "#00B83E",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Workflow-v2.jpg",
SupportBatch: true, // supportBatch: 2
ExecutableMeta: ExecutableMeta{
CallbackEnabled: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Workflow",
EnUSDescription: "Add published workflows to execute subtasks",
},
{
ID: 12,
Name: "SQL自定义",
Type: NodeTypeDatabaseCustomSQL,
Category: "database", // Mapped from cate_list
Desc: "基于用户自定义的 SQL 完成对数据库的增删改查操作",
Color: "#FF811A",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Database-v2.jpg",
SupportBatch: false, // supportBatch: 2
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "SQL Customization",
EnUSDescription: "Complete the operations of adding, deleting, modifying and querying the database based on user-defined SQL",
},
{
ID: 13,
Name: "输出",
Type: NodeTypeOutputEmitter,
Category: "input&output", // Mapped from cate_list
Desc: "节点从“消息”更名为“输出”,支持中间过程的消息输出,支持流式和非流式两种方式",
Color: "#5C62FF",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Output-v2.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
PreFillZero: true,
CallbackEnabled: true,
InputSourceAware: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true, Stream: true},
StreamSourceEOFAware: true,
IncrementalOutput: true,
},
EnUSName: "Output",
EnUSDescription: "The node is renamed from \"message\" to \"output\", Supports message output in the intermediate process and streaming and non-streaming methods",
},
{
ID: 15,
Name: "文本处理",
Type: NodeTypeTextProcessor,
Category: "utilities", // Mapped from cate_list
Desc: "用于处理多个字符串类型变量的格式",
Color: "#3071F2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-StrConcat-v2.jpg",
SupportBatch: false, // supportBatch: 2
ExecutableMeta: ExecutableMeta{
PreFillZero: true,
CallbackEnabled: true,
InputSourceAware: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Text Processing",
EnUSDescription: "The format used for handling multiple string-type variables.",
},
{
ID: 18,
Name: "问答",
Type: NodeTypeQuestionAnswer,
Category: "utilities", // Mapped from cate_list
Desc: "支持中间向用户提问问题,支持预置选项提问和开放式问题提问两种方式",
Color: "#3071F2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Direct-Question-v2.jpg",
SupportBatch: false, // supportBatch: 1
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
CallbackEnabled: true,
MayUseChatModel: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Question",
EnUSDescription: "Support asking questions to the user in the middle of the conversation, with both preset options and open-ended questions",
},
{
ID: 19,
Name: "终止循环",
Type: NodeTypeBreak,
Category: "logic", // Mapped from cate_list
Desc: "用于立即终止当前所在的循环,跳出循环体",
Color: "#00B2B2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Break-v2.jpg",
SupportBatch: false, // supportBatch: 1
ExecutableMeta: ExecutableMeta{
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Break",
EnUSDescription: "Used to immediately terminate the current loop and jump out of the loop",
},
{
ID: 20,
Name: "设置变量",
Type: NodeTypeVariableAssignerWithinLoop,
Category: "logic", // Mapped from cate_list
Desc: "用于重置循环变量的值,使其下次循环使用重置后的值",
Color: "#00B2B2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-LoopSetVariable-v2.jpg",
SupportBatch: false, // supportBatch: 1
ExecutableMeta: ExecutableMeta{
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Set Variable",
EnUSDescription: "Used to reset the value of the loop variable so that it uses the reset value in the next iteration",
},
{
ID: 21,
Name: "循环",
Type: NodeTypeLoop,
Category: "logic", // Mapped from cate_list
Desc: "用于通过设定循环次数和逻辑,重复执行一系列任务",
Color: "#00B2B2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Loop-v2.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
IsComposite: true,
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
CallbackEnabled: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Loop",
EnUSDescription: "Used to repeatedly execute a series of tasks by setting the number of iterations and logic",
},
{
ID: 22,
Name: "意图识别",
Type: NodeTypeIntentDetector,
Category: "logic", // Mapped from cate_list
Desc: "用于用户输入的意图识别,并将其与预设意图选项进行匹配。",
Color: "#00B2B2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Intent-v2.jpg",
SupportBatch: false, // supportBatch: 1
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
CallbackEnabled: true,
MayUseChatModel: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Intent recognition",
EnUSDescription: "Used for recognizing the intent in user input and matching it with preset intent options.",
},
{
ID: 27,
Name: "知识库写入",
Type: NodeTypeKnowledgeIndexer,
Category: "data", // Mapped from cate_list
Desc: "写入节点可以添加 文本类型 的知识库,仅可以添加一个知识库",
Color: "#FF811A",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-KnowledgeWriting-v2.jpg",
SupportBatch: false, // supportBatch: 1
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Knowledge writing",
EnUSDescription: "The write node can add a knowledge base of type text. Only one knowledge base can be added.",
},
{
ID: 28,
Name: "批处理",
Type: NodeTypeBatch,
Category: "logic", // Mapped from cate_list
Desc: "通过设定批量运行次数和逻辑,运行批处理体内的任务",
Color: "#00B2B2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Batch-v2.jpg",
SupportBatch: false, // supportBatch: 1 (Corrected from previous assumption)
ExecutableMeta: ExecutableMeta{
IsComposite: true,
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
CallbackEnabled: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Batch",
EnUSDescription: "By setting the number of batch runs and logic, run the tasks in the batch body.",
},
{
ID: 29,
Name: "继续循环",
Type: NodeTypeContinue,
Category: "logic", // Mapped from cate_list
Desc: "用于终止当前循环,执行下次循环",
Color: "#00B2B2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Continue-v2.jpg",
SupportBatch: false, // supportBatch: 1
ExecutableMeta: ExecutableMeta{
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Continue",
EnUSDescription: "Used to immediately terminate the current loop and execute next loop",
},
{
ID: 30,
Name: "输入",
Type: NodeTypeInputReceiver,
Category: "input&output", // Mapped from cate_list
Desc: "支持中间过程的信息输入",
Color: "#5C62FF",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Input-v2.jpg",
SupportBatch: false, // supportBatch: 1
ExecutableMeta: ExecutableMeta{
PostFillNil: true,
CallbackEnabled: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Input",
EnUSDescription: "Support intermediate information input",
},
{
ID: 31,
Name: "注释",
Type: "",
Category: "", // Not found in cate_list
Desc: "comment_desc", // Placeholder from JSON
Color: "",
IconURL: "comment_icon", // Placeholder from JSON
SupportBatch: false, // supportBatch: 1
EnUSName: "Comment",
},
{
ID: 32,
Name: "变量聚合",
Type: NodeTypeVariableAggregator,
Category: "logic", // Mapped from cate_list
Desc: "对多个分支的输出进行聚合处理",
Color: "#00B2B2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/VariableMerge-icon.jpg",
SupportBatch: false, // supportBatch: 1
ExecutableMeta: ExecutableMeta{
PostFillNil: true,
CallbackEnabled: true,
InputSourceAware: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true, Transform: true},
},
EnUSName: "Variable Merge",
EnUSDescription: "Aggregate the outputs of multiple branches.",
},
{
ID: 37,
Name: "查询消息列表",
Type: NodeTypeMessageList,
Category: "message", // Mapped from cate_list
Desc: "用于查询消息列表",
Color: "#F2B600",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Conversation-List.jpeg",
SupportBatch: false, // supportBatch: 1
Disabled: true,
ExecutableMeta: ExecutableMeta{
PreFillZero: true,
PostFillNil: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Query message list",
EnUSDescription: "Used to query the message list",
},
{
ID: 38,
Name: "清除上下文",
Type: NodeTypeClearMessage,
Category: "conversation_history", // Mapped from cate_list
Desc: "用于清空会话历史清空后LLM看到的会话历史为空",
Color: "#F2B600",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Conversation-Delete.jpeg",
SupportBatch: false, // supportBatch: 1
Disabled: true,
ExecutableMeta: ExecutableMeta{
PreFillZero: true,
PostFillNil: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Clear conversation history",
EnUSDescription: "Used to clear conversation history. After clearing, the conversation history visible to the LLM node will be empty.",
},
{
ID: 39,
Name: "创建会话",
Type: NodeTypeCreateConversation,
Category: "conversation_management", // Mapped from cate_list
Desc: "用于创建会话",
Color: "#F2B600",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Conversation-Create.jpeg",
SupportBatch: false, // supportBatch: 1
Disabled: true,
ExecutableMeta: ExecutableMeta{
PreFillZero: true,
PostFillNil: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Create conversation",
EnUSDescription: "This node is used to create a conversation.",
},
{
ID: 40,
Name: "变量赋值",
Type: NodeTypeVariableAssigner,
Category: "data", // Mapped from cate_list
Desc: "用于给支持写入的变量赋值,包括应用变量、用户变量",
Color: "#FF811A",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/Variable.jpg",
SupportBatch: false, // supportBatch: 1
ExecutableMeta: ExecutableMeta{
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Variable assign",
EnUSDescription: "Assigns values to variables that support the write operation, including app and user variables.",
},
{
ID: 42,
Name: "更新数据",
Type: NodeTypeDatabaseUpdate,
Category: "database", // Mapped from cate_list
Desc: "修改表中已存在的数据记录,用户指定更新条件和内容来更新数据",
Color: "#F2B600",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-database-update.jpg", // Corrected Icon URL from JSON
SupportBatch: false, // supportBatch: 1
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
CallbackEnabled: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Update Data",
EnUSDescription: "Modify the existing data records in the table, and the user specifies the update conditions and contents to update the data",
},
{
ID: 43,
Name: "查询数据", // Corrected Name from JSON (was "insert data")
Type: NodeTypeDatabaseQuery,
Category: "database", // Mapped from cate_list
Desc: "从表获取数据,用户可定义查询条件、选择列等,输出符合条件的数据", // Corrected Desc from JSON
Color: "#F2B600",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icaon-database-select.jpg", // Corrected Icon URL from JSON
SupportBatch: false, // supportBatch: 1
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
CallbackEnabled: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Query Data",
EnUSDescription: "Query data from the table, and the user can define query conditions, select columns, etc., and output the data that meets the conditions",
},
{
ID: 44,
Name: "删除数据",
Type: NodeTypeDatabaseDelete,
Category: "database", // Mapped from cate_list
Desc: "从表中删除数据记录,用户指定删除条件来删除符合条件的记录", // Corrected Desc from JSON
Color: "#F2B600",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-database-delete.jpg", // Corrected Icon URL from JSON
SupportBatch: false, // supportBatch: 1
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
CallbackEnabled: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Delete Data",
EnUSDescription: "Delete data records from the table, and the user specifies the deletion conditions to delete the records that meet the conditions",
},
{
ID: 45,
Name: "HTTP 请求",
Type: NodeTypeHTTPRequester,
Category: "utilities", // Mapped from cate_list
Desc: "用于发送API请求从接口返回数据", // Corrected Desc from JSON
Color: "#3071F2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-HTTP.png", // Corrected Icon URL from JSON
SupportBatch: false, // supportBatch: 1
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
CallbackEnabled: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "HTTP request",
EnUSDescription: "It is used to send API requests and return data from the interface.",
},
{
ID: 46,
Name: "新增数据", // Corrected Name from JSON (was "Query Data")
Type: NodeTypeDatabaseInsert,
Category: "database", // Mapped from cate_list
Desc: "向表添加新数据记录,用户输入数据内容后插入数据库", // Corrected Desc from JSON
Color: "#F2B600",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-database-insert.jpg", // Corrected Icon URL from JSON
SupportBatch: false, // supportBatch: 1
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
CallbackEnabled: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Add Data",
EnUSDescription: "Add new data records to the table, and insert them into the database after the user enters the data content",
},
{
ID: 58,
Name: "JSON 序列化",
Type: NodeTypeJsonSerialization,
Category: "utilities",
Desc: "用于把变量转化为JSON字符串",
Color: "F2B600",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-to_json.png",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
CallbackEnabled: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "JSON serialization",
EnUSDescription: "Convert variable to JSON string",
},
{
ID: 59,
Name: "JSON 反序列化",
Type: NodeTypeJsonDeserialization,
Category: "utilities",
Desc: "用于将JSON字符串解析为变量",
Color: "F2B600",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-from_json.png",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
CallbackEnabled: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "JSON deserialization",
EnUSDescription: "Parse JSON string to variable",
},
{
ID: 60,
Name: "知识库删除",
Type: NodeTypeKnowledgeDeleter,
Category: "data", // Mapped from cate_list
Desc: "用于删除知识库中的文档",
Color: "#FF811A",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icons-dataset-delete.png",
SupportBatch: false, // supportBatch: 1
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Knowledge delete",
EnUSDescription: "The delete node can delete a document in knowledge base.",
},
// --- End of nodes parsed from template_list ---
}
// PluginNodeMetas holds metadata for specific plugin API entity.
var PluginNodeMetas []*PluginNodeMeta
// PluginCategoryMetas holds metadata for plugin category entity.
var PluginCategoryMetas []*PluginCategoryMeta
func NodeMetaByNodeType(t NodeType) *NodeTypeMeta {
for _, meta := range NodeTypeMetas {
if meta.Type == t {
return meta
}
}
return nil
}
const defaultZhCNInitCanvasJsonSchema = `{
"nodes": [
{
"id": "100001",
"type": "1",
"meta": {
"position": {
"x": 0,
"y": 0
}
},
"data": {
"nodeMeta": {
"description": "工作流的起始节点,用于设定启动工作流需要的信息",
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start.png",
"subTitle": "",
"title": "开始"
},
"outputs": [
{
"type": "string",
"name": "input",
"required": false
}
],
"trigger_parameters": [
{
"type": "string",
"name": "input",
"required": false
}
]
}
},
{
"id": "900001",
"type": "2",
"meta": {
"position": {
"x": 1000,
"y": 0
}
},
"data": {
"nodeMeta": {
"description": "工作流的最终节点,用于返回工作流运行后的结果信息",
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End.png",
"subTitle": "",
"title": "结束"
},
"inputs": {
"terminatePlan": "returnVariables",
"inputParameters": [
{
"name": "output",
"input": {
"type": "string",
"value": {
"type": "ref",
"content": {
"source": "block-output",
"blockID": "",
"name": ""
}
}
}
}
]
}
}
}
],
"edges": [],
"versions": {
"loop": "v2"
}
}`
const defaultEnUSInitCanvasJsonSchema = `{
"nodes": [
{
"id": "100001",
"type": "1",
"meta": {
"position": {
"x": 0,
"y": 0
}
},
"data": {
"nodeMeta": {
"description": "The starting node of the workflow, used to set the information needed to initiate the workflow.",
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start.png",
"subTitle": "",
"title": "Start"
},
"outputs": [
{
"type": "string",
"name": "input",
"required": false
}
],
"trigger_parameters": [
{
"type": "string",
"name": "input",
"required": false
}
]
}
},
{
"id": "900001",
"type": "2",
"meta": {
"position": {
"x": 1000,
"y": 0
}
},
"data": {
"nodeMeta": {
"description": "The final node of the workflow, used to return the result information after the workflow runs.",
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End.png",
"subTitle": "",
"title": "End"
},
"inputs": {
"terminatePlan": "returnVariables",
"inputParameters": [
{
"name": "output",
"input": {
"type": "string",
"value": {
"type": "ref",
"content": {
"source": "block-output",
"blockID": "",
"name": ""
}
}
}
}
]
}
}
}
],
"edges": [],
"versions": {
"loop": "v2"
}
}`
func GetDefaultInitCanvasJsonSchema(locale i18n.Locale) string {
return ternary.IFElse(locale == i18n.LocaleEN, defaultEnUSInitCanvasJsonSchema, defaultZhCNInitCanvasJsonSchema)
}

View File

@@ -19,24 +19,48 @@ package vo
import ( import (
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/workflow" "github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model" "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
"github.com/coze-dev/coze-studio/backend/pkg/i18n"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
) )
// Canvas is the definition of FRONTEND schema for a workflow.
type Canvas struct { type Canvas struct {
Nodes []*Node `json:"nodes"` Nodes []*Node `json:"nodes"`
Edges []*Edge `json:"edges"` Edges []*Edge `json:"edges"`
Versions any `json:"versions"` Versions any `json:"versions"`
} }
// Node represents a node within a workflow canvas.
type Node struct { type Node struct {
ID string `json:"id"` // ID is the unique node ID within the workflow.
Type BlockType `json:"type"` // In normal use cases, this ID is generated by frontend.
Meta any `json:"meta"` // It does NOT need to be unique between parent workflow and sub workflows.
Data *Data `json:"data"` // The Entry node and Exit node of a workflow always have fixed node IDs: 100001 and 900001.
Blocks []*Node `json:"blocks,omitempty"` ID string `json:"id"`
Edges []*Edge `json:"edges,omitempty"`
Version string `json:"version,omitempty"`
parent *Node // Type is the Node Type of this node instance.
// It corresponds to the string value of 'ID' field of NodeMeta.
Type string `json:"type"`
// Meta holds meta data used by frontend, such as the node's position within canvas.
Meta any `json:"meta"`
// Data holds the actual configurations of a node, such as inputs, outputs and exception handling.
// It also holds exclusive configurations for different node types, such as LLM configurations.
Data *Data `json:"data"`
// Blocks holds the sub nodes of this node.
// It is only used when the node type is Composite, such as NodeTypeBatch and NodeTypeLoop.
Blocks []*Node `json:"blocks,omitempty"`
// Edges are the connections between nodes.
// Strictly corresponds to connections drawn on canvas.
Edges []*Edge `json:"edges,omitempty"`
// Version is the version of this node type's schema.
Version string `json:"version,omitempty"`
parent *Node // if this node is within a composite node, coze will set this. No need to set manually
} }
func (n *Node) SetParent(parent *Node) { func (n *Node) SetParent(parent *Node) {
@@ -47,7 +71,7 @@ func (n *Node) Parent() *Node {
return n.parent return n.parent
} }
type NodeMeta struct { type NodeMetaFE struct {
Title string `json:"title,omitempty"` Title string `json:"title,omitempty"`
Description string `json:"description,omitempty"` Description string `json:"description,omitempty"`
Icon string `json:"icon,omitempty"` Icon string `json:"icon,omitempty"`
@@ -62,52 +86,88 @@ type Edge struct {
TargetPortID string `json:"targetPortID,omitempty"` TargetPortID string `json:"targetPortID,omitempty"`
} }
// Data holds the actual configuration of a Node.
type Data struct { type Data struct {
Meta *NodeMeta `json:"nodeMeta,omitempty"` // Meta is the meta data of this node. Only used by frontend.
Outputs []any `json:"outputs,omitempty"` // either []*Variable or []*Param Meta *NodeMetaFE `json:"nodeMeta,omitempty"`
Inputs *Inputs `json:"inputs,omitempty"`
Size any `json:"size,omitempty"` // Outputs configures the output fields and their types.
// Outputs can be either []*Variable (most of the cases, just fields and types),
// or []*Param (used by composite nodes as they need to refer to outputs of sub nodes)
Outputs []any `json:"outputs,omitempty"`
// Inputs configures ALL input information of a node,
// including fixed input fields and dynamic input fields defined by user.
Inputs *Inputs `json:"inputs,omitempty"`
// Size configures the size of this node in frontend.
// Only used by NodeTypeComment.
Size any `json:"size,omitempty"`
} }
type Inputs struct { type Inputs struct {
InputParameters []*Param `json:"inputParameters"` // InputParameters are the fields defined by user for this particular node.
Content *BlockInput `json:"content"` InputParameters []*Param `json:"inputParameters"`
TerminatePlan *TerminatePlan `json:"terminatePlan,omitempty"`
StreamingOutput bool `json:"streamingOutput,omitempty"`
CallTransferVoice bool `json:"callTransferVoice,omitempty"`
ChatHistoryWriting string `json:"chatHistoryWriting,omitempty"`
LLMParam any `json:"llmParam,omitempty"` // The LLMParam type may be one of the LLMParam or IntentDetectorLLMParam type or QALLMParam type
FCParam *FCParam `json:"fcParam,omitempty"`
SettingOnError *SettingOnError `json:"settingOnError,omitempty"`
// SettingOnError configures common error handling strategy for nodes.
// NOTE: enable in frontend node's form first.
SettingOnError *SettingOnError `json:"settingOnError,omitempty"`
// NodeBatchInfo configures batch mode for nodes.
// NOTE: not to be confused with NodeTypeBatch.
NodeBatchInfo *NodeBatch `json:"batch,omitempty"`
// LLMParam may be one of the LLMParam or IntentDetectorLLMParam or SimpleLLMParam.
// Shared between most nodes requiring an ChatModel to function.
LLMParam any `json:"llmParam,omitempty"`
*OutputEmitter // exclusive configurations for NodeTypeEmitter and NodeTypeExit in Answer mode
*Exit // exclusive configurations for NodeTypeExit
*LLM // exclusive configurations for NodeTypeLLM
*Loop // exclusive configurations for NodeTypeLoop
*Selector // exclusive configurations for NodeTypeSelector
*TextProcessor // exclusive configurations for NodeTypeTextProcessor
*SubWorkflow // exclusive configurations for NodeTypeSubWorkflow
*IntentDetector // exclusive configurations for NodeTypeIntentDetector
*DatabaseNode // exclusive configurations for various Database nodes
*HttpRequestNode // exclusive configurations for NodeTypeHTTPRequester
*Knowledge // exclusive configurations for various Knowledge nodes
*CodeRunner // exclusive configurations for NodeTypeCodeRunner
*PluginAPIParam // exclusive configurations for NodeTypePlugin
*VariableAggregator // exclusive configurations for NodeTypeVariableAggregator
*VariableAssigner // exclusive configurations for NodeTypeVariableAssigner
*QA // exclusive configurations for NodeTypeQuestionAnswer
*Batch // exclusive configurations for NodeTypeBatch
*Comment // exclusive configurations for NodeTypeComment
*InputReceiver // exclusive configurations for NodeTypeInputReceiver
}
type OutputEmitter struct {
Content *BlockInput `json:"content"`
StreamingOutput bool `json:"streamingOutput,omitempty"`
}
type Exit struct {
TerminatePlan *TerminatePlan `json:"terminatePlan,omitempty"`
}
type LLM struct {
FCParam *FCParam `json:"fcParam,omitempty"`
}
type Loop struct {
LoopType LoopType `json:"loopType,omitempty"` LoopType LoopType `json:"loopType,omitempty"`
LoopCount *BlockInput `json:"loopCount,omitempty"` LoopCount *BlockInput `json:"loopCount,omitempty"`
VariableParameters []*Param `json:"variableParameters,omitempty"` VariableParameters []*Param `json:"variableParameters,omitempty"`
}
type Selector struct {
Branches []*struct { Branches []*struct {
Condition struct { Condition struct {
Logic LogicType `json:"logic"` Logic LogicType `json:"logic"`
Conditions []*Condition `json:"conditions"` Conditions []*Condition `json:"conditions"`
} `json:"condition"` } `json:"condition"`
} `json:"branches,omitempty"` } `json:"branches,omitempty"`
NodeBatchInfo *NodeBatch `json:"batch,omitempty"` // node in batch mode
*TextProcessor
*SubWorkflow
*IntentDetector
*DatabaseNode
*HttpRequestNode
*KnowledgeIndexer
*CodeRunner
*PluginAPIParam
*VariableAggregator
*VariableAssigner
*QA
*Batch
*Comment
OutputSchema string `json:"outputSchema,omitempty"`
} }
type Comment struct { type Comment struct {
@@ -127,7 +187,7 @@ type VariableAssigner struct {
type LLMParam = []*Param type LLMParam = []*Param
type IntentDetectorLLMParam = map[string]any type IntentDetectorLLMParam = map[string]any
type QALLMParam struct { type SimpleLLMParam struct {
GenerationDiversity string `json:"generationDiversity"` GenerationDiversity string `json:"generationDiversity"`
MaxTokens int `json:"maxTokens"` MaxTokens int `json:"maxTokens"`
ModelName string `json:"modelName"` ModelName string `json:"modelName"`
@@ -248,7 +308,7 @@ type CodeRunner struct {
Language int64 `json:"language"` Language int64 `json:"language"`
} }
type KnowledgeIndexer struct { type Knowledge struct {
DatasetParam []*Param `json:"datasetParam,omitempty"` DatasetParam []*Param `json:"datasetParam,omitempty"`
StrategyParam StrategyParam `json:"strategyParam,omitempty"` StrategyParam StrategyParam `json:"strategyParam,omitempty"`
} }
@@ -384,23 +444,54 @@ type ChatHistorySetting struct {
type Intent struct { type Intent struct {
Name string `json:"name"` Name string `json:"name"`
} }
// Param is a node's field with type and source info.
type Param struct { type Param struct {
Name string `json:"name,omitempty"` // Name is the field's name.
Input *BlockInput `json:"input,omitempty"` Name string `json:"name,omitempty"`
Left *BlockInput `json:"left,omitempty"`
Right *BlockInput `json:"right,omitempty"` // Input is the configurations for normal, singular field.
Input *BlockInput `json:"input,omitempty"`
// Left is the configurations for the left half of an expression,
// such as an assignment in NodeTypeVariableAssigner.
Left *BlockInput `json:"left,omitempty"`
// Right is the configuration for the right half of an expression.
Right *BlockInput `json:"right,omitempty"`
// Variables are configurations for a group of fields.
// Only used in NodeTypeVariableAggregator.
Variables []*BlockInput `json:"variables,omitempty"` Variables []*BlockInput `json:"variables,omitempty"`
} }
// Variable is the configuration of a node's field, either input or output.
type Variable struct { type Variable struct {
Name string `json:"name"` // Name is the field's name as defined on canvas.
Type VariableType `json:"type"` Name string `json:"name"`
Required bool `json:"required,omitempty"`
AssistType AssistType `json:"assistType,omitempty"` // Type is the field's data type, such as string, integer, number, object, array, etc.
Schema any `json:"schema,omitempty"` // either []*Variable (for object) or *Variable (for list) Type VariableType `json:"type"`
Description string `json:"description,omitempty"`
ReadOnly bool `json:"readOnly,omitempty"` // Required is set to true if you checked the 'required box' on this field
DefaultValue any `json:"defaultValue,omitempty"` Required bool `json:"required,omitempty"`
// AssistType is the 'secondary' type of string fields, such as different types of file and image, or time.
AssistType AssistType `json:"assistType,omitempty"`
// Schema contains detailed info for sub-fields of an object field, or element type of an array.
Schema any `json:"schema,omitempty"` // either []*Variable (for object) or *Variable (for list)
// Description describes the field's intended use. Used on Entry node. Useful for workflow tools.
Description string `json:"description,omitempty"`
// ReadOnly indicates a field is not to be set by Node's business logic.
// e.g. the ErrorBody field when exception strategy is configured.
ReadOnly bool `json:"readOnly,omitempty"`
// DefaultValue configures the 'default value' if this field is missing in input.
// Effective only in Entry node.
DefaultValue any `json:"defaultValue,omitempty"`
} }
type BlockInput struct { type BlockInput struct {
@@ -436,48 +527,6 @@ type SubWorkflow struct {
SpaceID string `json:"spaceId,omitempty"` SpaceID string `json:"spaceId,omitempty"`
} }
// BlockType is the enumeration of node types for front-end canvas schema.
// To add a new BlockType, start from a really big number such as 1000, to avoid conflict with future extensions.
type BlockType string
func (b BlockType) String() string {
return string(b)
}
const (
BlockTypeBotStart BlockType = "1"
BlockTypeBotEnd BlockType = "2"
BlockTypeBotLLM BlockType = "3"
BlockTypeBotAPI BlockType = "4"
BlockTypeBotCode BlockType = "5"
BlockTypeBotDataset BlockType = "6"
BlockTypeCondition BlockType = "8"
BlockTypeBotSubWorkflow BlockType = "9"
BlockTypeDatabase BlockType = "12"
BlockTypeBotMessage BlockType = "13"
BlockTypeBotText BlockType = "15"
BlockTypeQuestion BlockType = "18"
BlockTypeBotBreak BlockType = "19"
BlockTypeBotLoopSetVariable BlockType = "20"
BlockTypeBotLoop BlockType = "21"
BlockTypeBotIntent BlockType = "22"
BlockTypeBotDatasetWrite BlockType = "27"
BlockTypeBotInput BlockType = "30"
BlockTypeBotBatch BlockType = "28"
BlockTypeBotContinue BlockType = "29"
BlockTypeBotComment BlockType = "31"
BlockTypeBotVariableMerge BlockType = "32"
BlockTypeBotAssignVariable BlockType = "40"
BlockTypeDatabaseUpdate BlockType = "42"
BlockTypeDatabaseSelect BlockType = "43"
BlockTypeDatabaseDelete BlockType = "44"
BlockTypeBotHttp BlockType = "45"
BlockTypeDatabaseInsert BlockType = "46"
BlockTypeJsonSerialization BlockType = "58"
BlockTypeJsonDeserialization BlockType = "59"
BlockTypeBotDatasetDelete BlockType = "60"
)
type VariableType string type VariableType string
const ( const (
@@ -536,19 +585,31 @@ const (
type ErrorProcessType int type ErrorProcessType int
const ( const (
ErrorProcessTypeThrow ErrorProcessType = 1 ErrorProcessTypeThrow ErrorProcessType = 1 // throws the error as usual
ErrorProcessTypeDefault ErrorProcessType = 2 ErrorProcessTypeReturnDefaultData ErrorProcessType = 2 // return DataOnErr configured in SettingOnError
ErrorProcessTypeExceptionBranch ErrorProcessType = 3 ErrorProcessTypeExceptionBranch ErrorProcessType = 3 // executes the exception branch on error
) )
// SettingOnError contains common error handling strategy.
type SettingOnError struct { type SettingOnError struct {
DataOnErr string `json:"dataOnErr,omitempty"` // DataOnErr defines the JSON result to be returned on error.
Switch bool `json:"switch,omitempty"` DataOnErr string `json:"dataOnErr,omitempty"`
// Switch defines whether ANY error handling strategy is active.
// If set to false, it's equivalent to set ProcessType = ErrorProcessTypeThrow
Switch bool `json:"switch,omitempty"`
// ProcessType determines the error handling strategy for this node.
ProcessType *ErrorProcessType `json:"processType,omitempty"` ProcessType *ErrorProcessType `json:"processType,omitempty"`
RetryTimes int64 `json:"retryTimes,omitempty"` // RetryTimes determines how many times to retry. 0 means no retry.
TimeoutMs int64 `json:"timeoutMs,omitempty"` // If positive, any retries will be executed immediately after error.
Ext *struct { RetryTimes int64 `json:"retryTimes,omitempty"`
BackupLLMParam string `json:"backupLLMParam,omitempty"` // only for LLM Node, marshaled from QALLMParam // TimeoutMs sets the timeout duration in millisecond.
// If any retry happens, ALL retry attempts accumulates to the same timeout threshold.
TimeoutMs int64 `json:"timeoutMs,omitempty"`
// Ext sets any extra settings specific to NodeType
Ext *struct {
// BackupLLMParam is only for LLM Node, marshaled from SimpleLLMParam.
// If retry happens, the backup LLM will be used instead of the main LLM.
BackupLLMParam string `json:"backupLLMParam,omitempty"`
} `json:"ext,omitempty"` } `json:"ext,omitempty"`
} }
@@ -597,32 +658,8 @@ const (
LoopTypeInfinite LoopType = "infinite" LoopTypeInfinite LoopType = "infinite"
) )
type WorkflowIdentity struct { type InputReceiver struct {
ID string `json:"id"` OutputSchema string `json:"outputSchema,omitempty"`
Version string `json:"version"`
}
func (c *Canvas) GetAllSubWorkflowIdentities() []*WorkflowIdentity {
workflowEntities := make([]*WorkflowIdentity, 0)
var collectSubWorkFlowEntities func(nodes []*Node)
collectSubWorkFlowEntities = func(nodes []*Node) {
for _, n := range nodes {
if n.Type == BlockTypeBotSubWorkflow {
workflowEntities = append(workflowEntities, &WorkflowIdentity{
ID: n.Data.Inputs.WorkflowID,
Version: n.Data.Inputs.WorkflowVersion,
})
}
if len(n.Blocks) > 0 {
collectSubWorkFlowEntities(n.Blocks)
}
}
}
collectSubWorkFlowEntities(c.Nodes)
return workflowEntities
} }
func GenerateNodeIDForBatchMode(key string) string { func GenerateNodeIDForBatchMode(key string) string {
@@ -632,3 +669,163 @@ func GenerateNodeIDForBatchMode(key string) string {
func IsGeneratedNodeForBatchMode(key string, parentKey string) bool { func IsGeneratedNodeForBatchMode(key string, parentKey string) bool {
return key == GenerateNodeIDForBatchMode(parentKey) return key == GenerateNodeIDForBatchMode(parentKey)
} }
const defaultZhCNInitCanvasJsonSchema = `{
"nodes": [
{
"id": "100001",
"type": "1",
"meta": {
"position": {
"x": 0,
"y": 0
}
},
"data": {
"nodeMeta": {
"description": "工作流的起始节点,用于设定启动工作流需要的信息",
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start.png",
"subTitle": "",
"title": "开始"
},
"outputs": [
{
"type": "string",
"name": "input",
"required": false
}
],
"trigger_parameters": [
{
"type": "string",
"name": "input",
"required": false
}
]
}
},
{
"id": "900001",
"type": "2",
"meta": {
"position": {
"x": 1000,
"y": 0
}
},
"data": {
"nodeMeta": {
"description": "工作流的最终节点,用于返回工作流运行后的结果信息",
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End.png",
"subTitle": "",
"title": "结束"
},
"inputs": {
"terminatePlan": "returnVariables",
"inputParameters": [
{
"name": "output",
"input": {
"type": "string",
"value": {
"type": "ref",
"content": {
"source": "block-output",
"blockID": "",
"name": ""
}
}
}
}
]
}
}
}
],
"edges": [],
"versions": {
"loop": "v2"
}
}`
const defaultEnUSInitCanvasJsonSchema = `{
"nodes": [
{
"id": "100001",
"type": "1",
"meta": {
"position": {
"x": 0,
"y": 0
}
},
"data": {
"nodeMeta": {
"description": "The starting node of the workflow, used to set the information needed to initiate the workflow.",
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start.png",
"subTitle": "",
"title": "Start"
},
"outputs": [
{
"type": "string",
"name": "input",
"required": false
}
],
"trigger_parameters": [
{
"type": "string",
"name": "input",
"required": false
}
]
}
},
{
"id": "900001",
"type": "2",
"meta": {
"position": {
"x": 1000,
"y": 0
}
},
"data": {
"nodeMeta": {
"description": "The final node of the workflow, used to return the result information after the workflow runs.",
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End.png",
"subTitle": "",
"title": "End"
},
"inputs": {
"terminatePlan": "returnVariables",
"inputParameters": [
{
"name": "output",
"input": {
"type": "string",
"value": {
"type": "ref",
"content": {
"source": "block-output",
"blockID": "",
"name": ""
}
}
}
}
]
}
}
}
],
"edges": [],
"versions": {
"loop": "v2"
}
}`
func GetDefaultInitCanvasJsonSchema(locale i18n.Locale) string {
return ternary.IFElse(locale == i18n.LocaleEN, defaultEnUSInitCanvasJsonSchema, defaultZhCNInitCanvasJsonSchema)
}

View File

@@ -47,12 +47,6 @@ type FieldSource struct {
Val any `json:"val,omitempty"` Val any `json:"val,omitempty"`
} }
type ImplicitNodeDependency struct {
NodeID string
FieldPath compose.FieldPath
TypeInfo *TypeInfo
}
type TypeInfo struct { type TypeInfo struct {
Type DataType `json:"type"` Type DataType `json:"type"`
ElemTypeInfo *TypeInfo `json:"elem_type_info,omitempty"` ElemTypeInfo *TypeInfo `json:"elem_type_info,omitempty"`

View File

@@ -69,13 +69,6 @@ type IDVersionPair struct {
Version string Version string
} }
type Stage uint8
const (
StageDraft Stage = 1
StagePublished Stage = 2
)
type WorkflowBasic struct { type WorkflowBasic struct {
ID int64 ID int64
Version string Version string

View File

@@ -58,6 +58,11 @@ import (
"github.com/coze-dev/coze-studio/backend/types/consts" "github.com/coze-dev/coze-studio/backend/types/consts"
) )
func TestMain(m *testing.M) {
RegisterAllNodeAdaptors()
m.Run()
}
func TestIntentDetectorAndDatabase(t *testing.T) { func TestIntentDetectorAndDatabase(t *testing.T) {
mockey.PatchConvey("intent detector & database custom sql", t, func() { mockey.PatchConvey("intent detector & database custom sql", t, func() {
data, err := os.ReadFile("../examples/intent_detector_database_custom_sql.json") data, err := os.ReadFile("../examples/intent_detector_database_custom_sql.json")

View File

@@ -26,10 +26,11 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
) )
func WorkflowSchemaFromNode(ctx context.Context, c *vo.Canvas, nodeID string) ( func WorkflowSchemaFromNode(ctx context.Context, c *vo.Canvas, nodeID string) (
*compose.WorkflowSchema, error) { *schema.WorkflowSchema, error) {
var ( var (
n *vo.Node n *vo.Node
nodeFinder func(nodes []*vo.Node) *vo.Node nodeFinder func(nodes []*vo.Node) *vo.Node
@@ -62,35 +63,27 @@ func WorkflowSchemaFromNode(ctx context.Context, c *vo.Canvas, nodeID string) (
n = batchN n = batchN
} }
implicitDependencies, err := extractImplicitDependency(n, c.Nodes) nsList, hierarchy, err := NodeToNodeSchema(ctx, n, c)
if err != nil {
return nil, err
}
opts := make([]OptionFn, 0, 1)
if len(implicitDependencies) > 0 {
opts = append(opts, WithImplicitNodeDependencies(implicitDependencies))
}
nsList, hierarchy, err := NodeToNodeSchema(ctx, n, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var ( var (
ns *compose.NodeSchema ns *schema.NodeSchema
innerNodes map[vo.NodeKey]*compose.NodeSchema // inner nodes of the composite node if nodeKey is composite innerNodes map[vo.NodeKey]*schema.NodeSchema // inner nodes of the composite node if nodeKey is composite
connections []*compose.Connection connections []*schema.Connection
) )
if len(nsList) == 1 { if len(nsList) == 1 {
ns = nsList[0] ns = nsList[0]
} else { } else {
innerNodes = make(map[vo.NodeKey]*compose.NodeSchema) innerNodes = make(map[vo.NodeKey]*schema.NodeSchema)
for i := range nsList { for i := range nsList {
one := nsList[i] one := nsList[i]
if _, ok := hierarchy[one.Key]; ok { if _, ok := hierarchy[one.Key]; ok {
innerNodes[one.Key] = one innerNodes[one.Key] = one
if one.Type == entity.NodeTypeContinue || one.Type == entity.NodeTypeBreak { if one.Type == entity.NodeTypeContinue || one.Type == entity.NodeTypeBreak {
connections = append(connections, &compose.Connection{ connections = append(connections, &schema.Connection{
FromNode: one.Key, FromNode: one.Key,
ToNode: vo.NodeKey(nodeID), ToNode: vo.NodeKey(nodeID),
}) })
@@ -106,13 +99,13 @@ func WorkflowSchemaFromNode(ctx context.Context, c *vo.Canvas, nodeID string) (
} }
const inputFillerKey = "input_filler" const inputFillerKey = "input_filler"
connections = append(connections, &compose.Connection{ connections = append(connections, &schema.Connection{
FromNode: einoCompose.START, FromNode: einoCompose.START,
ToNode: inputFillerKey, ToNode: inputFillerKey,
}, &compose.Connection{ }, &schema.Connection{
FromNode: inputFillerKey, FromNode: inputFillerKey,
ToNode: ns.Key, ToNode: ns.Key,
}, &compose.Connection{ }, &schema.Connection{
FromNode: ns.Key, FromNode: ns.Key,
ToNode: einoCompose.END, ToNode: einoCompose.END,
}) })
@@ -209,7 +202,7 @@ func WorkflowSchemaFromNode(ctx context.Context, c *vo.Canvas, nodeID string) (
return newOutput, nil return newOutput, nil
} }
inputFiller := &compose.NodeSchema{ inputFiller := &schema.NodeSchema{
Key: inputFillerKey, Key: inputFillerKey,
Type: entity.NodeTypeLambda, Type: entity.NodeTypeLambda,
Lambda: einoCompose.InvokableLambda(i), Lambda: einoCompose.InvokableLambda(i),
@@ -227,10 +220,16 @@ func WorkflowSchemaFromNode(ctx context.Context, c *vo.Canvas, nodeID string) (
OutputTypes: startOutputTypes, OutputTypes: startOutputTypes,
} }
trimmedSC := &compose.WorkflowSchema{ branches, err := schema.BuildBranches(connections)
Nodes: append([]*compose.NodeSchema{ns, inputFiller}, maps.Values(innerNodes)...), if err != nil {
return nil, err
}
trimmedSC := &schema.WorkflowSchema{
Nodes: append([]*schema.NodeSchema{ns, inputFiller}, maps.Values(innerNodes)...),
Connections: connections, Connections: connections,
Hierarchy: hierarchy, Hierarchy: hierarchy,
Branches: branches,
} }
if enabled { if enabled {

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,663 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package convert
import (
"fmt"
"strconv"
"strings"
einoCompose "github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
func CanvasVariableToTypeInfo(v *vo.Variable) (*vo.TypeInfo, error) {
tInfo := &vo.TypeInfo{
Required: v.Required,
Desc: v.Description,
}
switch v.Type {
case vo.VariableTypeString:
switch v.AssistType {
case vo.AssistTypeTime:
tInfo.Type = vo.DataTypeTime
case vo.AssistTypeNotSet:
tInfo.Type = vo.DataTypeString
default:
fileType, ok := assistTypeToFileType(v.AssistType)
if ok {
tInfo.Type = vo.DataTypeFile
tInfo.FileType = &fileType
} else {
return nil, fmt.Errorf("unsupported assist type: %v", v.AssistType)
}
}
case vo.VariableTypeInteger:
tInfo.Type = vo.DataTypeInteger
case vo.VariableTypeFloat:
tInfo.Type = vo.DataTypeNumber
case vo.VariableTypeBoolean:
tInfo.Type = vo.DataTypeBoolean
case vo.VariableTypeObject:
tInfo.Type = vo.DataTypeObject
tInfo.Properties = make(map[string]*vo.TypeInfo)
if v.Schema != nil {
for _, subVAny := range v.Schema.([]any) {
subV, err := vo.ParseVariable(subVAny)
if err != nil {
return nil, err
}
subTInfo, err := CanvasVariableToTypeInfo(subV)
if err != nil {
return nil, err
}
tInfo.Properties[subV.Name] = subTInfo
}
}
case vo.VariableTypeList:
tInfo.Type = vo.DataTypeArray
subVAny := v.Schema
subV, err := vo.ParseVariable(subVAny)
if err != nil {
return nil, err
}
subTInfo, err := CanvasVariableToTypeInfo(subV)
if err != nil {
return nil, err
}
tInfo.ElemTypeInfo = subTInfo
default:
return nil, fmt.Errorf("unsupported variable type: %s", v.Type)
}
return tInfo, nil
}
func CanvasBlockInputToTypeInfo(b *vo.BlockInput) (tInfo *vo.TypeInfo, err error) {
defer func() {
if err != nil {
err = vo.WrapIfNeeded(errno.ErrSchemaConversionFail, err)
}
}()
tInfo = &vo.TypeInfo{}
if b == nil {
return tInfo, nil
}
switch b.Type {
case vo.VariableTypeString:
switch b.AssistType {
case vo.AssistTypeTime:
tInfo.Type = vo.DataTypeTime
case vo.AssistTypeNotSet:
tInfo.Type = vo.DataTypeString
default:
fileType, ok := assistTypeToFileType(b.AssistType)
if ok {
tInfo.Type = vo.DataTypeFile
tInfo.FileType = &fileType
} else {
return nil, fmt.Errorf("unsupported assist type: %v", b.AssistType)
}
}
case vo.VariableTypeInteger:
tInfo.Type = vo.DataTypeInteger
case vo.VariableTypeFloat:
tInfo.Type = vo.DataTypeNumber
case vo.VariableTypeBoolean:
tInfo.Type = vo.DataTypeBoolean
case vo.VariableTypeObject:
tInfo.Type = vo.DataTypeObject
tInfo.Properties = make(map[string]*vo.TypeInfo)
if b.Schema != nil {
for _, subVAny := range b.Schema.([]any) {
if b.Value.Type == vo.BlockInputValueTypeRef {
subV, err := vo.ParseVariable(subVAny)
if err != nil {
return nil, err
}
subTInfo, err := CanvasVariableToTypeInfo(subV)
if err != nil {
return nil, err
}
tInfo.Properties[subV.Name] = subTInfo
} else if b.Value.Type == vo.BlockInputValueTypeObjectRef {
subV, err := parseParam(subVAny)
if err != nil {
return nil, err
}
subTInfo, err := CanvasBlockInputToTypeInfo(subV.Input)
if err != nil {
return nil, err
}
tInfo.Properties[subV.Name] = subTInfo
}
}
}
case vo.VariableTypeList:
tInfo.Type = vo.DataTypeArray
subVAny := b.Schema
subV, err := vo.ParseVariable(subVAny)
if err != nil {
return nil, err
}
subTInfo, err := CanvasVariableToTypeInfo(subV)
if err != nil {
return nil, err
}
tInfo.ElemTypeInfo = subTInfo
default:
return nil, fmt.Errorf("unsupported variable type: %s", b.Type)
}
return tInfo, nil
}
func CanvasBlockInputToFieldInfo(b *vo.BlockInput, path einoCompose.FieldPath, parentNode *vo.Node) (sources []*vo.FieldInfo, err error) {
value := b.Value
if value == nil {
return nil, fmt.Errorf("input %v has no value, type= %s", path, b.Type)
}
switch value.Type {
case vo.BlockInputValueTypeObjectRef:
sc := b.Schema
if sc == nil {
return nil, fmt.Errorf("input %v has no schema, type= %s", path, b.Type)
}
paramList, ok := sc.([]any)
if !ok {
return nil, fmt.Errorf("input %v schema not []any, type= %T", path, sc)
}
for i := range paramList {
paramAny := paramList[i]
param, err := parseParam(paramAny)
if err != nil {
return nil, err
}
copied := make([]string, len(path))
copy(copied, path)
subFieldInfo, err := CanvasBlockInputToFieldInfo(param.Input, append(copied, param.Name), parentNode)
if err != nil {
return nil, err
}
sources = append(sources, subFieldInfo...)
}
return sources, nil
case vo.BlockInputValueTypeLiteral:
content := value.Content
if content == nil {
return nil, fmt.Errorf("input %v is literal but has no value, type= %s", path, b.Type)
}
switch b.Type {
case vo.VariableTypeObject:
m := make(map[string]any)
if err = sonic.UnmarshalString(content.(string), &m); err != nil {
return nil, err
}
content = m
case vo.VariableTypeList:
l := make([]any, 0)
if err = sonic.UnmarshalString(content.(string), &l); err != nil {
return nil, err
}
content = l
case vo.VariableTypeInteger:
switch content.(type) {
case string:
content, err = strconv.ParseInt(content.(string), 10, 64)
if err != nil {
return nil, err
}
case int64:
content = content.(int64)
case float64:
content = int64(content.(float64))
default:
return nil, fmt.Errorf("unsupported variable type fot integer: %s", b.Type)
}
case vo.VariableTypeFloat:
switch content.(type) {
case string:
content, err = strconv.ParseFloat(content.(string), 64)
if err != nil {
return nil, err
}
case int64:
content = float64(content.(int64))
case float64:
content = content.(float64)
default:
return nil, fmt.Errorf("unsupported variable type for float: %s", b.Type)
}
case vo.VariableTypeBoolean:
switch content.(type) {
case string:
content, err = strconv.ParseBool(content.(string))
if err != nil {
return nil, err
}
case bool:
content = content.(bool)
default:
return nil, fmt.Errorf("unsupported variable type for boolean: %s", b.Type)
}
default:
}
return []*vo.FieldInfo{
{
Path: path,
Source: vo.FieldSource{
Val: content,
},
},
}, nil
case vo.BlockInputValueTypeRef:
content := value.Content
if content == nil {
return nil, fmt.Errorf("input %v is literal but has no value, type= %s", path, b.Type)
}
ref, err := parseBlockInputRef(content)
if err != nil {
return nil, err
}
fieldSource, err := CanvasBlockInputRefToFieldSource(ref)
if err != nil {
return nil, err
}
if parentNode != nil {
if fieldSource.Ref != nil && len(fieldSource.Ref.FromNodeKey) > 0 && fieldSource.Ref.FromNodeKey == vo.NodeKey(parentNode.ID) {
varRoot := fieldSource.Ref.FromPath[0]
if parentNode.Data.Inputs.Loop != nil {
for _, p := range parentNode.Data.Inputs.VariableParameters {
if p.Name == varRoot {
fieldSource.Ref.FromNodeKey = ""
pi := vo.ParentIntermediate
fieldSource.Ref.VariableType = &pi
}
}
}
}
}
return []*vo.FieldInfo{
{
Path: path,
Source: *fieldSource,
},
}, nil
default:
return nil, fmt.Errorf("unsupported value type: %s for blockInput type= %s", value.Type, b.Type)
}
}
func parseBlockInputRef(content any) (*vo.BlockInputReference, error) {
if bi, ok := content.(*vo.BlockInputReference); ok {
return bi, nil
}
m, ok := content.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid content type: %T when parse BlockInputRef", content)
}
marshaled, err := sonic.Marshal(m)
if err != nil {
return nil, err
}
p := &vo.BlockInputReference{}
if err := sonic.Unmarshal(marshaled, p); err != nil {
return nil, err
}
return p, nil
}
func parseParam(v any) (*vo.Param, error) {
if pa, ok := v.(*vo.Param); ok {
return pa, nil
}
m, ok := v.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid content type: %T when parse Param", v)
}
marshaled, err := sonic.Marshal(m)
if err != nil {
return nil, err
}
p := &vo.Param{}
if err := sonic.Unmarshal(marshaled, p); err != nil {
return nil, err
}
return p, nil
}
func CanvasBlockInputRefToFieldSource(r *vo.BlockInputReference) (*vo.FieldSource, error) {
switch r.Source {
case vo.RefSourceTypeBlockOutput:
if len(r.BlockID) == 0 {
return nil, fmt.Errorf("invalid BlockInputReference = %+v, BlockID is empty when source is block output", r)
}
parts := strings.Split(r.Name, ".") // an empty r.Name signals an all-to-all mapping
return &vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: vo.NodeKey(r.BlockID),
FromPath: parts,
},
}, nil
case vo.RefSourceTypeGlobalApp, vo.RefSourceTypeGlobalSystem, vo.RefSourceTypeGlobalUser:
if len(r.Path) == 0 {
return nil, fmt.Errorf("invalid BlockInputReference = %+v, Path is empty when source is variables", r)
}
var varType vo.GlobalVarType
switch r.Source {
case vo.RefSourceTypeGlobalApp:
varType = vo.GlobalAPP
case vo.RefSourceTypeGlobalSystem:
varType = vo.GlobalSystem
case vo.RefSourceTypeGlobalUser:
varType = vo.GlobalUser
default:
return nil, fmt.Errorf("invalid BlockInputReference = %+v, Source is invalid", r)
}
return &vo.FieldSource{
Ref: &vo.Reference{
VariableType: &varType,
FromPath: r.Path,
},
}, nil
default:
return nil, fmt.Errorf("unsupported ref source type: %s", r.Source)
}
}
func assistTypeToFileType(a vo.AssistType) (vo.FileSubType, bool) {
switch a {
case vo.AssistTypeNotSet:
return "", false
case vo.AssistTypeTime:
return "", false
case vo.AssistTypeImage:
return vo.FileTypeImage, true
case vo.AssistTypeAudio:
return vo.FileTypeAudio, true
case vo.AssistTypeVideo:
return vo.FileTypeVideo, true
case vo.AssistTypeDefault:
return vo.FileTypeDefault, true
case vo.AssistTypeDoc:
return vo.FileTypeDocument, true
case vo.AssistTypeExcel:
return vo.FileTypeExcel, true
case vo.AssistTypeCode:
return vo.FileTypeCode, true
case vo.AssistTypePPT:
return vo.FileTypePPT, true
case vo.AssistTypeTXT:
return vo.FileTypeTxt, true
case vo.AssistTypeSvg:
return vo.FileTypeSVG, true
case vo.AssistTypeVoice:
return vo.FileTypeVoice, true
case vo.AssistTypeZip:
return vo.FileTypeZip, true
default:
panic("impossible")
}
}
func SetInputsForNodeSchema(n *vo.Node, ns *schema.NodeSchema) error {
if n.Data.Inputs == nil {
return nil
}
inputParams := n.Data.Inputs.InputParameters
if len(inputParams) == 0 {
return nil
}
for _, param := range inputParams {
name := param.Name
tInfo, err := CanvasBlockInputToTypeInfo(param.Input)
if err != nil {
return err
}
ns.SetInputType(name, tInfo)
sources, err := CanvasBlockInputToFieldInfo(param.Input, einoCompose.FieldPath{name}, n.Parent())
if err != nil {
return err
}
ns.AddInputSource(sources...)
}
return nil
}
func SetOutputTypesForNodeSchema(n *vo.Node, ns *schema.NodeSchema) error {
for _, vAny := range n.Data.Outputs {
v, err := vo.ParseVariable(vAny)
if err != nil {
return err
}
tInfo, err := CanvasVariableToTypeInfo(v)
if err != nil {
return err
}
if v.ReadOnly {
if v.Name == "errorBody" { // reserved output fields when exception happens
continue
}
}
ns.SetOutputType(v.Name, tInfo)
}
return nil
}
func SetOutputsForNodeSchema(n *vo.Node, ns *schema.NodeSchema) error {
for _, vAny := range n.Data.Outputs {
param, err := parseParam(vAny)
if err != nil {
return err
}
name := param.Name
tInfo, err := CanvasBlockInputToTypeInfo(param.Input)
if err != nil {
return err
}
ns.SetOutputType(name, tInfo)
sources, err := CanvasBlockInputToFieldInfo(param.Input, einoCompose.FieldPath{name}, n.Parent())
if err != nil {
return err
}
ns.AddOutputSource(sources...)
}
return nil
}
func BlockInputToNamedTypeInfo(name string, b *vo.BlockInput) (*vo.NamedTypeInfo, error) {
tInfo := &vo.NamedTypeInfo{
Name: name,
}
if b == nil {
return tInfo, nil
}
switch b.Type {
case vo.VariableTypeString:
switch b.AssistType {
case vo.AssistTypeTime:
tInfo.Type = vo.DataTypeTime
case vo.AssistTypeNotSet:
tInfo.Type = vo.DataTypeString
default:
fileType, ok := assistTypeToFileType(b.AssistType)
if ok {
tInfo.Type = vo.DataTypeFile
tInfo.FileType = &fileType
} else {
return nil, fmt.Errorf("unsupported assist type: %v", b.AssistType)
}
}
case vo.VariableTypeInteger:
tInfo.Type = vo.DataTypeInteger
case vo.VariableTypeFloat:
tInfo.Type = vo.DataTypeNumber
case vo.VariableTypeBoolean:
tInfo.Type = vo.DataTypeBoolean
case vo.VariableTypeObject:
tInfo.Type = vo.DataTypeObject
if b.Schema != nil {
tInfo.Properties = make([]*vo.NamedTypeInfo, 0, len(b.Schema.([]any)))
for _, subVAny := range b.Schema.([]any) {
if b.Value.Type == vo.BlockInputValueTypeRef {
subV, err := vo.ParseVariable(subVAny)
if err != nil {
return nil, err
}
subNInfo, err := VariableToNamedTypeInfo(subV)
if err != nil {
return nil, err
}
tInfo.Properties = append(tInfo.Properties, subNInfo)
} else if b.Value.Type == vo.BlockInputValueTypeObjectRef {
subV, err := parseParam(subVAny)
if err != nil {
return nil, err
}
subNInfo, err := BlockInputToNamedTypeInfo(subV.Name, subV.Input)
if err != nil {
return nil, err
}
tInfo.Properties = append(tInfo.Properties, subNInfo)
}
}
}
case vo.VariableTypeList:
tInfo.Type = vo.DataTypeArray
subVAny := b.Schema
subV, err := vo.ParseVariable(subVAny)
if err != nil {
return nil, err
}
subNInfo, err := VariableToNamedTypeInfo(subV)
if err != nil {
return nil, err
}
tInfo.ElemTypeInfo = subNInfo
default:
return nil, fmt.Errorf("unsupported variable type: %s", b.Type)
}
return tInfo, nil
}
func VariableToNamedTypeInfo(v *vo.Variable) (*vo.NamedTypeInfo, error) {
nInfo := &vo.NamedTypeInfo{
Required: v.Required,
Name: v.Name,
Desc: v.Description,
}
switch v.Type {
case vo.VariableTypeString:
switch v.AssistType {
case vo.AssistTypeTime:
nInfo.Type = vo.DataTypeTime
case vo.AssistTypeNotSet:
nInfo.Type = vo.DataTypeString
default:
fileType, ok := assistTypeToFileType(v.AssistType)
if ok {
nInfo.Type = vo.DataTypeFile
nInfo.FileType = &fileType
} else {
return nil, fmt.Errorf("unsupported assist type: %v", v.AssistType)
}
}
case vo.VariableTypeInteger:
nInfo.Type = vo.DataTypeInteger
case vo.VariableTypeFloat:
nInfo.Type = vo.DataTypeNumber
case vo.VariableTypeBoolean:
nInfo.Type = vo.DataTypeBoolean
case vo.VariableTypeObject:
nInfo.Type = vo.DataTypeObject
if v.Schema != nil {
nInfo.Properties = make([]*vo.NamedTypeInfo, 0)
for _, subVAny := range v.Schema.([]any) {
subV, err := vo.ParseVariable(subVAny)
if err != nil {
return nil, err
}
subTInfo, err := VariableToNamedTypeInfo(subV)
if err != nil {
return nil, err
}
nInfo.Properties = append(nInfo.Properties, subTInfo)
}
}
case vo.VariableTypeList:
nInfo.Type = vo.DataTypeArray
subVAny := v.Schema
subV, err := vo.ParseVariable(subVAny)
if err != nil {
return nil, err
}
subTInfo, err := VariableToNamedTypeInfo(subV)
if err != nil {
return nil, err
}
nInfo.ElemTypeInfo = subTInfo
default:
return nil, fmt.Errorf("unsupported variable type: %s", v.Type)
}
return nInfo, nil
}

View File

@@ -24,8 +24,11 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow" "github.com/coze-dev/coze-studio/backend/domain/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable" "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/adaptor" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/sonic" "github.com/coze-dev/coze-studio/backend/pkg/sonic"
"github.com/coze-dev/coze-studio/backend/types/errno" "github.com/coze-dev/coze-studio/backend/types/errno"
) )
@@ -123,7 +126,7 @@ func (cv *CanvasValidator) ValidateConnections(ctx context.Context) (issues []*I
return issues, nil return issues, nil
} }
func (cv *CanvasValidator) CheckRefVariable(ctx context.Context) (issues []*Issue, err error) { func (cv *CanvasValidator) CheckRefVariable(_ context.Context) (issues []*Issue, err error) {
issues = make([]*Issue, 0) issues = make([]*Issue, 0)
var checkRefVariable func(reachability *reachability, reachableNodes map[string]bool) error var checkRefVariable func(reachability *reachability, reachableNodes map[string]bool) error
checkRefVariable = func(reachability *reachability, parentReachableNodes map[string]bool) error { checkRefVariable = func(reachability *reachability, parentReachableNodes map[string]bool) error {
@@ -237,7 +240,7 @@ func (cv *CanvasValidator) CheckRefVariable(ctx context.Context) (issues []*Issu
return issues, nil return issues, nil
} }
func (cv *CanvasValidator) ValidateNestedFlows(ctx context.Context) (issues []*Issue, err error) { func (cv *CanvasValidator) ValidateNestedFlows(_ context.Context) (issues []*Issue, err error) {
issues = make([]*Issue, 0) issues = make([]*Issue, 0)
for nodeID, node := range cv.reachability.reachableNodes { for nodeID, node := range cv.reachability.reachableNodes {
if nestedReachableNodes, ok := cv.reachability.nestedReachability[nodeID]; ok && len(nestedReachableNodes.nestedReachability) > 0 { if nestedReachableNodes, ok := cv.reachability.nestedReachability[nodeID]; ok && len(nestedReachableNodes.nestedReachability) > 0 {
@@ -265,13 +268,13 @@ func (cv *CanvasValidator) CheckGlobalVariables(ctx context.Context) (issues []*
nVars := make([]*nodeVars, 0) nVars := make([]*nodeVars, 0)
for _, node := range cv.cfg.Canvas.Nodes { for _, node := range cv.cfg.Canvas.Nodes {
if node.Type == vo.BlockTypeBotComment { if node.Type == entity.NodeTypeComment.IDStr() {
continue continue
} }
if node.Type == vo.BlockTypeBotAssignVariable { if node.Type == entity.NodeTypeVariableAssigner.IDStr() {
v := &nodeVars{node: node, vars: make(map[string]*vo.TypeInfo)} v := &nodeVars{node: node, vars: make(map[string]*vo.TypeInfo)}
for _, p := range node.Data.Inputs.InputParameters { for _, p := range node.Data.Inputs.InputParameters {
v.vars[p.Name], err = adaptor.CanvasBlockInputToTypeInfo(p.Left) v.vars[p.Name], err = convert.CanvasBlockInputToTypeInfo(p.Left)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -338,7 +341,7 @@ func (cv *CanvasValidator) CheckSubWorkFlowTerminatePlanType(ctx context.Context
var collectSubWorkFlowNodes func(nodes []*vo.Node) var collectSubWorkFlowNodes func(nodes []*vo.Node)
collectSubWorkFlowNodes = func(nodes []*vo.Node) { collectSubWorkFlowNodes = func(nodes []*vo.Node) {
for _, n := range nodes { for _, n := range nodes {
if n.Type == vo.BlockTypeBotSubWorkflow { if n.Type == entity.NodeTypeSubWorkflow.IDStr() {
subWfMap = append(subWfMap, n) subWfMap = append(subWfMap, n)
wID, err := strconv.ParseInt(n.Data.Inputs.WorkflowID, 10, 64) wID, err := strconv.ParseInt(n.Data.Inputs.WorkflowID, 10, 64)
if err != nil { if err != nil {
@@ -465,62 +468,28 @@ func validateConnections(ctx context.Context, c *vo.Canvas) (issues []*Issue, er
selectorPorts := make(map[string]map[string]bool) selectorPorts := make(map[string]map[string]bool)
for nodeID, node := range nodeMap { for nodeID, node := range nodeMap {
switch node.Type { if node.Data.Inputs != nil && node.Data.Inputs.SettingOnError != nil &&
case vo.BlockTypeCondition: node.Data.Inputs.SettingOnError.ProcessType != nil &&
branches := node.Data.Inputs.Branches *node.Data.Inputs.SettingOnError.ProcessType == vo.ErrorProcessTypeExceptionBranch {
if _, exists := selectorPorts[nodeID]; !exists { if _, exists := selectorPorts[nodeID]; !exists {
selectorPorts[nodeID] = make(map[string]bool) selectorPorts[nodeID] = make(map[string]bool)
} }
selectorPorts[nodeID]["false"] = true selectorPorts[nodeID][schema.PortBranchError] = true
for index := range branches { selectorPorts[nodeID][schema.PortDefault] = true
if index == 0 {
selectorPorts[nodeID]["true"] = true
} else {
selectorPorts[nodeID][fmt.Sprintf("true_%v", index)] = true
}
}
case vo.BlockTypeBotIntent:
intents := node.Data.Inputs.Intents
if _, exists := selectorPorts[nodeID]; !exists {
selectorPorts[nodeID] = make(map[string]bool)
}
for index := range intents {
selectorPorts[nodeID][fmt.Sprintf("branch_%v", index)] = true
}
selectorPorts[nodeID]["default"] = true
if node.Data.Inputs.SettingOnError != nil && node.Data.Inputs.SettingOnError.ProcessType != nil &&
*node.Data.Inputs.SettingOnError.ProcessType == vo.ErrorProcessTypeExceptionBranch {
selectorPorts[nodeID]["branch_error"] = true
}
case vo.BlockTypeQuestion:
if node.Data.Inputs.QA.AnswerType == vo.QAAnswerTypeOption {
if _, exists := selectorPorts[nodeID]; !exists {
selectorPorts[nodeID] = make(map[string]bool)
}
if node.Data.Inputs.QA.OptionType == vo.QAOptionTypeStatic {
for index := range node.Data.Inputs.QA.Options {
selectorPorts[nodeID][fmt.Sprintf("branch_%v", index)] = true
}
}
if node.Data.Inputs.QA.OptionType == vo.QAOptionTypeDynamic {
selectorPorts[nodeID][fmt.Sprintf("branch_%v", 0)] = true
}
}
default:
if node.Data.Inputs != nil && node.Data.Inputs.SettingOnError != nil &&
node.Data.Inputs.SettingOnError.ProcessType != nil &&
*node.Data.Inputs.SettingOnError.ProcessType == vo.ErrorProcessTypeExceptionBranch {
if _, exists := selectorPorts[nodeID]; !exists {
selectorPorts[nodeID] = make(map[string]bool)
}
selectorPorts[nodeID]["branch_error"] = true
selectorPorts[nodeID]["default"] = true
} else {
outDegree[node.ID] = 0
}
} }
ba, ok := nodes.GetBranchAdaptor(entity.IDStrToNodeType(node.Type))
if ok {
expects := ba.ExpectPorts(ctx, node)
if len(expects) > 0 {
if _, exists := selectorPorts[nodeID]; !exists {
selectorPorts[nodeID] = make(map[string]bool)
}
for _, e := range expects {
selectorPorts[nodeID][e] = true
}
}
}
} }
for _, edge := range c.Edges { for _, edge := range c.Edges {
@@ -544,8 +513,8 @@ func validateConnections(ctx context.Context, c *vo.Canvas) (issues []*Issue, er
for nodeID, node := range nodeMap { for nodeID, node := range nodeMap {
nodeName := node.Data.Meta.Title nodeName := node.Data.Meta.Title
switch node.Type { switch et := entity.IDStrToNodeType(node.Type); et {
case vo.BlockTypeBotStart: case entity.NodeTypeEntry:
if outDegree[nodeID] == 0 { if outDegree[nodeID] == 0 {
issues = append(issues, &Issue{ issues = append(issues, &Issue{
NodeErr: &NodeErr{ NodeErr: &NodeErr{
@@ -555,13 +524,9 @@ func validateConnections(ctx context.Context, c *vo.Canvas) (issues []*Issue, er
Message: `node "start" not connected`, Message: `node "start" not connected`,
}) })
} }
case vo.BlockTypeBotEnd: case entity.NodeTypeExit:
default: default:
if ports, isSelector := selectorPorts[nodeID]; isSelector { if ports, isSelector := selectorPorts[nodeID]; isSelector {
selectorIssues := &Issue{NodeErr: &NodeErr{
NodeID: node.ID,
NodeName: nodeName,
}}
message := "" message := ""
for port := range ports { for port := range ports {
if portOutDegree[nodeID][port] == 0 { if portOutDegree[nodeID][port] == 0 {
@@ -569,12 +534,15 @@ func validateConnections(ctx context.Context, c *vo.Canvas) (issues []*Issue, er
} }
} }
if len(message) > 0 { if len(message) > 0 {
selectorIssues.Message = message selectorIssues := &Issue{NodeErr: &NodeErr{
NodeID: node.ID,
NodeName: nodeName,
}, Message: message}
issues = append(issues, selectorIssues) issues = append(issues, selectorIssues)
} }
} else { } else {
// Break, continue without checking out degrees // Break, continue without checking out degrees
if node.Type == vo.BlockTypeBotBreak || node.Type == vo.BlockTypeBotContinue { if et == entity.NodeTypeBreak || et == entity.NodeTypeContinue {
continue continue
} }
if outDegree[nodeID] == 0 { if outDegree[nodeID] == 0 {
@@ -585,7 +553,6 @@ func validateConnections(ctx context.Context, c *vo.Canvas) (issues []*Issue, er
}, },
Message: fmt.Sprintf(`node "%v" not connected`, nodeName), Message: fmt.Sprintf(`node "%v" not connected`, nodeName),
}) })
} }
} }
} }
@@ -602,7 +569,7 @@ func analyzeCanvasReachability(c *vo.Canvas) (*reachability, error) {
return nil, err return nil, err
} }
startNode, endNode, err := findStartAndEndNodes(c.Nodes) startNode, _, err := findStartAndEndNodes(c.Nodes)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -612,7 +579,7 @@ func analyzeCanvasReachability(c *vo.Canvas) (*reachability, error) {
edgeMap[edge.SourceNodeID] = append(edgeMap[edge.SourceNodeID], edge.TargetNodeID) edgeMap[edge.SourceNodeID] = append(edgeMap[edge.SourceNodeID], edge.TargetNodeID)
} }
reachable.reachableNodes, err = performReachabilityAnalysis(nodeMap, edgeMap, startNode, endNode) reachable.reachableNodes, err = performReachabilityAnalysis(nodeMap, edgeMap, startNode)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -635,12 +602,12 @@ func processNestedReachability(c *vo.Canvas, r *reachability) error {
Nodes: append([]*vo.Node{ Nodes: append([]*vo.Node{
{ {
ID: node.ID, ID: node.ID,
Type: vo.BlockTypeBotStart, Type: entity.NodeTypeEntry.IDStr(),
Data: node.Data, Data: node.Data,
}, },
{ {
ID: node.ID, ID: node.ID,
Type: vo.BlockTypeBotEnd, Type: entity.NodeTypeExit.IDStr(),
}, },
}, node.Blocks...), }, node.Blocks...),
Edges: node.Edges, Edges: node.Edges,
@@ -663,9 +630,9 @@ func findStartAndEndNodes(nodes []*vo.Node) (*vo.Node, *vo.Node, error) {
for _, node := range nodes { for _, node := range nodes {
switch node.Type { switch node.Type {
case vo.BlockTypeBotStart: case entity.NodeTypeEntry.IDStr():
startNode = node startNode = node
case vo.BlockTypeBotEnd: case entity.NodeTypeExit.IDStr():
endNode = node endNode = node
} }
} }
@@ -680,7 +647,7 @@ func findStartAndEndNodes(nodes []*vo.Node) (*vo.Node, *vo.Node, error) {
return startNode, endNode, nil return startNode, endNode, nil
} }
func performReachabilityAnalysis(nodeMap map[string]*vo.Node, edgeMap map[string][]string, startNode *vo.Node, endNode *vo.Node) (map[string]*vo.Node, error) { func performReachabilityAnalysis(nodeMap map[string]*vo.Node, edgeMap map[string][]string, startNode *vo.Node) (map[string]*vo.Node, error) {
result := make(map[string]*vo.Node) result := make(map[string]*vo.Node)
result[startNode.ID] = startNode result[startNode.ID] = startNode

View File

@@ -1,181 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package compose
import (
"context"
"errors"
"fmt"
"github.com/cloudwego/eino/compose"
"github.com/spf13/cast"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector"
)
func (s *NodeSchema) OutputPortCount() (int, bool) {
var hasExceptionPort bool
if s.ExceptionConfigs != nil && s.ExceptionConfigs.ProcessType != nil &&
*s.ExceptionConfigs.ProcessType == vo.ErrorProcessTypeExceptionBranch {
hasExceptionPort = true
}
switch s.Type {
case entity.NodeTypeSelector:
return len(mustGetKey[[]*selector.OneClauseSchema]("Clauses", s.Configs)) + 1, hasExceptionPort
case entity.NodeTypeQuestionAnswer:
if mustGetKey[qa.AnswerType]("AnswerType", s.Configs.(map[string]any)) == qa.AnswerByChoices {
if mustGetKey[qa.ChoiceType]("ChoiceType", s.Configs.(map[string]any)) == qa.FixedChoices {
return len(mustGetKey[[]string]("FixedChoices", s.Configs.(map[string]any))) + 1, hasExceptionPort
} else {
return 2, hasExceptionPort
}
}
return 1, hasExceptionPort
case entity.NodeTypeIntentDetector:
intents := mustGetKey[[]string]("Intents", s.Configs.(map[string]any))
return len(intents) + 1, hasExceptionPort
default:
return 1, hasExceptionPort
}
}
type BranchMapping struct {
Normal []map[string]bool
Exception map[string]bool
}
const (
DefaultBranch = "default"
BranchFmt = "branch_%d"
)
func (s *NodeSchema) GetBranch(bMapping *BranchMapping) (*compose.GraphBranch, error) {
if bMapping == nil {
return nil, errors.New("no branch mapping")
}
endNodes := make(map[string]bool)
for i := range bMapping.Normal {
for k := range bMapping.Normal[i] {
endNodes[k] = true
}
}
if bMapping.Exception != nil {
for k := range bMapping.Exception {
endNodes[k] = true
}
}
switch s.Type {
case entity.NodeTypeSelector:
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
choice := in[selector.SelectKey].(int)
if choice < 0 || choice > len(bMapping.Normal) {
return nil, fmt.Errorf("node %s choice out of range: %d", s.Key, choice)
}
choices := make(map[string]bool, len((bMapping.Normal)[choice]))
for k := range (bMapping.Normal)[choice] {
choices[k] = true
}
return choices, nil
}
return compose.NewGraphMultiBranch(condition, endNodes), nil
case entity.NodeTypeQuestionAnswer:
conf := s.Configs.(map[string]any)
if mustGetKey[qa.AnswerType]("AnswerType", conf) == qa.AnswerByChoices {
choiceType := mustGetKey[qa.ChoiceType]("ChoiceType", conf)
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
optionID, ok := nodes.TakeMapValue(in, compose.FieldPath{qa.OptionIDKey})
if !ok {
return nil, fmt.Errorf("failed to take option id from input map: %v", in)
}
if optionID.(string) == "other" {
return (bMapping.Normal)[len(bMapping.Normal)-1], nil
}
if choiceType == qa.DynamicChoices { // all dynamic choices maps to branch 0
return (bMapping.Normal)[0], nil
}
optionIDInt, ok := qa.AlphabetToInt(optionID.(string))
if !ok {
return nil, fmt.Errorf("failed to convert option id from input map: %v", optionID)
}
if optionIDInt < 0 || optionIDInt >= len(bMapping.Normal) {
return nil, fmt.Errorf("failed to take option id from input map: %v", in)
}
return (bMapping.Normal)[optionIDInt], nil
}
return compose.NewGraphMultiBranch(condition, endNodes), nil
}
return nil, fmt.Errorf("this qa node should not have branches: %s", s.Key)
case entity.NodeTypeIntentDetector:
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
isSuccess, ok := in["isSuccess"]
if ok && isSuccess != nil && !isSuccess.(bool) {
return bMapping.Exception, nil
}
classificationId, ok := nodes.TakeMapValue(in, compose.FieldPath{"classificationId"})
if !ok {
return nil, fmt.Errorf("failed to take classification id from input map: %v", in)
}
// Intent detector the node default branch uses classificationId=0. But currently scene, the implementation uses default as the last element of the array.
// Therefore, when classificationId=0, it needs to be converted into the node corresponding to the last index of the array.
// Other options also need to reduce the index by 1.
id, err := cast.ToInt64E(classificationId)
if err != nil {
return nil, err
}
realID := id - 1
if realID >= int64(len(bMapping.Normal)) {
return nil, fmt.Errorf("invalid classification id from input, classification id: %v", classificationId)
}
if realID < 0 {
realID = int64(len(bMapping.Normal)) - 1
}
return (bMapping.Normal)[realID], nil
}
return compose.NewGraphMultiBranch(condition, endNodes), nil
default:
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
isSuccess, ok := in["isSuccess"]
if ok && isSuccess != nil && !isSuccess.(bool) {
return bMapping.Exception, nil
}
return (bMapping.Normal)[0], nil
}
return compose.NewGraphMultiBranch(condition, endNodes), nil
}
}

View File

@@ -1,194 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package compose
import (
"context"
"fmt"
"strconv"
"strings"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector"
)
type selectorCallbackField struct {
Key string `json:"key"`
Type vo.DataType `json:"type"`
Value any `json:"value"`
}
type selectorCondition struct {
Left selectorCallbackField `json:"left"`
Operator vo.OperatorType `json:"operator"`
Right *selectorCallbackField `json:"right,omitempty"`
}
type selectorBranch struct {
Conditions []*selectorCondition `json:"conditions"`
Logic vo.LogicType `json:"logic"`
Name string `json:"name"`
}
func (s *NodeSchema) toSelectorCallbackInput(sc *WorkflowSchema) func(_ context.Context, in map[string]any) (map[string]any, error) {
return func(_ context.Context, in map[string]any) (map[string]any, error) {
config := mustGetKey[[]*selector.OneClauseSchema]("Clauses", s.Configs)
count := len(config)
output := make([]*selectorBranch, count)
for _, source := range s.InputSources {
targetPath := source.Path
if len(targetPath) == 2 {
indexStr := targetPath[0]
index, err := strconv.Atoi(indexStr)
if err != nil {
return nil, err
}
branch := output[index]
if branch == nil {
output[index] = &selectorBranch{
Conditions: []*selectorCondition{
{
Operator: config[index].Single.ToCanvasOperatorType(),
},
},
Logic: selector.ClauseRelationAND.ToVOLogicType(),
}
}
if targetPath[1] == selector.LeftKey {
leftV, ok := nodes.TakeMapValue(in, targetPath)
if !ok {
return nil, fmt.Errorf("failed to take left value of %s", targetPath)
}
if source.Source.Ref.VariableType != nil {
if *source.Source.Ref.VariableType == vo.ParentIntermediate {
parentNodeKey, ok := sc.Hierarchy[s.Key]
if !ok {
return nil, fmt.Errorf("failed to find parent node key of %s", s.Key)
}
parentNode := sc.GetNode(parentNodeKey)
output[index].Conditions[0].Left = selectorCallbackField{
Key: parentNode.Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
Value: leftV,
}
} else {
output[index].Conditions[0].Left = selectorCallbackField{
Key: "",
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
Value: leftV,
}
}
} else {
output[index].Conditions[0].Left = selectorCallbackField{
Key: sc.GetNode(source.Source.Ref.FromNodeKey).Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
Value: leftV,
}
}
} else if targetPath[1] == selector.RightKey {
rightV, ok := nodes.TakeMapValue(in, targetPath)
if !ok {
return nil, fmt.Errorf("failed to take right value of %s", targetPath)
}
output[index].Conditions[0].Right = &selectorCallbackField{
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
Value: rightV,
}
}
} else if len(targetPath) == 3 {
indexStr := targetPath[0]
index, err := strconv.Atoi(indexStr)
if err != nil {
return nil, err
}
multi := config[index].Multi
branch := output[index]
if branch == nil {
output[index] = &selectorBranch{
Conditions: make([]*selectorCondition, len(multi.Clauses)),
Logic: multi.Relation.ToVOLogicType(),
}
}
clauseIndexStr := targetPath[1]
clauseIndex, err := strconv.Atoi(clauseIndexStr)
if err != nil {
return nil, err
}
clause := multi.Clauses[clauseIndex]
if output[index].Conditions[clauseIndex] == nil {
output[index].Conditions[clauseIndex] = &selectorCondition{
Operator: clause.ToCanvasOperatorType(),
}
}
if targetPath[2] == selector.LeftKey {
leftV, ok := nodes.TakeMapValue(in, targetPath)
if !ok {
return nil, fmt.Errorf("failed to take left value of %s", targetPath)
}
if source.Source.Ref.VariableType != nil {
if *source.Source.Ref.VariableType == vo.ParentIntermediate {
parentNodeKey, ok := sc.Hierarchy[s.Key]
if !ok {
return nil, fmt.Errorf("failed to find parent node key of %s", s.Key)
}
parentNode := sc.GetNode(parentNodeKey)
output[index].Conditions[clauseIndex].Left = selectorCallbackField{
Key: parentNode.Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
Value: leftV,
}
} else {
output[index].Conditions[clauseIndex].Left = selectorCallbackField{
Key: "",
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
Value: leftV,
}
}
} else {
output[index].Conditions[clauseIndex].Left = selectorCallbackField{
Key: sc.GetNode(source.Source.Ref.FromNodeKey).Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
Value: leftV,
}
}
} else if targetPath[2] == selector.RightKey {
rightV, ok := nodes.TakeMapValue(in, targetPath)
if !ok {
return nil, fmt.Errorf("failed to take right value of %s", targetPath)
}
output[index].Conditions[clauseIndex].Right = &selectorCallbackField{
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
Value: rightV,
}
}
}
}
return map[string]any{"branches": output}, nil
}
}

View File

@@ -31,7 +31,9 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
) )
@@ -53,7 +55,7 @@ func (r *WorkflowRunner) designateOptions(ctx context.Context) (context.Context,
rootHandler := execute.NewRootWorkflowHandler( rootHandler := execute.NewRootWorkflowHandler(
wb, wb,
executeID, executeID,
workflowSC.requireCheckPoint, workflowSC.RequireCheckpoint(),
eventChan, eventChan,
resumedEvent, resumedEvent,
exeCfg, exeCfg,
@@ -67,7 +69,7 @@ func (r *WorkflowRunner) designateOptions(ctx context.Context) (context.Context,
var nodeOpt einoCompose.Option var nodeOpt einoCompose.Option
if ns.Type == entity.NodeTypeExit { if ns.Type == entity.NodeTypeExit {
nodeOpt = nodeCallbackOption(key, ns.Name, eventChan, resumedEvent, nodeOpt = nodeCallbackOption(key, ns.Name, eventChan, resumedEvent,
ptr.Of(mustGetKey[vo.TerminatePlan]("TerminalPlan", ns.Configs))) ptr.Of(ns.Configs.(*exit.Config).TerminatePlan))
} else if ns.Type != entity.NodeTypeLambda { } else if ns.Type != entity.NodeTypeLambda {
nodeOpt = nodeCallbackOption(key, ns.Name, eventChan, resumedEvent, nil) nodeOpt = nodeCallbackOption(key, ns.Name, eventChan, resumedEvent, nil)
} }
@@ -117,7 +119,7 @@ func (r *WorkflowRunner) designateOptions(ctx context.Context) (context.Context,
} }
} }
if workflowSC.requireCheckPoint { if workflowSC.RequireCheckpoint() {
opts = append(opts, einoCompose.WithCheckPointID(strconv.FormatInt(executeID, 10))) opts = append(opts, einoCompose.WithCheckPointID(strconv.FormatInt(executeID, 10)))
} }
@@ -139,7 +141,7 @@ func WrapOptWithIndex(opt einoCompose.Option, parentNodeKey vo.NodeKey, index in
func (r *WorkflowRunner) designateOptionsForSubWorkflow(ctx context.Context, func (r *WorkflowRunner) designateOptionsForSubWorkflow(ctx context.Context,
parentHandler *execute.WorkflowHandler, parentHandler *execute.WorkflowHandler,
ns *NodeSchema, ns *schema2.NodeSchema,
pathPrefix ...string) (opts []einoCompose.Option, err error) { pathPrefix ...string) (opts []einoCompose.Option, err error) {
var ( var (
resumeEvent = r.interruptEvent resumeEvent = r.interruptEvent
@@ -163,7 +165,7 @@ func (r *WorkflowRunner) designateOptionsForSubWorkflow(ctx context.Context,
var nodeOpt einoCompose.Option var nodeOpt einoCompose.Option
if subNS.Type == entity.NodeTypeExit { if subNS.Type == entity.NodeTypeExit {
nodeOpt = nodeCallbackOption(key, subNS.Name, eventChan, resumeEvent, nodeOpt = nodeCallbackOption(key, subNS.Name, eventChan, resumeEvent,
ptr.Of(mustGetKey[vo.TerminatePlan]("TerminalPlan", subNS.Configs))) ptr.Of(subNS.Configs.(*exit.Config).TerminatePlan))
} else { } else {
nodeOpt = nodeCallbackOption(key, subNS.Name, eventChan, resumeEvent, nil) nodeOpt = nodeCallbackOption(key, subNS.Name, eventChan, resumeEvent, nil)
} }
@@ -219,7 +221,7 @@ func (r *WorkflowRunner) designateOptionsForSubWorkflow(ctx context.Context,
return opts, nil return opts, nil
} }
func llmToolCallbackOptions(ctx context.Context, ns *NodeSchema, eventChan chan *execute.Event, func llmToolCallbackOptions(ctx context.Context, ns *schema2.NodeSchema, eventChan chan *execute.Event,
sw *schema.StreamWriter[*entity.Message]) ( sw *schema.StreamWriter[*entity.Message]) (
opts []einoCompose.Option, err error) { opts []einoCompose.Option, err error) {
// this is a LLM node. // this is a LLM node.
@@ -229,7 +231,8 @@ func llmToolCallbackOptions(ctx context.Context, ns *NodeSchema, eventChan chan
panic("impossible: llmToolCallbackOptions is called on a non-LLM node") panic("impossible: llmToolCallbackOptions is called on a non-LLM node")
} }
fcParams := getKeyOrZero[*vo.FCParam]("FCParam", ns.Configs) cfg := ns.Configs.(*llm.Config)
fcParams := cfg.FCParam
if fcParams != nil { if fcParams != nil {
if fcParams.WorkflowFCParam != nil { if fcParams.WorkflowFCParam != nil {
// TODO: try to avoid getting the workflow tool all over again // TODO: try to avoid getting the workflow tool all over again
@@ -272,7 +275,7 @@ func llmToolCallbackOptions(ctx context.Context, ns *NodeSchema, eventChan chan
toolHandler := execute.NewToolHandler(eventChan, funcInfo) toolHandler := execute.NewToolHandler(eventChan, funcInfo)
opt := einoCompose.WithCallbacks(toolHandler) opt := einoCompose.WithCallbacks(toolHandler)
opt = einoCompose.WithLambdaOption(llm.WithNestedWorkflowOptions(nodes.WithOptsForNested(opt))).DesignateNode(string(ns.Key)) opt = einoCompose.WithLambdaOption(nodes.WithOptsForNested(opt)).DesignateNode(string(ns.Key))
opts = append(opts, opt) opts = append(opts, opt)
} }
} }
@@ -310,7 +313,7 @@ func llmToolCallbackOptions(ctx context.Context, ns *NodeSchema, eventChan chan
toolHandler := execute.NewToolHandler(eventChan, funcInfo) toolHandler := execute.NewToolHandler(eventChan, funcInfo)
opt := einoCompose.WithCallbacks(toolHandler) opt := einoCompose.WithCallbacks(toolHandler)
opt = einoCompose.WithLambdaOption(llm.WithNestedWorkflowOptions(nodes.WithOptsForNested(opt))).DesignateNode(string(ns.Key)) opt = einoCompose.WithLambdaOption(nodes.WithOptsForNested(opt)).DesignateNode(string(ns.Key))
opts = append(opts, opt) opts = append(opts, opt)
} }
} }

View File

@@ -25,12 +25,13 @@ import (
"github.com/cloudwego/eino/schema" "github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/sonic" "github.com/coze-dev/coze-studio/backend/pkg/sonic"
) )
// outputValueFiller will fill the output value with nil if the key is not present in the output map. // outputValueFiller will fill the output value with nil if the key is not present in the output map.
// if a node emits stream as output, the node needs to handle these absent keys in stream themselves. // if a node emits stream as output, the node needs to handle these absent keys in stream themselves.
func (s *NodeSchema) outputValueFiller() func(ctx context.Context, output map[string]any) (map[string]any, error) { func outputValueFiller(s *schema2.NodeSchema) func(ctx context.Context, output map[string]any) (map[string]any, error) {
if len(s.OutputTypes) == 0 { if len(s.OutputTypes) == 0 {
return func(ctx context.Context, output map[string]any) (map[string]any, error) { return func(ctx context.Context, output map[string]any) (map[string]any, error) {
return output, nil return output, nil
@@ -55,7 +56,7 @@ func (s *NodeSchema) outputValueFiller() func(ctx context.Context, output map[st
// inputValueFiller will fill the input value with default value(zero or nil) if the input value is not present in map. // inputValueFiller will fill the input value with default value(zero or nil) if the input value is not present in map.
// if a node accepts stream as input, the node needs to handle these absent keys in stream themselves. // if a node accepts stream as input, the node needs to handle these absent keys in stream themselves.
func (s *NodeSchema) inputValueFiller() func(ctx context.Context, input map[string]any) (map[string]any, error) { func inputValueFiller(s *schema2.NodeSchema) func(ctx context.Context, input map[string]any) (map[string]any, error) {
if len(s.InputTypes) == 0 { if len(s.InputTypes) == 0 {
return func(ctx context.Context, input map[string]any) (map[string]any, error) { return func(ctx context.Context, input map[string]any) (map[string]any, error) {
return input, nil return input, nil
@@ -78,7 +79,7 @@ func (s *NodeSchema) inputValueFiller() func(ctx context.Context, input map[stri
} }
} }
func (s *NodeSchema) streamInputValueFiller() func(ctx context.Context, func streamInputValueFiller(s *schema2.NodeSchema) func(ctx context.Context,
input *schema.StreamReader[map[string]any]) *schema.StreamReader[map[string]any] { input *schema.StreamReader[map[string]any]) *schema.StreamReader[map[string]any] {
fn := func(ctx context.Context, i map[string]any) (map[string]any, error) { fn := func(ctx context.Context, i map[string]any) (map[string]any, error) {
newI := make(map[string]any) newI := make(map[string]any)

View File

@@ -23,6 +23,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
) )
func TestNodeSchema_OutputValueFiller(t *testing.T) { func TestNodeSchema_OutputValueFiller(t *testing.T) {
@@ -282,11 +283,11 @@ func TestNodeSchema_OutputValueFiller(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
s := &NodeSchema{ s := &schema.NodeSchema{
OutputTypes: tt.fields.Outputs, OutputTypes: tt.fields.Outputs,
} }
got, err := s.outputValueFiller()(context.Background(), tt.fields.In) got, err := outputValueFiller(s)(context.Background(), tt.fields.In)
if len(tt.wantErr) > 0 { if len(tt.wantErr) > 0 {
assert.Error(t, err) assert.Error(t, err)

View File

@@ -0,0 +1,118 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package compose
import (
"context"
"fmt"
"runtime/debug"
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/subworkflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
type Node struct {
Lambda *compose.Lambda
}
// New instantiates the actual node type from NodeSchema.
func New(ctx context.Context, s *schema.NodeSchema,
inner compose.Runnable[map[string]any, map[string]any], // inner workflow for composite node
sc *schema.WorkflowSchema, // the workflow this NodeSchema is in
deps *dependencyInfo, // the dependency for this node pre-calculated by workflow engine
) (_ *Node, err error) {
defer func() {
if panicErr := recover(); panicErr != nil {
err = safego.NewPanicErr(panicErr, debug.Stack())
}
if err != nil {
err = vo.WrapIfNeeded(errno.ErrCreateNodeFail, err, errorx.KV("node_name", s.Name), errorx.KV("cause", err.Error()))
}
}()
var fullSources map[string]*schema.SourceInfo
if m := entity.NodeMetaByNodeType(s.Type); m != nil && m.InputSourceAware {
if fullSources, err = GetFullSources(s, sc, deps); err != nil {
return nil, err
}
s.FullSources = fullSources
}
// if NodeSchema's Configs implements NodeBuilder, will use it to build the node
nb, ok := s.Configs.(schema.NodeBuilder)
if ok {
opts := []schema.BuildOption{
schema.WithWorkflowSchema(sc),
schema.WithInnerWorkflow(inner),
}
// build the actual InvokableNode, etc.
n, err := nb.Build(ctx, s, opts...)
if err != nil {
return nil, err
}
// wrap InvokableNode, etc. within NodeRunner, converting to eino's Lambda
return toNode(s, n), nil
}
switch s.Type {
case entity.NodeTypeLambda:
if s.Lambda == nil {
return nil, fmt.Errorf("lambda is not defined for NodeTypeLambda")
}
return &Node{Lambda: s.Lambda}, nil
case entity.NodeTypeSubWorkflow:
subWorkflow, err := buildSubWorkflow(ctx, s, sc.RequireCheckpoint())
if err != nil {
return nil, err
}
return toNode(s, subWorkflow), nil
default:
panic(fmt.Sprintf("node schema's Configs does not implement NodeBuilder. type: %v", s.Type))
}
}
func buildSubWorkflow(ctx context.Context, s *schema.NodeSchema, requireCheckpoint bool) (*subworkflow.SubWorkflow, error) {
var opts []WorkflowOption
opts = append(opts, WithIDAsName(s.Configs.(*subworkflow.Config).WorkflowID))
if requireCheckpoint {
opts = append(opts, WithParentRequireCheckpoint())
}
if s := execute.GetStaticConfig(); s != nil && s.MaxNodeCountPerWorkflow > 0 {
opts = append(opts, WithMaxNodeCount(s.MaxNodeCountPerWorkflow))
}
wf, err := NewWorkflow(ctx, s.SubWorkflowSchema, opts...)
if err != nil {
return nil, err
}
return &subworkflow.SubWorkflow{
Runner: wf.Runner,
}, nil
}

View File

@@ -33,6 +33,8 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/errorx" "github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/logs" "github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/safego" "github.com/coze-dev/coze-studio/backend/pkg/safego"
@@ -48,7 +50,6 @@ type nodeRunConfig[O any] struct {
maxRetry int64 maxRetry int64
errProcessType vo.ErrorProcessType errProcessType vo.ErrorProcessType
dataOnErr func(ctx context.Context) map[string]any dataOnErr func(ctx context.Context) map[string]any
callbackEnabled bool
preProcessors []func(ctx context.Context, input map[string]any) (map[string]any, error) preProcessors []func(ctx context.Context, input map[string]any) (map[string]any, error)
postProcessors []func(ctx context.Context, input map[string]any) (map[string]any, error) postProcessors []func(ctx context.Context, input map[string]any) (map[string]any, error)
streamPreProcessors []func(ctx context.Context, streamPreProcessors []func(ctx context.Context,
@@ -58,12 +59,14 @@ type nodeRunConfig[O any] struct {
init []func(context.Context) (context.Context, error) init []func(context.Context) (context.Context, error)
i compose.Invoke[map[string]any, map[string]any, O] i compose.Invoke[map[string]any, map[string]any, O]
s compose.Stream[map[string]any, map[string]any, O] s compose.Stream[map[string]any, map[string]any, O]
c compose.Collect[map[string]any, map[string]any, O]
t compose.Transform[map[string]any, map[string]any, O] t compose.Transform[map[string]any, map[string]any, O]
} }
func newNodeRunConfig[O any](ns *NodeSchema, func newNodeRunConfig[O any](ns *schema2.NodeSchema,
i compose.Invoke[map[string]any, map[string]any, O], i compose.Invoke[map[string]any, map[string]any, O],
s compose.Stream[map[string]any, map[string]any, O], s compose.Stream[map[string]any, map[string]any, O],
c compose.Collect[map[string]any, map[string]any, O],
t compose.Transform[map[string]any, map[string]any, O], t compose.Transform[map[string]any, map[string]any, O],
opts *newNodeOptions) *nodeRunConfig[O] { opts *newNodeOptions) *nodeRunConfig[O] {
meta := entity.NodeMetaByNodeType(ns.Type) meta := entity.NodeMetaByNodeType(ns.Type)
@@ -92,12 +95,12 @@ func newNodeRunConfig[O any](ns *NodeSchema,
keyFinishedMarkerTrimmer(), keyFinishedMarkerTrimmer(),
} }
if meta.PreFillZero { if meta.PreFillZero {
preProcessors = append(preProcessors, ns.inputValueFiller()) preProcessors = append(preProcessors, inputValueFiller(ns))
} }
var postProcessors []func(ctx context.Context, input map[string]any) (map[string]any, error) var postProcessors []func(ctx context.Context, input map[string]any) (map[string]any, error)
if meta.PostFillNil { if meta.PostFillNil {
postProcessors = append(postProcessors, ns.outputValueFiller()) postProcessors = append(postProcessors, outputValueFiller(ns))
} }
streamPreProcessors := []func(ctx context.Context, streamPreProcessors := []func(ctx context.Context,
@@ -110,7 +113,15 @@ func newNodeRunConfig[O any](ns *NodeSchema,
}, },
} }
if meta.PreFillZero { if meta.PreFillZero {
streamPreProcessors = append(streamPreProcessors, ns.streamInputValueFiller()) streamPreProcessors = append(streamPreProcessors, streamInputValueFiller(ns))
}
if meta.UseCtxCache {
opts.init = append([]func(ctx context.Context) (context.Context, error){
func(ctx context.Context) (context.Context, error) {
return ctxcache.Init(ctx), nil
},
}, opts.init...)
} }
opts.init = append(opts.init, func(ctx context.Context) (context.Context, error) { opts.init = append(opts.init, func(ctx context.Context) (context.Context, error) {
@@ -129,7 +140,6 @@ func newNodeRunConfig[O any](ns *NodeSchema,
maxRetry: maxRetry, maxRetry: maxRetry,
errProcessType: errProcessType, errProcessType: errProcessType,
dataOnErr: dataOnErr, dataOnErr: dataOnErr,
callbackEnabled: meta.CallbackEnabled,
preProcessors: preProcessors, preProcessors: preProcessors,
postProcessors: postProcessors, postProcessors: postProcessors,
streamPreProcessors: streamPreProcessors, streamPreProcessors: streamPreProcessors,
@@ -138,18 +148,21 @@ func newNodeRunConfig[O any](ns *NodeSchema,
init: opts.init, init: opts.init,
i: i, i: i,
s: s, s: s,
c: c,
t: t, t: t,
} }
} }
func newNodeRunConfigWOOpt(ns *NodeSchema, func newNodeRunConfigWOOpt(ns *schema2.NodeSchema,
i compose.InvokeWOOpt[map[string]any, map[string]any], i compose.InvokeWOOpt[map[string]any, map[string]any],
s compose.StreamWOOpt[map[string]any, map[string]any], s compose.StreamWOOpt[map[string]any, map[string]any],
c compose.CollectWOOpt[map[string]any, map[string]any],
t compose.TransformWOOpts[map[string]any, map[string]any], t compose.TransformWOOpts[map[string]any, map[string]any],
opts *newNodeOptions) *nodeRunConfig[any] { opts *newNodeOptions) *nodeRunConfig[any] {
var ( var (
iWO compose.Invoke[map[string]any, map[string]any, any] iWO compose.Invoke[map[string]any, map[string]any, any]
sWO compose.Stream[map[string]any, map[string]any, any] sWO compose.Stream[map[string]any, map[string]any, any]
cWO compose.Collect[map[string]any, map[string]any, any]
tWO compose.Transform[map[string]any, map[string]any, any] tWO compose.Transform[map[string]any, map[string]any, any]
) )
@@ -165,13 +178,19 @@ func newNodeRunConfigWOOpt(ns *NodeSchema,
} }
} }
if c != nil {
cWO = func(ctx context.Context, in *schema.StreamReader[map[string]any], _ ...any) (out map[string]any, err error) {
return c(ctx, in)
}
}
if t != nil { if t != nil {
tWO = func(ctx context.Context, input *schema.StreamReader[map[string]any], opts ...any) (output *schema.StreamReader[map[string]any], err error) { tWO = func(ctx context.Context, input *schema.StreamReader[map[string]any], opts ...any) (output *schema.StreamReader[map[string]any], err error) {
return t(ctx, input) return t(ctx, input)
} }
} }
return newNodeRunConfig[any](ns, iWO, sWO, tWO, opts) return newNodeRunConfig[any](ns, iWO, sWO, cWO, tWO, opts)
} }
type newNodeOptions struct { type newNodeOptions struct {
@@ -180,57 +199,100 @@ type newNodeOptions struct {
init []func(context.Context) (context.Context, error) init []func(context.Context) (context.Context, error)
} }
type newNodeOption func(*newNodeOptions) func toNode(ns *schema2.NodeSchema, r any) *Node {
iWOpt, _ := r.(nodes.InvokableNodeWOpt)
sWOpt, _ := r.(nodes.StreamableNodeWOpt)
cWOpt, _ := r.(nodes.CollectableNodeWOpt)
tWOpt, _ := r.(nodes.TransformableNodeWOpt)
iWOOpt, _ := r.(nodes.InvokableNode)
sWOOpt, _ := r.(nodes.StreamableNode)
cWOOpt, _ := r.(nodes.CollectableNode)
tWOOpt, _ := r.(nodes.TransformableNode)
func withCallbackInputConverter(f func(context.Context, map[string]any) (map[string]any, error)) newNodeOption { var wOpt, wOOpt bool
return func(opts *newNodeOptions) { if iWOpt != nil || sWOpt != nil || cWOpt != nil || tWOpt != nil {
opts.callbackInputConverter = f wOpt = true
} }
} if iWOOpt != nil || sWOOpt != nil || cWOOpt != nil || tWOOpt != nil {
func withCallbackOutputConverter(f func(context.Context, map[string]any) (*nodes.StructuredCallbackOutput, error)) newNodeOption { wOOpt = true
return func(opts *newNodeOptions) { }
opts.callbackOutputConverter = f
if wOpt && wOOpt {
panic("a node's different streaming methods needs to be consistent: " +
"they should ALL have NodeOption or None should have them")
}
if !wOpt && !wOOpt {
panic("a node should implement at least one interface among: InvokableNodeWOpt, StreamableNodeWOpt, CollectableNodeWOpt, TransformableNodeWOpt, InvokableNode, StreamableNode, CollectableNode, TransformableNode")
} }
}
func withInit(f func(context.Context) (context.Context, error)) newNodeOption {
return func(opts *newNodeOptions) {
opts.init = append(opts.init, f)
}
}
func invokableNode(ns *NodeSchema, i compose.InvokeWOOpt[map[string]any, map[string]any], opts ...newNodeOption) *Node {
options := &newNodeOptions{} options := &newNodeOptions{}
for _, opt := range opts { ci, ok := r.(nodes.CallbackInputConverted)
opt(options) if ok {
options.callbackInputConverter = ci.ToCallbackInput
} }
return newNodeRunConfigWOOpt(ns, i, nil, nil, options).toNode() co, ok := r.(nodes.CallbackOutputConverted)
} if ok {
options.callbackOutputConverter = co.ToCallbackOutput
func invokableNodeWO[O any](ns *NodeSchema, i compose.Invoke[map[string]any, map[string]any, O], opts ...newNodeOption) *Node {
options := &newNodeOptions{}
for _, opt := range opts {
opt(options)
} }
return newNodeRunConfig(ns, i, nil, nil, options).toNode() init, ok := r.(nodes.Initializer)
} if ok {
options.init = append(options.init, init.Init)
func invokableTransformableNode(ns *NodeSchema, i compose.InvokeWOOpt[map[string]any, map[string]any],
t compose.TransformWOOpts[map[string]any, map[string]any], opts ...newNodeOption) *Node {
options := &newNodeOptions{}
for _, opt := range opts {
opt(options)
} }
return newNodeRunConfigWOOpt(ns, i, nil, t, options).toNode()
}
func invokableStreamableNodeWO[O any](ns *NodeSchema, i compose.Invoke[map[string]any, map[string]any, O], s compose.Stream[map[string]any, map[string]any, O], opts ...newNodeOption) *Node { if wOpt {
options := &newNodeOptions{} var (
for _, opt := range opts { i compose.Invoke[map[string]any, map[string]any, nodes.NodeOption]
opt(options) s compose.Stream[map[string]any, map[string]any, nodes.NodeOption]
c compose.Collect[map[string]any, map[string]any, nodes.NodeOption]
t compose.Transform[map[string]any, map[string]any, nodes.NodeOption]
)
if iWOpt != nil {
i = iWOpt.Invoke
}
if sWOpt != nil {
s = sWOpt.Stream
}
if cWOpt != nil {
c = cWOpt.Collect
}
if tWOpt != nil {
t = tWOpt.Transform
}
return newNodeRunConfig(ns, i, s, c, t, options).toNode()
} }
return newNodeRunConfig(ns, i, s, nil, options).toNode()
var (
i compose.InvokeWOOpt[map[string]any, map[string]any]
s compose.StreamWOOpt[map[string]any, map[string]any]
c compose.CollectWOOpt[map[string]any, map[string]any]
t compose.TransformWOOpts[map[string]any, map[string]any]
)
if iWOOpt != nil {
i = iWOOpt.Invoke
}
if sWOOpt != nil {
s = sWOOpt.Stream
}
if cWOOpt != nil {
c = cWOOpt.Collect
}
if tWOOpt != nil {
t = tWOOpt.Transform
}
return newNodeRunConfigWOOpt(ns, i, s, c, t, options).toNode()
} }
func (nc *nodeRunConfig[O]) invoke() func(ctx context.Context, input map[string]any, opts ...O) (output map[string]any, err error) { func (nc *nodeRunConfig[O]) invoke() func(ctx context.Context, input map[string]any, opts ...O) (output map[string]any, err error) {
@@ -375,10 +437,8 @@ func (nc *nodeRunConfig[O]) transform() func(ctx context.Context, input *schema.
func (nc *nodeRunConfig[O]) toNode() *Node { func (nc *nodeRunConfig[O]) toNode() *Node {
var opts []compose.LambdaOpt var opts []compose.LambdaOpt
opts = append(opts, compose.WithLambdaType(string(nc.nodeType))) opts = append(opts, compose.WithLambdaType(string(nc.nodeType)))
opts = append(opts, compose.WithLambdaCallbackEnable(true))
if nc.callbackEnabled {
opts = append(opts, compose.WithLambdaCallbackEnable(true))
}
l, err := compose.AnyLambda(nc.invoke(), nc.stream(), nil, nc.transform(), opts...) l, err := compose.AnyLambda(nc.invoke(), nc.stream(), nil, nc.transform(), opts...)
if err != nil { if err != nil {
panic(fmt.Sprintf("failed to create lambda for node %s, err: %v", nc.nodeName, err)) panic(fmt.Sprintf("failed to create lambda for node %s, err: %v", nc.nodeName, err))
@@ -406,9 +466,6 @@ func newNodeRunner[O any](ctx context.Context, cfg *nodeRunConfig[O]) (context.C
} }
func (r *nodeRunner[O]) onStart(ctx context.Context, input map[string]any) (context.Context, error) { func (r *nodeRunner[O]) onStart(ctx context.Context, input map[string]any) (context.Context, error) {
if !r.callbackEnabled {
return ctx, nil
}
if r.callbackInputConverter != nil { if r.callbackInputConverter != nil {
convertedInput, err := r.callbackInputConverter(ctx, input) convertedInput, err := r.callbackInputConverter(ctx, input)
if err != nil { if err != nil {
@@ -425,10 +482,6 @@ func (r *nodeRunner[O]) onStart(ctx context.Context, input map[string]any) (cont
func (r *nodeRunner[O]) onStartStream(ctx context.Context, input *schema.StreamReader[map[string]any]) ( func (r *nodeRunner[O]) onStartStream(ctx context.Context, input *schema.StreamReader[map[string]any]) (
context.Context, *schema.StreamReader[map[string]any], error) { context.Context, *schema.StreamReader[map[string]any], error) {
if !r.callbackEnabled {
return ctx, input, nil
}
if r.callbackInputConverter != nil { if r.callbackInputConverter != nil {
copied := input.Copy(2) copied := input.Copy(2)
realConverter := func(ctx context.Context) func(map[string]any) (map[string]any, error) { realConverter := func(ctx context.Context) func(map[string]any) (map[string]any, error) {
@@ -580,14 +633,10 @@ func (r *nodeRunner[O]) transform(ctx context.Context, input *schema.StreamReade
} }
func (r *nodeRunner[O]) onEnd(ctx context.Context, output map[string]any) error { func (r *nodeRunner[O]) onEnd(ctx context.Context, output map[string]any) error {
if r.errProcessType == vo.ErrorProcessTypeExceptionBranch || r.errProcessType == vo.ErrorProcessTypeDefault { if r.errProcessType == vo.ErrorProcessTypeExceptionBranch || r.errProcessType == vo.ErrorProcessTypeReturnDefaultData {
output["isSuccess"] = true output["isSuccess"] = true
} }
if !r.callbackEnabled {
return nil
}
if r.callbackOutputConverter != nil { if r.callbackOutputConverter != nil {
convertedOutput, err := r.callbackOutputConverter(ctx, output) convertedOutput, err := r.callbackOutputConverter(ctx, output)
if err != nil { if err != nil {
@@ -603,15 +652,11 @@ func (r *nodeRunner[O]) onEnd(ctx context.Context, output map[string]any) error
func (r *nodeRunner[O]) onEndStream(ctx context.Context, output *schema.StreamReader[map[string]any]) ( func (r *nodeRunner[O]) onEndStream(ctx context.Context, output *schema.StreamReader[map[string]any]) (
*schema.StreamReader[map[string]any], error) { *schema.StreamReader[map[string]any], error) {
if r.errProcessType == vo.ErrorProcessTypeExceptionBranch || r.errProcessType == vo.ErrorProcessTypeDefault { if r.errProcessType == vo.ErrorProcessTypeExceptionBranch || r.errProcessType == vo.ErrorProcessTypeReturnDefaultData {
flag := schema.StreamReaderFromArray([]map[string]any{{"isSuccess": true}}) flag := schema.StreamReaderFromArray([]map[string]any{{"isSuccess": true}})
output = schema.MergeStreamReaders([]*schema.StreamReader[map[string]any]{flag, output}) output = schema.MergeStreamReaders([]*schema.StreamReader[map[string]any]{flag, output})
} }
if !r.callbackEnabled {
return output, nil
}
if r.callbackOutputConverter != nil { if r.callbackOutputConverter != nil {
copied := output.Copy(2) copied := output.Copy(2)
realConverter := func(ctx context.Context) func(map[string]any) (*nodes.StructuredCallbackOutput, error) { realConverter := func(ctx context.Context) func(map[string]any) (*nodes.StructuredCallbackOutput, error) {
@@ -632,9 +677,7 @@ func (r *nodeRunner[O]) onEndStream(ctx context.Context, output *schema.StreamRe
func (r *nodeRunner[O]) onError(ctx context.Context, err error) (map[string]any, bool) { func (r *nodeRunner[O]) onError(ctx context.Context, err error) (map[string]any, bool) {
if r.interrupted { if r.interrupted {
if r.callbackEnabled { _ = callbacks.OnError(ctx, err)
_ = callbacks.OnError(ctx, err)
}
return nil, false return nil, false
} }
@@ -653,22 +696,20 @@ func (r *nodeRunner[O]) onError(ctx context.Context, err error) (map[string]any,
msg := sErr.Msg() msg := sErr.Msg()
switch r.errProcessType { switch r.errProcessType {
case vo.ErrorProcessTypeDefault: case vo.ErrorProcessTypeReturnDefaultData:
d := r.dataOnErr(ctx) d := r.dataOnErr(ctx)
d["errorBody"] = map[string]any{ d["errorBody"] = map[string]any{
"errorMessage": msg, "errorMessage": msg,
"errorCode": code, "errorCode": code,
} }
d["isSuccess"] = false d["isSuccess"] = false
if r.callbackEnabled { sErr = sErr.ChangeErrLevel(vo.LevelWarn)
sErr = sErr.ChangeErrLevel(vo.LevelWarn) sOutput := &nodes.StructuredCallbackOutput{
sOutput := &nodes.StructuredCallbackOutput{ Output: d,
Output: d, RawOutput: d,
RawOutput: d, Error: sErr,
Error: sErr,
}
_ = callbacks.OnEnd(ctx, sOutput)
} }
_ = callbacks.OnEnd(ctx, sOutput)
return d, true return d, true
case vo.ErrorProcessTypeExceptionBranch: case vo.ErrorProcessTypeExceptionBranch:
s := make(map[string]any) s := make(map[string]any)
@@ -677,20 +718,16 @@ func (r *nodeRunner[O]) onError(ctx context.Context, err error) (map[string]any,
"errorCode": code, "errorCode": code,
} }
s["isSuccess"] = false s["isSuccess"] = false
if r.callbackEnabled { sErr = sErr.ChangeErrLevel(vo.LevelWarn)
sErr = sErr.ChangeErrLevel(vo.LevelWarn) sOutput := &nodes.StructuredCallbackOutput{
sOutput := &nodes.StructuredCallbackOutput{ Output: s,
Output: s, RawOutput: s,
RawOutput: s, Error: sErr,
Error: sErr,
}
_ = callbacks.OnEnd(ctx, sOutput)
} }
_ = callbacks.OnEnd(ctx, sOutput)
return s, true return s, true
default: default:
if r.callbackEnabled { _ = callbacks.OnError(ctx, sErr)
_ = callbacks.OnError(ctx, sErr)
}
return nil, false return nil, false
} }
} }

View File

@@ -1,580 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package compose
import (
"context"
"fmt"
"runtime/debug"
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/batch"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/code"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/conversation"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/emitter"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/httprequester"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/intentdetector"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/json"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/knowledge"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/plugin"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/receiver"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/subworkflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/textprocessor"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableaggregator"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableassigner"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
"github.com/coze-dev/coze-studio/backend/types/errno"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
)
type NodeSchema struct {
Key vo.NodeKey `json:"key"`
Name string `json:"name"`
Type entity.NodeType `json:"type"`
// Configs are node specific configurations with pre-defined config key and config value.
// Will not participate in request-time field mapping, nor as node's static values.
// In a word, these Configs are INTERNAL to node's implementation, the workflow layer is not aware of them.
Configs any `json:"configs,omitempty"`
InputTypes map[string]*vo.TypeInfo `json:"input_types,omitempty"`
InputSources []*vo.FieldInfo `json:"input_sources,omitempty"`
OutputTypes map[string]*vo.TypeInfo `json:"output_types,omitempty"`
OutputSources []*vo.FieldInfo `json:"output_sources,omitempty"` // only applicable to composite nodes such as Batch or Loop
ExceptionConfigs *ExceptionConfig `json:"exception_configs,omitempty"` // generic configurations applicable to most nodes
StreamConfigs *StreamConfig `json:"stream_configs,omitempty"`
SubWorkflowBasic *entity.WorkflowBasic `json:"sub_workflow_basic,omitempty"`
SubWorkflowSchema *WorkflowSchema `json:"sub_workflow_schema,omitempty"`
Lambda *compose.Lambda // not serializable, used for internal test.
}
type ExceptionConfig struct {
TimeoutMS int64 `json:"timeout_ms,omitempty"` // timeout in milliseconds, 0 means no timeout
MaxRetry int64 `json:"max_retry,omitempty"` // max retry times, 0 means no retry
ProcessType *vo.ErrorProcessType `json:"process_type,omitempty"` // error process type, 0 means throw error
DataOnErr string `json:"data_on_err,omitempty"` // data to return when error, effective when ProcessType==Default occurs
}
type StreamConfig struct {
// whether this node has the ability to produce genuine streaming output.
// not include nodes that only passes stream down as they receives them
CanGeneratesStream bool `json:"can_generates_stream,omitempty"`
// whether this node prioritize streaming input over none-streaming input.
// not include nodes that can accept both and does not have preference.
RequireStreamingInput bool `json:"can_process_stream,omitempty"`
}
type Node struct {
Lambda *compose.Lambda
}
func (s *NodeSchema) New(ctx context.Context, inner compose.Runnable[map[string]any, map[string]any],
sc *WorkflowSchema, deps *dependencyInfo) (_ *Node, err error) {
defer func() {
if panicErr := recover(); panicErr != nil {
err = safego.NewPanicErr(panicErr, debug.Stack())
}
if err != nil {
err = vo.WrapIfNeeded(errno.ErrCreateNodeFail, err, errorx.KV("node_name", s.Name), errorx.KV("cause", err.Error()))
}
}()
if m := entity.NodeMetaByNodeType(s.Type); m != nil && m.InputSourceAware {
if err = s.SetFullSources(sc.GetAllNodes(), deps); err != nil {
return nil, err
}
}
switch s.Type {
case entity.NodeTypeLambda:
if s.Lambda == nil {
return nil, fmt.Errorf("lambda is not defined for NodeTypeLambda")
}
return &Node{Lambda: s.Lambda}, nil
case entity.NodeTypeLLM:
conf, err := s.ToLLMConfig(ctx)
if err != nil {
return nil, err
}
l, err := llm.New(ctx, conf)
if err != nil {
return nil, err
}
return invokableStreamableNodeWO(s, l.Chat, l.ChatStream, withCallbackOutputConverter(l.ToCallbackOutput)), nil
case entity.NodeTypeSelector:
conf := s.ToSelectorConfig()
sl, err := selector.NewSelector(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, sl.Select, withCallbackInputConverter(s.toSelectorCallbackInput(sc)), withCallbackOutputConverter(sl.ToCallbackOutput)), nil
case entity.NodeTypeBatch:
if inner == nil {
return nil, fmt.Errorf("inner workflow must not be nil when creating batch node")
}
conf, err := s.ToBatchConfig(inner)
if err != nil {
return nil, err
}
b, err := batch.NewBatch(ctx, conf)
if err != nil {
return nil, err
}
return invokableNodeWO(s, b.Execute, withCallbackInputConverter(b.ToCallbackInput)), nil
case entity.NodeTypeVariableAggregator:
conf, err := s.ToVariableAggregatorConfig()
if err != nil {
return nil, err
}
va, err := variableaggregator.NewVariableAggregator(ctx, conf)
if err != nil {
return nil, err
}
return invokableTransformableNode(s, va.Invoke, va.Transform,
withCallbackInputConverter(va.ToCallbackInput),
withCallbackOutputConverter(va.ToCallbackOutput),
withInit(va.Init)), nil
case entity.NodeTypeTextProcessor:
conf, err := s.ToTextProcessorConfig()
if err != nil {
return nil, err
}
tp, err := textprocessor.NewTextProcessor(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, tp.Invoke), nil
case entity.NodeTypeHTTPRequester:
conf, err := s.ToHTTPRequesterConfig()
if err != nil {
return nil, err
}
hr, err := httprequester.NewHTTPRequester(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, hr.Invoke, withCallbackInputConverter(hr.ToCallbackInput), withCallbackOutputConverter(hr.ToCallbackOutput)), nil
case entity.NodeTypeContinue:
i := func(ctx context.Context, in map[string]any) (map[string]any, error) {
return map[string]any{}, nil
}
return invokableNode(s, i), nil
case entity.NodeTypeBreak:
b, err := loop.NewBreak(ctx, &nodes.ParentIntermediateStore{})
if err != nil {
return nil, err
}
return invokableNode(s, b.DoBreak), nil
case entity.NodeTypeVariableAssigner:
handler := variable.GetVariableHandler()
conf, err := s.ToVariableAssignerConfig(handler)
if err != nil {
return nil, err
}
va, err := variableassigner.NewVariableAssigner(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, va.Assign), nil
case entity.NodeTypeVariableAssignerWithinLoop:
conf, err := s.ToVariableAssignerInLoopConfig()
if err != nil {
return nil, err
}
va, err := variableassigner.NewVariableAssignerInLoop(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, va.Assign), nil
case entity.NodeTypeLoop:
conf, err := s.ToLoopConfig(inner)
if err != nil {
return nil, err
}
l, err := loop.NewLoop(ctx, conf)
if err != nil {
return nil, err
}
return invokableNodeWO(s, l.Execute, withCallbackInputConverter(l.ToCallbackInput)), nil
case entity.NodeTypeQuestionAnswer:
conf, err := s.ToQAConfig(ctx)
if err != nil {
return nil, err
}
qA, err := qa.NewQuestionAnswer(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, qA.Execute, withCallbackOutputConverter(qA.ToCallbackOutput)), nil
case entity.NodeTypeInputReceiver:
conf, err := s.ToInputReceiverConfig()
if err != nil {
return nil, err
}
inputR, err := receiver.New(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, inputR.Invoke, withCallbackOutputConverter(inputR.ToCallbackOutput)), nil
case entity.NodeTypeOutputEmitter:
conf, err := s.ToOutputEmitterConfig(sc)
if err != nil {
return nil, err
}
e, err := emitter.New(ctx, conf)
if err != nil {
return nil, err
}
return invokableTransformableNode(s, e.Emit, e.EmitStream), nil
case entity.NodeTypeEntry:
conf, err := s.ToEntryConfig(ctx)
if err != nil {
return nil, err
}
e, err := entry.NewEntry(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, e.Invoke), nil
case entity.NodeTypeExit:
terminalPlan := mustGetKey[vo.TerminatePlan]("TerminalPlan", s.Configs)
if terminalPlan == vo.ReturnVariables {
i := func(ctx context.Context, in map[string]any) (map[string]any, error) {
if in == nil {
return map[string]any{}, nil
}
return in, nil
}
return invokableNode(s, i), nil
}
conf, err := s.ToOutputEmitterConfig(sc)
if err != nil {
return nil, err
}
e, err := emitter.New(ctx, conf)
if err != nil {
return nil, err
}
return invokableTransformableNode(s, e.Emit, e.EmitStream), nil
case entity.NodeTypeDatabaseCustomSQL:
conf, err := s.ToDatabaseCustomSQLConfig()
if err != nil {
return nil, err
}
sqlER, err := database.NewCustomSQL(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, sqlER.Execute), nil
case entity.NodeTypeDatabaseQuery:
conf, err := s.ToDatabaseQueryConfig()
if err != nil {
return nil, err
}
query, err := database.NewQuery(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, query.Query, withCallbackInputConverter(query.ToCallbackInput)), nil
case entity.NodeTypeDatabaseInsert:
conf, err := s.ToDatabaseInsertConfig()
if err != nil {
return nil, err
}
insert, err := database.NewInsert(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, insert.Insert, withCallbackInputConverter(insert.ToCallbackInput)), nil
case entity.NodeTypeDatabaseUpdate:
conf, err := s.ToDatabaseUpdateConfig()
if err != nil {
return nil, err
}
update, err := database.NewUpdate(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, update.Update, withCallbackInputConverter(update.ToCallbackInput)), nil
case entity.NodeTypeDatabaseDelete:
conf, err := s.ToDatabaseDeleteConfig()
if err != nil {
return nil, err
}
del, err := database.NewDelete(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, del.Delete, withCallbackInputConverter(del.ToCallbackInput)), nil
case entity.NodeTypeKnowledgeIndexer:
conf, err := s.ToKnowledgeIndexerConfig()
if err != nil {
return nil, err
}
w, err := knowledge.NewKnowledgeIndexer(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, w.Store), nil
case entity.NodeTypeKnowledgeRetriever:
conf, err := s.ToKnowledgeRetrieveConfig()
if err != nil {
return nil, err
}
r, err := knowledge.NewKnowledgeRetrieve(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, r.Retrieve), nil
case entity.NodeTypeKnowledgeDeleter:
conf, err := s.ToKnowledgeDeleterConfig()
if err != nil {
return nil, err
}
r, err := knowledge.NewKnowledgeDeleter(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, r.Delete), nil
case entity.NodeTypeCodeRunner:
conf, err := s.ToCodeRunnerConfig()
if err != nil {
return nil, err
}
r, err := code.NewCodeRunner(ctx, conf)
if err != nil {
return nil, err
}
initFn := func(ctx context.Context) (context.Context, error) {
return ctxcache.Init(ctx), nil
}
return invokableNode(s, r.RunCode, withCallbackOutputConverter(r.ToCallbackOutput), withInit(initFn)), nil
case entity.NodeTypePlugin:
conf, err := s.ToPluginConfig()
if err != nil {
return nil, err
}
r, err := plugin.NewPlugin(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, r.Invoke), nil
case entity.NodeTypeCreateConversation:
conf, err := s.ToCreateConversationConfig()
if err != nil {
return nil, err
}
r, err := conversation.NewCreateConversation(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, r.Create), nil
case entity.NodeTypeMessageList:
conf, err := s.ToMessageListConfig()
if err != nil {
return nil, err
}
r, err := conversation.NewMessageList(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, r.List), nil
case entity.NodeTypeClearMessage:
conf, err := s.ToClearMessageConfig()
if err != nil {
return nil, err
}
r, err := conversation.NewClearMessage(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, r.Clear), nil
case entity.NodeTypeIntentDetector:
conf, err := s.ToIntentDetectorConfig(ctx)
if err != nil {
return nil, err
}
r, err := intentdetector.NewIntentDetector(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, r.Invoke), nil
case entity.NodeTypeSubWorkflow:
conf, err := s.ToSubWorkflowConfig(ctx, sc.requireCheckPoint)
if err != nil {
return nil, err
}
r, err := subworkflow.NewSubWorkflow(ctx, conf)
if err != nil {
return nil, err
}
return invokableStreamableNodeWO(s, r.Invoke, r.Stream), nil
case entity.NodeTypeJsonSerialization:
conf, err := s.ToJsonSerializationConfig()
if err != nil {
return nil, err
}
js, err := json.NewJsonSerializer(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, js.Invoke), nil
case entity.NodeTypeJsonDeserialization:
conf, err := s.ToJsonDeserializationConfig()
if err != nil {
return nil, err
}
jd, err := json.NewJsonDeserializer(ctx, conf)
if err != nil {
return nil, err
}
initFn := func(ctx context.Context) (context.Context, error) {
return ctxcache.Init(ctx), nil
}
return invokableNode(s, jd.Invoke, withCallbackOutputConverter(jd.ToCallbackOutput), withInit(initFn)), nil
default:
panic("not implemented")
}
}
func (s *NodeSchema) IsEnableUserQuery() bool {
if s == nil {
return false
}
if s.Type != entity.NodeTypeEntry {
return false
}
if len(s.OutputSources) == 0 {
return false
}
for _, source := range s.OutputSources {
fieldPath := source.Path
if len(fieldPath) == 1 && (fieldPath[0] == "BOT_USER_INPUT" || fieldPath[0] == "USER_INPUT") {
return true
}
}
return false
}
func (s *NodeSchema) IsEnableChatHistory() bool {
if s == nil {
return false
}
switch s.Type {
case entity.NodeTypeLLM:
llmParam := mustGetKey[*model.LLMParams]("LLMParams", s.Configs)
return llmParam.EnableChatHistory
case entity.NodeTypeIntentDetector:
llmParam := mustGetKey[*model.LLMParams]("LLMParams", s.Configs)
return llmParam.EnableChatHistory
default:
return false
}
}
func (s *NodeSchema) IsRefGlobalVariable() bool {
for _, source := range s.InputSources {
if source.IsRefGlobalVariable() {
return true
}
}
for _, source := range s.OutputSources {
if source.IsRefGlobalVariable() {
return true
}
}
return false
}
func (s *NodeSchema) requireCheckpoint() bool {
if s.Type == entity.NodeTypeQuestionAnswer || s.Type == entity.NodeTypeInputReceiver {
return true
}
if s.Type == entity.NodeTypeLLM {
fcParams := getKeyOrZero[*vo.FCParam]("FCParam", s.Configs)
if fcParams != nil && fcParams.WorkflowFCParam != nil {
return true
}
}
if s.Type == entity.NodeTypeSubWorkflow {
s.SubWorkflowSchema.Init()
if s.SubWorkflowSchema.requireCheckPoint {
return true
}
}
return false
}

View File

@@ -32,8 +32,10 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/receiver" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/receiver"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/sonic" "github.com/coze-dev/coze-studio/backend/pkg/sonic"
) )
@@ -46,9 +48,9 @@ type State struct {
InterruptEvents map[vo.NodeKey]*entity.InterruptEvent `json:"interrupt_events,omitempty"` InterruptEvents map[vo.NodeKey]*entity.InterruptEvent `json:"interrupt_events,omitempty"`
NestedWorkflowStates map[vo.NodeKey]*nodes.NestedWorkflowState `json:"nested_workflow_states,omitempty"` NestedWorkflowStates map[vo.NodeKey]*nodes.NestedWorkflowState `json:"nested_workflow_states,omitempty"`
ExecutedNodes map[vo.NodeKey]bool `json:"executed_nodes,omitempty"` ExecutedNodes map[vo.NodeKey]bool `json:"executed_nodes,omitempty"`
SourceInfos map[vo.NodeKey]map[string]*nodes.SourceInfo `json:"source_infos,omitempty"` SourceInfos map[vo.NodeKey]map[string]*schema2.SourceInfo `json:"source_infos,omitempty"`
GroupChoices map[vo.NodeKey]map[string]int `json:"group_choices,omitempty"` GroupChoices map[vo.NodeKey]map[string]int `json:"group_choices,omitempty"`
ToolInterruptEvents map[vo.NodeKey]map[string] /*ToolCallID*/ *entity.ToolInterruptEvent `json:"tool_interrupt_events,omitempty"` ToolInterruptEvents map[vo.NodeKey]map[string] /*ToolCallID*/ *entity.ToolInterruptEvent `json:"tool_interrupt_events,omitempty"`
LLMToResumeData map[vo.NodeKey]string `json:"llm_to_resume_data,omitempty"` LLMToResumeData map[vo.NodeKey]string `json:"llm_to_resume_data,omitempty"`
@@ -71,8 +73,8 @@ func init() {
_ = compose.RegisterSerializableType[*model.TokenUsage]("model_token_usage") _ = compose.RegisterSerializableType[*model.TokenUsage]("model_token_usage")
_ = compose.RegisterSerializableType[*nodes.NestedWorkflowState]("composite_state") _ = compose.RegisterSerializableType[*nodes.NestedWorkflowState]("composite_state")
_ = compose.RegisterSerializableType[*compose.InterruptInfo]("interrupt_info") _ = compose.RegisterSerializableType[*compose.InterruptInfo]("interrupt_info")
_ = compose.RegisterSerializableType[*nodes.SourceInfo]("source_info") _ = compose.RegisterSerializableType[*schema2.SourceInfo]("source_info")
_ = compose.RegisterSerializableType[nodes.FieldStreamType]("field_stream_type") _ = compose.RegisterSerializableType[schema2.FieldStreamType]("field_stream_type")
_ = compose.RegisterSerializableType[compose.FieldPath]("field_path") _ = compose.RegisterSerializableType[compose.FieldPath]("field_path")
_ = compose.RegisterSerializableType[*entity.WorkflowBasic]("workflow_basic") _ = compose.RegisterSerializableType[*entity.WorkflowBasic]("workflow_basic")
_ = compose.RegisterSerializableType[vo.TerminatePlan]("terminate_plan") _ = compose.RegisterSerializableType[vo.TerminatePlan]("terminate_plan")
@@ -162,41 +164,41 @@ func (s *State) GetDynamicChoice(nodeKey vo.NodeKey) map[string]int {
return s.GroupChoices[nodeKey] return s.GroupChoices[nodeKey]
} }
func (s *State) GetDynamicStreamType(nodeKey vo.NodeKey, group string) (nodes.FieldStreamType, error) { func (s *State) GetDynamicStreamType(nodeKey vo.NodeKey, group string) (schema2.FieldStreamType, error) {
choices, ok := s.GroupChoices[nodeKey] choices, ok := s.GroupChoices[nodeKey]
if !ok { if !ok {
return nodes.FieldMaybeStream, fmt.Errorf("choice not found for node %s", nodeKey) return schema2.FieldMaybeStream, fmt.Errorf("choice not found for node %s", nodeKey)
} }
choice, ok := choices[group] choice, ok := choices[group]
if !ok { if !ok {
return nodes.FieldMaybeStream, fmt.Errorf("choice not found for node %s and group %s", nodeKey, group) return schema2.FieldMaybeStream, fmt.Errorf("choice not found for node %s and group %s", nodeKey, group)
} }
if choice == -1 { // this group picks none of the elements if choice == -1 { // this group picks none of the elements
return nodes.FieldNotStream, nil return schema2.FieldNotStream, nil
} }
sInfos, ok := s.SourceInfos[nodeKey] sInfos, ok := s.SourceInfos[nodeKey]
if !ok { if !ok {
return nodes.FieldMaybeStream, fmt.Errorf("source infos not found for node %s", nodeKey) return schema2.FieldMaybeStream, fmt.Errorf("source infos not found for node %s", nodeKey)
} }
groupInfo, ok := sInfos[group] groupInfo, ok := sInfos[group]
if !ok { if !ok {
return nodes.FieldMaybeStream, fmt.Errorf("source infos not found for node %s and group %s", nodeKey, group) return schema2.FieldMaybeStream, fmt.Errorf("source infos not found for node %s and group %s", nodeKey, group)
} }
if groupInfo.SubSources == nil { if groupInfo.SubSources == nil {
return nodes.FieldNotStream, fmt.Errorf("dynamic group %s of node %s does not contain any sub sources", group, nodeKey) return schema2.FieldNotStream, fmt.Errorf("dynamic group %s of node %s does not contain any sub sources", group, nodeKey)
} }
subInfo, ok := groupInfo.SubSources[strconv.Itoa(choice)] subInfo, ok := groupInfo.SubSources[strconv.Itoa(choice)]
if !ok { if !ok {
return nodes.FieldNotStream, fmt.Errorf("dynamic group %s of node %s does not contain sub source for choice %d", group, nodeKey, choice) return schema2.FieldNotStream, fmt.Errorf("dynamic group %s of node %s does not contain sub source for choice %d", group, nodeKey, choice)
} }
if subInfo.FieldType != nodes.FieldMaybeStream { if subInfo.FieldType != schema2.FieldMaybeStream {
return subInfo.FieldType, nil return subInfo.FieldType, nil
} }
@@ -211,8 +213,8 @@ func (s *State) GetDynamicStreamType(nodeKey vo.NodeKey, group string) (nodes.Fi
return s.GetDynamicStreamType(subInfo.FromNodeKey, subInfo.FromPath[0]) return s.GetDynamicStreamType(subInfo.FromNodeKey, subInfo.FromPath[0])
} }
func (s *State) GetAllDynamicStreamTypes(nodeKey vo.NodeKey) (map[string]nodes.FieldStreamType, error) { func (s *State) GetAllDynamicStreamTypes(nodeKey vo.NodeKey) (map[string]schema2.FieldStreamType, error) {
result := make(map[string]nodes.FieldStreamType) result := make(map[string]schema2.FieldStreamType)
choices, ok := s.GroupChoices[nodeKey] choices, ok := s.GroupChoices[nodeKey]
if !ok { if !ok {
return result, nil return result, nil
@@ -269,7 +271,7 @@ func GenState() compose.GenLocalState[*State] {
InterruptEvents: make(map[vo.NodeKey]*entity.InterruptEvent), InterruptEvents: make(map[vo.NodeKey]*entity.InterruptEvent),
NestedWorkflowStates: make(map[vo.NodeKey]*nodes.NestedWorkflowState), NestedWorkflowStates: make(map[vo.NodeKey]*nodes.NestedWorkflowState),
ExecutedNodes: make(map[vo.NodeKey]bool), ExecutedNodes: make(map[vo.NodeKey]bool),
SourceInfos: make(map[vo.NodeKey]map[string]*nodes.SourceInfo), SourceInfos: make(map[vo.NodeKey]map[string]*schema2.SourceInfo),
GroupChoices: make(map[vo.NodeKey]map[string]int), GroupChoices: make(map[vo.NodeKey]map[string]int),
ToolInterruptEvents: make(map[vo.NodeKey]map[string]*entity.ToolInterruptEvent), ToolInterruptEvents: make(map[vo.NodeKey]map[string]*entity.ToolInterruptEvent),
LLMToResumeData: make(map[vo.NodeKey]string), LLMToResumeData: make(map[vo.NodeKey]string),
@@ -277,7 +279,7 @@ func GenState() compose.GenLocalState[*State] {
} }
} }
func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt { func statePreHandler(s *schema2.NodeSchema, stream bool) compose.GraphAddNodeOpt {
var ( var (
handlers []compose.StatePreHandler[map[string]any, *State] handlers []compose.StatePreHandler[map[string]any, *State]
streamHandlers []compose.StreamStatePreHandler[map[string]any, *State] streamHandlers []compose.StreamStatePreHandler[map[string]any, *State]
@@ -314,7 +316,7 @@ func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt {
} }
return in, nil return in, nil
}) })
} else if s.Type == entity.NodeTypeBatch || s.Type == entity.NodeTypeLoop { } else if entity.NodeMetaByNodeType(s.Type).IsComposite {
handlers = append(handlers, func(ctx context.Context, in map[string]any, state *State) (map[string]any, error) { handlers = append(handlers, func(ctx context.Context, in map[string]any, state *State) (map[string]any, error) {
if _, ok := state.Inputs[s.Key]; !ok { // first execution, store input for potential resume later if _, ok := state.Inputs[s.Key]; !ok { // first execution, store input for potential resume later
state.Inputs[s.Key] = in state.Inputs[s.Key] = in
@@ -329,7 +331,7 @@ func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt {
} }
if len(handlers) > 0 || !stream { if len(handlers) > 0 || !stream {
handlerForVars := s.statePreHandlerForVars() handlerForVars := statePreHandlerForVars(s)
if handlerForVars != nil { if handlerForVars != nil {
handlers = append(handlers, handlerForVars) handlers = append(handlers, handlerForVars)
} }
@@ -349,12 +351,12 @@ func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt {
if s.Type == entity.NodeTypeVariableAggregator { if s.Type == entity.NodeTypeVariableAggregator {
streamHandlers = append(streamHandlers, func(ctx context.Context, in *schema.StreamReader[map[string]any], state *State) (*schema.StreamReader[map[string]any], error) { streamHandlers = append(streamHandlers, func(ctx context.Context, in *schema.StreamReader[map[string]any], state *State) (*schema.StreamReader[map[string]any], error) {
state.SourceInfos[s.Key] = mustGetKey[map[string]*nodes.SourceInfo]("FullSources", s.Configs) state.SourceInfos[s.Key] = s.FullSources
return in, nil return in, nil
}) })
} }
handlerForVars := s.streamStatePreHandlerForVars() handlerForVars := streamStatePreHandlerForVars(s)
if handlerForVars != nil { if handlerForVars != nil {
streamHandlers = append(streamHandlers, handlerForVars) streamHandlers = append(streamHandlers, handlerForVars)
} }
@@ -381,7 +383,7 @@ func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt {
return nil return nil
} }
func (s *NodeSchema) statePreHandlerForVars() compose.StatePreHandler[map[string]any, *State] { func statePreHandlerForVars(s *schema2.NodeSchema) compose.StatePreHandler[map[string]any, *State] {
// checkout the node's inputs, if it has any variable, use the state's variableHandler to get the variables and set them to the input // checkout the node's inputs, if it has any variable, use the state's variableHandler to get the variables and set them to the input
var vars []*vo.FieldInfo var vars []*vo.FieldInfo
for _, input := range s.InputSources { for _, input := range s.InputSources {
@@ -456,7 +458,7 @@ func (s *NodeSchema) statePreHandlerForVars() compose.StatePreHandler[map[string
} }
} }
func (s *NodeSchema) streamStatePreHandlerForVars() compose.StreamStatePreHandler[map[string]any, *State] { func streamStatePreHandlerForVars(s *schema2.NodeSchema) compose.StreamStatePreHandler[map[string]any, *State] {
// checkout the node's inputs, if it has any variables, get the variables and merge them with the input // checkout the node's inputs, if it has any variables, get the variables and merge them with the input
var vars []*vo.FieldInfo var vars []*vo.FieldInfo
for _, input := range s.InputSources { for _, input := range s.InputSources {
@@ -533,7 +535,7 @@ func (s *NodeSchema) streamStatePreHandlerForVars() compose.StreamStatePreHandle
} }
} }
func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamStatePreHandler[map[string]any, *State] { func streamStatePreHandlerForStreamSources(s *schema2.NodeSchema) compose.StreamStatePreHandler[map[string]any, *State] {
// if it does not have source info, do not add this pre handler // if it does not have source info, do not add this pre handler
if s.Configs == nil { if s.Configs == nil {
return nil return nil
@@ -543,7 +545,7 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
case entity.NodeTypeVariableAggregator, entity.NodeTypeOutputEmitter: case entity.NodeTypeVariableAggregator, entity.NodeTypeOutputEmitter:
return nil return nil
case entity.NodeTypeExit: case entity.NodeTypeExit:
terminatePlan := mustGetKey[vo.TerminatePlan]("TerminalPlan", s.Configs) terminatePlan := s.Configs.(*exit.Config).TerminatePlan
if terminatePlan != vo.ReturnVariables { if terminatePlan != vo.ReturnVariables {
return nil return nil
} }
@@ -551,7 +553,7 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
// all other node can only accept non-stream inputs, relying on Eino's automatically stream concatenation. // all other node can only accept non-stream inputs, relying on Eino's automatically stream concatenation.
} }
sourceInfo := getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs) sourceInfo := s.FullSources
if len(sourceInfo) == 0 { if len(sourceInfo) == 0 {
return nil return nil
} }
@@ -566,10 +568,10 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
var ( var (
anyStream bool anyStream bool
checker func(source *nodes.SourceInfo) bool checker func(source *schema2.SourceInfo) bool
) )
checker = func(source *nodes.SourceInfo) bool { checker = func(source *schema2.SourceInfo) bool {
if source.FieldType != nodes.FieldNotStream { if source.FieldType != schema2.FieldNotStream {
return true return true
} }
for _, subSource := range source.SubSources { for _, subSource := range source.SubSources {
@@ -594,8 +596,8 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
return func(ctx context.Context, in *schema.StreamReader[map[string]any], state *State) (*schema.StreamReader[map[string]any], error) { return func(ctx context.Context, in *schema.StreamReader[map[string]any], state *State) (*schema.StreamReader[map[string]any], error) {
resolved := map[string]resolvedStreamSource{} resolved := map[string]resolvedStreamSource{}
var resolver func(source nodes.SourceInfo) (result *resolvedStreamSource, err error) var resolver func(source schema2.SourceInfo) (result *resolvedStreamSource, err error)
resolver = func(source nodes.SourceInfo) (result *resolvedStreamSource, err error) { resolver = func(source schema2.SourceInfo) (result *resolvedStreamSource, err error) {
if source.IsIntermediate { if source.IsIntermediate {
result = &resolvedStreamSource{ result = &resolvedStreamSource{
intermediate: true, intermediate: true,
@@ -615,14 +617,14 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
} }
streamType := source.FieldType streamType := source.FieldType
if streamType == nodes.FieldMaybeStream { if streamType == schema2.FieldMaybeStream {
streamType, err = state.GetDynamicStreamType(source.FromNodeKey, source.FromPath[0]) streamType, err = state.GetDynamicStreamType(source.FromNodeKey, source.FromPath[0])
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
if streamType == nodes.FieldNotStream { if streamType == schema2.FieldNotStream {
return nil, nil return nil, nil
} }
@@ -690,7 +692,7 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
} }
} }
func (s *NodeSchema) StatePostHandler(stream bool) compose.GraphAddNodeOpt { func statePostHandler(s *schema2.NodeSchema, stream bool) compose.GraphAddNodeOpt {
var ( var (
handlers []compose.StatePostHandler[map[string]any, *State] handlers []compose.StatePostHandler[map[string]any, *State]
streamHandlers []compose.StreamStatePostHandler[map[string]any, *State] streamHandlers []compose.StreamStatePostHandler[map[string]any, *State]
@@ -702,7 +704,7 @@ func (s *NodeSchema) StatePostHandler(stream bool) compose.GraphAddNodeOpt {
return out, nil return out, nil
}) })
forVars := s.streamStatePostHandlerForVars() forVars := streamStatePostHandlerForVars(s)
if forVars != nil { if forVars != nil {
streamHandlers = append(streamHandlers, forVars) streamHandlers = append(streamHandlers, forVars)
} }
@@ -725,7 +727,7 @@ func (s *NodeSchema) StatePostHandler(stream bool) compose.GraphAddNodeOpt {
return out, nil return out, nil
}) })
forVars := s.statePostHandlerForVars() forVars := statePostHandlerForVars(s)
if forVars != nil { if forVars != nil {
handlers = append(handlers, forVars) handlers = append(handlers, forVars)
} }
@@ -745,7 +747,7 @@ func (s *NodeSchema) StatePostHandler(stream bool) compose.GraphAddNodeOpt {
return compose.WithStatePostHandler(handler) return compose.WithStatePostHandler(handler)
} }
func (s *NodeSchema) statePostHandlerForVars() compose.StatePostHandler[map[string]any, *State] { func statePostHandlerForVars(s *schema2.NodeSchema) compose.StatePostHandler[map[string]any, *State] {
// checkout the node's output sources, if it has any variable, // checkout the node's output sources, if it has any variable,
// use the state's variableHandler to get the variables and set them to the output // use the state's variableHandler to get the variables and set them to the output
var vars []*vo.FieldInfo var vars []*vo.FieldInfo
@@ -823,7 +825,7 @@ func (s *NodeSchema) statePostHandlerForVars() compose.StatePostHandler[map[stri
} }
} }
func (s *NodeSchema) streamStatePostHandlerForVars() compose.StreamStatePostHandler[map[string]any, *State] { func streamStatePostHandlerForVars(s *schema2.NodeSchema) compose.StreamStatePostHandler[map[string]any, *State] {
// checkout the node's output sources, if it has any variables, get the variables and merge them with the output // checkout the node's output sources, if it has any variables, get the variables and merge them with the output
var vars []*vo.FieldInfo var vars []*vo.FieldInfo
for _, output := range s.OutputSources { for _, output := range s.OutputSources {

View File

@@ -21,19 +21,20 @@ import (
"github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
) )
// SetFullSources calculates REAL input sources for a node. // GetFullSources calculates REAL input sources for a node.
// It may be different from a NodeSchema's InputSources because of the following reasons: // It may be different from a NodeSchema's InputSources because of the following reasons:
// 1. a inner node under composite node may refer to a field from a node in its parent workflow, // 1. a inner node under composite node may refer to a field from a node in its parent workflow,
// this is instead routed to and sourced from the inner workflow's start node. // this is instead routed to and sourced from the inner workflow's start node.
// 2. at the same time, the composite node needs to delegate the input source to the inner workflow. // 2. at the same time, the composite node needs to delegate the input source to the inner workflow.
// 3. also, some node may have implicit input sources not defined in its NodeSchema's InputSources. // 3. also, some node may have implicit input sources not defined in its NodeSchema's InputSources.
func (s *NodeSchema) SetFullSources(allNS map[vo.NodeKey]*NodeSchema, dep *dependencyInfo) error { func GetFullSources(s *schema.NodeSchema, sc *schema.WorkflowSchema, dep *dependencyInfo) (
fullSource := make(map[string]*nodes.SourceInfo) map[string]*schema.SourceInfo, error) {
fullSource := make(map[string]*schema.SourceInfo)
var fieldInfos []vo.FieldInfo var fieldInfos []vo.FieldInfo
for _, s := range dep.staticValues { for _, s := range dep.staticValues {
fieldInfos = append(fieldInfos, vo.FieldInfo{ fieldInfos = append(fieldInfos, vo.FieldInfo{
@@ -113,14 +114,14 @@ func (s *NodeSchema) SetFullSources(allNS map[vo.NodeKey]*NodeSchema, dep *depen
tInfo = tInfo.Properties[path[j]] tInfo = tInfo.Properties[path[j]]
} }
if current, ok := currentSource[path[j]]; !ok { if current, ok := currentSource[path[j]]; !ok {
currentSource[path[j]] = &nodes.SourceInfo{ currentSource[path[j]] = &schema.SourceInfo{
IsIntermediate: true, IsIntermediate: true,
FieldType: nodes.FieldNotStream, FieldType: schema.FieldNotStream,
TypeInfo: tInfo, TypeInfo: tInfo,
SubSources: make(map[string]*nodes.SourceInfo), SubSources: make(map[string]*schema.SourceInfo),
} }
} else if !current.IsIntermediate { } else if !current.IsIntermediate {
return fmt.Errorf("existing sourceInfo for path %s is not intermediate, conflict", path[:j+1]) return nil, fmt.Errorf("existing sourceInfo for path %s is not intermediate, conflict", path[:j+1])
} }
currentSource = currentSource[path[j]].SubSources currentSource = currentSource[path[j]].SubSources
@@ -135,9 +136,9 @@ func (s *NodeSchema) SetFullSources(allNS map[vo.NodeKey]*NodeSchema, dep *depen
// static values or variables // static values or variables
if fInfo.Source.Ref == nil || fInfo.Source.Ref.FromNodeKey == "" { if fInfo.Source.Ref == nil || fInfo.Source.Ref.FromNodeKey == "" {
currentSource[lastPath] = &nodes.SourceInfo{ currentSource[lastPath] = &schema.SourceInfo{
IsIntermediate: false, IsIntermediate: false,
FieldType: nodes.FieldNotStream, FieldType: schema.FieldNotStream,
TypeInfo: tInfo, TypeInfo: tInfo,
} }
continue continue
@@ -145,25 +146,25 @@ func (s *NodeSchema) SetFullSources(allNS map[vo.NodeKey]*NodeSchema, dep *depen
fromNodeKey := fInfo.Source.Ref.FromNodeKey fromNodeKey := fInfo.Source.Ref.FromNodeKey
var ( var (
streamType nodes.FieldStreamType streamType schema.FieldStreamType
err error err error
) )
if len(fromNodeKey) > 0 { if len(fromNodeKey) > 0 {
if fromNodeKey == compose.START { if fromNodeKey == compose.START {
streamType = nodes.FieldNotStream // TODO: set start node to not stream for now until composite node supports transform streamType = schema.FieldNotStream // TODO: set start node to not stream for now until composite node supports transform
} else { } else {
fromNode, ok := allNS[fromNodeKey] fromNode := sc.GetNode(fromNodeKey)
if !ok { if fromNode == nil {
return fmt.Errorf("node %s not found", fromNodeKey) return nil, fmt.Errorf("node %s not found", fromNodeKey)
} }
streamType, err = fromNode.IsStreamingField(fInfo.Source.Ref.FromPath, allNS) streamType, err = nodes.IsStreamingField(fromNode, fInfo.Source.Ref.FromPath, sc)
if err != nil { if err != nil {
return err return nil, err
} }
} }
} }
currentSource[lastPath] = &nodes.SourceInfo{ currentSource[lastPath] = &schema.SourceInfo{
IsIntermediate: false, IsIntermediate: false,
FieldType: streamType, FieldType: streamType,
FromNodeKey: fromNodeKey, FromNodeKey: fromNodeKey,
@@ -172,121 +173,5 @@ func (s *NodeSchema) SetFullSources(allNS map[vo.NodeKey]*NodeSchema, dep *depen
} }
} }
s.Configs.(map[string]any)["FullSources"] = fullSource return fullSource, nil
return nil
}
func (s *NodeSchema) IsStreamingField(path compose.FieldPath, allNS map[vo.NodeKey]*NodeSchema) (nodes.FieldStreamType, error) {
if s.Type == entity.NodeTypeExit {
if mustGetKey[nodes.Mode]("Mode", s.Configs) == nodes.Streaming {
if len(path) == 1 && path[0] == "output" {
return nodes.FieldIsStream, nil
}
}
return nodes.FieldNotStream, nil
} else if s.Type == entity.NodeTypeSubWorkflow { // TODO: why not use sub workflow's Mode configuration directly?
subSC := s.SubWorkflowSchema
subExit := subSC.GetNode(entity.ExitNodeKey)
subStreamType, err := subExit.IsStreamingField(path, nil)
if err != nil {
return nodes.FieldNotStream, err
}
return subStreamType, nil
} else if s.Type == entity.NodeTypeVariableAggregator {
if len(path) == 2 { // asking about a specific index within a group
for _, fInfo := range s.InputSources {
if len(fInfo.Path) == len(path) {
equal := true
for i := range fInfo.Path {
if fInfo.Path[i] != path[i] {
equal = false
break
}
}
if equal {
if fInfo.Source.Ref == nil || fInfo.Source.Ref.FromNodeKey == "" {
return nodes.FieldNotStream, nil
}
fromNodeKey := fInfo.Source.Ref.FromNodeKey
fromNode, ok := allNS[fromNodeKey]
if !ok {
return nodes.FieldNotStream, fmt.Errorf("node %s not found", fromNodeKey)
}
return fromNode.IsStreamingField(fInfo.Source.Ref.FromPath, allNS)
}
}
}
} else if len(path) == 1 { // asking about the entire group
var streamCount, notStreamCount int
for _, fInfo := range s.InputSources {
if fInfo.Path[0] == path[0] { // belong to the group
if fInfo.Source.Ref != nil && len(fInfo.Source.Ref.FromNodeKey) > 0 {
fromNode, ok := allNS[fInfo.Source.Ref.FromNodeKey]
if !ok {
return nodes.FieldNotStream, fmt.Errorf("node %s not found", fInfo.Source.Ref.FromNodeKey)
}
subStreamType, err := fromNode.IsStreamingField(fInfo.Source.Ref.FromPath, allNS)
if err != nil {
return nodes.FieldNotStream, err
}
if subStreamType == nodes.FieldMaybeStream {
return nodes.FieldMaybeStream, nil
} else if subStreamType == nodes.FieldIsStream {
streamCount++
} else {
notStreamCount++
}
}
}
}
if streamCount > 0 && notStreamCount == 0 {
return nodes.FieldIsStream, nil
}
if streamCount == 0 && notStreamCount > 0 {
return nodes.FieldNotStream, nil
}
return nodes.FieldMaybeStream, nil
}
}
if s.Type != entity.NodeTypeLLM {
return nodes.FieldNotStream, nil
}
if len(path) != 1 {
return nodes.FieldNotStream, nil
}
outputs := s.OutputTypes
if len(outputs) != 1 && len(outputs) != 2 {
return nodes.FieldNotStream, nil
}
var outputKey string
for key, output := range outputs {
if output.Type != vo.DataTypeString {
return nodes.FieldNotStream, nil
}
if key != "reasoning_content" {
if len(outputKey) > 0 {
return nodes.FieldNotStream, nil
}
outputKey = key
}
}
field := path[0]
if field == "reasoning_content" || field == outputKey {
return nodes.FieldIsStream, nil
}
return nodes.FieldNotStream, nil
} }

View File

@@ -28,6 +28,9 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose" compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/batch" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/batch"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
) )
func TestBatch(t *testing.T) { func TestBatch(t *testing.T) {
@@ -52,7 +55,7 @@ func TestBatch(t *testing.T) {
return in, nil return in, nil
} }
lambdaNode1 := &compose2.NodeSchema{ lambdaNode1 := &schema.NodeSchema{
Key: "lambda", Key: "lambda",
Type: entity.NodeTypeLambda, Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(lambda1), Lambda: compose.InvokableLambda(lambda1),
@@ -86,7 +89,7 @@ func TestBatch(t *testing.T) {
}, },
}, },
} }
lambdaNode2 := &compose2.NodeSchema{ lambdaNode2 := &schema.NodeSchema{
Key: "index", Key: "index",
Type: entity.NodeTypeLambda, Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(lambda2), Lambda: compose.InvokableLambda(lambda2),
@@ -103,7 +106,7 @@ func TestBatch(t *testing.T) {
}, },
} }
lambdaNode3 := &compose2.NodeSchema{ lambdaNode3 := &schema.NodeSchema{
Key: "consumer", Key: "consumer",
Type: entity.NodeTypeLambda, Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(lambda3), Lambda: compose.InvokableLambda(lambda3),
@@ -135,23 +138,22 @@ func TestBatch(t *testing.T) {
}, },
} }
entry := &compose2.NodeSchema{ entryN := &schema.NodeSchema{
Key: entity.EntryNodeKey, Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry, Type: entity.NodeTypeEntry,
Configs: map[string]any{ Configs: &entry.Config{},
"DefaultValues": map[string]any{},
},
} }
ns := &compose2.NodeSchema{ ns := &schema.NodeSchema{
Key: "batch_node_key", Key: "batch_node_key",
Type: entity.NodeTypeBatch, Type: entity.NodeTypeBatch,
Configs: &batch.Config{},
InputSources: []*vo.FieldInfo{ InputSources: []*vo.FieldInfo{
{ {
Path: compose.FieldPath{"array_1"}, Path: compose.FieldPath{"array_1"},
Source: vo.FieldSource{ Source: vo.FieldSource{
Ref: &vo.Reference{ Ref: &vo.Reference{
FromNodeKey: entry.Key, FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"array_1"}, FromPath: compose.FieldPath{"array_1"},
}, },
}, },
@@ -160,7 +162,7 @@ func TestBatch(t *testing.T) {
Path: compose.FieldPath{"array_2"}, Path: compose.FieldPath{"array_2"},
Source: vo.FieldSource{ Source: vo.FieldSource{
Ref: &vo.Reference{ Ref: &vo.Reference{
FromNodeKey: entry.Key, FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"array_2"}, FromPath: compose.FieldPath{"array_2"},
}, },
}, },
@@ -214,11 +216,11 @@ func TestBatch(t *testing.T) {
}, },
} }
exit := &compose2.NodeSchema{ exitN := &schema.NodeSchema{
Key: entity.ExitNodeKey, Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit, Type: entity.NodeTypeExit,
Configs: map[string]any{ Configs: &exit.Config{
"TerminalPlan": vo.ReturnVariables, TerminatePlan: vo.ReturnVariables,
}, },
InputSources: []*vo.FieldInfo{ InputSources: []*vo.FieldInfo{
{ {
@@ -246,18 +248,18 @@ func TestBatch(t *testing.T) {
return map[string]any{"success": true}, nil return map[string]any{"success": true}, nil
} }
parentLambdaNode := &compose2.NodeSchema{ parentLambdaNode := &schema.NodeSchema{
Key: "parent_predecessor_1", Key: "parent_predecessor_1",
Type: entity.NodeTypeLambda, Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(parentLambda), Lambda: compose.InvokableLambda(parentLambda),
} }
ws := &compose2.WorkflowSchema{ ws := &schema.WorkflowSchema{
Nodes: []*compose2.NodeSchema{ Nodes: []*schema.NodeSchema{
entry, entryN,
parentLambdaNode, parentLambdaNode,
ns, ns,
exit, exitN,
lambdaNode1, lambdaNode1,
lambdaNode2, lambdaNode2,
lambdaNode3, lambdaNode3,
@@ -267,7 +269,7 @@ func TestBatch(t *testing.T) {
"index": "batch_node_key", "index": "batch_node_key",
"consumer": "batch_node_key", "consumer": "batch_node_key",
}, },
Connections: []*compose2.Connection{ Connections: []*schema.Connection{
{ {
FromNode: entity.EntryNodeKey, FromNode: entity.EntryNodeKey,
ToNode: "parent_predecessor_1", ToNode: "parent_predecessor_1",

View File

@@ -40,7 +40,11 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose" compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/emitter"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr" "github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
"github.com/coze-dev/coze-studio/backend/internal/testutil" "github.com/coze-dev/coze-studio/backend/internal/testutil"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache" "github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
@@ -108,22 +112,20 @@ func TestLLM(t *testing.T) {
} }
} }
entry := &compose2.NodeSchema{ entryN := &schema2.NodeSchema{
Key: entity.EntryNodeKey, Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry, Type: entity.NodeTypeEntry,
Configs: map[string]any{ Configs: &entry.Config{},
"DefaultValues": map[string]any{},
},
} }
llmNode := &compose2.NodeSchema{ llmNode := &schema2.NodeSchema{
Key: "llm_node_key", Key: "llm_node_key",
Type: entity.NodeTypeLLM, Type: entity.NodeTypeLLM,
Configs: map[string]any{ Configs: &llm.Config{
"SystemPrompt": "{{sys_prompt}}", SystemPrompt: "{{sys_prompt}}",
"UserPrompt": "{{query}}", UserPrompt: "{{query}}",
"OutputFormat": llm.FormatText, OutputFormat: llm.FormatText,
"LLMParams": &model.LLMParams{ LLMParams: &model.LLMParams{
ModelName: modelName, ModelName: modelName,
}, },
}, },
@@ -132,7 +134,7 @@ func TestLLM(t *testing.T) {
Path: compose.FieldPath{"sys_prompt"}, Path: compose.FieldPath{"sys_prompt"},
Source: vo.FieldSource{ Source: vo.FieldSource{
Ref: &vo.Reference{ Ref: &vo.Reference{
FromNodeKey: entry.Key, FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"sys_prompt"}, FromPath: compose.FieldPath{"sys_prompt"},
}, },
}, },
@@ -141,7 +143,7 @@ func TestLLM(t *testing.T) {
Path: compose.FieldPath{"query"}, Path: compose.FieldPath{"query"},
Source: vo.FieldSource{ Source: vo.FieldSource{
Ref: &vo.Reference{ Ref: &vo.Reference{
FromNodeKey: entry.Key, FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"query"}, FromPath: compose.FieldPath{"query"},
}, },
}, },
@@ -162,11 +164,11 @@ func TestLLM(t *testing.T) {
}, },
} }
exit := &compose2.NodeSchema{ exitN := &schema2.NodeSchema{
Key: entity.ExitNodeKey, Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit, Type: entity.NodeTypeExit,
Configs: map[string]any{ Configs: &exit.Config{
"TerminalPlan": vo.ReturnVariables, TerminatePlan: vo.ReturnVariables,
}, },
InputSources: []*vo.FieldInfo{ InputSources: []*vo.FieldInfo{
{ {
@@ -181,20 +183,20 @@ func TestLLM(t *testing.T) {
}, },
} }
ws := &compose2.WorkflowSchema{ ws := &schema2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{ Nodes: []*schema2.NodeSchema{
entry, entryN,
llmNode, llmNode,
exit, exitN,
}, },
Connections: []*compose2.Connection{ Connections: []*schema2.Connection{
{ {
FromNode: entry.Key, FromNode: entryN.Key,
ToNode: llmNode.Key, ToNode: llmNode.Key,
}, },
{ {
FromNode: llmNode.Key, FromNode: llmNode.Key,
ToNode: exit.Key, ToNode: exitN.Key,
}, },
}, },
} }
@@ -228,27 +230,20 @@ func TestLLM(t *testing.T) {
} }
} }
entry := &compose2.NodeSchema{ entryN := &schema2.NodeSchema{
Key: entity.EntryNodeKey, Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry, Type: entity.NodeTypeEntry,
Configs: map[string]any{ Configs: &entry.Config{},
"DefaultValues": map[string]any{},
},
} }
llmNode := &compose2.NodeSchema{ llmNode := &schema2.NodeSchema{
Key: "llm_node_key", Key: "llm_node_key",
Type: entity.NodeTypeLLM, Type: entity.NodeTypeLLM,
Configs: map[string]any{ Configs: &llm.Config{
"SystemPrompt": "you are a helpful assistant", SystemPrompt: "you are a helpful assistant",
"UserPrompt": "what's the largest country in the world and it's area size in square kilometers?", UserPrompt: "what's the largest country in the world and it's area size in square kilometers?",
"OutputFormat": llm.FormatJSON, OutputFormat: llm.FormatJSON,
"IgnoreException": true, LLMParams: &model.LLMParams{
"DefaultOutput": map[string]any{
"country_name": "unknown",
"area_size": int64(0),
},
"LLMParams": &model.LLMParams{
ModelName: modelName, ModelName: modelName,
}, },
}, },
@@ -264,11 +259,11 @@ func TestLLM(t *testing.T) {
}, },
} }
exit := &compose2.NodeSchema{ exitN := &schema2.NodeSchema{
Key: entity.ExitNodeKey, Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit, Type: entity.NodeTypeExit,
Configs: map[string]any{ Configs: &exit.Config{
"TerminalPlan": vo.ReturnVariables, TerminatePlan: vo.ReturnVariables,
}, },
InputSources: []*vo.FieldInfo{ InputSources: []*vo.FieldInfo{
{ {
@@ -292,20 +287,20 @@ func TestLLM(t *testing.T) {
}, },
} }
ws := &compose2.WorkflowSchema{ ws := &schema2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{ Nodes: []*schema2.NodeSchema{
entry, entryN,
llmNode, llmNode,
exit, exitN,
}, },
Connections: []*compose2.Connection{ Connections: []*schema2.Connection{
{ {
FromNode: entry.Key, FromNode: entryN.Key,
ToNode: llmNode.Key, ToNode: llmNode.Key,
}, },
{ {
FromNode: llmNode.Key, FromNode: llmNode.Key,
ToNode: exit.Key, ToNode: exitN.Key,
}, },
}, },
} }
@@ -337,22 +332,20 @@ func TestLLM(t *testing.T) {
} }
} }
entry := &compose2.NodeSchema{ entryN := &schema2.NodeSchema{
Key: entity.EntryNodeKey, Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry, Type: entity.NodeTypeEntry,
Configs: map[string]any{ Configs: &entry.Config{},
"DefaultValues": map[string]any{},
},
} }
llmNode := &compose2.NodeSchema{ llmNode := &schema2.NodeSchema{
Key: "llm_node_key", Key: "llm_node_key",
Type: entity.NodeTypeLLM, Type: entity.NodeTypeLLM,
Configs: map[string]any{ Configs: &llm.Config{
"SystemPrompt": "you are a helpful assistant", SystemPrompt: "you are a helpful assistant",
"UserPrompt": "list the top 5 largest countries in the world", UserPrompt: "list the top 5 largest countries in the world",
"OutputFormat": llm.FormatMarkdown, OutputFormat: llm.FormatMarkdown,
"LLMParams": &model.LLMParams{ LLMParams: &model.LLMParams{
ModelName: modelName, ModelName: modelName,
}, },
}, },
@@ -363,11 +356,11 @@ func TestLLM(t *testing.T) {
}, },
} }
exit := &compose2.NodeSchema{ exitN := &schema2.NodeSchema{
Key: entity.ExitNodeKey, Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit, Type: entity.NodeTypeExit,
Configs: map[string]any{ Configs: &exit.Config{
"TerminalPlan": vo.ReturnVariables, TerminatePlan: vo.ReturnVariables,
}, },
InputSources: []*vo.FieldInfo{ InputSources: []*vo.FieldInfo{
{ {
@@ -382,20 +375,20 @@ func TestLLM(t *testing.T) {
}, },
} }
ws := &compose2.WorkflowSchema{ ws := &schema2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{ Nodes: []*schema2.NodeSchema{
entry, entryN,
llmNode, llmNode,
exit, exitN,
}, },
Connections: []*compose2.Connection{ Connections: []*schema2.Connection{
{ {
FromNode: entry.Key, FromNode: entryN.Key,
ToNode: llmNode.Key, ToNode: llmNode.Key,
}, },
{ {
FromNode: llmNode.Key, FromNode: llmNode.Key,
ToNode: exit.Key, ToNode: exitN.Key,
}, },
}, },
} }
@@ -456,22 +449,20 @@ func TestLLM(t *testing.T) {
} }
} }
entry := &compose2.NodeSchema{ entryN := &schema2.NodeSchema{
Key: entity.EntryNodeKey, Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry, Type: entity.NodeTypeEntry,
Configs: map[string]any{ Configs: &entry.Config{},
"DefaultValues": map[string]any{},
},
} }
openaiNode := &compose2.NodeSchema{ openaiNode := &schema2.NodeSchema{
Key: "openai_llm_node_key", Key: "openai_llm_node_key",
Type: entity.NodeTypeLLM, Type: entity.NodeTypeLLM,
Configs: map[string]any{ Configs: &llm.Config{
"SystemPrompt": "you are a helpful assistant", SystemPrompt: "you are a helpful assistant",
"UserPrompt": "plan a 10 day family visit to China.", UserPrompt: "plan a 10 day family visit to China.",
"OutputFormat": llm.FormatText, OutputFormat: llm.FormatText,
"LLMParams": &model.LLMParams{ LLMParams: &model.LLMParams{
ModelName: modelName, ModelName: modelName,
}, },
}, },
@@ -482,14 +473,14 @@ func TestLLM(t *testing.T) {
}, },
} }
deepseekNode := &compose2.NodeSchema{ deepseekNode := &schema2.NodeSchema{
Key: "deepseek_llm_node_key", Key: "deepseek_llm_node_key",
Type: entity.NodeTypeLLM, Type: entity.NodeTypeLLM,
Configs: map[string]any{ Configs: &llm.Config{
"SystemPrompt": "you are a helpful assistant", SystemPrompt: "you are a helpful assistant",
"UserPrompt": "thoroughly plan a 10 day family visit to China. Use your reasoning ability.", UserPrompt: "thoroughly plan a 10 day family visit to China. Use your reasoning ability.",
"OutputFormat": llm.FormatText, OutputFormat: llm.FormatText,
"LLMParams": &model.LLMParams{ LLMParams: &model.LLMParams{
ModelName: modelName, ModelName: modelName,
}, },
}, },
@@ -503,12 +494,11 @@ func TestLLM(t *testing.T) {
}, },
} }
emitterNode := &compose2.NodeSchema{ emitterNode := &schema2.NodeSchema{
Key: "emitter_node_key", Key: "emitter_node_key",
Type: entity.NodeTypeOutputEmitter, Type: entity.NodeTypeOutputEmitter,
Configs: map[string]any{ Configs: &emitter.Config{
"Template": "prefix {{inputObj.field1}} {{input2}} {{deepseek_reasoning}} \n\n###\n\n {{openai_output}} \n\n###\n\n {{deepseek_output}} {{inputObj.field2}} suffix", Template: "prefix {{inputObj.field1}} {{input2}} {{deepseek_reasoning}} \n\n###\n\n {{openai_output}} \n\n###\n\n {{deepseek_output}} {{inputObj.field2}} suffix",
"Mode": nodes.Streaming,
}, },
InputSources: []*vo.FieldInfo{ InputSources: []*vo.FieldInfo{
{ {
@@ -542,7 +532,7 @@ func TestLLM(t *testing.T) {
Path: compose.FieldPath{"inputObj"}, Path: compose.FieldPath{"inputObj"},
Source: vo.FieldSource{ Source: vo.FieldSource{
Ref: &vo.Reference{ Ref: &vo.Reference{
FromNodeKey: entry.Key, FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"inputObj"}, FromPath: compose.FieldPath{"inputObj"},
}, },
}, },
@@ -551,7 +541,7 @@ func TestLLM(t *testing.T) {
Path: compose.FieldPath{"input2"}, Path: compose.FieldPath{"input2"},
Source: vo.FieldSource{ Source: vo.FieldSource{
Ref: &vo.Reference{ Ref: &vo.Reference{
FromNodeKey: entry.Key, FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"input2"}, FromPath: compose.FieldPath{"input2"},
}, },
}, },
@@ -559,11 +549,11 @@ func TestLLM(t *testing.T) {
}, },
} }
exit := &compose2.NodeSchema{ exitN := &schema2.NodeSchema{
Key: entity.ExitNodeKey, Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit, Type: entity.NodeTypeExit,
Configs: map[string]any{ Configs: &exit.Config{
"TerminalPlan": vo.UseAnswerContent, TerminatePlan: vo.UseAnswerContent,
}, },
InputSources: []*vo.FieldInfo{ InputSources: []*vo.FieldInfo{
{ {
@@ -596,17 +586,17 @@ func TestLLM(t *testing.T) {
}, },
} }
ws := &compose2.WorkflowSchema{ ws := &schema2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{ Nodes: []*schema2.NodeSchema{
entry, entryN,
openaiNode, openaiNode,
deepseekNode, deepseekNode,
emitterNode, emitterNode,
exit, exitN,
}, },
Connections: []*compose2.Connection{ Connections: []*schema2.Connection{
{ {
FromNode: entry.Key, FromNode: entryN.Key,
ToNode: openaiNode.Key, ToNode: openaiNode.Key,
}, },
{ {
@@ -614,7 +604,7 @@ func TestLLM(t *testing.T) {
ToNode: emitterNode.Key, ToNode: emitterNode.Key,
}, },
{ {
FromNode: entry.Key, FromNode: entryN.Key,
ToNode: deepseekNode.Key, ToNode: deepseekNode.Key,
}, },
{ {
@@ -623,7 +613,7 @@ func TestLLM(t *testing.T) {
}, },
{ {
FromNode: emitterNode.Key, FromNode: emitterNode.Key,
ToNode: exit.Key, ToNode: exitN.Key,
}, },
}, },
} }

View File

@@ -26,15 +26,20 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose" compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop"
_break "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop/break"
_continue "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop/continue"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableassigner" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableassigner"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
) )
func TestLoop(t *testing.T) { func TestLoop(t *testing.T) {
t.Run("by iteration", func(t *testing.T) { t.Run("by iteration", func(t *testing.T) {
// start-> loop_node_key[innerNode->continue] -> end // start-> loop_node_key[innerNode->continue] -> end
innerNode := &compose2.NodeSchema{ innerNode := &schema.NodeSchema{
Key: "innerNode", Key: "innerNode",
Type: entity.NodeTypeLambda, Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) { Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
@@ -54,31 +59,30 @@ func TestLoop(t *testing.T) {
}, },
} }
continueNode := &compose2.NodeSchema{ continueNode := &schema.NodeSchema{
Key: "continueNode", Key: "continueNode",
Type: entity.NodeTypeContinue, Type: entity.NodeTypeContinue,
Configs: &_continue.Config{},
} }
entry := &compose2.NodeSchema{ entryN := &schema.NodeSchema{
Key: entity.EntryNodeKey, Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry, Type: entity.NodeTypeEntry,
Configs: map[string]any{ Configs: &entry.Config{},
"DefaultValues": map[string]any{},
},
} }
loopNode := &compose2.NodeSchema{ loopNode := &schema.NodeSchema{
Key: "loop_node_key", Key: "loop_node_key",
Type: entity.NodeTypeLoop, Type: entity.NodeTypeLoop,
Configs: map[string]any{ Configs: &loop.Config{
"LoopType": loop.ByIteration, LoopType: loop.ByIteration,
}, },
InputSources: []*vo.FieldInfo{ InputSources: []*vo.FieldInfo{
{ {
Path: compose.FieldPath{loop.Count}, Path: compose.FieldPath{loop.Count},
Source: vo.FieldSource{ Source: vo.FieldSource{
Ref: &vo.Reference{ Ref: &vo.Reference{
FromNodeKey: entry.Key, FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"count"}, FromPath: compose.FieldPath{"count"},
}, },
}, },
@@ -97,11 +101,11 @@ func TestLoop(t *testing.T) {
}, },
} }
exit := &compose2.NodeSchema{ exitN := &schema.NodeSchema{
Key: entity.ExitNodeKey, Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit, Type: entity.NodeTypeExit,
Configs: map[string]any{ Configs: &exit.Config{
"TerminalPlan": vo.ReturnVariables, TerminatePlan: vo.ReturnVariables,
}, },
InputSources: []*vo.FieldInfo{ InputSources: []*vo.FieldInfo{
{ {
@@ -116,11 +120,11 @@ func TestLoop(t *testing.T) {
}, },
} }
ws := &compose2.WorkflowSchema{ ws := &schema.WorkflowSchema{
Nodes: []*compose2.NodeSchema{ Nodes: []*schema.NodeSchema{
entry, entryN,
loopNode, loopNode,
exit, exitN,
innerNode, innerNode,
continueNode, continueNode,
}, },
@@ -128,7 +132,7 @@ func TestLoop(t *testing.T) {
"innerNode": "loop_node_key", "innerNode": "loop_node_key",
"continueNode": "loop_node_key", "continueNode": "loop_node_key",
}, },
Connections: []*compose2.Connection{ Connections: []*schema.Connection{
{ {
FromNode: "loop_node_key", FromNode: "loop_node_key",
ToNode: "innerNode", ToNode: "innerNode",
@@ -142,12 +146,12 @@ func TestLoop(t *testing.T) {
ToNode: "loop_node_key", ToNode: "loop_node_key",
}, },
{ {
FromNode: entry.Key, FromNode: entryN.Key,
ToNode: "loop_node_key", ToNode: "loop_node_key",
}, },
{ {
FromNode: "loop_node_key", FromNode: "loop_node_key",
ToNode: exit.Key, ToNode: exitN.Key,
}, },
}, },
} }
@@ -168,7 +172,7 @@ func TestLoop(t *testing.T) {
t.Run("infinite", func(t *testing.T) { t.Run("infinite", func(t *testing.T) {
// start-> loop_node_key[innerNode->break] -> end // start-> loop_node_key[innerNode->break] -> end
innerNode := &compose2.NodeSchema{ innerNode := &schema.NodeSchema{
Key: "innerNode", Key: "innerNode",
Type: entity.NodeTypeLambda, Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) { Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
@@ -188,24 +192,23 @@ func TestLoop(t *testing.T) {
}, },
} }
breakNode := &compose2.NodeSchema{ breakNode := &schema.NodeSchema{
Key: "breakNode", Key: "breakNode",
Type: entity.NodeTypeBreak, Type: entity.NodeTypeBreak,
Configs: &_break.Config{},
} }
entry := &compose2.NodeSchema{ entryN := &schema.NodeSchema{
Key: entity.EntryNodeKey, Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry, Type: entity.NodeTypeEntry,
Configs: map[string]any{ Configs: &entry.Config{},
"DefaultValues": map[string]any{},
},
} }
loopNode := &compose2.NodeSchema{ loopNode := &schema.NodeSchema{
Key: "loop_node_key", Key: "loop_node_key",
Type: entity.NodeTypeLoop, Type: entity.NodeTypeLoop,
Configs: map[string]any{ Configs: &loop.Config{
"LoopType": loop.Infinite, LoopType: loop.Infinite,
}, },
OutputSources: []*vo.FieldInfo{ OutputSources: []*vo.FieldInfo{
{ {
@@ -220,11 +223,11 @@ func TestLoop(t *testing.T) {
}, },
} }
exit := &compose2.NodeSchema{ exitN := &schema.NodeSchema{
Key: entity.ExitNodeKey, Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit, Type: entity.NodeTypeExit,
Configs: map[string]any{ Configs: &exit.Config{
"TerminalPlan": vo.ReturnVariables, TerminatePlan: vo.ReturnVariables,
}, },
InputSources: []*vo.FieldInfo{ InputSources: []*vo.FieldInfo{
{ {
@@ -239,11 +242,11 @@ func TestLoop(t *testing.T) {
}, },
} }
ws := &compose2.WorkflowSchema{ ws := &schema.WorkflowSchema{
Nodes: []*compose2.NodeSchema{ Nodes: []*schema.NodeSchema{
entry, entryN,
loopNode, loopNode,
exit, exitN,
innerNode, innerNode,
breakNode, breakNode,
}, },
@@ -251,7 +254,7 @@ func TestLoop(t *testing.T) {
"innerNode": "loop_node_key", "innerNode": "loop_node_key",
"breakNode": "loop_node_key", "breakNode": "loop_node_key",
}, },
Connections: []*compose2.Connection{ Connections: []*schema.Connection{
{ {
FromNode: "loop_node_key", FromNode: "loop_node_key",
ToNode: "innerNode", ToNode: "innerNode",
@@ -265,12 +268,12 @@ func TestLoop(t *testing.T) {
ToNode: "loop_node_key", ToNode: "loop_node_key",
}, },
{ {
FromNode: entry.Key, FromNode: entryN.Key,
ToNode: "loop_node_key", ToNode: "loop_node_key",
}, },
{ {
FromNode: "loop_node_key", FromNode: "loop_node_key",
ToNode: exit.Key, ToNode: exitN.Key,
}, },
}, },
} }
@@ -290,14 +293,14 @@ func TestLoop(t *testing.T) {
t.Run("by array", func(t *testing.T) { t.Run("by array", func(t *testing.T) {
// start-> loop_node_key[innerNode->variable_assign] -> end // start-> loop_node_key[innerNode->variable_assign] -> end
innerNode := &compose2.NodeSchema{ innerNode := &schema.NodeSchema{
Key: "innerNode", Key: "innerNode",
Type: entity.NodeTypeLambda, Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) { Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
item1 := in["item1"].(string) item1 := in["item1"].(string)
item2 := in["item2"].(string) item2 := in["item2"].(string)
count := in["count"].(int) count := in["count"].(int)
return map[string]any{"total": int(count) + len(item1) + len(item2)}, nil return map[string]any{"total": count + len(item1) + len(item2)}, nil
}), }),
InputSources: []*vo.FieldInfo{ InputSources: []*vo.FieldInfo{
{ {
@@ -330,16 +333,18 @@ func TestLoop(t *testing.T) {
}, },
} }
assigner := &compose2.NodeSchema{ assigner := &schema.NodeSchema{
Key: "assigner", Key: "assigner",
Type: entity.NodeTypeVariableAssignerWithinLoop, Type: entity.NodeTypeVariableAssignerWithinLoop,
Configs: []*variableassigner.Pair{ Configs: &variableassigner.InLoopConfig{
{ Pairs: []*variableassigner.Pair{
Left: vo.Reference{ {
FromPath: compose.FieldPath{"count"}, Left: vo.Reference{
VariableType: ptr.Of(vo.ParentIntermediate), FromPath: compose.FieldPath{"count"},
VariableType: ptr.Of(vo.ParentIntermediate),
},
Right: compose.FieldPath{"total"},
}, },
Right: compose.FieldPath{"total"},
}, },
}, },
InputSources: []*vo.FieldInfo{ InputSources: []*vo.FieldInfo{
@@ -355,19 +360,17 @@ func TestLoop(t *testing.T) {
}, },
} }
entry := &compose2.NodeSchema{ entryN := &schema.NodeSchema{
Key: entity.EntryNodeKey, Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry, Type: entity.NodeTypeEntry,
Configs: map[string]any{ Configs: &entry.Config{},
"DefaultValues": map[string]any{},
},
} }
exit := &compose2.NodeSchema{ exitN := &schema.NodeSchema{
Key: entity.ExitNodeKey, Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit, Type: entity.NodeTypeExit,
Configs: map[string]any{ Configs: &exit.Config{
"TerminalPlan": vo.ReturnVariables, TerminatePlan: vo.ReturnVariables,
}, },
InputSources: []*vo.FieldInfo{ InputSources: []*vo.FieldInfo{
{ {
@@ -382,12 +385,13 @@ func TestLoop(t *testing.T) {
}, },
} }
loopNode := &compose2.NodeSchema{ loopNode := &schema.NodeSchema{
Key: "loop_node_key", Key: "loop_node_key",
Type: entity.NodeTypeLoop, Type: entity.NodeTypeLoop,
Configs: map[string]any{ Configs: &loop.Config{
"LoopType": loop.ByArray, LoopType: loop.ByArray,
"IntermediateVars": map[string]*vo.TypeInfo{ InputArrays: []string{"items1", "items2"},
IntermediateVars: map[string]*vo.TypeInfo{
"count": { "count": {
Type: vo.DataTypeInteger, Type: vo.DataTypeInteger,
}, },
@@ -408,7 +412,7 @@ func TestLoop(t *testing.T) {
Path: compose.FieldPath{"items1"}, Path: compose.FieldPath{"items1"},
Source: vo.FieldSource{ Source: vo.FieldSource{
Ref: &vo.Reference{ Ref: &vo.Reference{
FromNodeKey: entry.Key, FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"items1"}, FromPath: compose.FieldPath{"items1"},
}, },
}, },
@@ -417,7 +421,7 @@ func TestLoop(t *testing.T) {
Path: compose.FieldPath{"items2"}, Path: compose.FieldPath{"items2"},
Source: vo.FieldSource{ Source: vo.FieldSource{
Ref: &vo.Reference{ Ref: &vo.Reference{
FromNodeKey: entry.Key, FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"items2"}, FromPath: compose.FieldPath{"items2"},
}, },
}, },
@@ -442,11 +446,11 @@ func TestLoop(t *testing.T) {
}, },
} }
ws := &compose2.WorkflowSchema{ ws := &schema.WorkflowSchema{
Nodes: []*compose2.NodeSchema{ Nodes: []*schema.NodeSchema{
entry, entryN,
loopNode, loopNode,
exit, exitN,
innerNode, innerNode,
assigner, assigner,
}, },
@@ -454,7 +458,7 @@ func TestLoop(t *testing.T) {
"innerNode": "loop_node_key", "innerNode": "loop_node_key",
"assigner": "loop_node_key", "assigner": "loop_node_key",
}, },
Connections: []*compose2.Connection{ Connections: []*schema.Connection{
{ {
FromNode: "loop_node_key", FromNode: "loop_node_key",
ToNode: "innerNode", ToNode: "innerNode",
@@ -468,12 +472,12 @@ func TestLoop(t *testing.T) {
ToNode: "loop_node_key", ToNode: "loop_node_key",
}, },
{ {
FromNode: entry.Key, FromNode: entryN.Key,
ToNode: "loop_node_key", ToNode: "loop_node_key",
}, },
{ {
FromNode: "loop_node_key", FromNode: "loop_node_key",
ToNode: exit.Key, ToNode: exitN.Key,
}, },
}, },
} }

View File

@@ -43,8 +43,11 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose" compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
repo2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/repo" repo2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/repo"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/infra/impl/checkpoint" "github.com/coze-dev/coze-studio/backend/infra/impl/checkpoint"
mock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/idgen" mock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/idgen"
storageMock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/storage" storageMock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/storage"
@@ -106,26 +109,25 @@ func TestQuestionAnswer(t *testing.T) {
mockey.Mock(workflow.GetRepository).Return(repo).Build() mockey.Mock(workflow.GetRepository).Return(repo).Build()
t.Run("answer directly, no structured output", func(t *testing.T) { t.Run("answer directly, no structured output", func(t *testing.T) {
entry := &compose2.NodeSchema{ entryN := &schema2.NodeSchema{
Key: entity.EntryNodeKey, Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry, Type: entity.NodeTypeEntry,
Configs: map[string]any{ Configs: &entry.Config{},
"DefaultValues": map[string]any{}, }
}}
ns := &compose2.NodeSchema{ ns := &schema2.NodeSchema{
Key: "qa_node_key", Key: "qa_node_key",
Type: entity.NodeTypeQuestionAnswer, Type: entity.NodeTypeQuestionAnswer,
Configs: map[string]any{ Configs: &qa.Config{
"QuestionTpl": "{{input}}", QuestionTpl: "{{input}}",
"AnswerType": qa.AnswerDirectly, AnswerType: qa.AnswerDirectly,
}, },
InputSources: []*vo.FieldInfo{ InputSources: []*vo.FieldInfo{
{ {
Path: compose.FieldPath{"input"}, Path: compose.FieldPath{"input"},
Source: vo.FieldSource{ Source: vo.FieldSource{
Ref: &vo.Reference{ Ref: &vo.Reference{
FromNodeKey: entry.Key, FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"query"}, FromPath: compose.FieldPath{"query"},
}, },
}, },
@@ -133,11 +135,11 @@ func TestQuestionAnswer(t *testing.T) {
}, },
} }
exit := &compose2.NodeSchema{ exitN := &schema2.NodeSchema{
Key: entity.ExitNodeKey, Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit, Type: entity.NodeTypeExit,
Configs: map[string]any{ Configs: &exit.Config{
"TerminalPlan": vo.ReturnVariables, TerminatePlan: vo.ReturnVariables,
}, },
InputSources: []*vo.FieldInfo{ InputSources: []*vo.FieldInfo{
{ {
@@ -152,20 +154,20 @@ func TestQuestionAnswer(t *testing.T) {
}, },
} }
ws := &compose2.WorkflowSchema{ ws := &schema2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{ Nodes: []*schema2.NodeSchema{
entry, entryN,
ns, ns,
exit, exitN,
}, },
Connections: []*compose2.Connection{ Connections: []*schema2.Connection{
{ {
FromNode: entry.Key, FromNode: entryN.Key,
ToNode: "qa_node_key", ToNode: "qa_node_key",
}, },
{ {
FromNode: "qa_node_key", FromNode: "qa_node_key",
ToNode: exit.Key, ToNode: exitN.Key,
}, },
}, },
} }
@@ -210,30 +212,28 @@ func TestQuestionAnswer(t *testing.T) {
mockModelManager.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(oneChatModel, nil, nil).Times(1) mockModelManager.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(oneChatModel, nil, nil).Times(1)
} }
entry := &compose2.NodeSchema{ entryN := &schema2.NodeSchema{
Key: entity.EntryNodeKey, Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry, Type: entity.NodeTypeEntry,
Configs: map[string]any{ Configs: &entry.Config{},
"DefaultValues": map[string]any{},
},
} }
ns := &compose2.NodeSchema{ ns := &schema2.NodeSchema{
Key: "qa_node_key", Key: "qa_node_key",
Type: entity.NodeTypeQuestionAnswer, Type: entity.NodeTypeQuestionAnswer,
Configs: map[string]any{ Configs: &qa.Config{
"QuestionTpl": "{{input}}", QuestionTpl: "{{input}}",
"AnswerType": qa.AnswerByChoices, AnswerType: qa.AnswerByChoices,
"ChoiceType": qa.FixedChoices, ChoiceType: qa.FixedChoices,
"FixedChoices": []string{"{{choice1}}", "{{choice2}}"}, FixedChoices: []string{"{{choice1}}", "{{choice2}}"},
"LLMParams": &model.LLMParams{}, LLMParams: &model.LLMParams{},
}, },
InputSources: []*vo.FieldInfo{ InputSources: []*vo.FieldInfo{
{ {
Path: compose.FieldPath{"input"}, Path: compose.FieldPath{"input"},
Source: vo.FieldSource{ Source: vo.FieldSource{
Ref: &vo.Reference{ Ref: &vo.Reference{
FromNodeKey: entry.Key, FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"query"}, FromPath: compose.FieldPath{"query"},
}, },
}, },
@@ -242,7 +242,7 @@ func TestQuestionAnswer(t *testing.T) {
Path: compose.FieldPath{"choice1"}, Path: compose.FieldPath{"choice1"},
Source: vo.FieldSource{ Source: vo.FieldSource{
Ref: &vo.Reference{ Ref: &vo.Reference{
FromNodeKey: entry.Key, FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"choice1"}, FromPath: compose.FieldPath{"choice1"},
}, },
}, },
@@ -251,7 +251,7 @@ func TestQuestionAnswer(t *testing.T) {
Path: compose.FieldPath{"choice2"}, Path: compose.FieldPath{"choice2"},
Source: vo.FieldSource{ Source: vo.FieldSource{
Ref: &vo.Reference{ Ref: &vo.Reference{
FromNodeKey: entry.Key, FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"choice2"}, FromPath: compose.FieldPath{"choice2"},
}, },
}, },
@@ -259,11 +259,11 @@ func TestQuestionAnswer(t *testing.T) {
}, },
} }
exit := &compose2.NodeSchema{ exitN := &schema2.NodeSchema{
Key: entity.ExitNodeKey, Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit, Type: entity.NodeTypeExit,
Configs: map[string]any{ Configs: &exit.Config{
"TerminalPlan": vo.ReturnVariables, TerminatePlan: vo.ReturnVariables,
}, },
InputSources: []*vo.FieldInfo{ InputSources: []*vo.FieldInfo{
{ {
@@ -287,7 +287,7 @@ func TestQuestionAnswer(t *testing.T) {
}, },
} }
lambda := &compose2.NodeSchema{ lambda := &schema2.NodeSchema{
Key: "lambda", Key: "lambda",
Type: entity.NodeTypeLambda, Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) { Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
@@ -295,26 +295,26 @@ func TestQuestionAnswer(t *testing.T) {
}), }),
} }
ws := &compose2.WorkflowSchema{ ws := &schema2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{ Nodes: []*schema2.NodeSchema{
entry, entryN,
ns, ns,
exit, exitN,
lambda, lambda,
}, },
Connections: []*compose2.Connection{ Connections: []*schema2.Connection{
{ {
FromNode: entry.Key, FromNode: entryN.Key,
ToNode: "qa_node_key", ToNode: "qa_node_key",
}, },
{ {
FromNode: "qa_node_key", FromNode: "qa_node_key",
ToNode: exit.Key, ToNode: exitN.Key,
FromPort: ptr.Of("branch_0"), FromPort: ptr.Of("branch_0"),
}, },
{ {
FromNode: "qa_node_key", FromNode: "qa_node_key",
ToNode: exit.Key, ToNode: exitN.Key,
FromPort: ptr.Of("branch_1"), FromPort: ptr.Of("branch_1"),
}, },
{ {
@@ -324,11 +324,15 @@ func TestQuestionAnswer(t *testing.T) {
}, },
{ {
FromNode: "lambda", FromNode: "lambda",
ToNode: exit.Key, ToNode: exitN.Key,
}, },
}, },
} }
branches, err := schema2.BuildBranches(ws.Connections)
assert.NoError(t, err)
ws.Branches = branches
ws.Init() ws.Init()
wf, err := compose2.NewWorkflow(context.Background(), ws) wf, err := compose2.NewWorkflow(context.Background(), ws)
@@ -362,28 +366,26 @@ func TestQuestionAnswer(t *testing.T) {
}) })
t.Run("answer with dynamic choices", func(t *testing.T) { t.Run("answer with dynamic choices", func(t *testing.T) {
entry := &compose2.NodeSchema{ entryN := &schema2.NodeSchema{
Key: entity.EntryNodeKey, Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry, Type: entity.NodeTypeEntry,
Configs: map[string]any{ Configs: &entry.Config{},
"DefaultValues": map[string]any{},
},
} }
ns := &compose2.NodeSchema{ ns := &schema2.NodeSchema{
Key: "qa_node_key", Key: "qa_node_key",
Type: entity.NodeTypeQuestionAnswer, Type: entity.NodeTypeQuestionAnswer,
Configs: map[string]any{ Configs: &qa.Config{
"QuestionTpl": "{{input}}", QuestionTpl: "{{input}}",
"AnswerType": qa.AnswerByChoices, AnswerType: qa.AnswerByChoices,
"ChoiceType": qa.DynamicChoices, ChoiceType: qa.DynamicChoices,
}, },
InputSources: []*vo.FieldInfo{ InputSources: []*vo.FieldInfo{
{ {
Path: compose.FieldPath{"input"}, Path: compose.FieldPath{"input"},
Source: vo.FieldSource{ Source: vo.FieldSource{
Ref: &vo.Reference{ Ref: &vo.Reference{
FromNodeKey: entry.Key, FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"query"}, FromPath: compose.FieldPath{"query"},
}, },
}, },
@@ -392,7 +394,7 @@ func TestQuestionAnswer(t *testing.T) {
Path: compose.FieldPath{qa.DynamicChoicesKey}, Path: compose.FieldPath{qa.DynamicChoicesKey},
Source: vo.FieldSource{ Source: vo.FieldSource{
Ref: &vo.Reference{ Ref: &vo.Reference{
FromNodeKey: entry.Key, FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"choices"}, FromPath: compose.FieldPath{"choices"},
}, },
}, },
@@ -400,11 +402,11 @@ func TestQuestionAnswer(t *testing.T) {
}, },
} }
exit := &compose2.NodeSchema{ exitN := &schema2.NodeSchema{
Key: entity.ExitNodeKey, Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit, Type: entity.NodeTypeExit,
Configs: map[string]any{ Configs: &exit.Config{
"TerminalPlan": vo.ReturnVariables, TerminatePlan: vo.ReturnVariables,
}, },
InputSources: []*vo.FieldInfo{ InputSources: []*vo.FieldInfo{
{ {
@@ -428,7 +430,7 @@ func TestQuestionAnswer(t *testing.T) {
}, },
} }
lambda := &compose2.NodeSchema{ lambda := &schema2.NodeSchema{
Key: "lambda", Key: "lambda",
Type: entity.NodeTypeLambda, Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) { Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
@@ -436,26 +438,26 @@ func TestQuestionAnswer(t *testing.T) {
}), }),
} }
ws := &compose2.WorkflowSchema{ ws := &schema2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{ Nodes: []*schema2.NodeSchema{
entry, entryN,
ns, ns,
exit, exitN,
lambda, lambda,
}, },
Connections: []*compose2.Connection{ Connections: []*schema2.Connection{
{ {
FromNode: entry.Key, FromNode: entryN.Key,
ToNode: "qa_node_key", ToNode: "qa_node_key",
}, },
{ {
FromNode: "qa_node_key", FromNode: "qa_node_key",
ToNode: exit.Key, ToNode: exitN.Key,
FromPort: ptr.Of("branch_0"), FromPort: ptr.Of("branch_0"),
}, },
{ {
FromNode: "lambda", FromNode: "lambda",
ToNode: exit.Key, ToNode: exitN.Key,
}, },
{ {
FromNode: "qa_node_key", FromNode: "qa_node_key",
@@ -465,6 +467,10 @@ func TestQuestionAnswer(t *testing.T) {
}, },
} }
branches, err := schema2.BuildBranches(ws.Connections)
assert.NoError(t, err)
ws.Branches = branches
ws.Init() ws.Init()
wf, err := compose2.NewWorkflow(context.Background(), ws) wf, err := compose2.NewWorkflow(context.Background(), ws)
@@ -522,31 +528,29 @@ func TestQuestionAnswer(t *testing.T) {
mockModelManager.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(chatModel, nil, nil).Times(1) mockModelManager.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(chatModel, nil, nil).Times(1)
} }
entry := &compose2.NodeSchema{ entryN := &schema2.NodeSchema{
Key: entity.EntryNodeKey, Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry, Type: entity.NodeTypeEntry,
Configs: map[string]any{ Configs: &entry.Config{},
"DefaultValues": map[string]any{},
},
} }
ns := &compose2.NodeSchema{ ns := &schema2.NodeSchema{
Key: "qa_node_key", Key: "qa_node_key",
Type: entity.NodeTypeQuestionAnswer, Type: entity.NodeTypeQuestionAnswer,
Configs: map[string]any{ Configs: &qa.Config{
"QuestionTpl": "{{input}}", QuestionTpl: "{{input}}",
"AnswerType": qa.AnswerDirectly, AnswerType: qa.AnswerDirectly,
"ExtractFromAnswer": true, ExtractFromAnswer: true,
"AdditionalSystemPromptTpl": "{{prompt}}", AdditionalSystemPromptTpl: "{{prompt}}",
"MaxAnswerCount": 2, MaxAnswerCount: 2,
"LLMParams": &model.LLMParams{}, LLMParams: &model.LLMParams{},
}, },
InputSources: []*vo.FieldInfo{ InputSources: []*vo.FieldInfo{
{ {
Path: compose.FieldPath{"input"}, Path: compose.FieldPath{"input"},
Source: vo.FieldSource{ Source: vo.FieldSource{
Ref: &vo.Reference{ Ref: &vo.Reference{
FromNodeKey: entry.Key, FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"query"}, FromPath: compose.FieldPath{"query"},
}, },
}, },
@@ -555,7 +559,7 @@ func TestQuestionAnswer(t *testing.T) {
Path: compose.FieldPath{"prompt"}, Path: compose.FieldPath{"prompt"},
Source: vo.FieldSource{ Source: vo.FieldSource{
Ref: &vo.Reference{ Ref: &vo.Reference{
FromNodeKey: entry.Key, FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"prompt"}, FromPath: compose.FieldPath{"prompt"},
}, },
}, },
@@ -573,11 +577,11 @@ func TestQuestionAnswer(t *testing.T) {
}, },
} }
exit := &compose2.NodeSchema{ exitN := &schema2.NodeSchema{
Key: entity.ExitNodeKey, Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit, Type: entity.NodeTypeExit,
Configs: map[string]any{ Configs: &exit.Config{
"TerminalPlan": vo.ReturnVariables, TerminatePlan: vo.ReturnVariables,
}, },
InputSources: []*vo.FieldInfo{ InputSources: []*vo.FieldInfo{
{ {
@@ -610,20 +614,20 @@ func TestQuestionAnswer(t *testing.T) {
}, },
} }
ws := &compose2.WorkflowSchema{ ws := &schema2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{ Nodes: []*schema2.NodeSchema{
entry, entryN,
ns, ns,
exit, exitN,
}, },
Connections: []*compose2.Connection{ Connections: []*schema2.Connection{
{ {
FromNode: entry.Key, FromNode: entryN.Key,
ToNode: "qa_node_key", ToNode: "qa_node_key",
}, },
{ {
FromNode: "qa_node_key", FromNode: "qa_node_key",
ToNode: exit.Key, ToNode: exitN.Key,
}, },
}, },
} }

View File

@@ -26,26 +26,28 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose" compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/textprocessor" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/textprocessor"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableaggregator" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableaggregator"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
) )
func TestAddSelector(t *testing.T) { func TestAddSelector(t *testing.T) {
// start -> selector, selector.condition1 -> lambda1 -> end, selector.condition2 -> [lambda2, lambda3] -> end, selector default -> end // start -> selector, selector.condition1 -> lambda1 -> end, selector.condition2 -> [lambda2, lambda3] -> end, selector default -> end
entry := &compose2.NodeSchema{ entryN := &schema.NodeSchema{
Key: entity.EntryNodeKey, Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry, Type: entity.NodeTypeEntry,
Configs: map[string]any{ Configs: &entry.Config{},
"DefaultValues": map[string]any{}, }
}}
exit := &compose2.NodeSchema{ exitN := &schema.NodeSchema{
Key: entity.ExitNodeKey, Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit, Type: entity.NodeTypeExit,
Configs: map[string]any{ Configs: &exit.Config{
"TerminalPlan": vo.ReturnVariables, TerminatePlan: vo.ReturnVariables,
}, },
InputSources: []*vo.FieldInfo{ InputSources: []*vo.FieldInfo{
{ {
@@ -84,7 +86,7 @@ func TestAddSelector(t *testing.T) {
}, nil }, nil
} }
lambdaNode1 := &compose2.NodeSchema{ lambdaNode1 := &schema.NodeSchema{
Key: "lambda1", Key: "lambda1",
Type: entity.NodeTypeLambda, Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(lambda1), Lambda: compose.InvokableLambda(lambda1),
@@ -96,7 +98,7 @@ func TestAddSelector(t *testing.T) {
}, nil }, nil
} }
LambdaNode2 := &compose2.NodeSchema{ LambdaNode2 := &schema.NodeSchema{
Key: "lambda2", Key: "lambda2",
Type: entity.NodeTypeLambda, Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(lambda2), Lambda: compose.InvokableLambda(lambda2),
@@ -108,16 +110,16 @@ func TestAddSelector(t *testing.T) {
}, nil }, nil
} }
lambdaNode3 := &compose2.NodeSchema{ lambdaNode3 := &schema.NodeSchema{
Key: "lambda3", Key: "lambda3",
Type: entity.NodeTypeLambda, Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(lambda3), Lambda: compose.InvokableLambda(lambda3),
} }
ns := &compose2.NodeSchema{ ns := &schema.NodeSchema{
Key: "selector", Key: "selector",
Type: entity.NodeTypeSelector, Type: entity.NodeTypeSelector,
Configs: map[string]any{"Clauses": []*selector.OneClauseSchema{ Configs: &selector.Config{Clauses: []*selector.OneClauseSchema{
{ {
Single: ptr.Of(selector.OperatorEqual), Single: ptr.Of(selector.OperatorEqual),
}, },
@@ -136,7 +138,7 @@ func TestAddSelector(t *testing.T) {
Path: compose.FieldPath{"0", selector.LeftKey}, Path: compose.FieldPath{"0", selector.LeftKey},
Source: vo.FieldSource{ Source: vo.FieldSource{
Ref: &vo.Reference{ Ref: &vo.Reference{
FromNodeKey: entry.Key, FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"key1"}, FromPath: compose.FieldPath{"key1"},
}, },
}, },
@@ -151,7 +153,7 @@ func TestAddSelector(t *testing.T) {
Path: compose.FieldPath{"1", "0", selector.LeftKey}, Path: compose.FieldPath{"1", "0", selector.LeftKey},
Source: vo.FieldSource{ Source: vo.FieldSource{
Ref: &vo.Reference{ Ref: &vo.Reference{
FromNodeKey: entry.Key, FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"key2"}, FromPath: compose.FieldPath{"key2"},
}, },
}, },
@@ -160,7 +162,7 @@ func TestAddSelector(t *testing.T) {
Path: compose.FieldPath{"1", "0", selector.RightKey}, Path: compose.FieldPath{"1", "0", selector.RightKey},
Source: vo.FieldSource{ Source: vo.FieldSource{
Ref: &vo.Reference{ Ref: &vo.Reference{
FromNodeKey: entry.Key, FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"key3"}, FromPath: compose.FieldPath{"key3"},
}, },
}, },
@@ -169,7 +171,7 @@ func TestAddSelector(t *testing.T) {
Path: compose.FieldPath{"1", "1", selector.LeftKey}, Path: compose.FieldPath{"1", "1", selector.LeftKey},
Source: vo.FieldSource{ Source: vo.FieldSource{
Ref: &vo.Reference{ Ref: &vo.Reference{
FromNodeKey: entry.Key, FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"key4"}, FromPath: compose.FieldPath{"key4"},
}, },
}, },
@@ -214,18 +216,18 @@ func TestAddSelector(t *testing.T) {
}, },
} }
ws := &compose2.WorkflowSchema{ ws := &schema.WorkflowSchema{
Nodes: []*compose2.NodeSchema{ Nodes: []*schema.NodeSchema{
entry, entryN,
ns, ns,
lambdaNode1, lambdaNode1,
LambdaNode2, LambdaNode2,
lambdaNode3, lambdaNode3,
exit, exitN,
}, },
Connections: []*compose2.Connection{ Connections: []*schema.Connection{
{ {
FromNode: entry.Key, FromNode: entryN.Key,
ToNode: "selector", ToNode: "selector",
}, },
{ {
@@ -245,24 +247,28 @@ func TestAddSelector(t *testing.T) {
}, },
{ {
FromNode: "selector", FromNode: "selector",
ToNode: exit.Key, ToNode: exitN.Key,
FromPort: ptr.Of("default"), FromPort: ptr.Of("default"),
}, },
{ {
FromNode: "lambda1", FromNode: "lambda1",
ToNode: exit.Key, ToNode: exitN.Key,
}, },
{ {
FromNode: "lambda2", FromNode: "lambda2",
ToNode: exit.Key, ToNode: exitN.Key,
}, },
{ {
FromNode: "lambda3", FromNode: "lambda3",
ToNode: exit.Key, ToNode: exitN.Key,
}, },
}, },
} }
branches, err := schema.BuildBranches(ws.Connections)
assert.NoError(t, err)
ws.Branches = branches
ws.Init() ws.Init()
ctx := context.Background() ctx := context.Background()
@@ -303,19 +309,17 @@ func TestAddSelector(t *testing.T) {
} }
func TestVariableAggregator(t *testing.T) { func TestVariableAggregator(t *testing.T) {
entry := &compose2.NodeSchema{ entryN := &schema.NodeSchema{
Key: entity.EntryNodeKey, Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry, Type: entity.NodeTypeEntry,
Configs: map[string]any{ Configs: &entry.Config{},
"DefaultValues": map[string]any{},
},
} }
exit := &compose2.NodeSchema{ exitN := &schema.NodeSchema{
Key: entity.ExitNodeKey, Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit, Type: entity.NodeTypeExit,
Configs: map[string]any{ Configs: &exit.Config{
"TerminalPlan": vo.ReturnVariables, TerminatePlan: vo.ReturnVariables,
}, },
InputSources: []*vo.FieldInfo{ InputSources: []*vo.FieldInfo{
{ {
@@ -339,16 +343,16 @@ func TestVariableAggregator(t *testing.T) {
}, },
} }
ns := &compose2.NodeSchema{ ns := &schema.NodeSchema{
Key: "va", Key: "va",
Type: entity.NodeTypeVariableAggregator, Type: entity.NodeTypeVariableAggregator,
Configs: map[string]any{ Configs: &variableaggregator.Config{
"MergeStrategy": variableaggregator.FirstNotNullValue, MergeStrategy: variableaggregator.FirstNotNullValue,
"GroupToLen": map[string]int{ GroupLen: map[string]int{
"Group1": 1, "Group1": 1,
"Group2": 1, "Group2": 1,
}, },
"GroupOrder": []string{ GroupOrder: []string{
"Group1", "Group1",
"Group2", "Group2",
}, },
@@ -358,7 +362,7 @@ func TestVariableAggregator(t *testing.T) {
Path: compose.FieldPath{"Group1", "0"}, Path: compose.FieldPath{"Group1", "0"},
Source: vo.FieldSource{ Source: vo.FieldSource{
Ref: &vo.Reference{ Ref: &vo.Reference{
FromNodeKey: entry.Key, FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"Str1"}, FromPath: compose.FieldPath{"Str1"},
}, },
}, },
@@ -367,7 +371,7 @@ func TestVariableAggregator(t *testing.T) {
Path: compose.FieldPath{"Group2", "0"}, Path: compose.FieldPath{"Group2", "0"},
Source: vo.FieldSource{ Source: vo.FieldSource{
Ref: &vo.Reference{ Ref: &vo.Reference{
FromNodeKey: entry.Key, FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"Int1"}, FromPath: compose.FieldPath{"Int1"},
}, },
}, },
@@ -401,20 +405,20 @@ func TestVariableAggregator(t *testing.T) {
}, },
} }
ws := &compose2.WorkflowSchema{ ws := &schema.WorkflowSchema{
Nodes: []*compose2.NodeSchema{ Nodes: []*schema.NodeSchema{
entry, entryN,
ns, ns,
exit, exitN,
}, },
Connections: []*compose2.Connection{ Connections: []*schema.Connection{
{ {
FromNode: entry.Key, FromNode: entryN.Key,
ToNode: "va", ToNode: "va",
}, },
{ {
FromNode: "va", FromNode: "va",
ToNode: exit.Key, ToNode: exitN.Key,
}, },
}, },
} }
@@ -448,19 +452,17 @@ func TestVariableAggregator(t *testing.T) {
func TestTextProcessor(t *testing.T) { func TestTextProcessor(t *testing.T) {
t.Run("split", func(t *testing.T) { t.Run("split", func(t *testing.T) {
entry := &compose2.NodeSchema{ entryN := &schema.NodeSchema{
Key: entity.EntryNodeKey, Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry, Type: entity.NodeTypeEntry,
Configs: map[string]any{ Configs: &entry.Config{},
"DefaultValues": map[string]any{},
},
} }
exit := &compose2.NodeSchema{ exitN := &schema.NodeSchema{
Key: entity.ExitNodeKey, Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit, Type: entity.NodeTypeExit,
Configs: map[string]any{ Configs: &exit.Config{
"TerminalPlan": vo.ReturnVariables, TerminatePlan: vo.ReturnVariables,
}, },
InputSources: []*vo.FieldInfo{ InputSources: []*vo.FieldInfo{
{ {
@@ -475,19 +477,19 @@ func TestTextProcessor(t *testing.T) {
}, },
} }
ns := &compose2.NodeSchema{ ns := &schema.NodeSchema{
Key: "tp", Key: "tp",
Type: entity.NodeTypeTextProcessor, Type: entity.NodeTypeTextProcessor,
Configs: map[string]any{ Configs: &textprocessor.Config{
"Type": textprocessor.SplitText, Type: textprocessor.SplitText,
"Separators": []string{"|"}, Separators: []string{"|"},
}, },
InputSources: []*vo.FieldInfo{ InputSources: []*vo.FieldInfo{
{ {
Path: compose.FieldPath{"String"}, Path: compose.FieldPath{"String"},
Source: vo.FieldSource{ Source: vo.FieldSource{
Ref: &vo.Reference{ Ref: &vo.Reference{
FromNodeKey: entry.Key, FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"Str"}, FromPath: compose.FieldPath{"Str"},
}, },
}, },
@@ -495,20 +497,20 @@ func TestTextProcessor(t *testing.T) {
}, },
} }
ws := &compose2.WorkflowSchema{ ws := &schema.WorkflowSchema{
Nodes: []*compose2.NodeSchema{ Nodes: []*schema.NodeSchema{
ns, ns,
entry, entryN,
exit, exitN,
}, },
Connections: []*compose2.Connection{ Connections: []*schema.Connection{
{ {
FromNode: entry.Key, FromNode: entryN.Key,
ToNode: "tp", ToNode: "tp",
}, },
{ {
FromNode: "tp", FromNode: "tp",
ToNode: exit.Key, ToNode: exitN.Key,
}, },
}, },
} }
@@ -527,19 +529,17 @@ func TestTextProcessor(t *testing.T) {
}) })
t.Run("concat", func(t *testing.T) { t.Run("concat", func(t *testing.T) {
entry := &compose2.NodeSchema{ entryN := &schema.NodeSchema{
Key: entity.EntryNodeKey, Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry, Type: entity.NodeTypeEntry,
Configs: map[string]any{ Configs: &entry.Config{},
"DefaultValues": map[string]any{},
},
} }
exit := &compose2.NodeSchema{ exitN := &schema.NodeSchema{
Key: entity.ExitNodeKey, Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit, Type: entity.NodeTypeExit,
Configs: map[string]any{ Configs: &exit.Config{
"TerminalPlan": vo.ReturnVariables, TerminatePlan: vo.ReturnVariables,
}, },
InputSources: []*vo.FieldInfo{ InputSources: []*vo.FieldInfo{
{ {
@@ -554,20 +554,20 @@ func TestTextProcessor(t *testing.T) {
}, },
} }
ns := &compose2.NodeSchema{ ns := &schema.NodeSchema{
Key: "tp", Key: "tp",
Type: entity.NodeTypeTextProcessor, Type: entity.NodeTypeTextProcessor,
Configs: map[string]any{ Configs: &textprocessor.Config{
"Type": textprocessor.ConcatText, Type: textprocessor.ConcatText,
"Tpl": "{{String1}}_{{String2.f1}}_{{String3.f2[1]}}", Tpl: "{{String1}}_{{String2.f1}}_{{String3.f2[1]}}",
"ConcatChar": "\t", ConcatChar: "\t",
}, },
InputSources: []*vo.FieldInfo{ InputSources: []*vo.FieldInfo{
{ {
Path: compose.FieldPath{"String1"}, Path: compose.FieldPath{"String1"},
Source: vo.FieldSource{ Source: vo.FieldSource{
Ref: &vo.Reference{ Ref: &vo.Reference{
FromNodeKey: entry.Key, FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"Str1"}, FromPath: compose.FieldPath{"Str1"},
}, },
}, },
@@ -576,7 +576,7 @@ func TestTextProcessor(t *testing.T) {
Path: compose.FieldPath{"String2"}, Path: compose.FieldPath{"String2"},
Source: vo.FieldSource{ Source: vo.FieldSource{
Ref: &vo.Reference{ Ref: &vo.Reference{
FromNodeKey: entry.Key, FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"Str2"}, FromPath: compose.FieldPath{"Str2"},
}, },
}, },
@@ -585,7 +585,7 @@ func TestTextProcessor(t *testing.T) {
Path: compose.FieldPath{"String3"}, Path: compose.FieldPath{"String3"},
Source: vo.FieldSource{ Source: vo.FieldSource{
Ref: &vo.Reference{ Ref: &vo.Reference{
FromNodeKey: entry.Key, FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"Str3"}, FromPath: compose.FieldPath{"Str3"},
}, },
}, },
@@ -593,20 +593,20 @@ func TestTextProcessor(t *testing.T) {
}, },
} }
ws := &compose2.WorkflowSchema{ ws := &schema.WorkflowSchema{
Nodes: []*compose2.NodeSchema{ Nodes: []*schema.NodeSchema{
ns, ns,
entry, entryN,
exit, exitN,
}, },
Connections: []*compose2.Connection{ Connections: []*schema.Connection{
{ {
FromNode: entry.Key, FromNode: entryN.Key,
ToNode: "tp", ToNode: "tp",
}, },
{ {
FromNode: "tp", FromNode: "tp",
ToNode: exit.Key, ToNode: exitN.Key,
}, },
}, },
} }

View File

@@ -1,652 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package compose
import (
"context"
"errors"
"fmt"
"runtime/debug"
"strconv"
"time"
einomodel "github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/components/tool"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
workflow3 "github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/workflow"
workflow2 "github.com/coze-dev/coze-studio/backend/domain/workflow"
crosscode "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/code"
crossconversation "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/conversation"
crossdatabase "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
crossknowledge "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
crossplugin "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/batch"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/code"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/conversation"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/emitter"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/httprequester"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/intentdetector"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/json"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/knowledge"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/plugin"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/receiver"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/subworkflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/textprocessor"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableaggregator"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableassigner"
"github.com/coze-dev/coze-studio/backend/infra/contract/coderunner"
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
)
func (s *NodeSchema) ToEntryConfig(_ context.Context) (*entry.Config, error) {
return &entry.Config{
DefaultValues: getKeyOrZero[map[string]any]("DefaultValues", s.Configs),
OutputTypes: s.OutputTypes,
}, nil
}
func (s *NodeSchema) ToLLMConfig(ctx context.Context) (*llm.Config, error) {
llmConf := &llm.Config{
SystemPrompt: getKeyOrZero[string]("SystemPrompt", s.Configs),
UserPrompt: getKeyOrZero[string]("UserPrompt", s.Configs),
OutputFormat: mustGetKey[llm.Format]("OutputFormat", s.Configs),
InputFields: s.InputTypes,
OutputFields: s.OutputTypes,
FullSources: getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs),
}
llmParams := getKeyOrZero[*model.LLMParams]("LLMParams", s.Configs)
if llmParams == nil {
return nil, fmt.Errorf("llm node llmParams is required")
}
var (
err error
chatModel, fallbackM einomodel.BaseChatModel
info, fallbackI *modelmgr.Model
modelWithInfo llm.ModelWithInfo
)
chatModel, info, err = model.GetManager().GetModel(ctx, llmParams)
if err != nil {
return nil, err
}
metaConfigs := s.ExceptionConfigs
if metaConfigs != nil && metaConfigs.MaxRetry > 0 {
backupModelParams := getKeyOrZero[*model.LLMParams]("BackupLLMParams", s.Configs)
if backupModelParams != nil {
fallbackM, fallbackI, err = model.GetManager().GetModel(ctx, backupModelParams)
if err != nil {
return nil, err
}
}
}
if fallbackM == nil {
modelWithInfo = llm.NewModel(chatModel, info)
} else {
modelWithInfo = llm.NewModelWithFallback(chatModel, fallbackM, info, fallbackI)
}
llmConf.ChatModel = modelWithInfo
fcParams := getKeyOrZero[*vo.FCParam]("FCParam", s.Configs)
if fcParams != nil {
if fcParams.WorkflowFCParam != nil {
for _, wf := range fcParams.WorkflowFCParam.WorkflowList {
wfIDStr := wf.WorkflowID
wfID, err := strconv.ParseInt(wfIDStr, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid workflow id: %s", wfIDStr)
}
workflowToolConfig := vo.WorkflowToolConfig{}
if wf.FCSetting != nil {
workflowToolConfig.InputParametersConfig = wf.FCSetting.RequestParameters
workflowToolConfig.OutputParametersConfig = wf.FCSetting.ResponseParameters
}
locator := vo.FromDraft
if wf.WorkflowVersion != "" {
locator = vo.FromSpecificVersion
}
wfTool, err := workflow2.GetRepository().WorkflowAsTool(ctx, vo.GetPolicy{
ID: wfID,
QType: locator,
Version: wf.WorkflowVersion,
}, workflowToolConfig)
if err != nil {
return nil, err
}
llmConf.Tools = append(llmConf.Tools, wfTool)
if wfTool.TerminatePlan() == vo.UseAnswerContent {
if llmConf.ToolsReturnDirectly == nil {
llmConf.ToolsReturnDirectly = make(map[string]bool)
}
toolInfo, err := wfTool.Info(ctx)
if err != nil {
return nil, err
}
llmConf.ToolsReturnDirectly[toolInfo.Name] = true
}
}
}
if fcParams.PluginFCParam != nil {
pluginToolsInvokableReq := make(map[int64]*crossplugin.ToolsInvokableRequest)
for _, p := range fcParams.PluginFCParam.PluginList {
pid, err := strconv.ParseInt(p.PluginID, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid plugin id: %s", p.PluginID)
}
toolID, err := strconv.ParseInt(p.ApiId, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid plugin id: %s", p.PluginID)
}
var (
requestParameters []*workflow3.APIParameter
responseParameters []*workflow3.APIParameter
)
if p.FCSetting != nil {
requestParameters = p.FCSetting.RequestParameters
responseParameters = p.FCSetting.ResponseParameters
}
if req, ok := pluginToolsInvokableReq[pid]; ok {
req.ToolsInvokableInfo[toolID] = &crossplugin.ToolsInvokableInfo{
ToolID: toolID,
RequestAPIParametersConfig: requestParameters,
ResponseAPIParametersConfig: responseParameters,
}
} else {
pluginToolsInfoRequest := &crossplugin.ToolsInvokableRequest{
PluginEntity: crossplugin.Entity{
PluginID: pid,
PluginVersion: ptr.Of(p.PluginVersion),
},
ToolsInvokableInfo: map[int64]*crossplugin.ToolsInvokableInfo{
toolID: {
ToolID: toolID,
RequestAPIParametersConfig: requestParameters,
ResponseAPIParametersConfig: responseParameters,
},
},
IsDraft: p.IsDraft,
}
pluginToolsInvokableReq[pid] = pluginToolsInfoRequest
}
}
inInvokableTools := make([]tool.BaseTool, 0, len(fcParams.PluginFCParam.PluginList))
for _, req := range pluginToolsInvokableReq {
toolMap, err := crossplugin.GetPluginService().GetPluginInvokableTools(ctx, req)
if err != nil {
return nil, err
}
for _, t := range toolMap {
inInvokableTools = append(inInvokableTools, crossplugin.NewInvokableTool(t))
}
}
if len(inInvokableTools) > 0 {
llmConf.Tools = inInvokableTools
}
}
if fcParams.KnowledgeFCParam != nil && len(fcParams.KnowledgeFCParam.KnowledgeList) > 0 {
kwChatModel := workflow2.GetRepository().GetKnowledgeRecallChatModel()
if kwChatModel == nil {
return nil, fmt.Errorf("workflow builtin chat model for knowledge recall not configured")
}
knowledgeOperator := crossknowledge.GetKnowledgeOperator()
setting := fcParams.KnowledgeFCParam.GlobalSetting
cfg := &llm.KnowledgeRecallConfig{
ChatModel: kwChatModel,
Retriever: knowledgeOperator,
}
searchType, err := totRetrievalSearchType(setting.SearchMode)
if err != nil {
return nil, err
}
cfg.RetrievalStrategy = &llm.RetrievalStrategy{
RetrievalStrategy: &crossknowledge.RetrievalStrategy{
TopK: ptr.Of(setting.TopK),
MinScore: ptr.Of(setting.MinScore),
SearchType: searchType,
EnableNL2SQL: setting.UseNL2SQL,
EnableQueryRewrite: setting.UseRewrite,
EnableRerank: setting.UseRerank,
},
NoReCallReplyMode: llm.NoReCallReplyMode(setting.NoRecallReplyMode),
NoReCallReplyCustomizePrompt: setting.NoRecallReplyCustomizePrompt,
}
knowledgeIDs := make([]int64, 0, len(fcParams.KnowledgeFCParam.KnowledgeList))
for _, kw := range fcParams.KnowledgeFCParam.KnowledgeList {
kid, err := strconv.ParseInt(kw.ID, 10, 64)
if err != nil {
return nil, err
}
knowledgeIDs = append(knowledgeIDs, kid)
}
detailResp, err := knowledgeOperator.ListKnowledgeDetail(ctx, &crossknowledge.ListKnowledgeDetailRequest{
KnowledgeIDs: knowledgeIDs,
})
if err != nil {
return nil, err
}
cfg.SelectedKnowledgeDetails = detailResp.KnowledgeDetails
llmConf.KnowledgeRecallConfig = cfg
}
}
return llmConf, nil
}
func (s *NodeSchema) ToSelectorConfig() *selector.Config {
return &selector.Config{
Clauses: mustGetKey[[]*selector.OneClauseSchema]("Clauses", s.Configs),
}
}
func (s *NodeSchema) SelectorInputConverter(in map[string]any) (out []selector.Operants, err error) {
conf := mustGetKey[[]*selector.OneClauseSchema]("Clauses", s.Configs)
for i, oneConf := range conf {
if oneConf.Single != nil {
left, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), selector.LeftKey})
if !ok {
return nil, fmt.Errorf("failed to take left operant from input map: %v, clause index= %d", in, i)
}
right, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), selector.RightKey})
if ok {
out = append(out, selector.Operants{Left: left, Right: right})
} else {
out = append(out, selector.Operants{Left: left})
}
} else if oneConf.Multi != nil {
multiClause := make([]*selector.Operants, 0)
for j := range oneConf.Multi.Clauses {
left, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), strconv.Itoa(j), selector.LeftKey})
if !ok {
return nil, fmt.Errorf("failed to take left operant from input map: %v, clause index= %d, single clause index= %d", in, i, j)
}
right, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), strconv.Itoa(j), selector.RightKey})
if ok {
multiClause = append(multiClause, &selector.Operants{Left: left, Right: right})
} else {
multiClause = append(multiClause, &selector.Operants{Left: left})
}
}
out = append(out, selector.Operants{Multi: multiClause})
} else {
return nil, fmt.Errorf("invalid clause config, both single and multi are nil: %v", oneConf)
}
}
return out, nil
}
func (s *NodeSchema) ToBatchConfig(inner compose.Runnable[map[string]any, map[string]any]) (*batch.Config, error) {
conf := &batch.Config{
BatchNodeKey: s.Key,
InnerWorkflow: inner,
Outputs: s.OutputSources,
}
for key, tInfo := range s.InputTypes {
if tInfo.Type != vo.DataTypeArray {
continue
}
conf.InputArrays = append(conf.InputArrays, key)
}
return conf, nil
}
func (s *NodeSchema) ToVariableAggregatorConfig() (*variableaggregator.Config, error) {
return &variableaggregator.Config{
MergeStrategy: s.Configs.(map[string]any)["MergeStrategy"].(variableaggregator.MergeStrategy),
GroupLen: s.Configs.(map[string]any)["GroupToLen"].(map[string]int),
FullSources: getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs),
NodeKey: s.Key,
InputSources: s.InputSources,
GroupOrder: mustGetKey[[]string]("GroupOrder", s.Configs),
}, nil
}
func (s *NodeSchema) variableAggregatorInputConverter(in map[string]any) (converted map[string]map[int]any) {
converted = make(map[string]map[int]any)
for k, value := range in {
m, ok := value.(map[string]any)
if !ok {
panic(errors.New("value is not a map[string]any"))
}
converted[k] = make(map[int]any, len(m))
for i, sv := range m {
index, err := strconv.Atoi(i)
if err != nil {
panic(fmt.Errorf(" converting %s to int failed, err=%v", i, err))
}
converted[k][index] = sv
}
}
return converted
}
func (s *NodeSchema) variableAggregatorStreamInputConverter(in *schema.StreamReader[map[string]any]) *schema.StreamReader[map[string]map[int]any] {
converter := func(input map[string]any) (output map[string]map[int]any, err error) {
defer func() {
if r := recover(); r != nil {
err = safego.NewPanicErr(r, debug.Stack())
}
}()
return s.variableAggregatorInputConverter(input), nil
}
return schema.StreamReaderWithConvert(in, converter)
}
func (s *NodeSchema) ToTextProcessorConfig() (*textprocessor.Config, error) {
return &textprocessor.Config{
Type: s.Configs.(map[string]any)["Type"].(textprocessor.Type),
Tpl: getKeyOrZero[string]("Tpl", s.Configs.(map[string]any)),
ConcatChar: getKeyOrZero[string]("ConcatChar", s.Configs.(map[string]any)),
Separators: getKeyOrZero[[]string]("Separators", s.Configs.(map[string]any)),
FullSources: getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs),
}, nil
}
func (s *NodeSchema) ToJsonSerializationConfig() (*json.SerializationConfig, error) {
return &json.SerializationConfig{
InputTypes: s.InputTypes,
}, nil
}
func (s *NodeSchema) ToJsonDeserializationConfig() (*json.DeserializationConfig, error) {
return &json.DeserializationConfig{
OutputFields: s.OutputTypes,
}, nil
}
func (s *NodeSchema) ToHTTPRequesterConfig() (*httprequester.Config, error) {
return &httprequester.Config{
URLConfig: mustGetKey[httprequester.URLConfig]("URLConfig", s.Configs),
AuthConfig: getKeyOrZero[*httprequester.AuthenticationConfig]("AuthConfig", s.Configs),
BodyConfig: mustGetKey[httprequester.BodyConfig]("BodyConfig", s.Configs),
Method: mustGetKey[string]("Method", s.Configs),
Timeout: mustGetKey[time.Duration]("Timeout", s.Configs),
RetryTimes: mustGetKey[uint64]("RetryTimes", s.Configs),
MD5FieldMapping: mustGetKey[httprequester.MD5FieldMapping]("MD5FieldMapping", s.Configs),
}, nil
}
func (s *NodeSchema) ToVariableAssignerConfig(handler *variable.Handler) (*variableassigner.Config, error) {
return &variableassigner.Config{
Pairs: s.Configs.([]*variableassigner.Pair),
Handler: handler,
}, nil
}
func (s *NodeSchema) ToVariableAssignerInLoopConfig() (*variableassigner.Config, error) {
return &variableassigner.Config{
Pairs: s.Configs.([]*variableassigner.Pair),
}, nil
}
func (s *NodeSchema) ToLoopConfig(inner compose.Runnable[map[string]any, map[string]any]) (*loop.Config, error) {
conf := &loop.Config{
LoopNodeKey: s.Key,
LoopType: mustGetKey[loop.Type]("LoopType", s.Configs),
IntermediateVars: getKeyOrZero[map[string]*vo.TypeInfo]("IntermediateVars", s.Configs),
Outputs: s.OutputSources,
Inner: inner,
}
for key, tInfo := range s.InputTypes {
if tInfo.Type != vo.DataTypeArray {
continue
}
if _, ok := conf.IntermediateVars[key]; ok { // exclude arrays in intermediate vars
continue
}
conf.InputArrays = append(conf.InputArrays, key)
}
return conf, nil
}
func (s *NodeSchema) ToQAConfig(ctx context.Context) (*qa.Config, error) {
conf := &qa.Config{
QuestionTpl: mustGetKey[string]("QuestionTpl", s.Configs),
AnswerType: mustGetKey[qa.AnswerType]("AnswerType", s.Configs),
ChoiceType: getKeyOrZero[qa.ChoiceType]("ChoiceType", s.Configs),
FixedChoices: getKeyOrZero[[]string]("FixedChoices", s.Configs),
ExtractFromAnswer: getKeyOrZero[bool]("ExtractFromAnswer", s.Configs),
MaxAnswerCount: getKeyOrZero[int]("MaxAnswerCount", s.Configs),
AdditionalSystemPromptTpl: getKeyOrZero[string]("AdditionalSystemPromptTpl", s.Configs),
OutputFields: s.OutputTypes,
NodeKey: s.Key,
}
llmParams := getKeyOrZero[*model.LLMParams]("LLMParams", s.Configs)
if llmParams != nil {
m, _, err := model.GetManager().GetModel(ctx, llmParams)
if err != nil {
return nil, err
}
conf.Model = m
}
return conf, nil
}
func (s *NodeSchema) ToInputReceiverConfig() (*receiver.Config, error) {
return &receiver.Config{
OutputTypes: s.OutputTypes,
NodeKey: s.Key,
OutputSchema: mustGetKey[string]("OutputSchema", s.Configs),
}, nil
}
func (s *NodeSchema) ToOutputEmitterConfig(sc *WorkflowSchema) (*emitter.Config, error) {
conf := &emitter.Config{
Template: getKeyOrZero[string]("Template", s.Configs),
FullSources: getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs),
}
return conf, nil
}
func (s *NodeSchema) ToDatabaseCustomSQLConfig() (*database.CustomSQLConfig, error) {
return &database.CustomSQLConfig{
DatabaseInfoID: mustGetKey[int64]("DatabaseInfoID", s.Configs),
SQLTemplate: mustGetKey[string]("SQLTemplate", s.Configs),
OutputConfig: s.OutputTypes,
CustomSQLExecutor: crossdatabase.GetDatabaseOperator(),
}, nil
}
func (s *NodeSchema) ToDatabaseQueryConfig() (*database.QueryConfig, error) {
return &database.QueryConfig{
DatabaseInfoID: mustGetKey[int64]("DatabaseInfoID", s.Configs),
QueryFields: getKeyOrZero[[]string]("QueryFields", s.Configs),
OrderClauses: getKeyOrZero[[]*crossdatabase.OrderClause]("OrderClauses", s.Configs),
ClauseGroup: getKeyOrZero[*crossdatabase.ClauseGroup]("ClauseGroup", s.Configs),
OutputConfig: s.OutputTypes,
Limit: mustGetKey[int64]("Limit", s.Configs),
Op: crossdatabase.GetDatabaseOperator(),
}, nil
}
func (s *NodeSchema) ToDatabaseInsertConfig() (*database.InsertConfig, error) {
return &database.InsertConfig{
DatabaseInfoID: mustGetKey[int64]("DatabaseInfoID", s.Configs),
OutputConfig: s.OutputTypes,
Inserter: crossdatabase.GetDatabaseOperator(),
}, nil
}
func (s *NodeSchema) ToDatabaseDeleteConfig() (*database.DeleteConfig, error) {
return &database.DeleteConfig{
DatabaseInfoID: mustGetKey[int64]("DatabaseInfoID", s.Configs),
ClauseGroup: mustGetKey[*crossdatabase.ClauseGroup]("ClauseGroup", s.Configs),
OutputConfig: s.OutputTypes,
Deleter: crossdatabase.GetDatabaseOperator(),
}, nil
}
func (s *NodeSchema) ToDatabaseUpdateConfig() (*database.UpdateConfig, error) {
return &database.UpdateConfig{
DatabaseInfoID: mustGetKey[int64]("DatabaseInfoID", s.Configs),
ClauseGroup: mustGetKey[*crossdatabase.ClauseGroup]("ClauseGroup", s.Configs),
OutputConfig: s.OutputTypes,
Updater: crossdatabase.GetDatabaseOperator(),
}, nil
}
func (s *NodeSchema) ToKnowledgeIndexerConfig() (*knowledge.IndexerConfig, error) {
return &knowledge.IndexerConfig{
KnowledgeID: mustGetKey[int64]("KnowledgeID", s.Configs),
ParsingStrategy: mustGetKey[*crossknowledge.ParsingStrategy]("ParsingStrategy", s.Configs),
ChunkingStrategy: mustGetKey[*crossknowledge.ChunkingStrategy]("ChunkingStrategy", s.Configs),
KnowledgeIndexer: crossknowledge.GetKnowledgeOperator(),
}, nil
}
func (s *NodeSchema) ToKnowledgeRetrieveConfig() (*knowledge.RetrieveConfig, error) {
return &knowledge.RetrieveConfig{
KnowledgeIDs: mustGetKey[[]int64]("KnowledgeIDs", s.Configs),
RetrievalStrategy: mustGetKey[*crossknowledge.RetrievalStrategy]("RetrievalStrategy", s.Configs),
Retriever: crossknowledge.GetKnowledgeOperator(),
}, nil
}
func (s *NodeSchema) ToKnowledgeDeleterConfig() (*knowledge.DeleterConfig, error) {
return &knowledge.DeleterConfig{
KnowledgeID: mustGetKey[int64]("KnowledgeID", s.Configs),
KnowledgeDeleter: crossknowledge.GetKnowledgeOperator(),
}, nil
}
func (s *NodeSchema) ToPluginConfig() (*plugin.Config, error) {
return &plugin.Config{
PluginID: mustGetKey[int64]("PluginID", s.Configs),
ToolID: mustGetKey[int64]("ToolID", s.Configs),
PluginVersion: mustGetKey[string]("PluginVersion", s.Configs),
PluginService: crossplugin.GetPluginService(),
}, nil
}
func (s *NodeSchema) ToCodeRunnerConfig() (*code.Config, error) {
return &code.Config{
Code: mustGetKey[string]("Code", s.Configs),
Language: mustGetKey[coderunner.Language]("Language", s.Configs),
OutputConfig: s.OutputTypes,
Runner: crosscode.GetCodeRunner(),
}, nil
}
func (s *NodeSchema) ToCreateConversationConfig() (*conversation.CreateConversationConfig, error) {
return &conversation.CreateConversationConfig{
Creator: crossconversation.ConversationManagerImpl,
}, nil
}
func (s *NodeSchema) ToClearMessageConfig() (*conversation.ClearMessageConfig, error) {
return &conversation.ClearMessageConfig{
Clearer: crossconversation.ConversationManagerImpl,
}, nil
}
func (s *NodeSchema) ToMessageListConfig() (*conversation.MessageListConfig, error) {
return &conversation.MessageListConfig{
Lister: crossconversation.ConversationManagerImpl,
}, nil
}
func (s *NodeSchema) ToIntentDetectorConfig(ctx context.Context) (*intentdetector.Config, error) {
cfg := &intentdetector.Config{
Intents: mustGetKey[[]string]("Intents", s.Configs),
SystemPrompt: getKeyOrZero[string]("SystemPrompt", s.Configs),
IsFastMode: getKeyOrZero[bool]("IsFastMode", s.Configs),
}
llmParams := mustGetKey[*model.LLMParams]("LLMParams", s.Configs)
m, _, err := model.GetManager().GetModel(ctx, llmParams)
if err != nil {
return nil, err
}
cfg.ChatModel = m
return cfg, nil
}
func (s *NodeSchema) ToSubWorkflowConfig(ctx context.Context, requireCheckpoint bool) (*subworkflow.Config, error) {
var opts []WorkflowOption
opts = append(opts, WithIDAsName(mustGetKey[int64]("WorkflowID", s.Configs)))
if requireCheckpoint {
opts = append(opts, WithParentRequireCheckpoint())
}
if s := execute.GetStaticConfig(); s != nil && s.MaxNodeCountPerWorkflow > 0 {
opts = append(opts, WithMaxNodeCount(s.MaxNodeCountPerWorkflow))
}
wf, err := NewWorkflow(ctx, s.SubWorkflowSchema, opts...)
if err != nil {
return nil, err
}
return &subworkflow.Config{
Runner: wf.Runner,
}, nil
}
func totRetrievalSearchType(s int64) (crossknowledge.SearchType, error) {
switch s {
case 0:
return crossknowledge.SearchTypeSemantic, nil
case 1:
return crossknowledge.SearchTypeHybrid, nil
case 20:
return crossknowledge.SearchTypeFullText, nil
default:
return "", fmt.Errorf("invalid retrieval search type %v", s)
}
}

View File

@@ -1,107 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package compose
import (
"fmt"
"reflect"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
)
func getKeyOrZero[T any](key string, cfg any) T {
var zero T
if cfg == nil {
return zero
}
m, ok := cfg.(map[string]any)
if !ok {
panic(fmt.Sprintf("m is not a map[string]any, actual type: %v", reflect.TypeOf(cfg)))
}
if len(m) == 0 {
return zero
}
if v, ok := m[key]; ok {
return v.(T)
}
return zero
}
func mustGetKey[T any](key string, cfg any) T {
if cfg == nil {
panic(fmt.Sprintf("mustGetKey[*any] is nil, key=%s", key))
}
m, ok := cfg.(map[string]any)
if !ok {
panic(fmt.Sprintf("m is not a map[string]any, actual type: %v", reflect.TypeOf(cfg)))
}
if _, ok := m[key]; !ok {
panic(fmt.Sprintf("key %s does not exist in map: %v", key, m))
}
v, ok := m[key].(T)
if !ok {
panic(fmt.Sprintf("key %s is not a %v, actual type: %v", key, reflect.TypeOf(v), reflect.TypeOf(m[key])))
}
return v
}
func (s *NodeSchema) SetConfigKV(key string, value any) {
if s.Configs == nil {
s.Configs = make(map[string]any)
}
s.Configs.(map[string]any)[key] = value
}
func (s *NodeSchema) SetInputType(key string, t *vo.TypeInfo) {
if s.InputTypes == nil {
s.InputTypes = make(map[string]*vo.TypeInfo)
}
s.InputTypes[key] = t
}
func (s *NodeSchema) AddInputSource(info ...*vo.FieldInfo) {
s.InputSources = append(s.InputSources, info...)
}
func (s *NodeSchema) SetOutputType(key string, t *vo.TypeInfo) {
if s.OutputTypes == nil {
s.OutputTypes = make(map[string]*vo.TypeInfo)
}
s.OutputTypes[key] = t
}
func (s *NodeSchema) AddOutputSource(info ...*vo.FieldInfo) {
s.OutputSources = append(s.OutputSources, info...)
}
func (s *NodeSchema) GetSubWorkflowIdentity() (int64, string, bool) {
if s.Type != entity.NodeTypeSubWorkflow {
return 0, "", false
}
return mustGetKey[int64]("WorkflowID", s.Configs), mustGetKey[string]("WorkflowVersion", s.Configs), true
}

View File

@@ -29,6 +29,8 @@ import (
workflow2 "github.com/coze-dev/coze-studio/backend/domain/workflow" workflow2 "github.com/coze-dev/coze-studio/backend/domain/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/safego" "github.com/coze-dev/coze-studio/backend/pkg/safego"
) )
@@ -37,7 +39,7 @@ type workflow = compose.Workflow[map[string]any, map[string]any]
type Workflow struct { // TODO: too many fields in this struct, cut them down to the absolutely essentials type Workflow struct { // TODO: too many fields in this struct, cut them down to the absolutely essentials
*workflow *workflow
hierarchy map[vo.NodeKey]vo.NodeKey hierarchy map[vo.NodeKey]vo.NodeKey
connections []*Connection connections []*schema.Connection
requireCheckpoint bool requireCheckpoint bool
entry *compose.WorkflowNode entry *compose.WorkflowNode
inner bool inner bool
@@ -47,7 +49,7 @@ type Workflow struct { // TODO: too many fields in this struct, cut them down to
input map[string]*vo.TypeInfo input map[string]*vo.TypeInfo
output map[string]*vo.TypeInfo output map[string]*vo.TypeInfo
terminatePlan vo.TerminatePlan terminatePlan vo.TerminatePlan
schema *WorkflowSchema schema *schema.WorkflowSchema
} }
type workflowOptions struct { type workflowOptions struct {
@@ -78,7 +80,7 @@ func WithMaxNodeCount(c int) WorkflowOption {
} }
} }
func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption) (*Workflow, error) { func NewWorkflow(ctx context.Context, sc *schema.WorkflowSchema, opts ...WorkflowOption) (*Workflow, error) {
sc.Init() sc.Init()
wf := &Workflow{ wf := &Workflow{
@@ -88,8 +90,8 @@ func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption
schema: sc, schema: sc,
} }
wf.streamRun = sc.requireStreaming wf.streamRun = sc.RequireStreaming()
wf.requireCheckpoint = sc.requireCheckPoint wf.requireCheckpoint = sc.RequireCheckpoint()
wfOpts := &workflowOptions{} wfOpts := &workflowOptions{}
for _, opt := range opts { for _, opt := range opts {
@@ -125,7 +127,6 @@ func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption
processedNodeKey[child.Key] = struct{}{} processedNodeKey[child.Key] = struct{}{}
} }
} }
// add all nodes other than composite nodes and their children // add all nodes other than composite nodes and their children
for _, ns := range sc.Nodes { for _, ns := range sc.Nodes {
if _, ok := processedNodeKey[ns.Key]; !ok { if _, ok := processedNodeKey[ns.Key]; !ok {
@@ -135,7 +136,7 @@ func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption
} }
if ns.Type == entity.NodeTypeExit { if ns.Type == entity.NodeTypeExit {
wf.terminatePlan = mustGetKey[vo.TerminatePlan]("TerminalPlan", ns.Configs) wf.terminatePlan = ns.Configs.(*exit.Config).TerminatePlan
} }
} }
@@ -147,7 +148,7 @@ func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption
compileOpts = append(compileOpts, compose.WithGraphName(strconv.FormatInt(wfOpts.wfID, 10))) compileOpts = append(compileOpts, compose.WithGraphName(strconv.FormatInt(wfOpts.wfID, 10)))
} }
fanInConfigs := sc.fanInMergeConfigs() fanInConfigs := sc.FanInMergeConfigs()
if len(fanInConfigs) > 0 { if len(fanInConfigs) > 0 {
compileOpts = append(compileOpts, compose.WithFanInMergeConfig(fanInConfigs)) compileOpts = append(compileOpts, compose.WithFanInMergeConfig(fanInConfigs))
} }
@@ -199,12 +200,12 @@ type innerWorkflowInfo struct {
carryOvers map[vo.NodeKey][]*compose.FieldMapping carryOvers map[vo.NodeKey][]*compose.FieldMapping
} }
func (w *Workflow) AddNode(ctx context.Context, ns *NodeSchema) error { func (w *Workflow) AddNode(ctx context.Context, ns *schema.NodeSchema) error {
_, err := w.addNodeInternal(ctx, ns, nil) _, err := w.addNodeInternal(ctx, ns, nil)
return err return err
} }
func (w *Workflow) AddCompositeNode(ctx context.Context, cNode *CompositeNode) error { func (w *Workflow) AddCompositeNode(ctx context.Context, cNode *schema.CompositeNode) error {
inner, err := w.getInnerWorkflow(ctx, cNode) inner, err := w.getInnerWorkflow(ctx, cNode)
if err != nil { if err != nil {
return err return err
@@ -213,11 +214,11 @@ func (w *Workflow) AddCompositeNode(ctx context.Context, cNode *CompositeNode) e
return err return err
} }
func (w *Workflow) addInnerNode(ctx context.Context, cNode *NodeSchema) (map[vo.NodeKey][]*compose.FieldMapping, error) { func (w *Workflow) addInnerNode(ctx context.Context, cNode *schema.NodeSchema) (map[vo.NodeKey][]*compose.FieldMapping, error) {
return w.addNodeInternal(ctx, cNode, nil) return w.addNodeInternal(ctx, cNode, nil)
} }
func (w *Workflow) addNodeInternal(ctx context.Context, ns *NodeSchema, inner *innerWorkflowInfo) (map[vo.NodeKey][]*compose.FieldMapping, error) { func (w *Workflow) addNodeInternal(ctx context.Context, ns *schema.NodeSchema, inner *innerWorkflowInfo) (map[vo.NodeKey][]*compose.FieldMapping, error) {
key := ns.Key key := ns.Key
var deps *dependencyInfo var deps *dependencyInfo
@@ -237,7 +238,7 @@ func (w *Workflow) addNodeInternal(ctx context.Context, ns *NodeSchema, inner *i
innerWorkflow = inner.inner innerWorkflow = inner.inner
} }
ins, err := ns.New(ctx, innerWorkflow, w.schema, deps) ins, err := New(ctx, ns, innerWorkflow, w.schema, deps)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -245,12 +246,12 @@ func (w *Workflow) addNodeInternal(ctx context.Context, ns *NodeSchema, inner *i
var opts []compose.GraphAddNodeOpt var opts []compose.GraphAddNodeOpt
opts = append(opts, compose.WithNodeName(string(ns.Key))) opts = append(opts, compose.WithNodeName(string(ns.Key)))
preHandler := ns.StatePreHandler(w.streamRun) preHandler := statePreHandler(ns, w.streamRun)
if preHandler != nil { if preHandler != nil {
opts = append(opts, preHandler) opts = append(opts, preHandler)
} }
postHandler := ns.StatePostHandler(w.streamRun) postHandler := statePostHandler(ns, w.streamRun)
if postHandler != nil { if postHandler != nil {
opts = append(opts, postHandler) opts = append(opts, postHandler)
} }
@@ -297,19 +298,23 @@ func (w *Workflow) addNodeInternal(ctx context.Context, ns *NodeSchema, inner *i
w.entry = wNode w.entry = wNode
} }
outputPortCount, hasExceptionPort := ns.OutputPortCount() b := w.schema.GetBranch(ns.Key)
if outputPortCount > 1 || hasExceptionPort { if b != nil {
bMapping, err := w.resolveBranch(key, outputPortCount) if b.OnlyException() {
if err != nil { _ = w.AddBranch(string(key), b.GetExceptionBranch())
return nil, err } else {
} bb, ok := ns.Configs.(schema.BranchBuilder)
if !ok {
return nil, fmt.Errorf("node schema's Configs should implement BranchBuilder, node type= %v", ns.Type)
}
branch, err := ns.GetBranch(bMapping) br, err := b.GetFullBranch(ctx, bb)
if err != nil { if err != nil {
return nil, err return nil, err
} }
_ = w.AddBranch(string(key), branch) _ = w.AddBranch(string(key), br)
}
} }
return deps.inputsForParent, nil return deps.inputsForParent, nil
@@ -328,15 +333,15 @@ func (w *Workflow) Compile(ctx context.Context, opts ...compose.GraphCompileOpti
return w.workflow.Compile(ctx, opts...) return w.workflow.Compile(ctx, opts...)
} }
func (w *Workflow) getInnerWorkflow(ctx context.Context, cNode *CompositeNode) (*innerWorkflowInfo, error) { func (w *Workflow) getInnerWorkflow(ctx context.Context, cNode *schema.CompositeNode) (*innerWorkflowInfo, error) {
innerNodes := make(map[vo.NodeKey]*NodeSchema) innerNodes := make(map[vo.NodeKey]*schema.NodeSchema)
for _, n := range cNode.Children { for _, n := range cNode.Children {
innerNodes[n.Key] = n innerNodes[n.Key] = n
} }
// trim the connections, only keep the connections that are related to the inner workflow // trim the connections, only keep the connections that are related to the inner workflow
// ignore the cases when we have nested inner workflows, because we do not support nested composite nodes // ignore the cases when we have nested inner workflows, because we do not support nested composite nodes
innerConnections := make([]*Connection, 0) innerConnections := make([]*schema.Connection, 0)
for i := range w.schema.Connections { for i := range w.schema.Connections {
conn := w.schema.Connections[i] conn := w.schema.Connections[i]
if _, ok := innerNodes[conn.FromNode]; ok { if _, ok := innerNodes[conn.FromNode]; ok {
@@ -510,7 +515,7 @@ func (d *dependencyInfo) merge(mappings map[vo.NodeKey][]*compose.FieldMapping)
// For example, if the 'from path' is ['a', 'b', 'c'], and 'b' is an array, we will take value using a.b[0].c. // For example, if the 'from path' is ['a', 'b', 'c'], and 'b' is an array, we will take value using a.b[0].c.
// As a counter example, if the 'from path' is ['a', 'b', 'c'], and 'b' is not an array, but 'c' is an array, // As a counter example, if the 'from path' is ['a', 'b', 'c'], and 'b' is not an array, but 'c' is an array,
// we will not try to drill, instead, just take value using a.b.c. // we will not try to drill, instead, just take value using a.b.c.
func (d *dependencyInfo) arrayDrillDown(allNS map[vo.NodeKey]*NodeSchema) error { func (d *dependencyInfo) arrayDrillDown(allNS map[vo.NodeKey]*schema.NodeSchema) error {
for nKey, fms := range d.inputs { for nKey, fms := range d.inputs {
if nKey == compose.START { // reference to START node would NEVER need to do array drill down if nKey == compose.START { // reference to START node would NEVER need to do array drill down
continue continue
@@ -638,55 +643,6 @@ type variableInfo struct {
toPath compose.FieldPath toPath compose.FieldPath
} }
func (w *Workflow) resolveBranch(n vo.NodeKey, portCount int) (*BranchMapping, error) {
m := make([]map[string]bool, portCount)
var exception map[string]bool
for _, conn := range w.connections {
if conn.FromNode != n {
continue
}
if conn.FromPort == nil {
continue
}
if *conn.FromPort == "default" { // default condition
if m[portCount-1] == nil {
m[portCount-1] = make(map[string]bool)
}
m[portCount-1][string(conn.ToNode)] = true
} else if *conn.FromPort == "branch_error" {
if exception == nil {
exception = make(map[string]bool)
}
exception[string(conn.ToNode)] = true
} else {
if !strings.HasPrefix(*conn.FromPort, "branch_") {
return nil, fmt.Errorf("outgoing connections has invalid port= %s", *conn.FromPort)
}
index := (*conn.FromPort)[7:]
i, err := strconv.Atoi(index)
if err != nil {
return nil, fmt.Errorf("outgoing connections has invalid port index= %s", *conn.FromPort)
}
if i < 0 || i >= portCount {
return nil, fmt.Errorf("outgoing connections has invalid port index range= %d, condition count= %d", i, portCount)
}
if m[i] == nil {
m[i] = make(map[string]bool)
}
m[i][string(conn.ToNode)] = true
}
}
return &BranchMapping{
Normal: m,
Exception: exception,
}, nil
}
func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.FieldInfo) (*dependencyInfo, error) { func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.FieldInfo) (*dependencyInfo, error) {
var ( var (
inputs = make(map[vo.NodeKey][]*compose.FieldMapping) inputs = make(map[vo.NodeKey][]*compose.FieldMapping)
@@ -701,7 +657,7 @@ func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.Field
inputsForParent = make(map[vo.NodeKey][]*compose.FieldMapping) inputsForParent = make(map[vo.NodeKey][]*compose.FieldMapping)
) )
connMap := make(map[vo.NodeKey]Connection) connMap := make(map[vo.NodeKey]schema.Connection)
for _, conn := range w.connections { for _, conn := range w.connections {
if conn.ToNode != n { if conn.ToNode != n {
continue continue
@@ -734,7 +690,7 @@ func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.Field
continue continue
} }
if ok := isInSameWorkflow(w.hierarchy, n, fromNode); ok { if ok := schema.IsInSameWorkflow(w.hierarchy, n, fromNode); ok {
if _, ok := connMap[fromNode]; ok { // direct dependency if _, ok := connMap[fromNode]; ok { // direct dependency
if len(swp.Source.Ref.FromPath) == 0 && len(swp.Path) == 0 { if len(swp.Source.Ref.FromPath) == 0 && len(swp.Path) == 0 {
if inputFull == nil { if inputFull == nil {
@@ -755,10 +711,10 @@ func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.Field
compose.MapFieldPaths(swp.Source.Ref.FromPath, swp.Path)) compose.MapFieldPaths(swp.Source.Ref.FromPath, swp.Path))
} }
} }
} else if ok := isBelowOneLevel(w.hierarchy, n, fromNode); ok { } else if ok := schema.IsBelowOneLevel(w.hierarchy, n, fromNode); ok {
firstNodesInInnerWorkflow := true firstNodesInInnerWorkflow := true
for _, conn := range connMap { for _, conn := range connMap {
if isInSameWorkflow(w.hierarchy, n, conn.FromNode) { if schema.IsInSameWorkflow(w.hierarchy, n, conn.FromNode) {
// there is another node 'conn.FromNode' that connects to this node, while also at the same level // there is another node 'conn.FromNode' that connects to this node, while also at the same level
firstNodesInInnerWorkflow = false firstNodesInInnerWorkflow = false
break break
@@ -805,9 +761,9 @@ func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.Field
continue continue
} }
if isBelowOneLevel(w.hierarchy, n, fromNodeKey) { if schema.IsBelowOneLevel(w.hierarchy, n, fromNodeKey) {
fromNodeKey = compose.START fromNodeKey = compose.START
} else if !isInSameWorkflow(w.hierarchy, n, fromNodeKey) { } else if !schema.IsInSameWorkflow(w.hierarchy, n, fromNodeKey) {
continue continue
} }
@@ -864,13 +820,13 @@ func (w *Workflow) resolveDependenciesAsParent(n vo.NodeKey, sourceWithPaths []*
variableInfos []*variableInfo variableInfos []*variableInfo
) )
connMap := make(map[vo.NodeKey]Connection) connMap := make(map[vo.NodeKey]schema.Connection)
for _, conn := range w.connections { for _, conn := range w.connections {
if conn.ToNode != n { if conn.ToNode != n {
continue continue
} }
if isInSameWorkflow(w.hierarchy, conn.FromNode, n) { if schema.IsInSameWorkflow(w.hierarchy, conn.FromNode, n) {
continue continue
} }
@@ -899,7 +855,7 @@ func (w *Workflow) resolveDependenciesAsParent(n vo.NodeKey, sourceWithPaths []*
swp.Source.Ref.FromPath, swp.Path) swp.Source.Ref.FromPath, swp.Path)
} }
if ok := isParentOf(w.hierarchy, n, fromNode); ok { if ok := schema.IsParentOf(w.hierarchy, n, fromNode); ok {
if _, ok := connMap[fromNode]; ok { // direct dependency if _, ok := connMap[fromNode]; ok { // direct dependency
inputs[fromNode] = append(inputs[fromNode], compose.MapFieldPaths(swp.Source.Ref.FromPath, append(compose.FieldPath{string(fromNode)}, swp.Source.Ref.FromPath...))) inputs[fromNode] = append(inputs[fromNode], compose.MapFieldPaths(swp.Source.Ref.FromPath, append(compose.FieldPath{string(fromNode)}, swp.Source.Ref.FromPath...)))
} else { // indirect dependency } else { // indirect dependency

View File

@@ -23,9 +23,10 @@ import (
workflow2 "github.com/coze-dev/coze-studio/backend/domain/workflow" workflow2 "github.com/coze-dev/coze-studio/backend/domain/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
) )
func NewWorkflowFromNode(ctx context.Context, sc *WorkflowSchema, nodeKey vo.NodeKey, opts ...compose.GraphCompileOption) ( func NewWorkflowFromNode(ctx context.Context, sc *schema.WorkflowSchema, nodeKey vo.NodeKey, opts ...compose.GraphCompileOption) (
*Workflow, error) { *Workflow, error) {
sc.Init() sc.Init()
ns := sc.GetNode(nodeKey) ns := sc.GetNode(nodeKey)
@@ -37,7 +38,7 @@ func NewWorkflowFromNode(ctx context.Context, sc *WorkflowSchema, nodeKey vo.Nod
schema: sc, schema: sc,
fromNode: true, fromNode: true,
streamRun: false, // single node run can only invoke streamRun: false, // single node run can only invoke
requireCheckpoint: sc.requireCheckPoint, requireCheckpoint: sc.RequireCheckpoint(),
input: ns.InputTypes, input: ns.InputTypes,
output: ns.OutputTypes, output: ns.OutputTypes,
terminatePlan: vo.ReturnVariables, terminatePlan: vo.ReturnVariables,

View File

@@ -32,6 +32,7 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary" "github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
"github.com/coze-dev/coze-studio/backend/pkg/logs" "github.com/coze-dev/coze-studio/backend/pkg/logs"
@@ -42,7 +43,7 @@ type WorkflowRunner struct {
basic *entity.WorkflowBasic basic *entity.WorkflowBasic
input string input string
resumeReq *entity.ResumeRequest resumeReq *entity.ResumeRequest
schema *WorkflowSchema schema *schema2.WorkflowSchema
streamWriter *schema.StreamWriter[*entity.Message] streamWriter *schema.StreamWriter[*entity.Message]
config vo.ExecuteConfig config vo.ExecuteConfig
@@ -76,7 +77,7 @@ func WithStreamWriter(sw *schema.StreamWriter[*entity.Message]) WorkflowRunnerOp
} }
} }
func NewWorkflowRunner(b *entity.WorkflowBasic, sc *WorkflowSchema, config vo.ExecuteConfig, opts ...WorkflowRunnerOption) *WorkflowRunner { func NewWorkflowRunner(b *entity.WorkflowBasic, sc *schema2.WorkflowSchema, config vo.ExecuteConfig, opts ...WorkflowRunnerOption) *WorkflowRunner {
options := &workflowRunOptions{} options := &workflowRunOptions{}
for _, opt := range opts { for _, opt := range opts {
opt(options) opt(options)

View File

@@ -30,6 +30,7 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/logs" "github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/sonic" "github.com/coze-dev/coze-studio/backend/pkg/sonic"
) )
@@ -41,7 +42,7 @@ type invokableWorkflow struct {
invoke func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (map[string]any, error) invoke func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (map[string]any, error)
terminatePlan vo.TerminatePlan terminatePlan vo.TerminatePlan
wfEntity *entity.Workflow wfEntity *entity.Workflow
sc *WorkflowSchema sc *schema2.WorkflowSchema
repo wf.Repository repo wf.Repository
} }
@@ -49,7 +50,7 @@ func NewInvokableWorkflow(info *schema.ToolInfo,
invoke func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (map[string]any, error), invoke func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (map[string]any, error),
terminatePlan vo.TerminatePlan, terminatePlan vo.TerminatePlan,
wfEntity *entity.Workflow, wfEntity *entity.Workflow,
sc *WorkflowSchema, sc *schema2.WorkflowSchema,
repo wf.Repository, repo wf.Repository,
) wf.ToolFromWorkflow { ) wf.ToolFromWorkflow {
return &invokableWorkflow{ return &invokableWorkflow{
@@ -112,7 +113,7 @@ func (i *invokableWorkflow) InvokableRun(ctx context.Context, argumentsInJSON st
return "", err return "", err
} }
var entryNode *NodeSchema var entryNode *schema2.NodeSchema
for _, node := range i.sc.Nodes { for _, node := range i.sc.Nodes {
if node.Type == entity.NodeTypeEntry { if node.Type == entity.NodeTypeEntry {
entryNode = node entryNode = node
@@ -190,7 +191,7 @@ type streamableWorkflow struct {
stream func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (*schema.StreamReader[map[string]any], error) stream func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (*schema.StreamReader[map[string]any], error)
terminatePlan vo.TerminatePlan terminatePlan vo.TerminatePlan
wfEntity *entity.Workflow wfEntity *entity.Workflow
sc *WorkflowSchema sc *schema2.WorkflowSchema
repo wf.Repository repo wf.Repository
} }
@@ -198,7 +199,7 @@ func NewStreamableWorkflow(info *schema.ToolInfo,
stream func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (*schema.StreamReader[map[string]any], error), stream func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (*schema.StreamReader[map[string]any], error),
terminatePlan vo.TerminatePlan, terminatePlan vo.TerminatePlan,
wfEntity *entity.Workflow, wfEntity *entity.Workflow,
sc *WorkflowSchema, sc *schema2.WorkflowSchema,
repo wf.Repository, repo wf.Repository,
) wf.ToolFromWorkflow { ) wf.ToolFromWorkflow {
return &streamableWorkflow{ return &streamableWorkflow{
@@ -261,7 +262,7 @@ func (s *streamableWorkflow) StreamableRun(ctx context.Context, argumentsInJSON
return nil, err return nil, err
} }
var entryNode *NodeSchema var entryNode *schema2.NodeSchema
for _, node := range s.sc.Nodes { for _, node := range s.sc.Nodes {
if node.Type == entity.NodeTypeEntry { if node.Type == entity.NodeTypeEntry {
entryNode = node entryNode = node

View File

@@ -30,50 +30,108 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/logs" "github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/safego" "github.com/coze-dev/coze-studio/backend/pkg/safego"
) )
type Batch struct { type Batch struct {
config *Config outputs map[string]*vo.FieldSource
outputs map[string]*vo.FieldSource innerWorkflow compose.Runnable[map[string]any, map[string]any]
key vo.NodeKey
inputArrays []string
} }
type Config struct { type Config struct{}
BatchNodeKey vo.NodeKey `json:"batch_node_key"`
InnerWorkflow compose.Runnable[map[string]any, map[string]any]
InputArrays []string `json:"input_arrays"` func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
Outputs []*vo.FieldInfo `json:"outputs"` if n.Parent() != nil {
} return nil, fmt.Errorf("batch node cannot have parent: %s", n.Parent().ID)
func NewBatch(_ context.Context, config *Config) (*Batch, error) {
if config == nil {
return nil, errors.New("config is required")
} }
if len(config.InputArrays) == 0 { ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeBatch,
Name: n.Data.Meta.Title,
Configs: c,
}
batchSizeField, err := convert.CanvasBlockInputToFieldInfo(n.Data.Inputs.BatchSize,
compose.FieldPath{MaxBatchSizeKey}, nil)
if err != nil {
return nil, err
}
ns.AddInputSource(batchSizeField...)
concurrentSizeField, err := convert.CanvasBlockInputToFieldInfo(n.Data.Inputs.ConcurrentSize,
compose.FieldPath{ConcurrentSizeKey}, nil)
if err != nil {
return nil, err
}
ns.AddInputSource(concurrentSizeField...)
batchSizeType, err := convert.CanvasBlockInputToTypeInfo(n.Data.Inputs.BatchSize)
if err != nil {
return nil, err
}
ns.SetInputType(MaxBatchSizeKey, batchSizeType)
concurrentSizeType, err := convert.CanvasBlockInputToTypeInfo(n.Data.Inputs.ConcurrentSize)
if err != nil {
return nil, err
}
ns.SetInputType(ConcurrentSizeKey, concurrentSizeType)
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err := convert.SetOutputsForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, opts ...schema.BuildOption) (any, error) {
var inputArrays []string
for key, tInfo := range ns.InputTypes {
if tInfo.Type != vo.DataTypeArray {
continue
}
inputArrays = append(inputArrays, key)
}
if len(inputArrays) == 0 {
return nil, errors.New("need to have at least one incoming array for batch") return nil, errors.New("need to have at least one incoming array for batch")
} }
if len(config.Outputs) == 0 { if len(ns.OutputSources) == 0 {
return nil, errors.New("need to have at least one output variable for batch") return nil, errors.New("need to have at least one output variable for batch")
} }
b := &Batch{ bo := schema.GetBuildOptions(opts...)
config: config, if bo.Inner == nil {
outputs: make(map[string]*vo.FieldSource), return nil, errors.New("need to have inner workflow for batch")
} }
for i := range config.Outputs { b := &Batch{
source := config.Outputs[i] outputs: make(map[string]*vo.FieldSource),
innerWorkflow: bo.Inner,
key: ns.Key,
inputArrays: inputArrays,
}
for i := range ns.OutputSources {
source := ns.OutputSources[i]
path := source.Path path := source.Path
if len(path) != 1 { if len(path) != 1 {
return nil, fmt.Errorf("invalid path %q", path) return nil, fmt.Errorf("invalid path %q", path)
} }
// from which inner node's which field does the batch's output fields come from
b.outputs[path[0]] = &source.Source b.outputs[path[0]] = &source.Source
} }
@@ -97,11 +155,11 @@ func (b *Batch) initOutput(length int) map[string]any {
return out return out
} }
func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.NestedWorkflowOption) ( func (b *Batch) Invoke(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (
out map[string]any, err error) { out map[string]any, err error) {
arrays := make(map[string]any, len(b.config.InputArrays)) arrays := make(map[string]any, len(b.inputArrays))
minLen := math.MaxInt64 minLen := math.MaxInt64
for _, arrayKey := range b.config.InputArrays { for _, arrayKey := range b.inputArrays {
a, ok := nodes.TakeMapValue(in, compose.FieldPath{arrayKey}) a, ok := nodes.TakeMapValue(in, compose.FieldPath{arrayKey})
if !ok { if !ok {
return nil, fmt.Errorf("incoming array not present in input: %s", arrayKey) return nil, fmt.Errorf("incoming array not present in input: %s", arrayKey)
@@ -160,13 +218,13 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
} }
} }
input[string(b.config.BatchNodeKey)+"#index"] = int64(i) input[string(b.key)+"#index"] = int64(i)
items := make(map[string]any) items := make(map[string]any)
for arrayKey, array := range arrays { for arrayKey, array := range arrays {
ele := reflect.ValueOf(array).Index(i).Interface() ele := reflect.ValueOf(array).Index(i).Interface()
items[arrayKey] = []any{ele} items[arrayKey] = []any{ele}
currentKey := string(b.config.BatchNodeKey) + "#" + arrayKey currentKey := string(b.key) + "#" + arrayKey
// Recursively expand map[string]any elements // Recursively expand map[string]any elements
var expand func(prefix string, val interface{}) var expand func(prefix string, val interface{})
@@ -200,15 +258,11 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
return nil return nil
} }
options := &nodes.NestedWorkflowOptions{} options := nodes.GetCommonOptions(&nodes.NodeOptions{}, opts...)
for _, opt := range opts {
opt(options)
}
var existingCState *nodes.NestedWorkflowState var existingCState *nodes.NestedWorkflowState
err = compose.ProcessState(ctx, func(ctx context.Context, getter nodes.NestedWorkflowAware) error { err = compose.ProcessState(ctx, func(ctx context.Context, getter nodes.NestedWorkflowAware) error {
var e error var e error
existingCState, _, e = getter.GetNestedWorkflowState(b.config.BatchNodeKey) existingCState, _, e = getter.GetNestedWorkflowState(b.key)
if e != nil { if e != nil {
return e return e
} }
@@ -280,7 +334,7 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
mu.Unlock() mu.Unlock()
if subCheckpointID != "" { if subCheckpointID != "" {
logs.CtxInfof(ctx, "[testInterrupt] prepare %d th run for batch node %s, subCheckPointID %s", logs.CtxInfof(ctx, "[testInterrupt] prepare %d th run for batch node %s, subCheckPointID %s",
i, b.config.BatchNodeKey, subCheckpointID) i, b.key, subCheckpointID)
ithOpts = append(ithOpts, compose.WithCheckPointID(subCheckpointID)) ithOpts = append(ithOpts, compose.WithCheckPointID(subCheckpointID))
} }
@@ -298,7 +352,7 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
// if the innerWorkflow has output emitter that requires stream output, then we need to stream the inner workflow // if the innerWorkflow has output emitter that requires stream output, then we need to stream the inner workflow
// the output then needs to be concatenated. // the output then needs to be concatenated.
taskOutput, err := b.config.InnerWorkflow.Invoke(subCtx, input, ithOpts...) taskOutput, err := b.innerWorkflow.Invoke(subCtx, input, ithOpts...)
if err != nil { if err != nil {
info, ok := compose.ExtractInterruptInfo(err) info, ok := compose.ExtractInterruptInfo(err)
if !ok { if !ok {
@@ -376,17 +430,17 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
if len(index2InterruptInfo) > 0 { // this invocation of batch.Execute has new interruptions if len(index2InterruptInfo) > 0 { // this invocation of batch.Execute has new interruptions
iEvent := &entity.InterruptEvent{ iEvent := &entity.InterruptEvent{
NodeKey: b.config.BatchNodeKey, NodeKey: b.key,
NodeType: entity.NodeTypeBatch, NodeType: entity.NodeTypeBatch,
NestedInterruptInfo: index2InterruptInfo, // only emit the newly generated interruptInfo NestedInterruptInfo: index2InterruptInfo, // only emit the newly generated interruptInfo
} }
err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error { err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error {
if e := setter.SaveNestedWorkflowState(b.config.BatchNodeKey, compState); e != nil { if e := setter.SaveNestedWorkflowState(b.key, compState); e != nil {
return e return e
} }
return setter.SetInterruptEvent(b.config.BatchNodeKey, iEvent) return setter.SetInterruptEvent(b.key, iEvent)
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@@ -398,7 +452,7 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
return nil, compose.InterruptAndRerun return nil, compose.InterruptAndRerun
} else { } else {
err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error { err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error {
if e := setter.SaveNestedWorkflowState(b.config.BatchNodeKey, compState); e != nil { if e := setter.SaveNestedWorkflowState(b.key, compState); e != nil {
return e return e
} }
@@ -409,8 +463,8 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
// although this invocation does not have new interruptions, // although this invocation does not have new interruptions,
// this batch node previously have interrupts yet to be resumed. // this batch node previously have interrupts yet to be resumed.
// we overwrite the interrupt events, keeping only the interrupts yet to be resumed. // we overwrite the interrupt events, keeping only the interrupts yet to be resumed.
return setter.SetInterruptEvent(b.config.BatchNodeKey, &entity.InterruptEvent{ return setter.SetInterruptEvent(b.key, &entity.InterruptEvent{
NodeKey: b.config.BatchNodeKey, NodeKey: b.key,
NodeType: entity.NodeTypeBatch, NodeType: entity.NodeTypeBatch,
NestedInterruptInfo: existingCState.Index2InterruptInfo, NestedInterruptInfo: existingCState.Index2InterruptInfo,
}) })
@@ -424,7 +478,7 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
if existingCState != nil && len(existingCState.Index2InterruptInfo) > 0 { if existingCState != nil && len(existingCState.Index2InterruptInfo) > 0 {
logs.CtxInfof(ctx, "no interrupt thrown this round, but has historical interrupt events yet to be resumed, "+ logs.CtxInfof(ctx, "no interrupt thrown this round, but has historical interrupt events yet to be resumed, "+
"nodeKey: %v. indexes: %v", b.config.BatchNodeKey, maps.Keys(existingCState.Index2InterruptInfo)) "nodeKey: %v. indexes: %v", b.key, maps.Keys(existingCState.Index2InterruptInfo))
return nil, compose.InterruptAndRerun // interrupt again to wait for resuming of previously interrupted index runs return nil, compose.InterruptAndRerun // interrupt again to wait for resuming of previously interrupted index runs
} }
@@ -432,8 +486,8 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
} }
func (b *Batch) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) { func (b *Batch) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) {
trimmed := make(map[string]any, len(b.config.InputArrays)) trimmed := make(map[string]any, len(b.inputArrays))
for _, arrayKey := range b.config.InputArrays { for _, arrayKey := range b.inputArrays {
if v, ok := in[arrayKey]; ok { if v, ok := in[arrayKey]; ok {
trimmed[arrayKey] = v trimmed[arrayKey] = v
} }

View File

@@ -25,6 +25,10 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
code2 "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/code"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/infra/contract/coderunner" "github.com/coze-dev/coze-studio/backend/infra/contract/coderunner"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
@@ -113,50 +117,77 @@ var pythonThirdPartyWhitelist = map[string]struct{}{
} }
type Config struct { type Config struct {
Code string Code string
Language coderunner.Language Language coderunner.Language
OutputConfig map[string]*vo.TypeInfo
Runner coderunner.Runner Runner coderunner.Runner
} }
type CodeRunner struct { func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
config *Config ns := &schema.NodeSchema{
importError error Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeCodeRunner,
Name: n.Data.Meta.Title,
Configs: c,
}
inputs := n.Data.Inputs
code := inputs.Code
c.Code = code
language, err := convertCodeLanguage(inputs.Language)
if err != nil {
return nil, err
}
c.Language = language
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
} }
func NewCodeRunner(ctx context.Context, cfg *Config) (*CodeRunner, error) { func convertCodeLanguage(l int64) (coderunner.Language, error) {
if cfg == nil { switch l {
return nil, errors.New("cfg is required") case 5:
return coderunner.JavaScript, nil
case 3:
return coderunner.Python, nil
default:
return "", fmt.Errorf("invalid language: %d", l)
} }
}
if cfg.Language == "" { func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
return nil, errors.New("language is required")
}
if cfg.Code == "" { if c.Language != coderunner.Python {
return nil, errors.New("code is required")
}
if cfg.Language != coderunner.Python {
return nil, errors.New("only support python language") return nil, errors.New("only support python language")
} }
if len(cfg.OutputConfig) == 0 { importErr := validatePythonImports(c.Code)
return nil, errors.New("output config is required")
}
if cfg.Runner == nil { return &Runner{
return nil, errors.New("run coder is required") code: c.Code,
} language: c.Language,
outputConfig: ns.OutputTypes,
importErr := validatePythonImports(cfg.Code) runner: code2.GetCodeRunner(),
importError: importErr,
return &CodeRunner{
config: cfg,
importError: importErr,
}, nil }, nil
} }
type Runner struct {
outputConfig map[string]*vo.TypeInfo
code string
language coderunner.Language
runner coderunner.Runner
importError error
}
func validatePythonImports(code string) error { func validatePythonImports(code string) error {
imports := parsePythonImports(code) imports := parsePythonImports(code)
importErrors := make([]string, 0) importErrors := make([]string, 0)
@@ -191,11 +222,11 @@ func validatePythonImports(code string) error {
return nil return nil
} }
func (c *CodeRunner) RunCode(ctx context.Context, input map[string]any) (ret map[string]any, err error) { func (c *Runner) Invoke(ctx context.Context, input map[string]any) (ret map[string]any, err error) {
if c.importError != nil { if c.importError != nil {
return nil, vo.WrapError(errno.ErrCodeExecuteFail, c.importError, errorx.KV("detail", c.importError.Error())) return nil, vo.WrapError(errno.ErrCodeExecuteFail, c.importError, errorx.KV("detail", c.importError.Error()))
} }
response, err := c.config.Runner.Run(ctx, &coderunner.RunRequest{Code: c.config.Code, Language: c.config.Language, Params: input}) response, err := c.runner.Run(ctx, &coderunner.RunRequest{Code: c.code, Language: c.language, Params: input})
if err != nil { if err != nil {
return nil, vo.WrapError(errno.ErrCodeExecuteFail, err, errorx.KV("detail", err.Error())) return nil, vo.WrapError(errno.ErrCodeExecuteFail, err, errorx.KV("detail", err.Error()))
} }
@@ -203,7 +234,7 @@ func (c *CodeRunner) RunCode(ctx context.Context, input map[string]any) (ret map
result := response.Result result := response.Result
ctxcache.Store(ctx, coderRunnerRawOutputCtxKey, result) ctxcache.Store(ctx, coderRunnerRawOutputCtxKey, result)
output, ws, err := nodes.ConvertInputs(ctx, result, c.config.OutputConfig) output, ws, err := nodes.ConvertInputs(ctx, result, c.outputConfig)
if err != nil { if err != nil {
return nil, vo.WrapIfNeeded(errno.ErrCodeExecuteFail, err, errorx.KV("detail", err.Error())) return nil, vo.WrapIfNeeded(errno.ErrCodeExecuteFail, err, errorx.KV("detail", err.Error()))
} }
@@ -217,7 +248,7 @@ func (c *CodeRunner) RunCode(ctx context.Context, input map[string]any) (ret map
} }
func (c *CodeRunner) ToCallbackOutput(ctx context.Context, output map[string]any) (*nodes.StructuredCallbackOutput, error) { func (c *Runner) ToCallbackOutput(ctx context.Context, output map[string]any) (*nodes.StructuredCallbackOutput, error) {
rawOutput, ok := ctxcache.Get[map[string]any](ctx, coderRunnerRawOutputCtxKey) rawOutput, ok := ctxcache.Get[map[string]any](ctx, coderRunnerRawOutputCtxKey)
if !ok { if !ok {
return nil, errors.New("raw output config is required") return nil, errors.New("raw output config is required")

View File

@@ -75,30 +75,29 @@ async def main(args:Args)->Output:
mockRunner.EXPECT().Run(gomock.Any(), gomock.Any()).Return(response, nil) mockRunner.EXPECT().Run(gomock.Any(), gomock.Any()).Return(response, nil)
ctx := t.Context() ctx := t.Context()
c := &CodeRunner{ c := &Runner{
config: &Config{ language: coderunner.Python,
Language: coderunner.Python, code: codeTpl,
Code: codeTpl, outputConfig: map[string]*vo.TypeInfo{
OutputConfig: map[string]*vo.TypeInfo{ "key0": {Type: vo.DataTypeInteger},
"key0": {Type: vo.DataTypeInteger}, "key1": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeString}},
"key1": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeString}}, "key2": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key2": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}}, "key3": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key3": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{ "key31": {Type: vo.DataTypeString},
"key31": &vo.TypeInfo{Type: vo.DataTypeString}, "key32": {Type: vo.DataTypeString},
"key32": &vo.TypeInfo{Type: vo.DataTypeString}, "key33": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key33": &vo.TypeInfo{Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}}, "key34": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key34": &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{ "key341": {Type: vo.DataTypeString},
"key341": &vo.TypeInfo{Type: vo.DataTypeString}, "key342": {Type: vo.DataTypeString},
"key342": &vo.TypeInfo{Type: vo.DataTypeString}, }},
}},
},
},
"key4": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject}},
}, },
Runner: mockRunner, },
"key4": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject}},
}, },
runner: mockRunner,
} }
ret, err := c.RunCode(ctx, map[string]any{
ret, err := c.Invoke(ctx, map[string]any{
"input": "1123", "input": "1123",
}) })
@@ -145,38 +144,36 @@ async def main(args:Args)->Output:
mockRunner.EXPECT().Run(gomock.Any(), gomock.Any()).Return(response, nil) mockRunner.EXPECT().Run(gomock.Any(), gomock.Any()).Return(response, nil)
ctx := t.Context() ctx := t.Context()
c := &CodeRunner{ c := &Runner{
config: &Config{ code: codeTpl,
Code: codeTpl, language: coderunner.Python,
Language: coderunner.Python, outputConfig: map[string]*vo.TypeInfo{
OutputConfig: map[string]*vo.TypeInfo{ "key0": {Type: vo.DataTypeInteger},
"key0": {Type: vo.DataTypeInteger}, "key1": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeString}},
"key1": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeString}}, "key2": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key2": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}}, "key3": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key3": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{ "key31": {Type: vo.DataTypeString},
"key31": &vo.TypeInfo{Type: vo.DataTypeString}, "key32": {Type: vo.DataTypeString},
"key32": &vo.TypeInfo{Type: vo.DataTypeString}, "key33": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key33": &vo.TypeInfo{Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}}, "key34": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key34": &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{ "key341": {Type: vo.DataTypeString},
"key341": &vo.TypeInfo{Type: vo.DataTypeString}, "key342": {Type: vo.DataTypeString},
"key342": &vo.TypeInfo{Type: vo.DataTypeString},
}},
}}, }},
"key4": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{ }},
"key31": &vo.TypeInfo{Type: vo.DataTypeString}, "key4": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key32": &vo.TypeInfo{Type: vo.DataTypeString}, "key31": {Type: vo.DataTypeString},
"key33": &vo.TypeInfo{Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}}, "key32": {Type: vo.DataTypeString},
"key34": &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{ "key33": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key341": &vo.TypeInfo{Type: vo.DataTypeString}, "key34": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key342": &vo.TypeInfo{Type: vo.DataTypeString}, "key341": {Type: vo.DataTypeString},
}, "key342": {Type: vo.DataTypeString},
}},
}, },
}},
}, },
Runner: mockRunner,
}, },
runner: mockRunner,
} }
ret, err := c.RunCode(ctx, map[string]any{ ret, err := c.Invoke(ctx, map[string]any{
"input": "1123", "input": "1123",
}) })
@@ -219,30 +216,28 @@ async def main(args:Args)->Output:
} }
mockRunner.EXPECT().Run(gomock.Any(), gomock.Any()).Return(response, nil) mockRunner.EXPECT().Run(gomock.Any(), gomock.Any()).Return(response, nil)
c := &CodeRunner{ c := &Runner{
config: &Config{ code: codeTpl,
Code: codeTpl, language: coderunner.Python,
Language: coderunner.Python, outputConfig: map[string]*vo.TypeInfo{
OutputConfig: map[string]*vo.TypeInfo{ "key0": {Type: vo.DataTypeInteger},
"key0": {Type: vo.DataTypeInteger}, "key1": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key1": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}}, "key2": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key2": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}}, "key3": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key3": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{ "key31": {Type: vo.DataTypeString},
"key31": &vo.TypeInfo{Type: vo.DataTypeString}, "key32": {Type: vo.DataTypeString},
"key32": &vo.TypeInfo{Type: vo.DataTypeString}, "key33": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key33": &vo.TypeInfo{Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}}, "key34": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key34": &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{ "key341": {Type: vo.DataTypeString},
"key341": &vo.TypeInfo{Type: vo.DataTypeString}, "key342": {Type: vo.DataTypeString},
"key342": &vo.TypeInfo{Type: vo.DataTypeString}, "key343": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key343": &vo.TypeInfo{Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}}, }},
}}, },
},
},
}, },
Runner: mockRunner,
}, },
runner: mockRunner,
} }
ret, err := c.RunCode(ctx, map[string]any{ ret, err := c.Invoke(ctx, map[string]any{
"input": "1123", "input": "1123",
}) })

View File

@@ -0,0 +1,236 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package database
import (
"fmt"
einoCompose "github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
func setDatabaseInputsForNodeSchema(n *vo.Node, ns *schema.NodeSchema) (err error) {
selectParam := n.Data.Inputs.SelectParam
if selectParam != nil {
err = applyDBConditionToSchema(ns, selectParam.Condition, n.Parent())
if err != nil {
return err
}
}
insertParam := n.Data.Inputs.InsertParam
if insertParam != nil {
err = applyInsetFieldInfoToSchema(ns, insertParam.FieldInfo, n.Parent())
if err != nil {
return err
}
}
deleteParam := n.Data.Inputs.DeleteParam
if deleteParam != nil {
err = applyDBConditionToSchema(ns, &deleteParam.Condition, n.Parent())
if err != nil {
return err
}
}
updateParam := n.Data.Inputs.UpdateParam
if updateParam != nil {
err = applyDBConditionToSchema(ns, &updateParam.Condition, n.Parent())
if err != nil {
return err
}
err = applyInsetFieldInfoToSchema(ns, updateParam.FieldInfo, n.Parent())
if err != nil {
return err
}
}
return nil
}
func applyDBConditionToSchema(ns *schema.NodeSchema, condition *vo.DBCondition, parentNode *vo.Node) error {
if condition.ConditionList == nil {
return nil
}
for idx, params := range condition.ConditionList {
var right *vo.Param
for _, param := range params {
if param == nil {
continue
}
if param.Name == "right" {
right = param
break
}
}
if right == nil {
continue
}
name := fmt.Sprintf("__condition_right_%d", idx)
tInfo, err := convert.CanvasBlockInputToTypeInfo(right.Input)
if err != nil {
return err
}
ns.SetInputType(name, tInfo)
sources, err := convert.CanvasBlockInputToFieldInfo(right.Input, einoCompose.FieldPath{name}, parentNode)
if err != nil {
return err
}
ns.AddInputSource(sources...)
}
return nil
}
func applyInsetFieldInfoToSchema(ns *schema.NodeSchema, fieldInfo [][]*vo.Param, parentNode *vo.Node) error {
if len(fieldInfo) == 0 {
return nil
}
for _, params := range fieldInfo {
// Each FieldInfo is list params, containing two elements.
// The first is to set the name of the field and the second is the corresponding value.
p0 := params[0]
p1 := params[1]
name := p0.Input.Value.Content.(string) // must string type
tInfo, err := convert.CanvasBlockInputToTypeInfo(p1.Input)
if err != nil {
return err
}
name = "__setting_field_" + name
ns.SetInputType(name, tInfo)
sources, err := convert.CanvasBlockInputToFieldInfo(p1.Input, einoCompose.FieldPath{name}, parentNode)
if err != nil {
return err
}
ns.AddInputSource(sources...)
}
return nil
}
func buildClauseGroupFromCondition(condition *vo.DBCondition) (*database.ClauseGroup, error) {
clauseGroup := &database.ClauseGroup{}
if len(condition.ConditionList) == 1 {
params := condition.ConditionList[0]
clause, err := buildClauseFromParams(params)
if err != nil {
return nil, err
}
clauseGroup.Single = clause
} else {
relation, err := convertLogicTypeToRelation(condition.Logic)
if err != nil {
return nil, err
}
clauseGroup.Multi = &database.MultiClause{
Clauses: make([]*database.Clause, 0, len(condition.ConditionList)),
Relation: relation,
}
for i := range condition.ConditionList {
params := condition.ConditionList[i]
clause, err := buildClauseFromParams(params)
if err != nil {
return nil, err
}
clauseGroup.Multi.Clauses = append(clauseGroup.Multi.Clauses, clause)
}
}
return clauseGroup, nil
}
func buildClauseFromParams(params []*vo.Param) (*database.Clause, error) {
var left, operation *vo.Param
for _, p := range params {
if p == nil {
continue
}
if p.Name == "left" {
left = p
continue
}
if p.Name == "operation" {
operation = p
continue
}
}
if left == nil {
return nil, fmt.Errorf("left clause is required")
}
if operation == nil {
return nil, fmt.Errorf("operation clause is required")
}
operator, err := operationToOperator(operation.Input.Value.Content.(string))
if err != nil {
return nil, err
}
clause := &database.Clause{
Left: left.Input.Value.Content.(string),
Operator: operator,
}
return clause, nil
}
func convertLogicTypeToRelation(logicType vo.DatabaseLogicType) (database.ClauseRelation, error) {
switch logicType {
case vo.DatabaseLogicAnd:
return database.ClauseRelationAND, nil
case vo.DatabaseLogicOr:
return database.ClauseRelationOR, nil
default:
return "", fmt.Errorf("logic type %v is invalid", logicType)
}
}
func operationToOperator(s string) (database.Operator, error) {
switch s {
case "EQUAL":
return database.OperatorEqual, nil
case "NOT_EQUAL":
return database.OperatorNotEqual, nil
case "GREATER_THAN":
return database.OperatorGreater, nil
case "LESS_THAN":
return database.OperatorLesser, nil
case "GREATER_EQUAL":
return database.OperatorGreaterOrEqual, nil
case "LESS_EQUAL":
return database.OperatorLesserOrEqual, nil
case "IN":
return database.OperatorIn, nil
case "NOT_IN":
return database.OperatorNotIn, nil
case "IS_NULL":
return database.OperatorIsNull, nil
case "IS_NOT_NULL":
return database.OperatorIsNotNull, nil
case "LIKE":
return database.OperatorLike, nil
case "NOT_LIKE":
return database.OperatorNotLike, nil
}
return "", fmt.Errorf("not a valid Operation string")
}

View File

@@ -342,7 +342,7 @@ func responseFormatted(configOutput map[string]*vo.TypeInfo, response *database.
return ret, nil return ret, nil
} }
func convertClauseGroupToConditionGroup(ctx context.Context, clauseGroup *database.ClauseGroup, input map[string]any) (*database.ConditionGroup, error) { func convertClauseGroupToConditionGroup(_ context.Context, clauseGroup *database.ClauseGroup, input map[string]any) (*database.ConditionGroup, error) {
var ( var (
rightValue any rightValue any
ok bool ok bool
@@ -394,13 +394,13 @@ func convertClauseGroupToConditionGroup(ctx context.Context, clauseGroup *databa
return conditionGroup, nil return conditionGroup, nil
} }
func convertClauseGroupToUpdateInventory(ctx context.Context, clauseGroup *database.ClauseGroup, input map[string]any) (*UpdateInventory, error) { func convertClauseGroupToUpdateInventory(ctx context.Context, clauseGroup *database.ClauseGroup, input map[string]any) (*updateInventory, error) {
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, clauseGroup, input) conditionGroup, err := convertClauseGroupToConditionGroup(ctx, clauseGroup, input)
if err != nil { if err != nil {
return nil, err return nil, err
} }
fields := parseToInput(input) fields := parseToInput(input)
inventory := &UpdateInventory{ inventory := &updateInventory{
ConditionGroup: conditionGroup, ConditionGroup: conditionGroup,
Fields: fields, Fields: fields,
} }

View File

@@ -19,48 +19,89 @@ package database
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"reflect" "reflect"
"strconv"
"strings" "strings"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database" "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/sonic" "github.com/coze-dev/coze-studio/backend/pkg/sonic"
) )
type CustomSQLConfig struct { type CustomSQLConfig struct {
DatabaseInfoID int64 DatabaseInfoID int64
SQLTemplate string SQLTemplate string
OutputConfig map[string]*vo.TypeInfo
CustomSQLExecutor database.DatabaseOperator
} }
func NewCustomSQL(_ context.Context, cfg *CustomSQLConfig) (*CustomSQL, error) { func (c *CustomSQLConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
if cfg == nil { ns := &schema.NodeSchema{
return nil, errors.New("config is required") Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeDatabaseCustomSQL,
Name: n.Data.Meta.Title,
Configs: c,
} }
if cfg.DatabaseInfoID == 0 {
dsList := n.Data.Inputs.DatabaseInfoList
if len(dsList) == 0 {
return nil, fmt.Errorf("database info is requird")
}
databaseInfo := dsList[0]
dsID, err := strconv.ParseInt(databaseInfo.DatabaseInfoID, 10, 64)
if err != nil {
return nil, err
}
c.DatabaseInfoID = dsID
sql := n.Data.Inputs.SQL
if len(sql) == 0 {
return nil, fmt.Errorf("sql is requird")
}
c.SQLTemplate = sql
if err = convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func (c *CustomSQLConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
if c.DatabaseInfoID == 0 {
return nil, errors.New("database info id is required and greater than 0") return nil, errors.New("database info id is required and greater than 0")
} }
if cfg.SQLTemplate == "" { if c.SQLTemplate == "" {
return nil, errors.New("sql template is required") return nil, errors.New("sql template is required")
} }
if cfg.CustomSQLExecutor == nil {
return nil, errors.New("custom sqler is required")
}
return &CustomSQL{ return &CustomSQL{
config: cfg, databaseInfoID: c.DatabaseInfoID,
sqlTemplate: c.SQLTemplate,
outputTypes: ns.OutputTypes,
customSQLExecutor: database.GetDatabaseOperator(),
}, nil }, nil
} }
type CustomSQL struct { type CustomSQL struct {
config *CustomSQLConfig databaseInfoID int64
sqlTemplate string
outputTypes map[string]*vo.TypeInfo
customSQLExecutor database.DatabaseOperator
} }
func (c *CustomSQL) Execute(ctx context.Context, input map[string]any) (map[string]any, error) { func (c *CustomSQL) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
req := &database.CustomSQLRequest{ req := &database.CustomSQLRequest{
DatabaseInfoID: c.config.DatabaseInfoID, DatabaseInfoID: c.databaseInfoID,
IsDebugRun: isDebugExecute(ctx), IsDebugRun: isDebugExecute(ctx),
UserID: getExecUserID(ctx), UserID: getExecUserID(ctx),
} }
@@ -71,7 +112,7 @@ func (c *CustomSQL) Execute(ctx context.Context, input map[string]any) (map[stri
} }
templateSQL := "" templateSQL := ""
templateParts := nodes.ParseTemplate(c.config.SQLTemplate) templateParts := nodes.ParseTemplate(c.sqlTemplate)
sqlParams := make([]database.SQLParam, 0, len(templateParts)) sqlParams := make([]database.SQLParam, 0, len(templateParts))
var nilError = errors.New("field is nil") var nilError = errors.New("field is nil")
for _, templatePart := range templateParts { for _, templatePart := range templateParts {
@@ -113,12 +154,12 @@ func (c *CustomSQL) Execute(ctx context.Context, input map[string]any) (map[stri
templateSQL = strings.Replace(templateSQL, "`?`", "?", -1) templateSQL = strings.Replace(templateSQL, "`?`", "?", -1)
req.SQL = templateSQL req.SQL = templateSQL
req.Params = sqlParams req.Params = sqlParams
response, err := c.config.CustomSQLExecutor.Execute(ctx, req) response, err := c.customSQLExecutor.Execute(ctx, req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ret, err := responseFormatted(c.config.OutputConfig, response) ret, err := responseFormatted(c.outputTypes, response)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -28,6 +28,7 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database/databasemock" "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database/databasemock"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
) )
type mockCustomSQLer struct { type mockCustomSQLer struct {
@@ -39,7 +40,7 @@ func (m mockCustomSQLer) Execute() func(ctx context.Context, request *database.C
m.validate(request) m.validate(request)
r := &database.Response{ r := &database.Response{
Objects: []database.Object{ Objects: []database.Object{
database.Object{ {
"v1": "v1_ret", "v1": "v1_ret",
"v2": "v2_ret", "v2": "v2_ret",
}, },
@@ -58,9 +59,9 @@ func TestCustomSQL_Execute(t *testing.T) {
validate: func(req *database.CustomSQLRequest) { validate: func(req *database.CustomSQLRequest) {
assert.Equal(t, int64(111), req.DatabaseInfoID) assert.Equal(t, int64(111), req.DatabaseInfoID)
ps := []database.SQLParam{ ps := []database.SQLParam{
database.SQLParam{Value: "v1_value"}, {Value: "v1_value"},
database.SQLParam{Value: "v2_value"}, {Value: "v2_value"},
database.SQLParam{Value: "v3_value"}, {Value: "v3_value"},
} }
assert.Equal(t, ps, req.Params) assert.Equal(t, ps, req.Params)
assert.Equal(t, "select * from v1 where v1 = ? and v2 = ? and v3 = ?", req.SQL) assert.Equal(t, "select * from v1 where v1 = ? and v2 = ? and v3 = ?", req.SQL)
@@ -80,23 +81,25 @@ func TestCustomSQL_Execute(t *testing.T) {
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl) mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
mockDatabaseOperator.EXPECT().Execute(gomock.Any(), gomock.Any()).DoAndReturn(mockSQLer.Execute()).AnyTimes() mockDatabaseOperator.EXPECT().Execute(gomock.Any(), gomock.Any()).DoAndReturn(mockSQLer.Execute()).AnyTimes()
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
cfg := &CustomSQLConfig{ cfg := &CustomSQLConfig{
DatabaseInfoID: 111, DatabaseInfoID: 111,
SQLTemplate: "select * from v1 where v1 = {{v1}} and v2 = '{{v2}}' and v3 = `{{v3}}`", SQLTemplate: "select * from v1 where v1 = {{v1}} and v2 = '{{v2}}' and v3 = `{{v3}}`",
CustomSQLExecutor: mockDatabaseOperator, }
OutputConfig: map[string]*vo.TypeInfo{
c1, err := cfg.Build(context.Background(), &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{ "outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"v1": {Type: vo.DataTypeString}, "v1": {Type: vo.DataTypeString},
"v2": {Type: vo.DataTypeString}, "v2": {Type: vo.DataTypeString},
}}}, }}},
"rowNum": {Type: vo.DataTypeInteger}, "rowNum": {Type: vo.DataTypeInteger},
}, },
} })
cl := &CustomSQL{ assert.NoError(t, err)
config: cfg,
}
ret, err := cl.Execute(t.Context(), map[string]any{ ret, err := c1.(*CustomSQL).Invoke(t.Context(), map[string]any{
"v1": "v1_value", "v1": "v1_value",
"v2": "v2_value", "v2": "v2_value",
"v3": "v3_value", "v3": "v3_value",

View File

@@ -20,61 +20,102 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"strconv"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database" "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
) )
type DeleteConfig struct { type DeleteConfig struct {
DatabaseInfoID int64 DatabaseInfoID int64
ClauseGroup *database.ClauseGroup ClauseGroup *database.ClauseGroup
OutputConfig map[string]*vo.TypeInfo
Deleter database.DatabaseOperator
}
type Delete struct {
config *DeleteConfig
} }
func NewDelete(_ context.Context, cfg *DeleteConfig) (*Delete, error) { func (d *DeleteConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
if cfg == nil { ns := &schema.NodeSchema{
return nil, errors.New("config is required") Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeDatabaseDelete,
Name: n.Data.Meta.Title,
Configs: d,
} }
if cfg.DatabaseInfoID == 0 {
dsList := n.Data.Inputs.DatabaseInfoList
if len(dsList) == 0 {
return nil, fmt.Errorf("database info is requird")
}
databaseInfo := dsList[0]
dsID, err := strconv.ParseInt(databaseInfo.DatabaseInfoID, 10, 64)
if err != nil {
return nil, err
}
d.DatabaseInfoID = dsID
deleteParam := n.Data.Inputs.DeleteParam
clauseGroup, err := buildClauseGroupFromCondition(&deleteParam.Condition)
if err != nil {
return nil, err
}
d.ClauseGroup = clauseGroup
if err = setDatabaseInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func (d *DeleteConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
if d.DatabaseInfoID == 0 {
return nil, errors.New("database info id is required and greater than 0") return nil, errors.New("database info id is required and greater than 0")
} }
if cfg.ClauseGroup == nil { if d.ClauseGroup == nil {
return nil, errors.New("clauseGroup is required") return nil, errors.New("clauseGroup is required")
} }
if cfg.Deleter == nil {
return nil, errors.New("deleter is required")
}
return &Delete{ return &Delete{
config: cfg, databaseInfoID: d.DatabaseInfoID,
clauseGroup: d.ClauseGroup,
outputTypes: ns.OutputTypes,
deleter: database.GetDatabaseOperator(),
}, nil }, nil
} }
func (d *Delete) Delete(ctx context.Context, in map[string]any) (map[string]any, error) { type Delete struct {
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, d.config.ClauseGroup, in) databaseInfoID int64
clauseGroup *database.ClauseGroup
outputTypes map[string]*vo.TypeInfo
deleter database.DatabaseOperator
}
func (d *Delete) Invoke(ctx context.Context, in map[string]any) (map[string]any, error) {
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, d.clauseGroup, in)
if err != nil { if err != nil {
return nil, err return nil, err
} }
request := &database.DeleteRequest{ request := &database.DeleteRequest{
DatabaseInfoID: d.config.DatabaseInfoID, DatabaseInfoID: d.databaseInfoID,
ConditionGroup: conditionGroup, ConditionGroup: conditionGroup,
IsDebugRun: isDebugExecute(ctx), IsDebugRun: isDebugExecute(ctx),
UserID: getExecUserID(ctx), UserID: getExecUserID(ctx),
} }
response, err := d.config.Deleter.Delete(ctx, request) response, err := d.deleter.Delete(ctx, request)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ret, err := responseFormatted(d.config.OutputConfig, response) ret, err := responseFormatted(d.outputTypes, response)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -82,7 +123,7 @@ func (d *Delete) Delete(ctx context.Context, in map[string]any) (map[string]any,
} }
func (d *Delete) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) { func (d *Delete) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) {
conditionGroup, err := convertClauseGroupToConditionGroup(context.Background(), d.config.ClauseGroup, in) conditionGroup, err := convertClauseGroupToConditionGroup(context.Background(), d.clauseGroup, in)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -90,7 +131,7 @@ func (d *Delete) ToCallbackInput(_ context.Context, in map[string]any) (map[stri
} }
func (d *Delete) toDatabaseDeleteCallbackInput(conditionGroup *database.ConditionGroup) (map[string]any, error) { func (d *Delete) toDatabaseDeleteCallbackInput(conditionGroup *database.ConditionGroup) (map[string]any, error) {
databaseID := d.config.DatabaseInfoID databaseID := d.databaseInfoID
result := make(map[string]any) result := make(map[string]any)
result["databaseInfoList"] = []string{fmt.Sprintf("%d", databaseID)} result["databaseInfoList"] = []string{fmt.Sprintf("%d", databaseID)}

View File

@@ -20,54 +20,84 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"strconv"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database" "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
) )
type InsertConfig struct { type InsertConfig struct {
DatabaseInfoID int64 DatabaseInfoID int64
OutputConfig map[string]*vo.TypeInfo
Inserter database.DatabaseOperator
} }
type Insert struct { func (i *InsertConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
config *InsertConfig ns := &schema.NodeSchema{
} Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeDatabaseInsert,
func NewInsert(_ context.Context, cfg *InsertConfig) (*Insert, error) { Name: n.Data.Meta.Title,
if cfg == nil { Configs: i,
return nil, errors.New("config is required")
} }
if cfg.DatabaseInfoID == 0 {
dsList := n.Data.Inputs.DatabaseInfoList
if len(dsList) == 0 {
return nil, fmt.Errorf("database info is requird")
}
databaseInfo := dsList[0]
dsID, err := strconv.ParseInt(databaseInfo.DatabaseInfoID, 10, 64)
if err != nil {
return nil, err
}
i.DatabaseInfoID = dsID
if err = setDatabaseInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func (i *InsertConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
if i.DatabaseInfoID == 0 {
return nil, errors.New("database info id is required and greater than 0") return nil, errors.New("database info id is required and greater than 0")
} }
if cfg.Inserter == nil {
return nil, errors.New("inserter is required")
}
return &Insert{ return &Insert{
config: cfg, databaseInfoID: i.DatabaseInfoID,
outputTypes: ns.OutputTypes,
inserter: database.GetDatabaseOperator(),
}, nil }, nil
} }
func (is *Insert) Insert(ctx context.Context, input map[string]any) (map[string]any, error) { type Insert struct {
databaseInfoID int64
outputTypes map[string]*vo.TypeInfo
inserter database.DatabaseOperator
}
func (is *Insert) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
fields := parseToInput(input) fields := parseToInput(input)
req := &database.InsertRequest{ req := &database.InsertRequest{
DatabaseInfoID: is.config.DatabaseInfoID, DatabaseInfoID: is.databaseInfoID,
Fields: fields, Fields: fields,
IsDebugRun: isDebugExecute(ctx), IsDebugRun: isDebugExecute(ctx),
UserID: getExecUserID(ctx), UserID: getExecUserID(ctx),
} }
response, err := is.config.Inserter.Insert(ctx, req) response, err := is.inserter.Insert(ctx, req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ret, err := responseFormatted(is.config.OutputConfig, response) ret, err := responseFormatted(is.outputTypes, response)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -76,7 +106,7 @@ func (is *Insert) Insert(ctx context.Context, input map[string]any) (map[string]
} }
func (is *Insert) ToCallbackInput(_ context.Context, input map[string]any) (map[string]any, error) { func (is *Insert) ToCallbackInput(_ context.Context, input map[string]any) (map[string]any, error) {
databaseID := is.config.DatabaseInfoID databaseID := is.databaseInfoID
fs := parseToInput(input) fs := parseToInput(input)
result := make(map[string]any) result := make(map[string]any)
result["databaseInfoList"] = []string{fmt.Sprintf("%d", databaseID)} result["databaseInfoList"] = []string{fmt.Sprintf("%d", databaseID)}

View File

@@ -20,68 +20,137 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"strconv"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database" "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
) )
type QueryConfig struct { type QueryConfig struct {
DatabaseInfoID int64 DatabaseInfoID int64
QueryFields []string QueryFields []string
OrderClauses []*database.OrderClause OrderClauses []*database.OrderClause
OutputConfig map[string]*vo.TypeInfo
ClauseGroup *database.ClauseGroup ClauseGroup *database.ClauseGroup
Limit int64 Limit int64
Op database.DatabaseOperator
} }
type Query struct { func (q *QueryConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
config *QueryConfig ns := &schema.NodeSchema{
} Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeDatabaseQuery,
func NewQuery(_ context.Context, cfg *QueryConfig) (*Query, error) { Name: n.Data.Meta.Title,
if cfg == nil { Configs: q,
return nil, errors.New("config is required")
} }
if cfg.DatabaseInfoID == 0 {
dsList := n.Data.Inputs.DatabaseInfoList
if len(dsList) == 0 {
return nil, fmt.Errorf("database info is requird")
}
databaseInfo := dsList[0]
dsID, err := strconv.ParseInt(databaseInfo.DatabaseInfoID, 10, 64)
if err != nil {
return nil, err
}
q.DatabaseInfoID = dsID
selectParam := n.Data.Inputs.SelectParam
q.Limit = selectParam.Limit
queryFields := make([]string, 0)
for _, v := range selectParam.FieldList {
queryFields = append(queryFields, strconv.FormatInt(v.FieldID, 10))
}
q.QueryFields = queryFields
orderClauses := make([]*database.OrderClause, 0, len(selectParam.OrderByList))
for _, o := range selectParam.OrderByList {
orderClauses = append(orderClauses, &database.OrderClause{
FieldID: strconv.FormatInt(o.FieldID, 10),
IsAsc: o.IsAsc,
})
}
q.OrderClauses = orderClauses
clauseGroup := &database.ClauseGroup{}
if selectParam.Condition != nil {
clauseGroup, err = buildClauseGroupFromCondition(selectParam.Condition)
if err != nil {
return nil, err
}
}
q.ClauseGroup = clauseGroup
if err = setDatabaseInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func (q *QueryConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
if q.DatabaseInfoID == 0 {
return nil, errors.New("database info id is required and greater than 0") return nil, errors.New("database info id is required and greater than 0")
} }
if cfg.Limit == 0 { if q.Limit == 0 {
return nil, errors.New("limit is required and greater than 0") return nil, errors.New("limit is required and greater than 0")
} }
if cfg.Op == nil { return &Query{
return nil, errors.New("op is required") databaseInfoID: q.DatabaseInfoID,
} queryFields: q.QueryFields,
orderClauses: q.OrderClauses,
return &Query{config: cfg}, nil outputTypes: ns.OutputTypes,
clauseGroup: q.ClauseGroup,
limit: q.Limit,
op: database.GetDatabaseOperator(),
}, nil
} }
func (ds *Query) Query(ctx context.Context, in map[string]any) (map[string]any, error) { type Query struct {
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, ds.config.ClauseGroup, in) databaseInfoID int64
queryFields []string
orderClauses []*database.OrderClause
outputTypes map[string]*vo.TypeInfo
clauseGroup *database.ClauseGroup
limit int64
op database.DatabaseOperator
}
func (ds *Query) Invoke(ctx context.Context, in map[string]any) (map[string]any, error) {
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, ds.clauseGroup, in)
if err != nil { if err != nil {
return nil, err return nil, err
} }
req := &database.QueryRequest{ req := &database.QueryRequest{
DatabaseInfoID: ds.config.DatabaseInfoID, DatabaseInfoID: ds.databaseInfoID,
OrderClauses: ds.config.OrderClauses, OrderClauses: ds.orderClauses,
SelectFields: ds.config.QueryFields, SelectFields: ds.queryFields,
Limit: ds.config.Limit, Limit: ds.limit,
IsDebugRun: isDebugExecute(ctx), IsDebugRun: isDebugExecute(ctx),
UserID: getExecUserID(ctx), UserID: getExecUserID(ctx),
} }
req.ConditionGroup = conditionGroup req.ConditionGroup = conditionGroup
response, err := ds.config.Op.Query(ctx, req) response, err := ds.op.Query(ctx, req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ret, err := responseFormatted(ds.config.OutputConfig, response) ret, err := responseFormatted(ds.outputTypes, response)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -93,18 +162,18 @@ func notNeedTakeMapValue(op database.Operator) bool {
} }
func (ds *Query) ToCallbackInput(ctx context.Context, in map[string]any) (map[string]any, error) { func (ds *Query) ToCallbackInput(ctx context.Context, in map[string]any) (map[string]any, error) {
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, ds.config.ClauseGroup, in) conditionGroup, err := convertClauseGroupToConditionGroup(ctx, ds.clauseGroup, in)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return toDatabaseQueryCallbackInput(ds.config, conditionGroup) return ds.toDatabaseQueryCallbackInput(conditionGroup)
} }
func toDatabaseQueryCallbackInput(config *QueryConfig, conditionGroup *database.ConditionGroup) (map[string]any, error) { func (ds *Query) toDatabaseQueryCallbackInput(conditionGroup *database.ConditionGroup) (map[string]any, error) {
result := make(map[string]any) result := make(map[string]any)
databaseID := config.DatabaseInfoID databaseID := ds.databaseInfoID
result["databaseInfoList"] = []string{fmt.Sprintf("%d", databaseID)} result["databaseInfoList"] = []string{fmt.Sprintf("%d", databaseID)}
result["selectParam"] = map[string]any{} result["selectParam"] = map[string]any{}
@@ -116,8 +185,8 @@ func toDatabaseQueryCallbackInput(config *QueryConfig, conditionGroup *database.
FieldID string `json:"fieldId"` FieldID string `json:"fieldId"`
IsDistinct bool `json:"isDistinct"` IsDistinct bool `json:"isDistinct"`
} }
fieldList := make([]Field, 0, len(config.QueryFields)) fieldList := make([]Field, 0, len(ds.queryFields))
for _, f := range config.QueryFields { for _, f := range ds.queryFields {
fieldList = append(fieldList, Field{FieldID: f}) fieldList = append(fieldList, Field{FieldID: f})
} }
type Order struct { type Order struct {
@@ -126,7 +195,7 @@ func toDatabaseQueryCallbackInput(config *QueryConfig, conditionGroup *database.
} }
OrderList := make([]Order, 0) OrderList := make([]Order, 0)
for _, c := range config.OrderClauses { for _, c := range ds.orderClauses {
OrderList = append(OrderList, Order{ OrderList = append(OrderList, Order{
FieldID: c.FieldID, FieldID: c.FieldID,
IsAsc: c.IsAsc, IsAsc: c.IsAsc,
@@ -135,12 +204,11 @@ func toDatabaseQueryCallbackInput(config *QueryConfig, conditionGroup *database.
result["selectParam"] = map[string]any{ result["selectParam"] = map[string]any{
"condition": condition, "condition": condition,
"fieldList": fieldList, "fieldList": fieldList,
"limit": config.Limit, "limit": ds.limit,
"orderByList": OrderList, "orderByList": OrderList,
} }
return result, nil return result, nil
} }
type ConditionItem struct { type ConditionItem struct {
@@ -216,6 +284,5 @@ func convertToLogic(rel database.ClauseRelation) (string, error) {
return "AND", nil return "AND", nil
default: default:
return "", fmt.Errorf("unknown clause relation %v", rel) return "", fmt.Errorf("unknown clause relation %v", rel)
} }
} }

View File

@@ -30,6 +30,7 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database/databasemock" "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database/databasemock"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
) )
type mockDsSelect struct { type mockDsSelect struct {
@@ -82,16 +83,7 @@ func TestDataset_Query(t *testing.T) {
}, },
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}}, OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
QueryFields: []string{"v1", "v2"}, QueryFields: []string{"v1", "v2"},
OutputConfig: map[string]*vo.TypeInfo{ Limit: 10,
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"v1": {Type: vo.DataTypeString},
"v2": {Type: vo.DataTypeString},
},
}},
"rowNum": {Type: vo.DataTypeInteger},
},
} }
mockQuery := &mockDsSelect{objects: objects, t: t, validate: func(request *database.QueryRequest) { mockQuery := &mockDsSelect{objects: objects, t: t, validate: func(request *database.QueryRequest) {
@@ -106,17 +98,27 @@ func TestDataset_Query(t *testing.T) {
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl) mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()) mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query())
cfg.Op = mockDatabaseOperator defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
ds := Query{ ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
config: cfg, OutputTypes: map[string]*vo.TypeInfo{
} "outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"v1": {Type: vo.DataTypeString},
"v2": {Type: vo.DataTypeString},
},
}},
"rowNum": {Type: vo.DataTypeInteger},
},
})
assert.NoError(t, err)
in := map[string]interface{}{ in := map[string]interface{}{
"__condition_right_0": 1, "__condition_right_0": 1,
} }
result, err := ds.Query(t.Context(), in) result, err := ds.(*Query).Invoke(t.Context(), in)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "1", result["outputList"].([]any)[0].(database.Object)["v1"]) assert.Equal(t, "1", result["outputList"].([]any)[0].(database.Object)["v1"])
assert.Equal(t, "2", result["outputList"].([]any)[0].(database.Object)["v2"]) assert.Equal(t, "2", result["outputList"].([]any)[0].(database.Object)["v2"])
@@ -137,17 +139,7 @@ func TestDataset_Query(t *testing.T) {
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}}, OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
QueryFields: []string{"v1", "v2"}, QueryFields: []string{"v1", "v2"},
Limit: 10,
OutputConfig: map[string]*vo.TypeInfo{
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"v1": {Type: vo.DataTypeString},
"v2": {Type: vo.DataTypeString},
},
}},
"rowNum": {Type: vo.DataTypeInteger},
},
} }
objects := make([]database.Object, 0) objects := make([]database.Object, 0)
@@ -170,18 +162,28 @@ func TestDataset_Query(t *testing.T) {
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl) mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes() mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
cfg.Op = mockDatabaseOperator defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
ds := Query{ ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
config: cfg, OutputTypes: map[string]*vo.TypeInfo{
} "outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"v1": {Type: vo.DataTypeString},
"v2": {Type: vo.DataTypeString},
},
}},
"rowNum": {Type: vo.DataTypeInteger},
},
})
assert.NoError(t, err)
in := map[string]any{ in := map[string]any{
"__condition_right_0": 1, "__condition_right_0": 1,
"__condition_right_1": 2, "__condition_right_1": 2,
} }
result, err := ds.Query(t.Context(), in) result, err := ds.(*Query).Invoke(t.Context(), in)
assert.NoError(t, err) assert.NoError(t, err)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "1", result["outputList"].([]any)[0].(database.Object)["v1"]) assert.Equal(t, "1", result["outputList"].([]any)[0].(database.Object)["v1"])
@@ -199,17 +201,7 @@ func TestDataset_Query(t *testing.T) {
}, },
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}}, OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
QueryFields: []string{"v1", "v2"}, QueryFields: []string{"v1", "v2"},
Limit: 10,
OutputConfig: map[string]*vo.TypeInfo{
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"v1": {Type: vo.DataTypeInteger},
"v2": {Type: vo.DataTypeInteger},
},
}},
"rowNum": {Type: vo.DataTypeInteger},
},
} }
objects := make([]database.Object, 0) objects := make([]database.Object, 0)
objects = append(objects, database.Object{ objects = append(objects, database.Object{
@@ -230,17 +222,27 @@ func TestDataset_Query(t *testing.T) {
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl) mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes() mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
cfg.Op = mockDatabaseOperator defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
ds := Query{ ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
config: cfg, OutputTypes: map[string]*vo.TypeInfo{
} "outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"v1": {Type: vo.DataTypeInteger},
"v2": {Type: vo.DataTypeInteger},
},
}},
"rowNum": {Type: vo.DataTypeInteger},
},
})
assert.NoError(t, err)
in := map[string]any{ in := map[string]any{
"__condition_right_0": 1, "__condition_right_0": 1,
} }
result, err := ds.Query(t.Context(), in) result, err := ds.(*Query).Invoke(t.Context(), in)
assert.NoError(t, err) assert.NoError(t, err)
fmt.Println(result) fmt.Println(result)
assert.Equal(t, map[string]any{ assert.Equal(t, map[string]any{
@@ -261,18 +263,7 @@ func TestDataset_Query(t *testing.T) {
}, },
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}}, OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
QueryFields: []string{"v1", "v2"}, QueryFields: []string{"v1", "v2"},
Limit: 10,
OutputConfig: map[string]*vo.TypeInfo{
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"v1": {Type: vo.DataTypeInteger},
"v2": {Type: vo.DataTypeInteger},
"v3": {Type: vo.DataTypeInteger},
},
}},
"rowNum": {Type: vo.DataTypeInteger},
},
} }
objects := make([]database.Object, 0) objects := make([]database.Object, 0)
objects = append(objects, database.Object{ objects = append(objects, database.Object{
@@ -290,15 +281,26 @@ func TestDataset_Query(t *testing.T) {
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl) mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes() mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
cfg.Op = mockDatabaseOperator defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
ds := Query{ ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
config: cfg, OutputTypes: map[string]*vo.TypeInfo{
} "outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"v1": {Type: vo.DataTypeInteger},
"v2": {Type: vo.DataTypeInteger},
"v3": {Type: vo.DataTypeInteger},
},
}},
"rowNum": {Type: vo.DataTypeInteger},
},
})
assert.NoError(t, err)
in := map[string]any{"__condition_right_0": 1} in := map[string]any{"__condition_right_0": 1}
result, err := ds.Query(t.Context(), in) result, err := ds.(*Query).Invoke(t.Context(), in)
assert.NoError(t, err) assert.NoError(t, err)
fmt.Println(result) fmt.Println(result)
assert.Equal(t, int64(1), result["outputList"].([]any)[0].(database.Object)["v1"]) assert.Equal(t, int64(1), result["outputList"].([]any)[0].(database.Object)["v1"])
@@ -321,22 +323,7 @@ func TestDataset_Query(t *testing.T) {
}, },
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}}, OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
QueryFields: []string{"v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8"}, QueryFields: []string{"v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8"},
Limit: 10,
OutputConfig: map[string]*vo.TypeInfo{
"outputList": {Type: vo.DataTypeArray,
ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"v1": {Type: vo.DataTypeInteger},
"v2": {Type: vo.DataTypeNumber},
"v3": {Type: vo.DataTypeBoolean},
"v4": {Type: vo.DataTypeBoolean},
"v5": {Type: vo.DataTypeTime},
"v6": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeInteger}},
"v7": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeBoolean}},
"v8": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
},
}},
"rowNum": {Type: vo.DataTypeInteger},
},
} }
objects := make([]database.Object, 0) objects := make([]database.Object, 0)
@@ -363,17 +350,32 @@ func TestDataset_Query(t *testing.T) {
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl) mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes() mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
cfg.Op = mockDatabaseOperator defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
ds := Query{ ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
config: cfg, OutputTypes: map[string]*vo.TypeInfo{
} "outputList": {Type: vo.DataTypeArray,
ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"v1": {Type: vo.DataTypeInteger},
"v2": {Type: vo.DataTypeNumber},
"v3": {Type: vo.DataTypeBoolean},
"v4": {Type: vo.DataTypeBoolean},
"v5": {Type: vo.DataTypeTime},
"v6": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeInteger}},
"v7": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeBoolean}},
"v8": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
},
}},
"rowNum": {Type: vo.DataTypeInteger},
},
})
assert.NoError(t, err)
in := map[string]any{ in := map[string]any{
"__condition_right_0": 1, "__condition_right_0": 1,
} }
result, err := ds.Query(t.Context(), in) result, err := ds.(*Query).Invoke(t.Context(), in)
assert.NoError(t, err) assert.NoError(t, err)
object := result["outputList"].([]any)[0].(database.Object) object := result["outputList"].([]any)[0].(database.Object)
@@ -400,10 +402,7 @@ func TestDataset_Query(t *testing.T) {
}, },
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}}, OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
QueryFields: []string{"v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8"}, QueryFields: []string{"v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8"},
OutputConfig: map[string]*vo.TypeInfo{ Limit: 10,
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{}}},
"rowNum": {Type: vo.DataTypeInteger},
},
} }
objects := make([]database.Object, 0) objects := make([]database.Object, 0)
@@ -429,16 +428,21 @@ func TestDataset_Query(t *testing.T) {
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl) mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes() mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
cfg.Op = mockDatabaseOperator defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
ds := Query{
config: cfg, ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
} OutputTypes: map[string]*vo.TypeInfo{
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{}}},
"rowNum": {Type: vo.DataTypeInteger},
},
})
assert.NoError(t, err)
in := map[string]any{ in := map[string]any{
"__condition_right_0": 1, "__condition_right_0": 1,
} }
result, err := ds.Query(t.Context(), in) result, err := ds.(*Query).Invoke(t.Context(), in)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, result["outputList"].([]any)[0].(database.Object), database.Object{ assert.Equal(t, result["outputList"].([]any)[0].(database.Object), database.Object{
"v1": "1", "v1": "1",

View File

@@ -20,47 +20,93 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"strconv"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database" "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
) )
type UpdateConfig struct { type UpdateConfig struct {
DatabaseInfoID int64 DatabaseInfoID int64
ClauseGroup *database.ClauseGroup ClauseGroup *database.ClauseGroup
OutputConfig map[string]*vo.TypeInfo }
Updater database.DatabaseOperator
func (u *UpdateConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeDatabaseUpdate,
Name: n.Data.Meta.Title,
Configs: u,
}
dsList := n.Data.Inputs.DatabaseInfoList
if len(dsList) == 0 {
return nil, fmt.Errorf("database info is requird")
}
databaseInfo := dsList[0]
dsID, err := strconv.ParseInt(databaseInfo.DatabaseInfoID, 10, 64)
if err != nil {
return nil, err
}
u.DatabaseInfoID = dsID
updateParam := n.Data.Inputs.UpdateParam
if updateParam == nil {
return nil, fmt.Errorf("update param is requird")
}
clauseGroup, err := buildClauseGroupFromCondition(&updateParam.Condition)
if err != nil {
return nil, err
}
u.ClauseGroup = clauseGroup
if err = setDatabaseInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func (u *UpdateConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
if u.DatabaseInfoID == 0 {
return nil, errors.New("database info id is required and greater than 0")
}
if u.ClauseGroup == nil {
return nil, errors.New("clause group is required and greater than 0")
}
return &Update{
databaseInfoID: u.DatabaseInfoID,
clauseGroup: u.ClauseGroup,
outputTypes: ns.OutputTypes,
updater: database.GetDatabaseOperator(),
}, nil
} }
type Update struct { type Update struct {
config *UpdateConfig databaseInfoID int64
clauseGroup *database.ClauseGroup
outputTypes map[string]*vo.TypeInfo
updater database.DatabaseOperator
} }
type UpdateInventory struct {
type updateInventory struct {
ConditionGroup *database.ConditionGroup ConditionGroup *database.ConditionGroup
Fields map[string]any Fields map[string]any
} }
func NewUpdate(_ context.Context, cfg *UpdateConfig) (*Update, error) { func (u *Update) Invoke(ctx context.Context, in map[string]any) (map[string]any, error) {
if cfg == nil { inventory, err := convertClauseGroupToUpdateInventory(ctx, u.clauseGroup, in)
return nil, errors.New("config is required")
}
if cfg.DatabaseInfoID == 0 {
return nil, errors.New("database info id is required and greater than 0")
}
if cfg.ClauseGroup == nil {
return nil, errors.New("clause group is required and greater than 0")
}
if cfg.Updater == nil {
return nil, errors.New("updater is required")
}
return &Update{config: cfg}, nil
}
func (u *Update) Update(ctx context.Context, in map[string]any) (map[string]any, error) {
inventory, err := convertClauseGroupToUpdateInventory(ctx, u.config.ClauseGroup, in)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -72,20 +118,20 @@ func (u *Update) Update(ctx context.Context, in map[string]any) (map[string]any,
} }
req := &database.UpdateRequest{ req := &database.UpdateRequest{
DatabaseInfoID: u.config.DatabaseInfoID, DatabaseInfoID: u.databaseInfoID,
ConditionGroup: inventory.ConditionGroup, ConditionGroup: inventory.ConditionGroup,
Fields: fields, Fields: fields,
IsDebugRun: isDebugExecute(ctx), IsDebugRun: isDebugExecute(ctx),
UserID: getExecUserID(ctx), UserID: getExecUserID(ctx),
} }
response, err := u.config.Updater.Update(ctx, req) response, err := u.updater.Update(ctx, req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ret, err := responseFormatted(u.config.OutputConfig, response) ret, err := responseFormatted(u.outputTypes, response)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -94,15 +140,15 @@ func (u *Update) Update(ctx context.Context, in map[string]any) (map[string]any,
} }
func (u *Update) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) { func (u *Update) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) {
inventory, err := convertClauseGroupToUpdateInventory(context.Background(), u.config.ClauseGroup, in) inventory, err := convertClauseGroupToUpdateInventory(context.Background(), u.clauseGroup, in)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return u.toDatabaseUpdateCallbackInput(inventory) return u.toDatabaseUpdateCallbackInput(inventory)
} }
func (u *Update) toDatabaseUpdateCallbackInput(inventory *UpdateInventory) (map[string]any, error) { func (u *Update) toDatabaseUpdateCallbackInput(inventory *updateInventory) (map[string]any, error) {
databaseID := u.config.DatabaseInfoID databaseID := u.databaseInfoID
result := make(map[string]any) result := make(map[string]any)
result["databaseInfoList"] = []string{fmt.Sprintf("%d", databaseID)} result["databaseInfoList"] = []string{fmt.Sprintf("%d", databaseID)}
result["updateParam"] = map[string]any{} result["updateParam"] = map[string]any{}
@@ -128,6 +174,6 @@ func (u *Update) toDatabaseUpdateCallbackInput(inventory *UpdateInventory) (map[
"condition": condition, "condition": condition,
"fieldInfo": fieldInfo, "fieldInfo": fieldInfo,
} }
return result, nil
return result, nil
} }

View File

@@ -18,7 +18,6 @@ package emitter
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"io" "io"
"strings" "strings"
@@ -26,28 +25,77 @@ import (
"github.com/bytedance/sonic" "github.com/bytedance/sonic"
"github.com/cloudwego/eino/schema" "github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/logs" "github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/safego" "github.com/coze-dev/coze-studio/backend/pkg/safego"
) )
type OutputEmitter struct { type OutputEmitter struct {
cfg *Config Template string
FullSources map[string]*schema2.SourceInfo
} }
type Config struct { type Config struct {
Template string Template string
FullSources map[string]*nodes.SourceInfo
} }
func New(_ context.Context, cfg *Config) (*OutputEmitter, error) { func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema2.NodeSchema, error) {
if cfg == nil { ns := &schema2.NodeSchema{
return nil, errors.New("config is required") Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeOutputEmitter,
Name: n.Data.Meta.Title,
Configs: c,
} }
content := n.Data.Inputs.Content
streamingOutput := n.Data.Inputs.StreamingOutput
if streamingOutput {
ns.StreamConfigs = &schema2.StreamConfig{
RequireStreamingInput: true,
}
} else {
ns.StreamConfigs = &schema2.StreamConfig{
RequireStreamingInput: false,
}
}
if content != nil {
if content.Type != vo.VariableTypeString {
return nil, fmt.Errorf("output emitter node's content type must be %s, got %s", vo.VariableTypeString, content.Type)
}
if content.Value.Type != vo.BlockInputValueTypeLiteral {
return nil, fmt.Errorf("output emitter node's content value type must be %s, got %s", vo.BlockInputValueTypeLiteral, content.Value.Type)
}
if content.Value.Content == nil {
c.Template = ""
} else {
template, ok := content.Value.Content.(string)
if !ok {
return nil, fmt.Errorf("output emitter node's content value must be string, got %v", content.Value.Content)
}
c.Template = template
}
}
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func (c *Config) Build(_ context.Context, ns *schema2.NodeSchema, _ ...schema2.BuildOption) (any, error) {
return &OutputEmitter{ return &OutputEmitter{
cfg: cfg, Template: c.Template,
FullSources: ns.FullSources,
}, nil }, nil
} }
@@ -59,10 +107,10 @@ type cachedVal struct {
type cacheStore struct { type cacheStore struct {
store map[string]*cachedVal store map[string]*cachedVal
infos map[string]*nodes.SourceInfo infos map[string]*schema2.SourceInfo
} }
func newCacheStore(infos map[string]*nodes.SourceInfo) *cacheStore { func newCacheStore(infos map[string]*schema2.SourceInfo) *cacheStore {
return &cacheStore{ return &cacheStore{
store: make(map[string]*cachedVal), store: make(map[string]*cachedVal),
infos: infos, infos: infos,
@@ -76,7 +124,7 @@ func (c *cacheStore) put(k string, v any) (any, error) {
} }
if !sInfo.IsIntermediate { // this is not an intermediate object container if !sInfo.IsIntermediate { // this is not an intermediate object container
isStream := sInfo.FieldType == nodes.FieldIsStream isStream := sInfo.FieldType == schema2.FieldIsStream
if !isStream { if !isStream {
_, ok := c.store[k] _, ok := c.store[k]
if !ok { if !ok {
@@ -159,7 +207,7 @@ func (c *cacheStore) put(k string, v any) (any, error) {
func (c *cacheStore) finished(k string) bool { func (c *cacheStore) finished(k string) bool {
cached, ok := c.store[k] cached, ok := c.store[k]
if !ok { if !ok {
return c.infos[k].FieldType == nodes.FieldSkipped return c.infos[k].FieldType == schema2.FieldSkipped
} }
if cached.finished { if cached.finished {
@@ -182,7 +230,7 @@ func (c *cacheStore) finished(k string) bool {
return true return true
} }
func (c *cacheStore) find(part nodes.TemplatePart) (root any, subCache *cachedVal, sourceInfo *nodes.SourceInfo, func (c *cacheStore) find(part nodes.TemplatePart) (root any, subCache *cachedVal, sourceInfo *schema2.SourceInfo,
actualPath []string, actualPath []string,
) { ) {
rootCached, ok := c.store[part.Root] rootCached, ok := c.store[part.Root]
@@ -230,7 +278,7 @@ func (c *cacheStore) readyForPart(part nodes.TemplatePart, sw *schema.StreamWrit
hasErr bool, partFinished bool) { hasErr bool, partFinished bool) {
cachedRoot, subCache, sourceInfo, _ := c.find(part) cachedRoot, subCache, sourceInfo, _ := c.find(part)
if cachedRoot != nil && subCache != nil { if cachedRoot != nil && subCache != nil {
if subCache.finished || sourceInfo.FieldType == nodes.FieldIsStream { if subCache.finished || sourceInfo.FieldType == schema2.FieldIsStream {
hasErr = renderAndSend(part, part.Root, cachedRoot, sw) hasErr = renderAndSend(part, part.Root, cachedRoot, sw)
if hasErr { if hasErr {
return true, false return true, false
@@ -315,14 +363,14 @@ func merge(a, b any) any {
const outputKey = "output" const outputKey = "output"
func (e *OutputEmitter) EmitStream(ctx context.Context, in *schema.StreamReader[map[string]any]) (out *schema.StreamReader[map[string]any], err error) { func (e *OutputEmitter) Transform(ctx context.Context, in *schema.StreamReader[map[string]any]) (out *schema.StreamReader[map[string]any], err error) {
resolvedSources, err := nodes.ResolveStreamSources(ctx, e.cfg.FullSources) resolvedSources, err := nodes.ResolveStreamSources(ctx, e.FullSources)
if err != nil { if err != nil {
return nil, err return nil, err
} }
sr, sw := schema.Pipe[map[string]any](0) sr, sw := schema.Pipe[map[string]any](0)
parts := nodes.ParseTemplate(e.cfg.Template) parts := nodes.ParseTemplate(e.Template)
safego.Go(ctx, func() { safego.Go(ctx, func() {
hasErr := false hasErr := false
defer func() { defer func() {
@@ -454,7 +502,7 @@ func (e *OutputEmitter) EmitStream(ctx context.Context, in *schema.StreamReader[
shouldChangePart = true shouldChangePart = true
} }
} else { } else {
if sourceInfo.FieldType == nodes.FieldIsStream { if sourceInfo.FieldType == schema2.FieldIsStream {
currentV := v currentV := v
for i := 0; i < len(actualPath)-1; i++ { for i := 0; i < len(actualPath)-1; i++ {
currentM, ok := currentV.(map[string]any) currentM, ok := currentV.(map[string]any)
@@ -518,8 +566,8 @@ func (e *OutputEmitter) EmitStream(ctx context.Context, in *schema.StreamReader[
return sr, nil return sr, nil
} }
func (e *OutputEmitter) Emit(ctx context.Context, in map[string]any) (output map[string]any, err error) { func (e *OutputEmitter) Invoke(ctx context.Context, in map[string]any) (output map[string]any, err error) {
s, err := nodes.Render(ctx, e.cfg.Template, in, e.cfg.FullSources) s, err := nodes.Render(ctx, e.Template, in, e.FullSources)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -20,41 +20,74 @@ import (
"context" "context"
"fmt" "fmt"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
) )
type Config struct { type Config struct {
DefaultValues map[string]any DefaultValues map[string]any
OutputTypes map[string]*vo.TypeInfo
} }
type Entry struct { func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
cfg *Config if n.Parent() != nil {
defaultValues map[string]any return nil, fmt.Errorf("entry node cannot have parent: %s", n.Parent().ID)
}
func NewEntry(ctx context.Context, cfg *Config) (*Entry, error) {
if cfg == nil {
return nil, fmt.Errorf("config is requried")
} }
defaultValues, _, err := nodes.ConvertInputs(ctx, cfg.DefaultValues, cfg.OutputTypes, nodes.FailFast(), nodes.SkipRequireCheck())
if n.ID != entity.EntryNodeKey {
return nil, fmt.Errorf("entry node id must be %s, got %s", entity.EntryNodeKey, n.ID)
}
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Name: n.Data.Meta.Title,
Type: entity.NodeTypeEntry,
}
defaultValues := make(map[string]any, len(n.Data.Outputs))
for _, v := range n.Data.Outputs {
variable, err := vo.ParseVariable(v)
if err != nil {
return nil, err
}
if variable.DefaultValue != nil {
defaultValues[variable.Name] = variable.DefaultValue
}
}
c.DefaultValues = defaultValues
ns.Configs = c
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func (c *Config) Build(ctx context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
defaultValues, _, err := nodes.ConvertInputs(ctx, c.DefaultValues, ns.OutputTypes, nodes.FailFast(), nodes.SkipRequireCheck())
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &Entry{ return &Entry{
cfg: cfg,
defaultValues: defaultValues, defaultValues: defaultValues,
outputTypes: ns.OutputTypes,
}, nil }, nil
}
type Entry struct {
defaultValues map[string]any
outputTypes map[string]*vo.TypeInfo
} }
func (e *Entry) Invoke(_ context.Context, in map[string]any) (out map[string]any, err error) { func (e *Entry) Invoke(_ context.Context, in map[string]any) (out map[string]any, err error) {
for k, v := range e.defaultValues { for k, v := range e.defaultValues {
if val, ok := in[k]; ok { if val, ok := in[k]; ok {
tInfo := e.cfg.OutputTypes[k] tInfo := e.outputTypes[k]
switch tInfo.Type { switch tInfo.Type {
case vo.DataTypeString: case vo.DataTypeString:
if len(val.(string)) == 0 { if len(val.(string)) == 0 {

View File

@@ -0,0 +1,113 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package exit
import (
"context"
"fmt"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/emitter"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
type Config struct {
Template string
TerminatePlan vo.TerminatePlan
}
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
if n.Parent() != nil {
return nil, fmt.Errorf("exit node cannot have parent: %s", n.Parent().ID)
}
if n.ID != entity.ExitNodeKey {
return nil, fmt.Errorf("exit node id must be %s, got %s", entity.ExitNodeKey, n.ID)
}
ns := &schema.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Name: n.Data.Meta.Title,
Configs: c,
}
var (
content *vo.BlockInput
streamingOutput bool
)
if n.Data.Inputs.OutputEmitter != nil {
content = n.Data.Inputs.Content
streamingOutput = n.Data.Inputs.StreamingOutput
}
if streamingOutput {
ns.StreamConfigs = &schema.StreamConfig{
RequireStreamingInput: true,
}
} else {
ns.StreamConfigs = &schema.StreamConfig{
RequireStreamingInput: false,
}
}
if content != nil {
if content.Type != vo.VariableTypeString {
return nil, fmt.Errorf("exit node's content type must be %s, got %s", vo.VariableTypeString, content.Type)
}
if content.Value.Type != vo.BlockInputValueTypeLiteral {
return nil, fmt.Errorf("exit node's content value type must be %s, got %s", vo.BlockInputValueTypeLiteral, content.Value.Type)
}
c.Template = content.Value.Content.(string)
}
if n.Data.Inputs.TerminatePlan == nil {
return nil, fmt.Errorf("exit node requires a TerminatePlan")
}
c.TerminatePlan = *n.Data.Inputs.TerminatePlan
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
if c.TerminatePlan == vo.ReturnVariables {
return &Exit{}, nil
}
return &emitter.OutputEmitter{
Template: c.Template,
FullSources: ns.FullSources,
}, nil
}
type Exit struct{}
func (e *Exit) Invoke(_ context.Context, in map[string]any) (map[string]any, error) {
if in == nil {
return map[string]any{}, nil
}
return in, nil
}

View File

@@ -0,0 +1,340 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package httprequester
import (
"fmt"
"regexp"
"strings"
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/lang/crypto"
)
var extractBracesRegexp = regexp.MustCompile(`\{\{(.*?)}}`)
func extractBracesContent(s string) []string {
matches := extractBracesRegexp.FindAllStringSubmatch(s, -1)
var result []string
for _, match := range matches {
if len(match) >= 2 {
result = append(result, match[1])
}
}
return result
}
type ImplicitNodeDependency struct {
NodeID string
FieldPath compose.FieldPath
TypeInfo *vo.TypeInfo
}
func extractImplicitDependency(node *vo.Node, canvas *vo.Canvas) ([]*ImplicitNodeDependency, error) {
dependencies := make([]*ImplicitNodeDependency, 0, len(canvas.Nodes))
url := node.Data.Inputs.APIInfo.URL
urlVars := extractBracesContent(url)
hasReferred := make(map[string]bool)
extractDependenciesFromVars := func(vars []string) error {
for _, v := range vars {
if strings.HasPrefix(v, "block_output_") {
paths := strings.Split(strings.TrimPrefix(v, "block_output_"), ".")
if len(paths) < 2 {
return fmt.Errorf("invalid block_output_ variable: %s", v)
}
if hasReferred[v] {
continue
}
hasReferred[v] = true
dependencies = append(dependencies, &ImplicitNodeDependency{
NodeID: paths[0],
FieldPath: paths[1:],
})
}
}
return nil
}
err := extractDependenciesFromVars(urlVars)
if err != nil {
return nil, err
}
if node.Data.Inputs.Body.BodyType == string(BodyTypeJSON) {
jsonVars := extractBracesContent(node.Data.Inputs.Body.BodyData.Json)
err = extractDependenciesFromVars(jsonVars)
if err != nil {
return nil, err
}
}
if node.Data.Inputs.Body.BodyType == string(BodyTypeRawText) {
rawTextVars := extractBracesContent(node.Data.Inputs.Body.BodyData.Json)
err = extractDependenciesFromVars(rawTextVars)
if err != nil {
return nil, err
}
}
var nodeFinder func(nodes []*vo.Node, nodeID string) *vo.Node
nodeFinder = func(nodes []*vo.Node, nodeID string) *vo.Node {
for i := range nodes {
if nodes[i].ID == nodeID {
return nodes[i]
}
if len(nodes[i].Blocks) > 0 {
if n := nodeFinder(nodes[i].Blocks, nodeID); n != nil {
return n
}
}
}
return nil
}
for _, ds := range dependencies {
fNode := nodeFinder(canvas.Nodes, ds.NodeID)
if fNode == nil {
continue
}
tInfoMap := make(map[string]*vo.TypeInfo, len(node.Data.Outputs))
for _, vAny := range fNode.Data.Outputs {
v, err := vo.ParseVariable(vAny)
if err != nil {
return nil, err
}
tInfo, err := convert.CanvasVariableToTypeInfo(v)
if err != nil {
return nil, err
}
tInfoMap[v.Name] = tInfo
}
tInfo, ok := getTypeInfoByPath(ds.FieldPath[0], ds.FieldPath[1:], tInfoMap)
if !ok {
return nil, fmt.Errorf("cannot find type info for dependency: %s", ds.FieldPath)
}
ds.TypeInfo = tInfo
}
return dependencies, nil
}
func getTypeInfoByPath(root string, properties []string, tInfoMap map[string]*vo.TypeInfo) (*vo.TypeInfo, bool) {
if len(properties) == 0 {
if tInfo, ok := tInfoMap[root]; ok {
return tInfo, true
}
return nil, false
}
tInfo, ok := tInfoMap[root]
if !ok {
return nil, false
}
return getTypeInfoByPath(properties[0], properties[1:], tInfo.Properties)
}
var globalVariableRegex = regexp.MustCompile(`global_variable_\w+\s*\["(.*?)"]`)
func setHttpRequesterInputsForNodeSchema(n *vo.Node, ns *schema.NodeSchema, implicitNodeDependencies []*ImplicitNodeDependency) (err error) {
inputs := n.Data.Inputs
implicitPathVars := make(map[string]bool)
addImplicitVarsSources := func(prefix string, vars []string) error {
for _, v := range vars {
if strings.HasPrefix(v, "block_output_") {
paths := strings.Split(strings.TrimPrefix(v, "block_output_"), ".")
if len(paths) < 2 {
return fmt.Errorf("invalid implicit var : %s", v)
}
for _, dep := range implicitNodeDependencies {
if dep.NodeID == paths[0] && strings.Join(dep.FieldPath, ".") == strings.Join(paths[1:], ".") {
pathValue := prefix + crypto.MD5HexValue(v)
if _, visited := implicitPathVars[pathValue]; visited {
continue
}
implicitPathVars[pathValue] = true
ns.SetInputType(pathValue, dep.TypeInfo)
ns.AddInputSource(&vo.FieldInfo{
Path: []string{pathValue},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: vo.NodeKey(dep.NodeID),
FromPath: dep.FieldPath,
},
},
})
}
}
}
if strings.HasPrefix(v, "global_variable_") {
matches := globalVariableRegex.FindStringSubmatch(v)
if len(matches) < 2 {
continue
}
var varType vo.GlobalVarType
if strings.HasPrefix(v, string(vo.RefSourceTypeGlobalApp)) {
varType = vo.GlobalAPP
} else if strings.HasPrefix(v, string(vo.RefSourceTypeGlobalUser)) {
varType = vo.GlobalUser
} else if strings.HasPrefix(v, string(vo.RefSourceTypeGlobalSystem)) {
varType = vo.GlobalSystem
} else {
return fmt.Errorf("invalid global variable type: %s", v)
}
source := vo.FieldSource{
Ref: &vo.Reference{
VariableType: &varType,
FromPath: []string{matches[1]},
},
}
ns.AddInputSource(&vo.FieldInfo{
Path: []string{prefix + crypto.MD5HexValue(v)},
Source: source,
})
}
}
return nil
}
urlVars := extractBracesContent(inputs.APIInfo.URL)
err = addImplicitVarsSources("__apiInfo_url_", urlVars)
if err != nil {
return err
}
err = applyParamsToSchema(ns, "__headers_", inputs.Headers, n.Parent())
if err != nil {
return err
}
err = applyParamsToSchema(ns, "__params_", inputs.Params, n.Parent())
if err != nil {
return err
}
if inputs.Auth != nil && inputs.Auth.AuthOpen {
authData := inputs.Auth.AuthData
const bearerTokenKey = "__auth_authData_bearerTokenData_token"
if inputs.Auth.AuthType == "BEARER_AUTH" {
bearTokenParam := authData.BearerTokenData[0]
tInfo, err := convert.CanvasBlockInputToTypeInfo(bearTokenParam.Input)
if err != nil {
return err
}
ns.SetInputType(bearerTokenKey, tInfo)
sources, err := convert.CanvasBlockInputToFieldInfo(bearTokenParam.Input, compose.FieldPath{bearerTokenKey}, n.Parent())
if err != nil {
return err
}
ns.AddInputSource(sources...)
}
if inputs.Auth.AuthType == "CUSTOM_AUTH" {
const (
customDataDataKey = "__auth_authData_customData_data_Key"
customDataDataValue = "__auth_authData_customData_data_Value"
)
dataParams := authData.CustomData.Data
keyParam := dataParams[0]
keyTypeInfo, err := convert.CanvasBlockInputToTypeInfo(keyParam.Input)
if err != nil {
return err
}
ns.SetInputType(customDataDataKey, keyTypeInfo)
sources, err := convert.CanvasBlockInputToFieldInfo(keyParam.Input, compose.FieldPath{customDataDataKey}, n.Parent())
if err != nil {
return err
}
ns.AddInputSource(sources...)
valueParam := dataParams[1]
valueTypeInfo, err := convert.CanvasBlockInputToTypeInfo(valueParam.Input)
if err != nil {
return err
}
ns.SetInputType(customDataDataValue, valueTypeInfo)
sources, err = convert.CanvasBlockInputToFieldInfo(valueParam.Input, compose.FieldPath{customDataDataValue}, n.Parent())
if err != nil {
return err
}
ns.AddInputSource(sources...)
}
}
switch BodyType(inputs.Body.BodyType) {
case BodyTypeFormData:
err = applyParamsToSchema(ns, "__body_bodyData_formData_", inputs.Body.BodyData.FormData.Data, n.Parent())
if err != nil {
return err
}
case BodyTypeFormURLEncoded:
err = applyParamsToSchema(ns, "__body_bodyData_formURLEncoded_", inputs.Body.BodyData.FormURLEncoded, n.Parent())
if err != nil {
return err
}
case BodyTypeBinary:
const fileURLName = "__body_bodyData_binary_fileURL"
fileURLInput := inputs.Body.BodyData.Binary.FileURL
ns.SetInputType(fileURLName, &vo.TypeInfo{
Type: vo.DataTypeString,
})
sources, err := convert.CanvasBlockInputToFieldInfo(fileURLInput, compose.FieldPath{fileURLName}, n.Parent())
if err != nil {
return err
}
ns.AddInputSource(sources...)
case BodyTypeJSON:
jsonVars := extractBracesContent(inputs.Body.BodyData.Json)
err = addImplicitVarsSources("__body_bodyData_json_", jsonVars)
if err != nil {
return err
}
case BodyTypeRawText:
rawTextVars := extractBracesContent(inputs.Body.BodyData.RawText)
err = addImplicitVarsSources("__body_bodyData_rawText_", rawTextVars)
if err != nil {
return err
}
}
return nil
}
func applyParamsToSchema(ns *schema.NodeSchema, prefix string, params []*vo.Param, parentNode *vo.Node) error {
for i := range params {
param := params[i]
name := param.Name
tInfo, err := convert.CanvasBlockInputToTypeInfo(param.Input)
if err != nil {
return err
}
fieldName := prefix + crypto.MD5HexValue(name)
ns.SetInputType(fieldName, tInfo)
sources, err := convert.CanvasBlockInputToFieldInfo(param.Input, compose.FieldPath{fieldName}, parentNode)
if err != nil {
return err
}
ns.AddInputSource(sources...)
}
return nil
}

View File

@@ -31,9 +31,14 @@ import (
"strings" "strings"
"time" "time"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/lang/crypto" "github.com/coze-dev/coze-studio/backend/pkg/lang/crypto"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/pkg/sonic" "github.com/coze-dev/coze-studio/backend/pkg/sonic"
) )
@@ -129,7 +134,7 @@ type Request struct {
FileURL *string FileURL *string
} }
var globalVariableReplaceRegexp = regexp.MustCompile(`global_variable_(\w+)\["(\w+)"\]`) var globalVariableReplaceRegexp = regexp.MustCompile(`global_variable_(\w+)\["(\w+)"]`)
type MD5FieldMapping struct { type MD5FieldMapping struct {
HeaderMD5Mapping map[string]string `json:"header_md_5_mapping,omitempty"` // md5 vs key HeaderMD5Mapping map[string]string `json:"header_md_5_mapping,omitempty"` // md5 vs key
@@ -184,49 +189,188 @@ type Config struct {
Timeout time.Duration Timeout time.Duration
RetryTimes uint64 RetryTimes uint64
IgnoreException bool
DefaultOutput map[string]any
MD5FieldMapping MD5FieldMapping
} }
type HTTPRequester struct { func (c *Config) Adapt(_ context.Context, n *vo.Node, opts ...nodes.AdaptOption) (*schema.NodeSchema, error) {
client *http.Client options := nodes.GetAdaptOptions(opts...)
config *Config if options.Canvas == nil {
} return nil, fmt.Errorf("canvas is requried when adapting HTTPRequester node")
func NewHTTPRequester(_ context.Context, cfg *Config) (*HTTPRequester, error) {
if cfg == nil {
return nil, fmt.Errorf("config is requried")
} }
if len(cfg.Method) == 0 { implicitDeps, err := extractImplicitDependency(n, options.Canvas)
if err != nil {
return nil, err
}
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeHTTPRequester,
Name: n.Data.Meta.Title,
Configs: c,
}
inputs := n.Data.Inputs
md5FieldMapping := &MD5FieldMapping{}
method := inputs.APIInfo.Method
c.Method = method
reqURL := inputs.APIInfo.URL
c.URLConfig = URLConfig{
Tpl: strings.TrimSpace(reqURL),
}
urlVars := extractBracesContent(reqURL)
md5FieldMapping.SetURLFields(urlVars...)
md5FieldMapping.SetHeaderFields(slices.Transform(inputs.Headers, func(a *vo.Param) string {
return a.Name
})...)
md5FieldMapping.SetParamFields(slices.Transform(inputs.Params, func(a *vo.Param) string {
return a.Name
})...)
if inputs.Auth != nil && inputs.Auth.AuthOpen {
auth := &AuthenticationConfig{}
ty, err := convertAuthType(inputs.Auth.AuthType)
if err != nil {
return nil, err
}
auth.Type = ty
location, err := convertLocation(inputs.Auth.AuthData.CustomData.AddTo)
if err != nil {
return nil, err
}
auth.Location = location
c.AuthConfig = auth
}
bodyConfig := BodyConfig{}
bodyConfig.BodyType = BodyType(inputs.Body.BodyType)
switch BodyType(inputs.Body.BodyType) {
case BodyTypeJSON:
jsonTpl := inputs.Body.BodyData.Json
bodyConfig.TextJsonConfig = &TextJsonConfig{
Tpl: jsonTpl,
}
jsonVars := extractBracesContent(jsonTpl)
md5FieldMapping.SetBodyFields(jsonVars...)
case BodyTypeFormData:
bodyConfig.FormDataConfig = &FormDataConfig{
FileTypeMapping: map[string]bool{},
}
formDataVars := make([]string, 0)
for i := range inputs.Body.BodyData.FormData.Data {
p := inputs.Body.BodyData.FormData.Data[i]
formDataVars = append(formDataVars, p.Name)
if p.Input.Type == vo.VariableTypeString && p.Input.AssistType > vo.AssistTypeNotSet && p.Input.AssistType < vo.AssistTypeTime {
bodyConfig.FormDataConfig.FileTypeMapping[p.Name] = true
}
}
md5FieldMapping.SetBodyFields(formDataVars...)
case BodyTypeRawText:
TextTpl := inputs.Body.BodyData.RawText
bodyConfig.TextPlainConfig = &TextPlainConfig{
Tpl: TextTpl,
}
textPlainVars := extractBracesContent(TextTpl)
md5FieldMapping.SetBodyFields(textPlainVars...)
case BodyTypeFormURLEncoded:
formURLEncodedVars := make([]string, 0)
for _, p := range inputs.Body.BodyData.FormURLEncoded {
formURLEncodedVars = append(formURLEncodedVars, p.Name)
}
md5FieldMapping.SetBodyFields(formURLEncodedVars...)
}
c.BodyConfig = bodyConfig
c.MD5FieldMapping = *md5FieldMapping
if inputs.Setting != nil {
c.Timeout = time.Duration(inputs.Setting.Timeout) * time.Second
c.RetryTimes = uint64(inputs.Setting.RetryTimes)
}
if err := setHttpRequesterInputsForNodeSchema(n, ns, implicitDeps); err != nil {
return nil, err
}
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func convertAuthType(auth string) (AuthType, error) {
switch auth {
case "CUSTOM_AUTH":
return Custom, nil
case "BEARER_AUTH":
return BearToken, nil
default:
return AuthType(0), fmt.Errorf("invalid auth type")
}
}
func convertLocation(l string) (Location, error) {
switch l {
case "header":
return Header, nil
case "query":
return QueryParam, nil
default:
return 0, fmt.Errorf("invalid location")
}
}
func (c *Config) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
if len(c.Method) == 0 {
return nil, fmt.Errorf("method is requried") return nil, fmt.Errorf("method is requried")
} }
hg := &HTTPRequester{} hg := &HTTPRequester{
urlConfig: c.URLConfig,
method: c.Method,
retryTimes: c.RetryTimes,
authConfig: c.AuthConfig,
bodyConfig: c.BodyConfig,
md5FieldMapping: c.MD5FieldMapping,
}
client := http.DefaultClient client := http.DefaultClient
if cfg.Timeout > 0 { if c.Timeout > 0 {
client.Timeout = cfg.Timeout client.Timeout = c.Timeout
} }
hg.client = client hg.client = client
hg.config = cfg
return hg, nil return hg, nil
} }
type HTTPRequester struct {
client *http.Client
urlConfig URLConfig
authConfig *AuthenticationConfig
bodyConfig BodyConfig
method string
retryTimes uint64
md5FieldMapping MD5FieldMapping
}
func (hg *HTTPRequester) Invoke(ctx context.Context, input map[string]any) (output map[string]any, err error) { func (hg *HTTPRequester) Invoke(ctx context.Context, input map[string]any) (output map[string]any, err error) {
var ( var (
req = &Request{} req = &Request{}
method = hg.config.Method method = hg.method
retryTimes = hg.config.RetryTimes retryTimes = hg.retryTimes
body io.ReadCloser body io.ReadCloser
contentType string contentType string
response *http.Response response *http.Response
) )
req, err = hg.config.parserToRequest(input) req, err = hg.parserToRequest(input)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -236,7 +380,7 @@ func (hg *HTTPRequester) Invoke(ctx context.Context, input map[string]any) (outp
Header: http.Header{}, Header: http.Header{},
} }
httpURL, err := nodes.TemplateRender(hg.config.URLConfig.Tpl, req.URLVars) httpURL, err := nodes.TemplateRender(hg.urlConfig.Tpl, req.URLVars)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -255,8 +399,8 @@ func (hg *HTTPRequester) Invoke(ctx context.Context, input map[string]any) (outp
params.Set(key, value) params.Set(key, value)
} }
if hg.config.AuthConfig != nil { if hg.authConfig != nil {
httpRequest.Header, params, err = hg.config.AuthConfig.addAuthentication(ctx, req.Authentication, httpRequest.Header, params) httpRequest.Header, params, err = hg.authConfig.addAuthentication(ctx, req.Authentication, httpRequest.Header, params)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -264,7 +408,7 @@ func (hg *HTTPRequester) Invoke(ctx context.Context, input map[string]any) (outp
u.RawQuery = params.Encode() u.RawQuery = params.Encode()
httpRequest.URL = u httpRequest.URL = u
body, contentType, err = hg.config.BodyConfig.getBodyAndContentType(ctx, req) body, contentType, err = hg.bodyConfig.getBodyAndContentType(ctx, req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -479,18 +623,16 @@ func httpGet(ctx context.Context, url string) (*http.Response, error) {
} }
func (hg *HTTPRequester) ToCallbackInput(_ context.Context, input map[string]any) (map[string]any, error) { func (hg *HTTPRequester) ToCallbackInput(_ context.Context, input map[string]any) (map[string]any, error) {
var ( var request = &Request{}
request = &Request{}
config = hg.config request, err := hg.parserToRequest(input)
)
request, err := hg.config.parserToRequest(input)
if err != nil { if err != nil {
return nil, err return nil, err
} }
result := make(map[string]any) result := make(map[string]any)
result["method"] = config.Method result["method"] = hg.method
u, err := nodes.TemplateRender(config.URLConfig.Tpl, request.URLVars) u, err := nodes.TemplateRender(hg.urlConfig.Tpl, request.URLVars)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -508,13 +650,13 @@ func (hg *HTTPRequester) ToCallbackInput(_ context.Context, input map[string]any
} }
result["header"] = headers result["header"] = headers
result["auth"] = nil result["auth"] = nil
if config.AuthConfig != nil { if hg.authConfig != nil {
if config.AuthConfig.Type == Custom { if hg.authConfig.Type == Custom {
result["auth"] = map[string]interface{}{ result["auth"] = map[string]interface{}{
"Key": request.Authentication.Key, "Key": request.Authentication.Key,
"Value": request.Authentication.Value, "Value": request.Authentication.Value,
} }
} else if config.AuthConfig.Type == BearToken { } else if hg.authConfig.Type == BearToken {
result["auth"] = map[string]interface{}{ result["auth"] = map[string]interface{}{
"token": request.Authentication.Token, "token": request.Authentication.Token,
} }
@@ -522,9 +664,9 @@ func (hg *HTTPRequester) ToCallbackInput(_ context.Context, input map[string]any
} }
result["body"] = nil result["body"] = nil
switch config.BodyConfig.BodyType { switch hg.bodyConfig.BodyType {
case BodyTypeJSON: case BodyTypeJSON:
js, err := nodes.TemplateRender(config.BodyConfig.TextJsonConfig.Tpl, request.JsonVars) js, err := nodes.TemplateRender(hg.bodyConfig.TextJsonConfig.Tpl, request.JsonVars)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -535,7 +677,7 @@ func (hg *HTTPRequester) ToCallbackInput(_ context.Context, input map[string]any
} }
result["body"] = ret result["body"] = ret
case BodyTypeRawText: case BodyTypeRawText:
tx, err := nodes.TemplateRender(config.BodyConfig.TextPlainConfig.Tpl, request.TextPlainVars) tx, err := nodes.TemplateRender(hg.bodyConfig.TextPlainConfig.Tpl, request.TextPlainVars)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -569,7 +711,7 @@ const (
bodyBinaryFileURLPrefix = "binary_fileURL" bodyBinaryFileURLPrefix = "binary_fileURL"
) )
func (cfg *Config) parserToRequest(input map[string]any) (*Request, error) { func (hg *HTTPRequester) parserToRequest(input map[string]any) (*Request, error) {
request := &Request{ request := &Request{
URLVars: make(map[string]any), URLVars: make(map[string]any),
Headers: make(map[string]string), Headers: make(map[string]string),
@@ -583,7 +725,7 @@ func (cfg *Config) parserToRequest(input map[string]any) (*Request, error) {
for key, value := range input { for key, value := range input {
if strings.HasPrefix(key, apiInfoURLPrefix) { if strings.HasPrefix(key, apiInfoURLPrefix) {
urlMD5 := strings.TrimPrefix(key, apiInfoURLPrefix) urlMD5 := strings.TrimPrefix(key, apiInfoURLPrefix)
if urlKey, ok := cfg.URLMD5Mapping[urlMD5]; ok { if urlKey, ok := hg.md5FieldMapping.URLMD5Mapping[urlMD5]; ok {
if strings.HasPrefix(urlKey, "global_variable_") { if strings.HasPrefix(urlKey, "global_variable_") {
urlKey = globalVariableReplaceRegexp.ReplaceAllString(urlKey, "global_variable_$1.$2") urlKey = globalVariableReplaceRegexp.ReplaceAllString(urlKey, "global_variable_$1.$2")
} }
@@ -592,13 +734,13 @@ func (cfg *Config) parserToRequest(input map[string]any) (*Request, error) {
} }
if strings.HasPrefix(key, headersPrefix) { if strings.HasPrefix(key, headersPrefix) {
headerKeyMD5 := strings.TrimPrefix(key, headersPrefix) headerKeyMD5 := strings.TrimPrefix(key, headersPrefix)
if headerKey, ok := cfg.HeaderMD5Mapping[headerKeyMD5]; ok { if headerKey, ok := hg.md5FieldMapping.HeaderMD5Mapping[headerKeyMD5]; ok {
request.Headers[headerKey] = value.(string) request.Headers[headerKey] = value.(string)
} }
} }
if strings.HasPrefix(key, paramsPrefix) { if strings.HasPrefix(key, paramsPrefix) {
paramKeyMD5 := strings.TrimPrefix(key, paramsPrefix) paramKeyMD5 := strings.TrimPrefix(key, paramsPrefix)
if paramKey, ok := cfg.ParamMD5Mapping[paramKeyMD5]; ok { if paramKey, ok := hg.md5FieldMapping.ParamMD5Mapping[paramKeyMD5]; ok {
request.Params[paramKey] = value.(string) request.Params[paramKey] = value.(string)
} }
} }
@@ -622,7 +764,7 @@ func (cfg *Config) parserToRequest(input map[string]any) (*Request, error) {
bodyKey := strings.TrimPrefix(key, bodyDataPrefix) bodyKey := strings.TrimPrefix(key, bodyDataPrefix)
if strings.HasPrefix(bodyKey, bodyJsonPrefix) { if strings.HasPrefix(bodyKey, bodyJsonPrefix) {
jsonMd5Key := strings.TrimPrefix(bodyKey, bodyJsonPrefix) jsonMd5Key := strings.TrimPrefix(bodyKey, bodyJsonPrefix)
if jsonKey, ok := cfg.BodyMD5Mapping[jsonMd5Key]; ok { if jsonKey, ok := hg.md5FieldMapping.BodyMD5Mapping[jsonMd5Key]; ok {
if strings.HasPrefix(jsonKey, "global_variable_") { if strings.HasPrefix(jsonKey, "global_variable_") {
jsonKey = globalVariableReplaceRegexp.ReplaceAllString(jsonKey, "global_variable_$1.$2") jsonKey = globalVariableReplaceRegexp.ReplaceAllString(jsonKey, "global_variable_$1.$2")
} }
@@ -632,7 +774,7 @@ func (cfg *Config) parserToRequest(input map[string]any) (*Request, error) {
} }
if strings.HasPrefix(bodyKey, bodyFormDataPrefix) { if strings.HasPrefix(bodyKey, bodyFormDataPrefix) {
formDataMd5Key := strings.TrimPrefix(bodyKey, bodyFormDataPrefix) formDataMd5Key := strings.TrimPrefix(bodyKey, bodyFormDataPrefix)
if formDataKey, ok := cfg.BodyMD5Mapping[formDataMd5Key]; ok { if formDataKey, ok := hg.md5FieldMapping.BodyMD5Mapping[formDataMd5Key]; ok {
request.FormDataVars[formDataKey] = value.(string) request.FormDataVars[formDataKey] = value.(string)
} }
@@ -640,14 +782,14 @@ func (cfg *Config) parserToRequest(input map[string]any) (*Request, error) {
if strings.HasPrefix(bodyKey, bodyFormURLEncodedPrefix) { if strings.HasPrefix(bodyKey, bodyFormURLEncodedPrefix) {
formURLEncodeMd5Key := strings.TrimPrefix(bodyKey, bodyFormURLEncodedPrefix) formURLEncodeMd5Key := strings.TrimPrefix(bodyKey, bodyFormURLEncodedPrefix)
if formURLEncodeKey, ok := cfg.BodyMD5Mapping[formURLEncodeMd5Key]; ok { if formURLEncodeKey, ok := hg.md5FieldMapping.BodyMD5Mapping[formURLEncodeMd5Key]; ok {
request.FormURLEncodedVars[formURLEncodeKey] = value.(string) request.FormURLEncodedVars[formURLEncodeKey] = value.(string)
} }
} }
if strings.HasPrefix(bodyKey, bodyRawTextPrefix) { if strings.HasPrefix(bodyKey, bodyRawTextPrefix) {
rawTextMd5Key := strings.TrimPrefix(bodyKey, bodyRawTextPrefix) rawTextMd5Key := strings.TrimPrefix(bodyKey, bodyRawTextPrefix)
if rawTextKey, ok := cfg.BodyMD5Mapping[rawTextMd5Key]; ok { if rawTextKey, ok := hg.md5FieldMapping.BodyMD5Mapping[rawTextMd5Key]; ok {
if strings.HasPrefix(rawTextKey, "global_variable_") { if strings.HasPrefix(rawTextKey, "global_variable_") {
rawTextKey = globalVariableReplaceRegexp.ReplaceAllString(rawTextKey, "global_variable_$1.$2") rawTextKey = globalVariableReplaceRegexp.ReplaceAllString(rawTextKey, "global_variable_$1.$2")
} }

View File

@@ -28,6 +28,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/lang/crypto" "github.com/coze-dev/coze-studio/backend/pkg/lang/crypto"
) )
@@ -68,7 +69,7 @@ func TestInvoke(t *testing.T) {
}, },
}, },
} }
hg, err := NewHTTPRequester(context.Background(), cfg) hg, err := cfg.Build(context.Background(), &schema.NodeSchema{})
assert.NoError(t, err) assert.NoError(t, err)
m := map[string]any{ m := map[string]any{
"__apiInfo_url_" + crypto.MD5HexValue("url_v1"): "v1", "__apiInfo_url_" + crypto.MD5HexValue("url_v1"): "v1",
@@ -78,7 +79,7 @@ func TestInvoke(t *testing.T) {
"__params_" + crypto.MD5HexValue("p2"): "v2", "__params_" + crypto.MD5HexValue("p2"): "v2",
} }
result, err := hg.Invoke(context.Background(), m) result, err := hg.(*HTTPRequester).Invoke(context.Background(), m)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, `{"message":"success"}`, result["body"]) assert.Equal(t, `{"message":"success"}`, result["body"])
assert.Equal(t, int64(200), result["statusCode"]) assert.Equal(t, int64(200), result["statusCode"])
@@ -157,7 +158,7 @@ func TestInvoke(t *testing.T) {
} }
// Create an HTTPRequest instance // Create an HTTPRequest instance
hg, err := NewHTTPRequester(context.Background(), cfg) hg, err := cfg.Build(context.Background(), &schema.NodeSchema{})
assert.NoError(t, err) assert.NoError(t, err)
m := map[string]any{ m := map[string]any{
@@ -171,7 +172,7 @@ func TestInvoke(t *testing.T) {
"__body_bodyData_formData_" + crypto.MD5HexValue("fileURL"): fileServer.URL, "__body_bodyData_formData_" + crypto.MD5HexValue("fileURL"): fileServer.URL,
} }
result, err := hg.Invoke(context.Background(), m) result, err := hg.(*HTTPRequester).Invoke(context.Background(), m)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, `{"message":"success"}`, result["body"]) assert.Equal(t, `{"message":"success"}`, result["body"])
assert.Equal(t, int64(200), result["statusCode"]) assert.Equal(t, int64(200), result["statusCode"])
@@ -228,7 +229,7 @@ func TestInvoke(t *testing.T) {
}, },
}, },
} }
hg, err := NewHTTPRequester(context.Background(), cfg) hg, err := cfg.Build(context.Background(), &schema.NodeSchema{})
assert.NoError(t, err) assert.NoError(t, err)
m := map[string]any{ m := map[string]any{
@@ -241,7 +242,7 @@ func TestInvoke(t *testing.T) {
"__body_bodyData_rawText_" + crypto.MD5HexValue("v2"): "v2", "__body_bodyData_rawText_" + crypto.MD5HexValue("v2"): "v2",
} }
result, err := hg.Invoke(context.Background(), m) result, err := hg.(*HTTPRequester).Invoke(context.Background(), m)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, `{"message":"success"}`, result["body"]) assert.Equal(t, `{"message":"success"}`, result["body"])
assert.Equal(t, int64(200), result["statusCode"]) assert.Equal(t, int64(200), result["statusCode"])
@@ -303,7 +304,7 @@ func TestInvoke(t *testing.T) {
} }
// Create an HTTPRequest instance // Create an HTTPRequest instance
hg, err := NewHTTPRequester(context.Background(), cfg) hg, err := cfg.Build(context.Background(), &schema.NodeSchema{})
assert.NoError(t, err) assert.NoError(t, err)
m := map[string]any{ m := map[string]any{
@@ -316,7 +317,7 @@ func TestInvoke(t *testing.T) {
"__body_bodyData_json_" + crypto.MD5HexValue("v2"): "v2", "__body_bodyData_json_" + crypto.MD5HexValue("v2"): "v2",
} }
result, err := hg.Invoke(context.Background(), m) result, err := hg.(*HTTPRequester).Invoke(context.Background(), m)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, `{"message":"success"}`, result["body"]) assert.Equal(t, `{"message":"success"}`, result["body"])
assert.Equal(t, int64(200), result["statusCode"]) assert.Equal(t, int64(200), result["statusCode"])
@@ -376,7 +377,7 @@ func TestInvoke(t *testing.T) {
} }
// Create an HTTPRequest instance // Create an HTTPRequest instance
hg, err := NewHTTPRequester(context.Background(), cfg) hg, err := cfg.Build(context.Background(), &schema.NodeSchema{})
assert.NoError(t, err) assert.NoError(t, err)
m := map[string]any{ m := map[string]any{
@@ -388,7 +389,7 @@ func TestInvoke(t *testing.T) {
"__body_bodyData_binary_fileURL" + crypto.MD5HexValue("v1"): fileServer.URL, "__body_bodyData_binary_fileURL" + crypto.MD5HexValue("v1"): fileServer.URL,
} }
result, err := hg.Invoke(context.Background(), m) result, err := hg.(*HTTPRequester).Invoke(context.Background(), m)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, `{"message":"success"}`, result["body"]) assert.Equal(t, `{"message":"success"}`, result["body"])
assert.Equal(t, int64(200), result["statusCode"]) assert.Equal(t, int64(200), result["statusCode"])

View File

@@ -18,26 +18,167 @@ package intentdetector
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"fmt"
"strconv" "strconv"
"strings" "strings"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/components/prompt" "github.com/cloudwego/eino/components/prompt"
"github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema" "github.com/cloudwego/eino/schema"
"github.com/spf13/cast" "github.com/spf13/cast"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary" "github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
) )
type Config struct { type Config struct {
Intents []string Intents []string
SystemPrompt string SystemPrompt string
IsFastMode bool IsFastMode bool
ChatModel model.BaseChatModel LLMParams *model.LLMParams
}
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema2.NodeSchema, error) {
ns := &schema2.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeIntentDetector,
Name: n.Data.Meta.Title,
Configs: c,
}
param := n.Data.Inputs.LLMParam
if param == nil {
return nil, fmt.Errorf("intent detector node's llmParam is nil")
}
llmParam, ok := param.(vo.IntentDetectorLLMParam)
if !ok {
return nil, fmt.Errorf("llm node's llmParam must be LLMParam, got %v", llmParam)
}
paramBytes, err := sonic.Marshal(param)
if err != nil {
return nil, err
}
var intentDetectorConfig = &vo.IntentDetectorLLMConfig{}
err = sonic.Unmarshal(paramBytes, &intentDetectorConfig)
if err != nil {
return nil, err
}
modelLLMParams := &model.LLMParams{}
modelLLMParams.ModelType = int64(intentDetectorConfig.ModelType)
modelLLMParams.ModelName = intentDetectorConfig.ModelName
modelLLMParams.TopP = intentDetectorConfig.TopP
modelLLMParams.Temperature = intentDetectorConfig.Temperature
modelLLMParams.MaxTokens = intentDetectorConfig.MaxTokens
modelLLMParams.ResponseFormat = model.ResponseFormat(intentDetectorConfig.ResponseFormat)
modelLLMParams.SystemPrompt = intentDetectorConfig.SystemPrompt.Value.Content.(string)
c.LLMParams = modelLLMParams
c.SystemPrompt = modelLLMParams.SystemPrompt
var intents = make([]string, 0, len(n.Data.Inputs.Intents))
for _, it := range n.Data.Inputs.Intents {
intents = append(intents, it.Name)
}
c.Intents = intents
if n.Data.Inputs.Mode == "top_speed" {
c.IsFastMode = true
}
if err = convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func (c *Config) Build(ctx context.Context, _ *schema2.NodeSchema, _ ...schema2.BuildOption) (any, error) {
if !c.IsFastMode && c.LLMParams == nil {
return nil, errors.New("config chat model is required")
}
if len(c.Intents) == 0 {
return nil, errors.New("config intents is required")
}
m, _, err := model.GetManager().GetModel(ctx, c.LLMParams)
if err != nil {
return nil, err
}
chain := compose.NewChain[map[string]any, *schema.Message]()
spt := ternary.IFElse[string](c.IsFastMode, FastModeSystemIntentPrompt, SystemIntentPrompt)
intents, err := toIntentString(c.Intents)
if err != nil {
return nil, err
}
sptTemplate, err := nodes.TemplateRender(spt, map[string]interface{}{
"intents": intents,
})
if err != nil {
return nil, err
}
prompts := prompt.FromMessages(schema.Jinja2,
&schema.Message{Content: sptTemplate, Role: schema.System},
&schema.Message{Content: "{{query}}", Role: schema.User})
r, err := chain.AppendChatTemplate(prompts).AppendChatModel(m).Compile(ctx)
if err != nil {
return nil, err
}
return &IntentDetector{
isFastMode: c.IsFastMode,
systemPrompt: c.SystemPrompt,
runner: r,
}, nil
}
func (c *Config) BuildBranch(_ context.Context) (
func(ctx context.Context, nodeOutput map[string]any) (int64, bool, error), bool) {
return func(ctx context.Context, nodeOutput map[string]any) (int64, bool, error) {
classificationId, ok := nodeOutput[classificationID]
if !ok {
return -1, false, fmt.Errorf("failed to take classification id from input map: %v", nodeOutput)
}
cID64, ok := classificationId.(int64)
if !ok {
return -1, false, fmt.Errorf("classificationID not of type int64, actual type: %T", classificationId)
}
if cID64 == 0 {
return -1, true, nil
}
return cID64 - 1, false, nil
}, true
}
func (c *Config) ExpectPorts(ctx context.Context, n *vo.Node) []string {
expects := make([]string, len(n.Data.Inputs.Intents)+1)
expects[0] = schema2.PortDefault
for i := 0; i < len(n.Data.Inputs.Intents); i++ {
expects[i+1] = fmt.Sprintf(schema2.PortBranchFormat, i)
}
return expects
} }
const SystemIntentPrompt = ` const SystemIntentPrompt = `
@@ -95,71 +236,39 @@ Note:
##Limit ##Limit
- Please do not reply in text.` - Please do not reply in text.`
const classificationID = "classificationId"
type IntentDetector struct { type IntentDetector struct {
config *Config isFastMode bool
runner compose.Runnable[map[string]any, *schema.Message] systemPrompt string
} runner compose.Runnable[map[string]any, *schema.Message]
func NewIntentDetector(ctx context.Context, cfg *Config) (*IntentDetector, error) {
if cfg == nil {
return nil, errors.New("cfg is required")
}
if !cfg.IsFastMode && cfg.ChatModel == nil {
return nil, errors.New("config chat model is required")
}
if len(cfg.Intents) == 0 {
return nil, errors.New("config intents is required")
}
chain := compose.NewChain[map[string]any, *schema.Message]()
spt := ternary.IFElse[string](cfg.IsFastMode, FastModeSystemIntentPrompt, SystemIntentPrompt)
sptTemplate, err := nodes.TemplateRender(spt, map[string]interface{}{
"intents": toIntentString(cfg.Intents),
})
if err != nil {
return nil, err
}
prompts := prompt.FromMessages(schema.Jinja2,
&schema.Message{Content: sptTemplate, Role: schema.System},
&schema.Message{Content: "{{query}}", Role: schema.User})
r, err := chain.AppendChatTemplate(prompts).AppendChatModel(cfg.ChatModel).Compile(ctx)
if err != nil {
return nil, err
}
return &IntentDetector{
config: cfg,
runner: r,
}, nil
} }
func (id *IntentDetector) parseToNodeOut(content string) (map[string]any, error) { func (id *IntentDetector) parseToNodeOut(content string) (map[string]any, error) {
nodeOutput := make(map[string]any)
nodeOutput["classificationId"] = 0
if content == "" { if content == "" {
return nodeOutput, errors.New("content is empty") return nil, errors.New("intent detector's LLM output content is empty")
} }
if id.config.IsFastMode { if id.isFastMode {
cid, err := strconv.ParseInt(content, 10, 64) cid, err := strconv.ParseInt(content, 10, 64)
if err != nil { if err != nil {
return nodeOutput, err return nil, err
} }
nodeOutput["classificationId"] = cid return map[string]any{
return nodeOutput, nil classificationID: cid,
}, nil
} }
leftIndex := strings.Index(content, "{") leftIndex := strings.Index(content, "{")
rightIndex := strings.Index(content, "}") rightIndex := strings.Index(content, "}")
if leftIndex == -1 || rightIndex == -1 { if leftIndex == -1 || rightIndex == -1 {
return nodeOutput, errors.New("content is invalid") return nil, fmt.Errorf("intent detector's LLM output content is invalid: %s", content)
} }
err := json.Unmarshal([]byte(content[leftIndex:rightIndex+1]), &nodeOutput) var nodeOutput map[string]any
err := sonic.UnmarshalString(content[leftIndex:rightIndex+1], &nodeOutput)
if err != nil { if err != nil {
return nodeOutput, err return nil, err
} }
return nodeOutput, nil return nodeOutput, nil
@@ -178,8 +287,8 @@ func (id *IntentDetector) Invoke(ctx context.Context, input map[string]any) (map
vars := make(map[string]any) vars := make(map[string]any)
vars["query"] = queryStr vars["query"] = queryStr
if !id.config.IsFastMode { if !id.isFastMode {
ad, err := nodes.TemplateRender(id.config.SystemPrompt, map[string]any{"query": query}) ad, err := nodes.TemplateRender(id.systemPrompt, map[string]any{"query": query})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -193,7 +302,7 @@ func (id *IntentDetector) Invoke(ctx context.Context, input map[string]any) (map
return id.parseToNodeOut(o.Content) return id.parseToNodeOut(o.Content)
} }
func toIntentString(its []string) string { func toIntentString(its []string) (string, error) {
type IntentVariableItem struct { type IntentVariableItem struct {
ClassificationID int64 `json:"classificationId"` ClassificationID int64 `json:"classificationId"`
Content string `json:"content"` Content string `json:"content"`
@@ -207,6 +316,6 @@ func toIntentString(its []string) string {
Content: it, Content: it,
}) })
} }
itsBytes, _ := json.Marshal(vs)
return string(itsBytes) return sonic.MarshalString(vs)
} }

View File

@@ -1,88 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package intentdetector
import (
"context"
"fmt"
"testing"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/schema"
"github.com/stretchr/testify/assert"
)
type mockChatModel struct {
topSeed bool
}
func (m mockChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) {
if m.topSeed {
return &schema.Message{
Content: "1",
}, nil
}
return &schema.Message{
Content: `{"classificationId":1,"reason":"高兴"}`,
}, nil
}
func (m mockChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) {
return nil, nil
}
func (m mockChatModel) BindTools(tools []*schema.ToolInfo) error {
return nil
}
func TestNewIntentDetector(t *testing.T) {
ctx := context.Background()
t.Run("fast mode", func(t *testing.T) {
dt, err := NewIntentDetector(ctx, &Config{
Intents: []string{"高兴", "悲伤"},
IsFastMode: true,
ChatModel: &mockChatModel{topSeed: true},
})
assert.Nil(t, err)
ret, err := dt.Invoke(ctx, map[string]any{
"query": "我考了100分",
})
assert.Nil(t, err)
assert.Equal(t, ret["classificationId"], int64(1))
})
t.Run("full mode", func(t *testing.T) {
dt, err := NewIntentDetector(ctx, &Config{
Intents: []string{"高兴", "悲伤"},
IsFastMode: false,
ChatModel: &mockChatModel{},
})
assert.Nil(t, err)
ret, err := dt.Invoke(ctx, map[string]any{
"query": "我考了100分",
})
fmt.Println(err)
assert.Nil(t, err)
fmt.Println(ret)
assert.Equal(t, ret["classificationId"], float64(1))
assert.Equal(t, ret["reason"], "高兴")
})
}

View File

@@ -20,8 +20,11 @@ import (
"context" "context"
"fmt" "fmt"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache" "github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/errorx" "github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/sonic" "github.com/coze-dev/coze-studio/backend/pkg/sonic"
@@ -34,32 +37,42 @@ const (
warningsKey = "deserialization_warnings" warningsKey = "deserialization_warnings"
) )
type DeserializationConfig struct { type DeserializationConfig struct{}
OutputFields map[string]*vo.TypeInfo `json:"outputFields,omitempty"`
func (d *DeserializationConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (
*schema.NodeSchema, error) {
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeJsonDeserialization,
Name: n.Data.Meta.Title,
Configs: d,
}
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
} }
type Deserializer struct { func (d *DeserializationConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
config *DeserializationConfig typeInfo, ok := ns.OutputTypes[OutputKeyDeserialization]
typeInfo *vo.TypeInfo if !ok {
}
func NewJsonDeserializer(_ context.Context, cfg *DeserializationConfig) (*Deserializer, error) {
if cfg == nil {
return nil, fmt.Errorf("config required")
}
if cfg.OutputFields == nil {
return nil, fmt.Errorf("OutputFields is required for deserialization")
}
typeInfo := cfg.OutputFields[OutputKeyDeserialization]
if typeInfo == nil {
return nil, fmt.Errorf("no output field specified in deserialization config") return nil, fmt.Errorf("no output field specified in deserialization config")
} }
return &Deserializer{ return &Deserializer{
config: cfg,
typeInfo: typeInfo, typeInfo: typeInfo,
}, nil }, nil
} }
type Deserializer struct {
typeInfo *vo.TypeInfo
}
func (jd *Deserializer) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) { func (jd *Deserializer) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
jsonStrValue := input[InputKeyDeserialization] jsonStrValue := input[InputKeyDeserialization]

View File

@@ -24,6 +24,7 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache" "github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/sonic" "github.com/coze-dev/coze-studio/backend/pkg/sonic"
) )
@@ -31,19 +32,9 @@ import (
func TestNewJsonDeserializer(t *testing.T) { func TestNewJsonDeserializer(t *testing.T) {
ctx := context.Background() ctx := context.Background()
// Test with nil config
_, err := NewJsonDeserializer(ctx, nil)
assert.Error(t, err)
assert.Contains(t, err.Error(), "config required")
// Test with missing OutputFields config
_, err = NewJsonDeserializer(ctx, &DeserializationConfig{})
assert.Error(t, err)
assert.Contains(t, err.Error(), "OutputFields is required")
// Test with missing output key in OutputFields // Test with missing output key in OutputFields
_, err = NewJsonDeserializer(ctx, &DeserializationConfig{ _, err := (&DeserializationConfig{}).Build(ctx, &schema.NodeSchema{
OutputFields: map[string]*vo.TypeInfo{ OutputTypes: map[string]*vo.TypeInfo{
"testKey": {Type: vo.DataTypeString}, "testKey": {Type: vo.DataTypeString},
}, },
}) })
@@ -51,12 +42,12 @@ func TestNewJsonDeserializer(t *testing.T) {
assert.Contains(t, err.Error(), "no output field specified in deserialization config") assert.Contains(t, err.Error(), "no output field specified in deserialization config")
// Test with valid config // Test with valid config
validConfig := &DeserializationConfig{ validConfig := &schema.NodeSchema{
OutputFields: map[string]*vo.TypeInfo{ OutputTypes: map[string]*vo.TypeInfo{
OutputKeyDeserialization: {Type: vo.DataTypeString}, OutputKeyDeserialization: {Type: vo.DataTypeString},
}, },
} }
processor, err := NewJsonDeserializer(ctx, validConfig) processor, err := (&DeserializationConfig{}).Build(ctx, validConfig)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, processor) assert.NotNil(t, processor)
} }
@@ -65,16 +56,16 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
ctx := context.Background() ctx := context.Background()
// Base type test config // Base type test config
baseTypeConfig := &DeserializationConfig{ baseTypeConfig := &schema.NodeSchema{
OutputFields: map[string]*vo.TypeInfo{ OutputTypes: map[string]*vo.TypeInfo{
"output": {Type: vo.DataTypeString}, OutputKeyDeserialization: {Type: vo.DataTypeString},
}, },
} }
// Object type test config // Object type test config
objectTypeConfig := &DeserializationConfig{ objectTypeConfig := &schema.NodeSchema{
OutputFields: map[string]*vo.TypeInfo{ OutputTypes: map[string]*vo.TypeInfo{
"output": { OutputKeyDeserialization: {
Type: vo.DataTypeObject, Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{ Properties: map[string]*vo.TypeInfo{
"name": {Type: vo.DataTypeString, Required: true}, "name": {Type: vo.DataTypeString, Required: true},
@@ -85,9 +76,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
} }
// Array type test config // Array type test config
arrayTypeConfig := &DeserializationConfig{ arrayTypeConfig := &schema.NodeSchema{
OutputFields: map[string]*vo.TypeInfo{ OutputTypes: map[string]*vo.TypeInfo{
"output": { OutputKeyDeserialization: {
Type: vo.DataTypeArray, Type: vo.DataTypeArray,
ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeInteger}, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeInteger},
}, },
@@ -95,9 +86,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
} }
// Nested array object test config // Nested array object test config
nestedArrayConfig := &DeserializationConfig{ nestedArrayConfig := &schema.NodeSchema{
OutputFields: map[string]*vo.TypeInfo{ OutputTypes: map[string]*vo.TypeInfo{
"output": { OutputKeyDeserialization: {
Type: vo.DataTypeArray, Type: vo.DataTypeArray,
ElemTypeInfo: &vo.TypeInfo{ ElemTypeInfo: &vo.TypeInfo{
Type: vo.DataTypeObject, Type: vo.DataTypeObject,
@@ -113,7 +104,7 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
// Test cases // Test cases
tests := []struct { tests := []struct {
name string name string
config *DeserializationConfig config *schema.NodeSchema
inputJSON string inputJSON string
expectedOutput any expectedOutput any
expectErr bool expectErr bool
@@ -127,9 +118,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
expectWarnings: 0, expectWarnings: 0,
}, { }, {
name: "Test integer deserialization", name: "Test integer deserialization",
config: &DeserializationConfig{ config: &schema.NodeSchema{
OutputFields: map[string]*vo.TypeInfo{ OutputTypes: map[string]*vo.TypeInfo{
"output": {Type: vo.DataTypeInteger}, OutputKeyDeserialization: {Type: vo.DataTypeInteger},
}, },
}, },
inputJSON: `123`, inputJSON: `123`,
@@ -138,9 +129,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
expectWarnings: 0, expectWarnings: 0,
}, { }, {
name: "Test boolean deserialization", name: "Test boolean deserialization",
config: &DeserializationConfig{ config: &schema.NodeSchema{
OutputFields: map[string]*vo.TypeInfo{ OutputTypes: map[string]*vo.TypeInfo{
"output": {Type: vo.DataTypeBoolean}, OutputKeyDeserialization: {Type: vo.DataTypeBoolean},
}, },
}, },
inputJSON: `true`, inputJSON: `true`,
@@ -180,9 +171,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
expectWarnings: 0, expectWarnings: 0,
}, { }, {
name: "Test type mismatch warning", name: "Test type mismatch warning",
config: &DeserializationConfig{ config: &schema.NodeSchema{
OutputFields: map[string]*vo.TypeInfo{ OutputTypes: map[string]*vo.TypeInfo{
"output": {Type: vo.DataTypeInteger}, OutputKeyDeserialization: {Type: vo.DataTypeInteger},
}, },
}, },
inputJSON: `"not a number"`, inputJSON: `"not a number"`,
@@ -198,9 +189,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
expectWarnings: 0, expectWarnings: 0,
}, { }, {
name: "Test string to integer conversion", name: "Test string to integer conversion",
config: &DeserializationConfig{ config: &schema.NodeSchema{
OutputFields: map[string]*vo.TypeInfo{ OutputTypes: map[string]*vo.TypeInfo{
"output": {Type: vo.DataTypeInteger}, OutputKeyDeserialization: {Type: vo.DataTypeInteger},
}, },
}, },
inputJSON: `"123"`, inputJSON: `"123"`,
@@ -209,9 +200,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
expectWarnings: 0, expectWarnings: 0,
}, { }, {
name: "Test float to integer conversion (integer part)", name: "Test float to integer conversion (integer part)",
config: &DeserializationConfig{ config: &schema.NodeSchema{
OutputFields: map[string]*vo.TypeInfo{ OutputTypes: map[string]*vo.TypeInfo{
"output": {Type: vo.DataTypeInteger}, OutputKeyDeserialization: {Type: vo.DataTypeInteger},
}, },
}, },
inputJSON: `123.0`, inputJSON: `123.0`,
@@ -220,9 +211,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
expectWarnings: 0, expectWarnings: 0,
}, { }, {
name: "Test float to integer conversion (non-integer part)", name: "Test float to integer conversion (non-integer part)",
config: &DeserializationConfig{ config: &schema.NodeSchema{
OutputFields: map[string]*vo.TypeInfo{ OutputTypes: map[string]*vo.TypeInfo{
"output": {Type: vo.DataTypeInteger}, OutputKeyDeserialization: {Type: vo.DataTypeInteger},
}, },
}, },
inputJSON: `123.5`, inputJSON: `123.5`,
@@ -231,9 +222,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
expectWarnings: 0, expectWarnings: 0,
}, { }, {
name: "Test boolean to integer conversion", name: "Test boolean to integer conversion",
config: &DeserializationConfig{ config: &schema.NodeSchema{
OutputFields: map[string]*vo.TypeInfo{ OutputTypes: map[string]*vo.TypeInfo{
"output": {Type: vo.DataTypeInteger}, OutputKeyDeserialization: {Type: vo.DataTypeInteger},
}, },
}, },
inputJSON: `true`, inputJSON: `true`,
@@ -242,9 +233,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
expectWarnings: 1, expectWarnings: 1,
}, { }, {
name: "Test string to boolean conversion", name: "Test string to boolean conversion",
config: &DeserializationConfig{ config: &schema.NodeSchema{
OutputFields: map[string]*vo.TypeInfo{ OutputTypes: map[string]*vo.TypeInfo{
"output": {Type: vo.DataTypeBoolean}, OutputKeyDeserialization: {Type: vo.DataTypeBoolean},
}, },
}, },
inputJSON: `"true"`, inputJSON: `"true"`,
@@ -252,10 +243,11 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
expectErr: false, expectErr: false,
expectWarnings: 0, expectWarnings: 0,
}, { }, {
name: "Test string to integer conversion in nested object", name: "Test string to integer conversion in nested object",
config: &DeserializationConfig{ inputJSON: `{"age":"456"}`,
OutputFields: map[string]*vo.TypeInfo{ config: &schema.NodeSchema{
"output": { OutputTypes: map[string]*vo.TypeInfo{
OutputKeyDeserialization: {
Type: vo.DataTypeObject, Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{ Properties: map[string]*vo.TypeInfo{
"age": {Type: vo.DataTypeInteger}, "age": {Type: vo.DataTypeInteger},
@@ -263,15 +255,14 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
}, },
}, },
}, },
inputJSON: `{"age":"456"}`,
expectedOutput: map[string]any{"age": 456}, expectedOutput: map[string]any{"age": 456},
expectErr: false, expectErr: false,
expectWarnings: 0, expectWarnings: 0,
}, { }, {
name: "Test string to integer conversion for array elements", name: "Test string to integer conversion for array elements",
config: &DeserializationConfig{ config: &schema.NodeSchema{
OutputFields: map[string]*vo.TypeInfo{ OutputTypes: map[string]*vo.TypeInfo{
"output": { OutputKeyDeserialization: {
Type: vo.DataTypeArray, Type: vo.DataTypeArray,
ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeInteger}, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeInteger},
}, },
@@ -283,9 +274,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
expectWarnings: 0, expectWarnings: 0,
}, { }, {
name: "Test string with non-numeric characters to integer conversion", name: "Test string with non-numeric characters to integer conversion",
config: &DeserializationConfig{ config: &schema.NodeSchema{
OutputFields: map[string]*vo.TypeInfo{ OutputTypes: map[string]*vo.TypeInfo{
"output": {Type: vo.DataTypeInteger}, OutputKeyDeserialization: {Type: vo.DataTypeInteger},
}, },
}, },
inputJSON: `"123abc"`, inputJSON: `"123abc"`,
@@ -294,9 +285,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
expectWarnings: 1, expectWarnings: 1,
}, { }, {
name: "Test type mismatch in nested object field", name: "Test type mismatch in nested object field",
config: &DeserializationConfig{ config: &schema.NodeSchema{
OutputFields: map[string]*vo.TypeInfo{ OutputTypes: map[string]*vo.TypeInfo{
"output": { OutputKeyDeserialization: {
Type: vo.DataTypeObject, Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{ Properties: map[string]*vo.TypeInfo{
"score": {Type: vo.DataTypeInteger}, "score": {Type: vo.DataTypeInteger},
@@ -310,9 +301,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
expectWarnings: 1, expectWarnings: 1,
}, { }, {
name: "Test partial conversion failure in array elements", name: "Test partial conversion failure in array elements",
config: &DeserializationConfig{ config: &schema.NodeSchema{
OutputFields: map[string]*vo.TypeInfo{ OutputTypes: map[string]*vo.TypeInfo{
"output": { OutputKeyDeserialization: {
Type: vo.DataTypeArray, Type: vo.DataTypeArray,
ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeInteger}, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeInteger},
}, },
@@ -326,12 +317,12 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
processor, err := NewJsonDeserializer(ctx, tt.config) processor, err := (&DeserializationConfig{}).Build(ctx, tt.config)
assert.NoError(t, err) assert.NoError(t, err)
ctxWithCache := ctxcache.Init(ctx) ctxWithCache := ctxcache.Init(ctx)
input := map[string]any{"input": tt.inputJSON} input := map[string]any{"input": tt.inputJSON}
result, err := processor.Invoke(ctxWithCache, input) result, err := processor.(*Deserializer).Invoke(ctxWithCache, input)
if tt.expectErr { if tt.expectErr {
assert.Error(t, err) assert.Error(t, err)

View File

@@ -20,7 +20,11 @@ import (
"context" "context"
"fmt" "fmt"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/sonic" "github.com/coze-dev/coze-studio/backend/pkg/sonic"
) )
@@ -29,28 +33,57 @@ const (
OutputKeySerialization = "output" OutputKeySerialization = "output"
) )
// SerializationConfig is the Config type for NodeTypeJsonSerialization.
// Each Node Type should have its own designated Config type,
// which should implement NodeAdaptor and NodeBuilder.
// NOTE: we didn't define any fields for this type,
// because this node is simple, we doesn't need to extract any SPECIFIC piece of info
// from frontend Node. In other cases we would need to do it, such as LLM's model configs.
type SerializationConfig struct { type SerializationConfig struct {
InputTypes map[string]*vo.TypeInfo // you can define ANY number of fields here,
// as long as these fields are SERIALIZABLE and EXPORTED.
// to store specific info extracted from frontend node.
// e.g.
// - LLM model configs
// - conditional expressions
// - fixed input fields such as MaxBatchSize
} }
type JsonSerializer struct { // Adapt provides conversion from Node to NodeSchema.
config *SerializationConfig // NOTE: in this specific case, we don't need AdaptOption.
} func (s *SerializationConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
ns := &schema.NodeSchema{
func NewJsonSerializer(_ context.Context, cfg *SerializationConfig) (*JsonSerializer, error) { Key: vo.NodeKey(n.ID),
if cfg == nil { Type: entity.NodeTypeJsonSerialization,
return nil, fmt.Errorf("config required") Name: n.Data.Meta.Title,
} Configs: s, // remember to set the Node's Config Type to NodeSchema as well
if cfg.InputTypes == nil {
return nil, fmt.Errorf("InputTypes is required for serialization")
} }
return &JsonSerializer{ // this sets input fields' type and mapping info
config: cfg, if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
}, nil return nil, err
}
// this set output fields' type info
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
} }
func (js *JsonSerializer) Invoke(_ context.Context, input map[string]any) (map[string]any, error) { func (s *SerializationConfig) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (
any, error) {
return &Serializer{}, nil
}
// Serializer is the actual node implementation.
type Serializer struct {
// here can holds ANY data required for node execution
}
// Invoke implements the InvokableNode interface.
func (js *Serializer) Invoke(_ context.Context, input map[string]any) (map[string]any, error) {
// Directly use the input map for serialization // Directly use the input map for serialization
if input == nil { if input == nil {
return nil, fmt.Errorf("input data for serialization cannot be nil") return nil, fmt.Errorf("input data for serialization cannot be nil")

View File

@@ -23,44 +23,34 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
) )
func TestNewJsonSerialize(t *testing.T) { func TestNewJsonSerialize(t *testing.T) {
ctx := context.Background() ctx := context.Background()
// Test with nil config
_, err := NewJsonSerializer(ctx, nil)
assert.Error(t, err)
assert.Contains(t, err.Error(), "config required")
// Test with missing InputTypes config // Test with missing InputTypes config
_, err = NewJsonSerializer(ctx, &SerializationConfig{}) s, err := (&SerializationConfig{}).Build(ctx, &schema.NodeSchema{
assert.Error(t, err)
assert.Contains(t, err.Error(), "InputTypes is required")
// Test with valid config
validConfig := &SerializationConfig{
InputTypes: map[string]*vo.TypeInfo{ InputTypes: map[string]*vo.TypeInfo{
"testKey": {Type: "string"}, "testKey": {Type: "string"},
}, },
} })
processor, err := NewJsonSerializer(ctx, validConfig)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, processor) assert.NotNil(t, s)
} }
func TestJsonSerialize_Invoke(t *testing.T) { func TestJsonSerialize_Invoke(t *testing.T) {
ctx := context.Background() ctx := context.Background()
config := &SerializationConfig{
processor, err := (&SerializationConfig{}).Build(ctx, &schema.NodeSchema{
InputTypes: map[string]*vo.TypeInfo{ InputTypes: map[string]*vo.TypeInfo{
"stringKey": {Type: "string"}, "stringKey": {Type: "string"},
"intKey": {Type: "integer"}, "intKey": {Type: "integer"},
"boolKey": {Type: "boolean"}, "boolKey": {Type: "boolean"},
"objKey": {Type: "object"}, "objKey": {Type: "object"},
}, },
} })
processor, err := NewJsonSerializer(ctx, config)
assert.NoError(t, err) assert.NoError(t, err)
// Test cases // Test cases
@@ -115,7 +105,7 @@ func TestJsonSerialize_Invoke(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result, err := processor.Invoke(ctx, tt.input) result, err := processor.(*Serializer).Invoke(ctx, tt.input)
if tt.expectErr { if tt.expectErr {
assert.Error(t, err) assert.Error(t, err)

View File

@@ -0,0 +1,57 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package knowledge
import (
"fmt"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
)
func convertParsingType(p string) (knowledge.ParseMode, error) {
switch p {
case "fast":
return knowledge.FastParseMode, nil
case "accurate":
return knowledge.AccurateParseMode, nil
default:
return "", fmt.Errorf("invalid parsingType: %s", p)
}
}
func convertChunkType(p string) (knowledge.ChunkType, error) {
switch p {
case "custom":
return knowledge.ChunkTypeCustom, nil
case "default":
return knowledge.ChunkTypeDefault, nil
default:
return "", fmt.Errorf("invalid ChunkType: %s", p)
}
}
func convertRetrievalSearchType(s int64) (knowledge.SearchType, error) {
switch s {
case 0:
return knowledge.SearchTypeSemantic, nil
case 1:
return knowledge.SearchTypeHybrid, nil
case 20:
return knowledge.SearchTypeFullText, nil
default:
return "", fmt.Errorf("invalid RetrievalSearchType %v", s)
}
}

View File

@@ -21,27 +21,45 @@ import (
"errors" "errors"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge" "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
) )
type DeleterConfig struct { type DeleterConfig struct{}
KnowledgeID int64
KnowledgeDeleter knowledge.KnowledgeOperator
}
type KnowledgeDeleter struct { func (d *DeleterConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
config *DeleterConfig ns := &schema.NodeSchema{
} Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeKnowledgeDeleter,
func NewKnowledgeDeleter(_ context.Context, cfg *DeleterConfig) (*KnowledgeDeleter, error) { Name: n.Data.Meta.Title,
if cfg.KnowledgeDeleter == nil { Configs: d,
return nil, errors.New("knowledge deleter is required")
} }
return &KnowledgeDeleter{
config: cfg, if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func (d *DeleterConfig) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
return &Deleter{
knowledgeDeleter: knowledge.GetKnowledgeOperator(),
}, nil }, nil
} }
func (k *KnowledgeDeleter) Delete(ctx context.Context, input map[string]any) (map[string]any, error) { type Deleter struct {
knowledgeDeleter knowledge.KnowledgeOperator
}
func (k *Deleter) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
documentID, ok := input["documentID"].(string) documentID, ok := input["documentID"].(string)
if !ok { if !ok {
return nil, errors.New("documentID is required and must be a string") return nil, errors.New("documentID is required and must be a string")
@@ -51,7 +69,7 @@ func (k *KnowledgeDeleter) Delete(ctx context.Context, input map[string]any) (ma
DocumentID: documentID, DocumentID: documentID,
} }
response, err := k.config.KnowledgeDeleter.Delete(ctx, req) response, err := k.knowledgeDeleter.Delete(ctx, req)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -24,7 +24,14 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"github.com/spf13/cast"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge" "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/parser" "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
) )
@@ -32,30 +39,88 @@ type IndexerConfig struct {
KnowledgeID int64 KnowledgeID int64
ParsingStrategy *knowledge.ParsingStrategy ParsingStrategy *knowledge.ParsingStrategy
ChunkingStrategy *knowledge.ChunkingStrategy ChunkingStrategy *knowledge.ChunkingStrategy
KnowledgeIndexer knowledge.KnowledgeOperator
} }
type KnowledgeIndexer struct { func (i *IndexerConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
config *IndexerConfig ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeKnowledgeIndexer,
Name: n.Data.Meta.Title,
Configs: i,
}
inputs := n.Data.Inputs
datasetListInfoParam := inputs.DatasetParam[0]
datasetIDs := datasetListInfoParam.Input.Value.Content.([]any)
if len(datasetIDs) == 0 {
return nil, fmt.Errorf("dataset ids is required")
}
knowledgeID, err := cast.ToInt64E(datasetIDs[0])
if err != nil {
return nil, err
}
i.KnowledgeID = knowledgeID
ps := inputs.StrategyParam.ParsingStrategy
parseMode, err := convertParsingType(ps.ParsingType)
if err != nil {
return nil, err
}
parsingStrategy := &knowledge.ParsingStrategy{
ParseMode: parseMode,
ImageOCR: ps.ImageOcr,
ExtractImage: ps.ImageExtraction,
ExtractTable: ps.TableExtraction,
}
i.ParsingStrategy = parsingStrategy
cs := inputs.StrategyParam.ChunkStrategy
chunkType, err := convertChunkType(cs.ChunkType)
if err != nil {
return nil, err
}
chunkingStrategy := &knowledge.ChunkingStrategy{
ChunkType: chunkType,
Separator: cs.Separator,
ChunkSize: cs.MaxToken,
Overlap: int64(cs.Overlap * float64(cs.MaxToken)),
}
i.ChunkingStrategy = chunkingStrategy
if err = convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
} }
func NewKnowledgeIndexer(_ context.Context, cfg *IndexerConfig) (*KnowledgeIndexer, error) { func (i *IndexerConfig) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
if cfg.ParsingStrategy == nil { if i.ParsingStrategy == nil {
return nil, errors.New("parsing strategy is required") return nil, errors.New("parsing strategy is required")
} }
if cfg.ChunkingStrategy == nil { if i.ChunkingStrategy == nil {
return nil, errors.New("chunking strategy is required") return nil, errors.New("chunking strategy is required")
} }
if cfg.KnowledgeIndexer == nil { return &Indexer{
return nil, errors.New("knowledge indexer is required") knowledgeID: i.KnowledgeID,
} parsingStrategy: i.ParsingStrategy,
return &KnowledgeIndexer{ chunkingStrategy: i.ChunkingStrategy,
config: cfg, knowledgeIndexer: knowledge.GetKnowledgeOperator(),
}, nil }, nil
} }
func (k *KnowledgeIndexer) Store(ctx context.Context, input map[string]any) (map[string]any, error) { type Indexer struct {
knowledgeID int64
parsingStrategy *knowledge.ParsingStrategy
chunkingStrategy *knowledge.ChunkingStrategy
knowledgeIndexer knowledge.KnowledgeOperator
}
func (k *Indexer) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
fileURL, ok := input["knowledge"].(string) fileURL, ok := input["knowledge"].(string)
if !ok { if !ok {
return nil, errors.New("knowledge is required") return nil, errors.New("knowledge is required")
@@ -68,15 +133,15 @@ func (k *KnowledgeIndexer) Store(ctx context.Context, input map[string]any) (map
} }
req := &knowledge.CreateDocumentRequest{ req := &knowledge.CreateDocumentRequest{
KnowledgeID: k.config.KnowledgeID, KnowledgeID: k.knowledgeID,
ParsingStrategy: k.config.ParsingStrategy, ParsingStrategy: k.parsingStrategy,
ChunkingStrategy: k.config.ChunkingStrategy, ChunkingStrategy: k.chunkingStrategy,
FileURL: fileURL, FileURL: fileURL,
FileName: fileName, FileName: fileName,
FileExtension: ext, FileExtension: ext,
} }
response, err := k.config.KnowledgeIndexer.Store(ctx, req) response, err := k.knowledgeIndexer.Store(ctx, req)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -20,7 +20,14 @@ import (
"context" "context"
"errors" "errors"
"github.com/spf13/cast"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge" "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices" "github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
) )
@@ -29,37 +36,136 @@ const outputList = "outputList"
type RetrieveConfig struct { type RetrieveConfig struct {
KnowledgeIDs []int64 KnowledgeIDs []int64
RetrievalStrategy *knowledge.RetrievalStrategy RetrievalStrategy *knowledge.RetrievalStrategy
Retriever knowledge.KnowledgeOperator
} }
type KnowledgeRetrieve struct { func (r *RetrieveConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
config *RetrieveConfig ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeKnowledgeRetriever,
Name: n.Data.Meta.Title,
Configs: r,
}
inputs := n.Data.Inputs
datasetListInfoParam := inputs.DatasetParam[0]
datasetIDs := datasetListInfoParam.Input.Value.Content.([]any)
knowledgeIDs := make([]int64, 0, len(datasetIDs))
for _, id := range datasetIDs {
k, err := cast.ToInt64E(id)
if err != nil {
return nil, err
}
knowledgeIDs = append(knowledgeIDs, k)
}
r.KnowledgeIDs = knowledgeIDs
retrievalStrategy := &knowledge.RetrievalStrategy{}
var getDesignatedParamContent = func(name string) (any, bool) {
for _, param := range inputs.DatasetParam {
if param.Name == name {
return param.Input.Value.Content, true
}
}
return nil, false
}
if content, ok := getDesignatedParamContent("topK"); ok {
topK, err := cast.ToInt64E(content)
if err != nil {
return nil, err
}
retrievalStrategy.TopK = &topK
}
if content, ok := getDesignatedParamContent("useRerank"); ok {
useRerank, err := cast.ToBoolE(content)
if err != nil {
return nil, err
}
retrievalStrategy.EnableRerank = useRerank
}
if content, ok := getDesignatedParamContent("useRewrite"); ok {
useRewrite, err := cast.ToBoolE(content)
if err != nil {
return nil, err
}
retrievalStrategy.EnableQueryRewrite = useRewrite
}
if content, ok := getDesignatedParamContent("isPersonalOnly"); ok {
isPersonalOnly, err := cast.ToBoolE(content)
if err != nil {
return nil, err
}
retrievalStrategy.IsPersonalOnly = isPersonalOnly
}
if content, ok := getDesignatedParamContent("useNl2sql"); ok {
useNl2sql, err := cast.ToBoolE(content)
if err != nil {
return nil, err
}
retrievalStrategy.EnableNL2SQL = useNl2sql
}
if content, ok := getDesignatedParamContent("minScore"); ok {
minScore, err := cast.ToFloat64E(content)
if err != nil {
return nil, err
}
retrievalStrategy.MinScore = &minScore
}
if content, ok := getDesignatedParamContent("strategy"); ok {
strategy, err := cast.ToInt64E(content)
if err != nil {
return nil, err
}
searchType, err := convertRetrievalSearchType(strategy)
if err != nil {
return nil, err
}
retrievalStrategy.SearchType = searchType
}
r.RetrievalStrategy = retrievalStrategy
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
} }
func NewKnowledgeRetrieve(_ context.Context, cfg *RetrieveConfig) (*KnowledgeRetrieve, error) { func (r *RetrieveConfig) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
if cfg == nil { if len(r.KnowledgeIDs) == 0 {
return nil, errors.New("cfg is required") return nil, errors.New("knowledge ids are required")
} }
if cfg.Retriever == nil { if r.RetrievalStrategy == nil {
return nil, errors.New("retriever is required")
}
if len(cfg.KnowledgeIDs) == 0 {
return nil, errors.New("knowledgeI ids is required")
}
if cfg.RetrievalStrategy == nil {
return nil, errors.New("retrieval strategy is required") return nil, errors.New("retrieval strategy is required")
} }
return &KnowledgeRetrieve{ return &Retrieve{
config: cfg, knowledgeIDs: r.KnowledgeIDs,
retrievalStrategy: r.RetrievalStrategy,
retriever: knowledge.GetKnowledgeOperator(),
}, nil }, nil
} }
func (kr *KnowledgeRetrieve) Retrieve(ctx context.Context, input map[string]any) (map[string]any, error) { type Retrieve struct {
knowledgeIDs []int64
retrievalStrategy *knowledge.RetrievalStrategy
retriever knowledge.KnowledgeOperator
}
func (kr *Retrieve) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
query, ok := input["Query"].(string) query, ok := input["Query"].(string)
if !ok { if !ok {
return nil, errors.New("capital query key is required") return nil, errors.New("capital query key is required")
@@ -67,11 +173,11 @@ func (kr *KnowledgeRetrieve) Retrieve(ctx context.Context, input map[string]any)
req := &knowledge.RetrieveRequest{ req := &knowledge.RetrieveRequest{
Query: query, Query: query,
KnowledgeIDs: kr.config.KnowledgeIDs, KnowledgeIDs: kr.knowledgeIDs,
RetrievalStrategy: kr.config.RetrievalStrategy, RetrievalStrategy: kr.retrievalStrategy,
} }
response, err := kr.config.Retriever.Retrieve(ctx, req) response, err := kr.retriever.Retrieve(ctx, req)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -34,13 +34,20 @@ import (
callbacks2 "github.com/cloudwego/eino/utils/callbacks" callbacks2 "github.com/cloudwego/eino/utils/callbacks"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
workflow3 "github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow" "github.com/coze-dev/coze-studio/backend/domain/workflow"
crossknowledge "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge" "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
crossmodel "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache" "github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices" "github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/pkg/logs" "github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/safego" "github.com/coze-dev/coze-studio/backend/pkg/safego"
@@ -143,126 +150,408 @@ const (
) )
type RetrievalStrategy struct { type RetrievalStrategy struct {
RetrievalStrategy *crossknowledge.RetrievalStrategy RetrievalStrategy *knowledge.RetrievalStrategy
NoReCallReplyMode NoReCallReplyMode NoReCallReplyMode NoReCallReplyMode
NoReCallReplyCustomizePrompt string NoReCallReplyCustomizePrompt string
} }
type KnowledgeRecallConfig struct { type KnowledgeRecallConfig struct {
ChatModel model.BaseChatModel ChatModel model.BaseChatModel
Retriever crossknowledge.KnowledgeOperator Retriever knowledge.KnowledgeOperator
RetrievalStrategy *RetrievalStrategy RetrievalStrategy *RetrievalStrategy
SelectedKnowledgeDetails []*crossknowledge.KnowledgeDetail SelectedKnowledgeDetails []*knowledge.KnowledgeDetail
} }
type Config struct { type Config struct {
ChatModel ModelWithInfo SystemPrompt string
Tools []tool.BaseTool UserPrompt string
SystemPrompt string OutputFormat Format
UserPrompt string LLMParams *crossmodel.LLMParams
OutputFormat Format FCParam *vo.FCParam
InputFields map[string]*vo.TypeInfo BackupLLMParams *crossmodel.LLMParams
OutputFields map[string]*vo.TypeInfo
ToolsReturnDirectly map[string]bool
KnowledgeRecallConfig *KnowledgeRecallConfig
FullSources map[string]*nodes.SourceInfo
} }
type LLM struct { func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema2.NodeSchema, error) {
r compose.Runnable[map[string]any, map[string]any] ns := &schema2.NodeSchema{
outputFormat Format Key: vo.NodeKey(n.ID),
outputFields map[string]*vo.TypeInfo Type: entity.NodeTypeLLM,
canStream bool Name: n.Data.Meta.Title,
requireCheckpoint bool Configs: c,
fullSources map[string]*nodes.SourceInfo }
}
const ( param := n.Data.Inputs.LLMParam
rawOutputKey = "llm_raw_output_%s" if param == nil {
warningKey = "llm_warning_%s" return nil, fmt.Errorf("llm node's llmParam is nil")
) }
func jsonParse(ctx context.Context, data string, schema_ map[string]*vo.TypeInfo) (map[string]any, error) { bs, _ := sonic.Marshal(param)
data = nodes.ExtractJSONString(data) llmParam := make(vo.LLMParam, 0)
if err := sonic.Unmarshal(bs, &llmParam); err != nil {
var result map[string]any return nil, err
}
err := sonic.UnmarshalString(data, &result) convertedLLMParam, err := llmParamsToLLMParam(llmParam)
if err != nil { if err != nil {
c := execute.GetExeCtx(ctx)
if c != nil {
logs.CtxErrorf(ctx, "failed to parse json: %v, data: %s", err, data)
rawOutputK := fmt.Sprintf(rawOutputKey, c.NodeCtx.NodeKey)
warningK := fmt.Sprintf(warningKey, c.NodeCtx.NodeKey)
ctxcache.Store(ctx, rawOutputK, data)
ctxcache.Store(ctx, warningK, vo.WrapWarn(errno.ErrLLMStructuredOutputParseFail, err))
return map[string]any{}, nil
}
return nil, err return nil, err
} }
r, ws, err := nodes.ConvertInputs(ctx, result, schema_) c.LLMParams = convertedLLMParam
if err != nil { c.SystemPrompt = convertedLLMParam.SystemPrompt
return nil, vo.WrapError(errno.ErrLLMStructuredOutputParseFail, err) c.UserPrompt = convertedLLMParam.Prompt
var resFormat Format
switch convertedLLMParam.ResponseFormat {
case crossmodel.ResponseFormatText:
resFormat = FormatText
case crossmodel.ResponseFormatMarkdown:
resFormat = FormatMarkdown
case crossmodel.ResponseFormatJSON:
resFormat = FormatJSON
default:
return nil, fmt.Errorf("unsupported response format: %d", convertedLLMParam.ResponseFormat)
} }
if ws != nil { c.OutputFormat = resFormat
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
if err = convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
} }
return r, nil if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
if resFormat == FormatJSON {
if len(ns.OutputTypes) == 1 {
for _, v := range ns.OutputTypes {
if v.Type == vo.DataTypeString {
resFormat = FormatText
break
}
}
} else if len(ns.OutputTypes) == 2 {
if _, ok := ns.OutputTypes[ReasoningOutputKey]; ok {
for k, v := range ns.OutputTypes {
if k != ReasoningOutputKey && v.Type == vo.DataTypeString {
resFormat = FormatText
break
}
}
}
}
}
if resFormat == FormatJSON {
ns.StreamConfigs = &schema2.StreamConfig{
CanGeneratesStream: false,
}
} else {
ns.StreamConfigs = &schema2.StreamConfig{
CanGeneratesStream: true,
}
}
if n.Data.Inputs.LLM != nil && n.Data.Inputs.FCParam != nil {
c.FCParam = n.Data.Inputs.FCParam
}
if se := n.Data.Inputs.SettingOnError; se != nil {
if se.Ext != nil && len(se.Ext.BackupLLMParam) > 0 {
var backupLLMParam vo.SimpleLLMParam
if err = sonic.UnmarshalString(se.Ext.BackupLLMParam, &backupLLMParam); err != nil {
return nil, err
}
backupModel, err := simpleLLMParamsToLLMParams(backupLLMParam)
if err != nil {
return nil, err
}
c.BackupLLMParams = backupModel
}
}
return ns, nil
}
func llmParamsToLLMParam(params vo.LLMParam) (*crossmodel.LLMParams, error) {
p := &crossmodel.LLMParams{}
for _, param := range params {
switch param.Name {
case "temperature":
strVal := param.Input.Value.Content.(string)
floatVal, err := strconv.ParseFloat(strVal, 64)
if err != nil {
return nil, err
}
p.Temperature = &floatVal
case "maxTokens":
strVal := param.Input.Value.Content.(string)
intVal, err := strconv.Atoi(strVal)
if err != nil {
return nil, err
}
p.MaxTokens = intVal
case "responseFormat":
strVal := param.Input.Value.Content.(string)
int64Val, err := strconv.ParseInt(strVal, 10, 64)
if err != nil {
return nil, err
}
p.ResponseFormat = crossmodel.ResponseFormat(int64Val)
case "modleName":
strVal := param.Input.Value.Content.(string)
p.ModelName = strVal
case "modelType":
strVal := param.Input.Value.Content.(string)
int64Val, err := strconv.ParseInt(strVal, 10, 64)
if err != nil {
return nil, err
}
p.ModelType = int64Val
case "prompt":
strVal := param.Input.Value.Content.(string)
p.Prompt = strVal
case "enableChatHistory":
boolVar := param.Input.Value.Content.(bool)
p.EnableChatHistory = boolVar
case "systemPrompt":
strVal := param.Input.Value.Content.(string)
p.SystemPrompt = strVal
case "chatHistoryRound", "generationDiversity", "frequencyPenalty", "presencePenalty":
// do nothing
case "topP":
strVal := param.Input.Value.Content.(string)
floatVar, err := strconv.ParseFloat(strVal, 64)
if err != nil {
return nil, err
}
p.TopP = &floatVar
default:
return nil, fmt.Errorf("invalid LLMParam name: %s", param.Name)
}
}
return p, nil
}
func simpleLLMParamsToLLMParams(params vo.SimpleLLMParam) (*crossmodel.LLMParams, error) {
p := &crossmodel.LLMParams{}
p.ModelName = params.ModelName
p.ModelType = params.ModelType
p.Temperature = &params.Temperature
p.MaxTokens = params.MaxTokens
p.TopP = &params.TopP
p.ResponseFormat = params.ResponseFormat
p.SystemPrompt = params.SystemPrompt
return p, nil
} }
func getReasoningContent(message *schema.Message) string { func getReasoningContent(message *schema.Message) string {
return message.ReasoningContent return message.ReasoningContent
} }
type Options struct { func (c *Config) Build(ctx context.Context, ns *schema2.NodeSchema, _ ...schema2.BuildOption) (any, error) {
nested []nodes.NestedWorkflowOption var (
toolWorkflowSW *schema.StreamWriter[*entity.Message] err error
} chatModel, fallbackM model.BaseChatModel
info, fallbackI *modelmgr.Model
modelWithInfo ModelWithInfo
tools []tool.BaseTool
toolsReturnDirectly map[string]bool
knowledgeRecallConfig *KnowledgeRecallConfig
)
type Option func(o *Options) chatModel, info, err = crossmodel.GetManager().GetModel(ctx, c.LLMParams)
if err != nil {
func WithNestedWorkflowOptions(nested ...nodes.NestedWorkflowOption) Option { return nil, err
return func(o *Options) {
o.nested = append(o.nested, nested...)
} }
}
func WithToolWorkflowMessageWriter(sw *schema.StreamWriter[*entity.Message]) Option { exceptionConf := ns.ExceptionConfigs
return func(o *Options) { if exceptionConf != nil && exceptionConf.MaxRetry > 0 {
o.toolWorkflowSW = sw backupModelParams := c.BackupLLMParams
if backupModelParams != nil {
fallbackM, fallbackI, err = crossmodel.GetManager().GetModel(ctx, backupModelParams)
if err != nil {
return nil, err
}
}
} }
}
type llmState = map[string]any if fallbackM == nil {
modelWithInfo = NewModel(chatModel, info)
} else {
modelWithInfo = NewModelWithFallback(chatModel, fallbackM, info, fallbackI)
}
const agentModelName = "agent_model" fcParams := c.FCParam
if fcParams != nil {
if fcParams.WorkflowFCParam != nil {
for _, wf := range fcParams.WorkflowFCParam.WorkflowList {
wfIDStr := wf.WorkflowID
wfID, err := strconv.ParseInt(wfIDStr, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid workflow id: %s", wfIDStr)
}
workflowToolConfig := vo.WorkflowToolConfig{}
if wf.FCSetting != nil {
workflowToolConfig.InputParametersConfig = wf.FCSetting.RequestParameters
workflowToolConfig.OutputParametersConfig = wf.FCSetting.ResponseParameters
}
locator := vo.FromDraft
if wf.WorkflowVersion != "" {
locator = vo.FromSpecificVersion
}
wfTool, err := workflow.GetRepository().WorkflowAsTool(ctx, vo.GetPolicy{
ID: wfID,
QType: locator,
Version: wf.WorkflowVersion,
}, workflowToolConfig)
if err != nil {
return nil, err
}
tools = append(tools, wfTool)
if wfTool.TerminatePlan() == vo.UseAnswerContent {
if toolsReturnDirectly == nil {
toolsReturnDirectly = make(map[string]bool)
}
toolInfo, err := wfTool.Info(ctx)
if err != nil {
return nil, err
}
toolsReturnDirectly[toolInfo.Name] = true
}
}
}
if fcParams.PluginFCParam != nil {
pluginToolsInvokableReq := make(map[int64]*plugin.ToolsInvokableRequest)
for _, p := range fcParams.PluginFCParam.PluginList {
pid, err := strconv.ParseInt(p.PluginID, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid plugin id: %s", p.PluginID)
}
toolID, err := strconv.ParseInt(p.ApiId, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid plugin id: %s", p.PluginID)
}
var (
requestParameters []*workflow3.APIParameter
responseParameters []*workflow3.APIParameter
)
if p.FCSetting != nil {
requestParameters = p.FCSetting.RequestParameters
responseParameters = p.FCSetting.ResponseParameters
}
if req, ok := pluginToolsInvokableReq[pid]; ok {
req.ToolsInvokableInfo[toolID] = &plugin.ToolsInvokableInfo{
ToolID: toolID,
RequestAPIParametersConfig: requestParameters,
ResponseAPIParametersConfig: responseParameters,
}
} else {
pluginToolsInfoRequest := &plugin.ToolsInvokableRequest{
PluginEntity: plugin.Entity{
PluginID: pid,
PluginVersion: ptr.Of(p.PluginVersion),
},
ToolsInvokableInfo: map[int64]*plugin.ToolsInvokableInfo{
toolID: {
ToolID: toolID,
RequestAPIParametersConfig: requestParameters,
ResponseAPIParametersConfig: responseParameters,
},
},
IsDraft: p.IsDraft,
}
pluginToolsInvokableReq[pid] = pluginToolsInfoRequest
}
}
inInvokableTools := make([]tool.BaseTool, 0, len(fcParams.PluginFCParam.PluginList))
for _, req := range pluginToolsInvokableReq {
toolMap, err := plugin.GetPluginService().GetPluginInvokableTools(ctx, req)
if err != nil {
return nil, err
}
for _, t := range toolMap {
inInvokableTools = append(inInvokableTools, plugin.NewInvokableTool(t))
}
}
if len(inInvokableTools) > 0 {
tools = append(tools, inInvokableTools...)
}
}
if fcParams.KnowledgeFCParam != nil && len(fcParams.KnowledgeFCParam.KnowledgeList) > 0 {
kwChatModel := workflow.GetRepository().GetKnowledgeRecallChatModel()
if kwChatModel == nil {
return nil, fmt.Errorf("workflow builtin chat model for knowledge recall not configured")
}
knowledgeOperator := knowledge.GetKnowledgeOperator()
setting := fcParams.KnowledgeFCParam.GlobalSetting
knowledgeRecallConfig = &KnowledgeRecallConfig{
ChatModel: kwChatModel,
Retriever: knowledgeOperator,
}
searchType, err := toRetrievalSearchType(setting.SearchMode)
if err != nil {
return nil, err
}
knowledgeRecallConfig.RetrievalStrategy = &RetrievalStrategy{
RetrievalStrategy: &knowledge.RetrievalStrategy{
TopK: ptr.Of(setting.TopK),
MinScore: ptr.Of(setting.MinScore),
SearchType: searchType,
EnableNL2SQL: setting.UseNL2SQL,
EnableQueryRewrite: setting.UseRewrite,
EnableRerank: setting.UseRerank,
},
NoReCallReplyMode: NoReCallReplyMode(setting.NoRecallReplyMode),
NoReCallReplyCustomizePrompt: setting.NoRecallReplyCustomizePrompt,
}
knowledgeIDs := make([]int64, 0, len(fcParams.KnowledgeFCParam.KnowledgeList))
for _, kw := range fcParams.KnowledgeFCParam.KnowledgeList {
kid, err := strconv.ParseInt(kw.ID, 10, 64)
if err != nil {
return nil, err
}
knowledgeIDs = append(knowledgeIDs, kid)
}
detailResp, err := knowledgeOperator.ListKnowledgeDetail(ctx,
&knowledge.ListKnowledgeDetailRequest{
KnowledgeIDs: knowledgeIDs,
})
if err != nil {
return nil, err
}
knowledgeRecallConfig.SelectedKnowledgeDetails = detailResp.KnowledgeDetails
}
}
func New(ctx context.Context, cfg *Config) (*LLM, error) {
g := compose.NewGraph[map[string]any, map[string]any](compose.WithGenLocalState(func(ctx context.Context) (state llmState) { g := compose.NewGraph[map[string]any, map[string]any](compose.WithGenLocalState(func(ctx context.Context) (state llmState) {
return llmState{} return llmState{}
})) }))
var ( var hasReasoning bool
hasReasoning bool
canStream = true
)
format := cfg.OutputFormat format := c.OutputFormat
if format == FormatJSON { if format == FormatJSON {
if len(cfg.OutputFields) == 1 { if len(ns.OutputTypes) == 1 {
for _, v := range cfg.OutputFields { for _, v := range ns.OutputTypes {
if v.Type == vo.DataTypeString { if v.Type == vo.DataTypeString {
format = FormatText format = FormatText
break break
} }
} }
} else if len(cfg.OutputFields) == 2 { } else if len(ns.OutputTypes) == 2 {
if _, ok := cfg.OutputFields[ReasoningOutputKey]; ok { if _, ok := ns.OutputTypes[ReasoningOutputKey]; ok {
for k, v := range cfg.OutputFields { for k, v := range ns.OutputTypes {
if k != ReasoningOutputKey && v.Type == vo.DataTypeString { if k != ReasoningOutputKey && v.Type == vo.DataTypeString {
format = FormatText format = FormatText
break break
@@ -272,10 +561,10 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
} }
} }
userPrompt := cfg.UserPrompt userPrompt := c.UserPrompt
switch format { switch format {
case FormatJSON: case FormatJSON:
jsonSchema, err := vo.TypeInfoToJSONSchema(cfg.OutputFields, nil) jsonSchema, err := vo.TypeInfoToJSONSchema(ns.OutputTypes, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -287,20 +576,20 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
case FormatText: case FormatText:
} }
if cfg.KnowledgeRecallConfig != nil { if knowledgeRecallConfig != nil {
err := injectKnowledgeTool(ctx, g, cfg.UserPrompt, cfg.KnowledgeRecallConfig) err := injectKnowledgeTool(ctx, g, c.UserPrompt, knowledgeRecallConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
userPrompt = fmt.Sprintf("{{%s}}%s", knowledgeUserPromptTemplateKey, userPrompt) userPrompt = fmt.Sprintf("{{%s}}%s", knowledgeUserPromptTemplateKey, userPrompt)
inputs := maps.Clone(cfg.InputFields) inputs := maps.Clone(ns.InputTypes)
inputs[knowledgeUserPromptTemplateKey] = &vo.TypeInfo{ inputs[knowledgeUserPromptTemplateKey] = &vo.TypeInfo{
Type: vo.DataTypeString, Type: vo.DataTypeString,
} }
sp := newPromptTpl(schema.System, cfg.SystemPrompt, inputs, nil) sp := newPromptTpl(schema.System, c.SystemPrompt, inputs, nil)
up := newPromptTpl(schema.User, userPrompt, inputs, []string{knowledgeUserPromptTemplateKey}) up := newPromptTpl(schema.User, userPrompt, inputs, []string{knowledgeUserPromptTemplateKey})
template := newPrompts(sp, up, cfg.ChatModel) template := newPrompts(sp, up, modelWithInfo)
_ = g.AddChatTemplateNode(templateNodeKey, template, _ = g.AddChatTemplateNode(templateNodeKey, template,
compose.WithStatePreHandler(func(ctx context.Context, in map[string]any, state llmState) (map[string]any, error) { compose.WithStatePreHandler(func(ctx context.Context, in map[string]any, state llmState) (map[string]any, error) {
@@ -312,28 +601,28 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
_ = g.AddEdge(knowledgeLambdaKey, templateNodeKey) _ = g.AddEdge(knowledgeLambdaKey, templateNodeKey)
} else { } else {
sp := newPromptTpl(schema.System, cfg.SystemPrompt, cfg.InputFields, nil) sp := newPromptTpl(schema.System, c.SystemPrompt, ns.InputTypes, nil)
up := newPromptTpl(schema.User, userPrompt, cfg.InputFields, nil) up := newPromptTpl(schema.User, userPrompt, ns.InputTypes, nil)
template := newPrompts(sp, up, cfg.ChatModel) template := newPrompts(sp, up, modelWithInfo)
_ = g.AddChatTemplateNode(templateNodeKey, template) _ = g.AddChatTemplateNode(templateNodeKey, template)
_ = g.AddEdge(compose.START, templateNodeKey) _ = g.AddEdge(compose.START, templateNodeKey)
} }
if len(cfg.Tools) > 0 { if len(tools) > 0 {
m, ok := cfg.ChatModel.(model.ToolCallingChatModel) m, ok := modelWithInfo.(model.ToolCallingChatModel)
if !ok { if !ok {
return nil, errors.New("requires a ToolCallingChatModel to use with tools") return nil, errors.New("requires a ToolCallingChatModel to use with tools")
} }
reactConfig := react.AgentConfig{ reactConfig := react.AgentConfig{
ToolCallingModel: m, ToolCallingModel: m,
ToolsConfig: compose.ToolsNodeConfig{Tools: cfg.Tools}, ToolsConfig: compose.ToolsNodeConfig{Tools: tools},
ModelNodeName: agentModelName, ModelNodeName: agentModelName,
} }
if len(cfg.ToolsReturnDirectly) > 0 { if len(toolsReturnDirectly) > 0 {
reactConfig.ToolReturnDirectly = make(map[string]struct{}, len(cfg.ToolsReturnDirectly)) reactConfig.ToolReturnDirectly = make(map[string]struct{}, len(toolsReturnDirectly))
for k := range cfg.ToolsReturnDirectly { for k := range toolsReturnDirectly {
reactConfig.ToolReturnDirectly[k] = struct{}{} reactConfig.ToolReturnDirectly[k] = struct{}{}
} }
} }
@@ -347,28 +636,26 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
opts = append(opts, compose.WithNodeName("workflow_llm_react_agent")) opts = append(opts, compose.WithNodeName("workflow_llm_react_agent"))
_ = g.AddGraphNode(llmNodeKey, agentNode, opts...) _ = g.AddGraphNode(llmNodeKey, agentNode, opts...)
} else { } else {
_ = g.AddChatModelNode(llmNodeKey, cfg.ChatModel) _ = g.AddChatModelNode(llmNodeKey, modelWithInfo)
} }
_ = g.AddEdge(templateNodeKey, llmNodeKey) _ = g.AddEdge(templateNodeKey, llmNodeKey)
if format == FormatJSON { if format == FormatJSON {
iConvert := func(ctx context.Context, msg *schema.Message) (map[string]any, error) { iConvert := func(ctx context.Context, msg *schema.Message) (map[string]any, error) {
return jsonParse(ctx, msg.Content, cfg.OutputFields) return jsonParse(ctx, msg.Content, ns.OutputTypes)
} }
convertNode := compose.InvokableLambda(iConvert) convertNode := compose.InvokableLambda(iConvert)
_ = g.AddLambdaNode(outputConvertNodeKey, convertNode) _ = g.AddLambdaNode(outputConvertNodeKey, convertNode)
canStream = false
} else { } else {
var outputKey string var outputKey string
if len(cfg.OutputFields) != 1 && len(cfg.OutputFields) != 2 { if len(ns.OutputTypes) != 1 && len(ns.OutputTypes) != 2 {
panic("impossible") panic("impossible")
} }
for k, v := range cfg.OutputFields { for k, v := range ns.OutputTypes {
if v.Type != vo.DataTypeString { if v.Type != vo.DataTypeString {
panic("impossible") panic("impossible")
} }
@@ -443,17 +730,17 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
_ = g.AddEdge(outputConvertNodeKey, compose.END) _ = g.AddEdge(outputConvertNodeKey, compose.END)
requireCheckpoint := false requireCheckpoint := false
if len(cfg.Tools) > 0 { if len(tools) > 0 {
requireCheckpoint = true requireCheckpoint = true
} }
var opts []compose.GraphCompileOption var compileOpts []compose.GraphCompileOption
if requireCheckpoint { if requireCheckpoint {
opts = append(opts, compose.WithCheckPointStore(workflow.GetRepository())) compileOpts = append(compileOpts, compose.WithCheckPointStore(workflow.GetRepository()))
} }
opts = append(opts, compose.WithGraphName("workflow_llm_node_graph")) compileOpts = append(compileOpts, compose.WithGraphName("workflow_llm_node_graph"))
r, err := g.Compile(ctx, opts...) r, err := g.Compile(ctx, compileOpts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -461,15 +748,132 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
llm := &LLM{ llm := &LLM{
r: r, r: r,
outputFormat: format, outputFormat: format,
canStream: canStream,
requireCheckpoint: requireCheckpoint, requireCheckpoint: requireCheckpoint,
fullSources: cfg.FullSources, fullSources: ns.FullSources,
} }
return llm, nil return llm, nil
} }
func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...Option) (composeOpts []compose.Option, resumingEvent *entity.InterruptEvent, err error) { func (c *Config) RequireCheckpoint() bool {
if c.FCParam != nil {
if c.FCParam.WorkflowFCParam != nil || c.FCParam.PluginFCParam != nil {
return true
}
}
return false
}
func (c *Config) FieldStreamType(path compose.FieldPath, ns *schema2.NodeSchema,
sc *schema2.WorkflowSchema) (schema2.FieldStreamType, error) {
if !sc.RequireStreaming() {
return schema2.FieldNotStream, nil
}
if len(path) != 1 {
return schema2.FieldNotStream, nil
}
outputs := ns.OutputTypes
if len(outputs) != 1 && len(outputs) != 2 {
return schema2.FieldNotStream, nil
}
var outputKey string
for key, output := range outputs {
if output.Type != vo.DataTypeString {
return schema2.FieldNotStream, nil
}
if key != ReasoningOutputKey {
if len(outputKey) > 0 {
return schema2.FieldNotStream, nil
}
outputKey = key
}
}
field := path[0]
if field == ReasoningOutputKey || field == outputKey {
return schema2.FieldIsStream, nil
}
return schema2.FieldNotStream, nil
}
func toRetrievalSearchType(s int64) (knowledge.SearchType, error) {
switch s {
case 0:
return knowledge.SearchTypeSemantic, nil
case 1:
return knowledge.SearchTypeHybrid, nil
case 20:
return knowledge.SearchTypeFullText, nil
default:
return "", fmt.Errorf("invalid retrieval search type %v", s)
}
}
type LLM struct {
r compose.Runnable[map[string]any, map[string]any]
outputFormat Format
requireCheckpoint bool
fullSources map[string]*schema2.SourceInfo
}
const (
rawOutputKey = "llm_raw_output_%s"
warningKey = "llm_warning_%s"
)
func jsonParse(ctx context.Context, data string, schema_ map[string]*vo.TypeInfo) (map[string]any, error) {
data = nodes.ExtractJSONString(data)
var result map[string]any
err := sonic.UnmarshalString(data, &result)
if err != nil {
c := execute.GetExeCtx(ctx)
if c != nil {
logs.CtxErrorf(ctx, "failed to parse json: %v, data: %s", err, data)
rawOutputK := fmt.Sprintf(rawOutputKey, c.NodeCtx.NodeKey)
warningK := fmt.Sprintf(warningKey, c.NodeCtx.NodeKey)
ctxcache.Store(ctx, rawOutputK, data)
ctxcache.Store(ctx, warningK, vo.WrapWarn(errno.ErrLLMStructuredOutputParseFail, err))
return map[string]any{}, nil
}
return nil, err
}
r, ws, err := nodes.ConvertInputs(ctx, result, schema_)
if err != nil {
return nil, vo.WrapError(errno.ErrLLMStructuredOutputParseFail, err)
}
if ws != nil {
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
}
return r, nil
}
type llmOptions struct {
toolWorkflowSW *schema.StreamWriter[*entity.Message]
}
func WithToolWorkflowMessageWriter(sw *schema.StreamWriter[*entity.Message]) nodes.NodeOption {
return nodes.WrapImplSpecificOptFn(func(o *llmOptions) {
o.toolWorkflowSW = sw
})
}
type llmState = map[string]any
const agentModelName = "agent_model"
func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...nodes.NodeOption) (composeOpts []compose.Option, resumingEvent *entity.InterruptEvent, err error) {
c := execute.GetExeCtx(ctx) c := execute.GetExeCtx(ctx)
if c != nil { if c != nil {
resumingEvent = c.NodeCtx.ResumingEvent resumingEvent = c.NodeCtx.ResumingEvent
@@ -502,17 +906,9 @@ func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...Option) (co
composeOpts = append(composeOpts, compose.WithCheckPointID(checkpointID)) composeOpts = append(composeOpts, compose.WithCheckPointID(checkpointID))
} }
llmOpts := &Options{} options := nodes.GetCommonOptions(&nodes.NodeOptions{}, opts...)
for _, opt := range opts {
opt(llmOpts)
}
nestedOpts := &nodes.NestedWorkflowOptions{} composeOpts = append(composeOpts, options.GetOptsForNested()...)
for _, opt := range llmOpts.nested {
opt(nestedOpts)
}
composeOpts = append(composeOpts, nestedOpts.GetOptsForNested()...)
if resumingEvent != nil { if resumingEvent != nil {
var ( var (
@@ -580,6 +976,7 @@ func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...Option) (co
composeOpts = append(composeOpts, compose.WithToolsNodeOption(compose.WithToolOption(execute.WithExecuteConfig(exeCfg)))) composeOpts = append(composeOpts, compose.WithToolsNodeOption(compose.WithToolOption(execute.WithExecuteConfig(exeCfg))))
} }
llmOpts := nodes.GetImplSpecificOptions(&llmOptions{}, opts...)
if llmOpts.toolWorkflowSW != nil { if llmOpts.toolWorkflowSW != nil {
toolMsgOpt, toolMsgSR := execute.WithMessagePipe() toolMsgOpt, toolMsgSR := execute.WithMessagePipe()
composeOpts = append(composeOpts, toolMsgOpt) composeOpts = append(composeOpts, toolMsgOpt)
@@ -697,7 +1094,7 @@ func handleInterrupt(ctx context.Context, err error, resumingEvent *entity.Inter
return compose.NewInterruptAndRerunErr(ie) return compose.NewInterruptAndRerunErr(ie)
} }
func (l *LLM) Chat(ctx context.Context, in map[string]any, opts ...Option) (out map[string]any, err error) { func (l *LLM) Invoke(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (out map[string]any, err error) {
composeOpts, resumingEvent, err := l.prepare(ctx, in, opts...) composeOpts, resumingEvent, err := l.prepare(ctx, in, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -712,7 +1109,7 @@ func (l *LLM) Chat(ctx context.Context, in map[string]any, opts ...Option) (out
return out, nil return out, nil
} }
func (l *LLM) ChatStream(ctx context.Context, in map[string]any, opts ...Option) (out *schema.StreamReader[map[string]any], err error) { func (l *LLM) Stream(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (out *schema.StreamReader[map[string]any], err error) {
composeOpts, resumingEvent, err := l.prepare(ctx, in, opts...) composeOpts, resumingEvent, err := l.prepare(ctx, in, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -745,7 +1142,7 @@ func injectKnowledgeTool(_ context.Context, g *compose.Graph[map[string]any, map
_ = g.AddLambdaNode(knowledgeLambdaKey, compose.InvokableLambda(func(ctx context.Context, input *schema.Message) (output map[string]any, err error) { _ = g.AddLambdaNode(knowledgeLambdaKey, compose.InvokableLambda(func(ctx context.Context, input *schema.Message) (output map[string]any, err error) {
modelPredictionIDs := strings.Split(input.Content, ",") modelPredictionIDs := strings.Split(input.Content, ",")
selectKwIDs := slices.ToMap(cfg.SelectedKnowledgeDetails, func(e *crossknowledge.KnowledgeDetail) (string, int64) { selectKwIDs := slices.ToMap(cfg.SelectedKnowledgeDetails, func(e *knowledge.KnowledgeDetail) (string, int64) {
return strconv.Itoa(int(e.ID)), e.ID return strconv.Itoa(int(e.ID)), e.ID
}) })
recallKnowledgeIDs := make([]int64, 0) recallKnowledgeIDs := make([]int64, 0)
@@ -759,7 +1156,7 @@ func injectKnowledgeTool(_ context.Context, g *compose.Graph[map[string]any, map
return make(map[string]any), nil return make(map[string]any), nil
} }
docs, err := cfg.Retriever.Retrieve(ctx, &crossknowledge.RetrieveRequest{ docs, err := cfg.Retriever.Retrieve(ctx, &knowledge.RetrieveRequest{
Query: userPrompt, Query: userPrompt,
KnowledgeIDs: recallKnowledgeIDs, KnowledgeIDs: recallKnowledgeIDs,
RetrievalStrategy: cfg.RetrievalStrategy.RetrievalStrategy, RetrievalStrategy: cfg.RetrievalStrategy.RetrievalStrategy,

View File

@@ -26,6 +26,7 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr" "github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache" "github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/sonic" "github.com/coze-dev/coze-studio/backend/pkg/sonic"
@@ -107,7 +108,7 @@ func newPrompts(sp, up *promptTpl, model ModelWithInfo) *prompts {
} }
func (pl *promptTpl) render(ctx context.Context, vs map[string]any, func (pl *promptTpl) render(ctx context.Context, vs map[string]any,
sources map[string]*nodes.SourceInfo, sources map[string]*schema2.SourceInfo,
supportedModals map[modelmgr.Modal]bool, supportedModals map[modelmgr.Modal]bool,
) (*schema.Message, error) { ) (*schema.Message, error) {
if !pl.hasMultiModal || len(supportedModals) == 0 { if !pl.hasMultiModal || len(supportedModals) == 0 {
@@ -247,7 +248,7 @@ func (p *prompts) Format(ctx context.Context, vs map[string]any, _ ...prompt.Opt
} }
sk := fmt.Sprintf(sourceKey, nodeKey) sk := fmt.Sprintf(sourceKey, nodeKey)
sources, ok := ctxcache.Get[map[string]*nodes.SourceInfo](ctx, sk) sources, ok := ctxcache.Get[map[string]*schema2.SourceInfo](ctx, sk)
if !ok { if !ok {
return nil, fmt.Errorf("resolved sources not found llm node, key: %s", sk) return nil, fmt.Errorf("resolved sources not found llm node, key: %s", sk)
} }

View File

@@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
package loop package _break
import ( import (
"context" "context"
@@ -22,21 +22,36 @@ import (
"github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable" "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
) )
type Break struct { type Break struct {
parentIntermediateStore variable.Store parentIntermediateStore variable.Store
} }
func NewBreak(_ context.Context, store variable.Store) (*Break, error) { type Config struct{}
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
return &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeBreak,
Name: n.Data.Meta.Title,
Configs: c,
}, nil
}
func (c *Config) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
return &Break{ return &Break{
parentIntermediateStore: store, parentIntermediateStore: &nodes.ParentIntermediateStore{},
}, nil }, nil
} }
const BreakKey = "$break" const BreakKey = "$break"
func (b *Break) DoBreak(ctx context.Context, _ map[string]any) (map[string]any, error) { func (b *Break) Invoke(ctx context.Context, _ map[string]any) (map[string]any, error) {
err := b.parentIntermediateStore.Set(ctx, compose.FieldPath{BreakKey}, true) err := b.parentIntermediateStore.Set(ctx, compose.FieldPath{BreakKey}, true)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -0,0 +1,47 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package _continue
import (
"context"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
type Continue struct{}
type Config struct{}
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
return &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeContinue,
Name: n.Data.Meta.Title,
Configs: c,
}, nil
}
func (c *Config) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
return &Continue{}, nil
}
func (co *Continue) Invoke(_ context.Context, in map[string]any) (map[string]any, error) {
return in, nil
}

View File

@@ -27,53 +27,150 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
_break "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop/break"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
) )
type Loop struct { type Loop struct {
config *Config
outputs map[string]*vo.FieldSource outputs map[string]*vo.FieldSource
outputVars map[string]string outputVars map[string]string
inner compose.Runnable[map[string]any, map[string]any]
nodeKey vo.NodeKey
loopType Type
inputArrays []string
intermediateVars map[string]*vo.TypeInfo
} }
type Config struct { type Config struct {
LoopNodeKey vo.NodeKey
LoopType Type LoopType Type
InputArrays []string InputArrays []string
IntermediateVars map[string]*vo.TypeInfo IntermediateVars map[string]*vo.TypeInfo
Outputs []*vo.FieldInfo
Inner compose.Runnable[map[string]any, map[string]any]
} }
type Type string func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
if n.Parent() != nil {
const ( return nil, fmt.Errorf("loop node cannot have parent: %s", n.Parent().ID)
ByArray Type = "by_array"
ByIteration Type = "by_iteration"
Infinite Type = "infinite"
)
func NewLoop(_ context.Context, conf *Config) (*Loop, error) {
if conf == nil {
return nil, errors.New("config is nil")
} }
if conf.LoopType == ByArray { ns := &schema.NodeSchema{
if len(conf.InputArrays) == 0 { Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeLoop,
Name: n.Data.Meta.Title,
Configs: c,
}
loopType, err := toLoopType(n.Data.Inputs.LoopType)
if err != nil {
return nil, err
}
c.LoopType = loopType
intermediateVars := make(map[string]*vo.TypeInfo)
for _, param := range n.Data.Inputs.VariableParameters {
tInfo, err := convert.CanvasBlockInputToTypeInfo(param.Input)
if err != nil {
return nil, err
}
intermediateVars[param.Name] = tInfo
ns.SetInputType(param.Name, tInfo)
sources, err := convert.CanvasBlockInputToFieldInfo(param.Input, compose.FieldPath{param.Name}, nil)
if err != nil {
return nil, err
}
ns.AddInputSource(sources...)
}
c.IntermediateVars = intermediateVars
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err := convert.SetOutputsForNodeSchema(n, ns); err != nil {
return nil, err
}
for _, fieldInfo := range ns.OutputSources {
if fieldInfo.Source.Ref != nil {
if len(fieldInfo.Source.Ref.FromPath) == 1 {
if _, ok := intermediateVars[fieldInfo.Source.Ref.FromPath[0]]; ok {
fieldInfo.Source.Ref.VariableType = ptr.Of(vo.ParentIntermediate)
}
}
}
}
loopCount := n.Data.Inputs.LoopCount
if loopCount != nil {
typeInfo, err := convert.CanvasBlockInputToTypeInfo(loopCount)
if err != nil {
return nil, err
}
ns.SetInputType(Count, typeInfo)
sources, err := convert.CanvasBlockInputToFieldInfo(loopCount, compose.FieldPath{Count}, nil)
if err != nil {
return nil, err
}
ns.AddInputSource(sources...)
}
for key, tInfo := range ns.InputTypes {
if tInfo.Type != vo.DataTypeArray {
continue
}
if _, ok := intermediateVars[key]; ok { // exclude arrays in intermediate vars
continue
}
c.InputArrays = append(c.InputArrays, key)
}
return ns, nil
}
func toLoopType(l vo.LoopType) (Type, error) {
switch l {
case vo.LoopTypeArray:
return ByArray, nil
case vo.LoopTypeCount:
return ByIteration, nil
case vo.LoopTypeInfinite:
return Infinite, nil
default:
return "", fmt.Errorf("unsupported loop type: %s", l)
}
}
func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, opts ...schema.BuildOption) (any, error) {
if c.LoopType == ByArray {
if len(c.InputArrays) == 0 {
return nil, errors.New("input arrays is empty when loop type is ByArray") return nil, errors.New("input arrays is empty when loop type is ByArray")
} }
} }
loop := &Loop{ options := schema.GetBuildOptions(opts...)
config: conf, if options.Inner == nil {
outputs: make(map[string]*vo.FieldSource), return nil, errors.New("inner workflow is required for Loop Node")
outputVars: make(map[string]string),
} }
for _, info := range conf.Outputs { loop := &Loop{
outputs: make(map[string]*vo.FieldSource),
outputVars: make(map[string]string),
inputArrays: c.InputArrays,
nodeKey: ns.Key,
intermediateVars: c.IntermediateVars,
inner: options.Inner,
loopType: c.LoopType,
}
for _, info := range ns.OutputSources {
if len(info.Path) != 1 { if len(info.Path) != 1 {
return nil, fmt.Errorf("invalid output path: %s", info.Path) return nil, fmt.Errorf("invalid output path: %s", info.Path)
} }
@@ -87,7 +184,7 @@ func NewLoop(_ context.Context, conf *Config) (*Loop, error) {
return nil, fmt.Errorf("loop output refers to intermediate variable, but path length > 1: %v", fromPath) return nil, fmt.Errorf("loop output refers to intermediate variable, but path length > 1: %v", fromPath)
} }
if _, ok := conf.IntermediateVars[fromPath[0]]; !ok { if _, ok := c.IntermediateVars[fromPath[0]]; !ok {
return nil, fmt.Errorf("loop output refers to intermediate variable, but not found in intermediate vars: %v", fromPath) return nil, fmt.Errorf("loop output refers to intermediate variable, but not found in intermediate vars: %v", fromPath)
} }
@@ -102,18 +199,27 @@ func NewLoop(_ context.Context, conf *Config) (*Loop, error) {
return loop, nil return loop, nil
} }
type Type string
const (
ByArray Type = "by_array"
ByIteration Type = "by_iteration"
Infinite Type = "infinite"
)
const ( const (
Count = "loopCount" Count = "loopCount"
) )
func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.NestedWorkflowOption) (out map[string]any, err error) { func (l *Loop) Invoke(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (
out map[string]any, err error) {
maxIter, err := l.getMaxIter(in) maxIter, err := l.getMaxIter(in)
if err != nil { if err != nil {
return nil, err return nil, err
} }
arrays := make(map[string][]any, len(l.config.InputArrays)) arrays := make(map[string][]any, len(l.inputArrays))
for _, arrayKey := range l.config.InputArrays { for _, arrayKey := range l.inputArrays {
a, ok := nodes.TakeMapValue(in, compose.FieldPath{arrayKey}) a, ok := nodes.TakeMapValue(in, compose.FieldPath{arrayKey})
if !ok { if !ok {
return nil, fmt.Errorf("incoming array not present in input: %s", arrayKey) return nil, fmt.Errorf("incoming array not present in input: %s", arrayKey)
@@ -121,10 +227,7 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
arrays[arrayKey] = a.([]any) arrays[arrayKey] = a.([]any)
} }
options := &nodes.NestedWorkflowOptions{} options := nodes.GetCommonOptions(&nodes.NodeOptions{}, opts...)
for _, opt := range opts {
opt(options)
}
var ( var (
existingCState *nodes.NestedWorkflowState existingCState *nodes.NestedWorkflowState
@@ -134,7 +237,7 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
) )
err = compose.ProcessState(ctx, func(ctx context.Context, getter nodes.NestedWorkflowAware) error { err = compose.ProcessState(ctx, func(ctx context.Context, getter nodes.NestedWorkflowAware) error {
var e error var e error
existingCState, _, e = getter.GetNestedWorkflowState(l.config.LoopNodeKey) existingCState, _, e = getter.GetNestedWorkflowState(l.nodeKey)
if e != nil { if e != nil {
return e return e
} }
@@ -150,15 +253,15 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
for k := range existingCState.IntermediateVars { for k := range existingCState.IntermediateVars {
intermediateVars[k] = ptr.Of(existingCState.IntermediateVars[k]) intermediateVars[k] = ptr.Of(existingCState.IntermediateVars[k])
} }
intermediateVars[BreakKey] = &hasBreak intermediateVars[_break.BreakKey] = &hasBreak
} else { } else {
output = make(map[string]any, len(l.outputs)) output = make(map[string]any, len(l.outputs))
for k := range l.outputs { for k := range l.outputs {
output[k] = make([]any, 0) output[k] = make([]any, 0)
} }
intermediateVars = make(map[string]*any, len(l.config.IntermediateVars)) intermediateVars = make(map[string]*any, len(l.intermediateVars))
for varKey := range l.config.IntermediateVars { for varKey := range l.intermediateVars {
v, ok := nodes.TakeMapValue(in, compose.FieldPath{varKey}) v, ok := nodes.TakeMapValue(in, compose.FieldPath{varKey})
if !ok { if !ok {
return nil, fmt.Errorf("incoming intermediate variable not present in input: %s", varKey) return nil, fmt.Errorf("incoming intermediate variable not present in input: %s", varKey)
@@ -166,10 +269,10 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
intermediateVars[varKey] = &v intermediateVars[varKey] = &v
} }
intermediateVars[BreakKey] = &hasBreak intermediateVars[_break.BreakKey] = &hasBreak
} }
ctx = nodes.InitIntermediateVars(ctx, intermediateVars, l.config.IntermediateVars) ctx = nodes.InitIntermediateVars(ctx, intermediateVars, l.intermediateVars)
getIthInput := func(i int) (map[string]any, map[string]any, error) { getIthInput := func(i int) (map[string]any, map[string]any, error) {
input := make(map[string]any) input := make(map[string]any)
@@ -190,13 +293,13 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
input[k] = v input[k] = v
} }
input[string(l.config.LoopNodeKey)+"#index"] = int64(i) input[string(l.nodeKey)+"#index"] = int64(i)
items := make(map[string]any) items := make(map[string]any)
for arrayKey := range arrays { for arrayKey := range arrays {
ele := arrays[arrayKey][i] ele := arrays[arrayKey][i]
items[arrayKey] = ele items[arrayKey] = ele
currentKey := string(l.config.LoopNodeKey) + "#" + arrayKey currentKey := string(l.nodeKey) + "#" + arrayKey
// Recursively expand map[string]any elements // Recursively expand map[string]any elements
var expand func(prefix string, val interface{}) var expand func(prefix string, val interface{})
@@ -276,7 +379,7 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
} }
} }
taskOutput, err := l.config.Inner.Invoke(subCtx, input, ithOpts...) taskOutput, err := l.inner.Invoke(subCtx, input, ithOpts...)
if err != nil { if err != nil {
info, ok := compose.ExtractInterruptInfo(err) info, ok := compose.ExtractInterruptInfo(err)
if !ok { if !ok {
@@ -322,29 +425,26 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
if len(index2InterruptInfo) > 0 { // this invocation of batch.Execute has new interruptions if len(index2InterruptInfo) > 0 { // this invocation of batch.Execute has new interruptions
iEvent := &entity.InterruptEvent{ iEvent := &entity.InterruptEvent{
NodeKey: l.config.LoopNodeKey, NodeKey: l.nodeKey,
NodeType: entity.NodeTypeLoop, NodeType: entity.NodeTypeLoop,
NestedInterruptInfo: index2InterruptInfo, // only emit the newly generated interruptInfo NestedInterruptInfo: index2InterruptInfo, // only emit the newly generated interruptInfo
} }
err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error { err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error {
if e := setter.SaveNestedWorkflowState(l.config.LoopNodeKey, compState); e != nil { if e := setter.SaveNestedWorkflowState(l.nodeKey, compState); e != nil {
return e return e
} }
return setter.SetInterruptEvent(l.config.LoopNodeKey, iEvent) return setter.SetInterruptEvent(l.nodeKey, iEvent)
}) })
if err != nil { if err != nil {
return nil, err return nil, err
} }
fmt.Println("save interruptEvent in state within loop: ", iEvent)
fmt.Println("save composite info in state within loop: ", compState)
return nil, compose.InterruptAndRerun return nil, compose.InterruptAndRerun
} else { } else {
err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error { err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error {
return setter.SaveNestedWorkflowState(l.config.LoopNodeKey, compState) return setter.SaveNestedWorkflowState(l.nodeKey, compState)
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@@ -354,8 +454,7 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
} }
if existingCState != nil && len(existingCState.Index2InterruptInfo) > 0 { if existingCState != nil && len(existingCState.Index2InterruptInfo) > 0 {
fmt.Println("no interrupt thrown this round, but has historical interrupt events: ", existingCState.Index2InterruptInfo) panic(fmt.Sprintf("no interrupt thrown this round, but has historical interrupt events: %v", existingCState.Index2InterruptInfo))
panic("impossible")
} }
for outputVarKey, intermediateVarKey := range l.outputVars { for outputVarKey, intermediateVarKey := range l.outputVars {
@@ -368,9 +467,9 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
func (l *Loop) getMaxIter(in map[string]any) (int, error) { func (l *Loop) getMaxIter(in map[string]any) (int, error) {
maxIter := math.MaxInt maxIter := math.MaxInt
switch l.config.LoopType { switch l.loopType {
case ByArray: case ByArray:
for _, arrayKey := range l.config.InputArrays { for _, arrayKey := range l.inputArrays {
a, ok := nodes.TakeMapValue(in, compose.FieldPath{arrayKey}) a, ok := nodes.TakeMapValue(in, compose.FieldPath{arrayKey})
if !ok { if !ok {
return 0, fmt.Errorf("incoming array not present in input: %s", arrayKey) return 0, fmt.Errorf("incoming array not present in input: %s", arrayKey)
@@ -394,7 +493,7 @@ func (l *Loop) getMaxIter(in map[string]any) (int, error) {
maxIter = int(iter.(int64)) maxIter = int(iter.(int64))
case Infinite: case Infinite:
default: default:
return 0, fmt.Errorf("loop type not supported: %v", l.config.LoopType) return 0, fmt.Errorf("loop type not supported: %v", l.loopType)
} }
return maxIter, nil return maxIter, nil
@@ -409,8 +508,8 @@ func convertIntermediateVars(vars map[string]*any) map[string]any {
} }
func (l *Loop) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) { func (l *Loop) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) {
trimmed := make(map[string]any, len(l.config.InputArrays)) trimmed := make(map[string]any, len(l.inputArrays))
for _, arrayKey := range l.config.InputArrays { for _, arrayKey := range l.inputArrays {
if v, ok := in[arrayKey]; ok { if v, ok := in[arrayKey]; ok {
trimmed[arrayKey] = v trimmed[arrayKey] = v
} }

View File

@@ -1,90 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nodes
import (
"github.com/bytedance/sonic"
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
)
type NestedWorkflowOptions struct {
optsForNested []compose.Option
toResumeIndexes map[int]compose.StateModifier
optsForIndexed map[int][]compose.Option
}
type NestedWorkflowOption func(*NestedWorkflowOptions)
func WithOptsForNested(opts ...compose.Option) NestedWorkflowOption {
return func(o *NestedWorkflowOptions) {
o.optsForNested = append(o.optsForNested, opts...)
}
}
func (c *NestedWorkflowOptions) GetOptsForNested() []compose.Option {
return c.optsForNested
}
func WithResumeIndex(i int, m compose.StateModifier) NestedWorkflowOption {
return func(o *NestedWorkflowOptions) {
if o.toResumeIndexes == nil {
o.toResumeIndexes = map[int]compose.StateModifier{}
}
o.toResumeIndexes[i] = m
}
}
func (c *NestedWorkflowOptions) GetResumeIndexes() map[int]compose.StateModifier {
return c.toResumeIndexes
}
func WithOptsForIndexed(index int, opts ...compose.Option) NestedWorkflowOption {
return func(o *NestedWorkflowOptions) {
if o.optsForIndexed == nil {
o.optsForIndexed = map[int][]compose.Option{}
}
o.optsForIndexed[index] = opts
}
}
func (c *NestedWorkflowOptions) GetOptsForIndexed(index int) []compose.Option {
if c.optsForIndexed == nil {
return nil
}
return c.optsForIndexed[index]
}
type NestedWorkflowState struct {
Index2Done map[int]bool `json:"index_2_done,omitempty"`
Index2InterruptInfo map[int]*compose.InterruptInfo `json:"index_2_interrupt_info,omitempty"`
FullOutput map[string]any `json:"full_output,omitempty"`
IntermediateVars map[string]any `json:"intermediate_vars,omitempty"`
}
func (c *NestedWorkflowState) String() string {
s, _ := sonic.MarshalIndent(c, "", " ")
return string(s)
}
type NestedWorkflowAware interface {
SaveNestedWorkflowState(key vo.NodeKey, state *NestedWorkflowState) error
GetNestedWorkflowState(key vo.NodeKey) (*NestedWorkflowState, bool, error)
InterruptEventStore
}

View File

@@ -0,0 +1,194 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nodes
import (
"context"
"fmt"
"github.com/cloudwego/eino/compose"
einoschema "github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
// InvokableNode is a basic workflow node that can Invoke.
// Invoke accepts non-streaming input and returns non-streaming output.
// It does not accept any options.
// Most nodes implement this, such as NodeTypePlugin.
type InvokableNode interface {
Invoke(ctx context.Context, input map[string]any) (
output map[string]any, err error)
}
// InvokableNodeWOpt is a workflow node that can Invoke.
// Invoke accepts non-streaming input and returns non-streaming output.
// It can accept NodeOption.
// e.g. NodeTypeLLM, NodeTypeSubWorkflow implement this.
type InvokableNodeWOpt interface {
Invoke(ctx context.Context, in map[string]any, opts ...NodeOption) (
map[string]any, error)
}
// StreamableNode is a workflow node that can Stream.
// Stream accepts non-streaming input and returns streaming output.
// It does not accept and options
// Currently NO Node implement this.
// A potential example would be streamable plugin for NodeTypePlugin.
type StreamableNode interface {
Stream(ctx context.Context, in map[string]any) (
*einoschema.StreamReader[map[string]any], error)
}
// StreamableNodeWOpt is a workflow node that can Stream.
// Stream accepts non-streaming input and returns streaming output.
// It can accept NodeOption.
// e.g. NodeTypeLLM implement this.
type StreamableNodeWOpt interface {
Stream(ctx context.Context, in map[string]any, opts ...NodeOption) (
*einoschema.StreamReader[map[string]any], error)
}
// CollectableNode is a workflow node that can Collect.
// Collect accepts streaming input and returns non-streaming output.
// It does not accept and options
// Currently NO Node implement this.
// A potential example would be a new condition node that makes decisions
// based on streaming input.
type CollectableNode interface {
Collect(ctx context.Context, in *einoschema.StreamReader[map[string]any]) (
map[string]any, error)
}
// CollectableNodeWOpt is a workflow node that can Collect.
// Collect accepts streaming input and returns non-streaming output.
// It accepts NodeOption.
// Currently NO Node implement this.
// A potential example would be a new batch node that accepts streaming input,
// process them, and finally returns non-stream aggregation of results.
type CollectableNodeWOpt interface {
Collect(ctx context.Context, in *einoschema.StreamReader[map[string]any], opts ...NodeOption) (
map[string]any, error)
}
// TransformableNode is a workflow node that can Transform.
// Transform accepts streaming input and returns streaming output.
// It does not accept and options
// e.g.
// NodeTypeVariableAggregator implements TransformableNode.
type TransformableNode interface {
Transform(ctx context.Context, in *einoschema.StreamReader[map[string]any]) (
*einoschema.StreamReader[map[string]any], error)
}
// TransformableNodeWOpt is a workflow node that can Transform.
// Transform accepts streaming input and returns streaming output.
// It accepts NodeOption.
// Currently NO Node implement this.
// A potential example would be an audio processing node that
// transforms input audio clips, but within the node is a graph
// composed by Eino, and the audio processing node needs to carry
// options for this inner graph.
type TransformableNodeWOpt interface {
Transform(ctx context.Context, in *einoschema.StreamReader[map[string]any], opts ...NodeOption) (
*einoschema.StreamReader[map[string]any], error)
}
// CallbackInputConverted converts node input to a form better suited for UI.
// The converted input will be displayed on canvas when test run,
// and will be returned when querying the node's input through OpenAPI.
type CallbackInputConverted interface {
ToCallbackInput(ctx context.Context, in map[string]any) (map[string]any, error)
}
// CallbackOutputConverted converts node input to a form better suited for UI.
// The converted output will be displayed on canvas when test run,
// and will be returned when querying the node's output through OpenAPI.
type CallbackOutputConverted interface {
ToCallbackOutput(ctx context.Context, out map[string]any) (*StructuredCallbackOutput, error)
}
type Initializer interface {
Init(ctx context.Context) (context.Context, error)
}
type AdaptOptions struct {
Canvas *vo.Canvas
}
type AdaptOption func(*AdaptOptions)
func WithCanvas(canvas *vo.Canvas) AdaptOption {
return func(opts *AdaptOptions) {
opts.Canvas = canvas
}
}
func GetAdaptOptions(opts ...AdaptOption) *AdaptOptions {
options := &AdaptOptions{}
for _, opt := range opts {
opt(options)
}
return options
}
// NodeAdaptor provides conversion from frontend Node to backend NodeSchema.
type NodeAdaptor interface {
Adapt(ctx context.Context, n *vo.Node, opts ...AdaptOption) (
*schema.NodeSchema, error)
}
// BranchAdaptor provides validation and conversion from frontend port to backend port.
type BranchAdaptor interface {
ExpectPorts(ctx context.Context, n *vo.Node) []string
}
var (
nodeAdaptors = map[entity.NodeType]func() NodeAdaptor{}
branchAdaptors = map[entity.NodeType]func() BranchAdaptor{}
)
func RegisterNodeAdaptor(et entity.NodeType, f func() NodeAdaptor) {
nodeAdaptors[et] = f
}
func GetNodeAdaptor(et entity.NodeType) (NodeAdaptor, bool) {
na, ok := nodeAdaptors[et]
if !ok {
panic(fmt.Sprintf("node type %s not registered", et))
}
return na(), ok
}
func RegisterBranchAdaptor(et entity.NodeType, f func() BranchAdaptor) {
branchAdaptors[et] = f
}
func GetBranchAdaptor(et entity.NodeType) (BranchAdaptor, bool) {
na, ok := branchAdaptors[et]
if !ok {
return nil, false
}
return na(), ok
}
type StreamGenerator interface {
FieldStreamType(path compose.FieldPath, ns *schema.NodeSchema,
sc *schema.WorkflowSchema) (schema.FieldStreamType, error)
}

View File

@@ -0,0 +1,170 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nodes
import (
"github.com/bytedance/sonic"
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
)
type NodeOptions struct {
Nested *NestedWorkflowOptions
}
type NestedWorkflowOptions struct {
optsForNested []compose.Option
toResumeIndexes map[int]compose.StateModifier
optsForIndexed map[int][]compose.Option
}
type NodeOption struct {
apply func(opts *NodeOptions)
implSpecificOptFn any
}
type NestedWorkflowOption func(*NestedWorkflowOptions)
func WithOptsForNested(opts ...compose.Option) NodeOption {
return NodeOption{
apply: func(options *NodeOptions) {
if options.Nested == nil {
options.Nested = &NestedWorkflowOptions{}
}
options.Nested.optsForNested = append(options.Nested.optsForNested, opts...)
},
}
}
func (c *NodeOptions) GetOptsForNested() []compose.Option {
if c.Nested == nil {
return nil
}
return c.Nested.optsForNested
}
func WithResumeIndex(i int, m compose.StateModifier) NodeOption {
return NodeOption{
apply: func(options *NodeOptions) {
if options.Nested == nil {
options.Nested = &NestedWorkflowOptions{}
}
if options.Nested.toResumeIndexes == nil {
options.Nested.toResumeIndexes = map[int]compose.StateModifier{}
}
options.Nested.toResumeIndexes[i] = m
},
}
}
func (c *NodeOptions) GetResumeIndexes() map[int]compose.StateModifier {
if c.Nested == nil {
return nil
}
return c.Nested.toResumeIndexes
}
func WithOptsForIndexed(index int, opts ...compose.Option) NodeOption {
return NodeOption{
apply: func(options *NodeOptions) {
if options.Nested == nil {
options.Nested = &NestedWorkflowOptions{}
}
if options.Nested.optsForIndexed == nil {
options.Nested.optsForIndexed = map[int][]compose.Option{}
}
options.Nested.optsForIndexed[index] = opts
},
}
}
func (c *NodeOptions) GetOptsForIndexed(index int) []compose.Option {
if c.Nested == nil {
return nil
}
return c.Nested.optsForIndexed[index]
}
// WrapImplSpecificOptFn is the option to wrap the implementation specific option function.
func WrapImplSpecificOptFn[T any](optFn func(*T)) NodeOption {
return NodeOption{
implSpecificOptFn: optFn,
}
}
// GetCommonOptions extract model Options from Option list, optionally providing a base Options with default values.
func GetCommonOptions(base *NodeOptions, opts ...NodeOption) *NodeOptions {
if base == nil {
base = &NodeOptions{}
}
for i := range opts {
opt := opts[i]
if opt.apply != nil {
opt.apply(base)
}
}
return base
}
// GetImplSpecificOptions extract the implementation specific options from Option list, optionally providing a base options with default values.
// e.g.
//
// myOption := &MyOption{
// Field1: "default_value",
// }
//
// myOption := model.GetImplSpecificOptions(myOption, opts...)
func GetImplSpecificOptions[T any](base *T, opts ...NodeOption) *T {
if base == nil {
base = new(T)
}
for i := range opts {
opt := opts[i]
if opt.implSpecificOptFn != nil {
optFn, ok := opt.implSpecificOptFn.(func(*T))
if ok {
optFn(base)
}
}
}
return base
}
type NestedWorkflowState struct {
Index2Done map[int]bool `json:"index_2_done,omitempty"`
Index2InterruptInfo map[int]*compose.InterruptInfo `json:"index_2_interrupt_info,omitempty"`
FullOutput map[string]any `json:"full_output,omitempty"`
IntermediateVars map[string]any `json:"intermediate_vars,omitempty"`
}
func (c *NestedWorkflowState) String() string {
s, _ := sonic.MarshalIndent(c, "", " ")
return string(s)
}
type NestedWorkflowAware interface {
SaveNestedWorkflowState(key vo.NodeKey, state *NestedWorkflowState) error
GetNestedWorkflowState(key vo.NodeKey) (*NestedWorkflowState, bool, error)
InterruptEventStore
}

View File

@@ -18,16 +18,21 @@ package plugin
import ( import (
"context" "context"
"errors" "fmt"
"strconv"
"github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin" "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/errorx" "github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/types/errno" "github.com/coze-dev/coze-studio/backend/types/errno"
) )
@@ -35,29 +40,76 @@ type Config struct {
PluginID int64 PluginID int64
ToolID int64 ToolID int64
PluginVersion string PluginVersion string
}
PluginService plugin.Service func (c *Config) Adapt(ctx context.Context, n *vo.Node, opts ...nodes.AdaptOption) (*schema.NodeSchema, error) {
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypePlugin,
Name: n.Data.Meta.Title,
Configs: c,
}
inputs := n.Data.Inputs
apiParams := slices.ToMap(inputs.APIParams, func(e *vo.Param) (string, *vo.Param) {
return e.Name, e
})
ps, ok := apiParams["pluginID"]
if !ok {
return nil, fmt.Errorf("plugin id param is not found")
}
pID, err := strconv.ParseInt(ps.Input.Value.Content.(string), 10, 64)
c.PluginID = pID
ps, ok = apiParams["apiID"]
if !ok {
return nil, fmt.Errorf("plugin id param is not found")
}
tID, err := strconv.ParseInt(ps.Input.Value.Content.(string), 10, 64)
if err != nil {
return nil, err
}
c.ToolID = tID
ps, ok = apiParams["pluginVersion"]
if !ok {
return nil, fmt.Errorf("plugin version param is not found")
}
version := ps.Input.Value.Content.(string)
c.PluginVersion = version
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func (c *Config) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
return &Plugin{
pluginID: c.PluginID,
toolID: c.ToolID,
pluginVersion: c.PluginVersion,
pluginService: plugin.GetPluginService(),
}, nil
} }
type Plugin struct { type Plugin struct {
config *Config pluginID int64
} toolID int64
pluginVersion string
func NewPlugin(_ context.Context, cfg *Config) (*Plugin, error) { pluginService plugin.Service
if cfg == nil {
return nil, errors.New("config is nil")
}
if cfg.PluginID == 0 {
return nil, errors.New("plugin id is required")
}
if cfg.ToolID == 0 {
return nil, errors.New("tool id is required")
}
if cfg.PluginService == nil {
return nil, errors.New("tool service is required")
}
return &Plugin{config: cfg}, nil
} }
func (p *Plugin) Invoke(ctx context.Context, parameters map[string]any) (ret map[string]any, err error) { func (p *Plugin) Invoke(ctx context.Context, parameters map[string]any) (ret map[string]any, err error) {
@@ -65,10 +117,10 @@ func (p *Plugin) Invoke(ctx context.Context, parameters map[string]any) (ret map
if ctxExeCfg := execute.GetExeCtx(ctx); ctxExeCfg != nil { if ctxExeCfg := execute.GetExeCtx(ctx); ctxExeCfg != nil {
exeCfg = ctxExeCfg.ExeCfg exeCfg = ctxExeCfg.ExeCfg
} }
result, err := p.config.PluginService.ExecutePlugin(ctx, parameters, &vo.PluginEntity{ result, err := p.pluginService.ExecutePlugin(ctx, parameters, &vo.PluginEntity{
PluginID: p.config.PluginID, PluginID: p.pluginID,
PluginVersion: ptr.Of(p.config.PluginVersion), PluginVersion: ptr.Of(p.pluginVersion),
}, p.config.ToolID, exeCfg) }, p.toolID, exeCfg)
if err != nil { if err != nil {
if extra, ok := compose.IsInterruptRerunError(err); ok { if extra, ok := compose.IsInterruptRerunError(err); ok {
// TODO: temporarily replace interrupt with real error, because frontend cannot handle interrupt for now // TODO: temporarily replace interrupt with real error, because frontend cannot handle interrupt for now

View File

@@ -29,9 +29,12 @@ import (
"github.com/cloudwego/eino/schema" "github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/domain/workflow" "github.com/coze-dev/coze-studio/backend/domain/workflow"
crossmodel "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary" "github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
"github.com/coze-dev/coze-studio/backend/pkg/logs" "github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/sonic" "github.com/coze-dev/coze-studio/backend/pkg/sonic"
@@ -39,8 +42,21 @@ import (
) )
type QuestionAnswer struct { type QuestionAnswer struct {
config *Config model model.BaseChatModel
nodeMeta entity.NodeTypeMeta nodeMeta entity.NodeTypeMeta
questionTpl string
answerType AnswerType
choiceType ChoiceType
fixedChoices []string
needExtractFromAnswer bool
additionalSystemPromptTpl string
maxAnswerCount int
nodeKey vo.NodeKey
outputFields map[string]*vo.TypeInfo
} }
type Config struct { type Config struct {
@@ -51,15 +67,249 @@ type Config struct {
FixedChoices []string FixedChoices []string
// used for intent recognize if answer by choices and given a custom answer, as well as for extracting structured output from user response // used for intent recognize if answer by choices and given a custom answer, as well as for extracting structured output from user response
Model model.BaseChatModel LLMParams *crossmodel.LLMParams
// the following are required if AnswerType is AnswerDirectly and needs to extract from answer // the following are required if AnswerType is AnswerDirectly and needs to extract from answer
ExtractFromAnswer bool ExtractFromAnswer bool
AdditionalSystemPromptTpl string AdditionalSystemPromptTpl string
MaxAnswerCount int MaxAnswerCount int
OutputFields map[string]*vo.TypeInfo }
NodeKey vo.NodeKey func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema2.NodeSchema, error) {
ns := &schema2.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeQuestionAnswer,
Name: n.Data.Meta.Title,
Configs: c,
}
qaConf := n.Data.Inputs.QA
if qaConf == nil {
return nil, fmt.Errorf("qa config is nil")
}
c.QuestionTpl = qaConf.Question
var llmParams *crossmodel.LLMParams
if n.Data.Inputs.LLMParam != nil {
llmParamBytes, err := sonic.Marshal(n.Data.Inputs.LLMParam)
if err != nil {
return nil, err
}
var qaLLMParams vo.SimpleLLMParam
err = sonic.Unmarshal(llmParamBytes, &qaLLMParams)
if err != nil {
return nil, err
}
llmParams, err = convertLLMParams(qaLLMParams)
if err != nil {
return nil, err
}
c.LLMParams = llmParams
}
answerType, err := convertAnswerType(qaConf.AnswerType)
if err != nil {
return nil, err
}
c.AnswerType = answerType
var choiceType ChoiceType
if len(qaConf.OptionType) > 0 {
choiceType, err = convertChoiceType(qaConf.OptionType)
if err != nil {
return nil, err
}
c.ChoiceType = choiceType
}
if answerType == AnswerByChoices {
switch choiceType {
case FixedChoices:
var options []string
for _, option := range qaConf.Options {
options = append(options, option.Name)
}
c.FixedChoices = options
case DynamicChoices:
inputSources, err := convert.CanvasBlockInputToFieldInfo(qaConf.DynamicOption, compose.FieldPath{DynamicChoicesKey}, n.Parent())
if err != nil {
return nil, err
}
ns.AddInputSource(inputSources...)
inputTypes, err := convert.CanvasBlockInputToTypeInfo(qaConf.DynamicOption)
if err != nil {
return nil, err
}
ns.SetInputType(DynamicChoicesKey, inputTypes)
default:
return nil, fmt.Errorf("qa node is answer by options, but option type not provided")
}
} else if answerType == AnswerDirectly {
c.ExtractFromAnswer = qaConf.ExtractOutput
if qaConf.ExtractOutput {
if llmParams == nil {
return nil, fmt.Errorf("qa node needs to extract from answer, but LLMParams not provided")
}
c.AdditionalSystemPromptTpl = llmParams.SystemPrompt
c.MaxAnswerCount = qaConf.Limit
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
}
}
if err = convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func convertLLMParams(params vo.SimpleLLMParam) (*crossmodel.LLMParams, error) {
p := &crossmodel.LLMParams{}
p.ModelName = params.ModelName
p.ModelType = params.ModelType
p.Temperature = &params.Temperature
p.MaxTokens = params.MaxTokens
p.TopP = &params.TopP
p.ResponseFormat = params.ResponseFormat
p.SystemPrompt = params.SystemPrompt
return p, nil
}
func convertAnswerType(t vo.QAAnswerType) (AnswerType, error) {
switch t {
case vo.QAAnswerTypeOption:
return AnswerByChoices, nil
case vo.QAAnswerTypeText:
return AnswerDirectly, nil
default:
return "", fmt.Errorf("invalid QAAnswerType: %s", t)
}
}
func convertChoiceType(t vo.QAOptionType) (ChoiceType, error) {
switch t {
case vo.QAOptionTypeStatic:
return FixedChoices, nil
case vo.QAOptionTypeDynamic:
return DynamicChoices, nil
default:
return "", fmt.Errorf("invalid QAOptionType: %s", t)
}
}
func (c *Config) Build(ctx context.Context, ns *schema2.NodeSchema, _ ...schema2.BuildOption) (any, error) {
if c.AnswerType == AnswerDirectly {
if c.ExtractFromAnswer {
if c.LLMParams == nil {
return nil, errors.New("model is required when extract from answer")
}
if len(ns.OutputTypes) == 0 {
return nil, errors.New("output fields is required when extract from answer")
}
}
} else if c.AnswerType == AnswerByChoices {
if c.ChoiceType == FixedChoices {
if len(c.FixedChoices) == 0 {
return nil, errors.New("fixed choices is required when extract from answer")
}
}
} else {
return nil, fmt.Errorf("unknown answer type: %s", c.AnswerType)
}
nodeMeta := entity.NodeMetaByNodeType(entity.NodeTypeQuestionAnswer)
if nodeMeta == nil {
return nil, errors.New("node meta not found for question answer")
}
var (
m model.BaseChatModel
err error
)
if c.LLMParams != nil {
m, _, err = crossmodel.GetManager().GetModel(ctx, c.LLMParams)
if err != nil {
return nil, err
}
}
return &QuestionAnswer{
model: m,
nodeMeta: *nodeMeta,
questionTpl: c.QuestionTpl,
answerType: c.AnswerType,
choiceType: c.ChoiceType,
fixedChoices: c.FixedChoices,
needExtractFromAnswer: c.ExtractFromAnswer,
additionalSystemPromptTpl: c.AdditionalSystemPromptTpl,
maxAnswerCount: c.MaxAnswerCount,
nodeKey: ns.Key,
outputFields: ns.OutputTypes,
}, nil
}
func (c *Config) BuildBranch(_ context.Context) (
func(ctx context.Context, nodeOutput map[string]any) (int64, bool, error), bool) {
if c.AnswerType != AnswerByChoices {
return nil, false
}
return func(ctx context.Context, nodeOutput map[string]any) (int64, bool, error) {
optionID, ok := nodeOutput[OptionIDKey]
if !ok {
return -1, false, fmt.Errorf("failed to take option id from input map: %v", nodeOutput)
}
if c.ChoiceType == DynamicChoices {
if optionID.(string) == "other" {
return -1, true, nil
} else {
return 0, false, nil
}
}
if optionID.(string) == "other" {
return -1, true, nil
}
optionIDInt, ok := AlphabetToInt(optionID.(string))
if !ok {
return -1, false, fmt.Errorf("failed to convert option id from input map: %v", optionID)
}
return optionIDInt, false, nil
}, true
}
func (c *Config) ExpectPorts(ctx context.Context, n *vo.Node) (expects []string) {
if n.Data.Inputs.QA.AnswerType != vo.QAAnswerTypeOption {
return expects
}
if n.Data.Inputs.QA.OptionType == vo.QAOptionTypeStatic {
for index := range n.Data.Inputs.QA.Options {
expects = append(expects, fmt.Sprintf(schema2.PortBranchFormat, index))
}
expects = append(expects, schema2.PortDefault)
return expects
}
if n.Data.Inputs.QA.OptionType == vo.QAOptionTypeDynamic {
expects = append(expects, fmt.Sprintf(schema2.PortBranchFormat, 0))
expects = append(expects, schema2.PortDefault)
}
return expects
}
func (c *Config) RequireCheckpoint() bool {
return true
} }
type AnswerType string type AnswerType string
@@ -126,41 +376,6 @@ Strictly identify the intention and select the most suitable option. You can onl
Note: You can only output the id or -1. Your output can only be a pure number and no other content (including the reason)!` Note: You can only output the id or -1. Your output can only be a pure number and no other content (including the reason)!`
) )
func NewQuestionAnswer(_ context.Context, conf *Config) (*QuestionAnswer, error) {
if conf == nil {
return nil, errors.New("config is nil")
}
if conf.AnswerType == AnswerDirectly {
if conf.ExtractFromAnswer {
if conf.Model == nil {
return nil, errors.New("model is required when extract from answer")
}
if len(conf.OutputFields) == 0 {
return nil, errors.New("output fields is required when extract from answer")
}
}
} else if conf.AnswerType == AnswerByChoices {
if conf.ChoiceType == FixedChoices {
if len(conf.FixedChoices) == 0 {
return nil, errors.New("fixed choices is required when extract from answer")
}
}
} else {
return nil, fmt.Errorf("unknown answer type: %s", conf.AnswerType)
}
nodeMeta := entity.NodeMetaByNodeType(entity.NodeTypeQuestionAnswer)
if nodeMeta == nil {
return nil, errors.New("node meta not found for question answer")
}
return &QuestionAnswer{
config: conf,
nodeMeta: *nodeMeta,
}, nil
}
type Question struct { type Question struct {
Question string Question string
Choices []string Choices []string
@@ -182,10 +397,10 @@ type message struct {
ID string `json:"id,omitempty"` ID string `json:"id,omitempty"`
} }
// Execute formats the question (optionally with choices), interrupts, then extracts the answer. // Invoke formats the question (optionally with choices), interrupts, then extracts the answer.
// input: the references by input fields, as well as the dynamic choices array if needed. // input: the references by input fields, as well as the dynamic choices array if needed.
// output: USER_RESPONSE for direct answer, structured output if needs to extract from answer, and option ID / content for answer by choices. // output: USER_RESPONSE for direct answer, structured output if needs to extract from answer, and option ID / content for answer by choices.
func (q *QuestionAnswer) Execute(ctx context.Context, in map[string]any) (out map[string]any, err error) { func (q *QuestionAnswer) Invoke(ctx context.Context, in map[string]any) (out map[string]any, err error) {
var ( var (
questions []*Question questions []*Question
answers []string answers []string
@@ -206,11 +421,11 @@ func (q *QuestionAnswer) Execute(ctx context.Context, in map[string]any) (out ma
out[QuestionsKey] = questions out[QuestionsKey] = questions
out[AnswersKey] = answers out[AnswersKey] = answers
switch q.config.AnswerType { switch q.answerType {
case AnswerDirectly: case AnswerDirectly:
if isFirst { // first execution, ask the question if isFirst { // first execution, ask the question
// format the question. Which is common to all use cases // format the question. Which is common to all use cases
firstQuestion, err := nodes.TemplateRender(q.config.QuestionTpl, in) firstQuestion, err := nodes.TemplateRender(q.questionTpl, in)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -218,7 +433,7 @@ func (q *QuestionAnswer) Execute(ctx context.Context, in map[string]any) (out ma
return nil, q.interrupt(ctx, firstQuestion, nil, nil, nil) return nil, q.interrupt(ctx, firstQuestion, nil, nil, nil)
} }
if q.config.ExtractFromAnswer { if q.needExtractFromAnswer {
return q.extractFromAnswer(ctx, in, questions, answers) return q.extractFromAnswer(ctx, in, questions, answers)
} }
@@ -253,15 +468,15 @@ func (q *QuestionAnswer) Execute(ctx context.Context, in map[string]any) (out ma
} }
// format the question. Which is common to all use cases // format the question. Which is common to all use cases
firstQuestion, err := nodes.TemplateRender(q.config.QuestionTpl, in) firstQuestion, err := nodes.TemplateRender(q.questionTpl, in)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var formattedChoices []string var formattedChoices []string
switch q.config.ChoiceType { switch q.choiceType {
case FixedChoices: case FixedChoices:
for _, choice := range q.config.FixedChoices { for _, choice := range q.fixedChoices {
formattedChoice, err := nodes.TemplateRender(choice, in) formattedChoice, err := nodes.TemplateRender(choice, in)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -283,18 +498,18 @@ func (q *QuestionAnswer) Execute(ctx context.Context, in map[string]any) (out ma
formattedChoices = append(formattedChoices, c) formattedChoices = append(formattedChoices, c)
} }
default: default:
return nil, fmt.Errorf("unknown choice type: %s", q.config.ChoiceType) return nil, fmt.Errorf("unknown choice type: %s", q.choiceType)
} }
return nil, q.interrupt(ctx, firstQuestion, formattedChoices, nil, nil) return nil, q.interrupt(ctx, firstQuestion, formattedChoices, nil, nil)
default: default:
return nil, fmt.Errorf("unknown answer type: %s", q.config.AnswerType) return nil, fmt.Errorf("unknown answer type: %s", q.answerType)
} }
} }
func (q *QuestionAnswer) extractFromAnswer(ctx context.Context, in map[string]any, questions []*Question, answers []string) (map[string]any, error) { func (q *QuestionAnswer) extractFromAnswer(ctx context.Context, in map[string]any, questions []*Question, answers []string) (map[string]any, error) {
fieldInfo := "FieldInfo" fieldInfo := "FieldInfo"
s, err := vo.TypeInfoToJSONSchema(q.config.OutputFields, &fieldInfo) s, err := vo.TypeInfoToJSONSchema(q.outputFields, &fieldInfo)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -302,15 +517,15 @@ func (q *QuestionAnswer) extractFromAnswer(ctx context.Context, in map[string]an
sysPrompt := fmt.Sprintf(extractSystemPrompt, s) sysPrompt := fmt.Sprintf(extractSystemPrompt, s)
var requiredFields []string var requiredFields []string
for fName, tInfo := range q.config.OutputFields { for fName, tInfo := range q.outputFields {
if tInfo.Required { if tInfo.Required {
requiredFields = append(requiredFields, fName) requiredFields = append(requiredFields, fName)
} }
} }
var formattedAdditionalPrompt string var formattedAdditionalPrompt string
if len(q.config.AdditionalSystemPromptTpl) > 0 { if len(q.additionalSystemPromptTpl) > 0 {
additionalPrompt, err := nodes.TemplateRender(q.config.AdditionalSystemPromptTpl, in) additionalPrompt, err := nodes.TemplateRender(q.additionalSystemPromptTpl, in)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -336,7 +551,7 @@ func (q *QuestionAnswer) extractFromAnswer(ctx context.Context, in map[string]an
messages = append(messages, schema.UserMessage(answer)) messages = append(messages, schema.UserMessage(answer))
} }
out, err := q.config.Model.Generate(ctx, messages) out, err := q.model.Generate(ctx, messages)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -353,8 +568,8 @@ func (q *QuestionAnswer) extractFromAnswer(ctx context.Context, in map[string]an
if ok { if ok {
nextQuestionStr, ok := nextQuestion.(string) nextQuestionStr, ok := nextQuestion.(string)
if ok && len(nextQuestionStr) > 0 { if ok && len(nextQuestionStr) > 0 {
if len(answers) >= q.config.MaxAnswerCount { if len(answers) >= q.maxAnswerCount {
return nil, fmt.Errorf("max answer count= %d exceeded", q.config.MaxAnswerCount) return nil, fmt.Errorf("max answer count= %d exceeded", q.maxAnswerCount)
} }
return nil, q.interrupt(ctx, nextQuestionStr, nil, questions, answers) return nil, q.interrupt(ctx, nextQuestionStr, nil, questions, answers)
@@ -366,7 +581,7 @@ func (q *QuestionAnswer) extractFromAnswer(ctx context.Context, in map[string]an
return nil, fmt.Errorf("field %s not found", fieldInfo) return nil, fmt.Errorf("field %s not found", fieldInfo)
} }
realOutput, ws, err := nodes.ConvertInputs(ctx, fields.(map[string]any), q.config.OutputFields, nodes.SkipRequireCheck()) realOutput, ws, err := nodes.ConvertInputs(ctx, fields.(map[string]any), q.outputFields, nodes.SkipRequireCheck())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -431,7 +646,7 @@ func (q *QuestionAnswer) intentDetect(ctx context.Context, answer string, choice
schema.UserMessage(answer), schema.UserMessage(answer),
} }
out, err := q.config.Model.Generate(ctx, messages) out, err := q.model.Generate(ctx, messages)
if err != nil { if err != nil {
return -1, err return -1, err
} }
@@ -468,7 +683,7 @@ func (q *QuestionAnswer) interrupt(ctx context.Context, newQuestion string, choi
event := &entity.InterruptEvent{ event := &entity.InterruptEvent{
ID: eventID, ID: eventID,
NodeKey: q.config.NodeKey, NodeKey: q.nodeKey,
NodeType: entity.NodeTypeQuestionAnswer, NodeType: entity.NodeTypeQuestionAnswer,
NodeTitle: q.nodeMeta.Name, NodeTitle: q.nodeMeta.Name,
NodeIcon: q.nodeMeta.IconURL, NodeIcon: q.nodeMeta.IconURL,
@@ -477,7 +692,7 @@ func (q *QuestionAnswer) interrupt(ctx context.Context, newQuestion string, choi
} }
_ = compose.ProcessState(ctx, func(ctx context.Context, setter QuestionAnswerAware) error { _ = compose.ProcessState(ctx, func(ctx context.Context, setter QuestionAnswerAware) error {
setter.AddQuestion(q.config.NodeKey, &Question{ setter.AddQuestion(q.nodeKey, &Question{
Question: newQuestion, Question: newQuestion,
Choices: choices, Choices: choices,
}) })
@@ -495,14 +710,14 @@ func intToAlphabet(num int) string {
return "" return ""
} }
func AlphabetToInt(str string) (int, bool) { func AlphabetToInt(str string) (int64, bool) {
if len(str) != 1 { if len(str) != 1 {
return 0, false return 0, false
} }
char := rune(str[0]) char := rune(str[0])
char = unicode.ToUpper(char) char = unicode.ToUpper(char)
if char >= 'A' && char <= 'Z' { if char >= 'A' && char <= 'Z' {
return int(char - 'A'), true return int64(char - 'A'), true
} }
return 0, false return 0, false
} }
@@ -521,14 +736,14 @@ func (q *QuestionAnswer) generateHistory(oldQuestions []*Question, oldAnswers []
for i := 0; i < len(oldQuestions); i++ { for i := 0; i < len(oldQuestions); i++ {
oldQuestion := oldQuestions[i] oldQuestion := oldQuestions[i]
oldAnswer := oldAnswers[i] oldAnswer := oldAnswers[i]
contentType := ternary.IFElse(q.config.AnswerType == AnswerByChoices, "option", "text") contentType := ternary.IFElse(q.answerType == AnswerByChoices, "option", "text")
questionMsg := &message{ questionMsg := &message{
Type: "question", Type: "question",
ContentType: contentType, ContentType: contentType,
ID: fmt.Sprintf("%s_%d", q.config.NodeKey, i*2), ID: fmt.Sprintf("%s_%d", q.nodeKey, i*2),
} }
if q.config.AnswerType == AnswerByChoices { if q.answerType == AnswerByChoices {
questionMsg.Content = optionContent{ questionMsg.Content = optionContent{
Options: conv(oldQuestion.Choices), Options: conv(oldQuestion.Choices),
Question: oldQuestion.Question, Question: oldQuestion.Question,
@@ -541,14 +756,14 @@ func (q *QuestionAnswer) generateHistory(oldQuestions []*Question, oldAnswers []
Type: "answer", Type: "answer",
ContentType: contentType, ContentType: contentType,
Content: oldAnswer, Content: oldAnswer,
ID: fmt.Sprintf("%s_%d", q.config.NodeKey, i+1), ID: fmt.Sprintf("%s_%d", q.nodeKey, i+1),
} }
history = append(history, questionMsg, answerMsg) history = append(history, questionMsg, answerMsg)
} }
if newQuestion != nil { if newQuestion != nil {
if q.config.AnswerType == AnswerByChoices { if q.answerType == AnswerByChoices {
history = append(history, &message{ history = append(history, &message{
Type: "question", Type: "question",
ContentType: "option", ContentType: "option",
@@ -556,14 +771,14 @@ func (q *QuestionAnswer) generateHistory(oldQuestions []*Question, oldAnswers []
Options: conv(choices), Options: conv(choices),
Question: *newQuestion, Question: *newQuestion,
}, },
ID: fmt.Sprintf("%s_%d", q.config.NodeKey, len(oldQuestions)*2), ID: fmt.Sprintf("%s_%d", q.nodeKey, len(oldQuestions)*2),
}) })
} else { } else {
history = append(history, &message{ history = append(history, &message{
Type: "question", Type: "question",
ContentType: "text", ContentType: "text",
Content: *newQuestion, Content: *newQuestion,
ID: fmt.Sprintf("%s_%d", q.config.NodeKey, len(oldQuestions)*2), ID: fmt.Sprintf("%s_%d", q.nodeKey, len(oldQuestions)*2),
}) })
} }
} }

View File

@@ -27,8 +27,10 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow" "github.com/coze-dev/coze-studio/backend/domain/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache" "github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/errorx" "github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/logs" "github.com/coze-dev/coze-studio/backend/pkg/logs"
@@ -37,19 +39,27 @@ import (
) )
type Config struct { type Config struct {
OutputTypes map[string]*vo.TypeInfo
NodeKey vo.NodeKey
OutputSchema string OutputSchema string
} }
type InputReceiver struct { func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
outputTypes map[string]*vo.TypeInfo c.OutputSchema = n.Data.Inputs.OutputSchema
interruptData string
nodeKey vo.NodeKey ns := &schema.NodeSchema{
nodeMeta entity.NodeTypeMeta Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeInputReceiver,
Name: n.Data.Meta.Title,
Configs: c,
}
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
} }
func New(_ context.Context, cfg *Config) (*InputReceiver, error) { func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
nodeMeta := entity.NodeMetaByNodeType(entity.NodeTypeInputReceiver) nodeMeta := entity.NodeMetaByNodeType(entity.NodeTypeInputReceiver)
if nodeMeta == nil { if nodeMeta == nil {
return nil, errors.New("node meta not found for input receiver") return nil, errors.New("node meta not found for input receiver")
@@ -57,7 +67,7 @@ func New(_ context.Context, cfg *Config) (*InputReceiver, error) {
interruptData := map[string]string{ interruptData := map[string]string{
"content_type": "form_schema", "content_type": "form_schema",
"content": cfg.OutputSchema, "content": c.OutputSchema,
} }
interruptDataStr, err := sonic.ConfigStd.MarshalToString(interruptData) // keep the order of the keys interruptDataStr, err := sonic.ConfigStd.MarshalToString(interruptData) // keep the order of the keys
@@ -66,13 +76,24 @@ func New(_ context.Context, cfg *Config) (*InputReceiver, error) {
} }
return &InputReceiver{ return &InputReceiver{
outputTypes: cfg.OutputTypes, outputTypes: ns.OutputTypes, // so the node can refer to its output types during execution
nodeMeta: *nodeMeta, nodeMeta: *nodeMeta,
nodeKey: cfg.NodeKey, nodeKey: ns.Key,
interruptData: interruptDataStr, interruptData: interruptDataStr,
}, nil }, nil
} }
func (c *Config) RequireCheckpoint() bool {
return true
}
type InputReceiver struct {
outputTypes map[string]*vo.TypeInfo
interruptData string
nodeKey vo.NodeKey
nodeMeta entity.NodeTypeMeta
}
const ( const (
ReceivedDataKey = "$received_data" ReceivedDataKey = "$received_data"
receiverWarningKey = "receiver_warning_%d_%s" receiverWarningKey = "receiver_warning_%d_%s"

View File

@@ -0,0 +1,190 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package selector
import (
"context"
"fmt"
"strconv"
"strings"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
)
type selectorCallbackField struct {
Key string `json:"key"`
Type vo.DataType `json:"type"`
Value any `json:"value"`
}
type selectorCondition struct {
Left selectorCallbackField `json:"left"`
Operator vo.OperatorType `json:"operator"`
Right *selectorCallbackField `json:"right,omitempty"`
}
type selectorBranch struct {
Conditions []*selectorCondition `json:"conditions"`
Logic vo.LogicType `json:"logic"`
Name string `json:"name"`
}
func (s *Selector) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) {
count := len(s.clauses)
output := make([]*selectorBranch, count)
for _, source := range s.ns.InputSources {
targetPath := source.Path
if len(targetPath) == 2 {
indexStr := targetPath[0]
index, err := strconv.Atoi(indexStr)
if err != nil {
return nil, err
}
branch := output[index]
if branch == nil {
output[index] = &selectorBranch{
Conditions: []*selectorCondition{
{
Operator: s.clauses[index].Single.ToCanvasOperatorType(),
},
},
Logic: ClauseRelationAND.ToVOLogicType(),
}
}
if targetPath[1] == LeftKey {
leftV, ok := nodes.TakeMapValue(in, targetPath)
if !ok {
return nil, fmt.Errorf("failed to take left value of %s", targetPath)
}
if source.Source.Ref.VariableType != nil {
if *source.Source.Ref.VariableType == vo.ParentIntermediate {
parentNodeKey, ok := s.ws.Hierarchy[s.ns.Key]
if !ok {
return nil, fmt.Errorf("failed to find parent node key of %s", s.ns.Key)
}
parentNode := s.ws.GetNode(parentNodeKey)
output[index].Conditions[0].Left = selectorCallbackField{
Key: parentNode.Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
Value: leftV,
}
} else {
output[index].Conditions[0].Left = selectorCallbackField{
Key: "",
Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
Value: leftV,
}
}
} else {
output[index].Conditions[0].Left = selectorCallbackField{
Key: s.ws.GetNode(source.Source.Ref.FromNodeKey).Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
Value: leftV,
}
}
} else if targetPath[1] == RightKey {
rightV, ok := nodes.TakeMapValue(in, targetPath)
if !ok {
return nil, fmt.Errorf("failed to take right value of %s", targetPath)
}
output[index].Conditions[0].Right = &selectorCallbackField{
Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
Value: rightV,
}
}
} else if len(targetPath) == 3 {
indexStr := targetPath[0]
index, err := strconv.Atoi(indexStr)
if err != nil {
return nil, err
}
multi := s.clauses[index].Multi
branch := output[index]
if branch == nil {
output[index] = &selectorBranch{
Conditions: make([]*selectorCondition, len(multi.Clauses)),
Logic: multi.Relation.ToVOLogicType(),
}
}
clauseIndexStr := targetPath[1]
clauseIndex, err := strconv.Atoi(clauseIndexStr)
if err != nil {
return nil, err
}
clause := multi.Clauses[clauseIndex]
if output[index].Conditions[clauseIndex] == nil {
output[index].Conditions[clauseIndex] = &selectorCondition{
Operator: clause.ToCanvasOperatorType(),
}
}
if targetPath[2] == LeftKey {
leftV, ok := nodes.TakeMapValue(in, targetPath)
if !ok {
return nil, fmt.Errorf("failed to take left value of %s", targetPath)
}
if source.Source.Ref.VariableType != nil {
if *source.Source.Ref.VariableType == vo.ParentIntermediate {
parentNodeKey, ok := s.ws.Hierarchy[s.ns.Key]
if !ok {
return nil, fmt.Errorf("failed to find parent node key of %s", s.ns.Key)
}
parentNode := s.ws.GetNode(parentNodeKey)
output[index].Conditions[clauseIndex].Left = selectorCallbackField{
Key: parentNode.Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
Value: leftV,
}
} else {
output[index].Conditions[clauseIndex].Left = selectorCallbackField{
Key: "",
Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
Value: leftV,
}
}
} else {
output[index].Conditions[clauseIndex].Left = selectorCallbackField{
Key: s.ws.GetNode(source.Source.Ref.FromNodeKey).Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
Value: leftV,
}
}
} else if targetPath[2] == RightKey {
rightV, ok := nodes.TakeMapValue(in, targetPath)
if !ok {
return nil, fmt.Errorf("failed to take right value of %s", targetPath)
}
output[index].Conditions[clauseIndex].Right = &selectorCallbackField{
Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
Value: rightV,
}
}
}
}
return map[string]any{"branches": output}, nil
}

View File

@@ -180,3 +180,48 @@ func (o *Operator) ToCanvasOperatorType() vo.OperatorType {
panic(fmt.Sprintf("unknown operator: %+v", o)) panic(fmt.Sprintf("unknown operator: %+v", o))
} }
} }
func ToSelectorOperator(o vo.OperatorType, leftType *vo.TypeInfo) (Operator, error) {
switch o {
case vo.Equal:
return OperatorEqual, nil
case vo.NotEqual:
return OperatorNotEqual, nil
case vo.LengthGreaterThan:
return OperatorLengthGreater, nil
case vo.LengthGreaterThanEqual:
return OperatorLengthGreaterOrEqual, nil
case vo.LengthLessThan:
return OperatorLengthLesser, nil
case vo.LengthLessThanEqual:
return OperatorLengthLesserOrEqual, nil
case vo.Contain:
if leftType.Type == vo.DataTypeObject {
return OperatorContainKey, nil
}
return OperatorContain, nil
case vo.NotContain:
if leftType.Type == vo.DataTypeObject {
return OperatorNotContainKey, nil
}
return OperatorNotContain, nil
case vo.Empty:
return OperatorEmpty, nil
case vo.NotEmpty:
return OperatorNotEmpty, nil
case vo.True:
return OperatorIsTrue, nil
case vo.False:
return OperatorIsFalse, nil
case vo.GreaterThan:
return OperatorGreater, nil
case vo.GreaterThanEqual:
return OperatorGreaterOrEqual, nil
case vo.LessThan:
return OperatorLesser, nil
case vo.LessThanEqual:
return OperatorLesserOrEqual, nil
default:
return "", fmt.Errorf("unsupported operator type: %d", o)
}
}

View File

@@ -17,9 +17,16 @@
package selector package selector
import ( import (
"context"
"fmt" "fmt"
einoCompose "github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
) )
type ClauseRelation string type ClauseRelation string
@@ -29,10 +36,6 @@ const (
ClauseRelationOR ClauseRelation = "or" ClauseRelationOR ClauseRelation = "or"
) )
type Config struct {
Clauses []*OneClauseSchema `json:"clauses"`
}
type OneClauseSchema struct { type OneClauseSchema struct {
Single *Operator `json:"single,omitempty"` Single *Operator `json:"single,omitempty"`
Multi *MultiClauseSchema `json:"multi,omitempty"` Multi *MultiClauseSchema `json:"multi,omitempty"`
@@ -52,3 +55,140 @@ func (c ClauseRelation) ToVOLogicType() vo.LogicType {
panic(fmt.Sprintf("unknown clause relation: %s", c)) panic(fmt.Sprintf("unknown clause relation: %s", c))
} }
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
clauses := make([]*OneClauseSchema, 0)
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Name: n.Data.Meta.Title,
Type: entity.NodeTypeSelector,
Configs: c,
}
for i, branchCond := range n.Data.Inputs.Branches {
inputType := &vo.TypeInfo{
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{},
}
if len(branchCond.Condition.Conditions) == 1 { // single condition
cond := branchCond.Condition.Conditions[0]
left := cond.Left
if left == nil {
return nil, fmt.Errorf("operator left is nil")
}
leftType, err := convert.CanvasBlockInputToTypeInfo(left.Input)
if err != nil {
return nil, err
}
leftSources, err := convert.CanvasBlockInputToFieldInfo(left.Input, einoCompose.FieldPath{fmt.Sprintf("%d", i), LeftKey}, n.Parent())
if err != nil {
return nil, err
}
inputType.Properties[LeftKey] = leftType
ns.AddInputSource(leftSources...)
op, err := ToSelectorOperator(cond.Operator, leftType)
if err != nil {
return nil, err
}
if cond.Right != nil {
rightType, err := convert.CanvasBlockInputToTypeInfo(cond.Right.Input)
if err != nil {
return nil, err
}
rightSources, err := convert.CanvasBlockInputToFieldInfo(cond.Right.Input, einoCompose.FieldPath{fmt.Sprintf("%d", i), RightKey}, n.Parent())
if err != nil {
return nil, err
}
inputType.Properties[RightKey] = rightType
ns.AddInputSource(rightSources...)
}
ns.SetInputType(fmt.Sprintf("%d", i), inputType)
clauses = append(clauses, &OneClauseSchema{
Single: &op,
})
continue
}
var relation ClauseRelation
logic := branchCond.Condition.Logic
if logic == vo.OR {
relation = ClauseRelationOR
} else if logic == vo.AND {
relation = ClauseRelationAND
}
var ops []*Operator
for j, cond := range branchCond.Condition.Conditions {
left := cond.Left
if left == nil {
return nil, fmt.Errorf("operator left is nil")
}
leftType, err := convert.CanvasBlockInputToTypeInfo(left.Input)
if err != nil {
return nil, err
}
leftSources, err := convert.CanvasBlockInputToFieldInfo(left.Input, einoCompose.FieldPath{fmt.Sprintf("%d", i), fmt.Sprintf("%d", j), LeftKey}, n.Parent())
if err != nil {
return nil, err
}
inputType.Properties[fmt.Sprintf("%d", j)] = &vo.TypeInfo{
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
LeftKey: leftType,
},
}
ns.AddInputSource(leftSources...)
op, err := ToSelectorOperator(cond.Operator, leftType)
if err != nil {
return nil, err
}
ops = append(ops, &op)
if cond.Right != nil {
rightType, err := convert.CanvasBlockInputToTypeInfo(cond.Right.Input)
if err != nil {
return nil, err
}
rightSources, err := convert.CanvasBlockInputToFieldInfo(cond.Right.Input, einoCompose.FieldPath{fmt.Sprintf("%d", i), fmt.Sprintf("%d", j), RightKey}, n.Parent())
if err != nil {
return nil, err
}
inputType.Properties[fmt.Sprintf("%d", j)].Properties[RightKey] = rightType
ns.AddInputSource(rightSources...)
}
}
ns.SetInputType(fmt.Sprintf("%d", i), inputType)
clauses = append(clauses, &OneClauseSchema{
Multi: &MultiClauseSchema{
Clauses: ops,
Relation: relation,
},
})
}
c.Clauses = clauses
return ns, nil
}

View File

@@ -23,23 +23,32 @@ import (
"github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
) )
type Selector struct { type Selector struct {
config *Config clauses []*OneClauseSchema
ns *schema.NodeSchema
ws *schema.WorkflowSchema
} }
func NewSelector(_ context.Context, config *Config) (*Selector, error) { type Config struct {
if config == nil { Clauses []*OneClauseSchema `json:"clauses"`
return nil, fmt.Errorf("config is nil") }
func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, opts ...schema.BuildOption) (any, error) {
ws := schema.GetBuildOptions(opts...).WS
if ws == nil {
return nil, fmt.Errorf("workflow schema is required")
} }
if len(config.Clauses) == 0 { if len(c.Clauses) == 0 {
return nil, fmt.Errorf("config clauses are empty") return nil, fmt.Errorf("config clauses are empty")
} }
for _, clause := range config.Clauses { for _, clause := range c.Clauses {
if clause.Single == nil && clause.Multi == nil { if clause.Single == nil && clause.Multi == nil {
return nil, fmt.Errorf("single clause and multi clause are both nil") return nil, fmt.Errorf("single clause and multi clause are both nil")
} }
@@ -60,10 +69,42 @@ func NewSelector(_ context.Context, config *Config) (*Selector, error) {
} }
return &Selector{ return &Selector{
config: config, clauses: c.Clauses,
ns: ns,
ws: ws,
}, nil }, nil
} }
func (c *Config) BuildBranch(_ context.Context) (
func(ctx context.Context, nodeOutput map[string]any) (int64, bool, error), bool) {
return func(ctx context.Context, nodeOutput map[string]any) (int64, bool, error) {
choice := nodeOutput[SelectKey].(int64)
if choice < 0 || choice > int64(len(c.Clauses)+1) {
return -1, false, fmt.Errorf("selector choice out of range: %d", choice)
}
if choice == int64(len(c.Clauses)) { // default
return -1, true, nil
}
return choice, false, nil
}, true
}
func (c *Config) ExpectPorts(_ context.Context, n *vo.Node) []string {
expects := make([]string, len(n.Data.Inputs.Branches)+1)
expects[0] = "false" // default branch
if len(n.Data.Inputs.Branches) > 0 {
expects[1] = "true" // first condition
}
for i := 1; i < len(n.Data.Inputs.Branches); i++ { // other conditions
expects[i+1] = "true_" + strconv.Itoa(i)
}
return expects
}
type Operants struct { type Operants struct {
Left any Left any
Right any Right any
@@ -76,14 +117,14 @@ const (
SelectKey = "selected" SelectKey = "selected"
) )
func (s *Selector) Select(_ context.Context, input map[string]any) (out map[string]any, err error) { func (s *Selector) Invoke(_ context.Context, input map[string]any) (out map[string]any, err error) {
in, err := s.SelectorInputConverter(input) in, err := s.selectorInputConverter(input)
if err != nil { if err != nil {
return nil, err return nil, err
} }
predicates := make([]Predicate, 0, len(s.config.Clauses)) predicates := make([]Predicate, 0, len(s.clauses))
for i, oneConf := range s.config.Clauses { for i, oneConf := range s.clauses {
if oneConf.Single != nil { if oneConf.Single != nil {
left := in[i].Left left := in[i].Left
right := in[i].Right right := in[i].Right
@@ -132,23 +173,15 @@ func (s *Selector) Select(_ context.Context, input map[string]any) (out map[stri
} }
if isTrue { if isTrue {
return map[string]any{SelectKey: i}, nil return map[string]any{SelectKey: int64(i)}, nil
} }
} }
return map[string]any{SelectKey: len(in)}, nil // default choice return map[string]any{SelectKey: int64(len(in))}, nil // default choice
} }
func (s *Selector) GetType() string { func (s *Selector) selectorInputConverter(in map[string]any) (out []Operants, err error) {
return "Selector" conf := s.clauses
}
func (s *Selector) ConditionCount() int {
return len(s.config.Clauses)
}
func (s *Selector) SelectorInputConverter(in map[string]any) (out []Operants, err error) {
conf := s.config.Clauses
for i, oneConf := range conf { for i, oneConf := range conf {
if oneConf.Single != nil { if oneConf.Single != nil {
@@ -187,8 +220,8 @@ func (s *Selector) SelectorInputConverter(in map[string]any) (out []Operants, er
} }
func (s *Selector) ToCallbackOutput(_ context.Context, output map[string]any) (*nodes.StructuredCallbackOutput, error) { func (s *Selector) ToCallbackOutput(_ context.Context, output map[string]any) (*nodes.StructuredCallbackOutput, error) {
count := len(s.config.Clauses) count := int64(len(s.clauses))
out := output[SelectKey].(int) out := output[SelectKey].(int64)
if out == count { if out == count {
cOutput := map[string]any{"result": "pass to else branch"} cOutput := map[string]any{"result": "pass to else branch"}
return &nodes.StructuredCallbackOutput{ return &nodes.StructuredCallbackOutput{

View File

@@ -22,57 +22,27 @@ import (
"github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
) )
var KeyIsFinished = "\x1FKey is finished\x1F" var KeyIsFinished = "\x1FKey is finished\x1F"
type Mode string
const (
Streaming Mode = "streaming"
NonStreaming Mode = "non-streaming"
)
type FieldStreamType string
const (
FieldIsStream FieldStreamType = "yes" // absolutely a stream
FieldNotStream FieldStreamType = "no" // absolutely not a stream
FieldMaybeStream FieldStreamType = "maybe" // maybe a stream, requires request-time resolution
FieldSkipped FieldStreamType = "skipped" // the field source's node is skipped
)
// SourceInfo contains stream type for a input field source of a node.
type SourceInfo struct {
// IsIntermediate means this field is itself not a field source, but a map containing one or more field sources.
IsIntermediate bool
// FieldType the stream type of the field. May require request-time resolution in addition to compile-time.
FieldType FieldStreamType
// FromNodeKey is the node key that produces this field source. empty if the field is a static value or variable.
FromNodeKey vo.NodeKey
// FromPath is the path of this field source within the source node. empty if the field is a static value or variable.
FromPath compose.FieldPath
TypeInfo *vo.TypeInfo
// SubSources are SourceInfo for keys within this intermediate Map(Object) field.
SubSources map[string]*SourceInfo
}
type DynamicStreamContainer interface { type DynamicStreamContainer interface {
SaveDynamicChoice(nodeKey vo.NodeKey, groupToChoice map[string]int) SaveDynamicChoice(nodeKey vo.NodeKey, groupToChoice map[string]int)
GetDynamicChoice(nodeKey vo.NodeKey) map[string]int GetDynamicChoice(nodeKey vo.NodeKey) map[string]int
GetDynamicStreamType(nodeKey vo.NodeKey, group string) (FieldStreamType, error) GetDynamicStreamType(nodeKey vo.NodeKey, group string) (schema.FieldStreamType, error)
GetAllDynamicStreamTypes(nodeKey vo.NodeKey) (map[string]FieldStreamType, error) GetAllDynamicStreamTypes(nodeKey vo.NodeKey) (map[string]schema.FieldStreamType, error)
} }
// ResolveStreamSources resolves incoming field sources for a node, deciding their stream type. // ResolveStreamSources resolves incoming field sources for a node, deciding their stream type.
func ResolveStreamSources(ctx context.Context, sources map[string]*SourceInfo) (map[string]*SourceInfo, error) { func ResolveStreamSources(ctx context.Context, sources map[string]*schema.SourceInfo) (map[string]*schema.SourceInfo, error) {
resolved := make(map[string]*SourceInfo, len(sources)) resolved := make(map[string]*schema.SourceInfo, len(sources))
nodeKey2Skipped := make(map[vo.NodeKey]bool) nodeKey2Skipped := make(map[vo.NodeKey]bool)
var resolver func(path string, sInfo *SourceInfo) (*SourceInfo, error) var resolver func(path string, sInfo *schema.SourceInfo) (*schema.SourceInfo, error)
resolver = func(path string, sInfo *SourceInfo) (*SourceInfo, error) { resolver = func(path string, sInfo *schema.SourceInfo) (*schema.SourceInfo, error) {
resolvedNode := &SourceInfo{ resolvedNode := &schema.SourceInfo{
IsIntermediate: sInfo.IsIntermediate, IsIntermediate: sInfo.IsIntermediate,
FieldType: sInfo.FieldType, FieldType: sInfo.FieldType,
FromNodeKey: sInfo.FromNodeKey, FromNodeKey: sInfo.FromNodeKey,
@@ -81,7 +51,7 @@ func ResolveStreamSources(ctx context.Context, sources map[string]*SourceInfo) (
} }
if len(sInfo.SubSources) > 0 { if len(sInfo.SubSources) > 0 {
resolvedNode.SubSources = make(map[string]*SourceInfo, len(sInfo.SubSources)) resolvedNode.SubSources = make(map[string]*schema.SourceInfo, len(sInfo.SubSources))
for k, subInfo := range sInfo.SubSources { for k, subInfo := range sInfo.SubSources {
resolvedSub, err := resolver(k, subInfo) resolvedSub, err := resolver(k, subInfo)
@@ -109,16 +79,16 @@ func ResolveStreamSources(ctx context.Context, sources map[string]*SourceInfo) (
} }
if skipped { if skipped {
resolvedNode.FieldType = FieldSkipped resolvedNode.FieldType = schema.FieldSkipped
return resolvedNode, nil return resolvedNode, nil
} }
if sInfo.FieldType == FieldMaybeStream { if sInfo.FieldType == schema.FieldMaybeStream {
if len(sInfo.SubSources) > 0 { if len(sInfo.SubSources) > 0 {
panic("a maybe stream field should not have sub sources") panic("a maybe stream field should not have sub sources")
} }
var streamType FieldStreamType var streamType schema.FieldStreamType
err := compose.ProcessState(ctx, func(ctx context.Context, state DynamicStreamContainer) error { err := compose.ProcessState(ctx, func(ctx context.Context, state DynamicStreamContainer) error {
var e error var e error
streamType, e = state.GetDynamicStreamType(sInfo.FromNodeKey, sInfo.FromPath[0]) streamType, e = state.GetDynamicStreamType(sInfo.FromNodeKey, sInfo.FromPath[0])
@@ -128,7 +98,7 @@ func ResolveStreamSources(ctx context.Context, sources map[string]*SourceInfo) (
return nil, err return nil, err
} }
return &SourceInfo{ return &schema.SourceInfo{
IsIntermediate: sInfo.IsIntermediate, IsIntermediate: sInfo.IsIntermediate,
FieldType: streamType, FieldType: streamType,
FromNodeKey: sInfo.FromNodeKey, FromNodeKey: sInfo.FromNodeKey,
@@ -156,30 +126,12 @@ type NodeExecuteStatusAware interface {
NodeExecuted(key vo.NodeKey) bool NodeExecuted(key vo.NodeKey) bool
} }
func (s *SourceInfo) Skipped() bool { func IsStreamingField(s *schema.NodeSchema, path compose.FieldPath,
if !s.IsIntermediate { sc *schema.WorkflowSchema) (schema.FieldStreamType, error) {
return s.FieldType == FieldSkipped sg, ok := s.Configs.(StreamGenerator)
if !ok {
return schema.FieldNotStream, nil
} }
for _, sub := range s.SubSources { return sg.FieldStreamType(path, s, sc)
if !sub.Skipped() {
return false
}
}
return true
}
func (s *SourceInfo) FromNode(nodeKey vo.NodeKey) bool {
if !s.IsIntermediate {
return s.FromNodeKey == nodeKey
}
for _, sub := range s.SubSources {
if sub.FromNode(nodeKey) {
return true
}
}
return false
} }

View File

@@ -18,7 +18,6 @@ package subworkflow
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"strconv" "strconv"
@@ -29,35 +28,56 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
) )
type Config struct { type Config struct {
Runner compose.Runnable[map[string]any, map[string]any] WorkflowID int64
WorkflowVersion string
}
func (c *Config) FieldStreamType(path compose.FieldPath, ns *schema2.NodeSchema,
sc *schema2.WorkflowSchema) (schema2.FieldStreamType, error) {
if !sc.RequireStreaming() {
return schema2.FieldNotStream, nil
}
innerWF := ns.SubWorkflowSchema
if !innerWF.RequireStreaming() {
return schema2.FieldNotStream, nil
}
innerExit := innerWF.GetNode(entity.ExitNodeKey)
if innerExit.Configs.(*exit.Config).TerminatePlan == vo.ReturnVariables {
return schema2.FieldNotStream, nil
}
if !innerExit.StreamConfigs.RequireStreamingInput {
return schema2.FieldNotStream, nil
}
if len(path) > 1 || path[0] != "output" {
return schema2.FieldNotStream, fmt.Errorf(
"streaming answering sub-workflow node can only have out field 'output'")
}
return schema2.FieldIsStream, nil
} }
type SubWorkflow struct { type SubWorkflow struct {
cfg *Config Runner compose.Runnable[map[string]any, map[string]any]
} }
func NewSubWorkflow(_ context.Context, cfg *Config) (*SubWorkflow, error) { func (s *SubWorkflow) Invoke(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (map[string]any, error) {
if cfg == nil {
return nil, errors.New("config is nil")
}
if cfg.Runner == nil {
return nil, errors.New("runnable is nil")
}
return &SubWorkflow{cfg: cfg}, nil
}
func (s *SubWorkflow) Invoke(ctx context.Context, in map[string]any, opts ...nodes.NestedWorkflowOption) (map[string]any, error) {
nestedOpts, nodeKey, err := prepareOptions(ctx, opts...) nestedOpts, nodeKey, err := prepareOptions(ctx, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
out, err := s.cfg.Runner.Invoke(ctx, in, nestedOpts...) out, err := s.Runner.Invoke(ctx, in, nestedOpts...)
if err != nil { if err != nil {
interruptInfo, ok := compose.ExtractInterruptInfo(err) interruptInfo, ok := compose.ExtractInterruptInfo(err)
if !ok { if !ok {
@@ -82,13 +102,13 @@ func (s *SubWorkflow) Invoke(ctx context.Context, in map[string]any, opts ...nod
return out, nil return out, nil
} }
func (s *SubWorkflow) Stream(ctx context.Context, in map[string]any, opts ...nodes.NestedWorkflowOption) (*schema.StreamReader[map[string]any], error) { func (s *SubWorkflow) Stream(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (*schema.StreamReader[map[string]any], error) {
nestedOpts, nodeKey, err := prepareOptions(ctx, opts...) nestedOpts, nodeKey, err := prepareOptions(ctx, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
out, err := s.cfg.Runner.Stream(ctx, in, nestedOpts...) out, err := s.Runner.Stream(ctx, in, nestedOpts...)
if err != nil { if err != nil {
interruptInfo, ok := compose.ExtractInterruptInfo(err) interruptInfo, ok := compose.ExtractInterruptInfo(err)
if !ok { if !ok {
@@ -114,11 +134,8 @@ func (s *SubWorkflow) Stream(ctx context.Context, in map[string]any, opts ...nod
return out, nil return out, nil
} }
func prepareOptions(ctx context.Context, opts ...nodes.NestedWorkflowOption) ([]compose.Option, vo.NodeKey, error) { func prepareOptions(ctx context.Context, opts ...nodes.NodeOption) ([]compose.Option, vo.NodeKey, error) {
options := &nodes.NestedWorkflowOptions{} options := nodes.GetCommonOptions(&nodes.NodeOptions{}, opts...)
for _, opt := range opts {
opt(options)
}
nestedOpts := options.GetOptsForNested() nestedOpts := options.GetOptsForNested()

View File

@@ -30,6 +30,7 @@ import (
"github.com/bytedance/sonic/ast" "github.com/bytedance/sonic/ast"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/errorx" "github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/types/errno" "github.com/coze-dev/coze-studio/backend/types/errno"
) )
@@ -156,7 +157,7 @@ func removeSlice(s string) string {
type renderOptions struct { type renderOptions struct {
type2CustomRenderer map[reflect.Type]func(any) (string, error) type2CustomRenderer map[reflect.Type]func(any) (string, error)
reservedKey map[string]struct{} reservedKey map[string]struct{} // a reservedKey will always render, won't check for node skipping
nilRenderer func() (string, error) nilRenderer func() (string, error)
} }
@@ -300,7 +301,7 @@ func (tp TemplatePart) Render(m []byte, opts ...RenderOption) (string, error) {
} }
} }
func (tp TemplatePart) Skipped(resolvedSources map[string]*SourceInfo) (skipped bool, invalid bool) { func (tp TemplatePart) Skipped(resolvedSources map[string]*schema.SourceInfo) (skipped bool, invalid bool) {
if len(resolvedSources) == 0 { // no information available, maybe outside the scope of a workflow if len(resolvedSources) == 0 { // no information available, maybe outside the scope of a workflow
return false, false return false, false
} }
@@ -316,7 +317,7 @@ func (tp TemplatePart) Skipped(resolvedSources map[string]*SourceInfo) (skipped
} }
if !matchingSource.IsIntermediate { if !matchingSource.IsIntermediate {
return matchingSource.FieldType == FieldSkipped, false return matchingSource.FieldType == schema.FieldSkipped, false
} }
for _, subPath := range tp.SubPathsBeforeSlice { for _, subPath := range tp.SubPathsBeforeSlice {
@@ -325,20 +326,20 @@ func (tp TemplatePart) Skipped(resolvedSources map[string]*SourceInfo) (skipped
if matchingSource.IsIntermediate { // the user specified a non-existing source, just skip it if matchingSource.IsIntermediate { // the user specified a non-existing source, just skip it
return false, true return false, true
} }
return matchingSource.FieldType == FieldSkipped, false return matchingSource.FieldType == schema.FieldSkipped, false
} }
matchingSource = subSource matchingSource = subSource
} }
if !matchingSource.IsIntermediate { if !matchingSource.IsIntermediate {
return matchingSource.FieldType == FieldSkipped, false return matchingSource.FieldType == schema.FieldSkipped, false
} }
var checkSourceSkipped func(sInfo *SourceInfo) bool var checkSourceSkipped func(sInfo *schema.SourceInfo) bool
checkSourceSkipped = func(sInfo *SourceInfo) bool { checkSourceSkipped = func(sInfo *schema.SourceInfo) bool {
if !sInfo.IsIntermediate { if !sInfo.IsIntermediate {
return sInfo.FieldType == FieldSkipped return sInfo.FieldType == schema.FieldSkipped
} }
for _, subSource := range sInfo.SubSources { for _, subSource := range sInfo.SubSources {
if !checkSourceSkipped(subSource) { if !checkSourceSkipped(subSource) {
@@ -373,7 +374,7 @@ func (tp TemplatePart) TypeInfo(types map[string]*vo.TypeInfo) *vo.TypeInfo {
return currentType return currentType
} }
func Render(ctx context.Context, tpl string, input map[string]any, sources map[string]*SourceInfo, opts ...RenderOption) (string, error) { func Render(ctx context.Context, tpl string, input map[string]any, sources map[string]*schema.SourceInfo, opts ...RenderOption) (string, error) {
mi, err := sonic.Marshal(input) mi, err := sonic.Marshal(input)
if err != nil { if err != nil {
return "", err return "", err

View File

@@ -22,7 +22,11 @@ import (
"reflect" "reflect"
"strings" "strings"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/sonic" "github.com/coze-dev/coze-studio/backend/pkg/sonic"
) )
@@ -34,42 +38,92 @@ const (
) )
type Config struct { type Config struct {
Type Type `json:"type"` Type Type `json:"type"`
Tpl string `json:"tpl"` Tpl string `json:"tpl"`
ConcatChar string `json:"concatChar"` ConcatChar string `json:"concatChar"`
Separators []string `json:"separator"` Separators []string `json:"separator"`
FullSources map[string]*nodes.SourceInfo `json:"fullSources"`
} }
type TextProcessor struct { func (c *Config) Adapt(ctx context.Context, n *vo.Node, opts ...nodes.AdaptOption) (*schema.NodeSchema, error) {
config *Config ns := &schema.NodeSchema{
} Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeTextProcessor,
func NewTextProcessor(_ context.Context, cfg *Config) (*TextProcessor, error) { Name: n.Data.Meta.Title,
if cfg == nil { Configs: c,
return nil, fmt.Errorf("config requried")
} }
if cfg.Type == ConcatText && len(cfg.Tpl) == 0 {
if n.Data.Inputs.Method == vo.Concat {
c.Type = ConcatText
params := n.Data.Inputs.ConcatParams
for _, param := range params {
if param.Name == "concatResult" {
c.Tpl = param.Input.Value.Content.(string)
} else if param.Name == "arrayItemConcatChar" {
c.ConcatChar = param.Input.Value.Content.(string)
}
}
} else if n.Data.Inputs.Method == vo.Split {
c.Type = SplitText
params := n.Data.Inputs.SplitParams
separators := make([]string, 0, len(params))
for _, param := range params {
if param.Name == "delimiters" {
delimiters := param.Input.Value.Content.([]any)
for _, d := range delimiters {
separators = append(separators, d.(string))
}
}
}
c.Separators = separators
} else {
return nil, fmt.Errorf("not supported method: %s", n.Data.Inputs.Method)
}
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
if c.Type == ConcatText && len(c.Tpl) == 0 {
return nil, fmt.Errorf("config tpl requried") return nil, fmt.Errorf("config tpl requried")
} }
return &TextProcessor{ return &TextProcessor{
config: cfg, typ: c.Type,
tpl: c.Tpl,
concatChar: c.ConcatChar,
separators: c.Separators,
fullSources: ns.FullSources,
}, nil }, nil
}
type TextProcessor struct {
typ Type
tpl string
concatChar string
separators []string
fullSources map[string]*schema.SourceInfo
} }
const OutputKey = "output" const OutputKey = "output"
func (t *TextProcessor) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) { func (t *TextProcessor) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
switch t.config.Type { switch t.typ {
case ConcatText: case ConcatText:
arrayRenderer := func(i any) (string, error) { arrayRenderer := func(i any) (string, error) {
vs := i.([]any) vs := i.([]any)
return join(vs, t.config.ConcatChar) return join(vs, t.concatChar)
} }
result, err := nodes.Render(ctx, t.config.Tpl, input, t.config.FullSources, result, err := nodes.Render(ctx, t.tpl, input, t.fullSources,
nodes.WithCustomRender(reflect.TypeOf([]any{}), arrayRenderer)) nodes.WithCustomRender(reflect.TypeOf([]any{}), arrayRenderer))
if err != nil { if err != nil {
return nil, err return nil, err
@@ -86,9 +140,9 @@ func (t *TextProcessor) Invoke(ctx context.Context, input map[string]any) (map[s
if !ok { if !ok {
return nil, fmt.Errorf("input string field must string type but got %T", valueString) return nil, fmt.Errorf("input string field must string type but got %T", valueString)
} }
values := strings.Split(valueString, t.config.Separators[0]) values := strings.Split(valueString, t.separators[0])
// Iterate over each delimiter // Iterate over each delimiter
for _, sep := range t.config.Separators[1:] { for _, sep := range t.separators[1:] {
var tempParts []string var tempParts []string
for _, part := range values { for _, part := range values {
tempParts = append(tempParts, strings.Split(part, sep)...) tempParts = append(tempParts, strings.Split(part, sep)...)
@@ -102,7 +156,7 @@ func (t *TextProcessor) Invoke(ctx context.Context, input map[string]any) (map[s
return map[string]any{OutputKey: anyValues}, nil return map[string]any{OutputKey: anyValues}, nil
default: default:
return nil, fmt.Errorf("not support type %s", t.config.Type) return nil, fmt.Errorf("not support type %s", t.typ)
} }
} }

View File

@@ -21,6 +21,8 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
) )
func TestNewTextProcessorNodeGenerator(t *testing.T) { func TestNewTextProcessorNodeGenerator(t *testing.T) {
@@ -30,10 +32,10 @@ func TestNewTextProcessorNodeGenerator(t *testing.T) {
Type: SplitText, Type: SplitText,
Separators: []string{",", "|", "."}, Separators: []string{",", "|", "."},
} }
p, err := NewTextProcessor(ctx, cfg) p, err := cfg.Build(ctx, &schema2.NodeSchema{})
assert.NoError(t, err) assert.NoError(t, err)
result, err := p.Invoke(ctx, map[string]any{ result, err := p.(*TextProcessor).Invoke(ctx, map[string]any{
"String": "a,b|c.d,e|f|g", "String": "a,b|c.d,e|f|g",
}) })
@@ -60,9 +62,9 @@ func TestNewTextProcessorNodeGenerator(t *testing.T) {
ConcatChar: `\t`, ConcatChar: `\t`,
Tpl: "fx{{a}}=={{b.b1}}=={{b.b2[1]}}=={{c}}", Tpl: "fx{{a}}=={{b.b1}}=={{b.b2[1]}}=={{c}}",
} }
p, err := NewTextProcessor(context.Background(), cfg) p, err := cfg.Build(context.Background(), &schema2.NodeSchema{})
result, err := p.Invoke(ctx, in) result, err := p.(*TextProcessor).Invoke(ctx, in)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, result["output"], `fx1\t{"1":1}\t3==1\t2\t3==2=={"c1":"1"}`) assert.Equal(t, result["output"], `fx1\t{"1":1}\t3==1\t2\t3==2=={"c1":"1"}`)
}) })

View File

@@ -32,8 +32,11 @@ import (
"github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema" "github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache" "github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/safego" "github.com/coze-dev/coze-studio/backend/pkg/safego"
@@ -48,24 +51,147 @@ const (
type Config struct { type Config struct {
MergeStrategy MergeStrategy MergeStrategy MergeStrategy
GroupLen map[string]int GroupLen map[string]int
FullSources map[string]*nodes.SourceInfo
NodeKey vo.NodeKey
InputSources []*vo.FieldInfo
GroupOrder []string // the order the groups are declared in frontend canvas GroupOrder []string // the order the groups are declared in frontend canvas
} }
type VariableAggregator struct { func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema2.NodeSchema, error) {
config *Config ns := &schema2.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeVariableAggregator,
Name: n.Data.Meta.Title,
Configs: c,
}
c.MergeStrategy = FirstNotNullValue
inputs := n.Data.Inputs
groupToLen := make(map[string]int, len(inputs.VariableAggregator.MergeGroups))
for i := range inputs.VariableAggregator.MergeGroups {
group := inputs.VariableAggregator.MergeGroups[i]
tInfo := &vo.TypeInfo{
Type: vo.DataTypeObject,
Properties: make(map[string]*vo.TypeInfo),
}
ns.SetInputType(group.Name, tInfo)
for ii, v := range group.Variables {
name := strconv.Itoa(ii)
valueTypeInfo, err := convert.CanvasBlockInputToTypeInfo(v)
if err != nil {
return nil, err
}
tInfo.Properties[name] = valueTypeInfo
sources, err := convert.CanvasBlockInputToFieldInfo(v, compose.FieldPath{group.Name, name}, n.Parent())
if err != nil {
return nil, err
}
ns.AddInputSource(sources...)
}
length := len(group.Variables)
groupToLen[group.Name] = length
}
groupOrder := make([]string, 0, len(groupToLen))
for i := range inputs.VariableAggregator.MergeGroups {
group := inputs.VariableAggregator.MergeGroups[i]
groupOrder = append(groupOrder, group.Name)
}
c.GroupLen = groupToLen
c.GroupOrder = groupOrder
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
} }
func NewVariableAggregator(_ context.Context, cfg *Config) (*VariableAggregator, error) { func (c *Config) Build(_ context.Context, ns *schema2.NodeSchema, _ ...schema2.BuildOption) (any, error) {
if cfg == nil { if c.MergeStrategy != FirstNotNullValue {
return nil, errors.New("config is required") return nil, fmt.Errorf("merge strategy not supported: %v", c.MergeStrategy)
} }
if cfg.MergeStrategy != FirstNotNullValue {
return nil, fmt.Errorf("merge strategy not supported: %v", cfg.MergeStrategy) return &VariableAggregator{
groupLen: c.GroupLen,
fullSources: ns.FullSources,
nodeKey: ns.Key,
}, nil
}
func (c *Config) FieldStreamType(path compose.FieldPath, ns *schema2.NodeSchema,
sc *schema2.WorkflowSchema) (schema2.FieldStreamType, error) {
if !sc.RequireStreaming() {
return schema2.FieldNotStream, nil
} }
return &VariableAggregator{config: cfg}, nil
if len(path) == 2 { // asking about a specific index within a group
for _, fInfo := range ns.InputSources {
if len(fInfo.Path) == len(path) {
equal := true
for i := range fInfo.Path {
if fInfo.Path[i] != path[i] {
equal = false
break
}
}
if equal {
if fInfo.Source.Ref == nil || fInfo.Source.Ref.FromNodeKey == "" {
return schema2.FieldNotStream, nil // variables or static values
}
fromNodeKey := fInfo.Source.Ref.FromNodeKey
fromNode := sc.GetNode(fromNodeKey)
if fromNode == nil {
return schema2.FieldNotStream, fmt.Errorf("node %s not found", fromNodeKey)
}
return nodes.IsStreamingField(fromNode, fInfo.Source.Ref.FromPath, sc)
}
}
}
} else if len(path) == 1 { // asking about the entire group
var streamCount, notStreamCount int
for _, fInfo := range ns.InputSources {
if fInfo.Path[0] == path[0] { // belong to the group
if fInfo.Source.Ref != nil && len(fInfo.Source.Ref.FromNodeKey) > 0 {
fromNode := sc.GetNode(fInfo.Source.Ref.FromNodeKey)
if fromNode == nil {
return schema2.FieldNotStream, fmt.Errorf("node %s not found", fInfo.Source.Ref.FromNodeKey)
}
subStreamType, err := nodes.IsStreamingField(fromNode, fInfo.Source.Ref.FromPath, sc)
if err != nil {
return schema2.FieldNotStream, err
}
if subStreamType == schema2.FieldMaybeStream {
return schema2.FieldMaybeStream, nil
} else if subStreamType == schema2.FieldIsStream {
streamCount++
} else {
notStreamCount++
}
}
}
}
if streamCount > 0 && notStreamCount == 0 {
return schema2.FieldIsStream, nil
}
if streamCount == 0 && notStreamCount > 0 {
return schema2.FieldNotStream, nil
}
return schema2.FieldMaybeStream, nil
}
return schema2.FieldNotStream, fmt.Errorf("variable aggregator output path max len = 2, actual: %v", path)
}
type VariableAggregator struct {
groupLen map[string]int
fullSources map[string]*schema2.SourceInfo
nodeKey vo.NodeKey
groupOrder []string // the order the groups are declared in frontend canvas
} }
func (v *VariableAggregator) Invoke(ctx context.Context, input map[string]any) (_ map[string]any, err error) { func (v *VariableAggregator) Invoke(ctx context.Context, input map[string]any) (_ map[string]any, err error) {
@@ -76,7 +202,7 @@ func (v *VariableAggregator) Invoke(ctx context.Context, input map[string]any) (
result := make(map[string]any) result := make(map[string]any)
groupToChoice := make(map[string]int) groupToChoice := make(map[string]int)
for group, length := range v.config.GroupLen { for group, length := range v.groupLen {
for i := 0; i < length; i++ { for i := 0; i < length; i++ {
if value, ok := in[group][i]; ok { if value, ok := in[group][i]; ok {
if value != nil { if value != nil {
@@ -93,14 +219,14 @@ func (v *VariableAggregator) Invoke(ctx context.Context, input map[string]any) (
} }
_ = compose.ProcessState(ctx, func(ctx context.Context, state nodes.DynamicStreamContainer) error { _ = compose.ProcessState(ctx, func(ctx context.Context, state nodes.DynamicStreamContainer) error {
state.SaveDynamicChoice(v.config.NodeKey, groupToChoice) state.SaveDynamicChoice(v.nodeKey, groupToChoice)
return nil return nil
}) })
ctxcache.Store(ctx, groupChoiceTypeCacheKey, map[string]nodes.FieldStreamType{}) // none of the choices are stream ctxcache.Store(ctx, groupChoiceTypeCacheKey, map[string]schema2.FieldStreamType{}) // none of the choices are stream
groupChoices := make([]any, 0, len(v.config.GroupOrder)) groupChoices := make([]any, 0, len(v.groupOrder))
for _, group := range v.config.GroupOrder { for _, group := range v.groupOrder {
choice := groupToChoice[group] choice := groupToChoice[group]
if choice == -1 { if choice == -1 {
groupChoices = append(groupChoices, nil) groupChoices = append(groupChoices, nil)
@@ -125,7 +251,7 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
_ *schema.StreamReader[map[string]any], err error) { _ *schema.StreamReader[map[string]any], err error) {
inStream := streamInputConverter(input) inStream := streamInputConverter(input)
resolvedSources, ok := ctxcache.Get[map[string]*nodes.SourceInfo](ctx, resolvedSourcesCacheKey) resolvedSources, ok := ctxcache.Get[map[string]*schema2.SourceInfo](ctx, resolvedSourcesCacheKey)
if !ok { if !ok {
panic("unable to get resolvesSources from ctx cache.") panic("unable to get resolvesSources from ctx cache.")
} }
@@ -138,18 +264,18 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
defer func() { defer func() {
if err == nil { if err == nil {
groupChoiceToStreamType := map[string]nodes.FieldStreamType{} groupChoiceToStreamType := map[string]schema2.FieldStreamType{}
for group, choice := range groupToChoice { for group, choice := range groupToChoice {
if choice != -1 { if choice != -1 {
item := groupToItems[group][choice] item := groupToItems[group][choice]
if _, ok := item.(stream); ok { if _, ok := item.(stream); ok {
groupChoiceToStreamType[group] = nodes.FieldIsStream groupChoiceToStreamType[group] = schema2.FieldIsStream
} }
} }
} }
groupChoices := make([]any, 0, len(v.config.GroupOrder)) groupChoices := make([]any, 0, len(v.groupOrder))
for _, group := range v.config.GroupOrder { for _, group := range v.groupOrder {
choice := groupToChoice[group] choice := groupToChoice[group]
if choice == -1 { if choice == -1 {
groupChoices = append(groupChoices, nil) groupChoices = append(groupChoices, nil)
@@ -174,16 +300,16 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
// - if an element is not stream, actually receive from the stream to check if it's non-nil // - if an element is not stream, actually receive from the stream to check if it's non-nil
groupToCurrentIndex := make(map[string]int) // the currently known smallest index that is non-nil for each group groupToCurrentIndex := make(map[string]int) // the currently known smallest index that is non-nil for each group
for group, length := range v.config.GroupLen { for group, length := range v.groupLen {
groupToItems[group] = make([]any, length) groupToItems[group] = make([]any, length)
groupToCurrentIndex[group] = math.MaxInt groupToCurrentIndex[group] = math.MaxInt
for i := 0; i < length; i++ { for i := 0; i < length; i++ {
fType := resolvedSources[group].SubSources[strconv.Itoa(i)].FieldType fType := resolvedSources[group].SubSources[strconv.Itoa(i)].FieldType
if fType == nodes.FieldSkipped { if fType == schema2.FieldSkipped {
groupToItems[group][i] = skipped{} groupToItems[group][i] = skipped{}
continue continue
} }
if fType == nodes.FieldIsStream { if fType == schema2.FieldIsStream {
groupToItems[group][i] = stream{} groupToItems[group][i] = stream{}
if ci, _ := groupToCurrentIndex[group]; i < ci { if ci, _ := groupToCurrentIndex[group]; i < ci {
groupToCurrentIndex[group] = i groupToCurrentIndex[group] = i
@@ -211,7 +337,7 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
} }
allDone := func() bool { allDone := func() bool {
for group := range v.config.GroupLen { for group := range v.groupLen {
_, ok := groupToChoice[group] _, ok := groupToChoice[group]
if !ok { if !ok {
return false return false
@@ -223,7 +349,7 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
alreadyDone := allDone() alreadyDone := allDone()
if alreadyDone { // all groups have made their choices, no need to actually read input streams if alreadyDone { // all groups have made their choices, no need to actually read input streams
result := make(map[string]any, len(v.config.GroupLen)) result := make(map[string]any, len(v.groupLen))
allSkip := true allSkip := true
for group := range groupToChoice { for group := range groupToChoice {
choice := groupToChoice[group] choice := groupToChoice[group]
@@ -237,7 +363,7 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
if allSkip { // no need to convert input streams for the output, because all groups are skipped if allSkip { // no need to convert input streams for the output, because all groups are skipped
_ = compose.ProcessState(ctx, func(ctx context.Context, state nodes.DynamicStreamContainer) error { _ = compose.ProcessState(ctx, func(ctx context.Context, state nodes.DynamicStreamContainer) error {
state.SaveDynamicChoice(v.config.NodeKey, groupToChoice) state.SaveDynamicChoice(v.nodeKey, groupToChoice)
return nil return nil
}) })
return schema.StreamReaderFromArray([]map[string]any{result}), nil return schema.StreamReaderFromArray([]map[string]any{result}), nil
@@ -336,7 +462,7 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
} }
_ = compose.ProcessState(ctx, func(ctx context.Context, state nodes.DynamicStreamContainer) error { _ = compose.ProcessState(ctx, func(ctx context.Context, state nodes.DynamicStreamContainer) error {
state.SaveDynamicChoice(v.config.NodeKey, groupToChoice) state.SaveDynamicChoice(v.nodeKey, groupToChoice)
return nil return nil
}) })
@@ -416,26 +542,12 @@ type vaCallbackInput struct {
Variables []any `json:"variables"` Variables []any `json:"variables"`
} }
func (v *VariableAggregator) Init(ctx context.Context) (context.Context, error) {
ctx = ctxcache.Init(ctx)
resolvedSources, err := nodes.ResolveStreamSources(ctx, v.config.FullSources)
if err != nil {
return nil, err
}
// need this info for callbacks.OnStart, so we put it in cache within Init()
ctxcache.Store(ctx, resolvedSourcesCacheKey, resolvedSources)
return ctx, nil
}
type streamMarkerType string type streamMarkerType string
const streamMarker streamMarkerType = "<Stream Data...>" const streamMarker streamMarkerType = "<Stream Data...>"
func (v *VariableAggregator) ToCallbackInput(ctx context.Context, input map[string]any) (map[string]any, error) { func (v *VariableAggregator) ToCallbackInput(ctx context.Context, input map[string]any) (map[string]any, error) {
resolvedSources, ok := ctxcache.Get[map[string]*nodes.SourceInfo](ctx, resolvedSourcesCacheKey) resolvedSources, ok := ctxcache.Get[map[string]*schema2.SourceInfo](ctx, resolvedSourcesCacheKey)
if !ok { if !ok {
panic("unable to get resolved_sources from ctx cache") panic("unable to get resolved_sources from ctx cache")
} }
@@ -447,14 +559,14 @@ func (v *VariableAggregator) ToCallbackInput(ctx context.Context, input map[stri
merged := make([]vaCallbackInput, 0, len(in)) merged := make([]vaCallbackInput, 0, len(in))
groupLen := v.config.GroupLen groupLen := v.groupLen
for groupName, vars := range in { for groupName, vars := range in {
orderedVars := make([]any, groupLen[groupName]) orderedVars := make([]any, groupLen[groupName])
for index := range vars { for index := range vars {
orderedVars[index] = vars[index] orderedVars[index] = vars[index]
if len(resolvedSources) > 0 { if len(resolvedSources) > 0 {
if resolvedSources[groupName].SubSources[strconv.Itoa(index)].FieldType == nodes.FieldIsStream { if resolvedSources[groupName].SubSources[strconv.Itoa(index)].FieldType == schema2.FieldIsStream {
// replace the streams with streamMarker, // replace the streams with streamMarker,
// because we won't read, save to execution history, or display these streams to user // because we won't read, save to execution history, or display these streams to user
orderedVars[index] = streamMarker orderedVars[index] = streamMarker
@@ -479,7 +591,7 @@ func (v *VariableAggregator) ToCallbackInput(ctx context.Context, input map[stri
} }
func (v *VariableAggregator) ToCallbackOutput(ctx context.Context, output map[string]any) (*nodes.StructuredCallbackOutput, error) { func (v *VariableAggregator) ToCallbackOutput(ctx context.Context, output map[string]any) (*nodes.StructuredCallbackOutput, error) {
dynamicStreamType, ok := ctxcache.Get[map[string]nodes.FieldStreamType](ctx, groupChoiceTypeCacheKey) dynamicStreamType, ok := ctxcache.Get[map[string]schema2.FieldStreamType](ctx, groupChoiceTypeCacheKey)
if !ok { if !ok {
panic("unable to get dynamic stream types from ctx cache") panic("unable to get dynamic stream types from ctx cache")
} }
@@ -501,7 +613,7 @@ func (v *VariableAggregator) ToCallbackOutput(ctx context.Context, output map[st
newOut := maps.Clone(output) newOut := maps.Clone(output)
for k := range output { for k := range output {
if t, ok := dynamicStreamType[k]; ok && t == nodes.FieldIsStream { if t, ok := dynamicStreamType[k]; ok && t == schema2.FieldIsStream {
newOut[k] = streamMarker newOut[k] = streamMarker
} }
} }
@@ -594,3 +706,15 @@ func init() {
nodes.RegisterStreamChunkConcatFunc(concatVACallbackInputs) nodes.RegisterStreamChunkConcatFunc(concatVACallbackInputs)
nodes.RegisterStreamChunkConcatFunc(concatStreamMarkers) nodes.RegisterStreamChunkConcatFunc(concatStreamMarkers)
} }
func (v *VariableAggregator) Init(ctx context.Context) (context.Context, error) {
resolvedSources, err := nodes.ResolveStreamSources(ctx, v.fullSources)
if err != nil {
return nil, err
}
// need this info for callbacks.OnStart, so we put it in cache within Init()
ctxcache.Store(ctx, resolvedSourcesCacheKey, resolvedSources)
return ctx, nil
}

View File

@@ -25,29 +25,75 @@ import (
"github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable" "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/errorx" "github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/types/errno" "github.com/coze-dev/coze-studio/backend/types/errno"
) )
type VariableAssigner struct { type VariableAssigner struct {
config *Config pairs []*Pair
handler *variable.Handler
} }
type Config struct { type Config struct {
Pairs []*Pair Pairs []*Pair
Handler *variable.Handler
} }
type Pair struct { func (c *Config) Adapt(ctx context.Context, n *vo.Node, opts ...nodes.AdaptOption) (*schema.NodeSchema, error) {
Left vo.Reference ns := &schema.NodeSchema{
Right compose.FieldPath Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeVariableAssigner,
Name: n.Data.Meta.Title,
Configs: c,
}
var pairs = make([]*Pair, 0, len(n.Data.Inputs.InputParameters))
for i, param := range n.Data.Inputs.InputParameters {
if param.Left == nil || param.Input == nil {
return nil, fmt.Errorf("variable assigner node's param left or input is nil")
}
leftSources, err := convert.CanvasBlockInputToFieldInfo(param.Left, compose.FieldPath{fmt.Sprintf("left_%d", i)}, n.Parent())
if err != nil {
return nil, err
}
if leftSources[0].Source.Ref == nil {
return nil, fmt.Errorf("variable assigner node's param left source ref is nil")
}
if leftSources[0].Source.Ref.VariableType == nil {
return nil, fmt.Errorf("variable assigner node's param left source ref's variable type is nil")
}
if *leftSources[0].Source.Ref.VariableType == vo.GlobalSystem {
return nil, fmt.Errorf("variable assigner node's param left's ref's variable type cannot be variable.GlobalSystem")
}
inputSource, err := convert.CanvasBlockInputToFieldInfo(param.Input, leftSources[0].Source.Ref.FromPath, n.Parent())
if err != nil {
return nil, err
}
ns.AddInputSource(inputSource...)
pair := &Pair{
Left: *leftSources[0].Source.Ref,
Right: inputSource[0].Path,
}
pairs = append(pairs, pair)
}
c.Pairs = pairs
return ns, nil
} }
func NewVariableAssigner(_ context.Context, conf *Config) (*VariableAssigner, error) { func (c *Config) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
for _, pair := range conf.Pairs { for _, pair := range c.Pairs {
if pair.Left.VariableType == nil { if pair.Left.VariableType == nil {
return nil, fmt.Errorf("cannot assign to output of nodes in VariableAssigner, ref: %v", pair.Left) return nil, fmt.Errorf("cannot assign to output of nodes in VariableAssigner, ref: %v", pair.Left)
} }
@@ -63,12 +109,18 @@ func NewVariableAssigner(_ context.Context, conf *Config) (*VariableAssigner, er
} }
return &VariableAssigner{ return &VariableAssigner{
config: conf, pairs: c.Pairs,
handler: variable.GetVariableHandler(),
}, nil }, nil
} }
func (v *VariableAssigner) Assign(ctx context.Context, in map[string]any) (map[string]any, error) { type Pair struct {
for _, pair := range v.config.Pairs { Left vo.Reference
Right compose.FieldPath
}
func (v *VariableAssigner) Invoke(ctx context.Context, in map[string]any) (map[string]any, error) {
for _, pair := range v.pairs {
right, ok := nodes.TakeMapValue(in, pair.Right) right, ok := nodes.TakeMapValue(in, pair.Right)
if !ok { if !ok {
return nil, vo.NewError(errno.ErrInputFieldMissing, errorx.KV("name", strings.Join(pair.Right, "."))) return nil, vo.NewError(errno.ErrInputFieldMissing, errorx.KV("name", strings.Join(pair.Right, ".")))
@@ -98,7 +150,7 @@ func (v *VariableAssigner) Assign(ctx context.Context, in map[string]any) (map[s
ConnectorUID: exeCfg.ConnectorUID, ConnectorUID: exeCfg.ConnectorUID,
})) }))
} }
err := v.config.Handler.Set(ctx, *pair.Left.VariableType, pair.Left.FromPath, right, opts...) err := v.handler.Set(ctx, *pair.Left.VariableType, pair.Left.FromPath, right, opts...)
if err != nil { if err != nil {
return nil, vo.WrapIfNeeded(errno.ErrVariablesAPIFail, err) return nil, vo.WrapIfNeeded(errno.ErrVariablesAPIFail, err)
} }

View File

@@ -20,25 +20,93 @@ import (
"context" "context"
"fmt" "fmt"
einoCompose "github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable" "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
) )
type InLoop struct { type InLoopConfig struct {
config *Config Pairs []*Pair
intermediateVarStore variable.Store
} }
func NewVariableAssignerInLoop(_ context.Context, conf *Config) (*InLoop, error) { func (i *InLoopConfig) Adapt(ctx context.Context, n *vo.Node, opts ...nodes.AdaptOption) (*schema.NodeSchema, error) {
if n.Parent() == nil {
return nil, fmt.Errorf("loop set variable node must have parent: %s", n.ID)
}
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeVariableAssignerWithinLoop,
Name: n.Data.Meta.Title,
Configs: i,
}
var pairs []*Pair
for i, param := range n.Data.Inputs.InputParameters {
if param.Left == nil || param.Right == nil {
return nil, fmt.Errorf("loop set variable node's param left or right is nil")
}
leftSources, err := convert.CanvasBlockInputToFieldInfo(param.Left, einoCompose.FieldPath{fmt.Sprintf("left_%d", i)}, n.Parent())
if err != nil {
return nil, err
}
if len(leftSources) != 1 {
return nil, fmt.Errorf("loop set variable node's param left is not a single source")
}
if leftSources[0].Source.Ref == nil {
return nil, fmt.Errorf("loop set variable node's param left's ref is nil")
}
if leftSources[0].Source.Ref.VariableType == nil || *leftSources[0].Source.Ref.VariableType != vo.ParentIntermediate {
return nil, fmt.Errorf("loop set variable node's param left's ref's variable type is not variable.ParentIntermediate")
}
rightSources, err := convert.CanvasBlockInputToFieldInfo(param.Right, leftSources[0].Source.Ref.FromPath, n.Parent())
if err != nil {
return nil, err
}
ns.AddInputSource(rightSources...)
if len(rightSources) != 1 {
return nil, fmt.Errorf("loop set variable node's param right is not a single source")
}
pair := &Pair{
Left: *leftSources[0].Source.Ref,
Right: rightSources[0].Path,
}
pairs = append(pairs, pair)
}
i.Pairs = pairs
return ns, nil
}
func (i *InLoopConfig) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
return &InLoop{ return &InLoop{
config: conf, pairs: i.Pairs,
intermediateVarStore: &nodes.ParentIntermediateStore{}, intermediateVarStore: &nodes.ParentIntermediateStore{},
}, nil }, nil
} }
func (v *InLoop) Assign(ctx context.Context, in map[string]any) (out map[string]any, err error) { type InLoop struct {
for _, pair := range v.config.Pairs { pairs []*Pair
intermediateVarStore variable.Store
}
func (v *InLoop) Invoke(ctx context.Context, in map[string]any) (out map[string]any, err error) {
for _, pair := range v.pairs {
if pair.Left.VariableType == nil || *pair.Left.VariableType != vo.ParentIntermediate { if pair.Left.VariableType == nil || *pair.Left.VariableType != vo.ParentIntermediate {
panic(fmt.Errorf("dest is %+v in VariableAssignerInloop, invalid", pair.Left)) panic(fmt.Errorf("dest is %+v in VariableAssignerInloop, invalid", pair.Left))
} }

View File

@@ -37,36 +37,34 @@ func TestVariableAssigner(t *testing.T) {
arrVar := any([]any{1, "2"}) arrVar := any([]any{1, "2"})
va := &InLoop{ va := &InLoop{
config: &Config{ pairs: []*Pair{
Pairs: []*Pair{ {
{ Left: vo.Reference{
Left: vo.Reference{ FromPath: compose.FieldPath{"int_var_s"},
FromPath: compose.FieldPath{"int_var_s"}, VariableType: ptr.Of(vo.ParentIntermediate),
VariableType: ptr.Of(vo.ParentIntermediate),
},
Right: compose.FieldPath{"int_var_t"},
}, },
{ Right: compose.FieldPath{"int_var_t"},
Left: vo.Reference{ },
FromPath: compose.FieldPath{"str_var_s"}, {
VariableType: ptr.Of(vo.ParentIntermediate), Left: vo.Reference{
}, FromPath: compose.FieldPath{"str_var_s"},
Right: compose.FieldPath{"str_var_t"}, VariableType: ptr.Of(vo.ParentIntermediate),
}, },
{ Right: compose.FieldPath{"str_var_t"},
Left: vo.Reference{ },
FromPath: compose.FieldPath{"obj_var_s"}, {
VariableType: ptr.Of(vo.ParentIntermediate), Left: vo.Reference{
}, FromPath: compose.FieldPath{"obj_var_s"},
Right: compose.FieldPath{"obj_var_t"}, VariableType: ptr.Of(vo.ParentIntermediate),
}, },
{ Right: compose.FieldPath{"obj_var_t"},
Left: vo.Reference{ },
FromPath: compose.FieldPath{"arr_var_s"}, {
VariableType: ptr.Of(vo.ParentIntermediate), Left: vo.Reference{
}, FromPath: compose.FieldPath{"arr_var_s"},
Right: compose.FieldPath{"arr_var_t"}, VariableType: ptr.Of(vo.ParentIntermediate),
}, },
Right: compose.FieldPath{"arr_var_t"},
}, },
}, },
intermediateVarStore: &nodes.ParentIntermediateStore{}, intermediateVarStore: &nodes.ParentIntermediateStore{},
@@ -79,7 +77,7 @@ func TestVariableAssigner(t *testing.T) {
"arr_var_s": &arrVar, "arr_var_s": &arrVar,
}, nil) }, nil)
_, err := va.Assign(ctx, map[string]any{ _, err := va.Invoke(ctx, map[string]any{
"int_var_t": 2, "int_var_t": 2,
"str_var_t": "str2", "str_var_t": "str2",
"obj_var_t": map[string]any{ "obj_var_t": map[string]any{

View File

@@ -0,0 +1,196 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package schema
import (
"context"
"fmt"
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
)
// Port type constants
const (
PortDefault = "default"
PortBranchError = "branch_error"
PortBranchFormat = "branch_%d"
)
// BranchSchema defines the schema for workflow branches.
type BranchSchema struct {
From vo.NodeKey `json:"from_node"`
DefaultMapping map[string]bool `json:"default_mapping,omitempty"`
ExceptionMapping map[string]bool `json:"exception_mapping,omitempty"`
Mappings map[int64]map[string]bool `json:"mappings,omitempty"`
}
// BuildBranches builds branch schemas from connections.
func BuildBranches(connections []*Connection) (map[vo.NodeKey]*BranchSchema, error) {
var branchMap map[vo.NodeKey]*BranchSchema
for _, conn := range connections {
if conn.FromPort == nil || len(*conn.FromPort) == 0 {
continue
}
port := *conn.FromPort
sourceNodeKey := conn.FromNode
if branchMap == nil {
branchMap = map[vo.NodeKey]*BranchSchema{}
}
// Get or create branch schema for source node
branch, exists := branchMap[sourceNodeKey]
if !exists {
branch = &BranchSchema{
From: sourceNodeKey,
}
branchMap[sourceNodeKey] = branch
}
// Classify port type and add to appropriate mapping
switch {
case port == PortDefault:
if branch.DefaultMapping == nil {
branch.DefaultMapping = map[string]bool{}
}
branch.DefaultMapping[string(conn.ToNode)] = true
case port == PortBranchError:
if branch.ExceptionMapping == nil {
branch.ExceptionMapping = map[string]bool{}
}
branch.ExceptionMapping[string(conn.ToNode)] = true
default:
var branchNum int64
_, err := fmt.Sscanf(port, PortBranchFormat, &branchNum)
if err != nil || branchNum < 0 {
return nil, fmt.Errorf("invalid port format '%s' for connection %+v", port, conn)
}
if branch.Mappings == nil {
branch.Mappings = map[int64]map[string]bool{}
}
if _, exists := branch.Mappings[branchNum]; !exists {
branch.Mappings[branchNum] = make(map[string]bool)
}
branch.Mappings[branchNum][string(conn.ToNode)] = true
}
}
return branchMap, nil
}
func (bs *BranchSchema) OnlyException() bool {
return len(bs.Mappings) == 0 && len(bs.ExceptionMapping) > 0 && len(bs.DefaultMapping) > 0
}
func (bs *BranchSchema) GetExceptionBranch() *compose.GraphBranch {
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
isSuccess, ok := in["isSuccess"]
if ok && isSuccess != nil && !isSuccess.(bool) {
return bs.ExceptionMapping, nil
}
return bs.DefaultMapping, nil
}
// Combine ExceptionMapping and DefaultMapping into a new map
endNodes := make(map[string]bool)
for node := range bs.ExceptionMapping {
endNodes[node] = true
}
for node := range bs.DefaultMapping {
endNodes[node] = true
}
return compose.NewGraphMultiBranch(condition, endNodes)
}
func (bs *BranchSchema) GetFullBranch(ctx context.Context, bb BranchBuilder) (*compose.GraphBranch, error) {
extractor, hasBranch := bb.BuildBranch(ctx)
if !hasBranch {
return nil, fmt.Errorf("branch expected but BranchBuilder thinks not. BranchSchema: %v", bs)
}
if len(bs.ExceptionMapping) == 0 { // no exception, it's a normal branch
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
index, isDefault, err := extractor(ctx, in)
if err != nil {
return nil, err
}
if isDefault {
return bs.DefaultMapping, nil
}
if _, ok := bs.Mappings[index]; !ok {
return nil, fmt.Errorf("chosen index= %d, out of range", index)
}
return bs.Mappings[index], nil
}
// Combine DefaultMapping and normal mappings into a new map
endNodes := make(map[string]bool)
for node := range bs.DefaultMapping {
endNodes[node] = true
}
for _, ms := range bs.Mappings {
for node := range ms {
endNodes[node] = true
}
}
return compose.NewGraphMultiBranch(condition, endNodes), nil
}
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
isSuccess, ok := in["isSuccess"]
if ok && isSuccess != nil && !isSuccess.(bool) {
return bs.ExceptionMapping, nil
}
index, isDefault, err := extractor(ctx, in)
if err != nil {
return nil, err
}
if isDefault {
return bs.DefaultMapping, nil
}
return bs.Mappings[index], nil
}
// Combine ALL mappings into a new map
endNodes := make(map[string]bool)
for node := range bs.ExceptionMapping {
endNodes[node] = true
}
for node := range bs.DefaultMapping {
endNodes[node] = true
}
for _, ms := range bs.Mappings {
for node := range ms {
endNodes[node] = true
}
}
return compose.NewGraphMultiBranch(condition, endNodes), nil
}

View File

@@ -0,0 +1,73 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package schema
import (
"context"
"github.com/cloudwego/eino/compose"
)
type BuildOptions struct {
WS *WorkflowSchema
Inner compose.Runnable[map[string]any, map[string]any]
}
func GetBuildOptions(opts ...BuildOption) *BuildOptions {
bo := &BuildOptions{}
for _, o := range opts {
o(bo)
}
return bo
}
type BuildOption func(options *BuildOptions)
func WithWorkflowSchema(ws *WorkflowSchema) BuildOption {
return func(options *BuildOptions) {
options.WS = ws
}
}
func WithInnerWorkflow(inner compose.Runnable[map[string]any, map[string]any]) BuildOption {
return func(options *BuildOptions) {
options.Inner = inner
}
}
// NodeBuilder takes a NodeSchema and several BuildOption to build an executable node instance.
// The result 'executable' MUST implement at least one of the execute interfaces:
// - nodes.InvokableNode
// - nodes.StreamableNode
// - nodes.CollectableNode
// - nodes.TransformableNode
// - nodes.InvokableNodeWOpt
// - nodes.StreamableNodeWOpt
// - nodes.CollectableNodeWOpt
// - nodes.TransformableNodeWOpt
// NOTE: the 'normal' version does not take NodeOption, while the 'WOpt' versions take NodeOption.
// NOTE: a node should either implement the 'normal' versions, or the 'WOpt' versions, not mix them up.
type NodeBuilder interface {
Build(ctx context.Context, ns *NodeSchema, opts ...BuildOption) (
executable any, err error)
}
// BranchBuilder builds the extractor function that maps node output to port index.
type BranchBuilder interface {
BuildBranch(ctx context.Context) (extractor func(ctx context.Context,
nodeOutput map[string]any) (int64, bool /*if is default branch*/, error), hasBranch bool)
}

View File

@@ -0,0 +1,131 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package schema
import (
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
)
// NodeSchema is the universal description and configuration for a workflow Node.
// It should contain EVERYTHING a node needs to instantiate.
type NodeSchema struct {
// Key is the node key within the Eino graph.
// A node may need this information during execution,
// e.g.
// - using this Key to query workflow State for data belonging to current node.
Key vo.NodeKey `json:"key"`
// Name is the name for this node as specified on Canvas.
// A node may show this name on Canvas as part of this node's input/output.
Name string `json:"name"`
// Type is the NodeType for the node.
Type entity.NodeType `json:"type"`
// Configs are node specific configurations, with actual struct type defined by each Node Type.
// Will not hold information relating to field mappings, nor as node's static values.
// In a word, these Configs are INTERNAL to node's implementation, NOT related to workflow orchestration.
// Actual type of these Configs should implement two interfaces:
// - NodeAdaptor: to provide conversion from vo.Node to NodeSchema
// - NodeBuilder: to provide instantiation from NodeSchema to actual node instance.
Configs any `json:"configs,omitempty"`
// InputTypes are type information about the node's input fields.
InputTypes map[string]*vo.TypeInfo `json:"input_types,omitempty"`
// InputSources are field mapping information about the node's input fields.
InputSources []*vo.FieldInfo `json:"input_sources,omitempty"`
// OutputTypes are type information about the node's output fields.
OutputTypes map[string]*vo.TypeInfo `json:"output_types,omitempty"`
// OutputSources are field mapping information about the node's output fields.
// NOTE: only applicable to composite nodes such as NodeTypeBatch or NodeTypeLoop.
OutputSources []*vo.FieldInfo `json:"output_sources,omitempty"`
// ExceptionConfigs are about exception handling strategy of the node.
ExceptionConfigs *ExceptionConfig `json:"exception_configs,omitempty"`
// StreamConfigs are streaming characteristics of the node.
StreamConfigs *StreamConfig `json:"stream_configs,omitempty"`
// SubWorkflowBasic is basic information of the sub workflow if this node is NodeTypeSubWorkflow.
SubWorkflowBasic *entity.WorkflowBasic `json:"sub_workflow_basic,omitempty"`
// SubWorkflowSchema is WorkflowSchema of the sub workflow if this node is NodeTypeSubWorkflow.
SubWorkflowSchema *WorkflowSchema `json:"sub_workflow_schema,omitempty"`
// FullSources contains more complete information about a node's input fields' mapping sources,
// such as whether a field's source is a 'streaming field',
// or whether the field is an object that contains sub-fields with real mappings.
// Used for those nodes that need to process streaming input.
// Set InputSourceAware = true in NodeMeta to enable.
FullSources map[string]*SourceInfo
// Lambda directly sets the node to be an Eino Lambda.
// NOTE: not serializable, used ONLY for internal test.
Lambda *compose.Lambda
}
type RequireCheckpoint interface {
RequireCheckpoint() bool
}
type ExceptionConfig struct {
TimeoutMS int64 `json:"timeout_ms,omitempty"` // timeout in milliseconds, 0 means no timeout
MaxRetry int64 `json:"max_retry,omitempty"` // max retry times, 0 means no retry
ProcessType *vo.ErrorProcessType `json:"process_type,omitempty"` // error process type, 0 means throw error
DataOnErr string `json:"data_on_err,omitempty"` // data to return when error, effective when ProcessType==Default occurs
}
type StreamConfig struct {
// whether this node has the ability to produce genuine streaming output.
// not include nodes that only passes stream down as they receives them
CanGeneratesStream bool `json:"can_generates_stream,omitempty"`
// whether this node prioritize streaming input over none-streaming input.
// not include nodes that can accept both and does not have preference.
RequireStreamingInput bool `json:"can_process_stream,omitempty"`
}
func (s *NodeSchema) SetConfigKV(key string, value any) {
if s.Configs == nil {
s.Configs = make(map[string]any)
}
s.Configs.(map[string]any)[key] = value
}
func (s *NodeSchema) SetInputType(key string, t *vo.TypeInfo) {
if s.InputTypes == nil {
s.InputTypes = make(map[string]*vo.TypeInfo)
}
s.InputTypes[key] = t
}
func (s *NodeSchema) AddInputSource(info ...*vo.FieldInfo) {
s.InputSources = append(s.InputSources, info...)
}
func (s *NodeSchema) SetOutputType(key string, t *vo.TypeInfo) {
if s.OutputTypes == nil {
s.OutputTypes = make(map[string]*vo.TypeInfo)
}
s.OutputTypes[key] = t
}
func (s *NodeSchema) AddOutputSource(info ...*vo.FieldInfo) {
s.OutputSources = append(s.OutputSources, info...)
}

View File

@@ -0,0 +1,77 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package schema
import (
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
)
type FieldStreamType string
const (
FieldIsStream FieldStreamType = "yes" // absolutely a stream
FieldNotStream FieldStreamType = "no" // absolutely not a stream
FieldMaybeStream FieldStreamType = "maybe" // maybe a stream, requires request-time resolution
FieldSkipped FieldStreamType = "skipped" // the field source's node is skipped
)
type FieldSkipStatus string
// SourceInfo contains stream type for a input field source of a node.
type SourceInfo struct {
// IsIntermediate means this field is itself not a field source, but a map containing one or more field sources.
IsIntermediate bool
// FieldType the stream type of the field. May require request-time resolution in addition to compile-time.
FieldType FieldStreamType
// FromNodeKey is the node key that produces this field source. empty if the field is a static value or variable.
FromNodeKey vo.NodeKey
// FromPath is the path of this field source within the source node. empty if the field is a static value or variable.
FromPath compose.FieldPath
TypeInfo *vo.TypeInfo
// SubSources are SourceInfo for keys within this intermediate Map(Object) field.
SubSources map[string]*SourceInfo
}
func (s *SourceInfo) Skipped() bool {
if !s.IsIntermediate {
return s.FieldType == FieldSkipped
}
for _, sub := range s.SubSources {
if !sub.Skipped() {
return false
}
}
return true
}
func (s *SourceInfo) FromNode(nodeKey vo.NodeKey) bool {
if !s.IsIntermediate {
return s.FromNodeKey == nodeKey
}
for _, sub := range s.SubSources {
if sub.FromNode(nodeKey) {
return true
}
}
return false
}

View File

@@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
package compose package schema
import ( import (
"fmt" "fmt"
@@ -29,9 +29,10 @@ import (
) )
type WorkflowSchema struct { type WorkflowSchema struct {
Nodes []*NodeSchema `json:"nodes"` Nodes []*NodeSchema `json:"nodes"`
Connections []*Connection `json:"connections"` Connections []*Connection `json:"connections"`
Hierarchy map[vo.NodeKey]vo.NodeKey `json:"hierarchy,omitempty"` // child node key-> parent node key Hierarchy map[vo.NodeKey]vo.NodeKey `json:"hierarchy,omitempty"` // child node key-> parent node key
Branches map[vo.NodeKey]*BranchSchema `json:"branches,omitempty"`
GeneratedNodes []vo.NodeKey `json:"generated_nodes,omitempty"` // generated nodes for the nodes in batch mode GeneratedNodes []vo.NodeKey `json:"generated_nodes,omitempty"` // generated nodes for the nodes in batch mode
@@ -71,9 +72,19 @@ func (w *WorkflowSchema) Init() {
w.doGetCompositeNodes() w.doGetCompositeNodes()
for _, node := range w.Nodes { for _, node := range w.Nodes {
if node.requireCheckpoint() { if node.Type == entity.NodeTypeSubWorkflow {
w.requireCheckPoint = true node.SubWorkflowSchema.Init()
break if node.SubWorkflowSchema.requireCheckPoint {
w.requireCheckPoint = true
break
}
}
if rc, ok := node.Configs.(RequireCheckpoint); ok {
if rc.RequireCheckpoint() {
w.requireCheckPoint = true
break
}
} }
} }
@@ -97,6 +108,22 @@ func (w *WorkflowSchema) GetCompositeNodes() []*CompositeNode {
return w.compositeNodes return w.compositeNodes
} }
func (w *WorkflowSchema) GetBranch(key vo.NodeKey) *BranchSchema {
if w.Branches == nil {
return nil
}
return w.Branches[key]
}
func (w *WorkflowSchema) RequireCheckpoint() bool {
return w.requireCheckPoint
}
func (w *WorkflowSchema) RequireStreaming() bool {
return w.requireStreaming
}
func (w *WorkflowSchema) doGetCompositeNodes() (cNodes []*CompositeNode) { func (w *WorkflowSchema) doGetCompositeNodes() (cNodes []*CompositeNode) {
if w.Hierarchy == nil { if w.Hierarchy == nil {
return nil return nil
@@ -125,7 +152,7 @@ func (w *WorkflowSchema) doGetCompositeNodes() (cNodes []*CompositeNode) {
return cNodes return cNodes
} }
func isInSameWorkflow(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.NodeKey) bool { func IsInSameWorkflow(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.NodeKey) bool {
if n == nil { if n == nil {
return true return true
} }
@@ -144,7 +171,7 @@ func isInSameWorkflow(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.Node
return myParents == theirParents return myParents == theirParents
} }
func isBelowOneLevel(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.NodeKey) bool { func IsBelowOneLevel(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.NodeKey) bool {
if n == nil { if n == nil {
return false return false
} }
@@ -154,7 +181,7 @@ func isBelowOneLevel(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.NodeK
return myParentExists && !theirParentExists return myParentExists && !theirParentExists
} }
func isParentOf(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.NodeKey) bool { func IsParentOf(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.NodeKey) bool {
if n == nil { if n == nil {
return false return false
} }
@@ -230,21 +257,14 @@ func (w *WorkflowSchema) doRequireStreaming() bool {
consumers := make(map[vo.NodeKey]bool) consumers := make(map[vo.NodeKey]bool)
for _, node := range w.Nodes { for _, node := range w.Nodes {
meta := entity.NodeMetaByNodeType(node.Type) if node.StreamConfigs != nil && node.StreamConfigs.CanGeneratesStream {
if meta != nil { producers[node.Key] = true
sps := meta.ExecutableMeta.StreamingParadigms
if _, ok := sps[entity.Stream]; ok {
if node.StreamConfigs != nil && node.StreamConfigs.CanGeneratesStream {
producers[node.Key] = true
}
}
if sps[entity.Transform] || sps[entity.Collect] {
if node.StreamConfigs != nil && node.StreamConfigs.RequireStreamingInput {
consumers[node.Key] = true
}
}
} }
if node.StreamConfigs != nil && node.StreamConfigs.RequireStreamingInput {
consumers[node.Key] = true
}
} }
if len(producers) == 0 || len(consumers) == 0 { if len(producers) == 0 || len(consumers) == 0 {
@@ -290,7 +310,7 @@ func (w *WorkflowSchema) doRequireStreaming() bool {
return false return false
} }
func (w *WorkflowSchema) fanInMergeConfigs() map[string]compose.FanInMergeConfig { func (w *WorkflowSchema) FanInMergeConfigs() map[string]compose.FanInMergeConfig {
// what we need to do is to see if the workflow requires streaming, if not, then no fan-in merge configs needed // what we need to do is to see if the workflow requires streaming, if not, then no fan-in merge configs needed
// then we find those nodes that have 'transform' or 'collect' as streaming paradigm, // then we find those nodes that have 'transform' or 'collect' as streaming paradigm,
// and see if each of those nodes has multiple data predecessors, if so, it's a fan-in node. // and see if each of those nodes has multiple data predecessors, if so, it's a fan-in node.
@@ -301,21 +321,15 @@ func (w *WorkflowSchema) fanInMergeConfigs() map[string]compose.FanInMergeConfig
fanInNodes := make(map[vo.NodeKey]bool) fanInNodes := make(map[vo.NodeKey]bool)
for _, node := range w.Nodes { for _, node := range w.Nodes {
meta := entity.NodeMetaByNodeType(node.Type) if node.StreamConfigs != nil && node.StreamConfigs.RequireStreamingInput {
if meta != nil { var predecessor *vo.NodeKey
sps := meta.ExecutableMeta.StreamingParadigms for _, source := range node.InputSources {
if sps[entity.Transform] || sps[entity.Collect] { if source.Source.Ref != nil && len(source.Source.Ref.FromNodeKey) > 0 {
if node.StreamConfigs != nil && node.StreamConfigs.RequireStreamingInput { if predecessor != nil {
var predecessor *vo.NodeKey fanInNodes[node.Key] = true
for _, source := range node.InputSources { break
if source.Source.Ref != nil && len(source.Source.Ref.FromNodeKey) > 0 {
if predecessor != nil {
fanInNodes[node.Key] = true
break
}
predecessor = &source.Source.Ref.FromNodeKey
}
} }
predecessor = &source.Source.Ref.FromNodeKey
} }
} }
} }

View File

@@ -37,8 +37,11 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/adaptor" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/adaptor"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/intentdetector"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/repo" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/repo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/infra/contract/cache" "github.com/coze-dev/coze-studio/backend/infra/contract/cache"
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel" "github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen" "github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
@@ -73,7 +76,7 @@ func NewWorkflowRepository(idgen idgen.IDGenerator, db *gorm.DB, redis cache.Cmd
return repo.NewRepository(idgen, db, redis, tos, cpStore, chatModel) return repo.NewRepository(idgen, db, redis, tos, cpStore, chatModel)
} }
func (i *impl) ListNodeMeta(ctx context.Context, nodeTypes map[entity.NodeType]bool) (map[string][]*entity.NodeTypeMeta, []entity.Category, error) { func (i *impl) ListNodeMeta(_ context.Context, nodeTypes map[entity.NodeType]bool) (map[string][]*entity.NodeTypeMeta, []entity.Category, error) {
// Initialize result maps // Initialize result maps
nodeMetaMap := make(map[string][]*entity.NodeTypeMeta) nodeMetaMap := make(map[string][]*entity.NodeTypeMeta)
@@ -82,7 +85,7 @@ func (i *impl) ListNodeMeta(ctx context.Context, nodeTypes map[entity.NodeType]b
if meta.Disabled { if meta.Disabled {
return false return false
} }
nodeType := meta.Type nodeType := meta.Key
if nodeTypes == nil || len(nodeTypes) == 0 { if nodeTypes == nil || len(nodeTypes) == 0 {
return true // No filter, include all return true // No filter, include all
} }
@@ -192,10 +195,10 @@ func extractInputsAndOutputsNamedInfoList(c *vo.Canvas) (inputs []*vo.NamedTypeI
if startNode != nil && endNode != nil { if startNode != nil && endNode != nil {
break break
} }
if node.Type == vo.BlockTypeBotStart { if node.Type == entity.NodeTypeEntry.IDStr() {
startNode = node startNode = node
} }
if node.Type == vo.BlockTypeBotEnd { if node.Type == entity.NodeTypeExit.IDStr() {
endNode = node endNode = node
} }
} }
@@ -207,7 +210,7 @@ func extractInputsAndOutputsNamedInfoList(c *vo.Canvas) (inputs []*vo.NamedTypeI
if err != nil { if err != nil {
return nil, err return nil, err
} }
nInfo, err := adaptor.VariableToNamedTypeInfo(v) nInfo, err := convert.VariableToNamedTypeInfo(v)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -220,7 +223,7 @@ func extractInputsAndOutputsNamedInfoList(c *vo.Canvas) (inputs []*vo.NamedTypeI
if endNode != nil { if endNode != nil {
outputs, err = slices.TransformWithErrorCheck(endNode.Data.Inputs.InputParameters, func(a *vo.Param) (*vo.NamedTypeInfo, error) { outputs, err = slices.TransformWithErrorCheck(endNode.Data.Inputs.InputParameters, func(a *vo.Param) (*vo.NamedTypeInfo, error) {
return adaptor.BlockInputToNamedTypeInfo(a.Name, a.Input) return convert.BlockInputToNamedTypeInfo(a.Name, a.Input)
}) })
if err != nil { if err != nil {
logs.Warn(fmt.Sprintf("transform end node inputs to named info failed, err=%v", err)) logs.Warn(fmt.Sprintf("transform end node inputs to named info failed, err=%v", err))
@@ -316,6 +319,34 @@ func (i *impl) GetWorkflowReference(ctx context.Context, id int64) (map[int64]*v
return ret, nil return ret, nil
} }
type workflowIdentity struct {
ID string `json:"id"`
Version string `json:"version"`
}
func getAllSubWorkflowIdentities(c *vo.Canvas) []*workflowIdentity {
workflowEntities := make([]*workflowIdentity, 0)
var collectSubWorkFlowEntities func(nodes []*vo.Node)
collectSubWorkFlowEntities = func(nodes []*vo.Node) {
for _, n := range nodes {
if n.Type == entity.NodeTypeSubWorkflow.IDStr() {
workflowEntities = append(workflowEntities, &workflowIdentity{
ID: n.Data.Inputs.WorkflowID,
Version: n.Data.Inputs.WorkflowVersion,
})
}
if len(n.Blocks) > 0 {
collectSubWorkFlowEntities(n.Blocks)
}
}
}
collectSubWorkFlowEntities(c.Nodes)
return workflowEntities
}
func (i *impl) ValidateTree(ctx context.Context, id int64, validateConfig vo.ValidateTreeConfig) ([]*cloudworkflow.ValidateTreeInfo, error) { func (i *impl) ValidateTree(ctx context.Context, id int64, validateConfig vo.ValidateTreeConfig) ([]*cloudworkflow.ValidateTreeInfo, error) {
wfValidateInfos := make([]*cloudworkflow.ValidateTreeInfo, 0) wfValidateInfos := make([]*cloudworkflow.ValidateTreeInfo, 0)
issues, err := validateWorkflowTree(ctx, validateConfig) issues, err := validateWorkflowTree(ctx, validateConfig)
@@ -337,7 +368,7 @@ func (i *impl) ValidateTree(ctx context.Context, id int64, validateConfig vo.Val
fmt.Errorf("failed to unmarshal canvas schema: %w", err)) fmt.Errorf("failed to unmarshal canvas schema: %w", err))
} }
subWorkflowIdentities := c.GetAllSubWorkflowIdentities() subWorkflowIdentities := getAllSubWorkflowIdentities(c)
if len(subWorkflowIdentities) > 0 { if len(subWorkflowIdentities) > 0 {
var ids []int64 var ids []int64
@@ -421,25 +452,21 @@ func (i *impl) collectNodePropertyMap(ctx context.Context, canvas *vo.Canvas) (m
} }
for _, n := range canvas.Nodes { for _, n := range canvas.Nodes {
if n.Type == vo.BlockTypeBotSubWorkflow { if n.Type == entity.NodeTypeSubWorkflow.IDStr() {
nodeSchema := &compose.NodeSchema{ nodeSchema := &schema.NodeSchema{
Key: vo.NodeKey(n.ID), Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeSubWorkflow, Type: entity.NodeTypeSubWorkflow,
Name: n.Data.Meta.Title, Name: n.Data.Meta.Title,
} }
err := adaptor.SetInputsForNodeSchema(n, nodeSchema) err := convert.SetInputsForNodeSchema(n, nodeSchema)
if err != nil {
return nil, err
}
blockType, err := entityNodeTypeToBlockType(nodeSchema.Type)
if err != nil { if err != nil {
return nil, err return nil, err
} }
prop := &vo.NodeProperty{ prop := &vo.NodeProperty{
Type: string(blockType), Type: nodeSchema.Type.IDStr(),
IsEnableUserQuery: nodeSchema.IsEnableUserQuery(), IsEnableUserQuery: isEnableUserQuery(nodeSchema),
IsEnableChatHistory: nodeSchema.IsEnableChatHistory(), IsEnableChatHistory: isEnableChatHistory(nodeSchema),
IsRefGlobalVariable: nodeSchema.IsRefGlobalVariable(), IsRefGlobalVariable: isRefGlobalVariable(nodeSchema),
} }
nodePropertyMap[string(nodeSchema.Key)] = prop nodePropertyMap[string(nodeSchema.Key)] = prop
wid, err := strconv.ParseInt(n.Data.Inputs.WorkflowID, 10, 64) wid, err := strconv.ParseInt(n.Data.Inputs.WorkflowID, 10, 64)
@@ -478,20 +505,16 @@ func (i *impl) collectNodePropertyMap(ctx context.Context, canvas *vo.Canvas) (m
prop.SubWorkflow = ret prop.SubWorkflow = ret
} else { } else {
nodeSchemas, _, err := adaptor.NodeToNodeSchema(ctx, n) nodeSchemas, _, err := adaptor.NodeToNodeSchema(ctx, n, canvas)
if err != nil { if err != nil {
return nil, err return nil, err
} }
for _, nodeSchema := range nodeSchemas { for _, nodeSchema := range nodeSchemas {
blockType, err := entityNodeTypeToBlockType(nodeSchema.Type)
if err != nil {
return nil, err
}
nodePropertyMap[string(nodeSchema.Key)] = &vo.NodeProperty{ nodePropertyMap[string(nodeSchema.Key)] = &vo.NodeProperty{
Type: string(blockType), Type: nodeSchema.Type.IDStr(),
IsEnableUserQuery: nodeSchema.IsEnableUserQuery(), IsEnableUserQuery: isEnableUserQuery(nodeSchema),
IsEnableChatHistory: nodeSchema.IsEnableChatHistory(), IsEnableChatHistory: isEnableChatHistory(nodeSchema),
IsRefGlobalVariable: nodeSchema.IsRefGlobalVariable(), IsRefGlobalVariable: isRefGlobalVariable(nodeSchema),
} }
} }
@@ -500,6 +523,60 @@ func (i *impl) collectNodePropertyMap(ctx context.Context, canvas *vo.Canvas) (m
return nodePropertyMap, nil return nodePropertyMap, nil
} }
func isEnableUserQuery(s *schema.NodeSchema) bool {
if s == nil {
return false
}
if s.Type != entity.NodeTypeEntry {
return false
}
if len(s.OutputSources) == 0 {
return false
}
for _, source := range s.OutputSources {
fieldPath := source.Path
if len(fieldPath) == 1 && (fieldPath[0] == "BOT_USER_INPUT" || fieldPath[0] == "USER_INPUT") {
return true
}
}
return false
}
func isEnableChatHistory(s *schema.NodeSchema) bool {
if s == nil {
return false
}
switch s.Type {
case entity.NodeTypeLLM:
llmParam := s.Configs.(*llm.Config).LLMParams
return llmParam.EnableChatHistory
case entity.NodeTypeIntentDetector:
llmParam := s.Configs.(*intentdetector.Config).LLMParams
return llmParam.EnableChatHistory
default:
return false
}
}
func isRefGlobalVariable(s *schema.NodeSchema) bool {
for _, source := range s.InputSources {
if source.IsRefGlobalVariable() {
return true
}
}
for _, source := range s.OutputSources {
if source.IsRefGlobalVariable() {
return true
}
}
return false
}
func canvasToRefs(referringID int64, canvasStr string) (map[entity.WorkflowReferenceKey]struct{}, error) { func canvasToRefs(referringID int64, canvasStr string) (map[entity.WorkflowReferenceKey]struct{}, error) {
var canvas vo.Canvas var canvas vo.Canvas
if err := sonic.UnmarshalString(canvasStr, &canvas); err != nil { if err := sonic.UnmarshalString(canvasStr, &canvas); err != nil {
@@ -510,7 +587,7 @@ func canvasToRefs(referringID int64, canvasStr string) (map[entity.WorkflowRefer
var getRefFn func([]*vo.Node) error var getRefFn func([]*vo.Node) error
getRefFn = func(nodes []*vo.Node) error { getRefFn = func(nodes []*vo.Node) error {
for _, node := range nodes { for _, node := range nodes {
if node.Type == vo.BlockTypeBotSubWorkflow { if node.Type == entity.NodeTypeSubWorkflow.IDStr() {
referredID, err := strconv.ParseInt(node.Data.Inputs.WorkflowID, 10, 64) referredID, err := strconv.ParseInt(node.Data.Inputs.WorkflowID, 10, 64)
if err != nil { if err != nil {
return vo.WrapError(errno.ErrSchemaConversionFail, err) return vo.WrapError(errno.ErrSchemaConversionFail, err)
@@ -521,19 +598,21 @@ func canvasToRefs(referringID int64, canvasStr string) (map[entity.WorkflowRefer
ReferType: vo.ReferTypeSubWorkflow, ReferType: vo.ReferTypeSubWorkflow,
ReferringBizType: vo.ReferringBizTypeWorkflow, ReferringBizType: vo.ReferringBizTypeWorkflow,
}] = struct{}{} }] = struct{}{}
} else if node.Type == vo.BlockTypeBotLLM { } else if node.Type == entity.NodeTypeLLM.IDStr() {
if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil { if node.Data.Inputs.LLM != nil {
for _, w := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList { if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil {
referredID, err := strconv.ParseInt(w.WorkflowID, 10, 64) for _, w := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList {
if err != nil { referredID, err := strconv.ParseInt(w.WorkflowID, 10, 64)
return vo.WrapError(errno.ErrSchemaConversionFail, err) if err != nil {
return vo.WrapError(errno.ErrSchemaConversionFail, err)
}
wfRefs[entity.WorkflowReferenceKey{
ReferredID: referredID,
ReferringID: referringID,
ReferType: vo.ReferTypeTool,
ReferringBizType: vo.ReferringBizTypeWorkflow,
}] = struct{}{}
} }
wfRefs[entity.WorkflowReferenceKey{
ReferredID: referredID,
ReferringID: referringID,
ReferType: vo.ReferTypeTool,
ReferringBizType: vo.ReferringBizTypeWorkflow,
}] = struct{}{}
} }
} }
} else if len(node.Blocks) > 0 { } else if len(node.Blocks) > 0 {
@@ -832,7 +911,7 @@ func (i *impl) CopyWorkflowFromAppToLibrary(ctx context.Context, workflowID int6
validateAndBuildWorkflowReference = func(nodes []*vo.Node, wf *copiedWorkflow) error { validateAndBuildWorkflowReference = func(nodes []*vo.Node, wf *copiedWorkflow) error {
for _, node := range nodes { for _, node := range nodes {
if node.Type == vo.BlockTypeBotSubWorkflow { if node.Type == entity.NodeTypeSubWorkflow.IDStr() {
var ( var (
v *vo.DraftInfo v *vo.DraftInfo
wfID int64 wfID int64
@@ -883,7 +962,7 @@ func (i *impl) CopyWorkflowFromAppToLibrary(ctx context.Context, workflowID int6
} }
if node.Type == vo.BlockTypeBotLLM { if node.Type == entity.NodeTypeLLM.IDStr() {
if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil { if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil {
for _, w := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList { for _, w := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList {
var ( var (
@@ -1086,7 +1165,7 @@ func (i *impl) DuplicateWorkflowsByAppID(ctx context.Context, sourceAppID, targe
var buildWorkflowReference func(nodes []*vo.Node, wf *copiedWorkflow) error var buildWorkflowReference func(nodes []*vo.Node, wf *copiedWorkflow) error
buildWorkflowReference = func(nodes []*vo.Node, wf *copiedWorkflow) error { buildWorkflowReference = func(nodes []*vo.Node, wf *copiedWorkflow) error {
for _, node := range nodes { for _, node := range nodes {
if node.Type == vo.BlockTypeBotSubWorkflow { if node.Type == entity.NodeTypeSubWorkflow.IDStr() {
var ( var (
v *vo.DraftInfo v *vo.DraftInfo
wfID int64 wfID int64
@@ -1121,7 +1200,7 @@ func (i *impl) DuplicateWorkflowsByAppID(ctx context.Context, sourceAppID, targe
} }
} }
if node.Type == vo.BlockTypeBotLLM { if node.Type == entity.NodeTypeLLM.IDStr() {
if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil { if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil {
for _, w := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList { for _, w := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList {
var ( var (
@@ -1323,8 +1402,8 @@ func (i *impl) GetWorkflowDependenceResource(ctx context.Context, workflowID int
var collectDependence func(nodes []*vo.Node) error var collectDependence func(nodes []*vo.Node) error
collectDependence = func(nodes []*vo.Node) error { collectDependence = func(nodes []*vo.Node) error {
for _, node := range nodes { for _, node := range nodes {
switch node.Type { switch entity.IDStrToNodeType(node.Type) {
case vo.BlockTypeBotAPI: case entity.NodeTypePlugin:
apiParams := slices.ToMap(node.Data.Inputs.APIParams, func(e *vo.Param) (string, *vo.Param) { apiParams := slices.ToMap(node.Data.Inputs.APIParams, func(e *vo.Param) (string, *vo.Param) {
return e.Name, e return e.Name, e
}) })
@@ -1347,7 +1426,7 @@ func (i *impl) GetWorkflowDependenceResource(ctx context.Context, workflowID int
ds.PluginIDs = append(ds.PluginIDs, pID) ds.PluginIDs = append(ds.PluginIDs, pID)
} }
case vo.BlockTypeBotDatasetWrite, vo.BlockTypeBotDataset: case entity.NodeTypeKnowledgeIndexer, entity.NodeTypeKnowledgeRetriever:
datasetListInfoParam := node.Data.Inputs.DatasetParam[0] datasetListInfoParam := node.Data.Inputs.DatasetParam[0]
datasetIDs := datasetListInfoParam.Input.Value.Content.([]any) datasetIDs := datasetListInfoParam.Input.Value.Content.([]any)
for _, id := range datasetIDs { for _, id := range datasetIDs {
@@ -1357,7 +1436,7 @@ func (i *impl) GetWorkflowDependenceResource(ctx context.Context, workflowID int
} }
ds.KnowledgeIDs = append(ds.KnowledgeIDs, k) ds.KnowledgeIDs = append(ds.KnowledgeIDs, k)
} }
case vo.BlockTypeDatabase, vo.BlockTypeDatabaseSelect, vo.BlockTypeDatabaseInsert, vo.BlockTypeDatabaseDelete, vo.BlockTypeDatabaseUpdate: case entity.NodeTypeDatabaseCustomSQL, entity.NodeTypeDatabaseQuery, entity.NodeTypeDatabaseInsert, entity.NodeTypeDatabaseDelete, entity.NodeTypeDatabaseUpdate:
dsList := node.Data.Inputs.DatabaseInfoList dsList := node.Data.Inputs.DatabaseInfoList
if len(dsList) == 0 { if len(dsList) == 0 {
return fmt.Errorf("database info is requird") return fmt.Errorf("database info is requird")
@@ -1369,7 +1448,7 @@ func (i *impl) GetWorkflowDependenceResource(ctx context.Context, workflowID int
} }
ds.DatabaseIDs = append(ds.DatabaseIDs, dsID) ds.DatabaseIDs = append(ds.DatabaseIDs, dsID)
} }
case vo.BlockTypeBotLLM: case entity.NodeTypeLLM:
if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.PluginFCParam != nil { if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.PluginFCParam != nil {
for idx := range node.Data.Inputs.FCParam.PluginFCParam.PluginList { for idx := range node.Data.Inputs.FCParam.PluginFCParam.PluginList {
pl := node.Data.Inputs.FCParam.PluginFCParam.PluginList[idx] pl := node.Data.Inputs.FCParam.PluginFCParam.PluginList[idx]
@@ -1396,7 +1475,7 @@ func (i *impl) GetWorkflowDependenceResource(ctx context.Context, workflowID int
} }
} }
case vo.BlockTypeBotSubWorkflow: case entity.NodeTypeSubWorkflow:
wfID, err := strconv.ParseInt(node.Data.Inputs.WorkflowID, 10, 64) wfID, err := strconv.ParseInt(node.Data.Inputs.WorkflowID, 10, 64)
if err != nil { if err != nil {
return err return err
@@ -1567,8 +1646,8 @@ func replaceRelatedWorkflowOrExternalResourceInWorkflowNodes(nodes []*vo.Node, r
) )
for _, node := range nodes { for _, node := range nodes {
switch node.Type { switch entity.IDStrToNodeType(node.Type) {
case vo.BlockTypeBotSubWorkflow: case entity.NodeTypeSubWorkflow:
if !hasWorkflowRelated { if !hasWorkflowRelated {
continue continue
} }
@@ -1580,7 +1659,7 @@ func replaceRelatedWorkflowOrExternalResourceInWorkflowNodes(nodes []*vo.Node, r
node.Data.Inputs.WorkflowID = strconv.FormatInt(wf.ID, 10) node.Data.Inputs.WorkflowID = strconv.FormatInt(wf.ID, 10)
node.Data.Inputs.WorkflowVersion = wf.Version node.Data.Inputs.WorkflowVersion = wf.Version
} }
case vo.BlockTypeBotAPI: case entity.NodeTypePlugin:
if !hasPluginRelated { if !hasPluginRelated {
continue continue
} }
@@ -1623,7 +1702,7 @@ func replaceRelatedWorkflowOrExternalResourceInWorkflowNodes(nodes []*vo.Node, r
apiIDParam.Input.Value.Content = strconv.FormatInt(refApiID, 10) apiIDParam.Input.Value.Content = strconv.FormatInt(refApiID, 10)
} }
case vo.BlockTypeBotLLM: case entity.NodeTypeLLM:
if hasWorkflowRelated && node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil { if hasWorkflowRelated && node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil {
for idx := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList { for idx := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList {
wf := node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList[idx] wf := node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList[idx]
@@ -1669,7 +1748,7 @@ func replaceRelatedWorkflowOrExternalResourceInWorkflowNodes(nodes []*vo.Node, r
} }
} }
case vo.BlockTypeBotDataset, vo.BlockTypeBotDatasetWrite: case entity.NodeTypeKnowledgeIndexer, entity.NodeTypeKnowledgeRetriever:
if !hasKnowledgeRelated { if !hasKnowledgeRelated {
continue continue
} }
@@ -1685,7 +1764,7 @@ func replaceRelatedWorkflowOrExternalResourceInWorkflowNodes(nodes []*vo.Node, r
} }
} }
case vo.BlockTypeDatabase, vo.BlockTypeDatabaseSelect, vo.BlockTypeDatabaseInsert, vo.BlockTypeDatabaseDelete, vo.BlockTypeDatabaseUpdate: case entity.NodeTypeDatabaseCustomSQL, entity.NodeTypeDatabaseQuery, entity.NodeTypeDatabaseInsert, entity.NodeTypeDatabaseDelete, entity.NodeTypeDatabaseUpdate:
if !hasDatabaseRelated { if !hasDatabaseRelated {
continue continue
} }
@@ -1713,3 +1792,7 @@ func replaceRelatedWorkflowOrExternalResourceInWorkflowNodes(nodes []*vo.Node, r
} }
return nil return nil
} }
func RegisterAllNodeAdaptors() {
adaptor.RegisterAllNodeAdaptors()
}

View File

@@ -24,11 +24,9 @@ import (
cloudworkflow "github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/workflow" cloudworkflow "github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable" "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/adaptor" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/adaptor"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/validate" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/validate"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/pkg/sonic" "github.com/coze-dev/coze-studio/backend/pkg/sonic"
"github.com/coze-dev/coze-studio/backend/types/errno" "github.com/coze-dev/coze-studio/backend/types/errno"
) )
@@ -199,158 +197,3 @@ func isIncremental(prev version, next version) bool {
return next.Patch > prev.Patch return next.Patch > prev.Patch
} }
func replaceRelatedWorkflowOrPluginInWorkflowNodes(nodes []*vo.Node, relatedWorkflows map[int64]entity.IDVersionPair, relatedPlugins map[int64]vo.PluginEntity) error {
for _, node := range nodes {
if node.Type == vo.BlockTypeBotSubWorkflow {
workflowID, err := strconv.ParseInt(node.Data.Inputs.WorkflowID, 10, 64)
if err != nil {
return err
}
if wf, ok := relatedWorkflows[workflowID]; ok {
node.Data.Inputs.WorkflowID = strconv.FormatInt(wf.ID, 10)
node.Data.Inputs.WorkflowVersion = wf.Version
}
}
if node.Type == vo.BlockTypeBotAPI {
apiParams := slices.ToMap(node.Data.Inputs.APIParams, func(e *vo.Param) (string, *vo.Param) {
return e.Name, e
})
pluginIDParam, ok := apiParams["pluginID"]
if !ok {
return fmt.Errorf("plugin id param is not found")
}
pID, err := strconv.ParseInt(pluginIDParam.Input.Value.Content.(string), 10, 64)
if err != nil {
return err
}
pluginVersionParam, ok := apiParams["pluginVersion"]
if !ok {
return fmt.Errorf("plugin version param is not found")
}
if refPlugin, ok := relatedPlugins[pID]; ok {
pluginIDParam.Input.Value.Content = refPlugin.PluginID
if refPlugin.PluginVersion != nil {
pluginVersionParam.Input.Value.Content = *refPlugin.PluginVersion
}
}
}
if node.Type == vo.BlockTypeBotLLM {
if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil {
for idx := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList {
wf := node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList[idx]
workflowID, err := strconv.ParseInt(wf.WorkflowID, 10, 64)
if err != nil {
return err
}
if refWf, ok := relatedWorkflows[workflowID]; ok {
wf.WorkflowID = strconv.FormatInt(refWf.ID, 10)
wf.WorkflowVersion = refWf.Version
}
}
}
if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.PluginFCParam != nil {
for idx := range node.Data.Inputs.FCParam.PluginFCParam.PluginList {
pl := node.Data.Inputs.FCParam.PluginFCParam.PluginList[idx]
pluginID, err := strconv.ParseInt(pl.PluginID, 10, 64)
if err != nil {
return err
}
if refPlugin, ok := relatedPlugins[pluginID]; ok {
pl.PluginID = strconv.FormatInt(refPlugin.PluginID, 10)
if refPlugin.PluginVersion != nil {
pl.PluginVersion = *refPlugin.PluginVersion
}
}
}
}
}
if len(node.Blocks) > 0 {
err := replaceRelatedWorkflowOrPluginInWorkflowNodes(node.Blocks, relatedWorkflows, relatedPlugins)
if err != nil {
return err
}
}
}
return nil
}
// entityNodeTypeToBlockType converts an entity.NodeType to the corresponding vo.BlockType.
func entityNodeTypeToBlockType(nodeType entity.NodeType) (vo.BlockType, error) {
switch nodeType {
case entity.NodeTypeEntry:
return vo.BlockTypeBotStart, nil
case entity.NodeTypeExit:
return vo.BlockTypeBotEnd, nil
case entity.NodeTypeLLM:
return vo.BlockTypeBotLLM, nil
case entity.NodeTypePlugin:
return vo.BlockTypeBotAPI, nil
case entity.NodeTypeCodeRunner:
return vo.BlockTypeBotCode, nil
case entity.NodeTypeKnowledgeRetriever:
return vo.BlockTypeBotDataset, nil
case entity.NodeTypeSelector:
return vo.BlockTypeCondition, nil
case entity.NodeTypeSubWorkflow:
return vo.BlockTypeBotSubWorkflow, nil
case entity.NodeTypeDatabaseCustomSQL:
return vo.BlockTypeDatabase, nil
case entity.NodeTypeOutputEmitter:
return vo.BlockTypeBotMessage, nil
case entity.NodeTypeTextProcessor:
return vo.BlockTypeBotText, nil
case entity.NodeTypeQuestionAnswer:
return vo.BlockTypeQuestion, nil
case entity.NodeTypeBreak:
return vo.BlockTypeBotBreak, nil
case entity.NodeTypeVariableAssigner:
return vo.BlockTypeBotAssignVariable, nil
case entity.NodeTypeVariableAssignerWithinLoop:
return vo.BlockTypeBotLoopSetVariable, nil
case entity.NodeTypeLoop:
return vo.BlockTypeBotLoop, nil
case entity.NodeTypeIntentDetector:
return vo.BlockTypeBotIntent, nil
case entity.NodeTypeKnowledgeIndexer:
return vo.BlockTypeBotDatasetWrite, nil
case entity.NodeTypeBatch:
return vo.BlockTypeBotBatch, nil
case entity.NodeTypeContinue:
return vo.BlockTypeBotContinue, nil
case entity.NodeTypeInputReceiver:
return vo.BlockTypeBotInput, nil
case entity.NodeTypeDatabaseUpdate:
return vo.BlockTypeDatabaseUpdate, nil
case entity.NodeTypeDatabaseQuery:
return vo.BlockTypeDatabaseSelect, nil
case entity.NodeTypeDatabaseDelete:
return vo.BlockTypeDatabaseDelete, nil
case entity.NodeTypeHTTPRequester:
return vo.BlockTypeBotHttp, nil
case entity.NodeTypeDatabaseInsert:
return vo.BlockTypeDatabaseInsert, nil
case entity.NodeTypeVariableAggregator:
return vo.BlockTypeBotVariableMerge, nil
case entity.NodeTypeJsonSerialization:
return vo.BlockTypeJsonSerialization, nil
case entity.NodeTypeJsonDeserialization:
return vo.BlockTypeJsonDeserialization, nil
case entity.NodeTypeKnowledgeDeleter:
return vo.BlockTypeBotDatasetDelete, nil
default:
return "", vo.WrapError(errno.ErrSchemaConversionFail,
fmt.Errorf("cannot map entity node type '%s' to a workflow.NodeTemplateType", nodeType))
}
}

View File

@@ -82,7 +82,7 @@ func init() {
code.Register( code.Register(
ErrMissingRequiredParam, ErrMissingRequiredParam,
"Missing required parameters {param}. Please review the API documentation and ensure all mandatory fields are included in your request.", "Missing required parameters{param}. Please review the API documentation and ensure all mandatory fields are included in your request.",
code.WithAffectStability(false), code.WithAffectStability(false),
) )