refactor: how to add a node type in workflow (#558)

This commit is contained in:
shentongmartin 2025-08-05 14:02:33 +08:00 committed by GitHub
parent 5dafd81a3f
commit bb6ff0026b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
96 changed files with 8305 additions and 8717 deletions

View File

@ -105,26 +105,28 @@ import (
func TestMain(m *testing.M) {
callbacks.AppendGlobalHandlers(service.GetTokenCallbackHandler())
service.RegisterAllNodeAdaptors()
os.Exit(m.Run())
}
type wfTestRunner struct {
t *testing.T
h *server.Hertz
ctrl *gomock.Controller
idGen *mock.MockIDGenerator
search *searchmock.MockNotifier
appVarS *mockvar.MockStore
userVarS *mockvar.MockStore
varGetter *mockvar.MockVariablesMetaGetter
modelManage *mockmodel.MockManager
plugin *mockPlugin.MockPluginService
tos *storageMock.MockStorage
knowledge *knowledgemock.MockKnowledgeOperator
database *databasemock.MockDatabaseOperator
pluginSrv *pluginmock.MockService
ctx context.Context
closeFn func()
t *testing.T
h *server.Hertz
ctrl *gomock.Controller
idGen *mock.MockIDGenerator
search *searchmock.MockNotifier
appVarS *mockvar.MockStore
userVarS *mockvar.MockStore
varGetter *mockvar.MockVariablesMetaGetter
modelManage *mockmodel.MockManager
plugin *mockPlugin.MockPluginService
tos *storageMock.MockStorage
knowledge *knowledgemock.MockKnowledgeOperator
database *databasemock.MockDatabaseOperator
pluginSrv *pluginmock.MockService
internalModel *testutil.UTChatModel
ctx context.Context
closeFn func()
}
var req2URL = map[reflect.Type]string{
@ -243,9 +245,11 @@ func newWfTestRunner(t *testing.T) *wfTestRunner {
cpStore := checkpoint.NewRedisStore(redisClient)
utChatModel := &testutil.UTChatModel{}
mockTos := storageMock.NewMockStorage(ctrl)
mockTos.EXPECT().GetObjectUrl(gomock.Any(), gomock.Any(), gomock.Any()).Return("", nil).AnyTimes()
workflowRepo := service.NewWorkflowRepository(mockIDGen, db, redisClient, mockTos, cpStore, nil)
workflowRepo := service.NewWorkflowRepository(mockIDGen, db, redisClient, mockTos, cpStore, utChatModel)
mockey.Mock(appworkflow.GetWorkflowDomainSVC).Return(service.NewWorkflowService(workflowRepo)).Build()
mockey.Mock(workflow2.GetRepository).Return(workflowRepo).Build()
@ -312,22 +316,23 @@ func newWfTestRunner(t *testing.T) *wfTestRunner {
}
return &wfTestRunner{
t: t,
h: h,
ctrl: ctrl,
idGen: mockIDGen,
search: mockSearchNotify,
appVarS: mockGlobalAppVarStore,
userVarS: mockGlobalUserVarStore,
varGetter: mockVarGetter,
modelManage: mockModelManage,
plugin: mPlugin,
tos: mockTos,
knowledge: mockKwOperator,
database: mockDatabaseOperator,
ctx: context.Background(),
closeFn: f,
pluginSrv: mockPluginSrv,
t: t,
h: h,
ctrl: ctrl,
idGen: mockIDGen,
search: mockSearchNotify,
appVarS: mockGlobalAppVarStore,
userVarS: mockGlobalUserVarStore,
varGetter: mockVarGetter,
modelManage: mockModelManage,
plugin: mPlugin,
tos: mockTos,
knowledge: mockKwOperator,
database: mockDatabaseOperator,
internalModel: utChatModel,
ctx: context.Background(),
closeFn: f,
pluginSrv: mockPluginSrv,
}
}
@ -1110,7 +1115,8 @@ func TestValidateTree(t *testing.T) {
assert.Equal(t, i.Message, `node "代码_1" not connected`)
}
if i.NodeError.NodeID == "160892" {
assert.Contains(t, i.Message, `node "意图识别"'s port "branch_1" not connected`, `node "意图识别"'s port "default" not connected;`)
assert.Contains(t, i.Message, `node "意图识别"'s port "branch_1" not connected`)
assert.Contains(t, i.Message, `node "意图识别"'s port "default" not connected`)
}
}
@ -1157,7 +1163,8 @@ func TestValidateTree(t *testing.T) {
assert.Equal(t, i.Message, `node "代码_1" not connected`)
}
if i.NodeError.NodeID == "160892" {
assert.Contains(t, i.Message, `node "意图识别"'s port "branch_1" not connected`, `node "意图识别"'s port "default" not connected;`)
assert.Contains(t, i.Message, `node "意图识别"'s port "branch_1" not connected`)
assert.Contains(t, i.Message, `node "意图识别"'s port "default" not connected`)
}
}
}
@ -2950,41 +2957,41 @@ func TestLLMWithSkills(t *testing.T) {
r := newWfTestRunner(t)
defer r.closeFn()
utChatModel := &testutil.UTChatModel{
InvokeResultProvider: func(index int, in []*schema.Message) (*schema.Message, error) {
if index == 0 {
assert.Equal(t, 1, len(in))
assert.Contains(t, in[0].Content, "7512369185624686592", "你是一个知识库意图识别AI Agent", "北京有哪些著名的景点")
return &schema.Message{
Role: schema.Assistant,
Content: "7512369185624686592",
ResponseMeta: &schema.ResponseMeta{
Usage: &schema.TokenUsage{
PromptTokens: 10,
CompletionTokens: 11,
TotalTokens: 21,
},
utChatModel := r.internalModel
utChatModel.InvokeResultProvider = func(index int, in []*schema.Message) (*schema.Message, error) {
if index == 0 {
assert.Equal(t, 1, len(in))
assert.Contains(t, in[0].Content, "7512369185624686592", "你是一个知识库意图识别AI Agent", "北京有哪些著名的景点")
return &schema.Message{
Role: schema.Assistant,
Content: "7512369185624686592",
ResponseMeta: &schema.ResponseMeta{
Usage: &schema.TokenUsage{
PromptTokens: 10,
CompletionTokens: 11,
TotalTokens: 21,
},
}, nil
},
}, nil
} else if index == 1 {
assert.Equal(t, 2, len(in))
for _, message := range in {
if message.Role == schema.System {
assert.Equal(t, "你是一个旅游推荐专家,通过用户提出的问题,推荐用户具体城市的旅游景点", message.Content)
}
if message.Role == schema.User {
assert.Contains(t, message.Content, "天安门广场 ‌:中国政治文化中心,见证了近现代重大历史事件‌", "八达岭长城 ‌:明代长城的精华段,被誉为“不到长城非好汉")
}
} else if index == 1 {
assert.Equal(t, 2, len(in))
for _, message := range in {
if message.Role == schema.System {
assert.Equal(t, "你是一个旅游推荐专家,通过用户提出的问题,推荐用户具体城市的旅游景点", message.Content)
}
if message.Role == schema.User {
assert.Contains(t, message.Content, "天安门广场 ‌:中国政治文化中心,见证了近现代重大历史事件‌", "八达岭长城 ‌:明代长城的精华段,被誉为“不到长城非好汉")
}
return &schema.Message{
Role: schema.Assistant,
Content: `八达岭长城 ‌:明代长城的精华段,被誉为“不到长城非好汉‌`,
}, nil
}
return nil, fmt.Errorf("unexpected index: %d", index)
},
return &schema.Message{
Role: schema.Assistant,
Content: `八达岭长城 ‌:明代长城的精华段,被誉为“不到长城非好汉‌`,
}, nil
}
return nil, fmt.Errorf("unexpected index: %d", index)
}
r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(utChatModel, nil, nil).AnyTimes()
r.knowledge.EXPECT().ListKnowledgeDetail(gomock.Any(), gomock.Any()).Return(&knowledge.ListKnowledgeDetailResponse{
@ -3767,10 +3774,10 @@ func TestReleaseApplicationWorkflows(t *testing.T) {
var validateCv func(ns []*vo.Node)
validateCv = func(ns []*vo.Node) {
for _, n := range ns {
if n.Type == vo.BlockTypeBotSubWorkflow {
if n.Type == entity.NodeTypeSubWorkflow.IDStr() {
assert.Equal(t, n.Data.Inputs.WorkflowVersion, version)
}
if n.Type == vo.BlockTypeBotAPI {
if n.Type == entity.NodeTypePlugin.IDStr() {
for _, apiParam := range n.Data.Inputs.APIParams {
// In the application, the workflow plugin node When the plugin version is equal to 0, the plugin is a plugin created in the application
if apiParam.Name == "pluginVersion" {
@ -3779,7 +3786,7 @@ func TestReleaseApplicationWorkflows(t *testing.T) {
}
}
if n.Type == vo.BlockTypeBotLLM {
if n.Type == entity.NodeTypeLLM.IDStr() {
if n.Data.Inputs.FCParam != nil && n.Data.Inputs.FCParam.PluginFCParam != nil {
// In the application, the workflow llm node When the plugin version is equal to 0, the plugin is a plugin created in the application
for _, p := range n.Data.Inputs.FCParam.PluginFCParam.PluginList {
@ -4063,8 +4070,8 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) {
var validateSubWorkflowIDs func(nodes []*vo.Node)
validateSubWorkflowIDs = func(nodes []*vo.Node) {
for _, node := range nodes {
switch node.Type {
case vo.BlockTypeBotAPI:
switch entity.IDStrToNodeType(node.Type) {
case entity.NodeTypePlugin:
apiParams := slices.ToMap(node.Data.Inputs.APIParams, func(e *vo.Param) (string, *vo.Param) {
return e.Name, e
})
@ -4082,7 +4089,7 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) {
assert.Equal(t, "100100", pID)
}
case vo.BlockTypeBotSubWorkflow:
case entity.NodeTypeSubWorkflow:
assert.True(t, copiedIDMap[node.Data.Inputs.WorkflowID])
wfId, err := strconv.ParseInt(node.Data.Inputs.WorkflowID, 10, 64)
assert.NoError(t, err)
@ -4096,7 +4103,7 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) {
err = sonic.UnmarshalString(subWf.Canvas, subworkflowCanvas)
assert.NoError(t, err)
validateSubWorkflowIDs(subworkflowCanvas.Nodes)
case vo.BlockTypeBotLLM:
case entity.NodeTypeLLM:
if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil {
for _, w := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList {
assert.True(t, copiedIDMap[w.WorkflowID])
@ -4116,13 +4123,13 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) {
assert.Equal(t, "100100", k.ID)
}
}
case vo.BlockTypeBotDataset, vo.BlockTypeBotDatasetWrite:
case entity.NodeTypeKnowledgeIndexer, entity.NodeTypeKnowledgeRetriever:
datasetListInfoParam := node.Data.Inputs.DatasetParam[0]
knowledgeIDs := datasetListInfoParam.Input.Value.Content.([]any)
for idx := range knowledgeIDs {
assert.Equal(t, "100100", knowledgeIDs[idx].(string))
}
case vo.BlockTypeDatabase, vo.BlockTypeDatabaseSelect, vo.BlockTypeDatabaseInsert, vo.BlockTypeDatabaseDelete, vo.BlockTypeDatabaseUpdate:
case entity.NodeTypeDatabaseCustomSQL, entity.NodeTypeDatabaseQuery, entity.NodeTypeDatabaseInsert, entity.NodeTypeDatabaseDelete, entity.NodeTypeDatabaseUpdate:
for _, d := range node.Data.Inputs.DatabaseInfoList {
assert.Equal(t, "100100", d.DatabaseInfoID)
}
@ -4208,10 +4215,10 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) {
var validateSubWorkflowIDs func(nodes []*vo.Node)
validateSubWorkflowIDs = func(nodes []*vo.Node) {
for _, node := range nodes {
switch node.Type {
case vo.BlockTypeBotSubWorkflow:
switch entity.IDStrToNodeType(node.Type) {
case entity.NodeTypeSubWorkflow:
assert.True(t, copiedIDMap[node.Data.Inputs.WorkflowID])
case vo.BlockTypeBotLLM:
case entity.NodeTypeLLM:
if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil {
for _, w := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList {
assert.True(t, copiedIDMap[w.WorkflowID])
@ -4229,13 +4236,13 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) {
assert.Equal(t, "100100", k.ID)
}
}
case vo.BlockTypeBotDataset, vo.BlockTypeBotDatasetWrite:
case entity.NodeTypeKnowledgeIndexer, entity.NodeTypeKnowledgeRetriever:
datasetListInfoParam := node.Data.Inputs.DatasetParam[0]
knowledgeIDs := datasetListInfoParam.Input.Value.Content.([]any)
for idx := range knowledgeIDs {
assert.Equal(t, "100100", knowledgeIDs[idx].(string))
}
case vo.BlockTypeDatabase, vo.BlockTypeDatabaseSelect, vo.BlockTypeDatabaseInsert, vo.BlockTypeDatabaseDelete, vo.BlockTypeDatabaseUpdate:
case entity.NodeTypeDatabaseCustomSQL, entity.NodeTypeDatabaseQuery, entity.NodeTypeDatabaseInsert, entity.NodeTypeDatabaseDelete, entity.NodeTypeDatabaseUpdate:
for _, d := range node.Data.Inputs.DatabaseInfoList {
assert.Equal(t, "100100", d.DatabaseInfoID)
}
@ -4356,7 +4363,7 @@ func TestMoveWorkflowAppToLibrary(t *testing.T) {
err = sonic.Unmarshal(data, mainCanvas)
assert.NoError(t, err)
for _, node := range mainCanvas.Nodes {
if node.Type == vo.BlockTypeBotSubWorkflow {
if node.Type == entity.NodeTypeSubWorkflow.IDStr() {
if node.Data.Inputs.WorkflowID == "7516826260387921920" {
node.Data.Inputs.WorkflowID = c1IdStr
}
@ -4372,7 +4379,7 @@ func TestMoveWorkflowAppToLibrary(t *testing.T) {
err = sonic.Unmarshal(cc1Data, cc1Canvas)
assert.NoError(t, err)
for _, node := range cc1Canvas.Nodes {
if node.Type == vo.BlockTypeBotSubWorkflow {
if node.Type == entity.NodeTypeSubWorkflow.IDStr() {
if node.Data.Inputs.WorkflowID == "7516826283318181888" {
node.Data.Inputs.WorkflowID = c2IdStr
}
@ -4423,7 +4430,7 @@ func TestMoveWorkflowAppToLibrary(t *testing.T) {
assert.NoError(t, err)
for _, node := range newMainCanvas.Nodes {
if node.Type == vo.BlockTypeBotSubWorkflow {
if node.Type == entity.NodeTypeSubWorkflow.IDStr() {
assert.True(t, newSubWorkflowID[node.Data.Inputs.WorkflowID])
assert.Equal(t, "v0.0.1", node.Data.Inputs.WorkflowVersion)
}
@ -4437,7 +4444,7 @@ func TestMoveWorkflowAppToLibrary(t *testing.T) {
assert.NoError(t, err)
for _, node := range cc1Canvas.Nodes {
if node.Type == vo.BlockTypeBotSubWorkflow {
if node.Type == entity.NodeTypeSubWorkflow.IDStr() {
assert.True(t, newSubWorkflowID[node.Data.Inputs.WorkflowID])
assert.Equal(t, "v0.0.1", node.Data.Inputs.WorkflowVersion)
}
@ -4508,10 +4515,10 @@ func TestDuplicateWorkflowsByAppID(t *testing.T) {
var validateSubWorkflowIDs func(nodes []*vo.Node)
validateSubWorkflowIDs = func(nodes []*vo.Node) {
for _, node := range nodes {
if node.Type == vo.BlockTypeBotSubWorkflow {
if node.Type == entity.NodeTypeSubWorkflow.IDStr() {
assert.True(t, copiedIDMap[node.Data.Inputs.WorkflowID])
}
if node.Type == vo.BlockTypeBotLLM {
if node.Type == entity.NodeTypeLLM.IDStr() {
if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil {
for _, w := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList {
assert.True(t, copiedIDMap[w.WorkflowID])

View File

@ -24,6 +24,7 @@ import (
"github.com/cloudwego/eino/callbacks"
"github.com/coze-dev/coze-studio/backend/application/internal"
wfdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/database"
wfknowledge "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/knowledge"
wfmodel "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/model"
@ -78,6 +79,9 @@ func InitService(ctx context.Context, components *ServiceComponents) (*Applicati
if !ok {
logs.CtxWarnf(ctx, "workflow builtin chat model for knowledge recall not configured")
}
service.RegisterAllNodeAdaptors()
workflowRepo := service.NewWorkflowRepository(components.IDGen, components.DB, components.Cache,
components.Tos, components.CPStore, bcm)
workflow.SetRepository(workflowRepo)

View File

@ -93,8 +93,8 @@ func (w *ApplicationService) GetNodeTemplateList(ctx context.Context, req *workf
toQueryTypes := make(map[entity.NodeType]bool)
for _, t := range req.NodeTypes {
entityType, err := nodeType2EntityNodeType(t)
if err != nil {
entityType := entity.IDStrToNodeType(t)
if len(entityType) == 0 {
logs.Warnf("get node type %v failed, err:=%v", t, err)
continue
}
@ -116,23 +116,19 @@ func (w *ApplicationService) GetNodeTemplateList(ctx context.Context, req *workf
Name: category,
}
for _, nodeMeta := range nodeMetaList {
tplType, err := entityNodeTypeToAPINodeTemplateType(nodeMeta.Type)
if err != nil {
return nil, err
}
tpl := &workflow.NodeTemplate{
ID: fmt.Sprintf("%d", nodeMeta.ID),
Type: tplType,
Type: workflow.NodeTemplateType(nodeMeta.ID),
Name: ternary.IFElse(i18n.GetLocale(ctx) == i18n.LocaleEN, nodeMeta.EnUSName, nodeMeta.Name),
Desc: ternary.IFElse(i18n.GetLocale(ctx) == i18n.LocaleEN, nodeMeta.EnUSDescription, nodeMeta.Desc),
IconURL: nodeMeta.IconURL,
SupportBatch: ternary.IFElse(nodeMeta.SupportBatch, workflow.SupportBatch_SUPPORT, workflow.SupportBatch_NOT_SUPPORT),
NodeType: fmt.Sprintf("%d", tplType),
NodeType: fmt.Sprintf("%d", nodeMeta.ID),
Color: nodeMeta.Color,
}
resp.Data.TemplateList = append(resp.Data.TemplateList, tpl)
categoryMap[category].NodeTypeList = append(categoryMap[category].NodeTypeList, fmt.Sprintf("%d", tplType))
categoryMap[category].NodeTypeList = append(categoryMap[category].NodeTypeList, fmt.Sprintf("%d", nodeMeta.ID))
}
}
@ -178,7 +174,7 @@ func (w *ApplicationService) CreateWorkflow(ctx context.Context, req *workflow.C
IconURI: req.IconURI,
AppID: parseInt64(req.ProjectID),
Mode: ternary.IFElse(req.IsSetFlowMode(), req.GetFlowMode(), workflow.WorkflowMode_Workflow),
InitCanvasSchema: entity.GetDefaultInitCanvasJsonSchema(i18n.GetLocale(ctx)),
InitCanvasSchema: vo.GetDefaultInitCanvasJsonSchema(i18n.GetLocale(ctx)),
}
id, err := GetWorkflowDomainSVC().Create(ctx, wf)
@ -1041,7 +1037,8 @@ func (w *ApplicationService) CopyWorkflowFromLibraryToApp(ctx context.Context, w
return wf.ID, nil
}
func (w *ApplicationService) MoveWorkflowFromAppToLibrary(ctx context.Context, workflowID int64, spaceID, appID int64) (_ int64, _ []*vo.ValidateIssue, err error) {
func (w *ApplicationService) MoveWorkflowFromAppToLibrary(ctx context.Context, workflowID int64, spaceID, /*not used for now*/
appID int64) (_ int64, _ []*vo.ValidateIssue, err error) {
defer func() {
if panicErr := recover(); panicErr != nil {
err = safego.NewPanicErr(panicErr, debug.Stack())
@ -1127,15 +1124,10 @@ func (w *ApplicationService) MoveWorkflowFromAppToLibrary(ctx context.Context, w
}
func convertNodeExecution(nodeExe *entity.NodeExecution) (*workflow.NodeResult, error) {
nType, err := entityNodeTypeToAPINodeTemplateType(nodeExe.NodeType)
if err != nil {
return nil, err
}
nr := &workflow.NodeResult{
NodeId: nodeExe.NodeID,
NodeName: nodeExe.NodeName,
NodeType: nType.String(),
NodeType: entity.NodeMetaByNodeType(nodeExe.NodeType).GetDisplayKey(),
NodeStatus: workflow.NodeExeStatus(nodeExe.Status),
ErrorInfo: ptr.FromOrDefault(nodeExe.ErrorInfo, ""),
Input: ptr.FromOrDefault(nodeExe.Input, ""),
@ -1316,13 +1308,6 @@ func convertStreamRunEvent(workflowID int64) func(msg *entity.Message) (res *wor
return nil, schema.ErrNoValue
}
var nodeType workflow.NodeTemplateType
nodeType, err = entityNodeTypeToAPINodeTemplateType(msg.NodeType)
if err != nil {
logs.Errorf("convert node type %v failed, err:=%v", msg.NodeType, err)
nodeType = workflow.NodeTemplateType(0)
}
res = &workflow.OpenAPIStreamRunFlowResponse{
ID: strconv.Itoa(messageID),
Event: string(MessageEvent),
@ -1330,7 +1315,7 @@ func convertStreamRunEvent(workflowID int64) func(msg *entity.Message) (res *wor
Content: ptr.Of(msg.Content),
ContentType: ptr.Of("text"),
NodeIsFinish: ptr.Of(msg.Last),
NodeType: ptr.Of(nodeType.String()),
NodeType: ptr.Of(entity.NodeMetaByNodeType(msg.NodeType).GetDisplayKey()),
NodeID: ptr.Of(msg.NodeID),
}
@ -3344,178 +3329,6 @@ func toWorkflowParameter(nType *vo.NamedTypeInfo) (*workflow.Parameter, error) {
return wp, nil
}
func nodeType2EntityNodeType(t string) (entity.NodeType, error) {
i, err := strconv.Atoi(t)
if err != nil {
return "", fmt.Errorf("invalid node type string '%s': %w", t, err)
}
switch i {
case 1:
return entity.NodeTypeEntry, nil
case 2:
return entity.NodeTypeExit, nil
case 3:
return entity.NodeTypeLLM, nil
case 4:
return entity.NodeTypePlugin, nil
case 5:
return entity.NodeTypeCodeRunner, nil
case 6:
return entity.NodeTypeKnowledgeRetriever, nil
case 8:
return entity.NodeTypeSelector, nil
case 9:
return entity.NodeTypeSubWorkflow, nil
case 12:
return entity.NodeTypeDatabaseCustomSQL, nil
case 13:
return entity.NodeTypeOutputEmitter, nil
case 15:
return entity.NodeTypeTextProcessor, nil
case 18:
return entity.NodeTypeQuestionAnswer, nil
case 19:
return entity.NodeTypeBreak, nil
case 20:
return entity.NodeTypeVariableAssignerWithinLoop, nil
case 21:
return entity.NodeTypeLoop, nil
case 22:
return entity.NodeTypeIntentDetector, nil
case 27:
return entity.NodeTypeKnowledgeIndexer, nil
case 28:
return entity.NodeTypeBatch, nil
case 29:
return entity.NodeTypeContinue, nil
case 30:
return entity.NodeTypeInputReceiver, nil
case 32:
return entity.NodeTypeVariableAggregator, nil
case 37:
return entity.NodeTypeMessageList, nil
case 38:
return entity.NodeTypeClearMessage, nil
case 39:
return entity.NodeTypeCreateConversation, nil
case 40:
return entity.NodeTypeVariableAssigner, nil
case 42:
return entity.NodeTypeDatabaseUpdate, nil
case 43:
return entity.NodeTypeDatabaseQuery, nil
case 44:
return entity.NodeTypeDatabaseDelete, nil
case 45:
return entity.NodeTypeHTTPRequester, nil
case 46:
return entity.NodeTypeDatabaseInsert, nil
case 58:
return entity.NodeTypeJsonSerialization, nil
case 59:
return entity.NodeTypeJsonDeserialization, nil
case 60:
return entity.NodeTypeKnowledgeDeleter, nil
default:
// Handle all unknown or unsupported types here
return "", fmt.Errorf("unsupported or unknown node type ID: %d", i)
}
}
// entityNodeTypeToAPINodeTemplateType converts an entity.NodeType to the corresponding workflow.NodeTemplateType.
func entityNodeTypeToAPINodeTemplateType(nodeType entity.NodeType) (workflow.NodeTemplateType, error) {
switch nodeType {
case entity.NodeTypeEntry:
return workflow.NodeTemplateType_Start, nil
case entity.NodeTypeExit:
return workflow.NodeTemplateType_End, nil
case entity.NodeTypeLLM:
return workflow.NodeTemplateType_LLM, nil
case entity.NodeTypePlugin:
// Maps to Api type in the API model
return workflow.NodeTemplateType_Api, nil
case entity.NodeTypeCodeRunner:
return workflow.NodeTemplateType_Code, nil
case entity.NodeTypeKnowledgeRetriever:
// Maps to Dataset type in the API model
return workflow.NodeTemplateType_Dataset, nil
case entity.NodeTypeSelector:
// Maps to If type in the API model
return workflow.NodeTemplateType_If, nil
case entity.NodeTypeSubWorkflow:
return workflow.NodeTemplateType_SubWorkflow, nil
case entity.NodeTypeDatabaseCustomSQL:
// Maps to the generic Database type in the API model
return workflow.NodeTemplateType_Database, nil
case entity.NodeTypeOutputEmitter:
// Maps to Message type in the API model
return workflow.NodeTemplateType_Message, nil
case entity.NodeTypeTextProcessor:
return workflow.NodeTemplateType_Text, nil
case entity.NodeTypeQuestionAnswer:
return workflow.NodeTemplateType_Question, nil
case entity.NodeTypeBreak:
return workflow.NodeTemplateType_Break, nil
case entity.NodeTypeVariableAssigner:
return workflow.NodeTemplateType_AssignVariable, nil
case entity.NodeTypeVariableAssignerWithinLoop:
return workflow.NodeTemplateType_LoopSetVariable, nil
case entity.NodeTypeLoop:
return workflow.NodeTemplateType_Loop, nil
case entity.NodeTypeIntentDetector:
return workflow.NodeTemplateType_Intent, nil
case entity.NodeTypeKnowledgeIndexer:
// Maps to DatasetWrite type in the API model
return workflow.NodeTemplateType_DatasetWrite, nil
case entity.NodeTypeBatch:
return workflow.NodeTemplateType_Batch, nil
case entity.NodeTypeContinue:
return workflow.NodeTemplateType_Continue, nil
case entity.NodeTypeInputReceiver:
return workflow.NodeTemplateType_Input, nil
case entity.NodeTypeMessageList:
return workflow.NodeTemplateType(37), nil
case entity.NodeTypeVariableAggregator:
return workflow.NodeTemplateType(32), nil
case entity.NodeTypeClearMessage:
return workflow.NodeTemplateType(38), nil
case entity.NodeTypeCreateConversation:
return workflow.NodeTemplateType(39), nil
// Note: entity.NodeTypeVariableAggregator (ID 32) has no direct mapping in NodeTemplateType
// Note: entity.NodeTypeMessageList (ID 37) has no direct mapping in NodeTemplateType
// Note: entity.NodeTypeClearMessage (ID 38) has no direct mapping in NodeTemplateType
// Note: entity.NodeTypeCreateConversation (ID 39) has no direct mapping in NodeTemplateType
case entity.NodeTypeDatabaseUpdate:
return workflow.NodeTemplateType_DatabaseUpdate, nil
case entity.NodeTypeDatabaseQuery:
// Maps to DatabasesELECT (ID 43) in the API model (note potential typo)
return workflow.NodeTemplateType_DatabasesELECT, nil
case entity.NodeTypeDatabaseDelete:
return workflow.NodeTemplateType_DatabaseDelete, nil
// Note: entity.NodeTypeHTTPRequester (ID 45) has no direct mapping in NodeTemplateType
case entity.NodeTypeHTTPRequester:
return workflow.NodeTemplateType(45), nil
case entity.NodeTypeDatabaseInsert:
// Maps to DatabaseInsert (ID 41) in the API model, despite entity ID being 46.
// return workflow.NodeTemplateType_DatabaseInsert, nil
return workflow.NodeTemplateType(46), nil
case entity.NodeTypeJsonSerialization:
return workflow.NodeTemplateType(58), nil
case entity.NodeTypeJsonDeserialization:
return workflow.NodeTemplateType_JsonDeserialization, nil
case entity.NodeTypeKnowledgeDeleter:
return workflow.NodeTemplateType_DatasetDelete, nil
case entity.NodeTypeLambda:
return 0, nil
default:
// Handle entity types that don't have a corresponding NodeTemplateType
return workflow.NodeTemplateType(0), fmt.Errorf("cannot map entity node type '%s' to a workflow.NodeTemplateType", nodeType)
}
}
func i64PtrToStringPtr(i *int64) *string {
if i == nil {
return nil
@ -3761,7 +3574,7 @@ func mergeWorkflowAPIParameters(latestAPIParameters []*workflow.APIParameter, ex
func parseWorkflowTerminatePlanType(c *vo.Canvas) (int32, error) {
var endNode *vo.Node
for _, n := range c.Nodes {
if n.Type == vo.BlockTypeBotEnd {
if n.Type == entity.NodeTypeExit.IDStr() {
endNode = n
break
}

View File

@ -1,49 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package service
import (
"net/http"
"net/url"
"testing"
. "github.com/bytedance/mockey"
"github.com/stretchr/testify/assert"
)
func TestGenRequestString(t *testing.T) {
PatchConvey("", t, func() {
requestStr, err := genRequestString(&http.Request{
Header: http.Header{
"Content-Type": []string{"application/json"},
},
Method: http.MethodPost,
URL: &url.URL{Path: "/test"},
}, []byte(`{"a": 1}`))
assert.NoError(t, err)
assert.Equal(t, `{"header":{"Content-Type":["application/json"]},"query":null,"path":"/test","body":{"a": 1}}`, requestStr)
})
PatchConvey("", t, func() {
var body []byte
requestStr, err := genRequestString(&http.Request{
URL: &url.URL{Path: "/test"},
}, body)
assert.NoError(t, err)
assert.Equal(t, `{"header":null,"query":null,"path":"/test","body":null}`, requestStr)
})
}

View File

@ -1,4 +1,5 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
@ -16,58 +17,85 @@
package entity
import (
"fmt"
"strconv"
)
type NodeType string
func (nt NodeType) IDStr() string {
m := NodeMetaByNodeType(nt)
if m == nil {
return ""
}
return fmt.Sprintf("%d", m.ID)
}
func IDStrToNodeType(s string) NodeType {
id, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return ""
}
for _, m := range NodeTypeMetas {
if m.ID == id {
return m.Key
}
}
return ""
}
type NodeTypeMeta struct {
ID int64 `json:"id"`
Name string `json:"name"`
Type NodeType `json:"type"`
Category string `json:"category"`
Color string `json:"color"`
Desc string `json:"desc"`
IconURL string `json:"icon_url"`
SupportBatch bool `json:"support_batch"`
Disabled bool `json:"disabled,omitempty"`
EnUSName string `json:"en_us_name,omitempty"`
EnUSDescription string `json:"en_us_description,omitempty"`
ID int64
Key NodeType
DisplayKey string
Name string `json:"name"`
Category string `json:"category"`
Color string `json:"color"`
Desc string `json:"desc"`
IconURL string `json:"icon_url"`
SupportBatch bool `json:"support_batch"`
Disabled bool `json:"disabled,omitempty"`
EnUSName string `json:"en_us_name,omitempty"`
EnUSDescription string `json:"en_us_description,omitempty"`
ExecutableMeta
}
func (ntm *NodeTypeMeta) GetDisplayKey() string {
if len(ntm.DisplayKey) > 0 {
return ntm.DisplayKey
}
return string(ntm.Key)
}
type Category struct {
Key string `json:"key"`
Name string `json:"name"`
EnUSName string `json:"en_us_name"`
}
type StreamingParadigm string
const (
Invoke StreamingParadigm = "invoke"
Stream StreamingParadigm = "stream"
Collect StreamingParadigm = "collect"
Transform StreamingParadigm = "transform"
)
type ExecutableMeta struct {
IsComposite bool `json:"is_composite,omitempty"`
DefaultTimeoutMS int64 `json:"default_timeout_ms,omitempty"` // default timeout in milliseconds, 0 means no timeout
PreFillZero bool `json:"pre_fill_zero,omitempty"`
PostFillNil bool `json:"post_fill_nil,omitempty"`
CallbackEnabled bool `json:"callback_enabled,omitempty"` // is false, Eino framework will inject callbacks for this node
MayUseChatModel bool `json:"may_use_chat_model,omitempty"`
InputSourceAware bool `json:"input_source_aware,omitempty"` // whether this node needs to know the runtime status of its input sources
StreamingParadigms map[StreamingParadigm]bool `json:"streaming_paradigms,omitempty"`
StreamSourceEOFAware bool `json:"needs_stream_source_eof,omitempty"` // whether this node needs to be aware stream sources' SourceEOF error
/*
IncrementalOutput indicates that the node's output is intended for progressive, user-facing streaming.
This distinguishes nodes that actually stream text to the user (e.g., 'Exit', 'Output')
from those that are merely capable of streaming internally (defined by StreamingParadigms),
whose output is consumed by other nodes.
In essence, nodes with IncrementalOutput are a subset of those defined in StreamingParadigms.
When set to true, stream chunks from the node are persisted in real-time and can be fetched by get_process.
*/
IsComposite bool `json:"is_composite,omitempty"`
DefaultTimeoutMS int64 `json:"default_timeout_ms,omitempty"` // default timeout in milliseconds, 0 means no timeout
PreFillZero bool `json:"pre_fill_zero,omitempty"`
PostFillNil bool `json:"post_fill_nil,omitempty"`
MayUseChatModel bool `json:"may_use_chat_model,omitempty"`
InputSourceAware bool `json:"input_source_aware,omitempty"` // whether this node needs to know the runtime status of its input sources
StreamSourceEOFAware bool `json:"needs_stream_source_eof,omitempty"` // whether this node needs to be aware stream sources' SourceEOF error
// IncrementalOutput indicates that the node's output is intended for progressive, user-facing streaming.
// This distinguishes nodes that actually stream text to the user (e.g., 'Exit', 'Output')
//from those that are merely capable of streaming internally (defined by StreamingParadigms),
// In essence, nodes with IncrementalOutput are a subset of those defined in StreamingParadigms.
// When set to true, stream chunks from the node are persisted in real-time and can be fetched by get_process.
IncrementalOutput bool `json:"incremental_output,omitempty"`
// UseCtxCache indicates that the node would require a newly initialized ctx cache for each invocation.
// example use cases:
// - write warnings to the ctx cache during Invoke, and read from the ctx within Callback output converter
UseCtxCache bool `json:"use_ctx_cache"`
}
type PluginNodeMeta struct {
@ -125,9 +153,700 @@ const (
NodeTypeSubWorkflow NodeType = "SubWorkflow"
NodeTypeJsonSerialization NodeType = "JsonSerialization"
NodeTypeJsonDeserialization NodeType = "JsonDeserialization"
NodeTypeComment NodeType = "Comment"
)
const (
EntryNodeKey = "100001"
ExitNodeKey = "900001"
)
var Categories = []Category{
{
Key: "", // this is the default category. some of the most important nodes belong here, such as LLM, plugin, sub-workflow
Name: "",
EnUSName: "",
},
{
Key: "logic",
Name: "业务逻辑",
EnUSName: "Logic",
},
{
Key: "input&output",
Name: "输入&输出",
EnUSName: "Input&Output",
},
{
Key: "database",
Name: "数据库",
EnUSName: "Database",
},
{
Key: "data",
Name: "知识库&数据",
EnUSName: "Data",
},
{
Key: "image",
Name: "图像处理",
EnUSName: "Image",
},
{
Key: "audio&video",
Name: "音视频处理",
EnUSName: "Audio&Video",
},
{
Key: "utilities",
Name: "组件",
EnUSName: "Utilities",
},
{
Key: "conversation_management",
Name: "会话管理",
EnUSName: "Conversation management",
},
{
Key: "conversation_history",
Name: "会话历史",
EnUSName: "Conversation history",
},
{
Key: "message",
Name: "消息",
EnUSName: "Message",
},
}
// NodeTypeMetas holds the metadata for all available node types.
// It is initialized with built-in node types and potentially extended by loading from external sources.
var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
NodeTypeEntry: {
ID: 1,
Key: NodeTypeEntry,
DisplayKey: "Start",
Name: "开始",
Category: "input&output",
Desc: "工作流的起始节点,用于设定启动工作流需要的信息",
Color: "#5C62FF",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start-v2.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
PostFillNil: true,
},
EnUSName: "Start",
EnUSDescription: "The starting node of the workflow, used to set the information needed to initiate the workflow.",
},
NodeTypeExit: {
ID: 2,
Key: NodeTypeExit,
DisplayKey: "End",
Name: "结束",
Category: "input&output",
Desc: "工作流的最终节点,用于返回工作流运行后的结果信息",
Color: "#5C62FF",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End-v2.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
PreFillZero: true,
InputSourceAware: true,
StreamSourceEOFAware: true,
IncrementalOutput: true,
},
EnUSName: "End",
EnUSDescription: "The final node of the workflow, used to return the result information after the workflow runs.",
},
NodeTypeLLM: {
ID: 3,
Key: NodeTypeLLM,
DisplayKey: "LLM",
Name: "大模型",
Category: "",
Desc: "调用大语言模型,使用变量和提示词生成回复",
Color: "#5C62FF",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-LLM-v2.jpg",
SupportBatch: true,
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 3 * 60 * 1000, // 3 minutes
PreFillZero: true,
PostFillNil: true,
InputSourceAware: true,
MayUseChatModel: true,
},
EnUSName: "LLM",
EnUSDescription: "Invoke the large language model, generate responses using variables and prompt words.",
},
NodeTypePlugin: {
ID: 4,
Key: NodeTypePlugin,
DisplayKey: "Api",
Name: "插件",
Category: "",
Desc: "通过添加工具访问实时数据和执行外部操作",
Color: "#CA61FF",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Plugin-v2.jpg",
SupportBatch: true,
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 3 * 60 * 1000, // 3 minutes
PreFillZero: true,
PostFillNil: true,
},
EnUSName: "Plugin",
EnUSDescription: "Used to access external real-time data and perform operations",
},
NodeTypeCodeRunner: {
ID: 5,
Key: NodeTypeCodeRunner,
DisplayKey: "Code",
Name: "代码",
Category: "logic",
Desc: "编写代码,处理输入变量来生成返回值",
Color: "#00B2B2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Code-v2.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
UseCtxCache: true,
},
EnUSName: "Code",
EnUSDescription: "Write code to process input variables to generate return values.",
},
NodeTypeKnowledgeRetriever: {
ID: 6,
Key: NodeTypeKnowledgeRetriever,
DisplayKey: "Dataset",
Name: "知识库检索",
Category: "data",
Desc: "在选定的知识中,根据输入变量召回最匹配的信息,并以列表形式返回",
Color: "#FF811A",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-KnowledgeQuery-v2.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
},
EnUSName: "Knowledge retrieval",
EnUSDescription: "In the selected knowledge, the best matching information is recalled based on the input variable and returned as an Array.",
},
NodeTypeSelector: {
ID: 8,
Key: NodeTypeSelector,
DisplayKey: "If",
Name: "选择器",
Category: "logic",
Desc: "连接多个下游分支,若设定的条件成立则仅运行对应的分支,若均不成立则只运行“否则”分支",
Color: "#00B2B2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Condition-v2.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{},
EnUSName: "Condition",
EnUSDescription: "Connect multiple downstream branches. Only the corresponding branch will be executed if the set conditions are met. If none are met, only the 'else' branch will be executed.",
},
NodeTypeSubWorkflow: {
ID: 9,
Key: NodeTypeSubWorkflow,
DisplayKey: "SubWorkflow",
Name: "工作流",
Category: "",
Desc: "集成已发布工作流,可以执行嵌套子任务",
Color: "#00B83E",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Workflow-v2.jpg",
SupportBatch: true,
ExecutableMeta: ExecutableMeta{},
EnUSName: "Workflow",
EnUSDescription: "Add published workflows to execute subtasks",
},
NodeTypeDatabaseCustomSQL: {
ID: 12,
Key: NodeTypeDatabaseCustomSQL,
DisplayKey: "End",
Name: "SQL自定义",
Category: "database",
Desc: "基于用户自定义的 SQL 完成对数据库的增删改查操作",
Color: "#FF811A",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Database-v2.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
},
EnUSName: "SQL Customization",
EnUSDescription: "Complete the operations of adding, deleting, modifying and querying the database based on user-defined SQL",
},
NodeTypeOutputEmitter: {
ID: 13,
Key: NodeTypeOutputEmitter,
DisplayKey: "Message",
Name: "输出",
Category: "input&output",
Desc: "节点从“消息”更名为“输出”,支持中间过程的消息输出,支持流式和非流式两种方式",
Color: "#5C62FF",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Output-v2.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
PreFillZero: true,
InputSourceAware: true,
StreamSourceEOFAware: true,
IncrementalOutput: true,
},
EnUSName: "Output",
EnUSDescription: "The node is renamed from \"message\" to \"output\", Supports message output in the intermediate process and streaming and non-streaming methods",
},
NodeTypeTextProcessor: {
ID: 15,
Key: NodeTypeTextProcessor,
DisplayKey: "Text",
Name: "文本处理",
Category: "utilities",
Desc: "用于处理多个字符串类型变量的格式",
Color: "#3071F2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-StrConcat-v2.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
PreFillZero: true,
InputSourceAware: true,
},
EnUSName: "Text Processing",
EnUSDescription: "The format used for handling multiple string-type variables.",
},
NodeTypeQuestionAnswer: {
ID: 18,
Key: NodeTypeQuestionAnswer,
DisplayKey: "Question",
Name: "问答",
Category: "utilities",
Desc: "支持中间向用户提问问题,支持预置选项提问和开放式问题提问两种方式",
Color: "#3071F2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Direct-Question-v2.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
MayUseChatModel: true,
},
EnUSName: "Question",
EnUSDescription: "Support asking questions to the user in the middle of the conversation, with both preset options and open-ended questions",
},
NodeTypeBreak: {
ID: 19,
Key: NodeTypeBreak,
DisplayKey: "Break",
Name: "终止循环",
Category: "logic",
Desc: "用于立即终止当前所在的循环,跳出循环体",
Color: "#00B2B2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Break-v2.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{},
EnUSName: "Break",
EnUSDescription: "Used to immediately terminate the current loop and jump out of the loop",
},
NodeTypeVariableAssignerWithinLoop: {
ID: 20,
Key: NodeTypeVariableAssignerWithinLoop,
DisplayKey: "LoopSetVariable",
Name: "设置变量",
Category: "logic",
Desc: "用于重置循环变量的值,使其下次循环使用重置后的值",
Color: "#00B2B2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-LoopSetVariable-v2.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{},
EnUSName: "Set Variable",
EnUSDescription: "Used to reset the value of the loop variable so that it uses the reset value in the next iteration",
},
NodeTypeLoop: {
ID: 21,
Key: NodeTypeLoop,
DisplayKey: "Loop",
Name: "循环",
Category: "logic",
Desc: "用于通过设定循环次数和逻辑,重复执行一系列任务",
Color: "#00B2B2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Loop-v2.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
IsComposite: true,
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
},
EnUSName: "Loop",
EnUSDescription: "Used to repeatedly execute a series of tasks by setting the number of iterations and logic",
},
NodeTypeIntentDetector: {
ID: 22,
Key: NodeTypeIntentDetector,
DisplayKey: "Intent",
Name: "意图识别",
Category: "logic",
Desc: "用于用户输入的意图识别,并将其与预设意图选项进行匹配。",
Color: "#00B2B2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Intent-v2.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
MayUseChatModel: true,
},
EnUSName: "Intent recognition",
EnUSDescription: "Used for recognizing the intent in user input and matching it with preset intent options.",
},
NodeTypeKnowledgeIndexer: {
ID: 27,
Key: NodeTypeKnowledgeIndexer,
DisplayKey: "DatasetWrite",
Name: "知识库写入",
Category: "data",
Desc: "写入节点可以添加 文本类型 的知识库,仅可以添加一个知识库",
Color: "#FF811A",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-KnowledgeWriting-v2.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
},
EnUSName: "Knowledge writing",
EnUSDescription: "The write node can add a knowledge base of type text. Only one knowledge base can be added.",
},
NodeTypeBatch: {
ID: 28,
Key: NodeTypeBatch,
DisplayKey: "Batch",
Name: "批处理",
Category: "logic",
Desc: "通过设定批量运行次数和逻辑,运行批处理体内的任务",
Color: "#00B2B2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Batch-v2.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
IsComposite: true,
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
},
EnUSName: "Batch",
EnUSDescription: "By setting the number of batch runs and logic, run the tasks in the batch body.",
},
NodeTypeContinue: {
ID: 29,
Key: NodeTypeContinue,
DisplayKey: "Continue",
Name: "继续循环",
Category: "logic",
Desc: "用于终止当前循环,执行下次循环",
Color: "#00B2B2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Continue-v2.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{},
EnUSName: "Continue",
EnUSDescription: "Used to immediately terminate the current loop and execute next loop",
},
NodeTypeInputReceiver: {
ID: 30,
Key: NodeTypeInputReceiver,
DisplayKey: "Input",
Name: "输入",
Category: "input&output",
Desc: "支持中间过程的信息输入",
Color: "#5C62FF",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Input-v2.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
PostFillNil: true,
},
EnUSName: "Input",
EnUSDescription: "Support intermediate information input",
},
NodeTypeComment: {
ID: 31,
Key: "",
Name: "注释",
Category: "", // Not found in cate_list
Desc: "comment_desc", // Placeholder from JSON
Color: "",
IconURL: "comment_icon", // Placeholder from JSON
SupportBatch: false, // supportBatch: 1
EnUSName: "Comment",
},
NodeTypeVariableAggregator: {
ID: 32,
Key: NodeTypeVariableAggregator,
Name: "变量聚合",
Category: "logic",
Desc: "对多个分支的输出进行聚合处理",
Color: "#00B2B2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/VariableMerge-icon.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
PostFillNil: true,
InputSourceAware: true,
UseCtxCache: true,
},
EnUSName: "Variable Merge",
EnUSDescription: "Aggregate the outputs of multiple branches.",
},
NodeTypeMessageList: {
ID: 37,
Key: NodeTypeMessageList,
Name: "查询消息列表",
Category: "message",
Desc: "用于查询消息列表",
Color: "#F2B600",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Conversation-List.jpeg",
SupportBatch: false,
Disabled: true,
ExecutableMeta: ExecutableMeta{
PreFillZero: true,
PostFillNil: true,
},
EnUSName: "Query message list",
EnUSDescription: "Used to query the message list",
},
NodeTypeClearMessage: {
ID: 38,
Key: NodeTypeClearMessage,
Name: "清除上下文",
Category: "conversation_history",
Desc: "用于清空会话历史清空后LLM看到的会话历史为空",
Color: "#F2B600",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Conversation-Delete.jpeg",
SupportBatch: false,
Disabled: true,
ExecutableMeta: ExecutableMeta{
PreFillZero: true,
PostFillNil: true,
},
EnUSName: "Clear conversation history",
EnUSDescription: "Used to clear conversation history. After clearing, the conversation history visible to the LLM node will be empty.",
},
NodeTypeCreateConversation: {
ID: 39,
Key: NodeTypeCreateConversation,
Name: "创建会话",
Category: "conversation_management",
Desc: "用于创建会话",
Color: "#F2B600",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Conversation-Create.jpeg",
SupportBatch: false,
Disabled: true,
ExecutableMeta: ExecutableMeta{
PreFillZero: true,
PostFillNil: true,
},
EnUSName: "Create conversation",
EnUSDescription: "This node is used to create a conversation.",
},
NodeTypeVariableAssigner: {
ID: 40,
Key: NodeTypeVariableAssigner,
DisplayKey: "AssignVariable",
Name: "变量赋值",
Category: "data",
Desc: "用于给支持写入的变量赋值,包括应用变量、用户变量",
Color: "#FF811A",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/Variable.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{},
EnUSName: "Variable assign",
EnUSDescription: "Assigns values to variables that support the write operation, including app and user variables.",
},
NodeTypeDatabaseUpdate: {
ID: 42,
Key: NodeTypeDatabaseUpdate,
DisplayKey: "DatabaseUpdate",
Name: "更新数据",
Category: "database",
Desc: "修改表中已存在的数据记录,用户指定更新条件和内容来更新数据",
Color: "#F2B600",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-database-update.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
},
EnUSName: "Update Data",
EnUSDescription: "Modify the existing data records in the table, and the user specifies the update conditions and contents to update the data",
},
NodeTypeDatabaseQuery: {
ID: 43,
Key: NodeTypeDatabaseQuery,
DisplayKey: "DatabaseSelect",
Name: "查询数据",
Category: "database",
Desc: "从表获取数据,用户可定义查询条件、选择列等,输出符合条件的数据",
Color: "#F2B600",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icaon-database-select.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
},
EnUSName: "Query Data",
EnUSDescription: "Query data from the table, and the user can define query conditions, select columns, etc., and output the data that meets the conditions",
},
NodeTypeDatabaseDelete: {
ID: 44,
Key: NodeTypeDatabaseDelete,
DisplayKey: "DatabaseDelete",
Name: "删除数据",
Category: "database",
Desc: "从表中删除数据记录,用户指定删除条件来删除符合条件的记录",
Color: "#F2B600",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-database-delete.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
},
EnUSName: "Delete Data",
EnUSDescription: "Delete data records from the table, and the user specifies the deletion conditions to delete the records that meet the conditions",
},
NodeTypeHTTPRequester: {
ID: 45,
Key: NodeTypeHTTPRequester,
DisplayKey: "Http",
Name: "HTTP 请求",
Category: "utilities",
Desc: "用于发送API请求从接口返回数据",
Color: "#3071F2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-HTTP.png",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
},
EnUSName: "HTTP request",
EnUSDescription: "It is used to send API requests and return data from the interface.",
},
NodeTypeDatabaseInsert: {
ID: 46,
Key: NodeTypeDatabaseInsert,
DisplayKey: "DatabaseInsert",
Name: "新增数据",
Category: "database",
Desc: "向表添加新数据记录,用户输入数据内容后插入数据库",
Color: "#F2B600",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-database-insert.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
},
EnUSName: "Add Data",
EnUSDescription: "Add new data records to the table, and insert them into the database after the user enters the data content",
},
NodeTypeJsonSerialization: {
// ID is the unique identifier of this node type. Used in various front-end APIs.
ID: 58,
// Key is the unique NodeType of this node. Used in backend code as well as saved in DB.
Key: NodeTypeJsonSerialization,
// DisplayKey is the string used in frontend to identify this node.
// Example use cases:
// - used during querying test-run results for nodes
// - used in returned messages from streaming openAPI Runs.
// If empty, will use Key as DisplayKey.
DisplayKey: "ToJSON",
// Name is the node in ZH_CN, will be displayed on Canvas.
Name: "JSON 序列化",
// Category is the category of this node, determines which category this node will be displayed in.
Category: "utilities",
// Desc is the desc in ZH_CN, will be displayed as tooltip on Canvas.
Desc: "用于把变量转化为JSON字符串",
// Color is the color of the upper edge of the node displayed on Canvas.
Color: "F2B600",
// IconURL is the URL of the icon displayed on Canvas.
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-to_json.png",
// SupportBatch indicates whether this node can set batch mode.
// NOTE: ultimately it's frontend that decides which node can enable batch mode.
SupportBatch: false,
// ExecutableMeta configures certain common aspects of request-time behaviors for this node.
ExecutableMeta: ExecutableMeta{
// DefaultTimeoutMS configures the default timeout for this node, in milliseconds. 0 means no timeout.
DefaultTimeoutMS: 60 * 1000, // 1 minute
// PreFillZero decides whether to pre-fill zero value for any missing fields in input.
PreFillZero: true,
// PostFillNil decides whether to post-fill nil value for any missing fields in output.
PostFillNil: true,
},
// EnUSName is the name in EN_US, will be displayed on Canvas if language of Coze-Studio is set to EnUS.
EnUSName: "JSON serialization",
// EnUSDescription is the description in EN_US, will be displayed on Canvas if language of Coze-Studio is set to EnUS.
EnUSDescription: "Convert variable to JSON string",
},
NodeTypeJsonDeserialization: {
ID: 59,
Key: NodeTypeJsonDeserialization,
DisplayKey: "FromJSON",
Name: "JSON 反序列化",
Category: "utilities",
Desc: "用于将JSON字符串解析为变量",
Color: "F2B600",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-from_json.png",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
UseCtxCache: true,
},
EnUSName: "JSON deserialization",
EnUSDescription: "Parse JSON string to variable",
},
NodeTypeKnowledgeDeleter: {
ID: 60,
Key: NodeTypeKnowledgeDeleter,
DisplayKey: "KnowledgeDelete",
Name: "知识库删除",
Category: "data",
Desc: "用于删除知识库中的文档",
Color: "#FF811A",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icons-dataset-delete.png",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
},
EnUSName: "Knowledge delete",
EnUSDescription: "The delete node can delete a document in knowledge base.",
},
NodeTypeLambda: {
ID: 1000,
Key: NodeTypeLambda,
Name: "Lambda",
EnUSName: "Comment",
},
}
// PluginNodeMetas holds metadata for specific plugin API entity.
var PluginNodeMetas []*PluginNodeMeta
// PluginCategoryMetas holds metadata for plugin category entity.
var PluginCategoryMetas []*PluginCategoryMeta
func NodeMetaByNodeType(t NodeType) *NodeTypeMeta {
if m, ok := NodeTypeMetas[t]; ok {
return m
}
return nil
}

View File

@ -1,867 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package entity
import (
"github.com/coze-dev/coze-studio/backend/pkg/i18n"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
)
var Categories = []Category{
{
Key: "", // this is the default category. some of the most important nodes belong here, such as LLM, plugin, sub-workflow
Name: "",
EnUSName: "",
},
{
Key: "logic",
Name: "业务逻辑",
EnUSName: "Logic",
},
{
Key: "input&output",
Name: "输入&输出",
EnUSName: "Input&Output",
},
{
Key: "database",
Name: "数据库",
EnUSName: "Database",
},
{
Key: "data",
Name: "知识库&数据",
EnUSName: "Data",
},
{
Key: "image",
Name: "图像处理",
EnUSName: "Image",
},
{
Key: "audio&video",
Name: "音视频处理",
EnUSName: "Audio&Video",
},
{
Key: "utilities",
Name: "组件",
EnUSName: "Utilities",
},
{
Key: "conversation_management",
Name: "会话管理",
EnUSName: "Conversation management",
},
{
Key: "conversation_history",
Name: "会话历史",
EnUSName: "Conversation history",
},
{
Key: "message",
Name: "消息",
EnUSName: "Message",
},
}
// NodeTypeMetas holds the metadata for all available node types.
// It is initialized with built-in types and potentially extended by loading from external sources.
var NodeTypeMetas = []*NodeTypeMeta{
{
ID: 1,
Name: "开始",
Type: NodeTypeEntry,
Category: "input&output", // Mapped from cate_list
Desc: "工作流的起始节点,用于设定启动工作流需要的信息",
Color: "#5C62FF",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start-v2.jpg",
SupportBatch: false, // supportBatch: 1
ExecutableMeta: ExecutableMeta{
PostFillNil: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Start",
EnUSDescription: "The starting node of the workflow, used to set the information needed to initiate the workflow.",
},
{
ID: 2,
Name: "结束",
Type: NodeTypeExit,
Category: "input&output", // Mapped from cate_list
Desc: "工作流的最终节点,用于返回工作流运行后的结果信息",
Color: "#5C62FF",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End-v2.jpg",
SupportBatch: false, // supportBatch: 1
ExecutableMeta: ExecutableMeta{
PreFillZero: true,
CallbackEnabled: true,
InputSourceAware: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true, Transform: true},
StreamSourceEOFAware: true,
IncrementalOutput: true,
},
EnUSName: "End",
EnUSDescription: "The final node of the workflow, used to return the result information after the workflow runs.",
},
{
ID: 3,
Name: "大模型",
Type: NodeTypeLLM,
Category: "", // Mapped from cate_list
Desc: "调用大语言模型,使用变量和提示词生成回复",
Color: "#5C62FF",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-LLM-v2.jpg",
SupportBatch: true, // supportBatch: 2
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 3 * 60 * 1000, // 3 minutes
PreFillZero: true,
PostFillNil: true,
CallbackEnabled: true,
InputSourceAware: true,
MayUseChatModel: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true, Stream: true},
},
EnUSName: "LLM",
EnUSDescription: "Invoke the large language model, generate responses using variables and prompt words.",
},
{
ID: 4,
Name: "插件",
Type: NodeTypePlugin,
Category: "", // Mapped from cate_list
Desc: "通过添加工具访问实时数据和执行外部操作",
Color: "#CA61FF",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Plugin-v2.jpg",
SupportBatch: true, // supportBatch: 2
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 3 * 60 * 1000, // 3 minutes
PreFillZero: true,
PostFillNil: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Plugin",
EnUSDescription: "Used to access external real-time data and perform operations",
},
{
ID: 5,
Name: "代码",
Type: NodeTypeCodeRunner,
Category: "logic", // Mapped from cate_list
Desc: "编写代码,处理输入变量来生成返回值",
Color: "#00B2B2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Code-v2.jpg",
SupportBatch: false, // supportBatch: 1
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
CallbackEnabled: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Code",
EnUSDescription: "Write code to process input variables to generate return values.",
},
{
ID: 6,
Name: "知识库检索",
Type: NodeTypeKnowledgeRetriever,
Category: "data", // Mapped from cate_list
Desc: "在选定的知识中,根据输入变量召回最匹配的信息,并以列表形式返回",
Color: "#FF811A",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-KnowledgeQuery-v2.jpg",
SupportBatch: false, // supportBatch: 1
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Knowledge retrieval",
EnUSDescription: "In the selected knowledge, the best matching information is recalled based on the input variable and returned as an Array.",
},
{
ID: 8,
Name: "选择器",
Type: NodeTypeSelector,
Category: "logic", // Mapped from cate_list
Desc: "连接多个下游分支,若设定的条件成立则仅运行对应的分支,若均不成立则只运行“否则”分支",
Color: "#00B2B2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Condition-v2.jpg",
SupportBatch: false, // supportBatch: 1
ExecutableMeta: ExecutableMeta{
CallbackEnabled: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Condition",
EnUSDescription: "Connect multiple downstream branches. Only the corresponding branch will be executed if the set conditions are met. If none are met, only the 'else' branch will be executed.",
},
{
ID: 9,
Name: "工作流",
Type: NodeTypeSubWorkflow,
Category: "", // Mapped from cate_list
Desc: "集成已发布工作流,可以执行嵌套子任务",
Color: "#00B83E",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Workflow-v2.jpg",
SupportBatch: true, // supportBatch: 2
ExecutableMeta: ExecutableMeta{
CallbackEnabled: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Workflow",
EnUSDescription: "Add published workflows to execute subtasks",
},
{
ID: 12,
Name: "SQL自定义",
Type: NodeTypeDatabaseCustomSQL,
Category: "database", // Mapped from cate_list
Desc: "基于用户自定义的 SQL 完成对数据库的增删改查操作",
Color: "#FF811A",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Database-v2.jpg",
SupportBatch: false, // supportBatch: 2
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "SQL Customization",
EnUSDescription: "Complete the operations of adding, deleting, modifying and querying the database based on user-defined SQL",
},
{
ID: 13,
Name: "输出",
Type: NodeTypeOutputEmitter,
Category: "input&output", // Mapped from cate_list
Desc: "节点从“消息”更名为“输出”,支持中间过程的消息输出,支持流式和非流式两种方式",
Color: "#5C62FF",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Output-v2.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
PreFillZero: true,
CallbackEnabled: true,
InputSourceAware: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true, Stream: true},
StreamSourceEOFAware: true,
IncrementalOutput: true,
},
EnUSName: "Output",
EnUSDescription: "The node is renamed from \"message\" to \"output\", Supports message output in the intermediate process and streaming and non-streaming methods",
},
{
ID: 15,
Name: "文本处理",
Type: NodeTypeTextProcessor,
Category: "utilities", // Mapped from cate_list
Desc: "用于处理多个字符串类型变量的格式",
Color: "#3071F2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-StrConcat-v2.jpg",
SupportBatch: false, // supportBatch: 2
ExecutableMeta: ExecutableMeta{
PreFillZero: true,
CallbackEnabled: true,
InputSourceAware: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Text Processing",
EnUSDescription: "The format used for handling multiple string-type variables.",
},
{
ID: 18,
Name: "问答",
Type: NodeTypeQuestionAnswer,
Category: "utilities", // Mapped from cate_list
Desc: "支持中间向用户提问问题,支持预置选项提问和开放式问题提问两种方式",
Color: "#3071F2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Direct-Question-v2.jpg",
SupportBatch: false, // supportBatch: 1
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
CallbackEnabled: true,
MayUseChatModel: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Question",
EnUSDescription: "Support asking questions to the user in the middle of the conversation, with both preset options and open-ended questions",
},
{
ID: 19,
Name: "终止循环",
Type: NodeTypeBreak,
Category: "logic", // Mapped from cate_list
Desc: "用于立即终止当前所在的循环,跳出循环体",
Color: "#00B2B2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Break-v2.jpg",
SupportBatch: false, // supportBatch: 1
ExecutableMeta: ExecutableMeta{
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Break",
EnUSDescription: "Used to immediately terminate the current loop and jump out of the loop",
},
{
ID: 20,
Name: "设置变量",
Type: NodeTypeVariableAssignerWithinLoop,
Category: "logic", // Mapped from cate_list
Desc: "用于重置循环变量的值,使其下次循环使用重置后的值",
Color: "#00B2B2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-LoopSetVariable-v2.jpg",
SupportBatch: false, // supportBatch: 1
ExecutableMeta: ExecutableMeta{
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Set Variable",
EnUSDescription: "Used to reset the value of the loop variable so that it uses the reset value in the next iteration",
},
{
ID: 21,
Name: "循环",
Type: NodeTypeLoop,
Category: "logic", // Mapped from cate_list
Desc: "用于通过设定循环次数和逻辑,重复执行一系列任务",
Color: "#00B2B2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Loop-v2.jpg",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
IsComposite: true,
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
CallbackEnabled: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Loop",
EnUSDescription: "Used to repeatedly execute a series of tasks by setting the number of iterations and logic",
},
{
ID: 22,
Name: "意图识别",
Type: NodeTypeIntentDetector,
Category: "logic", // Mapped from cate_list
Desc: "用于用户输入的意图识别,并将其与预设意图选项进行匹配。",
Color: "#00B2B2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Intent-v2.jpg",
SupportBatch: false, // supportBatch: 1
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
CallbackEnabled: true,
MayUseChatModel: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Intent recognition",
EnUSDescription: "Used for recognizing the intent in user input and matching it with preset intent options.",
},
{
ID: 27,
Name: "知识库写入",
Type: NodeTypeKnowledgeIndexer,
Category: "data", // Mapped from cate_list
Desc: "写入节点可以添加 文本类型 的知识库,仅可以添加一个知识库",
Color: "#FF811A",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-KnowledgeWriting-v2.jpg",
SupportBatch: false, // supportBatch: 1
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Knowledge writing",
EnUSDescription: "The write node can add a knowledge base of type text. Only one knowledge base can be added.",
},
{
ID: 28,
Name: "批处理",
Type: NodeTypeBatch,
Category: "logic", // Mapped from cate_list
Desc: "通过设定批量运行次数和逻辑,运行批处理体内的任务",
Color: "#00B2B2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Batch-v2.jpg",
SupportBatch: false, // supportBatch: 1 (Corrected from previous assumption)
ExecutableMeta: ExecutableMeta{
IsComposite: true,
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
CallbackEnabled: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Batch",
EnUSDescription: "By setting the number of batch runs and logic, run the tasks in the batch body.",
},
{
ID: 29,
Name: "继续循环",
Type: NodeTypeContinue,
Category: "logic", // Mapped from cate_list
Desc: "用于终止当前循环,执行下次循环",
Color: "#00B2B2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Continue-v2.jpg",
SupportBatch: false, // supportBatch: 1
ExecutableMeta: ExecutableMeta{
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Continue",
EnUSDescription: "Used to immediately terminate the current loop and execute next loop",
},
{
ID: 30,
Name: "输入",
Type: NodeTypeInputReceiver,
Category: "input&output", // Mapped from cate_list
Desc: "支持中间过程的信息输入",
Color: "#5C62FF",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Input-v2.jpg",
SupportBatch: false, // supportBatch: 1
ExecutableMeta: ExecutableMeta{
PostFillNil: true,
CallbackEnabled: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Input",
EnUSDescription: "Support intermediate information input",
},
{
ID: 31,
Name: "注释",
Type: "",
Category: "", // Not found in cate_list
Desc: "comment_desc", // Placeholder from JSON
Color: "",
IconURL: "comment_icon", // Placeholder from JSON
SupportBatch: false, // supportBatch: 1
EnUSName: "Comment",
},
{
ID: 32,
Name: "变量聚合",
Type: NodeTypeVariableAggregator,
Category: "logic", // Mapped from cate_list
Desc: "对多个分支的输出进行聚合处理",
Color: "#00B2B2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/VariableMerge-icon.jpg",
SupportBatch: false, // supportBatch: 1
ExecutableMeta: ExecutableMeta{
PostFillNil: true,
CallbackEnabled: true,
InputSourceAware: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true, Transform: true},
},
EnUSName: "Variable Merge",
EnUSDescription: "Aggregate the outputs of multiple branches.",
},
{
ID: 37,
Name: "查询消息列表",
Type: NodeTypeMessageList,
Category: "message", // Mapped from cate_list
Desc: "用于查询消息列表",
Color: "#F2B600",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Conversation-List.jpeg",
SupportBatch: false, // supportBatch: 1
Disabled: true,
ExecutableMeta: ExecutableMeta{
PreFillZero: true,
PostFillNil: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Query message list",
EnUSDescription: "Used to query the message list",
},
{
ID: 38,
Name: "清除上下文",
Type: NodeTypeClearMessage,
Category: "conversation_history", // Mapped from cate_list
Desc: "用于清空会话历史清空后LLM看到的会话历史为空",
Color: "#F2B600",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Conversation-Delete.jpeg",
SupportBatch: false, // supportBatch: 1
Disabled: true,
ExecutableMeta: ExecutableMeta{
PreFillZero: true,
PostFillNil: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Clear conversation history",
EnUSDescription: "Used to clear conversation history. After clearing, the conversation history visible to the LLM node will be empty.",
},
{
ID: 39,
Name: "创建会话",
Type: NodeTypeCreateConversation,
Category: "conversation_management", // Mapped from cate_list
Desc: "用于创建会话",
Color: "#F2B600",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Conversation-Create.jpeg",
SupportBatch: false, // supportBatch: 1
Disabled: true,
ExecutableMeta: ExecutableMeta{
PreFillZero: true,
PostFillNil: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Create conversation",
EnUSDescription: "This node is used to create a conversation.",
},
{
ID: 40,
Name: "变量赋值",
Type: NodeTypeVariableAssigner,
Category: "data", // Mapped from cate_list
Desc: "用于给支持写入的变量赋值,包括应用变量、用户变量",
Color: "#FF811A",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/Variable.jpg",
SupportBatch: false, // supportBatch: 1
ExecutableMeta: ExecutableMeta{
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Variable assign",
EnUSDescription: "Assigns values to variables that support the write operation, including app and user variables.",
},
{
ID: 42,
Name: "更新数据",
Type: NodeTypeDatabaseUpdate,
Category: "database", // Mapped from cate_list
Desc: "修改表中已存在的数据记录,用户指定更新条件和内容来更新数据",
Color: "#F2B600",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-database-update.jpg", // Corrected Icon URL from JSON
SupportBatch: false, // supportBatch: 1
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
CallbackEnabled: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Update Data",
EnUSDescription: "Modify the existing data records in the table, and the user specifies the update conditions and contents to update the data",
},
{
ID: 43,
Name: "查询数据", // Corrected Name from JSON (was "insert data")
Type: NodeTypeDatabaseQuery,
Category: "database", // Mapped from cate_list
Desc: "从表获取数据,用户可定义查询条件、选择列等,输出符合条件的数据", // Corrected Desc from JSON
Color: "#F2B600",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icaon-database-select.jpg", // Corrected Icon URL from JSON
SupportBatch: false, // supportBatch: 1
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
CallbackEnabled: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Query Data",
EnUSDescription: "Query data from the table, and the user can define query conditions, select columns, etc., and output the data that meets the conditions",
},
{
ID: 44,
Name: "删除数据",
Type: NodeTypeDatabaseDelete,
Category: "database", // Mapped from cate_list
Desc: "从表中删除数据记录,用户指定删除条件来删除符合条件的记录", // Corrected Desc from JSON
Color: "#F2B600",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-database-delete.jpg", // Corrected Icon URL from JSON
SupportBatch: false, // supportBatch: 1
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
CallbackEnabled: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Delete Data",
EnUSDescription: "Delete data records from the table, and the user specifies the deletion conditions to delete the records that meet the conditions",
},
{
ID: 45,
Name: "HTTP 请求",
Type: NodeTypeHTTPRequester,
Category: "utilities", // Mapped from cate_list
Desc: "用于发送API请求从接口返回数据", // Corrected Desc from JSON
Color: "#3071F2",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-HTTP.png", // Corrected Icon URL from JSON
SupportBatch: false, // supportBatch: 1
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
CallbackEnabled: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "HTTP request",
EnUSDescription: "It is used to send API requests and return data from the interface.",
},
{
ID: 46,
Name: "新增数据", // Corrected Name from JSON (was "Query Data")
Type: NodeTypeDatabaseInsert,
Category: "database", // Mapped from cate_list
Desc: "向表添加新数据记录,用户输入数据内容后插入数据库", // Corrected Desc from JSON
Color: "#F2B600",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-database-insert.jpg", // Corrected Icon URL from JSON
SupportBatch: false, // supportBatch: 1
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
CallbackEnabled: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Add Data",
EnUSDescription: "Add new data records to the table, and insert them into the database after the user enters the data content",
},
{
ID: 58,
Name: "JSON 序列化",
Type: NodeTypeJsonSerialization,
Category: "utilities",
Desc: "用于把变量转化为JSON字符串",
Color: "F2B600",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-to_json.png",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
CallbackEnabled: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "JSON serialization",
EnUSDescription: "Convert variable to JSON string",
},
{
ID: 59,
Name: "JSON 反序列化",
Type: NodeTypeJsonDeserialization,
Category: "utilities",
Desc: "用于将JSON字符串解析为变量",
Color: "F2B600",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-from_json.png",
SupportBatch: false,
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
CallbackEnabled: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "JSON deserialization",
EnUSDescription: "Parse JSON string to variable",
},
{
ID: 60,
Name: "知识库删除",
Type: NodeTypeKnowledgeDeleter,
Category: "data", // Mapped from cate_list
Desc: "用于删除知识库中的文档",
Color: "#FF811A",
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icons-dataset-delete.png",
SupportBatch: false, // supportBatch: 1
ExecutableMeta: ExecutableMeta{
DefaultTimeoutMS: 60 * 1000, // 1 minute
PreFillZero: true,
PostFillNil: true,
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
},
EnUSName: "Knowledge delete",
EnUSDescription: "The delete node can delete a document in knowledge base.",
},
// --- End of nodes parsed from template_list ---
}
// PluginNodeMetas holds metadata for specific plugin API entity.
var PluginNodeMetas []*PluginNodeMeta
// PluginCategoryMetas holds metadata for plugin category entity.
var PluginCategoryMetas []*PluginCategoryMeta
func NodeMetaByNodeType(t NodeType) *NodeTypeMeta {
for _, meta := range NodeTypeMetas {
if meta.Type == t {
return meta
}
}
return nil
}
const defaultZhCNInitCanvasJsonSchema = `{
"nodes": [
{
"id": "100001",
"type": "1",
"meta": {
"position": {
"x": 0,
"y": 0
}
},
"data": {
"nodeMeta": {
"description": "工作流的起始节点,用于设定启动工作流需要的信息",
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start.png",
"subTitle": "",
"title": "开始"
},
"outputs": [
{
"type": "string",
"name": "input",
"required": false
}
],
"trigger_parameters": [
{
"type": "string",
"name": "input",
"required": false
}
]
}
},
{
"id": "900001",
"type": "2",
"meta": {
"position": {
"x": 1000,
"y": 0
}
},
"data": {
"nodeMeta": {
"description": "工作流的最终节点,用于返回工作流运行后的结果信息",
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End.png",
"subTitle": "",
"title": "结束"
},
"inputs": {
"terminatePlan": "returnVariables",
"inputParameters": [
{
"name": "output",
"input": {
"type": "string",
"value": {
"type": "ref",
"content": {
"source": "block-output",
"blockID": "",
"name": ""
}
}
}
}
]
}
}
}
],
"edges": [],
"versions": {
"loop": "v2"
}
}`
const defaultEnUSInitCanvasJsonSchema = `{
"nodes": [
{
"id": "100001",
"type": "1",
"meta": {
"position": {
"x": 0,
"y": 0
}
},
"data": {
"nodeMeta": {
"description": "The starting node of the workflow, used to set the information needed to initiate the workflow.",
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start.png",
"subTitle": "",
"title": "Start"
},
"outputs": [
{
"type": "string",
"name": "input",
"required": false
}
],
"trigger_parameters": [
{
"type": "string",
"name": "input",
"required": false
}
]
}
},
{
"id": "900001",
"type": "2",
"meta": {
"position": {
"x": 1000,
"y": 0
}
},
"data": {
"nodeMeta": {
"description": "The final node of the workflow, used to return the result information after the workflow runs.",
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End.png",
"subTitle": "",
"title": "End"
},
"inputs": {
"terminatePlan": "returnVariables",
"inputParameters": [
{
"name": "output",
"input": {
"type": "string",
"value": {
"type": "ref",
"content": {
"source": "block-output",
"blockID": "",
"name": ""
}
}
}
}
]
}
}
}
],
"edges": [],
"versions": {
"loop": "v2"
}
}`
func GetDefaultInitCanvasJsonSchema(locale i18n.Locale) string {
return ternary.IFElse(locale == i18n.LocaleEN, defaultEnUSInitCanvasJsonSchema, defaultZhCNInitCanvasJsonSchema)
}

View File

@ -19,24 +19,48 @@ package vo
import (
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
"github.com/coze-dev/coze-studio/backend/pkg/i18n"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
)
// Canvas is the definition of FRONTEND schema for a workflow.
type Canvas struct {
Nodes []*Node `json:"nodes"`
Edges []*Edge `json:"edges"`
Versions any `json:"versions"`
}
// Node represents a node within a workflow canvas.
type Node struct {
ID string `json:"id"`
Type BlockType `json:"type"`
Meta any `json:"meta"`
Data *Data `json:"data"`
Blocks []*Node `json:"blocks,omitempty"`
Edges []*Edge `json:"edges,omitempty"`
Version string `json:"version,omitempty"`
// ID is the unique node ID within the workflow.
// In normal use cases, this ID is generated by frontend.
// It does NOT need to be unique between parent workflow and sub workflows.
// The Entry node and Exit node of a workflow always have fixed node IDs: 100001 and 900001.
ID string `json:"id"`
parent *Node
// Type is the Node Type of this node instance.
// It corresponds to the string value of 'ID' field of NodeMeta.
Type string `json:"type"`
// Meta holds meta data used by frontend, such as the node's position within canvas.
Meta any `json:"meta"`
// Data holds the actual configurations of a node, such as inputs, outputs and exception handling.
// It also holds exclusive configurations for different node types, such as LLM configurations.
Data *Data `json:"data"`
// Blocks holds the sub nodes of this node.
// It is only used when the node type is Composite, such as NodeTypeBatch and NodeTypeLoop.
Blocks []*Node `json:"blocks,omitempty"`
// Edges are the connections between nodes.
// Strictly corresponds to connections drawn on canvas.
Edges []*Edge `json:"edges,omitempty"`
// Version is the version of this node type's schema.
Version string `json:"version,omitempty"`
parent *Node // if this node is within a composite node, coze will set this. No need to set manually
}
func (n *Node) SetParent(parent *Node) {
@ -47,7 +71,7 @@ func (n *Node) Parent() *Node {
return n.parent
}
type NodeMeta struct {
type NodeMetaFE struct {
Title string `json:"title,omitempty"`
Description string `json:"description,omitempty"`
Icon string `json:"icon,omitempty"`
@ -62,52 +86,88 @@ type Edge struct {
TargetPortID string `json:"targetPortID,omitempty"`
}
// Data holds the actual configuration of a Node.
type Data struct {
Meta *NodeMeta `json:"nodeMeta,omitempty"`
Outputs []any `json:"outputs,omitempty"` // either []*Variable or []*Param
Inputs *Inputs `json:"inputs,omitempty"`
Size any `json:"size,omitempty"`
// Meta is the meta data of this node. Only used by frontend.
Meta *NodeMetaFE `json:"nodeMeta,omitempty"`
// Outputs configures the output fields and their types.
// Outputs can be either []*Variable (most of the cases, just fields and types),
// or []*Param (used by composite nodes as they need to refer to outputs of sub nodes)
Outputs []any `json:"outputs,omitempty"`
// Inputs configures ALL input information of a node,
// including fixed input fields and dynamic input fields defined by user.
Inputs *Inputs `json:"inputs,omitempty"`
// Size configures the size of this node in frontend.
// Only used by NodeTypeComment.
Size any `json:"size,omitempty"`
}
type Inputs struct {
InputParameters []*Param `json:"inputParameters"`
Content *BlockInput `json:"content"`
TerminatePlan *TerminatePlan `json:"terminatePlan,omitempty"`
StreamingOutput bool `json:"streamingOutput,omitempty"`
CallTransferVoice bool `json:"callTransferVoice,omitempty"`
ChatHistoryWriting string `json:"chatHistoryWriting,omitempty"`
LLMParam any `json:"llmParam,omitempty"` // The LLMParam type may be one of the LLMParam or IntentDetectorLLMParam type or QALLMParam type
FCParam *FCParam `json:"fcParam,omitempty"`
SettingOnError *SettingOnError `json:"settingOnError,omitempty"`
// InputParameters are the fields defined by user for this particular node.
InputParameters []*Param `json:"inputParameters"`
// SettingOnError configures common error handling strategy for nodes.
// NOTE: enable in frontend node's form first.
SettingOnError *SettingOnError `json:"settingOnError,omitempty"`
// NodeBatchInfo configures batch mode for nodes.
// NOTE: not to be confused with NodeTypeBatch.
NodeBatchInfo *NodeBatch `json:"batch,omitempty"`
// LLMParam may be one of the LLMParam or IntentDetectorLLMParam or SimpleLLMParam.
// Shared between most nodes requiring an ChatModel to function.
LLMParam any `json:"llmParam,omitempty"`
*OutputEmitter // exclusive configurations for NodeTypeEmitter and NodeTypeExit in Answer mode
*Exit // exclusive configurations for NodeTypeExit
*LLM // exclusive configurations for NodeTypeLLM
*Loop // exclusive configurations for NodeTypeLoop
*Selector // exclusive configurations for NodeTypeSelector
*TextProcessor // exclusive configurations for NodeTypeTextProcessor
*SubWorkflow // exclusive configurations for NodeTypeSubWorkflow
*IntentDetector // exclusive configurations for NodeTypeIntentDetector
*DatabaseNode // exclusive configurations for various Database nodes
*HttpRequestNode // exclusive configurations for NodeTypeHTTPRequester
*Knowledge // exclusive configurations for various Knowledge nodes
*CodeRunner // exclusive configurations for NodeTypeCodeRunner
*PluginAPIParam // exclusive configurations for NodeTypePlugin
*VariableAggregator // exclusive configurations for NodeTypeVariableAggregator
*VariableAssigner // exclusive configurations for NodeTypeVariableAssigner
*QA // exclusive configurations for NodeTypeQuestionAnswer
*Batch // exclusive configurations for NodeTypeBatch
*Comment // exclusive configurations for NodeTypeComment
*InputReceiver // exclusive configurations for NodeTypeInputReceiver
}
type OutputEmitter struct {
Content *BlockInput `json:"content"`
StreamingOutput bool `json:"streamingOutput,omitempty"`
}
type Exit struct {
TerminatePlan *TerminatePlan `json:"terminatePlan,omitempty"`
}
type LLM struct {
FCParam *FCParam `json:"fcParam,omitempty"`
}
type Loop struct {
LoopType LoopType `json:"loopType,omitempty"`
LoopCount *BlockInput `json:"loopCount,omitempty"`
VariableParameters []*Param `json:"variableParameters,omitempty"`
}
type Selector struct {
Branches []*struct {
Condition struct {
Logic LogicType `json:"logic"`
Conditions []*Condition `json:"conditions"`
} `json:"condition"`
} `json:"branches,omitempty"`
NodeBatchInfo *NodeBatch `json:"batch,omitempty"` // node in batch mode
*TextProcessor
*SubWorkflow
*IntentDetector
*DatabaseNode
*HttpRequestNode
*KnowledgeIndexer
*CodeRunner
*PluginAPIParam
*VariableAggregator
*VariableAssigner
*QA
*Batch
*Comment
OutputSchema string `json:"outputSchema,omitempty"`
}
type Comment struct {
@ -127,7 +187,7 @@ type VariableAssigner struct {
type LLMParam = []*Param
type IntentDetectorLLMParam = map[string]any
type QALLMParam struct {
type SimpleLLMParam struct {
GenerationDiversity string `json:"generationDiversity"`
MaxTokens int `json:"maxTokens"`
ModelName string `json:"modelName"`
@ -248,7 +308,7 @@ type CodeRunner struct {
Language int64 `json:"language"`
}
type KnowledgeIndexer struct {
type Knowledge struct {
DatasetParam []*Param `json:"datasetParam,omitempty"`
StrategyParam StrategyParam `json:"strategyParam,omitempty"`
}
@ -384,23 +444,54 @@ type ChatHistorySetting struct {
type Intent struct {
Name string `json:"name"`
}
// Param is a node's field with type and source info.
type Param struct {
Name string `json:"name,omitempty"`
Input *BlockInput `json:"input,omitempty"`
Left *BlockInput `json:"left,omitempty"`
Right *BlockInput `json:"right,omitempty"`
// Name is the field's name.
Name string `json:"name,omitempty"`
// Input is the configurations for normal, singular field.
Input *BlockInput `json:"input,omitempty"`
// Left is the configurations for the left half of an expression,
// such as an assignment in NodeTypeVariableAssigner.
Left *BlockInput `json:"left,omitempty"`
// Right is the configuration for the right half of an expression.
Right *BlockInput `json:"right,omitempty"`
// Variables are configurations for a group of fields.
// Only used in NodeTypeVariableAggregator.
Variables []*BlockInput `json:"variables,omitempty"`
}
// Variable is the configuration of a node's field, either input or output.
type Variable struct {
Name string `json:"name"`
Type VariableType `json:"type"`
Required bool `json:"required,omitempty"`
AssistType AssistType `json:"assistType,omitempty"`
Schema any `json:"schema,omitempty"` // either []*Variable (for object) or *Variable (for list)
Description string `json:"description,omitempty"`
ReadOnly bool `json:"readOnly,omitempty"`
DefaultValue any `json:"defaultValue,omitempty"`
// Name is the field's name as defined on canvas.
Name string `json:"name"`
// Type is the field's data type, such as string, integer, number, object, array, etc.
Type VariableType `json:"type"`
// Required is set to true if you checked the 'required box' on this field
Required bool `json:"required,omitempty"`
// AssistType is the 'secondary' type of string fields, such as different types of file and image, or time.
AssistType AssistType `json:"assistType,omitempty"`
// Schema contains detailed info for sub-fields of an object field, or element type of an array.
Schema any `json:"schema,omitempty"` // either []*Variable (for object) or *Variable (for list)
// Description describes the field's intended use. Used on Entry node. Useful for workflow tools.
Description string `json:"description,omitempty"`
// ReadOnly indicates a field is not to be set by Node's business logic.
// e.g. the ErrorBody field when exception strategy is configured.
ReadOnly bool `json:"readOnly,omitempty"`
// DefaultValue configures the 'default value' if this field is missing in input.
// Effective only in Entry node.
DefaultValue any `json:"defaultValue,omitempty"`
}
type BlockInput struct {
@ -436,48 +527,6 @@ type SubWorkflow struct {
SpaceID string `json:"spaceId,omitempty"`
}
// BlockType is the enumeration of node types for front-end canvas schema.
// To add a new BlockType, start from a really big number such as 1000, to avoid conflict with future extensions.
type BlockType string
func (b BlockType) String() string {
return string(b)
}
const (
BlockTypeBotStart BlockType = "1"
BlockTypeBotEnd BlockType = "2"
BlockTypeBotLLM BlockType = "3"
BlockTypeBotAPI BlockType = "4"
BlockTypeBotCode BlockType = "5"
BlockTypeBotDataset BlockType = "6"
BlockTypeCondition BlockType = "8"
BlockTypeBotSubWorkflow BlockType = "9"
BlockTypeDatabase BlockType = "12"
BlockTypeBotMessage BlockType = "13"
BlockTypeBotText BlockType = "15"
BlockTypeQuestion BlockType = "18"
BlockTypeBotBreak BlockType = "19"
BlockTypeBotLoopSetVariable BlockType = "20"
BlockTypeBotLoop BlockType = "21"
BlockTypeBotIntent BlockType = "22"
BlockTypeBotDatasetWrite BlockType = "27"
BlockTypeBotInput BlockType = "30"
BlockTypeBotBatch BlockType = "28"
BlockTypeBotContinue BlockType = "29"
BlockTypeBotComment BlockType = "31"
BlockTypeBotVariableMerge BlockType = "32"
BlockTypeBotAssignVariable BlockType = "40"
BlockTypeDatabaseUpdate BlockType = "42"
BlockTypeDatabaseSelect BlockType = "43"
BlockTypeDatabaseDelete BlockType = "44"
BlockTypeBotHttp BlockType = "45"
BlockTypeDatabaseInsert BlockType = "46"
BlockTypeJsonSerialization BlockType = "58"
BlockTypeJsonDeserialization BlockType = "59"
BlockTypeBotDatasetDelete BlockType = "60"
)
type VariableType string
const (
@ -536,19 +585,31 @@ const (
type ErrorProcessType int
const (
ErrorProcessTypeThrow ErrorProcessType = 1
ErrorProcessTypeDefault ErrorProcessType = 2
ErrorProcessTypeExceptionBranch ErrorProcessType = 3
ErrorProcessTypeThrow ErrorProcessType = 1 // throws the error as usual
ErrorProcessTypeReturnDefaultData ErrorProcessType = 2 // return DataOnErr configured in SettingOnError
ErrorProcessTypeExceptionBranch ErrorProcessType = 3 // executes the exception branch on error
)
// SettingOnError contains common error handling strategy.
type SettingOnError struct {
DataOnErr string `json:"dataOnErr,omitempty"`
Switch bool `json:"switch,omitempty"`
// DataOnErr defines the JSON result to be returned on error.
DataOnErr string `json:"dataOnErr,omitempty"`
// Switch defines whether ANY error handling strategy is active.
// If set to false, it's equivalent to set ProcessType = ErrorProcessTypeThrow
Switch bool `json:"switch,omitempty"`
// ProcessType determines the error handling strategy for this node.
ProcessType *ErrorProcessType `json:"processType,omitempty"`
RetryTimes int64 `json:"retryTimes,omitempty"`
TimeoutMs int64 `json:"timeoutMs,omitempty"`
Ext *struct {
BackupLLMParam string `json:"backupLLMParam,omitempty"` // only for LLM Node, marshaled from QALLMParam
// RetryTimes determines how many times to retry. 0 means no retry.
// If positive, any retries will be executed immediately after error.
RetryTimes int64 `json:"retryTimes,omitempty"`
// TimeoutMs sets the timeout duration in millisecond.
// If any retry happens, ALL retry attempts accumulates to the same timeout threshold.
TimeoutMs int64 `json:"timeoutMs,omitempty"`
// Ext sets any extra settings specific to NodeType
Ext *struct {
// BackupLLMParam is only for LLM Node, marshaled from SimpleLLMParam.
// If retry happens, the backup LLM will be used instead of the main LLM.
BackupLLMParam string `json:"backupLLMParam,omitempty"`
} `json:"ext,omitempty"`
}
@ -597,32 +658,8 @@ const (
LoopTypeInfinite LoopType = "infinite"
)
type WorkflowIdentity struct {
ID string `json:"id"`
Version string `json:"version"`
}
func (c *Canvas) GetAllSubWorkflowIdentities() []*WorkflowIdentity {
workflowEntities := make([]*WorkflowIdentity, 0)
var collectSubWorkFlowEntities func(nodes []*Node)
collectSubWorkFlowEntities = func(nodes []*Node) {
for _, n := range nodes {
if n.Type == BlockTypeBotSubWorkflow {
workflowEntities = append(workflowEntities, &WorkflowIdentity{
ID: n.Data.Inputs.WorkflowID,
Version: n.Data.Inputs.WorkflowVersion,
})
}
if len(n.Blocks) > 0 {
collectSubWorkFlowEntities(n.Blocks)
}
}
}
collectSubWorkFlowEntities(c.Nodes)
return workflowEntities
type InputReceiver struct {
OutputSchema string `json:"outputSchema,omitempty"`
}
func GenerateNodeIDForBatchMode(key string) string {
@ -632,3 +669,163 @@ func GenerateNodeIDForBatchMode(key string) string {
func IsGeneratedNodeForBatchMode(key string, parentKey string) bool {
return key == GenerateNodeIDForBatchMode(parentKey)
}
const defaultZhCNInitCanvasJsonSchema = `{
"nodes": [
{
"id": "100001",
"type": "1",
"meta": {
"position": {
"x": 0,
"y": 0
}
},
"data": {
"nodeMeta": {
"description": "工作流的起始节点,用于设定启动工作流需要的信息",
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start.png",
"subTitle": "",
"title": "开始"
},
"outputs": [
{
"type": "string",
"name": "input",
"required": false
}
],
"trigger_parameters": [
{
"type": "string",
"name": "input",
"required": false
}
]
}
},
{
"id": "900001",
"type": "2",
"meta": {
"position": {
"x": 1000,
"y": 0
}
},
"data": {
"nodeMeta": {
"description": "工作流的最终节点,用于返回工作流运行后的结果信息",
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End.png",
"subTitle": "",
"title": "结束"
},
"inputs": {
"terminatePlan": "returnVariables",
"inputParameters": [
{
"name": "output",
"input": {
"type": "string",
"value": {
"type": "ref",
"content": {
"source": "block-output",
"blockID": "",
"name": ""
}
}
}
}
]
}
}
}
],
"edges": [],
"versions": {
"loop": "v2"
}
}`
const defaultEnUSInitCanvasJsonSchema = `{
"nodes": [
{
"id": "100001",
"type": "1",
"meta": {
"position": {
"x": 0,
"y": 0
}
},
"data": {
"nodeMeta": {
"description": "The starting node of the workflow, used to set the information needed to initiate the workflow.",
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start.png",
"subTitle": "",
"title": "Start"
},
"outputs": [
{
"type": "string",
"name": "input",
"required": false
}
],
"trigger_parameters": [
{
"type": "string",
"name": "input",
"required": false
}
]
}
},
{
"id": "900001",
"type": "2",
"meta": {
"position": {
"x": 1000,
"y": 0
}
},
"data": {
"nodeMeta": {
"description": "The final node of the workflow, used to return the result information after the workflow runs.",
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End.png",
"subTitle": "",
"title": "End"
},
"inputs": {
"terminatePlan": "returnVariables",
"inputParameters": [
{
"name": "output",
"input": {
"type": "string",
"value": {
"type": "ref",
"content": {
"source": "block-output",
"blockID": "",
"name": ""
}
}
}
}
]
}
}
}
],
"edges": [],
"versions": {
"loop": "v2"
}
}`
func GetDefaultInitCanvasJsonSchema(locale i18n.Locale) string {
return ternary.IFElse(locale == i18n.LocaleEN, defaultEnUSInitCanvasJsonSchema, defaultZhCNInitCanvasJsonSchema)
}

View File

@ -47,12 +47,6 @@ type FieldSource struct {
Val any `json:"val,omitempty"`
}
type ImplicitNodeDependency struct {
NodeID string
FieldPath compose.FieldPath
TypeInfo *TypeInfo
}
type TypeInfo struct {
Type DataType `json:"type"`
ElemTypeInfo *TypeInfo `json:"elem_type_info,omitempty"`

View File

@ -69,13 +69,6 @@ type IDVersionPair struct {
Version string
}
type Stage uint8
const (
StageDraft Stage = 1
StagePublished Stage = 2
)
type WorkflowBasic struct {
ID int64
Version string

View File

@ -58,6 +58,11 @@ import (
"github.com/coze-dev/coze-studio/backend/types/consts"
)
func TestMain(m *testing.M) {
RegisterAllNodeAdaptors()
m.Run()
}
func TestIntentDetectorAndDatabase(t *testing.T) {
mockey.PatchConvey("intent detector & database custom sql", t, func() {
data, err := os.ReadFile("../examples/intent_detector_database_custom_sql.json")

View File

@ -26,10 +26,11 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
func WorkflowSchemaFromNode(ctx context.Context, c *vo.Canvas, nodeID string) (
*compose.WorkflowSchema, error) {
*schema.WorkflowSchema, error) {
var (
n *vo.Node
nodeFinder func(nodes []*vo.Node) *vo.Node
@ -62,35 +63,27 @@ func WorkflowSchemaFromNode(ctx context.Context, c *vo.Canvas, nodeID string) (
n = batchN
}
implicitDependencies, err := extractImplicitDependency(n, c.Nodes)
if err != nil {
return nil, err
}
opts := make([]OptionFn, 0, 1)
if len(implicitDependencies) > 0 {
opts = append(opts, WithImplicitNodeDependencies(implicitDependencies))
}
nsList, hierarchy, err := NodeToNodeSchema(ctx, n, opts...)
nsList, hierarchy, err := NodeToNodeSchema(ctx, n, c)
if err != nil {
return nil, err
}
var (
ns *compose.NodeSchema
innerNodes map[vo.NodeKey]*compose.NodeSchema // inner nodes of the composite node if nodeKey is composite
connections []*compose.Connection
ns *schema.NodeSchema
innerNodes map[vo.NodeKey]*schema.NodeSchema // inner nodes of the composite node if nodeKey is composite
connections []*schema.Connection
)
if len(nsList) == 1 {
ns = nsList[0]
} else {
innerNodes = make(map[vo.NodeKey]*compose.NodeSchema)
innerNodes = make(map[vo.NodeKey]*schema.NodeSchema)
for i := range nsList {
one := nsList[i]
if _, ok := hierarchy[one.Key]; ok {
innerNodes[one.Key] = one
if one.Type == entity.NodeTypeContinue || one.Type == entity.NodeTypeBreak {
connections = append(connections, &compose.Connection{
connections = append(connections, &schema.Connection{
FromNode: one.Key,
ToNode: vo.NodeKey(nodeID),
})
@ -106,13 +99,13 @@ func WorkflowSchemaFromNode(ctx context.Context, c *vo.Canvas, nodeID string) (
}
const inputFillerKey = "input_filler"
connections = append(connections, &compose.Connection{
connections = append(connections, &schema.Connection{
FromNode: einoCompose.START,
ToNode: inputFillerKey,
}, &compose.Connection{
}, &schema.Connection{
FromNode: inputFillerKey,
ToNode: ns.Key,
}, &compose.Connection{
}, &schema.Connection{
FromNode: ns.Key,
ToNode: einoCompose.END,
})
@ -209,7 +202,7 @@ func WorkflowSchemaFromNode(ctx context.Context, c *vo.Canvas, nodeID string) (
return newOutput, nil
}
inputFiller := &compose.NodeSchema{
inputFiller := &schema.NodeSchema{
Key: inputFillerKey,
Type: entity.NodeTypeLambda,
Lambda: einoCompose.InvokableLambda(i),
@ -227,10 +220,16 @@ func WorkflowSchemaFromNode(ctx context.Context, c *vo.Canvas, nodeID string) (
OutputTypes: startOutputTypes,
}
trimmedSC := &compose.WorkflowSchema{
Nodes: append([]*compose.NodeSchema{ns, inputFiller}, maps.Values(innerNodes)...),
branches, err := schema.BuildBranches(connections)
if err != nil {
return nil, err
}
trimmedSC := &schema.WorkflowSchema{
Nodes: append([]*schema.NodeSchema{ns, inputFiller}, maps.Values(innerNodes)...),
Connections: connections,
Hierarchy: hierarchy,
Branches: branches,
}
if enabled {

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,663 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package convert
import (
"fmt"
"strconv"
"strings"
einoCompose "github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
func CanvasVariableToTypeInfo(v *vo.Variable) (*vo.TypeInfo, error) {
tInfo := &vo.TypeInfo{
Required: v.Required,
Desc: v.Description,
}
switch v.Type {
case vo.VariableTypeString:
switch v.AssistType {
case vo.AssistTypeTime:
tInfo.Type = vo.DataTypeTime
case vo.AssistTypeNotSet:
tInfo.Type = vo.DataTypeString
default:
fileType, ok := assistTypeToFileType(v.AssistType)
if ok {
tInfo.Type = vo.DataTypeFile
tInfo.FileType = &fileType
} else {
return nil, fmt.Errorf("unsupported assist type: %v", v.AssistType)
}
}
case vo.VariableTypeInteger:
tInfo.Type = vo.DataTypeInteger
case vo.VariableTypeFloat:
tInfo.Type = vo.DataTypeNumber
case vo.VariableTypeBoolean:
tInfo.Type = vo.DataTypeBoolean
case vo.VariableTypeObject:
tInfo.Type = vo.DataTypeObject
tInfo.Properties = make(map[string]*vo.TypeInfo)
if v.Schema != nil {
for _, subVAny := range v.Schema.([]any) {
subV, err := vo.ParseVariable(subVAny)
if err != nil {
return nil, err
}
subTInfo, err := CanvasVariableToTypeInfo(subV)
if err != nil {
return nil, err
}
tInfo.Properties[subV.Name] = subTInfo
}
}
case vo.VariableTypeList:
tInfo.Type = vo.DataTypeArray
subVAny := v.Schema
subV, err := vo.ParseVariable(subVAny)
if err != nil {
return nil, err
}
subTInfo, err := CanvasVariableToTypeInfo(subV)
if err != nil {
return nil, err
}
tInfo.ElemTypeInfo = subTInfo
default:
return nil, fmt.Errorf("unsupported variable type: %s", v.Type)
}
return tInfo, nil
}
func CanvasBlockInputToTypeInfo(b *vo.BlockInput) (tInfo *vo.TypeInfo, err error) {
defer func() {
if err != nil {
err = vo.WrapIfNeeded(errno.ErrSchemaConversionFail, err)
}
}()
tInfo = &vo.TypeInfo{}
if b == nil {
return tInfo, nil
}
switch b.Type {
case vo.VariableTypeString:
switch b.AssistType {
case vo.AssistTypeTime:
tInfo.Type = vo.DataTypeTime
case vo.AssistTypeNotSet:
tInfo.Type = vo.DataTypeString
default:
fileType, ok := assistTypeToFileType(b.AssistType)
if ok {
tInfo.Type = vo.DataTypeFile
tInfo.FileType = &fileType
} else {
return nil, fmt.Errorf("unsupported assist type: %v", b.AssistType)
}
}
case vo.VariableTypeInteger:
tInfo.Type = vo.DataTypeInteger
case vo.VariableTypeFloat:
tInfo.Type = vo.DataTypeNumber
case vo.VariableTypeBoolean:
tInfo.Type = vo.DataTypeBoolean
case vo.VariableTypeObject:
tInfo.Type = vo.DataTypeObject
tInfo.Properties = make(map[string]*vo.TypeInfo)
if b.Schema != nil {
for _, subVAny := range b.Schema.([]any) {
if b.Value.Type == vo.BlockInputValueTypeRef {
subV, err := vo.ParseVariable(subVAny)
if err != nil {
return nil, err
}
subTInfo, err := CanvasVariableToTypeInfo(subV)
if err != nil {
return nil, err
}
tInfo.Properties[subV.Name] = subTInfo
} else if b.Value.Type == vo.BlockInputValueTypeObjectRef {
subV, err := parseParam(subVAny)
if err != nil {
return nil, err
}
subTInfo, err := CanvasBlockInputToTypeInfo(subV.Input)
if err != nil {
return nil, err
}
tInfo.Properties[subV.Name] = subTInfo
}
}
}
case vo.VariableTypeList:
tInfo.Type = vo.DataTypeArray
subVAny := b.Schema
subV, err := vo.ParseVariable(subVAny)
if err != nil {
return nil, err
}
subTInfo, err := CanvasVariableToTypeInfo(subV)
if err != nil {
return nil, err
}
tInfo.ElemTypeInfo = subTInfo
default:
return nil, fmt.Errorf("unsupported variable type: %s", b.Type)
}
return tInfo, nil
}
func CanvasBlockInputToFieldInfo(b *vo.BlockInput, path einoCompose.FieldPath, parentNode *vo.Node) (sources []*vo.FieldInfo, err error) {
value := b.Value
if value == nil {
return nil, fmt.Errorf("input %v has no value, type= %s", path, b.Type)
}
switch value.Type {
case vo.BlockInputValueTypeObjectRef:
sc := b.Schema
if sc == nil {
return nil, fmt.Errorf("input %v has no schema, type= %s", path, b.Type)
}
paramList, ok := sc.([]any)
if !ok {
return nil, fmt.Errorf("input %v schema not []any, type= %T", path, sc)
}
for i := range paramList {
paramAny := paramList[i]
param, err := parseParam(paramAny)
if err != nil {
return nil, err
}
copied := make([]string, len(path))
copy(copied, path)
subFieldInfo, err := CanvasBlockInputToFieldInfo(param.Input, append(copied, param.Name), parentNode)
if err != nil {
return nil, err
}
sources = append(sources, subFieldInfo...)
}
return sources, nil
case vo.BlockInputValueTypeLiteral:
content := value.Content
if content == nil {
return nil, fmt.Errorf("input %v is literal but has no value, type= %s", path, b.Type)
}
switch b.Type {
case vo.VariableTypeObject:
m := make(map[string]any)
if err = sonic.UnmarshalString(content.(string), &m); err != nil {
return nil, err
}
content = m
case vo.VariableTypeList:
l := make([]any, 0)
if err = sonic.UnmarshalString(content.(string), &l); err != nil {
return nil, err
}
content = l
case vo.VariableTypeInteger:
switch content.(type) {
case string:
content, err = strconv.ParseInt(content.(string), 10, 64)
if err != nil {
return nil, err
}
case int64:
content = content.(int64)
case float64:
content = int64(content.(float64))
default:
return nil, fmt.Errorf("unsupported variable type fot integer: %s", b.Type)
}
case vo.VariableTypeFloat:
switch content.(type) {
case string:
content, err = strconv.ParseFloat(content.(string), 64)
if err != nil {
return nil, err
}
case int64:
content = float64(content.(int64))
case float64:
content = content.(float64)
default:
return nil, fmt.Errorf("unsupported variable type for float: %s", b.Type)
}
case vo.VariableTypeBoolean:
switch content.(type) {
case string:
content, err = strconv.ParseBool(content.(string))
if err != nil {
return nil, err
}
case bool:
content = content.(bool)
default:
return nil, fmt.Errorf("unsupported variable type for boolean: %s", b.Type)
}
default:
}
return []*vo.FieldInfo{
{
Path: path,
Source: vo.FieldSource{
Val: content,
},
},
}, nil
case vo.BlockInputValueTypeRef:
content := value.Content
if content == nil {
return nil, fmt.Errorf("input %v is literal but has no value, type= %s", path, b.Type)
}
ref, err := parseBlockInputRef(content)
if err != nil {
return nil, err
}
fieldSource, err := CanvasBlockInputRefToFieldSource(ref)
if err != nil {
return nil, err
}
if parentNode != nil {
if fieldSource.Ref != nil && len(fieldSource.Ref.FromNodeKey) > 0 && fieldSource.Ref.FromNodeKey == vo.NodeKey(parentNode.ID) {
varRoot := fieldSource.Ref.FromPath[0]
if parentNode.Data.Inputs.Loop != nil {
for _, p := range parentNode.Data.Inputs.VariableParameters {
if p.Name == varRoot {
fieldSource.Ref.FromNodeKey = ""
pi := vo.ParentIntermediate
fieldSource.Ref.VariableType = &pi
}
}
}
}
}
return []*vo.FieldInfo{
{
Path: path,
Source: *fieldSource,
},
}, nil
default:
return nil, fmt.Errorf("unsupported value type: %s for blockInput type= %s", value.Type, b.Type)
}
}
func parseBlockInputRef(content any) (*vo.BlockInputReference, error) {
if bi, ok := content.(*vo.BlockInputReference); ok {
return bi, nil
}
m, ok := content.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid content type: %T when parse BlockInputRef", content)
}
marshaled, err := sonic.Marshal(m)
if err != nil {
return nil, err
}
p := &vo.BlockInputReference{}
if err := sonic.Unmarshal(marshaled, p); err != nil {
return nil, err
}
return p, nil
}
func parseParam(v any) (*vo.Param, error) {
if pa, ok := v.(*vo.Param); ok {
return pa, nil
}
m, ok := v.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid content type: %T when parse Param", v)
}
marshaled, err := sonic.Marshal(m)
if err != nil {
return nil, err
}
p := &vo.Param{}
if err := sonic.Unmarshal(marshaled, p); err != nil {
return nil, err
}
return p, nil
}
func CanvasBlockInputRefToFieldSource(r *vo.BlockInputReference) (*vo.FieldSource, error) {
switch r.Source {
case vo.RefSourceTypeBlockOutput:
if len(r.BlockID) == 0 {
return nil, fmt.Errorf("invalid BlockInputReference = %+v, BlockID is empty when source is block output", r)
}
parts := strings.Split(r.Name, ".") // an empty r.Name signals an all-to-all mapping
return &vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: vo.NodeKey(r.BlockID),
FromPath: parts,
},
}, nil
case vo.RefSourceTypeGlobalApp, vo.RefSourceTypeGlobalSystem, vo.RefSourceTypeGlobalUser:
if len(r.Path) == 0 {
return nil, fmt.Errorf("invalid BlockInputReference = %+v, Path is empty when source is variables", r)
}
var varType vo.GlobalVarType
switch r.Source {
case vo.RefSourceTypeGlobalApp:
varType = vo.GlobalAPP
case vo.RefSourceTypeGlobalSystem:
varType = vo.GlobalSystem
case vo.RefSourceTypeGlobalUser:
varType = vo.GlobalUser
default:
return nil, fmt.Errorf("invalid BlockInputReference = %+v, Source is invalid", r)
}
return &vo.FieldSource{
Ref: &vo.Reference{
VariableType: &varType,
FromPath: r.Path,
},
}, nil
default:
return nil, fmt.Errorf("unsupported ref source type: %s", r.Source)
}
}
func assistTypeToFileType(a vo.AssistType) (vo.FileSubType, bool) {
switch a {
case vo.AssistTypeNotSet:
return "", false
case vo.AssistTypeTime:
return "", false
case vo.AssistTypeImage:
return vo.FileTypeImage, true
case vo.AssistTypeAudio:
return vo.FileTypeAudio, true
case vo.AssistTypeVideo:
return vo.FileTypeVideo, true
case vo.AssistTypeDefault:
return vo.FileTypeDefault, true
case vo.AssistTypeDoc:
return vo.FileTypeDocument, true
case vo.AssistTypeExcel:
return vo.FileTypeExcel, true
case vo.AssistTypeCode:
return vo.FileTypeCode, true
case vo.AssistTypePPT:
return vo.FileTypePPT, true
case vo.AssistTypeTXT:
return vo.FileTypeTxt, true
case vo.AssistTypeSvg:
return vo.FileTypeSVG, true
case vo.AssistTypeVoice:
return vo.FileTypeVoice, true
case vo.AssistTypeZip:
return vo.FileTypeZip, true
default:
panic("impossible")
}
}
func SetInputsForNodeSchema(n *vo.Node, ns *schema.NodeSchema) error {
if n.Data.Inputs == nil {
return nil
}
inputParams := n.Data.Inputs.InputParameters
if len(inputParams) == 0 {
return nil
}
for _, param := range inputParams {
name := param.Name
tInfo, err := CanvasBlockInputToTypeInfo(param.Input)
if err != nil {
return err
}
ns.SetInputType(name, tInfo)
sources, err := CanvasBlockInputToFieldInfo(param.Input, einoCompose.FieldPath{name}, n.Parent())
if err != nil {
return err
}
ns.AddInputSource(sources...)
}
return nil
}
func SetOutputTypesForNodeSchema(n *vo.Node, ns *schema.NodeSchema) error {
for _, vAny := range n.Data.Outputs {
v, err := vo.ParseVariable(vAny)
if err != nil {
return err
}
tInfo, err := CanvasVariableToTypeInfo(v)
if err != nil {
return err
}
if v.ReadOnly {
if v.Name == "errorBody" { // reserved output fields when exception happens
continue
}
}
ns.SetOutputType(v.Name, tInfo)
}
return nil
}
func SetOutputsForNodeSchema(n *vo.Node, ns *schema.NodeSchema) error {
for _, vAny := range n.Data.Outputs {
param, err := parseParam(vAny)
if err != nil {
return err
}
name := param.Name
tInfo, err := CanvasBlockInputToTypeInfo(param.Input)
if err != nil {
return err
}
ns.SetOutputType(name, tInfo)
sources, err := CanvasBlockInputToFieldInfo(param.Input, einoCompose.FieldPath{name}, n.Parent())
if err != nil {
return err
}
ns.AddOutputSource(sources...)
}
return nil
}
func BlockInputToNamedTypeInfo(name string, b *vo.BlockInput) (*vo.NamedTypeInfo, error) {
tInfo := &vo.NamedTypeInfo{
Name: name,
}
if b == nil {
return tInfo, nil
}
switch b.Type {
case vo.VariableTypeString:
switch b.AssistType {
case vo.AssistTypeTime:
tInfo.Type = vo.DataTypeTime
case vo.AssistTypeNotSet:
tInfo.Type = vo.DataTypeString
default:
fileType, ok := assistTypeToFileType(b.AssistType)
if ok {
tInfo.Type = vo.DataTypeFile
tInfo.FileType = &fileType
} else {
return nil, fmt.Errorf("unsupported assist type: %v", b.AssistType)
}
}
case vo.VariableTypeInteger:
tInfo.Type = vo.DataTypeInteger
case vo.VariableTypeFloat:
tInfo.Type = vo.DataTypeNumber
case vo.VariableTypeBoolean:
tInfo.Type = vo.DataTypeBoolean
case vo.VariableTypeObject:
tInfo.Type = vo.DataTypeObject
if b.Schema != nil {
tInfo.Properties = make([]*vo.NamedTypeInfo, 0, len(b.Schema.([]any)))
for _, subVAny := range b.Schema.([]any) {
if b.Value.Type == vo.BlockInputValueTypeRef {
subV, err := vo.ParseVariable(subVAny)
if err != nil {
return nil, err
}
subNInfo, err := VariableToNamedTypeInfo(subV)
if err != nil {
return nil, err
}
tInfo.Properties = append(tInfo.Properties, subNInfo)
} else if b.Value.Type == vo.BlockInputValueTypeObjectRef {
subV, err := parseParam(subVAny)
if err != nil {
return nil, err
}
subNInfo, err := BlockInputToNamedTypeInfo(subV.Name, subV.Input)
if err != nil {
return nil, err
}
tInfo.Properties = append(tInfo.Properties, subNInfo)
}
}
}
case vo.VariableTypeList:
tInfo.Type = vo.DataTypeArray
subVAny := b.Schema
subV, err := vo.ParseVariable(subVAny)
if err != nil {
return nil, err
}
subNInfo, err := VariableToNamedTypeInfo(subV)
if err != nil {
return nil, err
}
tInfo.ElemTypeInfo = subNInfo
default:
return nil, fmt.Errorf("unsupported variable type: %s", b.Type)
}
return tInfo, nil
}
func VariableToNamedTypeInfo(v *vo.Variable) (*vo.NamedTypeInfo, error) {
nInfo := &vo.NamedTypeInfo{
Required: v.Required,
Name: v.Name,
Desc: v.Description,
}
switch v.Type {
case vo.VariableTypeString:
switch v.AssistType {
case vo.AssistTypeTime:
nInfo.Type = vo.DataTypeTime
case vo.AssistTypeNotSet:
nInfo.Type = vo.DataTypeString
default:
fileType, ok := assistTypeToFileType(v.AssistType)
if ok {
nInfo.Type = vo.DataTypeFile
nInfo.FileType = &fileType
} else {
return nil, fmt.Errorf("unsupported assist type: %v", v.AssistType)
}
}
case vo.VariableTypeInteger:
nInfo.Type = vo.DataTypeInteger
case vo.VariableTypeFloat:
nInfo.Type = vo.DataTypeNumber
case vo.VariableTypeBoolean:
nInfo.Type = vo.DataTypeBoolean
case vo.VariableTypeObject:
nInfo.Type = vo.DataTypeObject
if v.Schema != nil {
nInfo.Properties = make([]*vo.NamedTypeInfo, 0)
for _, subVAny := range v.Schema.([]any) {
subV, err := vo.ParseVariable(subVAny)
if err != nil {
return nil, err
}
subTInfo, err := VariableToNamedTypeInfo(subV)
if err != nil {
return nil, err
}
nInfo.Properties = append(nInfo.Properties, subTInfo)
}
}
case vo.VariableTypeList:
nInfo.Type = vo.DataTypeArray
subVAny := v.Schema
subV, err := vo.ParseVariable(subVAny)
if err != nil {
return nil, err
}
subTInfo, err := VariableToNamedTypeInfo(subV)
if err != nil {
return nil, err
}
nInfo.ElemTypeInfo = subTInfo
default:
return nil, fmt.Errorf("unsupported variable type: %s", v.Type)
}
return nInfo, nil
}

View File

@ -24,8 +24,11 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/adaptor"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
@ -123,7 +126,7 @@ func (cv *CanvasValidator) ValidateConnections(ctx context.Context) (issues []*I
return issues, nil
}
func (cv *CanvasValidator) CheckRefVariable(ctx context.Context) (issues []*Issue, err error) {
func (cv *CanvasValidator) CheckRefVariable(_ context.Context) (issues []*Issue, err error) {
issues = make([]*Issue, 0)
var checkRefVariable func(reachability *reachability, reachableNodes map[string]bool) error
checkRefVariable = func(reachability *reachability, parentReachableNodes map[string]bool) error {
@ -237,7 +240,7 @@ func (cv *CanvasValidator) CheckRefVariable(ctx context.Context) (issues []*Issu
return issues, nil
}
func (cv *CanvasValidator) ValidateNestedFlows(ctx context.Context) (issues []*Issue, err error) {
func (cv *CanvasValidator) ValidateNestedFlows(_ context.Context) (issues []*Issue, err error) {
issues = make([]*Issue, 0)
for nodeID, node := range cv.reachability.reachableNodes {
if nestedReachableNodes, ok := cv.reachability.nestedReachability[nodeID]; ok && len(nestedReachableNodes.nestedReachability) > 0 {
@ -265,13 +268,13 @@ func (cv *CanvasValidator) CheckGlobalVariables(ctx context.Context) (issues []*
nVars := make([]*nodeVars, 0)
for _, node := range cv.cfg.Canvas.Nodes {
if node.Type == vo.BlockTypeBotComment {
if node.Type == entity.NodeTypeComment.IDStr() {
continue
}
if node.Type == vo.BlockTypeBotAssignVariable {
if node.Type == entity.NodeTypeVariableAssigner.IDStr() {
v := &nodeVars{node: node, vars: make(map[string]*vo.TypeInfo)}
for _, p := range node.Data.Inputs.InputParameters {
v.vars[p.Name], err = adaptor.CanvasBlockInputToTypeInfo(p.Left)
v.vars[p.Name], err = convert.CanvasBlockInputToTypeInfo(p.Left)
if err != nil {
return nil, err
}
@ -338,7 +341,7 @@ func (cv *CanvasValidator) CheckSubWorkFlowTerminatePlanType(ctx context.Context
var collectSubWorkFlowNodes func(nodes []*vo.Node)
collectSubWorkFlowNodes = func(nodes []*vo.Node) {
for _, n := range nodes {
if n.Type == vo.BlockTypeBotSubWorkflow {
if n.Type == entity.NodeTypeSubWorkflow.IDStr() {
subWfMap = append(subWfMap, n)
wID, err := strconv.ParseInt(n.Data.Inputs.WorkflowID, 10, 64)
if err != nil {
@ -465,62 +468,28 @@ func validateConnections(ctx context.Context, c *vo.Canvas) (issues []*Issue, er
selectorPorts := make(map[string]map[string]bool)
for nodeID, node := range nodeMap {
switch node.Type {
case vo.BlockTypeCondition:
branches := node.Data.Inputs.Branches
if node.Data.Inputs != nil && node.Data.Inputs.SettingOnError != nil &&
node.Data.Inputs.SettingOnError.ProcessType != nil &&
*node.Data.Inputs.SettingOnError.ProcessType == vo.ErrorProcessTypeExceptionBranch {
if _, exists := selectorPorts[nodeID]; !exists {
selectorPorts[nodeID] = make(map[string]bool)
}
selectorPorts[nodeID]["false"] = true
for index := range branches {
if index == 0 {
selectorPorts[nodeID]["true"] = true
} else {
selectorPorts[nodeID][fmt.Sprintf("true_%v", index)] = true
}
}
case vo.BlockTypeBotIntent:
intents := node.Data.Inputs.Intents
if _, exists := selectorPorts[nodeID]; !exists {
selectorPorts[nodeID] = make(map[string]bool)
}
for index := range intents {
selectorPorts[nodeID][fmt.Sprintf("branch_%v", index)] = true
}
selectorPorts[nodeID]["default"] = true
if node.Data.Inputs.SettingOnError != nil && node.Data.Inputs.SettingOnError.ProcessType != nil &&
*node.Data.Inputs.SettingOnError.ProcessType == vo.ErrorProcessTypeExceptionBranch {
selectorPorts[nodeID]["branch_error"] = true
}
case vo.BlockTypeQuestion:
if node.Data.Inputs.QA.AnswerType == vo.QAAnswerTypeOption {
if _, exists := selectorPorts[nodeID]; !exists {
selectorPorts[nodeID] = make(map[string]bool)
}
if node.Data.Inputs.QA.OptionType == vo.QAOptionTypeStatic {
for index := range node.Data.Inputs.QA.Options {
selectorPorts[nodeID][fmt.Sprintf("branch_%v", index)] = true
}
}
if node.Data.Inputs.QA.OptionType == vo.QAOptionTypeDynamic {
selectorPorts[nodeID][fmt.Sprintf("branch_%v", 0)] = true
}
}
default:
if node.Data.Inputs != nil && node.Data.Inputs.SettingOnError != nil &&
node.Data.Inputs.SettingOnError.ProcessType != nil &&
*node.Data.Inputs.SettingOnError.ProcessType == vo.ErrorProcessTypeExceptionBranch {
if _, exists := selectorPorts[nodeID]; !exists {
selectorPorts[nodeID] = make(map[string]bool)
}
selectorPorts[nodeID]["branch_error"] = true
selectorPorts[nodeID]["default"] = true
} else {
outDegree[node.ID] = 0
}
selectorPorts[nodeID][schema.PortBranchError] = true
selectorPorts[nodeID][schema.PortDefault] = true
}
ba, ok := nodes.GetBranchAdaptor(entity.IDStrToNodeType(node.Type))
if ok {
expects := ba.ExpectPorts(ctx, node)
if len(expects) > 0 {
if _, exists := selectorPorts[nodeID]; !exists {
selectorPorts[nodeID] = make(map[string]bool)
}
for _, e := range expects {
selectorPorts[nodeID][e] = true
}
}
}
}
for _, edge := range c.Edges {
@ -544,8 +513,8 @@ func validateConnections(ctx context.Context, c *vo.Canvas) (issues []*Issue, er
for nodeID, node := range nodeMap {
nodeName := node.Data.Meta.Title
switch node.Type {
case vo.BlockTypeBotStart:
switch et := entity.IDStrToNodeType(node.Type); et {
case entity.NodeTypeEntry:
if outDegree[nodeID] == 0 {
issues = append(issues, &Issue{
NodeErr: &NodeErr{
@ -555,13 +524,9 @@ func validateConnections(ctx context.Context, c *vo.Canvas) (issues []*Issue, er
Message: `node "start" not connected`,
})
}
case vo.BlockTypeBotEnd:
case entity.NodeTypeExit:
default:
if ports, isSelector := selectorPorts[nodeID]; isSelector {
selectorIssues := &Issue{NodeErr: &NodeErr{
NodeID: node.ID,
NodeName: nodeName,
}}
message := ""
for port := range ports {
if portOutDegree[nodeID][port] == 0 {
@ -569,12 +534,15 @@ func validateConnections(ctx context.Context, c *vo.Canvas) (issues []*Issue, er
}
}
if len(message) > 0 {
selectorIssues.Message = message
selectorIssues := &Issue{NodeErr: &NodeErr{
NodeID: node.ID,
NodeName: nodeName,
}, Message: message}
issues = append(issues, selectorIssues)
}
} else {
// Break, continue without checking out degrees
if node.Type == vo.BlockTypeBotBreak || node.Type == vo.BlockTypeBotContinue {
if et == entity.NodeTypeBreak || et == entity.NodeTypeContinue {
continue
}
if outDegree[nodeID] == 0 {
@ -585,7 +553,6 @@ func validateConnections(ctx context.Context, c *vo.Canvas) (issues []*Issue, er
},
Message: fmt.Sprintf(`node "%v" not connected`, nodeName),
})
}
}
}
@ -602,7 +569,7 @@ func analyzeCanvasReachability(c *vo.Canvas) (*reachability, error) {
return nil, err
}
startNode, endNode, err := findStartAndEndNodes(c.Nodes)
startNode, _, err := findStartAndEndNodes(c.Nodes)
if err != nil {
return nil, err
}
@ -612,7 +579,7 @@ func analyzeCanvasReachability(c *vo.Canvas) (*reachability, error) {
edgeMap[edge.SourceNodeID] = append(edgeMap[edge.SourceNodeID], edge.TargetNodeID)
}
reachable.reachableNodes, err = performReachabilityAnalysis(nodeMap, edgeMap, startNode, endNode)
reachable.reachableNodes, err = performReachabilityAnalysis(nodeMap, edgeMap, startNode)
if err != nil {
return nil, err
}
@ -635,12 +602,12 @@ func processNestedReachability(c *vo.Canvas, r *reachability) error {
Nodes: append([]*vo.Node{
{
ID: node.ID,
Type: vo.BlockTypeBotStart,
Type: entity.NodeTypeEntry.IDStr(),
Data: node.Data,
},
{
ID: node.ID,
Type: vo.BlockTypeBotEnd,
Type: entity.NodeTypeExit.IDStr(),
},
}, node.Blocks...),
Edges: node.Edges,
@ -663,9 +630,9 @@ func findStartAndEndNodes(nodes []*vo.Node) (*vo.Node, *vo.Node, error) {
for _, node := range nodes {
switch node.Type {
case vo.BlockTypeBotStart:
case entity.NodeTypeEntry.IDStr():
startNode = node
case vo.BlockTypeBotEnd:
case entity.NodeTypeExit.IDStr():
endNode = node
}
}
@ -680,7 +647,7 @@ func findStartAndEndNodes(nodes []*vo.Node) (*vo.Node, *vo.Node, error) {
return startNode, endNode, nil
}
func performReachabilityAnalysis(nodeMap map[string]*vo.Node, edgeMap map[string][]string, startNode *vo.Node, endNode *vo.Node) (map[string]*vo.Node, error) {
func performReachabilityAnalysis(nodeMap map[string]*vo.Node, edgeMap map[string][]string, startNode *vo.Node) (map[string]*vo.Node, error) {
result := make(map[string]*vo.Node)
result[startNode.ID] = startNode

View File

@ -1,181 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package compose
import (
"context"
"errors"
"fmt"
"github.com/cloudwego/eino/compose"
"github.com/spf13/cast"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector"
)
func (s *NodeSchema) OutputPortCount() (int, bool) {
var hasExceptionPort bool
if s.ExceptionConfigs != nil && s.ExceptionConfigs.ProcessType != nil &&
*s.ExceptionConfigs.ProcessType == vo.ErrorProcessTypeExceptionBranch {
hasExceptionPort = true
}
switch s.Type {
case entity.NodeTypeSelector:
return len(mustGetKey[[]*selector.OneClauseSchema]("Clauses", s.Configs)) + 1, hasExceptionPort
case entity.NodeTypeQuestionAnswer:
if mustGetKey[qa.AnswerType]("AnswerType", s.Configs.(map[string]any)) == qa.AnswerByChoices {
if mustGetKey[qa.ChoiceType]("ChoiceType", s.Configs.(map[string]any)) == qa.FixedChoices {
return len(mustGetKey[[]string]("FixedChoices", s.Configs.(map[string]any))) + 1, hasExceptionPort
} else {
return 2, hasExceptionPort
}
}
return 1, hasExceptionPort
case entity.NodeTypeIntentDetector:
intents := mustGetKey[[]string]("Intents", s.Configs.(map[string]any))
return len(intents) + 1, hasExceptionPort
default:
return 1, hasExceptionPort
}
}
type BranchMapping struct {
Normal []map[string]bool
Exception map[string]bool
}
const (
DefaultBranch = "default"
BranchFmt = "branch_%d"
)
func (s *NodeSchema) GetBranch(bMapping *BranchMapping) (*compose.GraphBranch, error) {
if bMapping == nil {
return nil, errors.New("no branch mapping")
}
endNodes := make(map[string]bool)
for i := range bMapping.Normal {
for k := range bMapping.Normal[i] {
endNodes[k] = true
}
}
if bMapping.Exception != nil {
for k := range bMapping.Exception {
endNodes[k] = true
}
}
switch s.Type {
case entity.NodeTypeSelector:
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
choice := in[selector.SelectKey].(int)
if choice < 0 || choice > len(bMapping.Normal) {
return nil, fmt.Errorf("node %s choice out of range: %d", s.Key, choice)
}
choices := make(map[string]bool, len((bMapping.Normal)[choice]))
for k := range (bMapping.Normal)[choice] {
choices[k] = true
}
return choices, nil
}
return compose.NewGraphMultiBranch(condition, endNodes), nil
case entity.NodeTypeQuestionAnswer:
conf := s.Configs.(map[string]any)
if mustGetKey[qa.AnswerType]("AnswerType", conf) == qa.AnswerByChoices {
choiceType := mustGetKey[qa.ChoiceType]("ChoiceType", conf)
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
optionID, ok := nodes.TakeMapValue(in, compose.FieldPath{qa.OptionIDKey})
if !ok {
return nil, fmt.Errorf("failed to take option id from input map: %v", in)
}
if optionID.(string) == "other" {
return (bMapping.Normal)[len(bMapping.Normal)-1], nil
}
if choiceType == qa.DynamicChoices { // all dynamic choices maps to branch 0
return (bMapping.Normal)[0], nil
}
optionIDInt, ok := qa.AlphabetToInt(optionID.(string))
if !ok {
return nil, fmt.Errorf("failed to convert option id from input map: %v", optionID)
}
if optionIDInt < 0 || optionIDInt >= len(bMapping.Normal) {
return nil, fmt.Errorf("failed to take option id from input map: %v", in)
}
return (bMapping.Normal)[optionIDInt], nil
}
return compose.NewGraphMultiBranch(condition, endNodes), nil
}
return nil, fmt.Errorf("this qa node should not have branches: %s", s.Key)
case entity.NodeTypeIntentDetector:
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
isSuccess, ok := in["isSuccess"]
if ok && isSuccess != nil && !isSuccess.(bool) {
return bMapping.Exception, nil
}
classificationId, ok := nodes.TakeMapValue(in, compose.FieldPath{"classificationId"})
if !ok {
return nil, fmt.Errorf("failed to take classification id from input map: %v", in)
}
// Intent detector the node default branch uses classificationId=0. But currently scene, the implementation uses default as the last element of the array.
// Therefore, when classificationId=0, it needs to be converted into the node corresponding to the last index of the array.
// Other options also need to reduce the index by 1.
id, err := cast.ToInt64E(classificationId)
if err != nil {
return nil, err
}
realID := id - 1
if realID >= int64(len(bMapping.Normal)) {
return nil, fmt.Errorf("invalid classification id from input, classification id: %v", classificationId)
}
if realID < 0 {
realID = int64(len(bMapping.Normal)) - 1
}
return (bMapping.Normal)[realID], nil
}
return compose.NewGraphMultiBranch(condition, endNodes), nil
default:
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
isSuccess, ok := in["isSuccess"]
if ok && isSuccess != nil && !isSuccess.(bool) {
return bMapping.Exception, nil
}
return (bMapping.Normal)[0], nil
}
return compose.NewGraphMultiBranch(condition, endNodes), nil
}
}

View File

@ -1,194 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package compose
import (
"context"
"fmt"
"strconv"
"strings"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector"
)
type selectorCallbackField struct {
Key string `json:"key"`
Type vo.DataType `json:"type"`
Value any `json:"value"`
}
type selectorCondition struct {
Left selectorCallbackField `json:"left"`
Operator vo.OperatorType `json:"operator"`
Right *selectorCallbackField `json:"right,omitempty"`
}
type selectorBranch struct {
Conditions []*selectorCondition `json:"conditions"`
Logic vo.LogicType `json:"logic"`
Name string `json:"name"`
}
func (s *NodeSchema) toSelectorCallbackInput(sc *WorkflowSchema) func(_ context.Context, in map[string]any) (map[string]any, error) {
return func(_ context.Context, in map[string]any) (map[string]any, error) {
config := mustGetKey[[]*selector.OneClauseSchema]("Clauses", s.Configs)
count := len(config)
output := make([]*selectorBranch, count)
for _, source := range s.InputSources {
targetPath := source.Path
if len(targetPath) == 2 {
indexStr := targetPath[0]
index, err := strconv.Atoi(indexStr)
if err != nil {
return nil, err
}
branch := output[index]
if branch == nil {
output[index] = &selectorBranch{
Conditions: []*selectorCondition{
{
Operator: config[index].Single.ToCanvasOperatorType(),
},
},
Logic: selector.ClauseRelationAND.ToVOLogicType(),
}
}
if targetPath[1] == selector.LeftKey {
leftV, ok := nodes.TakeMapValue(in, targetPath)
if !ok {
return nil, fmt.Errorf("failed to take left value of %s", targetPath)
}
if source.Source.Ref.VariableType != nil {
if *source.Source.Ref.VariableType == vo.ParentIntermediate {
parentNodeKey, ok := sc.Hierarchy[s.Key]
if !ok {
return nil, fmt.Errorf("failed to find parent node key of %s", s.Key)
}
parentNode := sc.GetNode(parentNodeKey)
output[index].Conditions[0].Left = selectorCallbackField{
Key: parentNode.Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
Value: leftV,
}
} else {
output[index].Conditions[0].Left = selectorCallbackField{
Key: "",
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
Value: leftV,
}
}
} else {
output[index].Conditions[0].Left = selectorCallbackField{
Key: sc.GetNode(source.Source.Ref.FromNodeKey).Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
Value: leftV,
}
}
} else if targetPath[1] == selector.RightKey {
rightV, ok := nodes.TakeMapValue(in, targetPath)
if !ok {
return nil, fmt.Errorf("failed to take right value of %s", targetPath)
}
output[index].Conditions[0].Right = &selectorCallbackField{
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
Value: rightV,
}
}
} else if len(targetPath) == 3 {
indexStr := targetPath[0]
index, err := strconv.Atoi(indexStr)
if err != nil {
return nil, err
}
multi := config[index].Multi
branch := output[index]
if branch == nil {
output[index] = &selectorBranch{
Conditions: make([]*selectorCondition, len(multi.Clauses)),
Logic: multi.Relation.ToVOLogicType(),
}
}
clauseIndexStr := targetPath[1]
clauseIndex, err := strconv.Atoi(clauseIndexStr)
if err != nil {
return nil, err
}
clause := multi.Clauses[clauseIndex]
if output[index].Conditions[clauseIndex] == nil {
output[index].Conditions[clauseIndex] = &selectorCondition{
Operator: clause.ToCanvasOperatorType(),
}
}
if targetPath[2] == selector.LeftKey {
leftV, ok := nodes.TakeMapValue(in, targetPath)
if !ok {
return nil, fmt.Errorf("failed to take left value of %s", targetPath)
}
if source.Source.Ref.VariableType != nil {
if *source.Source.Ref.VariableType == vo.ParentIntermediate {
parentNodeKey, ok := sc.Hierarchy[s.Key]
if !ok {
return nil, fmt.Errorf("failed to find parent node key of %s", s.Key)
}
parentNode := sc.GetNode(parentNodeKey)
output[index].Conditions[clauseIndex].Left = selectorCallbackField{
Key: parentNode.Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
Value: leftV,
}
} else {
output[index].Conditions[clauseIndex].Left = selectorCallbackField{
Key: "",
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
Value: leftV,
}
}
} else {
output[index].Conditions[clauseIndex].Left = selectorCallbackField{
Key: sc.GetNode(source.Source.Ref.FromNodeKey).Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
Value: leftV,
}
}
} else if targetPath[2] == selector.RightKey {
rightV, ok := nodes.TakeMapValue(in, targetPath)
if !ok {
return nil, fmt.Errorf("failed to take right value of %s", targetPath)
}
output[index].Conditions[clauseIndex].Right = &selectorCallbackField{
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
Value: rightV,
}
}
}
}
return map[string]any{"branches": output}, nil
}
}

View File

@ -31,7 +31,9 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
@ -53,7 +55,7 @@ func (r *WorkflowRunner) designateOptions(ctx context.Context) (context.Context,
rootHandler := execute.NewRootWorkflowHandler(
wb,
executeID,
workflowSC.requireCheckPoint,
workflowSC.RequireCheckpoint(),
eventChan,
resumedEvent,
exeCfg,
@ -67,7 +69,7 @@ func (r *WorkflowRunner) designateOptions(ctx context.Context) (context.Context,
var nodeOpt einoCompose.Option
if ns.Type == entity.NodeTypeExit {
nodeOpt = nodeCallbackOption(key, ns.Name, eventChan, resumedEvent,
ptr.Of(mustGetKey[vo.TerminatePlan]("TerminalPlan", ns.Configs)))
ptr.Of(ns.Configs.(*exit.Config).TerminatePlan))
} else if ns.Type != entity.NodeTypeLambda {
nodeOpt = nodeCallbackOption(key, ns.Name, eventChan, resumedEvent, nil)
}
@ -117,7 +119,7 @@ func (r *WorkflowRunner) designateOptions(ctx context.Context) (context.Context,
}
}
if workflowSC.requireCheckPoint {
if workflowSC.RequireCheckpoint() {
opts = append(opts, einoCompose.WithCheckPointID(strconv.FormatInt(executeID, 10)))
}
@ -139,7 +141,7 @@ func WrapOptWithIndex(opt einoCompose.Option, parentNodeKey vo.NodeKey, index in
func (r *WorkflowRunner) designateOptionsForSubWorkflow(ctx context.Context,
parentHandler *execute.WorkflowHandler,
ns *NodeSchema,
ns *schema2.NodeSchema,
pathPrefix ...string) (opts []einoCompose.Option, err error) {
var (
resumeEvent = r.interruptEvent
@ -163,7 +165,7 @@ func (r *WorkflowRunner) designateOptionsForSubWorkflow(ctx context.Context,
var nodeOpt einoCompose.Option
if subNS.Type == entity.NodeTypeExit {
nodeOpt = nodeCallbackOption(key, subNS.Name, eventChan, resumeEvent,
ptr.Of(mustGetKey[vo.TerminatePlan]("TerminalPlan", subNS.Configs)))
ptr.Of(subNS.Configs.(*exit.Config).TerminatePlan))
} else {
nodeOpt = nodeCallbackOption(key, subNS.Name, eventChan, resumeEvent, nil)
}
@ -219,7 +221,7 @@ func (r *WorkflowRunner) designateOptionsForSubWorkflow(ctx context.Context,
return opts, nil
}
func llmToolCallbackOptions(ctx context.Context, ns *NodeSchema, eventChan chan *execute.Event,
func llmToolCallbackOptions(ctx context.Context, ns *schema2.NodeSchema, eventChan chan *execute.Event,
sw *schema.StreamWriter[*entity.Message]) (
opts []einoCompose.Option, err error) {
// this is a LLM node.
@ -229,7 +231,8 @@ func llmToolCallbackOptions(ctx context.Context, ns *NodeSchema, eventChan chan
panic("impossible: llmToolCallbackOptions is called on a non-LLM node")
}
fcParams := getKeyOrZero[*vo.FCParam]("FCParam", ns.Configs)
cfg := ns.Configs.(*llm.Config)
fcParams := cfg.FCParam
if fcParams != nil {
if fcParams.WorkflowFCParam != nil {
// TODO: try to avoid getting the workflow tool all over again
@ -272,7 +275,7 @@ func llmToolCallbackOptions(ctx context.Context, ns *NodeSchema, eventChan chan
toolHandler := execute.NewToolHandler(eventChan, funcInfo)
opt := einoCompose.WithCallbacks(toolHandler)
opt = einoCompose.WithLambdaOption(llm.WithNestedWorkflowOptions(nodes.WithOptsForNested(opt))).DesignateNode(string(ns.Key))
opt = einoCompose.WithLambdaOption(nodes.WithOptsForNested(opt)).DesignateNode(string(ns.Key))
opts = append(opts, opt)
}
}
@ -310,7 +313,7 @@ func llmToolCallbackOptions(ctx context.Context, ns *NodeSchema, eventChan chan
toolHandler := execute.NewToolHandler(eventChan, funcInfo)
opt := einoCompose.WithCallbacks(toolHandler)
opt = einoCompose.WithLambdaOption(llm.WithNestedWorkflowOptions(nodes.WithOptsForNested(opt))).DesignateNode(string(ns.Key))
opt = einoCompose.WithLambdaOption(nodes.WithOptsForNested(opt)).DesignateNode(string(ns.Key))
opts = append(opts, opt)
}
}

View File

@ -25,12 +25,13 @@ import (
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
)
// outputValueFiller will fill the output value with nil if the key is not present in the output map.
// if a node emits stream as output, the node needs to handle these absent keys in stream themselves.
func (s *NodeSchema) outputValueFiller() func(ctx context.Context, output map[string]any) (map[string]any, error) {
func outputValueFiller(s *schema2.NodeSchema) func(ctx context.Context, output map[string]any) (map[string]any, error) {
if len(s.OutputTypes) == 0 {
return func(ctx context.Context, output map[string]any) (map[string]any, error) {
return output, nil
@ -55,7 +56,7 @@ func (s *NodeSchema) outputValueFiller() func(ctx context.Context, output map[st
// inputValueFiller will fill the input value with default value(zero or nil) if the input value is not present in map.
// if a node accepts stream as input, the node needs to handle these absent keys in stream themselves.
func (s *NodeSchema) inputValueFiller() func(ctx context.Context, input map[string]any) (map[string]any, error) {
func inputValueFiller(s *schema2.NodeSchema) func(ctx context.Context, input map[string]any) (map[string]any, error) {
if len(s.InputTypes) == 0 {
return func(ctx context.Context, input map[string]any) (map[string]any, error) {
return input, nil
@ -78,7 +79,7 @@ func (s *NodeSchema) inputValueFiller() func(ctx context.Context, input map[stri
}
}
func (s *NodeSchema) streamInputValueFiller() func(ctx context.Context,
func streamInputValueFiller(s *schema2.NodeSchema) func(ctx context.Context,
input *schema.StreamReader[map[string]any]) *schema.StreamReader[map[string]any] {
fn := func(ctx context.Context, i map[string]any) (map[string]any, error) {
newI := make(map[string]any)

View File

@ -23,6 +23,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
func TestNodeSchema_OutputValueFiller(t *testing.T) {
@ -282,11 +283,11 @@ func TestNodeSchema_OutputValueFiller(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &NodeSchema{
s := &schema.NodeSchema{
OutputTypes: tt.fields.Outputs,
}
got, err := s.outputValueFiller()(context.Background(), tt.fields.In)
got, err := outputValueFiller(s)(context.Background(), tt.fields.In)
if len(tt.wantErr) > 0 {
assert.Error(t, err)

View File

@ -0,0 +1,118 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package compose
import (
"context"
"fmt"
"runtime/debug"
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/subworkflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
type Node struct {
Lambda *compose.Lambda
}
// New instantiates the actual node type from NodeSchema.
func New(ctx context.Context, s *schema.NodeSchema,
inner compose.Runnable[map[string]any, map[string]any], // inner workflow for composite node
sc *schema.WorkflowSchema, // the workflow this NodeSchema is in
deps *dependencyInfo, // the dependency for this node pre-calculated by workflow engine
) (_ *Node, err error) {
defer func() {
if panicErr := recover(); panicErr != nil {
err = safego.NewPanicErr(panicErr, debug.Stack())
}
if err != nil {
err = vo.WrapIfNeeded(errno.ErrCreateNodeFail, err, errorx.KV("node_name", s.Name), errorx.KV("cause", err.Error()))
}
}()
var fullSources map[string]*schema.SourceInfo
if m := entity.NodeMetaByNodeType(s.Type); m != nil && m.InputSourceAware {
if fullSources, err = GetFullSources(s, sc, deps); err != nil {
return nil, err
}
s.FullSources = fullSources
}
// if NodeSchema's Configs implements NodeBuilder, will use it to build the node
nb, ok := s.Configs.(schema.NodeBuilder)
if ok {
opts := []schema.BuildOption{
schema.WithWorkflowSchema(sc),
schema.WithInnerWorkflow(inner),
}
// build the actual InvokableNode, etc.
n, err := nb.Build(ctx, s, opts...)
if err != nil {
return nil, err
}
// wrap InvokableNode, etc. within NodeRunner, converting to eino's Lambda
return toNode(s, n), nil
}
switch s.Type {
case entity.NodeTypeLambda:
if s.Lambda == nil {
return nil, fmt.Errorf("lambda is not defined for NodeTypeLambda")
}
return &Node{Lambda: s.Lambda}, nil
case entity.NodeTypeSubWorkflow:
subWorkflow, err := buildSubWorkflow(ctx, s, sc.RequireCheckpoint())
if err != nil {
return nil, err
}
return toNode(s, subWorkflow), nil
default:
panic(fmt.Sprintf("node schema's Configs does not implement NodeBuilder. type: %v", s.Type))
}
}
func buildSubWorkflow(ctx context.Context, s *schema.NodeSchema, requireCheckpoint bool) (*subworkflow.SubWorkflow, error) {
var opts []WorkflowOption
opts = append(opts, WithIDAsName(s.Configs.(*subworkflow.Config).WorkflowID))
if requireCheckpoint {
opts = append(opts, WithParentRequireCheckpoint())
}
if s := execute.GetStaticConfig(); s != nil && s.MaxNodeCountPerWorkflow > 0 {
opts = append(opts, WithMaxNodeCount(s.MaxNodeCountPerWorkflow))
}
wf, err := NewWorkflow(ctx, s.SubWorkflowSchema, opts...)
if err != nil {
return nil, err
}
return &subworkflow.SubWorkflow{
Runner: wf.Runner,
}, nil
}

View File

@ -33,6 +33,8 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
@ -48,7 +50,6 @@ type nodeRunConfig[O any] struct {
maxRetry int64
errProcessType vo.ErrorProcessType
dataOnErr func(ctx context.Context) map[string]any
callbackEnabled bool
preProcessors []func(ctx context.Context, input map[string]any) (map[string]any, error)
postProcessors []func(ctx context.Context, input map[string]any) (map[string]any, error)
streamPreProcessors []func(ctx context.Context,
@ -58,12 +59,14 @@ type nodeRunConfig[O any] struct {
init []func(context.Context) (context.Context, error)
i compose.Invoke[map[string]any, map[string]any, O]
s compose.Stream[map[string]any, map[string]any, O]
c compose.Collect[map[string]any, map[string]any, O]
t compose.Transform[map[string]any, map[string]any, O]
}
func newNodeRunConfig[O any](ns *NodeSchema,
func newNodeRunConfig[O any](ns *schema2.NodeSchema,
i compose.Invoke[map[string]any, map[string]any, O],
s compose.Stream[map[string]any, map[string]any, O],
c compose.Collect[map[string]any, map[string]any, O],
t compose.Transform[map[string]any, map[string]any, O],
opts *newNodeOptions) *nodeRunConfig[O] {
meta := entity.NodeMetaByNodeType(ns.Type)
@ -92,12 +95,12 @@ func newNodeRunConfig[O any](ns *NodeSchema,
keyFinishedMarkerTrimmer(),
}
if meta.PreFillZero {
preProcessors = append(preProcessors, ns.inputValueFiller())
preProcessors = append(preProcessors, inputValueFiller(ns))
}
var postProcessors []func(ctx context.Context, input map[string]any) (map[string]any, error)
if meta.PostFillNil {
postProcessors = append(postProcessors, ns.outputValueFiller())
postProcessors = append(postProcessors, outputValueFiller(ns))
}
streamPreProcessors := []func(ctx context.Context,
@ -110,7 +113,15 @@ func newNodeRunConfig[O any](ns *NodeSchema,
},
}
if meta.PreFillZero {
streamPreProcessors = append(streamPreProcessors, ns.streamInputValueFiller())
streamPreProcessors = append(streamPreProcessors, streamInputValueFiller(ns))
}
if meta.UseCtxCache {
opts.init = append([]func(ctx context.Context) (context.Context, error){
func(ctx context.Context) (context.Context, error) {
return ctxcache.Init(ctx), nil
},
}, opts.init...)
}
opts.init = append(opts.init, func(ctx context.Context) (context.Context, error) {
@ -129,7 +140,6 @@ func newNodeRunConfig[O any](ns *NodeSchema,
maxRetry: maxRetry,
errProcessType: errProcessType,
dataOnErr: dataOnErr,
callbackEnabled: meta.CallbackEnabled,
preProcessors: preProcessors,
postProcessors: postProcessors,
streamPreProcessors: streamPreProcessors,
@ -138,18 +148,21 @@ func newNodeRunConfig[O any](ns *NodeSchema,
init: opts.init,
i: i,
s: s,
c: c,
t: t,
}
}
func newNodeRunConfigWOOpt(ns *NodeSchema,
func newNodeRunConfigWOOpt(ns *schema2.NodeSchema,
i compose.InvokeWOOpt[map[string]any, map[string]any],
s compose.StreamWOOpt[map[string]any, map[string]any],
c compose.CollectWOOpt[map[string]any, map[string]any],
t compose.TransformWOOpts[map[string]any, map[string]any],
opts *newNodeOptions) *nodeRunConfig[any] {
var (
iWO compose.Invoke[map[string]any, map[string]any, any]
sWO compose.Stream[map[string]any, map[string]any, any]
cWO compose.Collect[map[string]any, map[string]any, any]
tWO compose.Transform[map[string]any, map[string]any, any]
)
@ -165,13 +178,19 @@ func newNodeRunConfigWOOpt(ns *NodeSchema,
}
}
if c != nil {
cWO = func(ctx context.Context, in *schema.StreamReader[map[string]any], _ ...any) (out map[string]any, err error) {
return c(ctx, in)
}
}
if t != nil {
tWO = func(ctx context.Context, input *schema.StreamReader[map[string]any], opts ...any) (output *schema.StreamReader[map[string]any], err error) {
return t(ctx, input)
}
}
return newNodeRunConfig[any](ns, iWO, sWO, tWO, opts)
return newNodeRunConfig[any](ns, iWO, sWO, cWO, tWO, opts)
}
type newNodeOptions struct {
@ -180,57 +199,100 @@ type newNodeOptions struct {
init []func(context.Context) (context.Context, error)
}
type newNodeOption func(*newNodeOptions)
func toNode(ns *schema2.NodeSchema, r any) *Node {
iWOpt, _ := r.(nodes.InvokableNodeWOpt)
sWOpt, _ := r.(nodes.StreamableNodeWOpt)
cWOpt, _ := r.(nodes.CollectableNodeWOpt)
tWOpt, _ := r.(nodes.TransformableNodeWOpt)
iWOOpt, _ := r.(nodes.InvokableNode)
sWOOpt, _ := r.(nodes.StreamableNode)
cWOOpt, _ := r.(nodes.CollectableNode)
tWOOpt, _ := r.(nodes.TransformableNode)
func withCallbackInputConverter(f func(context.Context, map[string]any) (map[string]any, error)) newNodeOption {
return func(opts *newNodeOptions) {
opts.callbackInputConverter = f
var wOpt, wOOpt bool
if iWOpt != nil || sWOpt != nil || cWOpt != nil || tWOpt != nil {
wOpt = true
}
}
func withCallbackOutputConverter(f func(context.Context, map[string]any) (*nodes.StructuredCallbackOutput, error)) newNodeOption {
return func(opts *newNodeOptions) {
opts.callbackOutputConverter = f
if iWOOpt != nil || sWOOpt != nil || cWOOpt != nil || tWOOpt != nil {
wOOpt = true
}
if wOpt && wOOpt {
panic("a node's different streaming methods needs to be consistent: " +
"they should ALL have NodeOption or None should have them")
}
if !wOpt && !wOOpt {
panic("a node should implement at least one interface among: InvokableNodeWOpt, StreamableNodeWOpt, CollectableNodeWOpt, TransformableNodeWOpt, InvokableNode, StreamableNode, CollectableNode, TransformableNode")
}
}
func withInit(f func(context.Context) (context.Context, error)) newNodeOption {
return func(opts *newNodeOptions) {
opts.init = append(opts.init, f)
}
}
func invokableNode(ns *NodeSchema, i compose.InvokeWOOpt[map[string]any, map[string]any], opts ...newNodeOption) *Node {
options := &newNodeOptions{}
for _, opt := range opts {
opt(options)
ci, ok := r.(nodes.CallbackInputConverted)
if ok {
options.callbackInputConverter = ci.ToCallbackInput
}
return newNodeRunConfigWOOpt(ns, i, nil, nil, options).toNode()
}
func invokableNodeWO[O any](ns *NodeSchema, i compose.Invoke[map[string]any, map[string]any, O], opts ...newNodeOption) *Node {
options := &newNodeOptions{}
for _, opt := range opts {
opt(options)
co, ok := r.(nodes.CallbackOutputConverted)
if ok {
options.callbackOutputConverter = co.ToCallbackOutput
}
return newNodeRunConfig(ns, i, nil, nil, options).toNode()
}
func invokableTransformableNode(ns *NodeSchema, i compose.InvokeWOOpt[map[string]any, map[string]any],
t compose.TransformWOOpts[map[string]any, map[string]any], opts ...newNodeOption) *Node {
options := &newNodeOptions{}
for _, opt := range opts {
opt(options)
init, ok := r.(nodes.Initializer)
if ok {
options.init = append(options.init, init.Init)
}
return newNodeRunConfigWOOpt(ns, i, nil, t, options).toNode()
}
func invokableStreamableNodeWO[O any](ns *NodeSchema, i compose.Invoke[map[string]any, map[string]any, O], s compose.Stream[map[string]any, map[string]any, O], opts ...newNodeOption) *Node {
options := &newNodeOptions{}
for _, opt := range opts {
opt(options)
if wOpt {
var (
i compose.Invoke[map[string]any, map[string]any, nodes.NodeOption]
s compose.Stream[map[string]any, map[string]any, nodes.NodeOption]
c compose.Collect[map[string]any, map[string]any, nodes.NodeOption]
t compose.Transform[map[string]any, map[string]any, nodes.NodeOption]
)
if iWOpt != nil {
i = iWOpt.Invoke
}
if sWOpt != nil {
s = sWOpt.Stream
}
if cWOpt != nil {
c = cWOpt.Collect
}
if tWOpt != nil {
t = tWOpt.Transform
}
return newNodeRunConfig(ns, i, s, c, t, options).toNode()
}
return newNodeRunConfig(ns, i, s, nil, options).toNode()
var (
i compose.InvokeWOOpt[map[string]any, map[string]any]
s compose.StreamWOOpt[map[string]any, map[string]any]
c compose.CollectWOOpt[map[string]any, map[string]any]
t compose.TransformWOOpts[map[string]any, map[string]any]
)
if iWOOpt != nil {
i = iWOOpt.Invoke
}
if sWOOpt != nil {
s = sWOOpt.Stream
}
if cWOOpt != nil {
c = cWOOpt.Collect
}
if tWOOpt != nil {
t = tWOOpt.Transform
}
return newNodeRunConfigWOOpt(ns, i, s, c, t, options).toNode()
}
func (nc *nodeRunConfig[O]) invoke() func(ctx context.Context, input map[string]any, opts ...O) (output map[string]any, err error) {
@ -375,10 +437,8 @@ func (nc *nodeRunConfig[O]) transform() func(ctx context.Context, input *schema.
func (nc *nodeRunConfig[O]) toNode() *Node {
var opts []compose.LambdaOpt
opts = append(opts, compose.WithLambdaType(string(nc.nodeType)))
opts = append(opts, compose.WithLambdaCallbackEnable(true))
if nc.callbackEnabled {
opts = append(opts, compose.WithLambdaCallbackEnable(true))
}
l, err := compose.AnyLambda(nc.invoke(), nc.stream(), nil, nc.transform(), opts...)
if err != nil {
panic(fmt.Sprintf("failed to create lambda for node %s, err: %v", nc.nodeName, err))
@ -406,9 +466,6 @@ func newNodeRunner[O any](ctx context.Context, cfg *nodeRunConfig[O]) (context.C
}
func (r *nodeRunner[O]) onStart(ctx context.Context, input map[string]any) (context.Context, error) {
if !r.callbackEnabled {
return ctx, nil
}
if r.callbackInputConverter != nil {
convertedInput, err := r.callbackInputConverter(ctx, input)
if err != nil {
@ -425,10 +482,6 @@ func (r *nodeRunner[O]) onStart(ctx context.Context, input map[string]any) (cont
func (r *nodeRunner[O]) onStartStream(ctx context.Context, input *schema.StreamReader[map[string]any]) (
context.Context, *schema.StreamReader[map[string]any], error) {
if !r.callbackEnabled {
return ctx, input, nil
}
if r.callbackInputConverter != nil {
copied := input.Copy(2)
realConverter := func(ctx context.Context) func(map[string]any) (map[string]any, error) {
@ -580,14 +633,10 @@ func (r *nodeRunner[O]) transform(ctx context.Context, input *schema.StreamReade
}
func (r *nodeRunner[O]) onEnd(ctx context.Context, output map[string]any) error {
if r.errProcessType == vo.ErrorProcessTypeExceptionBranch || r.errProcessType == vo.ErrorProcessTypeDefault {
if r.errProcessType == vo.ErrorProcessTypeExceptionBranch || r.errProcessType == vo.ErrorProcessTypeReturnDefaultData {
output["isSuccess"] = true
}
if !r.callbackEnabled {
return nil
}
if r.callbackOutputConverter != nil {
convertedOutput, err := r.callbackOutputConverter(ctx, output)
if err != nil {
@ -603,15 +652,11 @@ func (r *nodeRunner[O]) onEnd(ctx context.Context, output map[string]any) error
func (r *nodeRunner[O]) onEndStream(ctx context.Context, output *schema.StreamReader[map[string]any]) (
*schema.StreamReader[map[string]any], error) {
if r.errProcessType == vo.ErrorProcessTypeExceptionBranch || r.errProcessType == vo.ErrorProcessTypeDefault {
if r.errProcessType == vo.ErrorProcessTypeExceptionBranch || r.errProcessType == vo.ErrorProcessTypeReturnDefaultData {
flag := schema.StreamReaderFromArray([]map[string]any{{"isSuccess": true}})
output = schema.MergeStreamReaders([]*schema.StreamReader[map[string]any]{flag, output})
}
if !r.callbackEnabled {
return output, nil
}
if r.callbackOutputConverter != nil {
copied := output.Copy(2)
realConverter := func(ctx context.Context) func(map[string]any) (*nodes.StructuredCallbackOutput, error) {
@ -632,9 +677,7 @@ func (r *nodeRunner[O]) onEndStream(ctx context.Context, output *schema.StreamRe
func (r *nodeRunner[O]) onError(ctx context.Context, err error) (map[string]any, bool) {
if r.interrupted {
if r.callbackEnabled {
_ = callbacks.OnError(ctx, err)
}
_ = callbacks.OnError(ctx, err)
return nil, false
}
@ -653,22 +696,20 @@ func (r *nodeRunner[O]) onError(ctx context.Context, err error) (map[string]any,
msg := sErr.Msg()
switch r.errProcessType {
case vo.ErrorProcessTypeDefault:
case vo.ErrorProcessTypeReturnDefaultData:
d := r.dataOnErr(ctx)
d["errorBody"] = map[string]any{
"errorMessage": msg,
"errorCode": code,
}
d["isSuccess"] = false
if r.callbackEnabled {
sErr = sErr.ChangeErrLevel(vo.LevelWarn)
sOutput := &nodes.StructuredCallbackOutput{
Output: d,
RawOutput: d,
Error: sErr,
}
_ = callbacks.OnEnd(ctx, sOutput)
sErr = sErr.ChangeErrLevel(vo.LevelWarn)
sOutput := &nodes.StructuredCallbackOutput{
Output: d,
RawOutput: d,
Error: sErr,
}
_ = callbacks.OnEnd(ctx, sOutput)
return d, true
case vo.ErrorProcessTypeExceptionBranch:
s := make(map[string]any)
@ -677,20 +718,16 @@ func (r *nodeRunner[O]) onError(ctx context.Context, err error) (map[string]any,
"errorCode": code,
}
s["isSuccess"] = false
if r.callbackEnabled {
sErr = sErr.ChangeErrLevel(vo.LevelWarn)
sOutput := &nodes.StructuredCallbackOutput{
Output: s,
RawOutput: s,
Error: sErr,
}
_ = callbacks.OnEnd(ctx, sOutput)
sErr = sErr.ChangeErrLevel(vo.LevelWarn)
sOutput := &nodes.StructuredCallbackOutput{
Output: s,
RawOutput: s,
Error: sErr,
}
_ = callbacks.OnEnd(ctx, sOutput)
return s, true
default:
if r.callbackEnabled {
_ = callbacks.OnError(ctx, sErr)
}
_ = callbacks.OnError(ctx, sErr)
return nil, false
}
}

View File

@ -1,580 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package compose
import (
"context"
"fmt"
"runtime/debug"
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/batch"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/code"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/conversation"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/emitter"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/httprequester"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/intentdetector"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/json"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/knowledge"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/plugin"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/receiver"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/subworkflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/textprocessor"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableaggregator"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableassigner"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
"github.com/coze-dev/coze-studio/backend/types/errno"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
)
type NodeSchema struct {
Key vo.NodeKey `json:"key"`
Name string `json:"name"`
Type entity.NodeType `json:"type"`
// Configs are node specific configurations with pre-defined config key and config value.
// Will not participate in request-time field mapping, nor as node's static values.
// In a word, these Configs are INTERNAL to node's implementation, the workflow layer is not aware of them.
Configs any `json:"configs,omitempty"`
InputTypes map[string]*vo.TypeInfo `json:"input_types,omitempty"`
InputSources []*vo.FieldInfo `json:"input_sources,omitempty"`
OutputTypes map[string]*vo.TypeInfo `json:"output_types,omitempty"`
OutputSources []*vo.FieldInfo `json:"output_sources,omitempty"` // only applicable to composite nodes such as Batch or Loop
ExceptionConfigs *ExceptionConfig `json:"exception_configs,omitempty"` // generic configurations applicable to most nodes
StreamConfigs *StreamConfig `json:"stream_configs,omitempty"`
SubWorkflowBasic *entity.WorkflowBasic `json:"sub_workflow_basic,omitempty"`
SubWorkflowSchema *WorkflowSchema `json:"sub_workflow_schema,omitempty"`
Lambda *compose.Lambda // not serializable, used for internal test.
}
type ExceptionConfig struct {
TimeoutMS int64 `json:"timeout_ms,omitempty"` // timeout in milliseconds, 0 means no timeout
MaxRetry int64 `json:"max_retry,omitempty"` // max retry times, 0 means no retry
ProcessType *vo.ErrorProcessType `json:"process_type,omitempty"` // error process type, 0 means throw error
DataOnErr string `json:"data_on_err,omitempty"` // data to return when error, effective when ProcessType==Default occurs
}
type StreamConfig struct {
// whether this node has the ability to produce genuine streaming output.
// not include nodes that only passes stream down as they receives them
CanGeneratesStream bool `json:"can_generates_stream,omitempty"`
// whether this node prioritize streaming input over none-streaming input.
// not include nodes that can accept both and does not have preference.
RequireStreamingInput bool `json:"can_process_stream,omitempty"`
}
type Node struct {
Lambda *compose.Lambda
}
func (s *NodeSchema) New(ctx context.Context, inner compose.Runnable[map[string]any, map[string]any],
sc *WorkflowSchema, deps *dependencyInfo) (_ *Node, err error) {
defer func() {
if panicErr := recover(); panicErr != nil {
err = safego.NewPanicErr(panicErr, debug.Stack())
}
if err != nil {
err = vo.WrapIfNeeded(errno.ErrCreateNodeFail, err, errorx.KV("node_name", s.Name), errorx.KV("cause", err.Error()))
}
}()
if m := entity.NodeMetaByNodeType(s.Type); m != nil && m.InputSourceAware {
if err = s.SetFullSources(sc.GetAllNodes(), deps); err != nil {
return nil, err
}
}
switch s.Type {
case entity.NodeTypeLambda:
if s.Lambda == nil {
return nil, fmt.Errorf("lambda is not defined for NodeTypeLambda")
}
return &Node{Lambda: s.Lambda}, nil
case entity.NodeTypeLLM:
conf, err := s.ToLLMConfig(ctx)
if err != nil {
return nil, err
}
l, err := llm.New(ctx, conf)
if err != nil {
return nil, err
}
return invokableStreamableNodeWO(s, l.Chat, l.ChatStream, withCallbackOutputConverter(l.ToCallbackOutput)), nil
case entity.NodeTypeSelector:
conf := s.ToSelectorConfig()
sl, err := selector.NewSelector(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, sl.Select, withCallbackInputConverter(s.toSelectorCallbackInput(sc)), withCallbackOutputConverter(sl.ToCallbackOutput)), nil
case entity.NodeTypeBatch:
if inner == nil {
return nil, fmt.Errorf("inner workflow must not be nil when creating batch node")
}
conf, err := s.ToBatchConfig(inner)
if err != nil {
return nil, err
}
b, err := batch.NewBatch(ctx, conf)
if err != nil {
return nil, err
}
return invokableNodeWO(s, b.Execute, withCallbackInputConverter(b.ToCallbackInput)), nil
case entity.NodeTypeVariableAggregator:
conf, err := s.ToVariableAggregatorConfig()
if err != nil {
return nil, err
}
va, err := variableaggregator.NewVariableAggregator(ctx, conf)
if err != nil {
return nil, err
}
return invokableTransformableNode(s, va.Invoke, va.Transform,
withCallbackInputConverter(va.ToCallbackInput),
withCallbackOutputConverter(va.ToCallbackOutput),
withInit(va.Init)), nil
case entity.NodeTypeTextProcessor:
conf, err := s.ToTextProcessorConfig()
if err != nil {
return nil, err
}
tp, err := textprocessor.NewTextProcessor(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, tp.Invoke), nil
case entity.NodeTypeHTTPRequester:
conf, err := s.ToHTTPRequesterConfig()
if err != nil {
return nil, err
}
hr, err := httprequester.NewHTTPRequester(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, hr.Invoke, withCallbackInputConverter(hr.ToCallbackInput), withCallbackOutputConverter(hr.ToCallbackOutput)), nil
case entity.NodeTypeContinue:
i := func(ctx context.Context, in map[string]any) (map[string]any, error) {
return map[string]any{}, nil
}
return invokableNode(s, i), nil
case entity.NodeTypeBreak:
b, err := loop.NewBreak(ctx, &nodes.ParentIntermediateStore{})
if err != nil {
return nil, err
}
return invokableNode(s, b.DoBreak), nil
case entity.NodeTypeVariableAssigner:
handler := variable.GetVariableHandler()
conf, err := s.ToVariableAssignerConfig(handler)
if err != nil {
return nil, err
}
va, err := variableassigner.NewVariableAssigner(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, va.Assign), nil
case entity.NodeTypeVariableAssignerWithinLoop:
conf, err := s.ToVariableAssignerInLoopConfig()
if err != nil {
return nil, err
}
va, err := variableassigner.NewVariableAssignerInLoop(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, va.Assign), nil
case entity.NodeTypeLoop:
conf, err := s.ToLoopConfig(inner)
if err != nil {
return nil, err
}
l, err := loop.NewLoop(ctx, conf)
if err != nil {
return nil, err
}
return invokableNodeWO(s, l.Execute, withCallbackInputConverter(l.ToCallbackInput)), nil
case entity.NodeTypeQuestionAnswer:
conf, err := s.ToQAConfig(ctx)
if err != nil {
return nil, err
}
qA, err := qa.NewQuestionAnswer(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, qA.Execute, withCallbackOutputConverter(qA.ToCallbackOutput)), nil
case entity.NodeTypeInputReceiver:
conf, err := s.ToInputReceiverConfig()
if err != nil {
return nil, err
}
inputR, err := receiver.New(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, inputR.Invoke, withCallbackOutputConverter(inputR.ToCallbackOutput)), nil
case entity.NodeTypeOutputEmitter:
conf, err := s.ToOutputEmitterConfig(sc)
if err != nil {
return nil, err
}
e, err := emitter.New(ctx, conf)
if err != nil {
return nil, err
}
return invokableTransformableNode(s, e.Emit, e.EmitStream), nil
case entity.NodeTypeEntry:
conf, err := s.ToEntryConfig(ctx)
if err != nil {
return nil, err
}
e, err := entry.NewEntry(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, e.Invoke), nil
case entity.NodeTypeExit:
terminalPlan := mustGetKey[vo.TerminatePlan]("TerminalPlan", s.Configs)
if terminalPlan == vo.ReturnVariables {
i := func(ctx context.Context, in map[string]any) (map[string]any, error) {
if in == nil {
return map[string]any{}, nil
}
return in, nil
}
return invokableNode(s, i), nil
}
conf, err := s.ToOutputEmitterConfig(sc)
if err != nil {
return nil, err
}
e, err := emitter.New(ctx, conf)
if err != nil {
return nil, err
}
return invokableTransformableNode(s, e.Emit, e.EmitStream), nil
case entity.NodeTypeDatabaseCustomSQL:
conf, err := s.ToDatabaseCustomSQLConfig()
if err != nil {
return nil, err
}
sqlER, err := database.NewCustomSQL(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, sqlER.Execute), nil
case entity.NodeTypeDatabaseQuery:
conf, err := s.ToDatabaseQueryConfig()
if err != nil {
return nil, err
}
query, err := database.NewQuery(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, query.Query, withCallbackInputConverter(query.ToCallbackInput)), nil
case entity.NodeTypeDatabaseInsert:
conf, err := s.ToDatabaseInsertConfig()
if err != nil {
return nil, err
}
insert, err := database.NewInsert(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, insert.Insert, withCallbackInputConverter(insert.ToCallbackInput)), nil
case entity.NodeTypeDatabaseUpdate:
conf, err := s.ToDatabaseUpdateConfig()
if err != nil {
return nil, err
}
update, err := database.NewUpdate(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, update.Update, withCallbackInputConverter(update.ToCallbackInput)), nil
case entity.NodeTypeDatabaseDelete:
conf, err := s.ToDatabaseDeleteConfig()
if err != nil {
return nil, err
}
del, err := database.NewDelete(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, del.Delete, withCallbackInputConverter(del.ToCallbackInput)), nil
case entity.NodeTypeKnowledgeIndexer:
conf, err := s.ToKnowledgeIndexerConfig()
if err != nil {
return nil, err
}
w, err := knowledge.NewKnowledgeIndexer(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, w.Store), nil
case entity.NodeTypeKnowledgeRetriever:
conf, err := s.ToKnowledgeRetrieveConfig()
if err != nil {
return nil, err
}
r, err := knowledge.NewKnowledgeRetrieve(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, r.Retrieve), nil
case entity.NodeTypeKnowledgeDeleter:
conf, err := s.ToKnowledgeDeleterConfig()
if err != nil {
return nil, err
}
r, err := knowledge.NewKnowledgeDeleter(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, r.Delete), nil
case entity.NodeTypeCodeRunner:
conf, err := s.ToCodeRunnerConfig()
if err != nil {
return nil, err
}
r, err := code.NewCodeRunner(ctx, conf)
if err != nil {
return nil, err
}
initFn := func(ctx context.Context) (context.Context, error) {
return ctxcache.Init(ctx), nil
}
return invokableNode(s, r.RunCode, withCallbackOutputConverter(r.ToCallbackOutput), withInit(initFn)), nil
case entity.NodeTypePlugin:
conf, err := s.ToPluginConfig()
if err != nil {
return nil, err
}
r, err := plugin.NewPlugin(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, r.Invoke), nil
case entity.NodeTypeCreateConversation:
conf, err := s.ToCreateConversationConfig()
if err != nil {
return nil, err
}
r, err := conversation.NewCreateConversation(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, r.Create), nil
case entity.NodeTypeMessageList:
conf, err := s.ToMessageListConfig()
if err != nil {
return nil, err
}
r, err := conversation.NewMessageList(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, r.List), nil
case entity.NodeTypeClearMessage:
conf, err := s.ToClearMessageConfig()
if err != nil {
return nil, err
}
r, err := conversation.NewClearMessage(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, r.Clear), nil
case entity.NodeTypeIntentDetector:
conf, err := s.ToIntentDetectorConfig(ctx)
if err != nil {
return nil, err
}
r, err := intentdetector.NewIntentDetector(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, r.Invoke), nil
case entity.NodeTypeSubWorkflow:
conf, err := s.ToSubWorkflowConfig(ctx, sc.requireCheckPoint)
if err != nil {
return nil, err
}
r, err := subworkflow.NewSubWorkflow(ctx, conf)
if err != nil {
return nil, err
}
return invokableStreamableNodeWO(s, r.Invoke, r.Stream), nil
case entity.NodeTypeJsonSerialization:
conf, err := s.ToJsonSerializationConfig()
if err != nil {
return nil, err
}
js, err := json.NewJsonSerializer(ctx, conf)
if err != nil {
return nil, err
}
return invokableNode(s, js.Invoke), nil
case entity.NodeTypeJsonDeserialization:
conf, err := s.ToJsonDeserializationConfig()
if err != nil {
return nil, err
}
jd, err := json.NewJsonDeserializer(ctx, conf)
if err != nil {
return nil, err
}
initFn := func(ctx context.Context) (context.Context, error) {
return ctxcache.Init(ctx), nil
}
return invokableNode(s, jd.Invoke, withCallbackOutputConverter(jd.ToCallbackOutput), withInit(initFn)), nil
default:
panic("not implemented")
}
}
func (s *NodeSchema) IsEnableUserQuery() bool {
if s == nil {
return false
}
if s.Type != entity.NodeTypeEntry {
return false
}
if len(s.OutputSources) == 0 {
return false
}
for _, source := range s.OutputSources {
fieldPath := source.Path
if len(fieldPath) == 1 && (fieldPath[0] == "BOT_USER_INPUT" || fieldPath[0] == "USER_INPUT") {
return true
}
}
return false
}
func (s *NodeSchema) IsEnableChatHistory() bool {
if s == nil {
return false
}
switch s.Type {
case entity.NodeTypeLLM:
llmParam := mustGetKey[*model.LLMParams]("LLMParams", s.Configs)
return llmParam.EnableChatHistory
case entity.NodeTypeIntentDetector:
llmParam := mustGetKey[*model.LLMParams]("LLMParams", s.Configs)
return llmParam.EnableChatHistory
default:
return false
}
}
func (s *NodeSchema) IsRefGlobalVariable() bool {
for _, source := range s.InputSources {
if source.IsRefGlobalVariable() {
return true
}
}
for _, source := range s.OutputSources {
if source.IsRefGlobalVariable() {
return true
}
}
return false
}
func (s *NodeSchema) requireCheckpoint() bool {
if s.Type == entity.NodeTypeQuestionAnswer || s.Type == entity.NodeTypeInputReceiver {
return true
}
if s.Type == entity.NodeTypeLLM {
fcParams := getKeyOrZero[*vo.FCParam]("FCParam", s.Configs)
if fcParams != nil && fcParams.WorkflowFCParam != nil {
return true
}
}
if s.Type == entity.NodeTypeSubWorkflow {
s.SubWorkflowSchema.Init()
if s.SubWorkflowSchema.requireCheckPoint {
return true
}
}
return false
}

View File

@ -32,8 +32,10 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/receiver"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
)
@ -46,9 +48,9 @@ type State struct {
InterruptEvents map[vo.NodeKey]*entity.InterruptEvent `json:"interrupt_events,omitempty"`
NestedWorkflowStates map[vo.NodeKey]*nodes.NestedWorkflowState `json:"nested_workflow_states,omitempty"`
ExecutedNodes map[vo.NodeKey]bool `json:"executed_nodes,omitempty"`
SourceInfos map[vo.NodeKey]map[string]*nodes.SourceInfo `json:"source_infos,omitempty"`
GroupChoices map[vo.NodeKey]map[string]int `json:"group_choices,omitempty"`
ExecutedNodes map[vo.NodeKey]bool `json:"executed_nodes,omitempty"`
SourceInfos map[vo.NodeKey]map[string]*schema2.SourceInfo `json:"source_infos,omitempty"`
GroupChoices map[vo.NodeKey]map[string]int `json:"group_choices,omitempty"`
ToolInterruptEvents map[vo.NodeKey]map[string] /*ToolCallID*/ *entity.ToolInterruptEvent `json:"tool_interrupt_events,omitempty"`
LLMToResumeData map[vo.NodeKey]string `json:"llm_to_resume_data,omitempty"`
@ -71,8 +73,8 @@ func init() {
_ = compose.RegisterSerializableType[*model.TokenUsage]("model_token_usage")
_ = compose.RegisterSerializableType[*nodes.NestedWorkflowState]("composite_state")
_ = compose.RegisterSerializableType[*compose.InterruptInfo]("interrupt_info")
_ = compose.RegisterSerializableType[*nodes.SourceInfo]("source_info")
_ = compose.RegisterSerializableType[nodes.FieldStreamType]("field_stream_type")
_ = compose.RegisterSerializableType[*schema2.SourceInfo]("source_info")
_ = compose.RegisterSerializableType[schema2.FieldStreamType]("field_stream_type")
_ = compose.RegisterSerializableType[compose.FieldPath]("field_path")
_ = compose.RegisterSerializableType[*entity.WorkflowBasic]("workflow_basic")
_ = compose.RegisterSerializableType[vo.TerminatePlan]("terminate_plan")
@ -162,41 +164,41 @@ func (s *State) GetDynamicChoice(nodeKey vo.NodeKey) map[string]int {
return s.GroupChoices[nodeKey]
}
func (s *State) GetDynamicStreamType(nodeKey vo.NodeKey, group string) (nodes.FieldStreamType, error) {
func (s *State) GetDynamicStreamType(nodeKey vo.NodeKey, group string) (schema2.FieldStreamType, error) {
choices, ok := s.GroupChoices[nodeKey]
if !ok {
return nodes.FieldMaybeStream, fmt.Errorf("choice not found for node %s", nodeKey)
return schema2.FieldMaybeStream, fmt.Errorf("choice not found for node %s", nodeKey)
}
choice, ok := choices[group]
if !ok {
return nodes.FieldMaybeStream, fmt.Errorf("choice not found for node %s and group %s", nodeKey, group)
return schema2.FieldMaybeStream, fmt.Errorf("choice not found for node %s and group %s", nodeKey, group)
}
if choice == -1 { // this group picks none of the elements
return nodes.FieldNotStream, nil
return schema2.FieldNotStream, nil
}
sInfos, ok := s.SourceInfos[nodeKey]
if !ok {
return nodes.FieldMaybeStream, fmt.Errorf("source infos not found for node %s", nodeKey)
return schema2.FieldMaybeStream, fmt.Errorf("source infos not found for node %s", nodeKey)
}
groupInfo, ok := sInfos[group]
if !ok {
return nodes.FieldMaybeStream, fmt.Errorf("source infos not found for node %s and group %s", nodeKey, group)
return schema2.FieldMaybeStream, fmt.Errorf("source infos not found for node %s and group %s", nodeKey, group)
}
if groupInfo.SubSources == nil {
return nodes.FieldNotStream, fmt.Errorf("dynamic group %s of node %s does not contain any sub sources", group, nodeKey)
return schema2.FieldNotStream, fmt.Errorf("dynamic group %s of node %s does not contain any sub sources", group, nodeKey)
}
subInfo, ok := groupInfo.SubSources[strconv.Itoa(choice)]
if !ok {
return nodes.FieldNotStream, fmt.Errorf("dynamic group %s of node %s does not contain sub source for choice %d", group, nodeKey, choice)
return schema2.FieldNotStream, fmt.Errorf("dynamic group %s of node %s does not contain sub source for choice %d", group, nodeKey, choice)
}
if subInfo.FieldType != nodes.FieldMaybeStream {
if subInfo.FieldType != schema2.FieldMaybeStream {
return subInfo.FieldType, nil
}
@ -211,8 +213,8 @@ func (s *State) GetDynamicStreamType(nodeKey vo.NodeKey, group string) (nodes.Fi
return s.GetDynamicStreamType(subInfo.FromNodeKey, subInfo.FromPath[0])
}
func (s *State) GetAllDynamicStreamTypes(nodeKey vo.NodeKey) (map[string]nodes.FieldStreamType, error) {
result := make(map[string]nodes.FieldStreamType)
func (s *State) GetAllDynamicStreamTypes(nodeKey vo.NodeKey) (map[string]schema2.FieldStreamType, error) {
result := make(map[string]schema2.FieldStreamType)
choices, ok := s.GroupChoices[nodeKey]
if !ok {
return result, nil
@ -269,7 +271,7 @@ func GenState() compose.GenLocalState[*State] {
InterruptEvents: make(map[vo.NodeKey]*entity.InterruptEvent),
NestedWorkflowStates: make(map[vo.NodeKey]*nodes.NestedWorkflowState),
ExecutedNodes: make(map[vo.NodeKey]bool),
SourceInfos: make(map[vo.NodeKey]map[string]*nodes.SourceInfo),
SourceInfos: make(map[vo.NodeKey]map[string]*schema2.SourceInfo),
GroupChoices: make(map[vo.NodeKey]map[string]int),
ToolInterruptEvents: make(map[vo.NodeKey]map[string]*entity.ToolInterruptEvent),
LLMToResumeData: make(map[vo.NodeKey]string),
@ -277,7 +279,7 @@ func GenState() compose.GenLocalState[*State] {
}
}
func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt {
func statePreHandler(s *schema2.NodeSchema, stream bool) compose.GraphAddNodeOpt {
var (
handlers []compose.StatePreHandler[map[string]any, *State]
streamHandlers []compose.StreamStatePreHandler[map[string]any, *State]
@ -314,7 +316,7 @@ func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt {
}
return in, nil
})
} else if s.Type == entity.NodeTypeBatch || s.Type == entity.NodeTypeLoop {
} else if entity.NodeMetaByNodeType(s.Type).IsComposite {
handlers = append(handlers, func(ctx context.Context, in map[string]any, state *State) (map[string]any, error) {
if _, ok := state.Inputs[s.Key]; !ok { // first execution, store input for potential resume later
state.Inputs[s.Key] = in
@ -329,7 +331,7 @@ func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt {
}
if len(handlers) > 0 || !stream {
handlerForVars := s.statePreHandlerForVars()
handlerForVars := statePreHandlerForVars(s)
if handlerForVars != nil {
handlers = append(handlers, handlerForVars)
}
@ -349,12 +351,12 @@ func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt {
if s.Type == entity.NodeTypeVariableAggregator {
streamHandlers = append(streamHandlers, func(ctx context.Context, in *schema.StreamReader[map[string]any], state *State) (*schema.StreamReader[map[string]any], error) {
state.SourceInfos[s.Key] = mustGetKey[map[string]*nodes.SourceInfo]("FullSources", s.Configs)
state.SourceInfos[s.Key] = s.FullSources
return in, nil
})
}
handlerForVars := s.streamStatePreHandlerForVars()
handlerForVars := streamStatePreHandlerForVars(s)
if handlerForVars != nil {
streamHandlers = append(streamHandlers, handlerForVars)
}
@ -381,7 +383,7 @@ func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt {
return nil
}
func (s *NodeSchema) statePreHandlerForVars() compose.StatePreHandler[map[string]any, *State] {
func statePreHandlerForVars(s *schema2.NodeSchema) compose.StatePreHandler[map[string]any, *State] {
// checkout the node's inputs, if it has any variable, use the state's variableHandler to get the variables and set them to the input
var vars []*vo.FieldInfo
for _, input := range s.InputSources {
@ -456,7 +458,7 @@ func (s *NodeSchema) statePreHandlerForVars() compose.StatePreHandler[map[string
}
}
func (s *NodeSchema) streamStatePreHandlerForVars() compose.StreamStatePreHandler[map[string]any, *State] {
func streamStatePreHandlerForVars(s *schema2.NodeSchema) compose.StreamStatePreHandler[map[string]any, *State] {
// checkout the node's inputs, if it has any variables, get the variables and merge them with the input
var vars []*vo.FieldInfo
for _, input := range s.InputSources {
@ -533,7 +535,7 @@ func (s *NodeSchema) streamStatePreHandlerForVars() compose.StreamStatePreHandle
}
}
func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamStatePreHandler[map[string]any, *State] {
func streamStatePreHandlerForStreamSources(s *schema2.NodeSchema) compose.StreamStatePreHandler[map[string]any, *State] {
// if it does not have source info, do not add this pre handler
if s.Configs == nil {
return nil
@ -543,7 +545,7 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
case entity.NodeTypeVariableAggregator, entity.NodeTypeOutputEmitter:
return nil
case entity.NodeTypeExit:
terminatePlan := mustGetKey[vo.TerminatePlan]("TerminalPlan", s.Configs)
terminatePlan := s.Configs.(*exit.Config).TerminatePlan
if terminatePlan != vo.ReturnVariables {
return nil
}
@ -551,7 +553,7 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
// all other node can only accept non-stream inputs, relying on Eino's automatically stream concatenation.
}
sourceInfo := getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs)
sourceInfo := s.FullSources
if len(sourceInfo) == 0 {
return nil
}
@ -566,10 +568,10 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
var (
anyStream bool
checker func(source *nodes.SourceInfo) bool
checker func(source *schema2.SourceInfo) bool
)
checker = func(source *nodes.SourceInfo) bool {
if source.FieldType != nodes.FieldNotStream {
checker = func(source *schema2.SourceInfo) bool {
if source.FieldType != schema2.FieldNotStream {
return true
}
for _, subSource := range source.SubSources {
@ -594,8 +596,8 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
return func(ctx context.Context, in *schema.StreamReader[map[string]any], state *State) (*schema.StreamReader[map[string]any], error) {
resolved := map[string]resolvedStreamSource{}
var resolver func(source nodes.SourceInfo) (result *resolvedStreamSource, err error)
resolver = func(source nodes.SourceInfo) (result *resolvedStreamSource, err error) {
var resolver func(source schema2.SourceInfo) (result *resolvedStreamSource, err error)
resolver = func(source schema2.SourceInfo) (result *resolvedStreamSource, err error) {
if source.IsIntermediate {
result = &resolvedStreamSource{
intermediate: true,
@ -615,14 +617,14 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
}
streamType := source.FieldType
if streamType == nodes.FieldMaybeStream {
if streamType == schema2.FieldMaybeStream {
streamType, err = state.GetDynamicStreamType(source.FromNodeKey, source.FromPath[0])
if err != nil {
return nil, err
}
}
if streamType == nodes.FieldNotStream {
if streamType == schema2.FieldNotStream {
return nil, nil
}
@ -690,7 +692,7 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
}
}
func (s *NodeSchema) StatePostHandler(stream bool) compose.GraphAddNodeOpt {
func statePostHandler(s *schema2.NodeSchema, stream bool) compose.GraphAddNodeOpt {
var (
handlers []compose.StatePostHandler[map[string]any, *State]
streamHandlers []compose.StreamStatePostHandler[map[string]any, *State]
@ -702,7 +704,7 @@ func (s *NodeSchema) StatePostHandler(stream bool) compose.GraphAddNodeOpt {
return out, nil
})
forVars := s.streamStatePostHandlerForVars()
forVars := streamStatePostHandlerForVars(s)
if forVars != nil {
streamHandlers = append(streamHandlers, forVars)
}
@ -725,7 +727,7 @@ func (s *NodeSchema) StatePostHandler(stream bool) compose.GraphAddNodeOpt {
return out, nil
})
forVars := s.statePostHandlerForVars()
forVars := statePostHandlerForVars(s)
if forVars != nil {
handlers = append(handlers, forVars)
}
@ -745,7 +747,7 @@ func (s *NodeSchema) StatePostHandler(stream bool) compose.GraphAddNodeOpt {
return compose.WithStatePostHandler(handler)
}
func (s *NodeSchema) statePostHandlerForVars() compose.StatePostHandler[map[string]any, *State] {
func statePostHandlerForVars(s *schema2.NodeSchema) compose.StatePostHandler[map[string]any, *State] {
// checkout the node's output sources, if it has any variable,
// use the state's variableHandler to get the variables and set them to the output
var vars []*vo.FieldInfo
@ -823,7 +825,7 @@ func (s *NodeSchema) statePostHandlerForVars() compose.StatePostHandler[map[stri
}
}
func (s *NodeSchema) streamStatePostHandlerForVars() compose.StreamStatePostHandler[map[string]any, *State] {
func streamStatePostHandlerForVars(s *schema2.NodeSchema) compose.StreamStatePostHandler[map[string]any, *State] {
// checkout the node's output sources, if it has any variables, get the variables and merge them with the output
var vars []*vo.FieldInfo
for _, output := range s.OutputSources {

View File

@ -21,19 +21,20 @@ import (
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
// SetFullSources calculates REAL input sources for a node.
// GetFullSources calculates REAL input sources for a node.
// It may be different from a NodeSchema's InputSources because of the following reasons:
// 1. a inner node under composite node may refer to a field from a node in its parent workflow,
// this is instead routed to and sourced from the inner workflow's start node.
// 2. at the same time, the composite node needs to delegate the input source to the inner workflow.
// 3. also, some node may have implicit input sources not defined in its NodeSchema's InputSources.
func (s *NodeSchema) SetFullSources(allNS map[vo.NodeKey]*NodeSchema, dep *dependencyInfo) error {
fullSource := make(map[string]*nodes.SourceInfo)
func GetFullSources(s *schema.NodeSchema, sc *schema.WorkflowSchema, dep *dependencyInfo) (
map[string]*schema.SourceInfo, error) {
fullSource := make(map[string]*schema.SourceInfo)
var fieldInfos []vo.FieldInfo
for _, s := range dep.staticValues {
fieldInfos = append(fieldInfos, vo.FieldInfo{
@ -113,14 +114,14 @@ func (s *NodeSchema) SetFullSources(allNS map[vo.NodeKey]*NodeSchema, dep *depen
tInfo = tInfo.Properties[path[j]]
}
if current, ok := currentSource[path[j]]; !ok {
currentSource[path[j]] = &nodes.SourceInfo{
currentSource[path[j]] = &schema.SourceInfo{
IsIntermediate: true,
FieldType: nodes.FieldNotStream,
FieldType: schema.FieldNotStream,
TypeInfo: tInfo,
SubSources: make(map[string]*nodes.SourceInfo),
SubSources: make(map[string]*schema.SourceInfo),
}
} else if !current.IsIntermediate {
return fmt.Errorf("existing sourceInfo for path %s is not intermediate, conflict", path[:j+1])
return nil, fmt.Errorf("existing sourceInfo for path %s is not intermediate, conflict", path[:j+1])
}
currentSource = currentSource[path[j]].SubSources
@ -135,9 +136,9 @@ func (s *NodeSchema) SetFullSources(allNS map[vo.NodeKey]*NodeSchema, dep *depen
// static values or variables
if fInfo.Source.Ref == nil || fInfo.Source.Ref.FromNodeKey == "" {
currentSource[lastPath] = &nodes.SourceInfo{
currentSource[lastPath] = &schema.SourceInfo{
IsIntermediate: false,
FieldType: nodes.FieldNotStream,
FieldType: schema.FieldNotStream,
TypeInfo: tInfo,
}
continue
@ -145,25 +146,25 @@ func (s *NodeSchema) SetFullSources(allNS map[vo.NodeKey]*NodeSchema, dep *depen
fromNodeKey := fInfo.Source.Ref.FromNodeKey
var (
streamType nodes.FieldStreamType
streamType schema.FieldStreamType
err error
)
if len(fromNodeKey) > 0 {
if fromNodeKey == compose.START {
streamType = nodes.FieldNotStream // TODO: set start node to not stream for now until composite node supports transform
streamType = schema.FieldNotStream // TODO: set start node to not stream for now until composite node supports transform
} else {
fromNode, ok := allNS[fromNodeKey]
if !ok {
return fmt.Errorf("node %s not found", fromNodeKey)
fromNode := sc.GetNode(fromNodeKey)
if fromNode == nil {
return nil, fmt.Errorf("node %s not found", fromNodeKey)
}
streamType, err = fromNode.IsStreamingField(fInfo.Source.Ref.FromPath, allNS)
streamType, err = nodes.IsStreamingField(fromNode, fInfo.Source.Ref.FromPath, sc)
if err != nil {
return err
return nil, err
}
}
}
currentSource[lastPath] = &nodes.SourceInfo{
currentSource[lastPath] = &schema.SourceInfo{
IsIntermediate: false,
FieldType: streamType,
FromNodeKey: fromNodeKey,
@ -172,121 +173,5 @@ func (s *NodeSchema) SetFullSources(allNS map[vo.NodeKey]*NodeSchema, dep *depen
}
}
s.Configs.(map[string]any)["FullSources"] = fullSource
return nil
}
func (s *NodeSchema) IsStreamingField(path compose.FieldPath, allNS map[vo.NodeKey]*NodeSchema) (nodes.FieldStreamType, error) {
if s.Type == entity.NodeTypeExit {
if mustGetKey[nodes.Mode]("Mode", s.Configs) == nodes.Streaming {
if len(path) == 1 && path[0] == "output" {
return nodes.FieldIsStream, nil
}
}
return nodes.FieldNotStream, nil
} else if s.Type == entity.NodeTypeSubWorkflow { // TODO: why not use sub workflow's Mode configuration directly?
subSC := s.SubWorkflowSchema
subExit := subSC.GetNode(entity.ExitNodeKey)
subStreamType, err := subExit.IsStreamingField(path, nil)
if err != nil {
return nodes.FieldNotStream, err
}
return subStreamType, nil
} else if s.Type == entity.NodeTypeVariableAggregator {
if len(path) == 2 { // asking about a specific index within a group
for _, fInfo := range s.InputSources {
if len(fInfo.Path) == len(path) {
equal := true
for i := range fInfo.Path {
if fInfo.Path[i] != path[i] {
equal = false
break
}
}
if equal {
if fInfo.Source.Ref == nil || fInfo.Source.Ref.FromNodeKey == "" {
return nodes.FieldNotStream, nil
}
fromNodeKey := fInfo.Source.Ref.FromNodeKey
fromNode, ok := allNS[fromNodeKey]
if !ok {
return nodes.FieldNotStream, fmt.Errorf("node %s not found", fromNodeKey)
}
return fromNode.IsStreamingField(fInfo.Source.Ref.FromPath, allNS)
}
}
}
} else if len(path) == 1 { // asking about the entire group
var streamCount, notStreamCount int
for _, fInfo := range s.InputSources {
if fInfo.Path[0] == path[0] { // belong to the group
if fInfo.Source.Ref != nil && len(fInfo.Source.Ref.FromNodeKey) > 0 {
fromNode, ok := allNS[fInfo.Source.Ref.FromNodeKey]
if !ok {
return nodes.FieldNotStream, fmt.Errorf("node %s not found", fInfo.Source.Ref.FromNodeKey)
}
subStreamType, err := fromNode.IsStreamingField(fInfo.Source.Ref.FromPath, allNS)
if err != nil {
return nodes.FieldNotStream, err
}
if subStreamType == nodes.FieldMaybeStream {
return nodes.FieldMaybeStream, nil
} else if subStreamType == nodes.FieldIsStream {
streamCount++
} else {
notStreamCount++
}
}
}
}
if streamCount > 0 && notStreamCount == 0 {
return nodes.FieldIsStream, nil
}
if streamCount == 0 && notStreamCount > 0 {
return nodes.FieldNotStream, nil
}
return nodes.FieldMaybeStream, nil
}
}
if s.Type != entity.NodeTypeLLM {
return nodes.FieldNotStream, nil
}
if len(path) != 1 {
return nodes.FieldNotStream, nil
}
outputs := s.OutputTypes
if len(outputs) != 1 && len(outputs) != 2 {
return nodes.FieldNotStream, nil
}
var outputKey string
for key, output := range outputs {
if output.Type != vo.DataTypeString {
return nodes.FieldNotStream, nil
}
if key != "reasoning_content" {
if len(outputKey) > 0 {
return nodes.FieldNotStream, nil
}
outputKey = key
}
}
field := path[0]
if field == "reasoning_content" || field == outputKey {
return nodes.FieldIsStream, nil
}
return nodes.FieldNotStream, nil
return fullSource, nil
}

View File

@ -28,6 +28,9 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/batch"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
func TestBatch(t *testing.T) {
@ -52,7 +55,7 @@ func TestBatch(t *testing.T) {
return in, nil
}
lambdaNode1 := &compose2.NodeSchema{
lambdaNode1 := &schema.NodeSchema{
Key: "lambda",
Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(lambda1),
@ -86,7 +89,7 @@ func TestBatch(t *testing.T) {
},
},
}
lambdaNode2 := &compose2.NodeSchema{
lambdaNode2 := &schema.NodeSchema{
Key: "index",
Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(lambda2),
@ -103,7 +106,7 @@ func TestBatch(t *testing.T) {
},
}
lambdaNode3 := &compose2.NodeSchema{
lambdaNode3 := &schema.NodeSchema{
Key: "consumer",
Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(lambda3),
@ -135,23 +138,22 @@ func TestBatch(t *testing.T) {
},
}
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
entryN := &schema.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: &entry.Config{},
}
ns := &compose2.NodeSchema{
Key: "batch_node_key",
Type: entity.NodeTypeBatch,
ns := &schema.NodeSchema{
Key: "batch_node_key",
Type: entity.NodeTypeBatch,
Configs: &batch.Config{},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"array_1"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"array_1"},
},
},
@ -160,7 +162,7 @@ func TestBatch(t *testing.T) {
Path: compose.FieldPath{"array_2"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"array_2"},
},
},
@ -214,11 +216,11 @@ func TestBatch(t *testing.T) {
},
}
exit := &compose2.NodeSchema{
exitN := &schema.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
Configs: &exit.Config{
TerminatePlan: vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
@ -246,18 +248,18 @@ func TestBatch(t *testing.T) {
return map[string]any{"success": true}, nil
}
parentLambdaNode := &compose2.NodeSchema{
parentLambdaNode := &schema.NodeSchema{
Key: "parent_predecessor_1",
Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(parentLambda),
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
ws := &schema.WorkflowSchema{
Nodes: []*schema.NodeSchema{
entryN,
parentLambdaNode,
ns,
exit,
exitN,
lambdaNode1,
lambdaNode2,
lambdaNode3,
@ -267,7 +269,7 @@ func TestBatch(t *testing.T) {
"index": "batch_node_key",
"consumer": "batch_node_key",
},
Connections: []*compose2.Connection{
Connections: []*schema.Connection{
{
FromNode: entity.EntryNodeKey,
ToNode: "parent_predecessor_1",

View File

@ -40,7 +40,11 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/emitter"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
"github.com/coze-dev/coze-studio/backend/internal/testutil"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
@ -108,22 +112,20 @@ func TestLLM(t *testing.T) {
}
}
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
entryN := &schema2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: &entry.Config{},
}
llmNode := &compose2.NodeSchema{
llmNode := &schema2.NodeSchema{
Key: "llm_node_key",
Type: entity.NodeTypeLLM,
Configs: map[string]any{
"SystemPrompt": "{{sys_prompt}}",
"UserPrompt": "{{query}}",
"OutputFormat": llm.FormatText,
"LLMParams": &model.LLMParams{
Configs: &llm.Config{
SystemPrompt: "{{sys_prompt}}",
UserPrompt: "{{query}}",
OutputFormat: llm.FormatText,
LLMParams: &model.LLMParams{
ModelName: modelName,
},
},
@ -132,7 +134,7 @@ func TestLLM(t *testing.T) {
Path: compose.FieldPath{"sys_prompt"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"sys_prompt"},
},
},
@ -141,7 +143,7 @@ func TestLLM(t *testing.T) {
Path: compose.FieldPath{"query"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"query"},
},
},
@ -162,11 +164,11 @@ func TestLLM(t *testing.T) {
},
}
exit := &compose2.NodeSchema{
exitN := &schema2.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
Configs: &exit.Config{
TerminatePlan: vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
@ -181,20 +183,20 @@ func TestLLM(t *testing.T) {
},
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
ws := &schema2.WorkflowSchema{
Nodes: []*schema2.NodeSchema{
entryN,
llmNode,
exit,
exitN,
},
Connections: []*compose2.Connection{
Connections: []*schema2.Connection{
{
FromNode: entry.Key,
FromNode: entryN.Key,
ToNode: llmNode.Key,
},
{
FromNode: llmNode.Key,
ToNode: exit.Key,
ToNode: exitN.Key,
},
},
}
@ -228,27 +230,20 @@ func TestLLM(t *testing.T) {
}
}
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
entryN := &schema2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: &entry.Config{},
}
llmNode := &compose2.NodeSchema{
llmNode := &schema2.NodeSchema{
Key: "llm_node_key",
Type: entity.NodeTypeLLM,
Configs: map[string]any{
"SystemPrompt": "you are a helpful assistant",
"UserPrompt": "what's the largest country in the world and it's area size in square kilometers?",
"OutputFormat": llm.FormatJSON,
"IgnoreException": true,
"DefaultOutput": map[string]any{
"country_name": "unknown",
"area_size": int64(0),
},
"LLMParams": &model.LLMParams{
Configs: &llm.Config{
SystemPrompt: "you are a helpful assistant",
UserPrompt: "what's the largest country in the world and it's area size in square kilometers?",
OutputFormat: llm.FormatJSON,
LLMParams: &model.LLMParams{
ModelName: modelName,
},
},
@ -264,11 +259,11 @@ func TestLLM(t *testing.T) {
},
}
exit := &compose2.NodeSchema{
exitN := &schema2.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
Configs: &exit.Config{
TerminatePlan: vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
@ -292,20 +287,20 @@ func TestLLM(t *testing.T) {
},
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
ws := &schema2.WorkflowSchema{
Nodes: []*schema2.NodeSchema{
entryN,
llmNode,
exit,
exitN,
},
Connections: []*compose2.Connection{
Connections: []*schema2.Connection{
{
FromNode: entry.Key,
FromNode: entryN.Key,
ToNode: llmNode.Key,
},
{
FromNode: llmNode.Key,
ToNode: exit.Key,
ToNode: exitN.Key,
},
},
}
@ -337,22 +332,20 @@ func TestLLM(t *testing.T) {
}
}
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
entryN := &schema2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: &entry.Config{},
}
llmNode := &compose2.NodeSchema{
llmNode := &schema2.NodeSchema{
Key: "llm_node_key",
Type: entity.NodeTypeLLM,
Configs: map[string]any{
"SystemPrompt": "you are a helpful assistant",
"UserPrompt": "list the top 5 largest countries in the world",
"OutputFormat": llm.FormatMarkdown,
"LLMParams": &model.LLMParams{
Configs: &llm.Config{
SystemPrompt: "you are a helpful assistant",
UserPrompt: "list the top 5 largest countries in the world",
OutputFormat: llm.FormatMarkdown,
LLMParams: &model.LLMParams{
ModelName: modelName,
},
},
@ -363,11 +356,11 @@ func TestLLM(t *testing.T) {
},
}
exit := &compose2.NodeSchema{
exitN := &schema2.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
Configs: &exit.Config{
TerminatePlan: vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
@ -382,20 +375,20 @@ func TestLLM(t *testing.T) {
},
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
ws := &schema2.WorkflowSchema{
Nodes: []*schema2.NodeSchema{
entryN,
llmNode,
exit,
exitN,
},
Connections: []*compose2.Connection{
Connections: []*schema2.Connection{
{
FromNode: entry.Key,
FromNode: entryN.Key,
ToNode: llmNode.Key,
},
{
FromNode: llmNode.Key,
ToNode: exit.Key,
ToNode: exitN.Key,
},
},
}
@ -456,22 +449,20 @@ func TestLLM(t *testing.T) {
}
}
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
entryN := &schema2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: &entry.Config{},
}
openaiNode := &compose2.NodeSchema{
openaiNode := &schema2.NodeSchema{
Key: "openai_llm_node_key",
Type: entity.NodeTypeLLM,
Configs: map[string]any{
"SystemPrompt": "you are a helpful assistant",
"UserPrompt": "plan a 10 day family visit to China.",
"OutputFormat": llm.FormatText,
"LLMParams": &model.LLMParams{
Configs: &llm.Config{
SystemPrompt: "you are a helpful assistant",
UserPrompt: "plan a 10 day family visit to China.",
OutputFormat: llm.FormatText,
LLMParams: &model.LLMParams{
ModelName: modelName,
},
},
@ -482,14 +473,14 @@ func TestLLM(t *testing.T) {
},
}
deepseekNode := &compose2.NodeSchema{
deepseekNode := &schema2.NodeSchema{
Key: "deepseek_llm_node_key",
Type: entity.NodeTypeLLM,
Configs: map[string]any{
"SystemPrompt": "you are a helpful assistant",
"UserPrompt": "thoroughly plan a 10 day family visit to China. Use your reasoning ability.",
"OutputFormat": llm.FormatText,
"LLMParams": &model.LLMParams{
Configs: &llm.Config{
SystemPrompt: "you are a helpful assistant",
UserPrompt: "thoroughly plan a 10 day family visit to China. Use your reasoning ability.",
OutputFormat: llm.FormatText,
LLMParams: &model.LLMParams{
ModelName: modelName,
},
},
@ -503,12 +494,11 @@ func TestLLM(t *testing.T) {
},
}
emitterNode := &compose2.NodeSchema{
emitterNode := &schema2.NodeSchema{
Key: "emitter_node_key",
Type: entity.NodeTypeOutputEmitter,
Configs: map[string]any{
"Template": "prefix {{inputObj.field1}} {{input2}} {{deepseek_reasoning}} \n\n###\n\n {{openai_output}} \n\n###\n\n {{deepseek_output}} {{inputObj.field2}} suffix",
"Mode": nodes.Streaming,
Configs: &emitter.Config{
Template: "prefix {{inputObj.field1}} {{input2}} {{deepseek_reasoning}} \n\n###\n\n {{openai_output}} \n\n###\n\n {{deepseek_output}} {{inputObj.field2}} suffix",
},
InputSources: []*vo.FieldInfo{
{
@ -542,7 +532,7 @@ func TestLLM(t *testing.T) {
Path: compose.FieldPath{"inputObj"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"inputObj"},
},
},
@ -551,7 +541,7 @@ func TestLLM(t *testing.T) {
Path: compose.FieldPath{"input2"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"input2"},
},
},
@ -559,11 +549,11 @@ func TestLLM(t *testing.T) {
},
}
exit := &compose2.NodeSchema{
exitN := &schema2.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.UseAnswerContent,
Configs: &exit.Config{
TerminatePlan: vo.UseAnswerContent,
},
InputSources: []*vo.FieldInfo{
{
@ -596,17 +586,17 @@ func TestLLM(t *testing.T) {
},
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
ws := &schema2.WorkflowSchema{
Nodes: []*schema2.NodeSchema{
entryN,
openaiNode,
deepseekNode,
emitterNode,
exit,
exitN,
},
Connections: []*compose2.Connection{
Connections: []*schema2.Connection{
{
FromNode: entry.Key,
FromNode: entryN.Key,
ToNode: openaiNode.Key,
},
{
@ -614,7 +604,7 @@ func TestLLM(t *testing.T) {
ToNode: emitterNode.Key,
},
{
FromNode: entry.Key,
FromNode: entryN.Key,
ToNode: deepseekNode.Key,
},
{
@ -623,7 +613,7 @@ func TestLLM(t *testing.T) {
},
{
FromNode: emitterNode.Key,
ToNode: exit.Key,
ToNode: exitN.Key,
},
},
}

View File

@ -26,15 +26,20 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop"
_break "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop/break"
_continue "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop/continue"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableassigner"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
func TestLoop(t *testing.T) {
t.Run("by iteration", func(t *testing.T) {
// start-> loop_node_key[innerNode->continue] -> end
innerNode := &compose2.NodeSchema{
innerNode := &schema.NodeSchema{
Key: "innerNode",
Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
@ -54,31 +59,30 @@ func TestLoop(t *testing.T) {
},
}
continueNode := &compose2.NodeSchema{
Key: "continueNode",
Type: entity.NodeTypeContinue,
continueNode := &schema.NodeSchema{
Key: "continueNode",
Type: entity.NodeTypeContinue,
Configs: &_continue.Config{},
}
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
entryN := &schema.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: &entry.Config{},
}
loopNode := &compose2.NodeSchema{
loopNode := &schema.NodeSchema{
Key: "loop_node_key",
Type: entity.NodeTypeLoop,
Configs: map[string]any{
"LoopType": loop.ByIteration,
Configs: &loop.Config{
LoopType: loop.ByIteration,
},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{loop.Count},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"count"},
},
},
@ -97,11 +101,11 @@ func TestLoop(t *testing.T) {
},
}
exit := &compose2.NodeSchema{
exitN := &schema.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
Configs: &exit.Config{
TerminatePlan: vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
@ -116,11 +120,11 @@ func TestLoop(t *testing.T) {
},
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
ws := &schema.WorkflowSchema{
Nodes: []*schema.NodeSchema{
entryN,
loopNode,
exit,
exitN,
innerNode,
continueNode,
},
@ -128,7 +132,7 @@ func TestLoop(t *testing.T) {
"innerNode": "loop_node_key",
"continueNode": "loop_node_key",
},
Connections: []*compose2.Connection{
Connections: []*schema.Connection{
{
FromNode: "loop_node_key",
ToNode: "innerNode",
@ -142,12 +146,12 @@ func TestLoop(t *testing.T) {
ToNode: "loop_node_key",
},
{
FromNode: entry.Key,
FromNode: entryN.Key,
ToNode: "loop_node_key",
},
{
FromNode: "loop_node_key",
ToNode: exit.Key,
ToNode: exitN.Key,
},
},
}
@ -168,7 +172,7 @@ func TestLoop(t *testing.T) {
t.Run("infinite", func(t *testing.T) {
// start-> loop_node_key[innerNode->break] -> end
innerNode := &compose2.NodeSchema{
innerNode := &schema.NodeSchema{
Key: "innerNode",
Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
@ -188,24 +192,23 @@ func TestLoop(t *testing.T) {
},
}
breakNode := &compose2.NodeSchema{
Key: "breakNode",
Type: entity.NodeTypeBreak,
breakNode := &schema.NodeSchema{
Key: "breakNode",
Type: entity.NodeTypeBreak,
Configs: &_break.Config{},
}
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
entryN := &schema.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: &entry.Config{},
}
loopNode := &compose2.NodeSchema{
loopNode := &schema.NodeSchema{
Key: "loop_node_key",
Type: entity.NodeTypeLoop,
Configs: map[string]any{
"LoopType": loop.Infinite,
Configs: &loop.Config{
LoopType: loop.Infinite,
},
OutputSources: []*vo.FieldInfo{
{
@ -220,11 +223,11 @@ func TestLoop(t *testing.T) {
},
}
exit := &compose2.NodeSchema{
exitN := &schema.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
Configs: &exit.Config{
TerminatePlan: vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
@ -239,11 +242,11 @@ func TestLoop(t *testing.T) {
},
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
ws := &schema.WorkflowSchema{
Nodes: []*schema.NodeSchema{
entryN,
loopNode,
exit,
exitN,
innerNode,
breakNode,
},
@ -251,7 +254,7 @@ func TestLoop(t *testing.T) {
"innerNode": "loop_node_key",
"breakNode": "loop_node_key",
},
Connections: []*compose2.Connection{
Connections: []*schema.Connection{
{
FromNode: "loop_node_key",
ToNode: "innerNode",
@ -265,12 +268,12 @@ func TestLoop(t *testing.T) {
ToNode: "loop_node_key",
},
{
FromNode: entry.Key,
FromNode: entryN.Key,
ToNode: "loop_node_key",
},
{
FromNode: "loop_node_key",
ToNode: exit.Key,
ToNode: exitN.Key,
},
},
}
@ -290,14 +293,14 @@ func TestLoop(t *testing.T) {
t.Run("by array", func(t *testing.T) {
// start-> loop_node_key[innerNode->variable_assign] -> end
innerNode := &compose2.NodeSchema{
innerNode := &schema.NodeSchema{
Key: "innerNode",
Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
item1 := in["item1"].(string)
item2 := in["item2"].(string)
count := in["count"].(int)
return map[string]any{"total": int(count) + len(item1) + len(item2)}, nil
return map[string]any{"total": count + len(item1) + len(item2)}, nil
}),
InputSources: []*vo.FieldInfo{
{
@ -330,16 +333,18 @@ func TestLoop(t *testing.T) {
},
}
assigner := &compose2.NodeSchema{
assigner := &schema.NodeSchema{
Key: "assigner",
Type: entity.NodeTypeVariableAssignerWithinLoop,
Configs: []*variableassigner.Pair{
{
Left: vo.Reference{
FromPath: compose.FieldPath{"count"},
VariableType: ptr.Of(vo.ParentIntermediate),
Configs: &variableassigner.InLoopConfig{
Pairs: []*variableassigner.Pair{
{
Left: vo.Reference{
FromPath: compose.FieldPath{"count"},
VariableType: ptr.Of(vo.ParentIntermediate),
},
Right: compose.FieldPath{"total"},
},
Right: compose.FieldPath{"total"},
},
},
InputSources: []*vo.FieldInfo{
@ -355,19 +360,17 @@ func TestLoop(t *testing.T) {
},
}
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
entryN := &schema.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: &entry.Config{},
}
exit := &compose2.NodeSchema{
exitN := &schema.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
Configs: &exit.Config{
TerminatePlan: vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
@ -382,12 +385,13 @@ func TestLoop(t *testing.T) {
},
}
loopNode := &compose2.NodeSchema{
loopNode := &schema.NodeSchema{
Key: "loop_node_key",
Type: entity.NodeTypeLoop,
Configs: map[string]any{
"LoopType": loop.ByArray,
"IntermediateVars": map[string]*vo.TypeInfo{
Configs: &loop.Config{
LoopType: loop.ByArray,
InputArrays: []string{"items1", "items2"},
IntermediateVars: map[string]*vo.TypeInfo{
"count": {
Type: vo.DataTypeInteger,
},
@ -408,7 +412,7 @@ func TestLoop(t *testing.T) {
Path: compose.FieldPath{"items1"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"items1"},
},
},
@ -417,7 +421,7 @@ func TestLoop(t *testing.T) {
Path: compose.FieldPath{"items2"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"items2"},
},
},
@ -442,11 +446,11 @@ func TestLoop(t *testing.T) {
},
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
ws := &schema.WorkflowSchema{
Nodes: []*schema.NodeSchema{
entryN,
loopNode,
exit,
exitN,
innerNode,
assigner,
},
@ -454,7 +458,7 @@ func TestLoop(t *testing.T) {
"innerNode": "loop_node_key",
"assigner": "loop_node_key",
},
Connections: []*compose2.Connection{
Connections: []*schema.Connection{
{
FromNode: "loop_node_key",
ToNode: "innerNode",
@ -468,12 +472,12 @@ func TestLoop(t *testing.T) {
ToNode: "loop_node_key",
},
{
FromNode: entry.Key,
FromNode: entryN.Key,
ToNode: "loop_node_key",
},
{
FromNode: "loop_node_key",
ToNode: exit.Key,
ToNode: exitN.Key,
},
},
}

View File

@ -43,8 +43,11 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
repo2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/repo"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/infra/impl/checkpoint"
mock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/idgen"
storageMock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/storage"
@ -106,26 +109,25 @@ func TestQuestionAnswer(t *testing.T) {
mockey.Mock(workflow.GetRepository).Return(repo).Build()
t.Run("answer directly, no structured output", func(t *testing.T) {
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
}}
entryN := &schema2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: &entry.Config{},
}
ns := &compose2.NodeSchema{
ns := &schema2.NodeSchema{
Key: "qa_node_key",
Type: entity.NodeTypeQuestionAnswer,
Configs: map[string]any{
"QuestionTpl": "{{input}}",
"AnswerType": qa.AnswerDirectly,
Configs: &qa.Config{
QuestionTpl: "{{input}}",
AnswerType: qa.AnswerDirectly,
},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"input"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"query"},
},
},
@ -133,11 +135,11 @@ func TestQuestionAnswer(t *testing.T) {
},
}
exit := &compose2.NodeSchema{
exitN := &schema2.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
Configs: &exit.Config{
TerminatePlan: vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
@ -152,20 +154,20 @@ func TestQuestionAnswer(t *testing.T) {
},
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
ws := &schema2.WorkflowSchema{
Nodes: []*schema2.NodeSchema{
entryN,
ns,
exit,
exitN,
},
Connections: []*compose2.Connection{
Connections: []*schema2.Connection{
{
FromNode: entry.Key,
FromNode: entryN.Key,
ToNode: "qa_node_key",
},
{
FromNode: "qa_node_key",
ToNode: exit.Key,
ToNode: exitN.Key,
},
},
}
@ -210,30 +212,28 @@ func TestQuestionAnswer(t *testing.T) {
mockModelManager.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(oneChatModel, nil, nil).Times(1)
}
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
entryN := &schema2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: &entry.Config{},
}
ns := &compose2.NodeSchema{
ns := &schema2.NodeSchema{
Key: "qa_node_key",
Type: entity.NodeTypeQuestionAnswer,
Configs: map[string]any{
"QuestionTpl": "{{input}}",
"AnswerType": qa.AnswerByChoices,
"ChoiceType": qa.FixedChoices,
"FixedChoices": []string{"{{choice1}}", "{{choice2}}"},
"LLMParams": &model.LLMParams{},
Configs: &qa.Config{
QuestionTpl: "{{input}}",
AnswerType: qa.AnswerByChoices,
ChoiceType: qa.FixedChoices,
FixedChoices: []string{"{{choice1}}", "{{choice2}}"},
LLMParams: &model.LLMParams{},
},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"input"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"query"},
},
},
@ -242,7 +242,7 @@ func TestQuestionAnswer(t *testing.T) {
Path: compose.FieldPath{"choice1"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"choice1"},
},
},
@ -251,7 +251,7 @@ func TestQuestionAnswer(t *testing.T) {
Path: compose.FieldPath{"choice2"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"choice2"},
},
},
@ -259,11 +259,11 @@ func TestQuestionAnswer(t *testing.T) {
},
}
exit := &compose2.NodeSchema{
exitN := &schema2.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
Configs: &exit.Config{
TerminatePlan: vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
@ -287,7 +287,7 @@ func TestQuestionAnswer(t *testing.T) {
},
}
lambda := &compose2.NodeSchema{
lambda := &schema2.NodeSchema{
Key: "lambda",
Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
@ -295,26 +295,26 @@ func TestQuestionAnswer(t *testing.T) {
}),
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
ws := &schema2.WorkflowSchema{
Nodes: []*schema2.NodeSchema{
entryN,
ns,
exit,
exitN,
lambda,
},
Connections: []*compose2.Connection{
Connections: []*schema2.Connection{
{
FromNode: entry.Key,
FromNode: entryN.Key,
ToNode: "qa_node_key",
},
{
FromNode: "qa_node_key",
ToNode: exit.Key,
ToNode: exitN.Key,
FromPort: ptr.Of("branch_0"),
},
{
FromNode: "qa_node_key",
ToNode: exit.Key,
ToNode: exitN.Key,
FromPort: ptr.Of("branch_1"),
},
{
@ -324,11 +324,15 @@ func TestQuestionAnswer(t *testing.T) {
},
{
FromNode: "lambda",
ToNode: exit.Key,
ToNode: exitN.Key,
},
},
}
branches, err := schema2.BuildBranches(ws.Connections)
assert.NoError(t, err)
ws.Branches = branches
ws.Init()
wf, err := compose2.NewWorkflow(context.Background(), ws)
@ -362,28 +366,26 @@ func TestQuestionAnswer(t *testing.T) {
})
t.Run("answer with dynamic choices", func(t *testing.T) {
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
entryN := &schema2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: &entry.Config{},
}
ns := &compose2.NodeSchema{
ns := &schema2.NodeSchema{
Key: "qa_node_key",
Type: entity.NodeTypeQuestionAnswer,
Configs: map[string]any{
"QuestionTpl": "{{input}}",
"AnswerType": qa.AnswerByChoices,
"ChoiceType": qa.DynamicChoices,
Configs: &qa.Config{
QuestionTpl: "{{input}}",
AnswerType: qa.AnswerByChoices,
ChoiceType: qa.DynamicChoices,
},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"input"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"query"},
},
},
@ -392,7 +394,7 @@ func TestQuestionAnswer(t *testing.T) {
Path: compose.FieldPath{qa.DynamicChoicesKey},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"choices"},
},
},
@ -400,11 +402,11 @@ func TestQuestionAnswer(t *testing.T) {
},
}
exit := &compose2.NodeSchema{
exitN := &schema2.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
Configs: &exit.Config{
TerminatePlan: vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
@ -428,7 +430,7 @@ func TestQuestionAnswer(t *testing.T) {
},
}
lambda := &compose2.NodeSchema{
lambda := &schema2.NodeSchema{
Key: "lambda",
Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
@ -436,26 +438,26 @@ func TestQuestionAnswer(t *testing.T) {
}),
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
ws := &schema2.WorkflowSchema{
Nodes: []*schema2.NodeSchema{
entryN,
ns,
exit,
exitN,
lambda,
},
Connections: []*compose2.Connection{
Connections: []*schema2.Connection{
{
FromNode: entry.Key,
FromNode: entryN.Key,
ToNode: "qa_node_key",
},
{
FromNode: "qa_node_key",
ToNode: exit.Key,
ToNode: exitN.Key,
FromPort: ptr.Of("branch_0"),
},
{
FromNode: "lambda",
ToNode: exit.Key,
ToNode: exitN.Key,
},
{
FromNode: "qa_node_key",
@ -465,6 +467,10 @@ func TestQuestionAnswer(t *testing.T) {
},
}
branches, err := schema2.BuildBranches(ws.Connections)
assert.NoError(t, err)
ws.Branches = branches
ws.Init()
wf, err := compose2.NewWorkflow(context.Background(), ws)
@ -522,31 +528,29 @@ func TestQuestionAnswer(t *testing.T) {
mockModelManager.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(chatModel, nil, nil).Times(1)
}
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
entryN := &schema2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: &entry.Config{},
}
ns := &compose2.NodeSchema{
ns := &schema2.NodeSchema{
Key: "qa_node_key",
Type: entity.NodeTypeQuestionAnswer,
Configs: map[string]any{
"QuestionTpl": "{{input}}",
"AnswerType": qa.AnswerDirectly,
"ExtractFromAnswer": true,
"AdditionalSystemPromptTpl": "{{prompt}}",
"MaxAnswerCount": 2,
"LLMParams": &model.LLMParams{},
Configs: &qa.Config{
QuestionTpl: "{{input}}",
AnswerType: qa.AnswerDirectly,
ExtractFromAnswer: true,
AdditionalSystemPromptTpl: "{{prompt}}",
MaxAnswerCount: 2,
LLMParams: &model.LLMParams{},
},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"input"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"query"},
},
},
@ -555,7 +559,7 @@ func TestQuestionAnswer(t *testing.T) {
Path: compose.FieldPath{"prompt"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"prompt"},
},
},
@ -573,11 +577,11 @@ func TestQuestionAnswer(t *testing.T) {
},
}
exit := &compose2.NodeSchema{
exitN := &schema2.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
Configs: &exit.Config{
TerminatePlan: vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
@ -610,20 +614,20 @@ func TestQuestionAnswer(t *testing.T) {
},
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
ws := &schema2.WorkflowSchema{
Nodes: []*schema2.NodeSchema{
entryN,
ns,
exit,
exitN,
},
Connections: []*compose2.Connection{
Connections: []*schema2.Connection{
{
FromNode: entry.Key,
FromNode: entryN.Key,
ToNode: "qa_node_key",
},
{
FromNode: "qa_node_key",
ToNode: exit.Key,
ToNode: exitN.Key,
},
},
}

View File

@ -26,26 +26,28 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/textprocessor"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableaggregator"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
func TestAddSelector(t *testing.T) {
// start -> selector, selector.condition1 -> lambda1 -> end, selector.condition2 -> [lambda2, lambda3] -> end, selector default -> end
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
}}
entryN := &schema.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: &entry.Config{},
}
exit := &compose2.NodeSchema{
exitN := &schema.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
Configs: &exit.Config{
TerminatePlan: vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
@ -84,7 +86,7 @@ func TestAddSelector(t *testing.T) {
}, nil
}
lambdaNode1 := &compose2.NodeSchema{
lambdaNode1 := &schema.NodeSchema{
Key: "lambda1",
Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(lambda1),
@ -96,7 +98,7 @@ func TestAddSelector(t *testing.T) {
}, nil
}
LambdaNode2 := &compose2.NodeSchema{
LambdaNode2 := &schema.NodeSchema{
Key: "lambda2",
Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(lambda2),
@ -108,16 +110,16 @@ func TestAddSelector(t *testing.T) {
}, nil
}
lambdaNode3 := &compose2.NodeSchema{
lambdaNode3 := &schema.NodeSchema{
Key: "lambda3",
Type: entity.NodeTypeLambda,
Lambda: compose.InvokableLambda(lambda3),
}
ns := &compose2.NodeSchema{
ns := &schema.NodeSchema{
Key: "selector",
Type: entity.NodeTypeSelector,
Configs: map[string]any{"Clauses": []*selector.OneClauseSchema{
Configs: &selector.Config{Clauses: []*selector.OneClauseSchema{
{
Single: ptr.Of(selector.OperatorEqual),
},
@ -136,7 +138,7 @@ func TestAddSelector(t *testing.T) {
Path: compose.FieldPath{"0", selector.LeftKey},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"key1"},
},
},
@ -151,7 +153,7 @@ func TestAddSelector(t *testing.T) {
Path: compose.FieldPath{"1", "0", selector.LeftKey},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"key2"},
},
},
@ -160,7 +162,7 @@ func TestAddSelector(t *testing.T) {
Path: compose.FieldPath{"1", "0", selector.RightKey},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"key3"},
},
},
@ -169,7 +171,7 @@ func TestAddSelector(t *testing.T) {
Path: compose.FieldPath{"1", "1", selector.LeftKey},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"key4"},
},
},
@ -214,18 +216,18 @@ func TestAddSelector(t *testing.T) {
},
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
ws := &schema.WorkflowSchema{
Nodes: []*schema.NodeSchema{
entryN,
ns,
lambdaNode1,
LambdaNode2,
lambdaNode3,
exit,
exitN,
},
Connections: []*compose2.Connection{
Connections: []*schema.Connection{
{
FromNode: entry.Key,
FromNode: entryN.Key,
ToNode: "selector",
},
{
@ -245,24 +247,28 @@ func TestAddSelector(t *testing.T) {
},
{
FromNode: "selector",
ToNode: exit.Key,
ToNode: exitN.Key,
FromPort: ptr.Of("default"),
},
{
FromNode: "lambda1",
ToNode: exit.Key,
ToNode: exitN.Key,
},
{
FromNode: "lambda2",
ToNode: exit.Key,
ToNode: exitN.Key,
},
{
FromNode: "lambda3",
ToNode: exit.Key,
ToNode: exitN.Key,
},
},
}
branches, err := schema.BuildBranches(ws.Connections)
assert.NoError(t, err)
ws.Branches = branches
ws.Init()
ctx := context.Background()
@ -303,19 +309,17 @@ func TestAddSelector(t *testing.T) {
}
func TestVariableAggregator(t *testing.T) {
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
entryN := &schema.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: &entry.Config{},
}
exit := &compose2.NodeSchema{
exitN := &schema.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
Configs: &exit.Config{
TerminatePlan: vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
@ -339,16 +343,16 @@ func TestVariableAggregator(t *testing.T) {
},
}
ns := &compose2.NodeSchema{
ns := &schema.NodeSchema{
Key: "va",
Type: entity.NodeTypeVariableAggregator,
Configs: map[string]any{
"MergeStrategy": variableaggregator.FirstNotNullValue,
"GroupToLen": map[string]int{
Configs: &variableaggregator.Config{
MergeStrategy: variableaggregator.FirstNotNullValue,
GroupLen: map[string]int{
"Group1": 1,
"Group2": 1,
},
"GroupOrder": []string{
GroupOrder: []string{
"Group1",
"Group2",
},
@ -358,7 +362,7 @@ func TestVariableAggregator(t *testing.T) {
Path: compose.FieldPath{"Group1", "0"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"Str1"},
},
},
@ -367,7 +371,7 @@ func TestVariableAggregator(t *testing.T) {
Path: compose.FieldPath{"Group2", "0"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"Int1"},
},
},
@ -401,20 +405,20 @@ func TestVariableAggregator(t *testing.T) {
},
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
entry,
ws := &schema.WorkflowSchema{
Nodes: []*schema.NodeSchema{
entryN,
ns,
exit,
exitN,
},
Connections: []*compose2.Connection{
Connections: []*schema.Connection{
{
FromNode: entry.Key,
FromNode: entryN.Key,
ToNode: "va",
},
{
FromNode: "va",
ToNode: exit.Key,
ToNode: exitN.Key,
},
},
}
@ -448,19 +452,17 @@ func TestVariableAggregator(t *testing.T) {
func TestTextProcessor(t *testing.T) {
t.Run("split", func(t *testing.T) {
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
entryN := &schema.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: &entry.Config{},
}
exit := &compose2.NodeSchema{
exitN := &schema.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
Configs: &exit.Config{
TerminatePlan: vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
@ -475,19 +477,19 @@ func TestTextProcessor(t *testing.T) {
},
}
ns := &compose2.NodeSchema{
ns := &schema.NodeSchema{
Key: "tp",
Type: entity.NodeTypeTextProcessor,
Configs: map[string]any{
"Type": textprocessor.SplitText,
"Separators": []string{"|"},
Configs: &textprocessor.Config{
Type: textprocessor.SplitText,
Separators: []string{"|"},
},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"String"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"Str"},
},
},
@ -495,20 +497,20 @@ func TestTextProcessor(t *testing.T) {
},
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
ws := &schema.WorkflowSchema{
Nodes: []*schema.NodeSchema{
ns,
entry,
exit,
entryN,
exitN,
},
Connections: []*compose2.Connection{
Connections: []*schema.Connection{
{
FromNode: entry.Key,
FromNode: entryN.Key,
ToNode: "tp",
},
{
FromNode: "tp",
ToNode: exit.Key,
ToNode: exitN.Key,
},
},
}
@ -527,19 +529,17 @@ func TestTextProcessor(t *testing.T) {
})
t.Run("concat", func(t *testing.T) {
entry := &compose2.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: map[string]any{
"DefaultValues": map[string]any{},
},
entryN := &schema.NodeSchema{
Key: entity.EntryNodeKey,
Type: entity.NodeTypeEntry,
Configs: &entry.Config{},
}
exit := &compose2.NodeSchema{
exitN := &schema.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Configs: map[string]any{
"TerminalPlan": vo.ReturnVariables,
Configs: &exit.Config{
TerminatePlan: vo.ReturnVariables,
},
InputSources: []*vo.FieldInfo{
{
@ -554,20 +554,20 @@ func TestTextProcessor(t *testing.T) {
},
}
ns := &compose2.NodeSchema{
ns := &schema.NodeSchema{
Key: "tp",
Type: entity.NodeTypeTextProcessor,
Configs: map[string]any{
"Type": textprocessor.ConcatText,
"Tpl": "{{String1}}_{{String2.f1}}_{{String3.f2[1]}}",
"ConcatChar": "\t",
Configs: &textprocessor.Config{
Type: textprocessor.ConcatText,
Tpl: "{{String1}}_{{String2.f1}}_{{String3.f2[1]}}",
ConcatChar: "\t",
},
InputSources: []*vo.FieldInfo{
{
Path: compose.FieldPath{"String1"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"Str1"},
},
},
@ -576,7 +576,7 @@ func TestTextProcessor(t *testing.T) {
Path: compose.FieldPath{"String2"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"Str2"},
},
},
@ -585,7 +585,7 @@ func TestTextProcessor(t *testing.T) {
Path: compose.FieldPath{"String3"},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: entry.Key,
FromNodeKey: entryN.Key,
FromPath: compose.FieldPath{"Str3"},
},
},
@ -593,20 +593,20 @@ func TestTextProcessor(t *testing.T) {
},
}
ws := &compose2.WorkflowSchema{
Nodes: []*compose2.NodeSchema{
ws := &schema.WorkflowSchema{
Nodes: []*schema.NodeSchema{
ns,
entry,
exit,
entryN,
exitN,
},
Connections: []*compose2.Connection{
Connections: []*schema.Connection{
{
FromNode: entry.Key,
FromNode: entryN.Key,
ToNode: "tp",
},
{
FromNode: "tp",
ToNode: exit.Key,
ToNode: exitN.Key,
},
},
}

View File

@ -1,652 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package compose
import (
"context"
"errors"
"fmt"
"runtime/debug"
"strconv"
"time"
einomodel "github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/components/tool"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
workflow3 "github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/workflow"
workflow2 "github.com/coze-dev/coze-studio/backend/domain/workflow"
crosscode "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/code"
crossconversation "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/conversation"
crossdatabase "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
crossknowledge "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
crossplugin "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/batch"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/code"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/conversation"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/emitter"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/httprequester"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/intentdetector"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/json"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/knowledge"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/plugin"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/receiver"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/subworkflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/textprocessor"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableaggregator"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableassigner"
"github.com/coze-dev/coze-studio/backend/infra/contract/coderunner"
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
)
func (s *NodeSchema) ToEntryConfig(_ context.Context) (*entry.Config, error) {
return &entry.Config{
DefaultValues: getKeyOrZero[map[string]any]("DefaultValues", s.Configs),
OutputTypes: s.OutputTypes,
}, nil
}
func (s *NodeSchema) ToLLMConfig(ctx context.Context) (*llm.Config, error) {
llmConf := &llm.Config{
SystemPrompt: getKeyOrZero[string]("SystemPrompt", s.Configs),
UserPrompt: getKeyOrZero[string]("UserPrompt", s.Configs),
OutputFormat: mustGetKey[llm.Format]("OutputFormat", s.Configs),
InputFields: s.InputTypes,
OutputFields: s.OutputTypes,
FullSources: getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs),
}
llmParams := getKeyOrZero[*model.LLMParams]("LLMParams", s.Configs)
if llmParams == nil {
return nil, fmt.Errorf("llm node llmParams is required")
}
var (
err error
chatModel, fallbackM einomodel.BaseChatModel
info, fallbackI *modelmgr.Model
modelWithInfo llm.ModelWithInfo
)
chatModel, info, err = model.GetManager().GetModel(ctx, llmParams)
if err != nil {
return nil, err
}
metaConfigs := s.ExceptionConfigs
if metaConfigs != nil && metaConfigs.MaxRetry > 0 {
backupModelParams := getKeyOrZero[*model.LLMParams]("BackupLLMParams", s.Configs)
if backupModelParams != nil {
fallbackM, fallbackI, err = model.GetManager().GetModel(ctx, backupModelParams)
if err != nil {
return nil, err
}
}
}
if fallbackM == nil {
modelWithInfo = llm.NewModel(chatModel, info)
} else {
modelWithInfo = llm.NewModelWithFallback(chatModel, fallbackM, info, fallbackI)
}
llmConf.ChatModel = modelWithInfo
fcParams := getKeyOrZero[*vo.FCParam]("FCParam", s.Configs)
if fcParams != nil {
if fcParams.WorkflowFCParam != nil {
for _, wf := range fcParams.WorkflowFCParam.WorkflowList {
wfIDStr := wf.WorkflowID
wfID, err := strconv.ParseInt(wfIDStr, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid workflow id: %s", wfIDStr)
}
workflowToolConfig := vo.WorkflowToolConfig{}
if wf.FCSetting != nil {
workflowToolConfig.InputParametersConfig = wf.FCSetting.RequestParameters
workflowToolConfig.OutputParametersConfig = wf.FCSetting.ResponseParameters
}
locator := vo.FromDraft
if wf.WorkflowVersion != "" {
locator = vo.FromSpecificVersion
}
wfTool, err := workflow2.GetRepository().WorkflowAsTool(ctx, vo.GetPolicy{
ID: wfID,
QType: locator,
Version: wf.WorkflowVersion,
}, workflowToolConfig)
if err != nil {
return nil, err
}
llmConf.Tools = append(llmConf.Tools, wfTool)
if wfTool.TerminatePlan() == vo.UseAnswerContent {
if llmConf.ToolsReturnDirectly == nil {
llmConf.ToolsReturnDirectly = make(map[string]bool)
}
toolInfo, err := wfTool.Info(ctx)
if err != nil {
return nil, err
}
llmConf.ToolsReturnDirectly[toolInfo.Name] = true
}
}
}
if fcParams.PluginFCParam != nil {
pluginToolsInvokableReq := make(map[int64]*crossplugin.ToolsInvokableRequest)
for _, p := range fcParams.PluginFCParam.PluginList {
pid, err := strconv.ParseInt(p.PluginID, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid plugin id: %s", p.PluginID)
}
toolID, err := strconv.ParseInt(p.ApiId, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid plugin id: %s", p.PluginID)
}
var (
requestParameters []*workflow3.APIParameter
responseParameters []*workflow3.APIParameter
)
if p.FCSetting != nil {
requestParameters = p.FCSetting.RequestParameters
responseParameters = p.FCSetting.ResponseParameters
}
if req, ok := pluginToolsInvokableReq[pid]; ok {
req.ToolsInvokableInfo[toolID] = &crossplugin.ToolsInvokableInfo{
ToolID: toolID,
RequestAPIParametersConfig: requestParameters,
ResponseAPIParametersConfig: responseParameters,
}
} else {
pluginToolsInfoRequest := &crossplugin.ToolsInvokableRequest{
PluginEntity: crossplugin.Entity{
PluginID: pid,
PluginVersion: ptr.Of(p.PluginVersion),
},
ToolsInvokableInfo: map[int64]*crossplugin.ToolsInvokableInfo{
toolID: {
ToolID: toolID,
RequestAPIParametersConfig: requestParameters,
ResponseAPIParametersConfig: responseParameters,
},
},
IsDraft: p.IsDraft,
}
pluginToolsInvokableReq[pid] = pluginToolsInfoRequest
}
}
inInvokableTools := make([]tool.BaseTool, 0, len(fcParams.PluginFCParam.PluginList))
for _, req := range pluginToolsInvokableReq {
toolMap, err := crossplugin.GetPluginService().GetPluginInvokableTools(ctx, req)
if err != nil {
return nil, err
}
for _, t := range toolMap {
inInvokableTools = append(inInvokableTools, crossplugin.NewInvokableTool(t))
}
}
if len(inInvokableTools) > 0 {
llmConf.Tools = inInvokableTools
}
}
if fcParams.KnowledgeFCParam != nil && len(fcParams.KnowledgeFCParam.KnowledgeList) > 0 {
kwChatModel := workflow2.GetRepository().GetKnowledgeRecallChatModel()
if kwChatModel == nil {
return nil, fmt.Errorf("workflow builtin chat model for knowledge recall not configured")
}
knowledgeOperator := crossknowledge.GetKnowledgeOperator()
setting := fcParams.KnowledgeFCParam.GlobalSetting
cfg := &llm.KnowledgeRecallConfig{
ChatModel: kwChatModel,
Retriever: knowledgeOperator,
}
searchType, err := totRetrievalSearchType(setting.SearchMode)
if err != nil {
return nil, err
}
cfg.RetrievalStrategy = &llm.RetrievalStrategy{
RetrievalStrategy: &crossknowledge.RetrievalStrategy{
TopK: ptr.Of(setting.TopK),
MinScore: ptr.Of(setting.MinScore),
SearchType: searchType,
EnableNL2SQL: setting.UseNL2SQL,
EnableQueryRewrite: setting.UseRewrite,
EnableRerank: setting.UseRerank,
},
NoReCallReplyMode: llm.NoReCallReplyMode(setting.NoRecallReplyMode),
NoReCallReplyCustomizePrompt: setting.NoRecallReplyCustomizePrompt,
}
knowledgeIDs := make([]int64, 0, len(fcParams.KnowledgeFCParam.KnowledgeList))
for _, kw := range fcParams.KnowledgeFCParam.KnowledgeList {
kid, err := strconv.ParseInt(kw.ID, 10, 64)
if err != nil {
return nil, err
}
knowledgeIDs = append(knowledgeIDs, kid)
}
detailResp, err := knowledgeOperator.ListKnowledgeDetail(ctx, &crossknowledge.ListKnowledgeDetailRequest{
KnowledgeIDs: knowledgeIDs,
})
if err != nil {
return nil, err
}
cfg.SelectedKnowledgeDetails = detailResp.KnowledgeDetails
llmConf.KnowledgeRecallConfig = cfg
}
}
return llmConf, nil
}
func (s *NodeSchema) ToSelectorConfig() *selector.Config {
return &selector.Config{
Clauses: mustGetKey[[]*selector.OneClauseSchema]("Clauses", s.Configs),
}
}
func (s *NodeSchema) SelectorInputConverter(in map[string]any) (out []selector.Operants, err error) {
conf := mustGetKey[[]*selector.OneClauseSchema]("Clauses", s.Configs)
for i, oneConf := range conf {
if oneConf.Single != nil {
left, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), selector.LeftKey})
if !ok {
return nil, fmt.Errorf("failed to take left operant from input map: %v, clause index= %d", in, i)
}
right, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), selector.RightKey})
if ok {
out = append(out, selector.Operants{Left: left, Right: right})
} else {
out = append(out, selector.Operants{Left: left})
}
} else if oneConf.Multi != nil {
multiClause := make([]*selector.Operants, 0)
for j := range oneConf.Multi.Clauses {
left, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), strconv.Itoa(j), selector.LeftKey})
if !ok {
return nil, fmt.Errorf("failed to take left operant from input map: %v, clause index= %d, single clause index= %d", in, i, j)
}
right, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), strconv.Itoa(j), selector.RightKey})
if ok {
multiClause = append(multiClause, &selector.Operants{Left: left, Right: right})
} else {
multiClause = append(multiClause, &selector.Operants{Left: left})
}
}
out = append(out, selector.Operants{Multi: multiClause})
} else {
return nil, fmt.Errorf("invalid clause config, both single and multi are nil: %v", oneConf)
}
}
return out, nil
}
func (s *NodeSchema) ToBatchConfig(inner compose.Runnable[map[string]any, map[string]any]) (*batch.Config, error) {
conf := &batch.Config{
BatchNodeKey: s.Key,
InnerWorkflow: inner,
Outputs: s.OutputSources,
}
for key, tInfo := range s.InputTypes {
if tInfo.Type != vo.DataTypeArray {
continue
}
conf.InputArrays = append(conf.InputArrays, key)
}
return conf, nil
}
func (s *NodeSchema) ToVariableAggregatorConfig() (*variableaggregator.Config, error) {
return &variableaggregator.Config{
MergeStrategy: s.Configs.(map[string]any)["MergeStrategy"].(variableaggregator.MergeStrategy),
GroupLen: s.Configs.(map[string]any)["GroupToLen"].(map[string]int),
FullSources: getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs),
NodeKey: s.Key,
InputSources: s.InputSources,
GroupOrder: mustGetKey[[]string]("GroupOrder", s.Configs),
}, nil
}
func (s *NodeSchema) variableAggregatorInputConverter(in map[string]any) (converted map[string]map[int]any) {
converted = make(map[string]map[int]any)
for k, value := range in {
m, ok := value.(map[string]any)
if !ok {
panic(errors.New("value is not a map[string]any"))
}
converted[k] = make(map[int]any, len(m))
for i, sv := range m {
index, err := strconv.Atoi(i)
if err != nil {
panic(fmt.Errorf(" converting %s to int failed, err=%v", i, err))
}
converted[k][index] = sv
}
}
return converted
}
func (s *NodeSchema) variableAggregatorStreamInputConverter(in *schema.StreamReader[map[string]any]) *schema.StreamReader[map[string]map[int]any] {
converter := func(input map[string]any) (output map[string]map[int]any, err error) {
defer func() {
if r := recover(); r != nil {
err = safego.NewPanicErr(r, debug.Stack())
}
}()
return s.variableAggregatorInputConverter(input), nil
}
return schema.StreamReaderWithConvert(in, converter)
}
func (s *NodeSchema) ToTextProcessorConfig() (*textprocessor.Config, error) {
return &textprocessor.Config{
Type: s.Configs.(map[string]any)["Type"].(textprocessor.Type),
Tpl: getKeyOrZero[string]("Tpl", s.Configs.(map[string]any)),
ConcatChar: getKeyOrZero[string]("ConcatChar", s.Configs.(map[string]any)),
Separators: getKeyOrZero[[]string]("Separators", s.Configs.(map[string]any)),
FullSources: getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs),
}, nil
}
func (s *NodeSchema) ToJsonSerializationConfig() (*json.SerializationConfig, error) {
return &json.SerializationConfig{
InputTypes: s.InputTypes,
}, nil
}
func (s *NodeSchema) ToJsonDeserializationConfig() (*json.DeserializationConfig, error) {
return &json.DeserializationConfig{
OutputFields: s.OutputTypes,
}, nil
}
func (s *NodeSchema) ToHTTPRequesterConfig() (*httprequester.Config, error) {
return &httprequester.Config{
URLConfig: mustGetKey[httprequester.URLConfig]("URLConfig", s.Configs),
AuthConfig: getKeyOrZero[*httprequester.AuthenticationConfig]("AuthConfig", s.Configs),
BodyConfig: mustGetKey[httprequester.BodyConfig]("BodyConfig", s.Configs),
Method: mustGetKey[string]("Method", s.Configs),
Timeout: mustGetKey[time.Duration]("Timeout", s.Configs),
RetryTimes: mustGetKey[uint64]("RetryTimes", s.Configs),
MD5FieldMapping: mustGetKey[httprequester.MD5FieldMapping]("MD5FieldMapping", s.Configs),
}, nil
}
func (s *NodeSchema) ToVariableAssignerConfig(handler *variable.Handler) (*variableassigner.Config, error) {
return &variableassigner.Config{
Pairs: s.Configs.([]*variableassigner.Pair),
Handler: handler,
}, nil
}
func (s *NodeSchema) ToVariableAssignerInLoopConfig() (*variableassigner.Config, error) {
return &variableassigner.Config{
Pairs: s.Configs.([]*variableassigner.Pair),
}, nil
}
func (s *NodeSchema) ToLoopConfig(inner compose.Runnable[map[string]any, map[string]any]) (*loop.Config, error) {
conf := &loop.Config{
LoopNodeKey: s.Key,
LoopType: mustGetKey[loop.Type]("LoopType", s.Configs),
IntermediateVars: getKeyOrZero[map[string]*vo.TypeInfo]("IntermediateVars", s.Configs),
Outputs: s.OutputSources,
Inner: inner,
}
for key, tInfo := range s.InputTypes {
if tInfo.Type != vo.DataTypeArray {
continue
}
if _, ok := conf.IntermediateVars[key]; ok { // exclude arrays in intermediate vars
continue
}
conf.InputArrays = append(conf.InputArrays, key)
}
return conf, nil
}
func (s *NodeSchema) ToQAConfig(ctx context.Context) (*qa.Config, error) {
conf := &qa.Config{
QuestionTpl: mustGetKey[string]("QuestionTpl", s.Configs),
AnswerType: mustGetKey[qa.AnswerType]("AnswerType", s.Configs),
ChoiceType: getKeyOrZero[qa.ChoiceType]("ChoiceType", s.Configs),
FixedChoices: getKeyOrZero[[]string]("FixedChoices", s.Configs),
ExtractFromAnswer: getKeyOrZero[bool]("ExtractFromAnswer", s.Configs),
MaxAnswerCount: getKeyOrZero[int]("MaxAnswerCount", s.Configs),
AdditionalSystemPromptTpl: getKeyOrZero[string]("AdditionalSystemPromptTpl", s.Configs),
OutputFields: s.OutputTypes,
NodeKey: s.Key,
}
llmParams := getKeyOrZero[*model.LLMParams]("LLMParams", s.Configs)
if llmParams != nil {
m, _, err := model.GetManager().GetModel(ctx, llmParams)
if err != nil {
return nil, err
}
conf.Model = m
}
return conf, nil
}
func (s *NodeSchema) ToInputReceiverConfig() (*receiver.Config, error) {
return &receiver.Config{
OutputTypes: s.OutputTypes,
NodeKey: s.Key,
OutputSchema: mustGetKey[string]("OutputSchema", s.Configs),
}, nil
}
func (s *NodeSchema) ToOutputEmitterConfig(sc *WorkflowSchema) (*emitter.Config, error) {
conf := &emitter.Config{
Template: getKeyOrZero[string]("Template", s.Configs),
FullSources: getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs),
}
return conf, nil
}
func (s *NodeSchema) ToDatabaseCustomSQLConfig() (*database.CustomSQLConfig, error) {
return &database.CustomSQLConfig{
DatabaseInfoID: mustGetKey[int64]("DatabaseInfoID", s.Configs),
SQLTemplate: mustGetKey[string]("SQLTemplate", s.Configs),
OutputConfig: s.OutputTypes,
CustomSQLExecutor: crossdatabase.GetDatabaseOperator(),
}, nil
}
func (s *NodeSchema) ToDatabaseQueryConfig() (*database.QueryConfig, error) {
return &database.QueryConfig{
DatabaseInfoID: mustGetKey[int64]("DatabaseInfoID", s.Configs),
QueryFields: getKeyOrZero[[]string]("QueryFields", s.Configs),
OrderClauses: getKeyOrZero[[]*crossdatabase.OrderClause]("OrderClauses", s.Configs),
ClauseGroup: getKeyOrZero[*crossdatabase.ClauseGroup]("ClauseGroup", s.Configs),
OutputConfig: s.OutputTypes,
Limit: mustGetKey[int64]("Limit", s.Configs),
Op: crossdatabase.GetDatabaseOperator(),
}, nil
}
func (s *NodeSchema) ToDatabaseInsertConfig() (*database.InsertConfig, error) {
return &database.InsertConfig{
DatabaseInfoID: mustGetKey[int64]("DatabaseInfoID", s.Configs),
OutputConfig: s.OutputTypes,
Inserter: crossdatabase.GetDatabaseOperator(),
}, nil
}
func (s *NodeSchema) ToDatabaseDeleteConfig() (*database.DeleteConfig, error) {
return &database.DeleteConfig{
DatabaseInfoID: mustGetKey[int64]("DatabaseInfoID", s.Configs),
ClauseGroup: mustGetKey[*crossdatabase.ClauseGroup]("ClauseGroup", s.Configs),
OutputConfig: s.OutputTypes,
Deleter: crossdatabase.GetDatabaseOperator(),
}, nil
}
func (s *NodeSchema) ToDatabaseUpdateConfig() (*database.UpdateConfig, error) {
return &database.UpdateConfig{
DatabaseInfoID: mustGetKey[int64]("DatabaseInfoID", s.Configs),
ClauseGroup: mustGetKey[*crossdatabase.ClauseGroup]("ClauseGroup", s.Configs),
OutputConfig: s.OutputTypes,
Updater: crossdatabase.GetDatabaseOperator(),
}, nil
}
func (s *NodeSchema) ToKnowledgeIndexerConfig() (*knowledge.IndexerConfig, error) {
return &knowledge.IndexerConfig{
KnowledgeID: mustGetKey[int64]("KnowledgeID", s.Configs),
ParsingStrategy: mustGetKey[*crossknowledge.ParsingStrategy]("ParsingStrategy", s.Configs),
ChunkingStrategy: mustGetKey[*crossknowledge.ChunkingStrategy]("ChunkingStrategy", s.Configs),
KnowledgeIndexer: crossknowledge.GetKnowledgeOperator(),
}, nil
}
func (s *NodeSchema) ToKnowledgeRetrieveConfig() (*knowledge.RetrieveConfig, error) {
return &knowledge.RetrieveConfig{
KnowledgeIDs: mustGetKey[[]int64]("KnowledgeIDs", s.Configs),
RetrievalStrategy: mustGetKey[*crossknowledge.RetrievalStrategy]("RetrievalStrategy", s.Configs),
Retriever: crossknowledge.GetKnowledgeOperator(),
}, nil
}
func (s *NodeSchema) ToKnowledgeDeleterConfig() (*knowledge.DeleterConfig, error) {
return &knowledge.DeleterConfig{
KnowledgeID: mustGetKey[int64]("KnowledgeID", s.Configs),
KnowledgeDeleter: crossknowledge.GetKnowledgeOperator(),
}, nil
}
func (s *NodeSchema) ToPluginConfig() (*plugin.Config, error) {
return &plugin.Config{
PluginID: mustGetKey[int64]("PluginID", s.Configs),
ToolID: mustGetKey[int64]("ToolID", s.Configs),
PluginVersion: mustGetKey[string]("PluginVersion", s.Configs),
PluginService: crossplugin.GetPluginService(),
}, nil
}
func (s *NodeSchema) ToCodeRunnerConfig() (*code.Config, error) {
return &code.Config{
Code: mustGetKey[string]("Code", s.Configs),
Language: mustGetKey[coderunner.Language]("Language", s.Configs),
OutputConfig: s.OutputTypes,
Runner: crosscode.GetCodeRunner(),
}, nil
}
func (s *NodeSchema) ToCreateConversationConfig() (*conversation.CreateConversationConfig, error) {
return &conversation.CreateConversationConfig{
Creator: crossconversation.ConversationManagerImpl,
}, nil
}
func (s *NodeSchema) ToClearMessageConfig() (*conversation.ClearMessageConfig, error) {
return &conversation.ClearMessageConfig{
Clearer: crossconversation.ConversationManagerImpl,
}, nil
}
func (s *NodeSchema) ToMessageListConfig() (*conversation.MessageListConfig, error) {
return &conversation.MessageListConfig{
Lister: crossconversation.ConversationManagerImpl,
}, nil
}
func (s *NodeSchema) ToIntentDetectorConfig(ctx context.Context) (*intentdetector.Config, error) {
cfg := &intentdetector.Config{
Intents: mustGetKey[[]string]("Intents", s.Configs),
SystemPrompt: getKeyOrZero[string]("SystemPrompt", s.Configs),
IsFastMode: getKeyOrZero[bool]("IsFastMode", s.Configs),
}
llmParams := mustGetKey[*model.LLMParams]("LLMParams", s.Configs)
m, _, err := model.GetManager().GetModel(ctx, llmParams)
if err != nil {
return nil, err
}
cfg.ChatModel = m
return cfg, nil
}
func (s *NodeSchema) ToSubWorkflowConfig(ctx context.Context, requireCheckpoint bool) (*subworkflow.Config, error) {
var opts []WorkflowOption
opts = append(opts, WithIDAsName(mustGetKey[int64]("WorkflowID", s.Configs)))
if requireCheckpoint {
opts = append(opts, WithParentRequireCheckpoint())
}
if s := execute.GetStaticConfig(); s != nil && s.MaxNodeCountPerWorkflow > 0 {
opts = append(opts, WithMaxNodeCount(s.MaxNodeCountPerWorkflow))
}
wf, err := NewWorkflow(ctx, s.SubWorkflowSchema, opts...)
if err != nil {
return nil, err
}
return &subworkflow.Config{
Runner: wf.Runner,
}, nil
}
func totRetrievalSearchType(s int64) (crossknowledge.SearchType, error) {
switch s {
case 0:
return crossknowledge.SearchTypeSemantic, nil
case 1:
return crossknowledge.SearchTypeHybrid, nil
case 20:
return crossknowledge.SearchTypeFullText, nil
default:
return "", fmt.Errorf("invalid retrieval search type %v", s)
}
}

View File

@ -1,107 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package compose
import (
"fmt"
"reflect"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
)
func getKeyOrZero[T any](key string, cfg any) T {
var zero T
if cfg == nil {
return zero
}
m, ok := cfg.(map[string]any)
if !ok {
panic(fmt.Sprintf("m is not a map[string]any, actual type: %v", reflect.TypeOf(cfg)))
}
if len(m) == 0 {
return zero
}
if v, ok := m[key]; ok {
return v.(T)
}
return zero
}
func mustGetKey[T any](key string, cfg any) T {
if cfg == nil {
panic(fmt.Sprintf("mustGetKey[*any] is nil, key=%s", key))
}
m, ok := cfg.(map[string]any)
if !ok {
panic(fmt.Sprintf("m is not a map[string]any, actual type: %v", reflect.TypeOf(cfg)))
}
if _, ok := m[key]; !ok {
panic(fmt.Sprintf("key %s does not exist in map: %v", key, m))
}
v, ok := m[key].(T)
if !ok {
panic(fmt.Sprintf("key %s is not a %v, actual type: %v", key, reflect.TypeOf(v), reflect.TypeOf(m[key])))
}
return v
}
func (s *NodeSchema) SetConfigKV(key string, value any) {
if s.Configs == nil {
s.Configs = make(map[string]any)
}
s.Configs.(map[string]any)[key] = value
}
func (s *NodeSchema) SetInputType(key string, t *vo.TypeInfo) {
if s.InputTypes == nil {
s.InputTypes = make(map[string]*vo.TypeInfo)
}
s.InputTypes[key] = t
}
func (s *NodeSchema) AddInputSource(info ...*vo.FieldInfo) {
s.InputSources = append(s.InputSources, info...)
}
func (s *NodeSchema) SetOutputType(key string, t *vo.TypeInfo) {
if s.OutputTypes == nil {
s.OutputTypes = make(map[string]*vo.TypeInfo)
}
s.OutputTypes[key] = t
}
func (s *NodeSchema) AddOutputSource(info ...*vo.FieldInfo) {
s.OutputSources = append(s.OutputSources, info...)
}
func (s *NodeSchema) GetSubWorkflowIdentity() (int64, string, bool) {
if s.Type != entity.NodeTypeSubWorkflow {
return 0, "", false
}
return mustGetKey[int64]("WorkflowID", s.Configs), mustGetKey[string]("WorkflowVersion", s.Configs), true
}

View File

@ -29,6 +29,8 @@ import (
workflow2 "github.com/coze-dev/coze-studio/backend/domain/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
)
@ -37,7 +39,7 @@ type workflow = compose.Workflow[map[string]any, map[string]any]
type Workflow struct { // TODO: too many fields in this struct, cut them down to the absolutely essentials
*workflow
hierarchy map[vo.NodeKey]vo.NodeKey
connections []*Connection
connections []*schema.Connection
requireCheckpoint bool
entry *compose.WorkflowNode
inner bool
@ -47,7 +49,7 @@ type Workflow struct { // TODO: too many fields in this struct, cut them down to
input map[string]*vo.TypeInfo
output map[string]*vo.TypeInfo
terminatePlan vo.TerminatePlan
schema *WorkflowSchema
schema *schema.WorkflowSchema
}
type workflowOptions struct {
@ -78,7 +80,7 @@ func WithMaxNodeCount(c int) WorkflowOption {
}
}
func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption) (*Workflow, error) {
func NewWorkflow(ctx context.Context, sc *schema.WorkflowSchema, opts ...WorkflowOption) (*Workflow, error) {
sc.Init()
wf := &Workflow{
@ -88,8 +90,8 @@ func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption
schema: sc,
}
wf.streamRun = sc.requireStreaming
wf.requireCheckpoint = sc.requireCheckPoint
wf.streamRun = sc.RequireStreaming()
wf.requireCheckpoint = sc.RequireCheckpoint()
wfOpts := &workflowOptions{}
for _, opt := range opts {
@ -125,7 +127,6 @@ func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption
processedNodeKey[child.Key] = struct{}{}
}
}
// add all nodes other than composite nodes and their children
for _, ns := range sc.Nodes {
if _, ok := processedNodeKey[ns.Key]; !ok {
@ -135,7 +136,7 @@ func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption
}
if ns.Type == entity.NodeTypeExit {
wf.terminatePlan = mustGetKey[vo.TerminatePlan]("TerminalPlan", ns.Configs)
wf.terminatePlan = ns.Configs.(*exit.Config).TerminatePlan
}
}
@ -147,7 +148,7 @@ func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption
compileOpts = append(compileOpts, compose.WithGraphName(strconv.FormatInt(wfOpts.wfID, 10)))
}
fanInConfigs := sc.fanInMergeConfigs()
fanInConfigs := sc.FanInMergeConfigs()
if len(fanInConfigs) > 0 {
compileOpts = append(compileOpts, compose.WithFanInMergeConfig(fanInConfigs))
}
@ -199,12 +200,12 @@ type innerWorkflowInfo struct {
carryOvers map[vo.NodeKey][]*compose.FieldMapping
}
func (w *Workflow) AddNode(ctx context.Context, ns *NodeSchema) error {
func (w *Workflow) AddNode(ctx context.Context, ns *schema.NodeSchema) error {
_, err := w.addNodeInternal(ctx, ns, nil)
return err
}
func (w *Workflow) AddCompositeNode(ctx context.Context, cNode *CompositeNode) error {
func (w *Workflow) AddCompositeNode(ctx context.Context, cNode *schema.CompositeNode) error {
inner, err := w.getInnerWorkflow(ctx, cNode)
if err != nil {
return err
@ -213,11 +214,11 @@ func (w *Workflow) AddCompositeNode(ctx context.Context, cNode *CompositeNode) e
return err
}
func (w *Workflow) addInnerNode(ctx context.Context, cNode *NodeSchema) (map[vo.NodeKey][]*compose.FieldMapping, error) {
func (w *Workflow) addInnerNode(ctx context.Context, cNode *schema.NodeSchema) (map[vo.NodeKey][]*compose.FieldMapping, error) {
return w.addNodeInternal(ctx, cNode, nil)
}
func (w *Workflow) addNodeInternal(ctx context.Context, ns *NodeSchema, inner *innerWorkflowInfo) (map[vo.NodeKey][]*compose.FieldMapping, error) {
func (w *Workflow) addNodeInternal(ctx context.Context, ns *schema.NodeSchema, inner *innerWorkflowInfo) (map[vo.NodeKey][]*compose.FieldMapping, error) {
key := ns.Key
var deps *dependencyInfo
@ -237,7 +238,7 @@ func (w *Workflow) addNodeInternal(ctx context.Context, ns *NodeSchema, inner *i
innerWorkflow = inner.inner
}
ins, err := ns.New(ctx, innerWorkflow, w.schema, deps)
ins, err := New(ctx, ns, innerWorkflow, w.schema, deps)
if err != nil {
return nil, err
}
@ -245,12 +246,12 @@ func (w *Workflow) addNodeInternal(ctx context.Context, ns *NodeSchema, inner *i
var opts []compose.GraphAddNodeOpt
opts = append(opts, compose.WithNodeName(string(ns.Key)))
preHandler := ns.StatePreHandler(w.streamRun)
preHandler := statePreHandler(ns, w.streamRun)
if preHandler != nil {
opts = append(opts, preHandler)
}
postHandler := ns.StatePostHandler(w.streamRun)
postHandler := statePostHandler(ns, w.streamRun)
if postHandler != nil {
opts = append(opts, postHandler)
}
@ -297,19 +298,23 @@ func (w *Workflow) addNodeInternal(ctx context.Context, ns *NodeSchema, inner *i
w.entry = wNode
}
outputPortCount, hasExceptionPort := ns.OutputPortCount()
if outputPortCount > 1 || hasExceptionPort {
bMapping, err := w.resolveBranch(key, outputPortCount)
if err != nil {
return nil, err
}
b := w.schema.GetBranch(ns.Key)
if b != nil {
if b.OnlyException() {
_ = w.AddBranch(string(key), b.GetExceptionBranch())
} else {
bb, ok := ns.Configs.(schema.BranchBuilder)
if !ok {
return nil, fmt.Errorf("node schema's Configs should implement BranchBuilder, node type= %v", ns.Type)
}
branch, err := ns.GetBranch(bMapping)
if err != nil {
return nil, err
}
br, err := b.GetFullBranch(ctx, bb)
if err != nil {
return nil, err
}
_ = w.AddBranch(string(key), branch)
_ = w.AddBranch(string(key), br)
}
}
return deps.inputsForParent, nil
@ -328,15 +333,15 @@ func (w *Workflow) Compile(ctx context.Context, opts ...compose.GraphCompileOpti
return w.workflow.Compile(ctx, opts...)
}
func (w *Workflow) getInnerWorkflow(ctx context.Context, cNode *CompositeNode) (*innerWorkflowInfo, error) {
innerNodes := make(map[vo.NodeKey]*NodeSchema)
func (w *Workflow) getInnerWorkflow(ctx context.Context, cNode *schema.CompositeNode) (*innerWorkflowInfo, error) {
innerNodes := make(map[vo.NodeKey]*schema.NodeSchema)
for _, n := range cNode.Children {
innerNodes[n.Key] = n
}
// trim the connections, only keep the connections that are related to the inner workflow
// ignore the cases when we have nested inner workflows, because we do not support nested composite nodes
innerConnections := make([]*Connection, 0)
innerConnections := make([]*schema.Connection, 0)
for i := range w.schema.Connections {
conn := w.schema.Connections[i]
if _, ok := innerNodes[conn.FromNode]; ok {
@ -510,7 +515,7 @@ func (d *dependencyInfo) merge(mappings map[vo.NodeKey][]*compose.FieldMapping)
// For example, if the 'from path' is ['a', 'b', 'c'], and 'b' is an array, we will take value using a.b[0].c.
// As a counter example, if the 'from path' is ['a', 'b', 'c'], and 'b' is not an array, but 'c' is an array,
// we will not try to drill, instead, just take value using a.b.c.
func (d *dependencyInfo) arrayDrillDown(allNS map[vo.NodeKey]*NodeSchema) error {
func (d *dependencyInfo) arrayDrillDown(allNS map[vo.NodeKey]*schema.NodeSchema) error {
for nKey, fms := range d.inputs {
if nKey == compose.START { // reference to START node would NEVER need to do array drill down
continue
@ -638,55 +643,6 @@ type variableInfo struct {
toPath compose.FieldPath
}
func (w *Workflow) resolveBranch(n vo.NodeKey, portCount int) (*BranchMapping, error) {
m := make([]map[string]bool, portCount)
var exception map[string]bool
for _, conn := range w.connections {
if conn.FromNode != n {
continue
}
if conn.FromPort == nil {
continue
}
if *conn.FromPort == "default" { // default condition
if m[portCount-1] == nil {
m[portCount-1] = make(map[string]bool)
}
m[portCount-1][string(conn.ToNode)] = true
} else if *conn.FromPort == "branch_error" {
if exception == nil {
exception = make(map[string]bool)
}
exception[string(conn.ToNode)] = true
} else {
if !strings.HasPrefix(*conn.FromPort, "branch_") {
return nil, fmt.Errorf("outgoing connections has invalid port= %s", *conn.FromPort)
}
index := (*conn.FromPort)[7:]
i, err := strconv.Atoi(index)
if err != nil {
return nil, fmt.Errorf("outgoing connections has invalid port index= %s", *conn.FromPort)
}
if i < 0 || i >= portCount {
return nil, fmt.Errorf("outgoing connections has invalid port index range= %d, condition count= %d", i, portCount)
}
if m[i] == nil {
m[i] = make(map[string]bool)
}
m[i][string(conn.ToNode)] = true
}
}
return &BranchMapping{
Normal: m,
Exception: exception,
}, nil
}
func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.FieldInfo) (*dependencyInfo, error) {
var (
inputs = make(map[vo.NodeKey][]*compose.FieldMapping)
@ -701,7 +657,7 @@ func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.Field
inputsForParent = make(map[vo.NodeKey][]*compose.FieldMapping)
)
connMap := make(map[vo.NodeKey]Connection)
connMap := make(map[vo.NodeKey]schema.Connection)
for _, conn := range w.connections {
if conn.ToNode != n {
continue
@ -734,7 +690,7 @@ func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.Field
continue
}
if ok := isInSameWorkflow(w.hierarchy, n, fromNode); ok {
if ok := schema.IsInSameWorkflow(w.hierarchy, n, fromNode); ok {
if _, ok := connMap[fromNode]; ok { // direct dependency
if len(swp.Source.Ref.FromPath) == 0 && len(swp.Path) == 0 {
if inputFull == nil {
@ -755,10 +711,10 @@ func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.Field
compose.MapFieldPaths(swp.Source.Ref.FromPath, swp.Path))
}
}
} else if ok := isBelowOneLevel(w.hierarchy, n, fromNode); ok {
} else if ok := schema.IsBelowOneLevel(w.hierarchy, n, fromNode); ok {
firstNodesInInnerWorkflow := true
for _, conn := range connMap {
if isInSameWorkflow(w.hierarchy, n, conn.FromNode) {
if schema.IsInSameWorkflow(w.hierarchy, n, conn.FromNode) {
// there is another node 'conn.FromNode' that connects to this node, while also at the same level
firstNodesInInnerWorkflow = false
break
@ -805,9 +761,9 @@ func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.Field
continue
}
if isBelowOneLevel(w.hierarchy, n, fromNodeKey) {
if schema.IsBelowOneLevel(w.hierarchy, n, fromNodeKey) {
fromNodeKey = compose.START
} else if !isInSameWorkflow(w.hierarchy, n, fromNodeKey) {
} else if !schema.IsInSameWorkflow(w.hierarchy, n, fromNodeKey) {
continue
}
@ -864,13 +820,13 @@ func (w *Workflow) resolveDependenciesAsParent(n vo.NodeKey, sourceWithPaths []*
variableInfos []*variableInfo
)
connMap := make(map[vo.NodeKey]Connection)
connMap := make(map[vo.NodeKey]schema.Connection)
for _, conn := range w.connections {
if conn.ToNode != n {
continue
}
if isInSameWorkflow(w.hierarchy, conn.FromNode, n) {
if schema.IsInSameWorkflow(w.hierarchy, conn.FromNode, n) {
continue
}
@ -899,7 +855,7 @@ func (w *Workflow) resolveDependenciesAsParent(n vo.NodeKey, sourceWithPaths []*
swp.Source.Ref.FromPath, swp.Path)
}
if ok := isParentOf(w.hierarchy, n, fromNode); ok {
if ok := schema.IsParentOf(w.hierarchy, n, fromNode); ok {
if _, ok := connMap[fromNode]; ok { // direct dependency
inputs[fromNode] = append(inputs[fromNode], compose.MapFieldPaths(swp.Source.Ref.FromPath, append(compose.FieldPath{string(fromNode)}, swp.Source.Ref.FromPath...)))
} else { // indirect dependency

View File

@ -23,9 +23,10 @@ import (
workflow2 "github.com/coze-dev/coze-studio/backend/domain/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
func NewWorkflowFromNode(ctx context.Context, sc *WorkflowSchema, nodeKey vo.NodeKey, opts ...compose.GraphCompileOption) (
func NewWorkflowFromNode(ctx context.Context, sc *schema.WorkflowSchema, nodeKey vo.NodeKey, opts ...compose.GraphCompileOption) (
*Workflow, error) {
sc.Init()
ns := sc.GetNode(nodeKey)
@ -37,7 +38,7 @@ func NewWorkflowFromNode(ctx context.Context, sc *WorkflowSchema, nodeKey vo.Nod
schema: sc,
fromNode: true,
streamRun: false, // single node run can only invoke
requireCheckpoint: sc.requireCheckPoint,
requireCheckpoint: sc.RequireCheckpoint(),
input: ns.InputTypes,
output: ns.OutputTypes,
terminatePlan: vo.ReturnVariables,

View File

@ -32,6 +32,7 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
@ -42,7 +43,7 @@ type WorkflowRunner struct {
basic *entity.WorkflowBasic
input string
resumeReq *entity.ResumeRequest
schema *WorkflowSchema
schema *schema2.WorkflowSchema
streamWriter *schema.StreamWriter[*entity.Message]
config vo.ExecuteConfig
@ -76,7 +77,7 @@ func WithStreamWriter(sw *schema.StreamWriter[*entity.Message]) WorkflowRunnerOp
}
}
func NewWorkflowRunner(b *entity.WorkflowBasic, sc *WorkflowSchema, config vo.ExecuteConfig, opts ...WorkflowRunnerOption) *WorkflowRunner {
func NewWorkflowRunner(b *entity.WorkflowBasic, sc *schema2.WorkflowSchema, config vo.ExecuteConfig, opts ...WorkflowRunnerOption) *WorkflowRunner {
options := &workflowRunOptions{}
for _, opt := range opts {
opt(options)

View File

@ -30,6 +30,7 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
)
@ -41,7 +42,7 @@ type invokableWorkflow struct {
invoke func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (map[string]any, error)
terminatePlan vo.TerminatePlan
wfEntity *entity.Workflow
sc *WorkflowSchema
sc *schema2.WorkflowSchema
repo wf.Repository
}
@ -49,7 +50,7 @@ func NewInvokableWorkflow(info *schema.ToolInfo,
invoke func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (map[string]any, error),
terminatePlan vo.TerminatePlan,
wfEntity *entity.Workflow,
sc *WorkflowSchema,
sc *schema2.WorkflowSchema,
repo wf.Repository,
) wf.ToolFromWorkflow {
return &invokableWorkflow{
@ -112,7 +113,7 @@ func (i *invokableWorkflow) InvokableRun(ctx context.Context, argumentsInJSON st
return "", err
}
var entryNode *NodeSchema
var entryNode *schema2.NodeSchema
for _, node := range i.sc.Nodes {
if node.Type == entity.NodeTypeEntry {
entryNode = node
@ -190,7 +191,7 @@ type streamableWorkflow struct {
stream func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (*schema.StreamReader[map[string]any], error)
terminatePlan vo.TerminatePlan
wfEntity *entity.Workflow
sc *WorkflowSchema
sc *schema2.WorkflowSchema
repo wf.Repository
}
@ -198,7 +199,7 @@ func NewStreamableWorkflow(info *schema.ToolInfo,
stream func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (*schema.StreamReader[map[string]any], error),
terminatePlan vo.TerminatePlan,
wfEntity *entity.Workflow,
sc *WorkflowSchema,
sc *schema2.WorkflowSchema,
repo wf.Repository,
) wf.ToolFromWorkflow {
return &streamableWorkflow{
@ -261,7 +262,7 @@ func (s *streamableWorkflow) StreamableRun(ctx context.Context, argumentsInJSON
return nil, err
}
var entryNode *NodeSchema
var entryNode *schema2.NodeSchema
for _, node := range s.sc.Nodes {
if node.Type == entity.NodeTypeEntry {
entryNode = node

View File

@ -30,50 +30,108 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
)
type Batch struct {
config *Config
outputs map[string]*vo.FieldSource
outputs map[string]*vo.FieldSource
innerWorkflow compose.Runnable[map[string]any, map[string]any]
key vo.NodeKey
inputArrays []string
}
type Config struct {
BatchNodeKey vo.NodeKey `json:"batch_node_key"`
InnerWorkflow compose.Runnable[map[string]any, map[string]any]
type Config struct{}
InputArrays []string `json:"input_arrays"`
Outputs []*vo.FieldInfo `json:"outputs"`
}
func NewBatch(_ context.Context, config *Config) (*Batch, error) {
if config == nil {
return nil, errors.New("config is required")
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
if n.Parent() != nil {
return nil, fmt.Errorf("batch node cannot have parent: %s", n.Parent().ID)
}
if len(config.InputArrays) == 0 {
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeBatch,
Name: n.Data.Meta.Title,
Configs: c,
}
batchSizeField, err := convert.CanvasBlockInputToFieldInfo(n.Data.Inputs.BatchSize,
compose.FieldPath{MaxBatchSizeKey}, nil)
if err != nil {
return nil, err
}
ns.AddInputSource(batchSizeField...)
concurrentSizeField, err := convert.CanvasBlockInputToFieldInfo(n.Data.Inputs.ConcurrentSize,
compose.FieldPath{ConcurrentSizeKey}, nil)
if err != nil {
return nil, err
}
ns.AddInputSource(concurrentSizeField...)
batchSizeType, err := convert.CanvasBlockInputToTypeInfo(n.Data.Inputs.BatchSize)
if err != nil {
return nil, err
}
ns.SetInputType(MaxBatchSizeKey, batchSizeType)
concurrentSizeType, err := convert.CanvasBlockInputToTypeInfo(n.Data.Inputs.ConcurrentSize)
if err != nil {
return nil, err
}
ns.SetInputType(ConcurrentSizeKey, concurrentSizeType)
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err := convert.SetOutputsForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, opts ...schema.BuildOption) (any, error) {
var inputArrays []string
for key, tInfo := range ns.InputTypes {
if tInfo.Type != vo.DataTypeArray {
continue
}
inputArrays = append(inputArrays, key)
}
if len(inputArrays) == 0 {
return nil, errors.New("need to have at least one incoming array for batch")
}
if len(config.Outputs) == 0 {
if len(ns.OutputSources) == 0 {
return nil, errors.New("need to have at least one output variable for batch")
}
b := &Batch{
config: config,
outputs: make(map[string]*vo.FieldSource),
bo := schema.GetBuildOptions(opts...)
if bo.Inner == nil {
return nil, errors.New("need to have inner workflow for batch")
}
for i := range config.Outputs {
source := config.Outputs[i]
b := &Batch{
outputs: make(map[string]*vo.FieldSource),
innerWorkflow: bo.Inner,
key: ns.Key,
inputArrays: inputArrays,
}
for i := range ns.OutputSources {
source := ns.OutputSources[i]
path := source.Path
if len(path) != 1 {
return nil, fmt.Errorf("invalid path %q", path)
}
// from which inner node's which field does the batch's output fields come from
b.outputs[path[0]] = &source.Source
}
@ -97,11 +155,11 @@ func (b *Batch) initOutput(length int) map[string]any {
return out
}
func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.NestedWorkflowOption) (
func (b *Batch) Invoke(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (
out map[string]any, err error) {
arrays := make(map[string]any, len(b.config.InputArrays))
arrays := make(map[string]any, len(b.inputArrays))
minLen := math.MaxInt64
for _, arrayKey := range b.config.InputArrays {
for _, arrayKey := range b.inputArrays {
a, ok := nodes.TakeMapValue(in, compose.FieldPath{arrayKey})
if !ok {
return nil, fmt.Errorf("incoming array not present in input: %s", arrayKey)
@ -160,13 +218,13 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
}
}
input[string(b.config.BatchNodeKey)+"#index"] = int64(i)
input[string(b.key)+"#index"] = int64(i)
items := make(map[string]any)
for arrayKey, array := range arrays {
ele := reflect.ValueOf(array).Index(i).Interface()
items[arrayKey] = []any{ele}
currentKey := string(b.config.BatchNodeKey) + "#" + arrayKey
currentKey := string(b.key) + "#" + arrayKey
// Recursively expand map[string]any elements
var expand func(prefix string, val interface{})
@ -200,15 +258,11 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
return nil
}
options := &nodes.NestedWorkflowOptions{}
for _, opt := range opts {
opt(options)
}
options := nodes.GetCommonOptions(&nodes.NodeOptions{}, opts...)
var existingCState *nodes.NestedWorkflowState
err = compose.ProcessState(ctx, func(ctx context.Context, getter nodes.NestedWorkflowAware) error {
var e error
existingCState, _, e = getter.GetNestedWorkflowState(b.config.BatchNodeKey)
existingCState, _, e = getter.GetNestedWorkflowState(b.key)
if e != nil {
return e
}
@ -280,7 +334,7 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
mu.Unlock()
if subCheckpointID != "" {
logs.CtxInfof(ctx, "[testInterrupt] prepare %d th run for batch node %s, subCheckPointID %s",
i, b.config.BatchNodeKey, subCheckpointID)
i, b.key, subCheckpointID)
ithOpts = append(ithOpts, compose.WithCheckPointID(subCheckpointID))
}
@ -298,7 +352,7 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
// if the innerWorkflow has output emitter that requires stream output, then we need to stream the inner workflow
// the output then needs to be concatenated.
taskOutput, err := b.config.InnerWorkflow.Invoke(subCtx, input, ithOpts...)
taskOutput, err := b.innerWorkflow.Invoke(subCtx, input, ithOpts...)
if err != nil {
info, ok := compose.ExtractInterruptInfo(err)
if !ok {
@ -376,17 +430,17 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
if len(index2InterruptInfo) > 0 { // this invocation of batch.Execute has new interruptions
iEvent := &entity.InterruptEvent{
NodeKey: b.config.BatchNodeKey,
NodeKey: b.key,
NodeType: entity.NodeTypeBatch,
NestedInterruptInfo: index2InterruptInfo, // only emit the newly generated interruptInfo
}
err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error {
if e := setter.SaveNestedWorkflowState(b.config.BatchNodeKey, compState); e != nil {
if e := setter.SaveNestedWorkflowState(b.key, compState); e != nil {
return e
}
return setter.SetInterruptEvent(b.config.BatchNodeKey, iEvent)
return setter.SetInterruptEvent(b.key, iEvent)
})
if err != nil {
return nil, err
@ -398,7 +452,7 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
return nil, compose.InterruptAndRerun
} else {
err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error {
if e := setter.SaveNestedWorkflowState(b.config.BatchNodeKey, compState); e != nil {
if e := setter.SaveNestedWorkflowState(b.key, compState); e != nil {
return e
}
@ -409,8 +463,8 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
// although this invocation does not have new interruptions,
// this batch node previously have interrupts yet to be resumed.
// we overwrite the interrupt events, keeping only the interrupts yet to be resumed.
return setter.SetInterruptEvent(b.config.BatchNodeKey, &entity.InterruptEvent{
NodeKey: b.config.BatchNodeKey,
return setter.SetInterruptEvent(b.key, &entity.InterruptEvent{
NodeKey: b.key,
NodeType: entity.NodeTypeBatch,
NestedInterruptInfo: existingCState.Index2InterruptInfo,
})
@ -424,7 +478,7 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
if existingCState != nil && len(existingCState.Index2InterruptInfo) > 0 {
logs.CtxInfof(ctx, "no interrupt thrown this round, but has historical interrupt events yet to be resumed, "+
"nodeKey: %v. indexes: %v", b.config.BatchNodeKey, maps.Keys(existingCState.Index2InterruptInfo))
"nodeKey: %v. indexes: %v", b.key, maps.Keys(existingCState.Index2InterruptInfo))
return nil, compose.InterruptAndRerun // interrupt again to wait for resuming of previously interrupted index runs
}
@ -432,8 +486,8 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
}
func (b *Batch) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) {
trimmed := make(map[string]any, len(b.config.InputArrays))
for _, arrayKey := range b.config.InputArrays {
trimmed := make(map[string]any, len(b.inputArrays))
for _, arrayKey := range b.inputArrays {
if v, ok := in[arrayKey]; ok {
trimmed[arrayKey] = v
}

View File

@ -25,6 +25,10 @@ import (
"golang.org/x/exp/maps"
code2 "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/code"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/infra/contract/coderunner"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
@ -113,50 +117,77 @@ var pythonThirdPartyWhitelist = map[string]struct{}{
}
type Config struct {
Code string
Language coderunner.Language
OutputConfig map[string]*vo.TypeInfo
Runner coderunner.Runner
Code string
Language coderunner.Language
Runner coderunner.Runner
}
type CodeRunner struct {
config *Config
importError error
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeCodeRunner,
Name: n.Data.Meta.Title,
Configs: c,
}
inputs := n.Data.Inputs
code := inputs.Code
c.Code = code
language, err := convertCodeLanguage(inputs.Language)
if err != nil {
return nil, err
}
c.Language = language
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func NewCodeRunner(ctx context.Context, cfg *Config) (*CodeRunner, error) {
if cfg == nil {
return nil, errors.New("cfg is required")
func convertCodeLanguage(l int64) (coderunner.Language, error) {
switch l {
case 5:
return coderunner.JavaScript, nil
case 3:
return coderunner.Python, nil
default:
return "", fmt.Errorf("invalid language: %d", l)
}
}
if cfg.Language == "" {
return nil, errors.New("language is required")
}
func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
if cfg.Code == "" {
return nil, errors.New("code is required")
}
if cfg.Language != coderunner.Python {
if c.Language != coderunner.Python {
return nil, errors.New("only support python language")
}
if len(cfg.OutputConfig) == 0 {
return nil, errors.New("output config is required")
}
importErr := validatePythonImports(c.Code)
if cfg.Runner == nil {
return nil, errors.New("run coder is required")
}
importErr := validatePythonImports(cfg.Code)
return &CodeRunner{
config: cfg,
importError: importErr,
return &Runner{
code: c.Code,
language: c.Language,
outputConfig: ns.OutputTypes,
runner: code2.GetCodeRunner(),
importError: importErr,
}, nil
}
type Runner struct {
outputConfig map[string]*vo.TypeInfo
code string
language coderunner.Language
runner coderunner.Runner
importError error
}
func validatePythonImports(code string) error {
imports := parsePythonImports(code)
importErrors := make([]string, 0)
@ -191,11 +222,11 @@ func validatePythonImports(code string) error {
return nil
}
func (c *CodeRunner) RunCode(ctx context.Context, input map[string]any) (ret map[string]any, err error) {
func (c *Runner) Invoke(ctx context.Context, input map[string]any) (ret map[string]any, err error) {
if c.importError != nil {
return nil, vo.WrapError(errno.ErrCodeExecuteFail, c.importError, errorx.KV("detail", c.importError.Error()))
}
response, err := c.config.Runner.Run(ctx, &coderunner.RunRequest{Code: c.config.Code, Language: c.config.Language, Params: input})
response, err := c.runner.Run(ctx, &coderunner.RunRequest{Code: c.code, Language: c.language, Params: input})
if err != nil {
return nil, vo.WrapError(errno.ErrCodeExecuteFail, err, errorx.KV("detail", err.Error()))
}
@ -203,7 +234,7 @@ func (c *CodeRunner) RunCode(ctx context.Context, input map[string]any) (ret map
result := response.Result
ctxcache.Store(ctx, coderRunnerRawOutputCtxKey, result)
output, ws, err := nodes.ConvertInputs(ctx, result, c.config.OutputConfig)
output, ws, err := nodes.ConvertInputs(ctx, result, c.outputConfig)
if err != nil {
return nil, vo.WrapIfNeeded(errno.ErrCodeExecuteFail, err, errorx.KV("detail", err.Error()))
}
@ -217,7 +248,7 @@ func (c *CodeRunner) RunCode(ctx context.Context, input map[string]any) (ret map
}
func (c *CodeRunner) ToCallbackOutput(ctx context.Context, output map[string]any) (*nodes.StructuredCallbackOutput, error) {
func (c *Runner) ToCallbackOutput(ctx context.Context, output map[string]any) (*nodes.StructuredCallbackOutput, error) {
rawOutput, ok := ctxcache.Get[map[string]any](ctx, coderRunnerRawOutputCtxKey)
if !ok {
return nil, errors.New("raw output config is required")

View File

@ -75,30 +75,29 @@ async def main(args:Args)->Output:
mockRunner.EXPECT().Run(gomock.Any(), gomock.Any()).Return(response, nil)
ctx := t.Context()
c := &CodeRunner{
config: &Config{
Language: coderunner.Python,
Code: codeTpl,
OutputConfig: map[string]*vo.TypeInfo{
"key0": {Type: vo.DataTypeInteger},
"key1": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeString}},
"key2": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key3": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key31": &vo.TypeInfo{Type: vo.DataTypeString},
"key32": &vo.TypeInfo{Type: vo.DataTypeString},
"key33": &vo.TypeInfo{Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key34": &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key341": &vo.TypeInfo{Type: vo.DataTypeString},
"key342": &vo.TypeInfo{Type: vo.DataTypeString},
}},
},
},
"key4": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject}},
c := &Runner{
language: coderunner.Python,
code: codeTpl,
outputConfig: map[string]*vo.TypeInfo{
"key0": {Type: vo.DataTypeInteger},
"key1": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeString}},
"key2": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key3": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key31": {Type: vo.DataTypeString},
"key32": {Type: vo.DataTypeString},
"key33": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key34": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key341": {Type: vo.DataTypeString},
"key342": {Type: vo.DataTypeString},
}},
},
Runner: mockRunner,
},
"key4": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject}},
},
runner: mockRunner,
}
ret, err := c.RunCode(ctx, map[string]any{
ret, err := c.Invoke(ctx, map[string]any{
"input": "1123",
})
@ -145,38 +144,36 @@ async def main(args:Args)->Output:
mockRunner.EXPECT().Run(gomock.Any(), gomock.Any()).Return(response, nil)
ctx := t.Context()
c := &CodeRunner{
config: &Config{
Code: codeTpl,
Language: coderunner.Python,
OutputConfig: map[string]*vo.TypeInfo{
"key0": {Type: vo.DataTypeInteger},
"key1": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeString}},
"key2": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key3": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key31": &vo.TypeInfo{Type: vo.DataTypeString},
"key32": &vo.TypeInfo{Type: vo.DataTypeString},
"key33": &vo.TypeInfo{Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key34": &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key341": &vo.TypeInfo{Type: vo.DataTypeString},
"key342": &vo.TypeInfo{Type: vo.DataTypeString},
}},
c := &Runner{
code: codeTpl,
language: coderunner.Python,
outputConfig: map[string]*vo.TypeInfo{
"key0": {Type: vo.DataTypeInteger},
"key1": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeString}},
"key2": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key3": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key31": {Type: vo.DataTypeString},
"key32": {Type: vo.DataTypeString},
"key33": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key34": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key341": {Type: vo.DataTypeString},
"key342": {Type: vo.DataTypeString},
}},
"key4": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key31": &vo.TypeInfo{Type: vo.DataTypeString},
"key32": &vo.TypeInfo{Type: vo.DataTypeString},
"key33": &vo.TypeInfo{Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key34": &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key341": &vo.TypeInfo{Type: vo.DataTypeString},
"key342": &vo.TypeInfo{Type: vo.DataTypeString},
},
}},
}},
"key4": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key31": {Type: vo.DataTypeString},
"key32": {Type: vo.DataTypeString},
"key33": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key34": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key341": {Type: vo.DataTypeString},
"key342": {Type: vo.DataTypeString},
},
}},
},
Runner: mockRunner,
},
runner: mockRunner,
}
ret, err := c.RunCode(ctx, map[string]any{
ret, err := c.Invoke(ctx, map[string]any{
"input": "1123",
})
@ -219,30 +216,28 @@ async def main(args:Args)->Output:
}
mockRunner.EXPECT().Run(gomock.Any(), gomock.Any()).Return(response, nil)
c := &CodeRunner{
config: &Config{
Code: codeTpl,
Language: coderunner.Python,
OutputConfig: map[string]*vo.TypeInfo{
"key0": {Type: vo.DataTypeInteger},
"key1": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key2": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key3": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key31": &vo.TypeInfo{Type: vo.DataTypeString},
"key32": &vo.TypeInfo{Type: vo.DataTypeString},
"key33": &vo.TypeInfo{Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key34": &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key341": &vo.TypeInfo{Type: vo.DataTypeString},
"key342": &vo.TypeInfo{Type: vo.DataTypeString},
"key343": &vo.TypeInfo{Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
}},
},
},
c := &Runner{
code: codeTpl,
language: coderunner.Python,
outputConfig: map[string]*vo.TypeInfo{
"key0": {Type: vo.DataTypeInteger},
"key1": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key2": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key3": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key31": {Type: vo.DataTypeString},
"key32": {Type: vo.DataTypeString},
"key33": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
"key34": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"key341": {Type: vo.DataTypeString},
"key342": {Type: vo.DataTypeString},
"key343": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
}},
},
},
Runner: mockRunner,
},
runner: mockRunner,
}
ret, err := c.RunCode(ctx, map[string]any{
ret, err := c.Invoke(ctx, map[string]any{
"input": "1123",
})

View File

@ -0,0 +1,236 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package database
import (
"fmt"
einoCompose "github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
func setDatabaseInputsForNodeSchema(n *vo.Node, ns *schema.NodeSchema) (err error) {
selectParam := n.Data.Inputs.SelectParam
if selectParam != nil {
err = applyDBConditionToSchema(ns, selectParam.Condition, n.Parent())
if err != nil {
return err
}
}
insertParam := n.Data.Inputs.InsertParam
if insertParam != nil {
err = applyInsetFieldInfoToSchema(ns, insertParam.FieldInfo, n.Parent())
if err != nil {
return err
}
}
deleteParam := n.Data.Inputs.DeleteParam
if deleteParam != nil {
err = applyDBConditionToSchema(ns, &deleteParam.Condition, n.Parent())
if err != nil {
return err
}
}
updateParam := n.Data.Inputs.UpdateParam
if updateParam != nil {
err = applyDBConditionToSchema(ns, &updateParam.Condition, n.Parent())
if err != nil {
return err
}
err = applyInsetFieldInfoToSchema(ns, updateParam.FieldInfo, n.Parent())
if err != nil {
return err
}
}
return nil
}
func applyDBConditionToSchema(ns *schema.NodeSchema, condition *vo.DBCondition, parentNode *vo.Node) error {
if condition.ConditionList == nil {
return nil
}
for idx, params := range condition.ConditionList {
var right *vo.Param
for _, param := range params {
if param == nil {
continue
}
if param.Name == "right" {
right = param
break
}
}
if right == nil {
continue
}
name := fmt.Sprintf("__condition_right_%d", idx)
tInfo, err := convert.CanvasBlockInputToTypeInfo(right.Input)
if err != nil {
return err
}
ns.SetInputType(name, tInfo)
sources, err := convert.CanvasBlockInputToFieldInfo(right.Input, einoCompose.FieldPath{name}, parentNode)
if err != nil {
return err
}
ns.AddInputSource(sources...)
}
return nil
}
func applyInsetFieldInfoToSchema(ns *schema.NodeSchema, fieldInfo [][]*vo.Param, parentNode *vo.Node) error {
if len(fieldInfo) == 0 {
return nil
}
for _, params := range fieldInfo {
// Each FieldInfo is list params, containing two elements.
// The first is to set the name of the field and the second is the corresponding value.
p0 := params[0]
p1 := params[1]
name := p0.Input.Value.Content.(string) // must string type
tInfo, err := convert.CanvasBlockInputToTypeInfo(p1.Input)
if err != nil {
return err
}
name = "__setting_field_" + name
ns.SetInputType(name, tInfo)
sources, err := convert.CanvasBlockInputToFieldInfo(p1.Input, einoCompose.FieldPath{name}, parentNode)
if err != nil {
return err
}
ns.AddInputSource(sources...)
}
return nil
}
func buildClauseGroupFromCondition(condition *vo.DBCondition) (*database.ClauseGroup, error) {
clauseGroup := &database.ClauseGroup{}
if len(condition.ConditionList) == 1 {
params := condition.ConditionList[0]
clause, err := buildClauseFromParams(params)
if err != nil {
return nil, err
}
clauseGroup.Single = clause
} else {
relation, err := convertLogicTypeToRelation(condition.Logic)
if err != nil {
return nil, err
}
clauseGroup.Multi = &database.MultiClause{
Clauses: make([]*database.Clause, 0, len(condition.ConditionList)),
Relation: relation,
}
for i := range condition.ConditionList {
params := condition.ConditionList[i]
clause, err := buildClauseFromParams(params)
if err != nil {
return nil, err
}
clauseGroup.Multi.Clauses = append(clauseGroup.Multi.Clauses, clause)
}
}
return clauseGroup, nil
}
func buildClauseFromParams(params []*vo.Param) (*database.Clause, error) {
var left, operation *vo.Param
for _, p := range params {
if p == nil {
continue
}
if p.Name == "left" {
left = p
continue
}
if p.Name == "operation" {
operation = p
continue
}
}
if left == nil {
return nil, fmt.Errorf("left clause is required")
}
if operation == nil {
return nil, fmt.Errorf("operation clause is required")
}
operator, err := operationToOperator(operation.Input.Value.Content.(string))
if err != nil {
return nil, err
}
clause := &database.Clause{
Left: left.Input.Value.Content.(string),
Operator: operator,
}
return clause, nil
}
func convertLogicTypeToRelation(logicType vo.DatabaseLogicType) (database.ClauseRelation, error) {
switch logicType {
case vo.DatabaseLogicAnd:
return database.ClauseRelationAND, nil
case vo.DatabaseLogicOr:
return database.ClauseRelationOR, nil
default:
return "", fmt.Errorf("logic type %v is invalid", logicType)
}
}
func operationToOperator(s string) (database.Operator, error) {
switch s {
case "EQUAL":
return database.OperatorEqual, nil
case "NOT_EQUAL":
return database.OperatorNotEqual, nil
case "GREATER_THAN":
return database.OperatorGreater, nil
case "LESS_THAN":
return database.OperatorLesser, nil
case "GREATER_EQUAL":
return database.OperatorGreaterOrEqual, nil
case "LESS_EQUAL":
return database.OperatorLesserOrEqual, nil
case "IN":
return database.OperatorIn, nil
case "NOT_IN":
return database.OperatorNotIn, nil
case "IS_NULL":
return database.OperatorIsNull, nil
case "IS_NOT_NULL":
return database.OperatorIsNotNull, nil
case "LIKE":
return database.OperatorLike, nil
case "NOT_LIKE":
return database.OperatorNotLike, nil
}
return "", fmt.Errorf("not a valid Operation string")
}

View File

@ -342,7 +342,7 @@ func responseFormatted(configOutput map[string]*vo.TypeInfo, response *database.
return ret, nil
}
func convertClauseGroupToConditionGroup(ctx context.Context, clauseGroup *database.ClauseGroup, input map[string]any) (*database.ConditionGroup, error) {
func convertClauseGroupToConditionGroup(_ context.Context, clauseGroup *database.ClauseGroup, input map[string]any) (*database.ConditionGroup, error) {
var (
rightValue any
ok bool
@ -394,13 +394,13 @@ func convertClauseGroupToConditionGroup(ctx context.Context, clauseGroup *databa
return conditionGroup, nil
}
func convertClauseGroupToUpdateInventory(ctx context.Context, clauseGroup *database.ClauseGroup, input map[string]any) (*UpdateInventory, error) {
func convertClauseGroupToUpdateInventory(ctx context.Context, clauseGroup *database.ClauseGroup, input map[string]any) (*updateInventory, error) {
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, clauseGroup, input)
if err != nil {
return nil, err
}
fields := parseToInput(input)
inventory := &UpdateInventory{
inventory := &updateInventory{
ConditionGroup: conditionGroup,
Fields: fields,
}

View File

@ -19,48 +19,89 @@ package database
import (
"context"
"errors"
"fmt"
"reflect"
"strconv"
"strings"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
)
type CustomSQLConfig struct {
DatabaseInfoID int64
SQLTemplate string
OutputConfig map[string]*vo.TypeInfo
CustomSQLExecutor database.DatabaseOperator
DatabaseInfoID int64
SQLTemplate string
}
func NewCustomSQL(_ context.Context, cfg *CustomSQLConfig) (*CustomSQL, error) {
if cfg == nil {
return nil, errors.New("config is required")
func (c *CustomSQLConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeDatabaseCustomSQL,
Name: n.Data.Meta.Title,
Configs: c,
}
if cfg.DatabaseInfoID == 0 {
dsList := n.Data.Inputs.DatabaseInfoList
if len(dsList) == 0 {
return nil, fmt.Errorf("database info is requird")
}
databaseInfo := dsList[0]
dsID, err := strconv.ParseInt(databaseInfo.DatabaseInfoID, 10, 64)
if err != nil {
return nil, err
}
c.DatabaseInfoID = dsID
sql := n.Data.Inputs.SQL
if len(sql) == 0 {
return nil, fmt.Errorf("sql is requird")
}
c.SQLTemplate = sql
if err = convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func (c *CustomSQLConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
if c.DatabaseInfoID == 0 {
return nil, errors.New("database info id is required and greater than 0")
}
if cfg.SQLTemplate == "" {
if c.SQLTemplate == "" {
return nil, errors.New("sql template is required")
}
if cfg.CustomSQLExecutor == nil {
return nil, errors.New("custom sqler is required")
}
return &CustomSQL{
config: cfg,
databaseInfoID: c.DatabaseInfoID,
sqlTemplate: c.SQLTemplate,
outputTypes: ns.OutputTypes,
customSQLExecutor: database.GetDatabaseOperator(),
}, nil
}
type CustomSQL struct {
config *CustomSQLConfig
databaseInfoID int64
sqlTemplate string
outputTypes map[string]*vo.TypeInfo
customSQLExecutor database.DatabaseOperator
}
func (c *CustomSQL) Execute(ctx context.Context, input map[string]any) (map[string]any, error) {
func (c *CustomSQL) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
req := &database.CustomSQLRequest{
DatabaseInfoID: c.config.DatabaseInfoID,
DatabaseInfoID: c.databaseInfoID,
IsDebugRun: isDebugExecute(ctx),
UserID: getExecUserID(ctx),
}
@ -71,7 +112,7 @@ func (c *CustomSQL) Execute(ctx context.Context, input map[string]any) (map[stri
}
templateSQL := ""
templateParts := nodes.ParseTemplate(c.config.SQLTemplate)
templateParts := nodes.ParseTemplate(c.sqlTemplate)
sqlParams := make([]database.SQLParam, 0, len(templateParts))
var nilError = errors.New("field is nil")
for _, templatePart := range templateParts {
@ -113,12 +154,12 @@ func (c *CustomSQL) Execute(ctx context.Context, input map[string]any) (map[stri
templateSQL = strings.Replace(templateSQL, "`?`", "?", -1)
req.SQL = templateSQL
req.Params = sqlParams
response, err := c.config.CustomSQLExecutor.Execute(ctx, req)
response, err := c.customSQLExecutor.Execute(ctx, req)
if err != nil {
return nil, err
}
ret, err := responseFormatted(c.config.OutputConfig, response)
ret, err := responseFormatted(c.outputTypes, response)
if err != nil {
return nil, err
}

View File

@ -28,6 +28,7 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database/databasemock"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
type mockCustomSQLer struct {
@ -39,7 +40,7 @@ func (m mockCustomSQLer) Execute() func(ctx context.Context, request *database.C
m.validate(request)
r := &database.Response{
Objects: []database.Object{
database.Object{
{
"v1": "v1_ret",
"v2": "v2_ret",
},
@ -58,9 +59,9 @@ func TestCustomSQL_Execute(t *testing.T) {
validate: func(req *database.CustomSQLRequest) {
assert.Equal(t, int64(111), req.DatabaseInfoID)
ps := []database.SQLParam{
database.SQLParam{Value: "v1_value"},
database.SQLParam{Value: "v2_value"},
database.SQLParam{Value: "v3_value"},
{Value: "v1_value"},
{Value: "v2_value"},
{Value: "v3_value"},
}
assert.Equal(t, ps, req.Params)
assert.Equal(t, "select * from v1 where v1 = ? and v2 = ? and v3 = ?", req.SQL)
@ -80,23 +81,25 @@ func TestCustomSQL_Execute(t *testing.T) {
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
mockDatabaseOperator.EXPECT().Execute(gomock.Any(), gomock.Any()).DoAndReturn(mockSQLer.Execute()).AnyTimes()
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
cfg := &CustomSQLConfig{
DatabaseInfoID: 111,
SQLTemplate: "select * from v1 where v1 = {{v1}} and v2 = '{{v2}}' and v3 = `{{v3}}`",
CustomSQLExecutor: mockDatabaseOperator,
OutputConfig: map[string]*vo.TypeInfo{
DatabaseInfoID: 111,
SQLTemplate: "select * from v1 where v1 = {{v1}} and v2 = '{{v2}}' and v3 = `{{v3}}`",
}
c1, err := cfg.Build(context.Background(), &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"v1": {Type: vo.DataTypeString},
"v2": {Type: vo.DataTypeString},
}}},
"rowNum": {Type: vo.DataTypeInteger},
},
}
cl := &CustomSQL{
config: cfg,
}
})
assert.NoError(t, err)
ret, err := cl.Execute(t.Context(), map[string]any{
ret, err := c1.(*CustomSQL).Invoke(t.Context(), map[string]any{
"v1": "v1_value",
"v2": "v2_value",
"v3": "v3_value",

View File

@ -20,61 +20,102 @@ import (
"context"
"errors"
"fmt"
"strconv"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
type DeleteConfig struct {
DatabaseInfoID int64
ClauseGroup *database.ClauseGroup
OutputConfig map[string]*vo.TypeInfo
Deleter database.DatabaseOperator
}
type Delete struct {
config *DeleteConfig
}
func NewDelete(_ context.Context, cfg *DeleteConfig) (*Delete, error) {
if cfg == nil {
return nil, errors.New("config is required")
func (d *DeleteConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeDatabaseDelete,
Name: n.Data.Meta.Title,
Configs: d,
}
if cfg.DatabaseInfoID == 0 {
dsList := n.Data.Inputs.DatabaseInfoList
if len(dsList) == 0 {
return nil, fmt.Errorf("database info is requird")
}
databaseInfo := dsList[0]
dsID, err := strconv.ParseInt(databaseInfo.DatabaseInfoID, 10, 64)
if err != nil {
return nil, err
}
d.DatabaseInfoID = dsID
deleteParam := n.Data.Inputs.DeleteParam
clauseGroup, err := buildClauseGroupFromCondition(&deleteParam.Condition)
if err != nil {
return nil, err
}
d.ClauseGroup = clauseGroup
if err = setDatabaseInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func (d *DeleteConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
if d.DatabaseInfoID == 0 {
return nil, errors.New("database info id is required and greater than 0")
}
if cfg.ClauseGroup == nil {
if d.ClauseGroup == nil {
return nil, errors.New("clauseGroup is required")
}
if cfg.Deleter == nil {
return nil, errors.New("deleter is required")
}
return &Delete{
config: cfg,
databaseInfoID: d.DatabaseInfoID,
clauseGroup: d.ClauseGroup,
outputTypes: ns.OutputTypes,
deleter: database.GetDatabaseOperator(),
}, nil
}
func (d *Delete) Delete(ctx context.Context, in map[string]any) (map[string]any, error) {
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, d.config.ClauseGroup, in)
type Delete struct {
databaseInfoID int64
clauseGroup *database.ClauseGroup
outputTypes map[string]*vo.TypeInfo
deleter database.DatabaseOperator
}
func (d *Delete) Invoke(ctx context.Context, in map[string]any) (map[string]any, error) {
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, d.clauseGroup, in)
if err != nil {
return nil, err
}
request := &database.DeleteRequest{
DatabaseInfoID: d.config.DatabaseInfoID,
DatabaseInfoID: d.databaseInfoID,
ConditionGroup: conditionGroup,
IsDebugRun: isDebugExecute(ctx),
UserID: getExecUserID(ctx),
}
response, err := d.config.Deleter.Delete(ctx, request)
response, err := d.deleter.Delete(ctx, request)
if err != nil {
return nil, err
}
ret, err := responseFormatted(d.config.OutputConfig, response)
ret, err := responseFormatted(d.outputTypes, response)
if err != nil {
return nil, err
}
@ -82,7 +123,7 @@ func (d *Delete) Delete(ctx context.Context, in map[string]any) (map[string]any,
}
func (d *Delete) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) {
conditionGroup, err := convertClauseGroupToConditionGroup(context.Background(), d.config.ClauseGroup, in)
conditionGroup, err := convertClauseGroupToConditionGroup(context.Background(), d.clauseGroup, in)
if err != nil {
return nil, err
}
@ -90,7 +131,7 @@ func (d *Delete) ToCallbackInput(_ context.Context, in map[string]any) (map[stri
}
func (d *Delete) toDatabaseDeleteCallbackInput(conditionGroup *database.ConditionGroup) (map[string]any, error) {
databaseID := d.config.DatabaseInfoID
databaseID := d.databaseInfoID
result := make(map[string]any)
result["databaseInfoList"] = []string{fmt.Sprintf("%d", databaseID)}

View File

@ -20,54 +20,84 @@ import (
"context"
"errors"
"fmt"
"strconv"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
type InsertConfig struct {
DatabaseInfoID int64
OutputConfig map[string]*vo.TypeInfo
Inserter database.DatabaseOperator
}
type Insert struct {
config *InsertConfig
}
func NewInsert(_ context.Context, cfg *InsertConfig) (*Insert, error) {
if cfg == nil {
return nil, errors.New("config is required")
func (i *InsertConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeDatabaseInsert,
Name: n.Data.Meta.Title,
Configs: i,
}
if cfg.DatabaseInfoID == 0 {
dsList := n.Data.Inputs.DatabaseInfoList
if len(dsList) == 0 {
return nil, fmt.Errorf("database info is requird")
}
databaseInfo := dsList[0]
dsID, err := strconv.ParseInt(databaseInfo.DatabaseInfoID, 10, 64)
if err != nil {
return nil, err
}
i.DatabaseInfoID = dsID
if err = setDatabaseInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func (i *InsertConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
if i.DatabaseInfoID == 0 {
return nil, errors.New("database info id is required and greater than 0")
}
if cfg.Inserter == nil {
return nil, errors.New("inserter is required")
}
return &Insert{
config: cfg,
databaseInfoID: i.DatabaseInfoID,
outputTypes: ns.OutputTypes,
inserter: database.GetDatabaseOperator(),
}, nil
}
func (is *Insert) Insert(ctx context.Context, input map[string]any) (map[string]any, error) {
type Insert struct {
databaseInfoID int64
outputTypes map[string]*vo.TypeInfo
inserter database.DatabaseOperator
}
func (is *Insert) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
fields := parseToInput(input)
req := &database.InsertRequest{
DatabaseInfoID: is.config.DatabaseInfoID,
DatabaseInfoID: is.databaseInfoID,
Fields: fields,
IsDebugRun: isDebugExecute(ctx),
UserID: getExecUserID(ctx),
}
response, err := is.config.Inserter.Insert(ctx, req)
response, err := is.inserter.Insert(ctx, req)
if err != nil {
return nil, err
}
ret, err := responseFormatted(is.config.OutputConfig, response)
ret, err := responseFormatted(is.outputTypes, response)
if err != nil {
return nil, err
}
@ -76,7 +106,7 @@ func (is *Insert) Insert(ctx context.Context, input map[string]any) (map[string]
}
func (is *Insert) ToCallbackInput(_ context.Context, input map[string]any) (map[string]any, error) {
databaseID := is.config.DatabaseInfoID
databaseID := is.databaseInfoID
fs := parseToInput(input)
result := make(map[string]any)
result["databaseInfoList"] = []string{fmt.Sprintf("%d", databaseID)}

View File

@ -20,68 +20,137 @@ import (
"context"
"errors"
"fmt"
"strconv"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
type QueryConfig struct {
DatabaseInfoID int64
QueryFields []string
OrderClauses []*database.OrderClause
OutputConfig map[string]*vo.TypeInfo
ClauseGroup *database.ClauseGroup
Limit int64
Op database.DatabaseOperator
}
type Query struct {
config *QueryConfig
}
func NewQuery(_ context.Context, cfg *QueryConfig) (*Query, error) {
if cfg == nil {
return nil, errors.New("config is required")
func (q *QueryConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeDatabaseQuery,
Name: n.Data.Meta.Title,
Configs: q,
}
if cfg.DatabaseInfoID == 0 {
dsList := n.Data.Inputs.DatabaseInfoList
if len(dsList) == 0 {
return nil, fmt.Errorf("database info is requird")
}
databaseInfo := dsList[0]
dsID, err := strconv.ParseInt(databaseInfo.DatabaseInfoID, 10, 64)
if err != nil {
return nil, err
}
q.DatabaseInfoID = dsID
selectParam := n.Data.Inputs.SelectParam
q.Limit = selectParam.Limit
queryFields := make([]string, 0)
for _, v := range selectParam.FieldList {
queryFields = append(queryFields, strconv.FormatInt(v.FieldID, 10))
}
q.QueryFields = queryFields
orderClauses := make([]*database.OrderClause, 0, len(selectParam.OrderByList))
for _, o := range selectParam.OrderByList {
orderClauses = append(orderClauses, &database.OrderClause{
FieldID: strconv.FormatInt(o.FieldID, 10),
IsAsc: o.IsAsc,
})
}
q.OrderClauses = orderClauses
clauseGroup := &database.ClauseGroup{}
if selectParam.Condition != nil {
clauseGroup, err = buildClauseGroupFromCondition(selectParam.Condition)
if err != nil {
return nil, err
}
}
q.ClauseGroup = clauseGroup
if err = setDatabaseInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func (q *QueryConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
if q.DatabaseInfoID == 0 {
return nil, errors.New("database info id is required and greater than 0")
}
if cfg.Limit == 0 {
if q.Limit == 0 {
return nil, errors.New("limit is required and greater than 0")
}
if cfg.Op == nil {
return nil, errors.New("op is required")
}
return &Query{config: cfg}, nil
return &Query{
databaseInfoID: q.DatabaseInfoID,
queryFields: q.QueryFields,
orderClauses: q.OrderClauses,
outputTypes: ns.OutputTypes,
clauseGroup: q.ClauseGroup,
limit: q.Limit,
op: database.GetDatabaseOperator(),
}, nil
}
func (ds *Query) Query(ctx context.Context, in map[string]any) (map[string]any, error) {
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, ds.config.ClauseGroup, in)
type Query struct {
databaseInfoID int64
queryFields []string
orderClauses []*database.OrderClause
outputTypes map[string]*vo.TypeInfo
clauseGroup *database.ClauseGroup
limit int64
op database.DatabaseOperator
}
func (ds *Query) Invoke(ctx context.Context, in map[string]any) (map[string]any, error) {
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, ds.clauseGroup, in)
if err != nil {
return nil, err
}
req := &database.QueryRequest{
DatabaseInfoID: ds.config.DatabaseInfoID,
OrderClauses: ds.config.OrderClauses,
SelectFields: ds.config.QueryFields,
Limit: ds.config.Limit,
DatabaseInfoID: ds.databaseInfoID,
OrderClauses: ds.orderClauses,
SelectFields: ds.queryFields,
Limit: ds.limit,
IsDebugRun: isDebugExecute(ctx),
UserID: getExecUserID(ctx),
}
req.ConditionGroup = conditionGroup
response, err := ds.config.Op.Query(ctx, req)
response, err := ds.op.Query(ctx, req)
if err != nil {
return nil, err
}
ret, err := responseFormatted(ds.config.OutputConfig, response)
ret, err := responseFormatted(ds.outputTypes, response)
if err != nil {
return nil, err
}
@ -93,18 +162,18 @@ func notNeedTakeMapValue(op database.Operator) bool {
}
func (ds *Query) ToCallbackInput(ctx context.Context, in map[string]any) (map[string]any, error) {
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, ds.config.ClauseGroup, in)
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, ds.clauseGroup, in)
if err != nil {
return nil, err
}
return toDatabaseQueryCallbackInput(ds.config, conditionGroup)
return ds.toDatabaseQueryCallbackInput(conditionGroup)
}
func toDatabaseQueryCallbackInput(config *QueryConfig, conditionGroup *database.ConditionGroup) (map[string]any, error) {
func (ds *Query) toDatabaseQueryCallbackInput(conditionGroup *database.ConditionGroup) (map[string]any, error) {
result := make(map[string]any)
databaseID := config.DatabaseInfoID
databaseID := ds.databaseInfoID
result["databaseInfoList"] = []string{fmt.Sprintf("%d", databaseID)}
result["selectParam"] = map[string]any{}
@ -116,8 +185,8 @@ func toDatabaseQueryCallbackInput(config *QueryConfig, conditionGroup *database.
FieldID string `json:"fieldId"`
IsDistinct bool `json:"isDistinct"`
}
fieldList := make([]Field, 0, len(config.QueryFields))
for _, f := range config.QueryFields {
fieldList := make([]Field, 0, len(ds.queryFields))
for _, f := range ds.queryFields {
fieldList = append(fieldList, Field{FieldID: f})
}
type Order struct {
@ -126,7 +195,7 @@ func toDatabaseQueryCallbackInput(config *QueryConfig, conditionGroup *database.
}
OrderList := make([]Order, 0)
for _, c := range config.OrderClauses {
for _, c := range ds.orderClauses {
OrderList = append(OrderList, Order{
FieldID: c.FieldID,
IsAsc: c.IsAsc,
@ -135,12 +204,11 @@ func toDatabaseQueryCallbackInput(config *QueryConfig, conditionGroup *database.
result["selectParam"] = map[string]any{
"condition": condition,
"fieldList": fieldList,
"limit": config.Limit,
"limit": ds.limit,
"orderByList": OrderList,
}
return result, nil
}
type ConditionItem struct {
@ -216,6 +284,5 @@ func convertToLogic(rel database.ClauseRelation) (string, error) {
return "AND", nil
default:
return "", fmt.Errorf("unknown clause relation %v", rel)
}
}

View File

@ -30,6 +30,7 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database/databasemock"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
type mockDsSelect struct {
@ -82,16 +83,7 @@ func TestDataset_Query(t *testing.T) {
},
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
QueryFields: []string{"v1", "v2"},
OutputConfig: map[string]*vo.TypeInfo{
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"v1": {Type: vo.DataTypeString},
"v2": {Type: vo.DataTypeString},
},
}},
"rowNum": {Type: vo.DataTypeInteger},
},
Limit: 10,
}
mockQuery := &mockDsSelect{objects: objects, t: t, validate: func(request *database.QueryRequest) {
@ -106,17 +98,27 @@ func TestDataset_Query(t *testing.T) {
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query())
cfg.Op = mockDatabaseOperator
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
ds := Query{
config: cfg,
}
ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"v1": {Type: vo.DataTypeString},
"v2": {Type: vo.DataTypeString},
},
}},
"rowNum": {Type: vo.DataTypeInteger},
},
})
assert.NoError(t, err)
in := map[string]interface{}{
"__condition_right_0": 1,
}
result, err := ds.Query(t.Context(), in)
result, err := ds.(*Query).Invoke(t.Context(), in)
assert.NoError(t, err)
assert.Equal(t, "1", result["outputList"].([]any)[0].(database.Object)["v1"])
assert.Equal(t, "2", result["outputList"].([]any)[0].(database.Object)["v2"])
@ -137,17 +139,7 @@ func TestDataset_Query(t *testing.T) {
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
QueryFields: []string{"v1", "v2"},
OutputConfig: map[string]*vo.TypeInfo{
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"v1": {Type: vo.DataTypeString},
"v2": {Type: vo.DataTypeString},
},
}},
"rowNum": {Type: vo.DataTypeInteger},
},
Limit: 10,
}
objects := make([]database.Object, 0)
@ -170,18 +162,28 @@ func TestDataset_Query(t *testing.T) {
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
cfg.Op = mockDatabaseOperator
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
ds := Query{
config: cfg,
}
ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"v1": {Type: vo.DataTypeString},
"v2": {Type: vo.DataTypeString},
},
}},
"rowNum": {Type: vo.DataTypeInteger},
},
})
assert.NoError(t, err)
in := map[string]any{
"__condition_right_0": 1,
"__condition_right_1": 2,
}
result, err := ds.Query(t.Context(), in)
result, err := ds.(*Query).Invoke(t.Context(), in)
assert.NoError(t, err)
assert.NoError(t, err)
assert.Equal(t, "1", result["outputList"].([]any)[0].(database.Object)["v1"])
@ -199,17 +201,7 @@ func TestDataset_Query(t *testing.T) {
},
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
QueryFields: []string{"v1", "v2"},
OutputConfig: map[string]*vo.TypeInfo{
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"v1": {Type: vo.DataTypeInteger},
"v2": {Type: vo.DataTypeInteger},
},
}},
"rowNum": {Type: vo.DataTypeInteger},
},
Limit: 10,
}
objects := make([]database.Object, 0)
objects = append(objects, database.Object{
@ -230,17 +222,27 @@ func TestDataset_Query(t *testing.T) {
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
cfg.Op = mockDatabaseOperator
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
ds := Query{
config: cfg,
}
ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"v1": {Type: vo.DataTypeInteger},
"v2": {Type: vo.DataTypeInteger},
},
}},
"rowNum": {Type: vo.DataTypeInteger},
},
})
assert.NoError(t, err)
in := map[string]any{
"__condition_right_0": 1,
}
result, err := ds.Query(t.Context(), in)
result, err := ds.(*Query).Invoke(t.Context(), in)
assert.NoError(t, err)
fmt.Println(result)
assert.Equal(t, map[string]any{
@ -261,18 +263,7 @@ func TestDataset_Query(t *testing.T) {
},
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
QueryFields: []string{"v1", "v2"},
OutputConfig: map[string]*vo.TypeInfo{
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"v1": {Type: vo.DataTypeInteger},
"v2": {Type: vo.DataTypeInteger},
"v3": {Type: vo.DataTypeInteger},
},
}},
"rowNum": {Type: vo.DataTypeInteger},
},
Limit: 10,
}
objects := make([]database.Object, 0)
objects = append(objects, database.Object{
@ -290,15 +281,26 @@ func TestDataset_Query(t *testing.T) {
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
cfg.Op = mockDatabaseOperator
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
ds := Query{
config: cfg,
}
ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"v1": {Type: vo.DataTypeInteger},
"v2": {Type: vo.DataTypeInteger},
"v3": {Type: vo.DataTypeInteger},
},
}},
"rowNum": {Type: vo.DataTypeInteger},
},
})
assert.NoError(t, err)
in := map[string]any{"__condition_right_0": 1}
result, err := ds.Query(t.Context(), in)
result, err := ds.(*Query).Invoke(t.Context(), in)
assert.NoError(t, err)
fmt.Println(result)
assert.Equal(t, int64(1), result["outputList"].([]any)[0].(database.Object)["v1"])
@ -321,22 +323,7 @@ func TestDataset_Query(t *testing.T) {
},
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
QueryFields: []string{"v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8"},
OutputConfig: map[string]*vo.TypeInfo{
"outputList": {Type: vo.DataTypeArray,
ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"v1": {Type: vo.DataTypeInteger},
"v2": {Type: vo.DataTypeNumber},
"v3": {Type: vo.DataTypeBoolean},
"v4": {Type: vo.DataTypeBoolean},
"v5": {Type: vo.DataTypeTime},
"v6": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeInteger}},
"v7": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeBoolean}},
"v8": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
},
}},
"rowNum": {Type: vo.DataTypeInteger},
},
Limit: 10,
}
objects := make([]database.Object, 0)
@ -363,17 +350,32 @@ func TestDataset_Query(t *testing.T) {
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
cfg.Op = mockDatabaseOperator
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
ds := Query{
config: cfg,
}
ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
"outputList": {Type: vo.DataTypeArray,
ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
"v1": {Type: vo.DataTypeInteger},
"v2": {Type: vo.DataTypeNumber},
"v3": {Type: vo.DataTypeBoolean},
"v4": {Type: vo.DataTypeBoolean},
"v5": {Type: vo.DataTypeTime},
"v6": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeInteger}},
"v7": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeBoolean}},
"v8": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
},
}},
"rowNum": {Type: vo.DataTypeInteger},
},
})
assert.NoError(t, err)
in := map[string]any{
"__condition_right_0": 1,
}
result, err := ds.Query(t.Context(), in)
result, err := ds.(*Query).Invoke(t.Context(), in)
assert.NoError(t, err)
object := result["outputList"].([]any)[0].(database.Object)
@ -400,10 +402,7 @@ func TestDataset_Query(t *testing.T) {
},
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
QueryFields: []string{"v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8"},
OutputConfig: map[string]*vo.TypeInfo{
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{}}},
"rowNum": {Type: vo.DataTypeInteger},
},
Limit: 10,
}
objects := make([]database.Object, 0)
@ -429,16 +428,21 @@ func TestDataset_Query(t *testing.T) {
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
cfg.Op = mockDatabaseOperator
ds := Query{
config: cfg,
}
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{}}},
"rowNum": {Type: vo.DataTypeInteger},
},
})
assert.NoError(t, err)
in := map[string]any{
"__condition_right_0": 1,
}
result, err := ds.Query(t.Context(), in)
result, err := ds.(*Query).Invoke(t.Context(), in)
assert.NoError(t, err)
assert.Equal(t, result["outputList"].([]any)[0].(database.Object), database.Object{
"v1": "1",

View File

@ -20,47 +20,93 @@ import (
"context"
"errors"
"fmt"
"strconv"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
type UpdateConfig struct {
DatabaseInfoID int64
ClauseGroup *database.ClauseGroup
OutputConfig map[string]*vo.TypeInfo
Updater database.DatabaseOperator
}
func (u *UpdateConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeDatabaseUpdate,
Name: n.Data.Meta.Title,
Configs: u,
}
dsList := n.Data.Inputs.DatabaseInfoList
if len(dsList) == 0 {
return nil, fmt.Errorf("database info is requird")
}
databaseInfo := dsList[0]
dsID, err := strconv.ParseInt(databaseInfo.DatabaseInfoID, 10, 64)
if err != nil {
return nil, err
}
u.DatabaseInfoID = dsID
updateParam := n.Data.Inputs.UpdateParam
if updateParam == nil {
return nil, fmt.Errorf("update param is requird")
}
clauseGroup, err := buildClauseGroupFromCondition(&updateParam.Condition)
if err != nil {
return nil, err
}
u.ClauseGroup = clauseGroup
if err = setDatabaseInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func (u *UpdateConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
if u.DatabaseInfoID == 0 {
return nil, errors.New("database info id is required and greater than 0")
}
if u.ClauseGroup == nil {
return nil, errors.New("clause group is required and greater than 0")
}
return &Update{
databaseInfoID: u.DatabaseInfoID,
clauseGroup: u.ClauseGroup,
outputTypes: ns.OutputTypes,
updater: database.GetDatabaseOperator(),
}, nil
}
type Update struct {
config *UpdateConfig
databaseInfoID int64
clauseGroup *database.ClauseGroup
outputTypes map[string]*vo.TypeInfo
updater database.DatabaseOperator
}
type UpdateInventory struct {
type updateInventory struct {
ConditionGroup *database.ConditionGroup
Fields map[string]any
}
func NewUpdate(_ context.Context, cfg *UpdateConfig) (*Update, error) {
if cfg == nil {
return nil, errors.New("config is required")
}
if cfg.DatabaseInfoID == 0 {
return nil, errors.New("database info id is required and greater than 0")
}
if cfg.ClauseGroup == nil {
return nil, errors.New("clause group is required and greater than 0")
}
if cfg.Updater == nil {
return nil, errors.New("updater is required")
}
return &Update{config: cfg}, nil
}
func (u *Update) Update(ctx context.Context, in map[string]any) (map[string]any, error) {
inventory, err := convertClauseGroupToUpdateInventory(ctx, u.config.ClauseGroup, in)
func (u *Update) Invoke(ctx context.Context, in map[string]any) (map[string]any, error) {
inventory, err := convertClauseGroupToUpdateInventory(ctx, u.clauseGroup, in)
if err != nil {
return nil, err
}
@ -72,20 +118,20 @@ func (u *Update) Update(ctx context.Context, in map[string]any) (map[string]any,
}
req := &database.UpdateRequest{
DatabaseInfoID: u.config.DatabaseInfoID,
DatabaseInfoID: u.databaseInfoID,
ConditionGroup: inventory.ConditionGroup,
Fields: fields,
IsDebugRun: isDebugExecute(ctx),
UserID: getExecUserID(ctx),
}
response, err := u.config.Updater.Update(ctx, req)
response, err := u.updater.Update(ctx, req)
if err != nil {
return nil, err
}
ret, err := responseFormatted(u.config.OutputConfig, response)
ret, err := responseFormatted(u.outputTypes, response)
if err != nil {
return nil, err
}
@ -94,15 +140,15 @@ func (u *Update) Update(ctx context.Context, in map[string]any) (map[string]any,
}
func (u *Update) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) {
inventory, err := convertClauseGroupToUpdateInventory(context.Background(), u.config.ClauseGroup, in)
inventory, err := convertClauseGroupToUpdateInventory(context.Background(), u.clauseGroup, in)
if err != nil {
return nil, err
}
return u.toDatabaseUpdateCallbackInput(inventory)
}
func (u *Update) toDatabaseUpdateCallbackInput(inventory *UpdateInventory) (map[string]any, error) {
databaseID := u.config.DatabaseInfoID
func (u *Update) toDatabaseUpdateCallbackInput(inventory *updateInventory) (map[string]any, error) {
databaseID := u.databaseInfoID
result := make(map[string]any)
result["databaseInfoList"] = []string{fmt.Sprintf("%d", databaseID)}
result["updateParam"] = map[string]any{}
@ -128,6 +174,6 @@ func (u *Update) toDatabaseUpdateCallbackInput(inventory *UpdateInventory) (map[
"condition": condition,
"fieldInfo": fieldInfo,
}
return result, nil
return result, nil
}

View File

@ -18,7 +18,6 @@ package emitter
import (
"context"
"errors"
"fmt"
"io"
"strings"
@ -26,28 +25,77 @@ import (
"github.com/bytedance/sonic"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
)
type OutputEmitter struct {
cfg *Config
Template string
FullSources map[string]*schema2.SourceInfo
}
type Config struct {
Template string
FullSources map[string]*nodes.SourceInfo
Template string
}
func New(_ context.Context, cfg *Config) (*OutputEmitter, error) {
if cfg == nil {
return nil, errors.New("config is required")
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema2.NodeSchema, error) {
ns := &schema2.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeOutputEmitter,
Name: n.Data.Meta.Title,
Configs: c,
}
content := n.Data.Inputs.Content
streamingOutput := n.Data.Inputs.StreamingOutput
if streamingOutput {
ns.StreamConfigs = &schema2.StreamConfig{
RequireStreamingInput: true,
}
} else {
ns.StreamConfigs = &schema2.StreamConfig{
RequireStreamingInput: false,
}
}
if content != nil {
if content.Type != vo.VariableTypeString {
return nil, fmt.Errorf("output emitter node's content type must be %s, got %s", vo.VariableTypeString, content.Type)
}
if content.Value.Type != vo.BlockInputValueTypeLiteral {
return nil, fmt.Errorf("output emitter node's content value type must be %s, got %s", vo.BlockInputValueTypeLiteral, content.Value.Type)
}
if content.Value.Content == nil {
c.Template = ""
} else {
template, ok := content.Value.Content.(string)
if !ok {
return nil, fmt.Errorf("output emitter node's content value must be string, got %v", content.Value.Content)
}
c.Template = template
}
}
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func (c *Config) Build(_ context.Context, ns *schema2.NodeSchema, _ ...schema2.BuildOption) (any, error) {
return &OutputEmitter{
cfg: cfg,
Template: c.Template,
FullSources: ns.FullSources,
}, nil
}
@ -59,10 +107,10 @@ type cachedVal struct {
type cacheStore struct {
store map[string]*cachedVal
infos map[string]*nodes.SourceInfo
infos map[string]*schema2.SourceInfo
}
func newCacheStore(infos map[string]*nodes.SourceInfo) *cacheStore {
func newCacheStore(infos map[string]*schema2.SourceInfo) *cacheStore {
return &cacheStore{
store: make(map[string]*cachedVal),
infos: infos,
@ -76,7 +124,7 @@ func (c *cacheStore) put(k string, v any) (any, error) {
}
if !sInfo.IsIntermediate { // this is not an intermediate object container
isStream := sInfo.FieldType == nodes.FieldIsStream
isStream := sInfo.FieldType == schema2.FieldIsStream
if !isStream {
_, ok := c.store[k]
if !ok {
@ -159,7 +207,7 @@ func (c *cacheStore) put(k string, v any) (any, error) {
func (c *cacheStore) finished(k string) bool {
cached, ok := c.store[k]
if !ok {
return c.infos[k].FieldType == nodes.FieldSkipped
return c.infos[k].FieldType == schema2.FieldSkipped
}
if cached.finished {
@ -182,7 +230,7 @@ func (c *cacheStore) finished(k string) bool {
return true
}
func (c *cacheStore) find(part nodes.TemplatePart) (root any, subCache *cachedVal, sourceInfo *nodes.SourceInfo,
func (c *cacheStore) find(part nodes.TemplatePart) (root any, subCache *cachedVal, sourceInfo *schema2.SourceInfo,
actualPath []string,
) {
rootCached, ok := c.store[part.Root]
@ -230,7 +278,7 @@ func (c *cacheStore) readyForPart(part nodes.TemplatePart, sw *schema.StreamWrit
hasErr bool, partFinished bool) {
cachedRoot, subCache, sourceInfo, _ := c.find(part)
if cachedRoot != nil && subCache != nil {
if subCache.finished || sourceInfo.FieldType == nodes.FieldIsStream {
if subCache.finished || sourceInfo.FieldType == schema2.FieldIsStream {
hasErr = renderAndSend(part, part.Root, cachedRoot, sw)
if hasErr {
return true, false
@ -315,14 +363,14 @@ func merge(a, b any) any {
const outputKey = "output"
func (e *OutputEmitter) EmitStream(ctx context.Context, in *schema.StreamReader[map[string]any]) (out *schema.StreamReader[map[string]any], err error) {
resolvedSources, err := nodes.ResolveStreamSources(ctx, e.cfg.FullSources)
func (e *OutputEmitter) Transform(ctx context.Context, in *schema.StreamReader[map[string]any]) (out *schema.StreamReader[map[string]any], err error) {
resolvedSources, err := nodes.ResolveStreamSources(ctx, e.FullSources)
if err != nil {
return nil, err
}
sr, sw := schema.Pipe[map[string]any](0)
parts := nodes.ParseTemplate(e.cfg.Template)
parts := nodes.ParseTemplate(e.Template)
safego.Go(ctx, func() {
hasErr := false
defer func() {
@ -454,7 +502,7 @@ func (e *OutputEmitter) EmitStream(ctx context.Context, in *schema.StreamReader[
shouldChangePart = true
}
} else {
if sourceInfo.FieldType == nodes.FieldIsStream {
if sourceInfo.FieldType == schema2.FieldIsStream {
currentV := v
for i := 0; i < len(actualPath)-1; i++ {
currentM, ok := currentV.(map[string]any)
@ -518,8 +566,8 @@ func (e *OutputEmitter) EmitStream(ctx context.Context, in *schema.StreamReader[
return sr, nil
}
func (e *OutputEmitter) Emit(ctx context.Context, in map[string]any) (output map[string]any, err error) {
s, err := nodes.Render(ctx, e.cfg.Template, in, e.cfg.FullSources)
func (e *OutputEmitter) Invoke(ctx context.Context, in map[string]any) (output map[string]any, err error) {
s, err := nodes.Render(ctx, e.Template, in, e.FullSources)
if err != nil {
return nil, err
}

View File

@ -20,41 +20,74 @@ import (
"context"
"fmt"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
type Config struct {
DefaultValues map[string]any
OutputTypes map[string]*vo.TypeInfo
}
type Entry struct {
cfg *Config
defaultValues map[string]any
}
func NewEntry(ctx context.Context, cfg *Config) (*Entry, error) {
if cfg == nil {
return nil, fmt.Errorf("config is requried")
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
if n.Parent() != nil {
return nil, fmt.Errorf("entry node cannot have parent: %s", n.Parent().ID)
}
defaultValues, _, err := nodes.ConvertInputs(ctx, cfg.DefaultValues, cfg.OutputTypes, nodes.FailFast(), nodes.SkipRequireCheck())
if n.ID != entity.EntryNodeKey {
return nil, fmt.Errorf("entry node id must be %s, got %s", entity.EntryNodeKey, n.ID)
}
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Name: n.Data.Meta.Title,
Type: entity.NodeTypeEntry,
}
defaultValues := make(map[string]any, len(n.Data.Outputs))
for _, v := range n.Data.Outputs {
variable, err := vo.ParseVariable(v)
if err != nil {
return nil, err
}
if variable.DefaultValue != nil {
defaultValues[variable.Name] = variable.DefaultValue
}
}
c.DefaultValues = defaultValues
ns.Configs = c
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func (c *Config) Build(ctx context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
defaultValues, _, err := nodes.ConvertInputs(ctx, c.DefaultValues, ns.OutputTypes, nodes.FailFast(), nodes.SkipRequireCheck())
if err != nil {
return nil, err
}
return &Entry{
cfg: cfg,
defaultValues: defaultValues,
outputTypes: ns.OutputTypes,
}, nil
}
type Entry struct {
defaultValues map[string]any
outputTypes map[string]*vo.TypeInfo
}
func (e *Entry) Invoke(_ context.Context, in map[string]any) (out map[string]any, err error) {
for k, v := range e.defaultValues {
if val, ok := in[k]; ok {
tInfo := e.cfg.OutputTypes[k]
tInfo := e.outputTypes[k]
switch tInfo.Type {
case vo.DataTypeString:
if len(val.(string)) == 0 {

View File

@ -0,0 +1,113 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package exit
import (
"context"
"fmt"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/emitter"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
type Config struct {
Template string
TerminatePlan vo.TerminatePlan
}
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
if n.Parent() != nil {
return nil, fmt.Errorf("exit node cannot have parent: %s", n.Parent().ID)
}
if n.ID != entity.ExitNodeKey {
return nil, fmt.Errorf("exit node id must be %s, got %s", entity.ExitNodeKey, n.ID)
}
ns := &schema.NodeSchema{
Key: entity.ExitNodeKey,
Type: entity.NodeTypeExit,
Name: n.Data.Meta.Title,
Configs: c,
}
var (
content *vo.BlockInput
streamingOutput bool
)
if n.Data.Inputs.OutputEmitter != nil {
content = n.Data.Inputs.Content
streamingOutput = n.Data.Inputs.StreamingOutput
}
if streamingOutput {
ns.StreamConfigs = &schema.StreamConfig{
RequireStreamingInput: true,
}
} else {
ns.StreamConfigs = &schema.StreamConfig{
RequireStreamingInput: false,
}
}
if content != nil {
if content.Type != vo.VariableTypeString {
return nil, fmt.Errorf("exit node's content type must be %s, got %s", vo.VariableTypeString, content.Type)
}
if content.Value.Type != vo.BlockInputValueTypeLiteral {
return nil, fmt.Errorf("exit node's content value type must be %s, got %s", vo.BlockInputValueTypeLiteral, content.Value.Type)
}
c.Template = content.Value.Content.(string)
}
if n.Data.Inputs.TerminatePlan == nil {
return nil, fmt.Errorf("exit node requires a TerminatePlan")
}
c.TerminatePlan = *n.Data.Inputs.TerminatePlan
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
if c.TerminatePlan == vo.ReturnVariables {
return &Exit{}, nil
}
return &emitter.OutputEmitter{
Template: c.Template,
FullSources: ns.FullSources,
}, nil
}
type Exit struct{}
func (e *Exit) Invoke(_ context.Context, in map[string]any) (map[string]any, error) {
if in == nil {
return map[string]any{}, nil
}
return in, nil
}

View File

@ -0,0 +1,340 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package httprequester
import (
"fmt"
"regexp"
"strings"
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/lang/crypto"
)
var extractBracesRegexp = regexp.MustCompile(`\{\{(.*?)}}`)
func extractBracesContent(s string) []string {
matches := extractBracesRegexp.FindAllStringSubmatch(s, -1)
var result []string
for _, match := range matches {
if len(match) >= 2 {
result = append(result, match[1])
}
}
return result
}
type ImplicitNodeDependency struct {
NodeID string
FieldPath compose.FieldPath
TypeInfo *vo.TypeInfo
}
func extractImplicitDependency(node *vo.Node, canvas *vo.Canvas) ([]*ImplicitNodeDependency, error) {
dependencies := make([]*ImplicitNodeDependency, 0, len(canvas.Nodes))
url := node.Data.Inputs.APIInfo.URL
urlVars := extractBracesContent(url)
hasReferred := make(map[string]bool)
extractDependenciesFromVars := func(vars []string) error {
for _, v := range vars {
if strings.HasPrefix(v, "block_output_") {
paths := strings.Split(strings.TrimPrefix(v, "block_output_"), ".")
if len(paths) < 2 {
return fmt.Errorf("invalid block_output_ variable: %s", v)
}
if hasReferred[v] {
continue
}
hasReferred[v] = true
dependencies = append(dependencies, &ImplicitNodeDependency{
NodeID: paths[0],
FieldPath: paths[1:],
})
}
}
return nil
}
err := extractDependenciesFromVars(urlVars)
if err != nil {
return nil, err
}
if node.Data.Inputs.Body.BodyType == string(BodyTypeJSON) {
jsonVars := extractBracesContent(node.Data.Inputs.Body.BodyData.Json)
err = extractDependenciesFromVars(jsonVars)
if err != nil {
return nil, err
}
}
if node.Data.Inputs.Body.BodyType == string(BodyTypeRawText) {
rawTextVars := extractBracesContent(node.Data.Inputs.Body.BodyData.Json)
err = extractDependenciesFromVars(rawTextVars)
if err != nil {
return nil, err
}
}
var nodeFinder func(nodes []*vo.Node, nodeID string) *vo.Node
nodeFinder = func(nodes []*vo.Node, nodeID string) *vo.Node {
for i := range nodes {
if nodes[i].ID == nodeID {
return nodes[i]
}
if len(nodes[i].Blocks) > 0 {
if n := nodeFinder(nodes[i].Blocks, nodeID); n != nil {
return n
}
}
}
return nil
}
for _, ds := range dependencies {
fNode := nodeFinder(canvas.Nodes, ds.NodeID)
if fNode == nil {
continue
}
tInfoMap := make(map[string]*vo.TypeInfo, len(node.Data.Outputs))
for _, vAny := range fNode.Data.Outputs {
v, err := vo.ParseVariable(vAny)
if err != nil {
return nil, err
}
tInfo, err := convert.CanvasVariableToTypeInfo(v)
if err != nil {
return nil, err
}
tInfoMap[v.Name] = tInfo
}
tInfo, ok := getTypeInfoByPath(ds.FieldPath[0], ds.FieldPath[1:], tInfoMap)
if !ok {
return nil, fmt.Errorf("cannot find type info for dependency: %s", ds.FieldPath)
}
ds.TypeInfo = tInfo
}
return dependencies, nil
}
func getTypeInfoByPath(root string, properties []string, tInfoMap map[string]*vo.TypeInfo) (*vo.TypeInfo, bool) {
if len(properties) == 0 {
if tInfo, ok := tInfoMap[root]; ok {
return tInfo, true
}
return nil, false
}
tInfo, ok := tInfoMap[root]
if !ok {
return nil, false
}
return getTypeInfoByPath(properties[0], properties[1:], tInfo.Properties)
}
var globalVariableRegex = regexp.MustCompile(`global_variable_\w+\s*\["(.*?)"]`)
func setHttpRequesterInputsForNodeSchema(n *vo.Node, ns *schema.NodeSchema, implicitNodeDependencies []*ImplicitNodeDependency) (err error) {
inputs := n.Data.Inputs
implicitPathVars := make(map[string]bool)
addImplicitVarsSources := func(prefix string, vars []string) error {
for _, v := range vars {
if strings.HasPrefix(v, "block_output_") {
paths := strings.Split(strings.TrimPrefix(v, "block_output_"), ".")
if len(paths) < 2 {
return fmt.Errorf("invalid implicit var : %s", v)
}
for _, dep := range implicitNodeDependencies {
if dep.NodeID == paths[0] && strings.Join(dep.FieldPath, ".") == strings.Join(paths[1:], ".") {
pathValue := prefix + crypto.MD5HexValue(v)
if _, visited := implicitPathVars[pathValue]; visited {
continue
}
implicitPathVars[pathValue] = true
ns.SetInputType(pathValue, dep.TypeInfo)
ns.AddInputSource(&vo.FieldInfo{
Path: []string{pathValue},
Source: vo.FieldSource{
Ref: &vo.Reference{
FromNodeKey: vo.NodeKey(dep.NodeID),
FromPath: dep.FieldPath,
},
},
})
}
}
}
if strings.HasPrefix(v, "global_variable_") {
matches := globalVariableRegex.FindStringSubmatch(v)
if len(matches) < 2 {
continue
}
var varType vo.GlobalVarType
if strings.HasPrefix(v, string(vo.RefSourceTypeGlobalApp)) {
varType = vo.GlobalAPP
} else if strings.HasPrefix(v, string(vo.RefSourceTypeGlobalUser)) {
varType = vo.GlobalUser
} else if strings.HasPrefix(v, string(vo.RefSourceTypeGlobalSystem)) {
varType = vo.GlobalSystem
} else {
return fmt.Errorf("invalid global variable type: %s", v)
}
source := vo.FieldSource{
Ref: &vo.Reference{
VariableType: &varType,
FromPath: []string{matches[1]},
},
}
ns.AddInputSource(&vo.FieldInfo{
Path: []string{prefix + crypto.MD5HexValue(v)},
Source: source,
})
}
}
return nil
}
urlVars := extractBracesContent(inputs.APIInfo.URL)
err = addImplicitVarsSources("__apiInfo_url_", urlVars)
if err != nil {
return err
}
err = applyParamsToSchema(ns, "__headers_", inputs.Headers, n.Parent())
if err != nil {
return err
}
err = applyParamsToSchema(ns, "__params_", inputs.Params, n.Parent())
if err != nil {
return err
}
if inputs.Auth != nil && inputs.Auth.AuthOpen {
authData := inputs.Auth.AuthData
const bearerTokenKey = "__auth_authData_bearerTokenData_token"
if inputs.Auth.AuthType == "BEARER_AUTH" {
bearTokenParam := authData.BearerTokenData[0]
tInfo, err := convert.CanvasBlockInputToTypeInfo(bearTokenParam.Input)
if err != nil {
return err
}
ns.SetInputType(bearerTokenKey, tInfo)
sources, err := convert.CanvasBlockInputToFieldInfo(bearTokenParam.Input, compose.FieldPath{bearerTokenKey}, n.Parent())
if err != nil {
return err
}
ns.AddInputSource(sources...)
}
if inputs.Auth.AuthType == "CUSTOM_AUTH" {
const (
customDataDataKey = "__auth_authData_customData_data_Key"
customDataDataValue = "__auth_authData_customData_data_Value"
)
dataParams := authData.CustomData.Data
keyParam := dataParams[0]
keyTypeInfo, err := convert.CanvasBlockInputToTypeInfo(keyParam.Input)
if err != nil {
return err
}
ns.SetInputType(customDataDataKey, keyTypeInfo)
sources, err := convert.CanvasBlockInputToFieldInfo(keyParam.Input, compose.FieldPath{customDataDataKey}, n.Parent())
if err != nil {
return err
}
ns.AddInputSource(sources...)
valueParam := dataParams[1]
valueTypeInfo, err := convert.CanvasBlockInputToTypeInfo(valueParam.Input)
if err != nil {
return err
}
ns.SetInputType(customDataDataValue, valueTypeInfo)
sources, err = convert.CanvasBlockInputToFieldInfo(valueParam.Input, compose.FieldPath{customDataDataValue}, n.Parent())
if err != nil {
return err
}
ns.AddInputSource(sources...)
}
}
switch BodyType(inputs.Body.BodyType) {
case BodyTypeFormData:
err = applyParamsToSchema(ns, "__body_bodyData_formData_", inputs.Body.BodyData.FormData.Data, n.Parent())
if err != nil {
return err
}
case BodyTypeFormURLEncoded:
err = applyParamsToSchema(ns, "__body_bodyData_formURLEncoded_", inputs.Body.BodyData.FormURLEncoded, n.Parent())
if err != nil {
return err
}
case BodyTypeBinary:
const fileURLName = "__body_bodyData_binary_fileURL"
fileURLInput := inputs.Body.BodyData.Binary.FileURL
ns.SetInputType(fileURLName, &vo.TypeInfo{
Type: vo.DataTypeString,
})
sources, err := convert.CanvasBlockInputToFieldInfo(fileURLInput, compose.FieldPath{fileURLName}, n.Parent())
if err != nil {
return err
}
ns.AddInputSource(sources...)
case BodyTypeJSON:
jsonVars := extractBracesContent(inputs.Body.BodyData.Json)
err = addImplicitVarsSources("__body_bodyData_json_", jsonVars)
if err != nil {
return err
}
case BodyTypeRawText:
rawTextVars := extractBracesContent(inputs.Body.BodyData.RawText)
err = addImplicitVarsSources("__body_bodyData_rawText_", rawTextVars)
if err != nil {
return err
}
}
return nil
}
func applyParamsToSchema(ns *schema.NodeSchema, prefix string, params []*vo.Param, parentNode *vo.Node) error {
for i := range params {
param := params[i]
name := param.Name
tInfo, err := convert.CanvasBlockInputToTypeInfo(param.Input)
if err != nil {
return err
}
fieldName := prefix + crypto.MD5HexValue(name)
ns.SetInputType(fieldName, tInfo)
sources, err := convert.CanvasBlockInputToFieldInfo(param.Input, compose.FieldPath{fieldName}, parentNode)
if err != nil {
return err
}
ns.AddInputSource(sources...)
}
return nil
}

View File

@ -31,9 +31,14 @@ import (
"strings"
"time"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/lang/crypto"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
)
@ -129,7 +134,7 @@ type Request struct {
FileURL *string
}
var globalVariableReplaceRegexp = regexp.MustCompile(`global_variable_(\w+)\["(\w+)"\]`)
var globalVariableReplaceRegexp = regexp.MustCompile(`global_variable_(\w+)\["(\w+)"]`)
type MD5FieldMapping struct {
HeaderMD5Mapping map[string]string `json:"header_md_5_mapping,omitempty"` // md5 vs key
@ -184,49 +189,188 @@ type Config struct {
Timeout time.Duration
RetryTimes uint64
IgnoreException bool
DefaultOutput map[string]any
MD5FieldMapping
}
type HTTPRequester struct {
client *http.Client
config *Config
}
func NewHTTPRequester(_ context.Context, cfg *Config) (*HTTPRequester, error) {
if cfg == nil {
return nil, fmt.Errorf("config is requried")
func (c *Config) Adapt(_ context.Context, n *vo.Node, opts ...nodes.AdaptOption) (*schema.NodeSchema, error) {
options := nodes.GetAdaptOptions(opts...)
if options.Canvas == nil {
return nil, fmt.Errorf("canvas is requried when adapting HTTPRequester node")
}
if len(cfg.Method) == 0 {
implicitDeps, err := extractImplicitDependency(n, options.Canvas)
if err != nil {
return nil, err
}
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeHTTPRequester,
Name: n.Data.Meta.Title,
Configs: c,
}
inputs := n.Data.Inputs
md5FieldMapping := &MD5FieldMapping{}
method := inputs.APIInfo.Method
c.Method = method
reqURL := inputs.APIInfo.URL
c.URLConfig = URLConfig{
Tpl: strings.TrimSpace(reqURL),
}
urlVars := extractBracesContent(reqURL)
md5FieldMapping.SetURLFields(urlVars...)
md5FieldMapping.SetHeaderFields(slices.Transform(inputs.Headers, func(a *vo.Param) string {
return a.Name
})...)
md5FieldMapping.SetParamFields(slices.Transform(inputs.Params, func(a *vo.Param) string {
return a.Name
})...)
if inputs.Auth != nil && inputs.Auth.AuthOpen {
auth := &AuthenticationConfig{}
ty, err := convertAuthType(inputs.Auth.AuthType)
if err != nil {
return nil, err
}
auth.Type = ty
location, err := convertLocation(inputs.Auth.AuthData.CustomData.AddTo)
if err != nil {
return nil, err
}
auth.Location = location
c.AuthConfig = auth
}
bodyConfig := BodyConfig{}
bodyConfig.BodyType = BodyType(inputs.Body.BodyType)
switch BodyType(inputs.Body.BodyType) {
case BodyTypeJSON:
jsonTpl := inputs.Body.BodyData.Json
bodyConfig.TextJsonConfig = &TextJsonConfig{
Tpl: jsonTpl,
}
jsonVars := extractBracesContent(jsonTpl)
md5FieldMapping.SetBodyFields(jsonVars...)
case BodyTypeFormData:
bodyConfig.FormDataConfig = &FormDataConfig{
FileTypeMapping: map[string]bool{},
}
formDataVars := make([]string, 0)
for i := range inputs.Body.BodyData.FormData.Data {
p := inputs.Body.BodyData.FormData.Data[i]
formDataVars = append(formDataVars, p.Name)
if p.Input.Type == vo.VariableTypeString && p.Input.AssistType > vo.AssistTypeNotSet && p.Input.AssistType < vo.AssistTypeTime {
bodyConfig.FormDataConfig.FileTypeMapping[p.Name] = true
}
}
md5FieldMapping.SetBodyFields(formDataVars...)
case BodyTypeRawText:
TextTpl := inputs.Body.BodyData.RawText
bodyConfig.TextPlainConfig = &TextPlainConfig{
Tpl: TextTpl,
}
textPlainVars := extractBracesContent(TextTpl)
md5FieldMapping.SetBodyFields(textPlainVars...)
case BodyTypeFormURLEncoded:
formURLEncodedVars := make([]string, 0)
for _, p := range inputs.Body.BodyData.FormURLEncoded {
formURLEncodedVars = append(formURLEncodedVars, p.Name)
}
md5FieldMapping.SetBodyFields(formURLEncodedVars...)
}
c.BodyConfig = bodyConfig
c.MD5FieldMapping = *md5FieldMapping
if inputs.Setting != nil {
c.Timeout = time.Duration(inputs.Setting.Timeout) * time.Second
c.RetryTimes = uint64(inputs.Setting.RetryTimes)
}
if err := setHttpRequesterInputsForNodeSchema(n, ns, implicitDeps); err != nil {
return nil, err
}
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func convertAuthType(auth string) (AuthType, error) {
switch auth {
case "CUSTOM_AUTH":
return Custom, nil
case "BEARER_AUTH":
return BearToken, nil
default:
return AuthType(0), fmt.Errorf("invalid auth type")
}
}
func convertLocation(l string) (Location, error) {
switch l {
case "header":
return Header, nil
case "query":
return QueryParam, nil
default:
return 0, fmt.Errorf("invalid location")
}
}
func (c *Config) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
if len(c.Method) == 0 {
return nil, fmt.Errorf("method is requried")
}
hg := &HTTPRequester{}
hg := &HTTPRequester{
urlConfig: c.URLConfig,
method: c.Method,
retryTimes: c.RetryTimes,
authConfig: c.AuthConfig,
bodyConfig: c.BodyConfig,
md5FieldMapping: c.MD5FieldMapping,
}
client := http.DefaultClient
if cfg.Timeout > 0 {
client.Timeout = cfg.Timeout
if c.Timeout > 0 {
client.Timeout = c.Timeout
}
hg.client = client
hg.config = cfg
return hg, nil
}
type HTTPRequester struct {
client *http.Client
urlConfig URLConfig
authConfig *AuthenticationConfig
bodyConfig BodyConfig
method string
retryTimes uint64
md5FieldMapping MD5FieldMapping
}
func (hg *HTTPRequester) Invoke(ctx context.Context, input map[string]any) (output map[string]any, err error) {
var (
req = &Request{}
method = hg.config.Method
retryTimes = hg.config.RetryTimes
method = hg.method
retryTimes = hg.retryTimes
body io.ReadCloser
contentType string
response *http.Response
)
req, err = hg.config.parserToRequest(input)
req, err = hg.parserToRequest(input)
if err != nil {
return nil, err
}
@ -236,7 +380,7 @@ func (hg *HTTPRequester) Invoke(ctx context.Context, input map[string]any) (outp
Header: http.Header{},
}
httpURL, err := nodes.TemplateRender(hg.config.URLConfig.Tpl, req.URLVars)
httpURL, err := nodes.TemplateRender(hg.urlConfig.Tpl, req.URLVars)
if err != nil {
return nil, err
}
@ -255,8 +399,8 @@ func (hg *HTTPRequester) Invoke(ctx context.Context, input map[string]any) (outp
params.Set(key, value)
}
if hg.config.AuthConfig != nil {
httpRequest.Header, params, err = hg.config.AuthConfig.addAuthentication(ctx, req.Authentication, httpRequest.Header, params)
if hg.authConfig != nil {
httpRequest.Header, params, err = hg.authConfig.addAuthentication(ctx, req.Authentication, httpRequest.Header, params)
if err != nil {
return nil, err
}
@ -264,7 +408,7 @@ func (hg *HTTPRequester) Invoke(ctx context.Context, input map[string]any) (outp
u.RawQuery = params.Encode()
httpRequest.URL = u
body, contentType, err = hg.config.BodyConfig.getBodyAndContentType(ctx, req)
body, contentType, err = hg.bodyConfig.getBodyAndContentType(ctx, req)
if err != nil {
return nil, err
}
@ -479,18 +623,16 @@ func httpGet(ctx context.Context, url string) (*http.Response, error) {
}
func (hg *HTTPRequester) ToCallbackInput(_ context.Context, input map[string]any) (map[string]any, error) {
var (
request = &Request{}
config = hg.config
)
request, err := hg.config.parserToRequest(input)
var request = &Request{}
request, err := hg.parserToRequest(input)
if err != nil {
return nil, err
}
result := make(map[string]any)
result["method"] = config.Method
result["method"] = hg.method
u, err := nodes.TemplateRender(config.URLConfig.Tpl, request.URLVars)
u, err := nodes.TemplateRender(hg.urlConfig.Tpl, request.URLVars)
if err != nil {
return nil, err
}
@ -508,13 +650,13 @@ func (hg *HTTPRequester) ToCallbackInput(_ context.Context, input map[string]any
}
result["header"] = headers
result["auth"] = nil
if config.AuthConfig != nil {
if config.AuthConfig.Type == Custom {
if hg.authConfig != nil {
if hg.authConfig.Type == Custom {
result["auth"] = map[string]interface{}{
"Key": request.Authentication.Key,
"Value": request.Authentication.Value,
}
} else if config.AuthConfig.Type == BearToken {
} else if hg.authConfig.Type == BearToken {
result["auth"] = map[string]interface{}{
"token": request.Authentication.Token,
}
@ -522,9 +664,9 @@ func (hg *HTTPRequester) ToCallbackInput(_ context.Context, input map[string]any
}
result["body"] = nil
switch config.BodyConfig.BodyType {
switch hg.bodyConfig.BodyType {
case BodyTypeJSON:
js, err := nodes.TemplateRender(config.BodyConfig.TextJsonConfig.Tpl, request.JsonVars)
js, err := nodes.TemplateRender(hg.bodyConfig.TextJsonConfig.Tpl, request.JsonVars)
if err != nil {
return nil, err
}
@ -535,7 +677,7 @@ func (hg *HTTPRequester) ToCallbackInput(_ context.Context, input map[string]any
}
result["body"] = ret
case BodyTypeRawText:
tx, err := nodes.TemplateRender(config.BodyConfig.TextPlainConfig.Tpl, request.TextPlainVars)
tx, err := nodes.TemplateRender(hg.bodyConfig.TextPlainConfig.Tpl, request.TextPlainVars)
if err != nil {
return nil, err
@ -569,7 +711,7 @@ const (
bodyBinaryFileURLPrefix = "binary_fileURL"
)
func (cfg *Config) parserToRequest(input map[string]any) (*Request, error) {
func (hg *HTTPRequester) parserToRequest(input map[string]any) (*Request, error) {
request := &Request{
URLVars: make(map[string]any),
Headers: make(map[string]string),
@ -583,7 +725,7 @@ func (cfg *Config) parserToRequest(input map[string]any) (*Request, error) {
for key, value := range input {
if strings.HasPrefix(key, apiInfoURLPrefix) {
urlMD5 := strings.TrimPrefix(key, apiInfoURLPrefix)
if urlKey, ok := cfg.URLMD5Mapping[urlMD5]; ok {
if urlKey, ok := hg.md5FieldMapping.URLMD5Mapping[urlMD5]; ok {
if strings.HasPrefix(urlKey, "global_variable_") {
urlKey = globalVariableReplaceRegexp.ReplaceAllString(urlKey, "global_variable_$1.$2")
}
@ -592,13 +734,13 @@ func (cfg *Config) parserToRequest(input map[string]any) (*Request, error) {
}
if strings.HasPrefix(key, headersPrefix) {
headerKeyMD5 := strings.TrimPrefix(key, headersPrefix)
if headerKey, ok := cfg.HeaderMD5Mapping[headerKeyMD5]; ok {
if headerKey, ok := hg.md5FieldMapping.HeaderMD5Mapping[headerKeyMD5]; ok {
request.Headers[headerKey] = value.(string)
}
}
if strings.HasPrefix(key, paramsPrefix) {
paramKeyMD5 := strings.TrimPrefix(key, paramsPrefix)
if paramKey, ok := cfg.ParamMD5Mapping[paramKeyMD5]; ok {
if paramKey, ok := hg.md5FieldMapping.ParamMD5Mapping[paramKeyMD5]; ok {
request.Params[paramKey] = value.(string)
}
}
@ -622,7 +764,7 @@ func (cfg *Config) parserToRequest(input map[string]any) (*Request, error) {
bodyKey := strings.TrimPrefix(key, bodyDataPrefix)
if strings.HasPrefix(bodyKey, bodyJsonPrefix) {
jsonMd5Key := strings.TrimPrefix(bodyKey, bodyJsonPrefix)
if jsonKey, ok := cfg.BodyMD5Mapping[jsonMd5Key]; ok {
if jsonKey, ok := hg.md5FieldMapping.BodyMD5Mapping[jsonMd5Key]; ok {
if strings.HasPrefix(jsonKey, "global_variable_") {
jsonKey = globalVariableReplaceRegexp.ReplaceAllString(jsonKey, "global_variable_$1.$2")
}
@ -632,7 +774,7 @@ func (cfg *Config) parserToRequest(input map[string]any) (*Request, error) {
}
if strings.HasPrefix(bodyKey, bodyFormDataPrefix) {
formDataMd5Key := strings.TrimPrefix(bodyKey, bodyFormDataPrefix)
if formDataKey, ok := cfg.BodyMD5Mapping[formDataMd5Key]; ok {
if formDataKey, ok := hg.md5FieldMapping.BodyMD5Mapping[formDataMd5Key]; ok {
request.FormDataVars[formDataKey] = value.(string)
}
@ -640,14 +782,14 @@ func (cfg *Config) parserToRequest(input map[string]any) (*Request, error) {
if strings.HasPrefix(bodyKey, bodyFormURLEncodedPrefix) {
formURLEncodeMd5Key := strings.TrimPrefix(bodyKey, bodyFormURLEncodedPrefix)
if formURLEncodeKey, ok := cfg.BodyMD5Mapping[formURLEncodeMd5Key]; ok {
if formURLEncodeKey, ok := hg.md5FieldMapping.BodyMD5Mapping[formURLEncodeMd5Key]; ok {
request.FormURLEncodedVars[formURLEncodeKey] = value.(string)
}
}
if strings.HasPrefix(bodyKey, bodyRawTextPrefix) {
rawTextMd5Key := strings.TrimPrefix(bodyKey, bodyRawTextPrefix)
if rawTextKey, ok := cfg.BodyMD5Mapping[rawTextMd5Key]; ok {
if rawTextKey, ok := hg.md5FieldMapping.BodyMD5Mapping[rawTextMd5Key]; ok {
if strings.HasPrefix(rawTextKey, "global_variable_") {
rawTextKey = globalVariableReplaceRegexp.ReplaceAllString(rawTextKey, "global_variable_$1.$2")
}

View File

@ -28,6 +28,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/lang/crypto"
)
@ -68,7 +69,7 @@ func TestInvoke(t *testing.T) {
},
},
}
hg, err := NewHTTPRequester(context.Background(), cfg)
hg, err := cfg.Build(context.Background(), &schema.NodeSchema{})
assert.NoError(t, err)
m := map[string]any{
"__apiInfo_url_" + crypto.MD5HexValue("url_v1"): "v1",
@ -78,7 +79,7 @@ func TestInvoke(t *testing.T) {
"__params_" + crypto.MD5HexValue("p2"): "v2",
}
result, err := hg.Invoke(context.Background(), m)
result, err := hg.(*HTTPRequester).Invoke(context.Background(), m)
assert.NoError(t, err)
assert.Equal(t, `{"message":"success"}`, result["body"])
assert.Equal(t, int64(200), result["statusCode"])
@ -157,7 +158,7 @@ func TestInvoke(t *testing.T) {
}
// Create an HTTPRequest instance
hg, err := NewHTTPRequester(context.Background(), cfg)
hg, err := cfg.Build(context.Background(), &schema.NodeSchema{})
assert.NoError(t, err)
m := map[string]any{
@ -171,7 +172,7 @@ func TestInvoke(t *testing.T) {
"__body_bodyData_formData_" + crypto.MD5HexValue("fileURL"): fileServer.URL,
}
result, err := hg.Invoke(context.Background(), m)
result, err := hg.(*HTTPRequester).Invoke(context.Background(), m)
assert.NoError(t, err)
assert.Equal(t, `{"message":"success"}`, result["body"])
assert.Equal(t, int64(200), result["statusCode"])
@ -228,7 +229,7 @@ func TestInvoke(t *testing.T) {
},
},
}
hg, err := NewHTTPRequester(context.Background(), cfg)
hg, err := cfg.Build(context.Background(), &schema.NodeSchema{})
assert.NoError(t, err)
m := map[string]any{
@ -241,7 +242,7 @@ func TestInvoke(t *testing.T) {
"__body_bodyData_rawText_" + crypto.MD5HexValue("v2"): "v2",
}
result, err := hg.Invoke(context.Background(), m)
result, err := hg.(*HTTPRequester).Invoke(context.Background(), m)
assert.NoError(t, err)
assert.Equal(t, `{"message":"success"}`, result["body"])
assert.Equal(t, int64(200), result["statusCode"])
@ -303,7 +304,7 @@ func TestInvoke(t *testing.T) {
}
// Create an HTTPRequest instance
hg, err := NewHTTPRequester(context.Background(), cfg)
hg, err := cfg.Build(context.Background(), &schema.NodeSchema{})
assert.NoError(t, err)
m := map[string]any{
@ -316,7 +317,7 @@ func TestInvoke(t *testing.T) {
"__body_bodyData_json_" + crypto.MD5HexValue("v2"): "v2",
}
result, err := hg.Invoke(context.Background(), m)
result, err := hg.(*HTTPRequester).Invoke(context.Background(), m)
assert.NoError(t, err)
assert.Equal(t, `{"message":"success"}`, result["body"])
assert.Equal(t, int64(200), result["statusCode"])
@ -376,7 +377,7 @@ func TestInvoke(t *testing.T) {
}
// Create an HTTPRequest instance
hg, err := NewHTTPRequester(context.Background(), cfg)
hg, err := cfg.Build(context.Background(), &schema.NodeSchema{})
assert.NoError(t, err)
m := map[string]any{
@ -388,7 +389,7 @@ func TestInvoke(t *testing.T) {
"__body_bodyData_binary_fileURL" + crypto.MD5HexValue("v1"): fileServer.URL,
}
result, err := hg.Invoke(context.Background(), m)
result, err := hg.(*HTTPRequester).Invoke(context.Background(), m)
assert.NoError(t, err)
assert.Equal(t, `{"message":"success"}`, result["body"])
assert.Equal(t, int64(200), result["statusCode"])

View File

@ -18,26 +18,167 @@ package intentdetector
import (
"context"
"encoding/json"
"errors"
"fmt"
"strconv"
"strings"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/components/prompt"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
"github.com/spf13/cast"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
)
type Config struct {
Intents []string
SystemPrompt string
IsFastMode bool
ChatModel model.BaseChatModel
LLMParams *model.LLMParams
}
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema2.NodeSchema, error) {
ns := &schema2.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeIntentDetector,
Name: n.Data.Meta.Title,
Configs: c,
}
param := n.Data.Inputs.LLMParam
if param == nil {
return nil, fmt.Errorf("intent detector node's llmParam is nil")
}
llmParam, ok := param.(vo.IntentDetectorLLMParam)
if !ok {
return nil, fmt.Errorf("llm node's llmParam must be LLMParam, got %v", llmParam)
}
paramBytes, err := sonic.Marshal(param)
if err != nil {
return nil, err
}
var intentDetectorConfig = &vo.IntentDetectorLLMConfig{}
err = sonic.Unmarshal(paramBytes, &intentDetectorConfig)
if err != nil {
return nil, err
}
modelLLMParams := &model.LLMParams{}
modelLLMParams.ModelType = int64(intentDetectorConfig.ModelType)
modelLLMParams.ModelName = intentDetectorConfig.ModelName
modelLLMParams.TopP = intentDetectorConfig.TopP
modelLLMParams.Temperature = intentDetectorConfig.Temperature
modelLLMParams.MaxTokens = intentDetectorConfig.MaxTokens
modelLLMParams.ResponseFormat = model.ResponseFormat(intentDetectorConfig.ResponseFormat)
modelLLMParams.SystemPrompt = intentDetectorConfig.SystemPrompt.Value.Content.(string)
c.LLMParams = modelLLMParams
c.SystemPrompt = modelLLMParams.SystemPrompt
var intents = make([]string, 0, len(n.Data.Inputs.Intents))
for _, it := range n.Data.Inputs.Intents {
intents = append(intents, it.Name)
}
c.Intents = intents
if n.Data.Inputs.Mode == "top_speed" {
c.IsFastMode = true
}
if err = convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func (c *Config) Build(ctx context.Context, _ *schema2.NodeSchema, _ ...schema2.BuildOption) (any, error) {
if !c.IsFastMode && c.LLMParams == nil {
return nil, errors.New("config chat model is required")
}
if len(c.Intents) == 0 {
return nil, errors.New("config intents is required")
}
m, _, err := model.GetManager().GetModel(ctx, c.LLMParams)
if err != nil {
return nil, err
}
chain := compose.NewChain[map[string]any, *schema.Message]()
spt := ternary.IFElse[string](c.IsFastMode, FastModeSystemIntentPrompt, SystemIntentPrompt)
intents, err := toIntentString(c.Intents)
if err != nil {
return nil, err
}
sptTemplate, err := nodes.TemplateRender(spt, map[string]interface{}{
"intents": intents,
})
if err != nil {
return nil, err
}
prompts := prompt.FromMessages(schema.Jinja2,
&schema.Message{Content: sptTemplate, Role: schema.System},
&schema.Message{Content: "{{query}}", Role: schema.User})
r, err := chain.AppendChatTemplate(prompts).AppendChatModel(m).Compile(ctx)
if err != nil {
return nil, err
}
return &IntentDetector{
isFastMode: c.IsFastMode,
systemPrompt: c.SystemPrompt,
runner: r,
}, nil
}
func (c *Config) BuildBranch(_ context.Context) (
func(ctx context.Context, nodeOutput map[string]any) (int64, bool, error), bool) {
return func(ctx context.Context, nodeOutput map[string]any) (int64, bool, error) {
classificationId, ok := nodeOutput[classificationID]
if !ok {
return -1, false, fmt.Errorf("failed to take classification id from input map: %v", nodeOutput)
}
cID64, ok := classificationId.(int64)
if !ok {
return -1, false, fmt.Errorf("classificationID not of type int64, actual type: %T", classificationId)
}
if cID64 == 0 {
return -1, true, nil
}
return cID64 - 1, false, nil
}, true
}
func (c *Config) ExpectPorts(ctx context.Context, n *vo.Node) []string {
expects := make([]string, len(n.Data.Inputs.Intents)+1)
expects[0] = schema2.PortDefault
for i := 0; i < len(n.Data.Inputs.Intents); i++ {
expects[i+1] = fmt.Sprintf(schema2.PortBranchFormat, i)
}
return expects
}
const SystemIntentPrompt = `
@ -95,71 +236,39 @@ Note:
##Limit
- Please do not reply in text.`
const classificationID = "classificationId"
type IntentDetector struct {
config *Config
runner compose.Runnable[map[string]any, *schema.Message]
}
func NewIntentDetector(ctx context.Context, cfg *Config) (*IntentDetector, error) {
if cfg == nil {
return nil, errors.New("cfg is required")
}
if !cfg.IsFastMode && cfg.ChatModel == nil {
return nil, errors.New("config chat model is required")
}
if len(cfg.Intents) == 0 {
return nil, errors.New("config intents is required")
}
chain := compose.NewChain[map[string]any, *schema.Message]()
spt := ternary.IFElse[string](cfg.IsFastMode, FastModeSystemIntentPrompt, SystemIntentPrompt)
sptTemplate, err := nodes.TemplateRender(spt, map[string]interface{}{
"intents": toIntentString(cfg.Intents),
})
if err != nil {
return nil, err
}
prompts := prompt.FromMessages(schema.Jinja2,
&schema.Message{Content: sptTemplate, Role: schema.System},
&schema.Message{Content: "{{query}}", Role: schema.User})
r, err := chain.AppendChatTemplate(prompts).AppendChatModel(cfg.ChatModel).Compile(ctx)
if err != nil {
return nil, err
}
return &IntentDetector{
config: cfg,
runner: r,
}, nil
isFastMode bool
systemPrompt string
runner compose.Runnable[map[string]any, *schema.Message]
}
func (id *IntentDetector) parseToNodeOut(content string) (map[string]any, error) {
nodeOutput := make(map[string]any)
nodeOutput["classificationId"] = 0
if content == "" {
return nodeOutput, errors.New("content is empty")
return nil, errors.New("intent detector's LLM output content is empty")
}
if id.config.IsFastMode {
if id.isFastMode {
cid, err := strconv.ParseInt(content, 10, 64)
if err != nil {
return nodeOutput, err
return nil, err
}
nodeOutput["classificationId"] = cid
return nodeOutput, nil
return map[string]any{
classificationID: cid,
}, nil
}
leftIndex := strings.Index(content, "{")
rightIndex := strings.Index(content, "}")
if leftIndex == -1 || rightIndex == -1 {
return nodeOutput, errors.New("content is invalid")
return nil, fmt.Errorf("intent detector's LLM output content is invalid: %s", content)
}
err := json.Unmarshal([]byte(content[leftIndex:rightIndex+1]), &nodeOutput)
var nodeOutput map[string]any
err := sonic.UnmarshalString(content[leftIndex:rightIndex+1], &nodeOutput)
if err != nil {
return nodeOutput, err
return nil, err
}
return nodeOutput, nil
@ -178,8 +287,8 @@ func (id *IntentDetector) Invoke(ctx context.Context, input map[string]any) (map
vars := make(map[string]any)
vars["query"] = queryStr
if !id.config.IsFastMode {
ad, err := nodes.TemplateRender(id.config.SystemPrompt, map[string]any{"query": query})
if !id.isFastMode {
ad, err := nodes.TemplateRender(id.systemPrompt, map[string]any{"query": query})
if err != nil {
return nil, err
}
@ -193,7 +302,7 @@ func (id *IntentDetector) Invoke(ctx context.Context, input map[string]any) (map
return id.parseToNodeOut(o.Content)
}
func toIntentString(its []string) string {
func toIntentString(its []string) (string, error) {
type IntentVariableItem struct {
ClassificationID int64 `json:"classificationId"`
Content string `json:"content"`
@ -207,6 +316,6 @@ func toIntentString(its []string) string {
Content: it,
})
}
itsBytes, _ := json.Marshal(vs)
return string(itsBytes)
return sonic.MarshalString(vs)
}

View File

@ -1,88 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package intentdetector
import (
"context"
"fmt"
"testing"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/schema"
"github.com/stretchr/testify/assert"
)
type mockChatModel struct {
topSeed bool
}
func (m mockChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) {
if m.topSeed {
return &schema.Message{
Content: "1",
}, nil
}
return &schema.Message{
Content: `{"classificationId":1,"reason":"高兴"}`,
}, nil
}
func (m mockChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) {
return nil, nil
}
func (m mockChatModel) BindTools(tools []*schema.ToolInfo) error {
return nil
}
func TestNewIntentDetector(t *testing.T) {
ctx := context.Background()
t.Run("fast mode", func(t *testing.T) {
dt, err := NewIntentDetector(ctx, &Config{
Intents: []string{"高兴", "悲伤"},
IsFastMode: true,
ChatModel: &mockChatModel{topSeed: true},
})
assert.Nil(t, err)
ret, err := dt.Invoke(ctx, map[string]any{
"query": "我考了100分",
})
assert.Nil(t, err)
assert.Equal(t, ret["classificationId"], int64(1))
})
t.Run("full mode", func(t *testing.T) {
dt, err := NewIntentDetector(ctx, &Config{
Intents: []string{"高兴", "悲伤"},
IsFastMode: false,
ChatModel: &mockChatModel{},
})
assert.Nil(t, err)
ret, err := dt.Invoke(ctx, map[string]any{
"query": "我考了100分",
})
fmt.Println(err)
assert.Nil(t, err)
fmt.Println(ret)
assert.Equal(t, ret["classificationId"], float64(1))
assert.Equal(t, ret["reason"], "高兴")
})
}

View File

@ -20,8 +20,11 @@ import (
"context"
"fmt"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
@ -34,32 +37,42 @@ const (
warningsKey = "deserialization_warnings"
)
type DeserializationConfig struct {
OutputFields map[string]*vo.TypeInfo `json:"outputFields,omitempty"`
type DeserializationConfig struct{}
func (d *DeserializationConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (
*schema.NodeSchema, error) {
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeJsonDeserialization,
Name: n.Data.Meta.Title,
Configs: d,
}
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
type Deserializer struct {
config *DeserializationConfig
typeInfo *vo.TypeInfo
}
func NewJsonDeserializer(_ context.Context, cfg *DeserializationConfig) (*Deserializer, error) {
if cfg == nil {
return nil, fmt.Errorf("config required")
}
if cfg.OutputFields == nil {
return nil, fmt.Errorf("OutputFields is required for deserialization")
}
typeInfo := cfg.OutputFields[OutputKeyDeserialization]
if typeInfo == nil {
func (d *DeserializationConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
typeInfo, ok := ns.OutputTypes[OutputKeyDeserialization]
if !ok {
return nil, fmt.Errorf("no output field specified in deserialization config")
}
return &Deserializer{
config: cfg,
typeInfo: typeInfo,
}, nil
}
type Deserializer struct {
typeInfo *vo.TypeInfo
}
func (jd *Deserializer) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
jsonStrValue := input[InputKeyDeserialization]

View File

@ -24,6 +24,7 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
)
@ -31,19 +32,9 @@ import (
func TestNewJsonDeserializer(t *testing.T) {
ctx := context.Background()
// Test with nil config
_, err := NewJsonDeserializer(ctx, nil)
assert.Error(t, err)
assert.Contains(t, err.Error(), "config required")
// Test with missing OutputFields config
_, err = NewJsonDeserializer(ctx, &DeserializationConfig{})
assert.Error(t, err)
assert.Contains(t, err.Error(), "OutputFields is required")
// Test with missing output key in OutputFields
_, err = NewJsonDeserializer(ctx, &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
_, err := (&DeserializationConfig{}).Build(ctx, &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
"testKey": {Type: vo.DataTypeString},
},
})
@ -51,12 +42,12 @@ func TestNewJsonDeserializer(t *testing.T) {
assert.Contains(t, err.Error(), "no output field specified in deserialization config")
// Test with valid config
validConfig := &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
validConfig := &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
OutputKeyDeserialization: {Type: vo.DataTypeString},
},
}
processor, err := NewJsonDeserializer(ctx, validConfig)
processor, err := (&DeserializationConfig{}).Build(ctx, validConfig)
assert.NoError(t, err)
assert.NotNil(t, processor)
}
@ -65,16 +56,16 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
ctx := context.Background()
// Base type test config
baseTypeConfig := &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {Type: vo.DataTypeString},
baseTypeConfig := &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
OutputKeyDeserialization: {Type: vo.DataTypeString},
},
}
// Object type test config
objectTypeConfig := &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {
objectTypeConfig := &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
OutputKeyDeserialization: {
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"name": {Type: vo.DataTypeString, Required: true},
@ -85,9 +76,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
}
// Array type test config
arrayTypeConfig := &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {
arrayTypeConfig := &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
OutputKeyDeserialization: {
Type: vo.DataTypeArray,
ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeInteger},
},
@ -95,9 +86,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
}
// Nested array object test config
nestedArrayConfig := &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {
nestedArrayConfig := &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
OutputKeyDeserialization: {
Type: vo.DataTypeArray,
ElemTypeInfo: &vo.TypeInfo{
Type: vo.DataTypeObject,
@ -113,7 +104,7 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
// Test cases
tests := []struct {
name string
config *DeserializationConfig
config *schema.NodeSchema
inputJSON string
expectedOutput any
expectErr bool
@ -127,9 +118,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
expectWarnings: 0,
}, {
name: "Test integer deserialization",
config: &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {Type: vo.DataTypeInteger},
config: &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
OutputKeyDeserialization: {Type: vo.DataTypeInteger},
},
},
inputJSON: `123`,
@ -138,9 +129,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
expectWarnings: 0,
}, {
name: "Test boolean deserialization",
config: &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {Type: vo.DataTypeBoolean},
config: &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
OutputKeyDeserialization: {Type: vo.DataTypeBoolean},
},
},
inputJSON: `true`,
@ -180,9 +171,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
expectWarnings: 0,
}, {
name: "Test type mismatch warning",
config: &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {Type: vo.DataTypeInteger},
config: &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
OutputKeyDeserialization: {Type: vo.DataTypeInteger},
},
},
inputJSON: `"not a number"`,
@ -198,9 +189,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
expectWarnings: 0,
}, {
name: "Test string to integer conversion",
config: &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {Type: vo.DataTypeInteger},
config: &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
OutputKeyDeserialization: {Type: vo.DataTypeInteger},
},
},
inputJSON: `"123"`,
@ -209,9 +200,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
expectWarnings: 0,
}, {
name: "Test float to integer conversion (integer part)",
config: &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {Type: vo.DataTypeInteger},
config: &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
OutputKeyDeserialization: {Type: vo.DataTypeInteger},
},
},
inputJSON: `123.0`,
@ -220,9 +211,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
expectWarnings: 0,
}, {
name: "Test float to integer conversion (non-integer part)",
config: &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {Type: vo.DataTypeInteger},
config: &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
OutputKeyDeserialization: {Type: vo.DataTypeInteger},
},
},
inputJSON: `123.5`,
@ -231,9 +222,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
expectWarnings: 0,
}, {
name: "Test boolean to integer conversion",
config: &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {Type: vo.DataTypeInteger},
config: &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
OutputKeyDeserialization: {Type: vo.DataTypeInteger},
},
},
inputJSON: `true`,
@ -242,9 +233,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
expectWarnings: 1,
}, {
name: "Test string to boolean conversion",
config: &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {Type: vo.DataTypeBoolean},
config: &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
OutputKeyDeserialization: {Type: vo.DataTypeBoolean},
},
},
inputJSON: `"true"`,
@ -252,10 +243,11 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
expectErr: false,
expectWarnings: 0,
}, {
name: "Test string to integer conversion in nested object",
config: &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {
name: "Test string to integer conversion in nested object",
inputJSON: `{"age":"456"}`,
config: &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
OutputKeyDeserialization: {
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"age": {Type: vo.DataTypeInteger},
@ -263,15 +255,14 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
},
},
},
inputJSON: `{"age":"456"}`,
expectedOutput: map[string]any{"age": 456},
expectErr: false,
expectWarnings: 0,
}, {
name: "Test string to integer conversion for array elements",
config: &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {
config: &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
OutputKeyDeserialization: {
Type: vo.DataTypeArray,
ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeInteger},
},
@ -283,9 +274,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
expectWarnings: 0,
}, {
name: "Test string with non-numeric characters to integer conversion",
config: &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {Type: vo.DataTypeInteger},
config: &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
OutputKeyDeserialization: {Type: vo.DataTypeInteger},
},
},
inputJSON: `"123abc"`,
@ -294,9 +285,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
expectWarnings: 1,
}, {
name: "Test type mismatch in nested object field",
config: &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {
config: &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
OutputKeyDeserialization: {
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
"score": {Type: vo.DataTypeInteger},
@ -310,9 +301,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
expectWarnings: 1,
}, {
name: "Test partial conversion failure in array elements",
config: &DeserializationConfig{
OutputFields: map[string]*vo.TypeInfo{
"output": {
config: &schema.NodeSchema{
OutputTypes: map[string]*vo.TypeInfo{
OutputKeyDeserialization: {
Type: vo.DataTypeArray,
ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeInteger},
},
@ -326,12 +317,12 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
processor, err := NewJsonDeserializer(ctx, tt.config)
processor, err := (&DeserializationConfig{}).Build(ctx, tt.config)
assert.NoError(t, err)
ctxWithCache := ctxcache.Init(ctx)
input := map[string]any{"input": tt.inputJSON}
result, err := processor.Invoke(ctxWithCache, input)
result, err := processor.(*Deserializer).Invoke(ctxWithCache, input)
if tt.expectErr {
assert.Error(t, err)

View File

@ -20,7 +20,11 @@ import (
"context"
"fmt"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
)
@ -29,28 +33,57 @@ const (
OutputKeySerialization = "output"
)
// SerializationConfig is the Config type for NodeTypeJsonSerialization.
// Each Node Type should have its own designated Config type,
// which should implement NodeAdaptor and NodeBuilder.
// NOTE: we didn't define any fields for this type,
// because this node is simple, we doesn't need to extract any SPECIFIC piece of info
// from frontend Node. In other cases we would need to do it, such as LLM's model configs.
type SerializationConfig struct {
InputTypes map[string]*vo.TypeInfo
// you can define ANY number of fields here,
// as long as these fields are SERIALIZABLE and EXPORTED.
// to store specific info extracted from frontend node.
// e.g.
// - LLM model configs
// - conditional expressions
// - fixed input fields such as MaxBatchSize
}
type JsonSerializer struct {
config *SerializationConfig
}
func NewJsonSerializer(_ context.Context, cfg *SerializationConfig) (*JsonSerializer, error) {
if cfg == nil {
return nil, fmt.Errorf("config required")
}
if cfg.InputTypes == nil {
return nil, fmt.Errorf("InputTypes is required for serialization")
// Adapt provides conversion from Node to NodeSchema.
// NOTE: in this specific case, we don't need AdaptOption.
func (s *SerializationConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeJsonSerialization,
Name: n.Data.Meta.Title,
Configs: s, // remember to set the Node's Config Type to NodeSchema as well
}
return &JsonSerializer{
config: cfg,
}, nil
// this sets input fields' type and mapping info
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
// this set output fields' type info
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func (js *JsonSerializer) Invoke(_ context.Context, input map[string]any) (map[string]any, error) {
func (s *SerializationConfig) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (
any, error) {
return &Serializer{}, nil
}
// Serializer is the actual node implementation.
type Serializer struct {
// here can holds ANY data required for node execution
}
// Invoke implements the InvokableNode interface.
func (js *Serializer) Invoke(_ context.Context, input map[string]any) (map[string]any, error) {
// Directly use the input map for serialization
if input == nil {
return nil, fmt.Errorf("input data for serialization cannot be nil")

View File

@ -23,44 +23,34 @@ import (
"github.com/stretchr/testify/assert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
func TestNewJsonSerialize(t *testing.T) {
ctx := context.Background()
// Test with nil config
_, err := NewJsonSerializer(ctx, nil)
assert.Error(t, err)
assert.Contains(t, err.Error(), "config required")
// Test with missing InputTypes config
_, err = NewJsonSerializer(ctx, &SerializationConfig{})
assert.Error(t, err)
assert.Contains(t, err.Error(), "InputTypes is required")
// Test with valid config
validConfig := &SerializationConfig{
s, err := (&SerializationConfig{}).Build(ctx, &schema.NodeSchema{
InputTypes: map[string]*vo.TypeInfo{
"testKey": {Type: "string"},
},
}
processor, err := NewJsonSerializer(ctx, validConfig)
})
assert.NoError(t, err)
assert.NotNil(t, processor)
assert.NotNil(t, s)
}
func TestJsonSerialize_Invoke(t *testing.T) {
ctx := context.Background()
config := &SerializationConfig{
processor, err := (&SerializationConfig{}).Build(ctx, &schema.NodeSchema{
InputTypes: map[string]*vo.TypeInfo{
"stringKey": {Type: "string"},
"intKey": {Type: "integer"},
"boolKey": {Type: "boolean"},
"objKey": {Type: "object"},
},
}
processor, err := NewJsonSerializer(ctx, config)
})
assert.NoError(t, err)
// Test cases
@ -115,7 +105,7 @@ func TestJsonSerialize_Invoke(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := processor.Invoke(ctx, tt.input)
result, err := processor.(*Serializer).Invoke(ctx, tt.input)
if tt.expectErr {
assert.Error(t, err)

View File

@ -0,0 +1,57 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package knowledge
import (
"fmt"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
)
func convertParsingType(p string) (knowledge.ParseMode, error) {
switch p {
case "fast":
return knowledge.FastParseMode, nil
case "accurate":
return knowledge.AccurateParseMode, nil
default:
return "", fmt.Errorf("invalid parsingType: %s", p)
}
}
func convertChunkType(p string) (knowledge.ChunkType, error) {
switch p {
case "custom":
return knowledge.ChunkTypeCustom, nil
case "default":
return knowledge.ChunkTypeDefault, nil
default:
return "", fmt.Errorf("invalid ChunkType: %s", p)
}
}
func convertRetrievalSearchType(s int64) (knowledge.SearchType, error) {
switch s {
case 0:
return knowledge.SearchTypeSemantic, nil
case 1:
return knowledge.SearchTypeHybrid, nil
case 20:
return knowledge.SearchTypeFullText, nil
default:
return "", fmt.Errorf("invalid RetrievalSearchType %v", s)
}
}

View File

@ -21,27 +21,45 @@ import (
"errors"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
type DeleterConfig struct {
KnowledgeID int64
KnowledgeDeleter knowledge.KnowledgeOperator
}
type DeleterConfig struct{}
type KnowledgeDeleter struct {
config *DeleterConfig
}
func NewKnowledgeDeleter(_ context.Context, cfg *DeleterConfig) (*KnowledgeDeleter, error) {
if cfg.KnowledgeDeleter == nil {
return nil, errors.New("knowledge deleter is required")
func (d *DeleterConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeKnowledgeDeleter,
Name: n.Data.Meta.Title,
Configs: d,
}
return &KnowledgeDeleter{
config: cfg,
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func (d *DeleterConfig) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
return &Deleter{
knowledgeDeleter: knowledge.GetKnowledgeOperator(),
}, nil
}
func (k *KnowledgeDeleter) Delete(ctx context.Context, input map[string]any) (map[string]any, error) {
type Deleter struct {
knowledgeDeleter knowledge.KnowledgeOperator
}
func (k *Deleter) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
documentID, ok := input["documentID"].(string)
if !ok {
return nil, errors.New("documentID is required and must be a string")
@ -51,7 +69,7 @@ func (k *KnowledgeDeleter) Delete(ctx context.Context, input map[string]any) (ma
DocumentID: documentID,
}
response, err := k.config.KnowledgeDeleter.Delete(ctx, req)
response, err := k.knowledgeDeleter.Delete(ctx, req)
if err != nil {
return nil, err
}

View File

@ -24,7 +24,14 @@ import (
"path/filepath"
"strings"
"github.com/spf13/cast"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
)
@ -32,30 +39,88 @@ type IndexerConfig struct {
KnowledgeID int64
ParsingStrategy *knowledge.ParsingStrategy
ChunkingStrategy *knowledge.ChunkingStrategy
KnowledgeIndexer knowledge.KnowledgeOperator
}
type KnowledgeIndexer struct {
config *IndexerConfig
func (i *IndexerConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeKnowledgeIndexer,
Name: n.Data.Meta.Title,
Configs: i,
}
inputs := n.Data.Inputs
datasetListInfoParam := inputs.DatasetParam[0]
datasetIDs := datasetListInfoParam.Input.Value.Content.([]any)
if len(datasetIDs) == 0 {
return nil, fmt.Errorf("dataset ids is required")
}
knowledgeID, err := cast.ToInt64E(datasetIDs[0])
if err != nil {
return nil, err
}
i.KnowledgeID = knowledgeID
ps := inputs.StrategyParam.ParsingStrategy
parseMode, err := convertParsingType(ps.ParsingType)
if err != nil {
return nil, err
}
parsingStrategy := &knowledge.ParsingStrategy{
ParseMode: parseMode,
ImageOCR: ps.ImageOcr,
ExtractImage: ps.ImageExtraction,
ExtractTable: ps.TableExtraction,
}
i.ParsingStrategy = parsingStrategy
cs := inputs.StrategyParam.ChunkStrategy
chunkType, err := convertChunkType(cs.ChunkType)
if err != nil {
return nil, err
}
chunkingStrategy := &knowledge.ChunkingStrategy{
ChunkType: chunkType,
Separator: cs.Separator,
ChunkSize: cs.MaxToken,
Overlap: int64(cs.Overlap * float64(cs.MaxToken)),
}
i.ChunkingStrategy = chunkingStrategy
if err = convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func NewKnowledgeIndexer(_ context.Context, cfg *IndexerConfig) (*KnowledgeIndexer, error) {
if cfg.ParsingStrategy == nil {
func (i *IndexerConfig) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
if i.ParsingStrategy == nil {
return nil, errors.New("parsing strategy is required")
}
if cfg.ChunkingStrategy == nil {
if i.ChunkingStrategy == nil {
return nil, errors.New("chunking strategy is required")
}
if cfg.KnowledgeIndexer == nil {
return nil, errors.New("knowledge indexer is required")
}
return &KnowledgeIndexer{
config: cfg,
return &Indexer{
knowledgeID: i.KnowledgeID,
parsingStrategy: i.ParsingStrategy,
chunkingStrategy: i.ChunkingStrategy,
knowledgeIndexer: knowledge.GetKnowledgeOperator(),
}, nil
}
func (k *KnowledgeIndexer) Store(ctx context.Context, input map[string]any) (map[string]any, error) {
type Indexer struct {
knowledgeID int64
parsingStrategy *knowledge.ParsingStrategy
chunkingStrategy *knowledge.ChunkingStrategy
knowledgeIndexer knowledge.KnowledgeOperator
}
func (k *Indexer) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
fileURL, ok := input["knowledge"].(string)
if !ok {
return nil, errors.New("knowledge is required")
@ -68,15 +133,15 @@ func (k *KnowledgeIndexer) Store(ctx context.Context, input map[string]any) (map
}
req := &knowledge.CreateDocumentRequest{
KnowledgeID: k.config.KnowledgeID,
ParsingStrategy: k.config.ParsingStrategy,
ChunkingStrategy: k.config.ChunkingStrategy,
KnowledgeID: k.knowledgeID,
ParsingStrategy: k.parsingStrategy,
ChunkingStrategy: k.chunkingStrategy,
FileURL: fileURL,
FileName: fileName,
FileExtension: ext,
}
response, err := k.config.KnowledgeIndexer.Store(ctx, req)
response, err := k.knowledgeIndexer.Store(ctx, req)
if err != nil {
return nil, err
}

View File

@ -20,7 +20,14 @@ import (
"context"
"errors"
"github.com/spf13/cast"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
)
@ -29,37 +36,136 @@ const outputList = "outputList"
type RetrieveConfig struct {
KnowledgeIDs []int64
RetrievalStrategy *knowledge.RetrievalStrategy
Retriever knowledge.KnowledgeOperator
}
type KnowledgeRetrieve struct {
config *RetrieveConfig
func (r *RetrieveConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeKnowledgeRetriever,
Name: n.Data.Meta.Title,
Configs: r,
}
inputs := n.Data.Inputs
datasetListInfoParam := inputs.DatasetParam[0]
datasetIDs := datasetListInfoParam.Input.Value.Content.([]any)
knowledgeIDs := make([]int64, 0, len(datasetIDs))
for _, id := range datasetIDs {
k, err := cast.ToInt64E(id)
if err != nil {
return nil, err
}
knowledgeIDs = append(knowledgeIDs, k)
}
r.KnowledgeIDs = knowledgeIDs
retrievalStrategy := &knowledge.RetrievalStrategy{}
var getDesignatedParamContent = func(name string) (any, bool) {
for _, param := range inputs.DatasetParam {
if param.Name == name {
return param.Input.Value.Content, true
}
}
return nil, false
}
if content, ok := getDesignatedParamContent("topK"); ok {
topK, err := cast.ToInt64E(content)
if err != nil {
return nil, err
}
retrievalStrategy.TopK = &topK
}
if content, ok := getDesignatedParamContent("useRerank"); ok {
useRerank, err := cast.ToBoolE(content)
if err != nil {
return nil, err
}
retrievalStrategy.EnableRerank = useRerank
}
if content, ok := getDesignatedParamContent("useRewrite"); ok {
useRewrite, err := cast.ToBoolE(content)
if err != nil {
return nil, err
}
retrievalStrategy.EnableQueryRewrite = useRewrite
}
if content, ok := getDesignatedParamContent("isPersonalOnly"); ok {
isPersonalOnly, err := cast.ToBoolE(content)
if err != nil {
return nil, err
}
retrievalStrategy.IsPersonalOnly = isPersonalOnly
}
if content, ok := getDesignatedParamContent("useNl2sql"); ok {
useNl2sql, err := cast.ToBoolE(content)
if err != nil {
return nil, err
}
retrievalStrategy.EnableNL2SQL = useNl2sql
}
if content, ok := getDesignatedParamContent("minScore"); ok {
minScore, err := cast.ToFloat64E(content)
if err != nil {
return nil, err
}
retrievalStrategy.MinScore = &minScore
}
if content, ok := getDesignatedParamContent("strategy"); ok {
strategy, err := cast.ToInt64E(content)
if err != nil {
return nil, err
}
searchType, err := convertRetrievalSearchType(strategy)
if err != nil {
return nil, err
}
retrievalStrategy.SearchType = searchType
}
r.RetrievalStrategy = retrievalStrategy
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func NewKnowledgeRetrieve(_ context.Context, cfg *RetrieveConfig) (*KnowledgeRetrieve, error) {
if cfg == nil {
return nil, errors.New("cfg is required")
func (r *RetrieveConfig) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
if len(r.KnowledgeIDs) == 0 {
return nil, errors.New("knowledge ids are required")
}
if cfg.Retriever == nil {
return nil, errors.New("retriever is required")
}
if len(cfg.KnowledgeIDs) == 0 {
return nil, errors.New("knowledgeI ids is required")
}
if cfg.RetrievalStrategy == nil {
if r.RetrievalStrategy == nil {
return nil, errors.New("retrieval strategy is required")
}
return &KnowledgeRetrieve{
config: cfg,
return &Retrieve{
knowledgeIDs: r.KnowledgeIDs,
retrievalStrategy: r.RetrievalStrategy,
retriever: knowledge.GetKnowledgeOperator(),
}, nil
}
func (kr *KnowledgeRetrieve) Retrieve(ctx context.Context, input map[string]any) (map[string]any, error) {
type Retrieve struct {
knowledgeIDs []int64
retrievalStrategy *knowledge.RetrievalStrategy
retriever knowledge.KnowledgeOperator
}
func (kr *Retrieve) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
query, ok := input["Query"].(string)
if !ok {
return nil, errors.New("capital query key is required")
@ -67,11 +173,11 @@ func (kr *KnowledgeRetrieve) Retrieve(ctx context.Context, input map[string]any)
req := &knowledge.RetrieveRequest{
Query: query,
KnowledgeIDs: kr.config.KnowledgeIDs,
RetrievalStrategy: kr.config.RetrievalStrategy,
KnowledgeIDs: kr.knowledgeIDs,
RetrievalStrategy: kr.retrievalStrategy,
}
response, err := kr.config.Retriever.Retrieve(ctx, req)
response, err := kr.retriever.Retrieve(ctx, req)
if err != nil {
return nil, err
}

View File

@ -34,13 +34,20 @@ import (
callbacks2 "github.com/cloudwego/eino/utils/callbacks"
"golang.org/x/exp/maps"
workflow3 "github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow"
crossknowledge "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
crossmodel "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
@ -143,126 +150,408 @@ const (
)
type RetrievalStrategy struct {
RetrievalStrategy *crossknowledge.RetrievalStrategy
RetrievalStrategy *knowledge.RetrievalStrategy
NoReCallReplyMode NoReCallReplyMode
NoReCallReplyCustomizePrompt string
}
type KnowledgeRecallConfig struct {
ChatModel model.BaseChatModel
Retriever crossknowledge.KnowledgeOperator
Retriever knowledge.KnowledgeOperator
RetrievalStrategy *RetrievalStrategy
SelectedKnowledgeDetails []*crossknowledge.KnowledgeDetail
SelectedKnowledgeDetails []*knowledge.KnowledgeDetail
}
type Config struct {
ChatModel ModelWithInfo
Tools []tool.BaseTool
SystemPrompt string
UserPrompt string
OutputFormat Format
InputFields map[string]*vo.TypeInfo
OutputFields map[string]*vo.TypeInfo
ToolsReturnDirectly map[string]bool
KnowledgeRecallConfig *KnowledgeRecallConfig
FullSources map[string]*nodes.SourceInfo
SystemPrompt string
UserPrompt string
OutputFormat Format
LLMParams *crossmodel.LLMParams
FCParam *vo.FCParam
BackupLLMParams *crossmodel.LLMParams
}
type LLM struct {
r compose.Runnable[map[string]any, map[string]any]
outputFormat Format
outputFields map[string]*vo.TypeInfo
canStream bool
requireCheckpoint bool
fullSources map[string]*nodes.SourceInfo
}
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema2.NodeSchema, error) {
ns := &schema2.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeLLM,
Name: n.Data.Meta.Title,
Configs: c,
}
const (
rawOutputKey = "llm_raw_output_%s"
warningKey = "llm_warning_%s"
)
param := n.Data.Inputs.LLMParam
if param == nil {
return nil, fmt.Errorf("llm node's llmParam is nil")
}
func jsonParse(ctx context.Context, data string, schema_ map[string]*vo.TypeInfo) (map[string]any, error) {
data = nodes.ExtractJSONString(data)
var result map[string]any
err := sonic.UnmarshalString(data, &result)
bs, _ := sonic.Marshal(param)
llmParam := make(vo.LLMParam, 0)
if err := sonic.Unmarshal(bs, &llmParam); err != nil {
return nil, err
}
convertedLLMParam, err := llmParamsToLLMParam(llmParam)
if err != nil {
c := execute.GetExeCtx(ctx)
if c != nil {
logs.CtxErrorf(ctx, "failed to parse json: %v, data: %s", err, data)
rawOutputK := fmt.Sprintf(rawOutputKey, c.NodeCtx.NodeKey)
warningK := fmt.Sprintf(warningKey, c.NodeCtx.NodeKey)
ctxcache.Store(ctx, rawOutputK, data)
ctxcache.Store(ctx, warningK, vo.WrapWarn(errno.ErrLLMStructuredOutputParseFail, err))
return map[string]any{}, nil
}
return nil, err
}
r, ws, err := nodes.ConvertInputs(ctx, result, schema_)
if err != nil {
return nil, vo.WrapError(errno.ErrLLMStructuredOutputParseFail, err)
c.LLMParams = convertedLLMParam
c.SystemPrompt = convertedLLMParam.SystemPrompt
c.UserPrompt = convertedLLMParam.Prompt
var resFormat Format
switch convertedLLMParam.ResponseFormat {
case crossmodel.ResponseFormatText:
resFormat = FormatText
case crossmodel.ResponseFormatMarkdown:
resFormat = FormatMarkdown
case crossmodel.ResponseFormatJSON:
resFormat = FormatJSON
default:
return nil, fmt.Errorf("unsupported response format: %d", convertedLLMParam.ResponseFormat)
}
if ws != nil {
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
c.OutputFormat = resFormat
if err = convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
return r, nil
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
if resFormat == FormatJSON {
if len(ns.OutputTypes) == 1 {
for _, v := range ns.OutputTypes {
if v.Type == vo.DataTypeString {
resFormat = FormatText
break
}
}
} else if len(ns.OutputTypes) == 2 {
if _, ok := ns.OutputTypes[ReasoningOutputKey]; ok {
for k, v := range ns.OutputTypes {
if k != ReasoningOutputKey && v.Type == vo.DataTypeString {
resFormat = FormatText
break
}
}
}
}
}
if resFormat == FormatJSON {
ns.StreamConfigs = &schema2.StreamConfig{
CanGeneratesStream: false,
}
} else {
ns.StreamConfigs = &schema2.StreamConfig{
CanGeneratesStream: true,
}
}
if n.Data.Inputs.LLM != nil && n.Data.Inputs.FCParam != nil {
c.FCParam = n.Data.Inputs.FCParam
}
if se := n.Data.Inputs.SettingOnError; se != nil {
if se.Ext != nil && len(se.Ext.BackupLLMParam) > 0 {
var backupLLMParam vo.SimpleLLMParam
if err = sonic.UnmarshalString(se.Ext.BackupLLMParam, &backupLLMParam); err != nil {
return nil, err
}
backupModel, err := simpleLLMParamsToLLMParams(backupLLMParam)
if err != nil {
return nil, err
}
c.BackupLLMParams = backupModel
}
}
return ns, nil
}
func llmParamsToLLMParam(params vo.LLMParam) (*crossmodel.LLMParams, error) {
p := &crossmodel.LLMParams{}
for _, param := range params {
switch param.Name {
case "temperature":
strVal := param.Input.Value.Content.(string)
floatVal, err := strconv.ParseFloat(strVal, 64)
if err != nil {
return nil, err
}
p.Temperature = &floatVal
case "maxTokens":
strVal := param.Input.Value.Content.(string)
intVal, err := strconv.Atoi(strVal)
if err != nil {
return nil, err
}
p.MaxTokens = intVal
case "responseFormat":
strVal := param.Input.Value.Content.(string)
int64Val, err := strconv.ParseInt(strVal, 10, 64)
if err != nil {
return nil, err
}
p.ResponseFormat = crossmodel.ResponseFormat(int64Val)
case "modleName":
strVal := param.Input.Value.Content.(string)
p.ModelName = strVal
case "modelType":
strVal := param.Input.Value.Content.(string)
int64Val, err := strconv.ParseInt(strVal, 10, 64)
if err != nil {
return nil, err
}
p.ModelType = int64Val
case "prompt":
strVal := param.Input.Value.Content.(string)
p.Prompt = strVal
case "enableChatHistory":
boolVar := param.Input.Value.Content.(bool)
p.EnableChatHistory = boolVar
case "systemPrompt":
strVal := param.Input.Value.Content.(string)
p.SystemPrompt = strVal
case "chatHistoryRound", "generationDiversity", "frequencyPenalty", "presencePenalty":
// do nothing
case "topP":
strVal := param.Input.Value.Content.(string)
floatVar, err := strconv.ParseFloat(strVal, 64)
if err != nil {
return nil, err
}
p.TopP = &floatVar
default:
return nil, fmt.Errorf("invalid LLMParam name: %s", param.Name)
}
}
return p, nil
}
func simpleLLMParamsToLLMParams(params vo.SimpleLLMParam) (*crossmodel.LLMParams, error) {
p := &crossmodel.LLMParams{}
p.ModelName = params.ModelName
p.ModelType = params.ModelType
p.Temperature = &params.Temperature
p.MaxTokens = params.MaxTokens
p.TopP = &params.TopP
p.ResponseFormat = params.ResponseFormat
p.SystemPrompt = params.SystemPrompt
return p, nil
}
func getReasoningContent(message *schema.Message) string {
return message.ReasoningContent
}
type Options struct {
nested []nodes.NestedWorkflowOption
toolWorkflowSW *schema.StreamWriter[*entity.Message]
}
func (c *Config) Build(ctx context.Context, ns *schema2.NodeSchema, _ ...schema2.BuildOption) (any, error) {
var (
err error
chatModel, fallbackM model.BaseChatModel
info, fallbackI *modelmgr.Model
modelWithInfo ModelWithInfo
tools []tool.BaseTool
toolsReturnDirectly map[string]bool
knowledgeRecallConfig *KnowledgeRecallConfig
)
type Option func(o *Options)
func WithNestedWorkflowOptions(nested ...nodes.NestedWorkflowOption) Option {
return func(o *Options) {
o.nested = append(o.nested, nested...)
chatModel, info, err = crossmodel.GetManager().GetModel(ctx, c.LLMParams)
if err != nil {
return nil, err
}
}
func WithToolWorkflowMessageWriter(sw *schema.StreamWriter[*entity.Message]) Option {
return func(o *Options) {
o.toolWorkflowSW = sw
exceptionConf := ns.ExceptionConfigs
if exceptionConf != nil && exceptionConf.MaxRetry > 0 {
backupModelParams := c.BackupLLMParams
if backupModelParams != nil {
fallbackM, fallbackI, err = crossmodel.GetManager().GetModel(ctx, backupModelParams)
if err != nil {
return nil, err
}
}
}
}
type llmState = map[string]any
if fallbackM == nil {
modelWithInfo = NewModel(chatModel, info)
} else {
modelWithInfo = NewModelWithFallback(chatModel, fallbackM, info, fallbackI)
}
const agentModelName = "agent_model"
fcParams := c.FCParam
if fcParams != nil {
if fcParams.WorkflowFCParam != nil {
for _, wf := range fcParams.WorkflowFCParam.WorkflowList {
wfIDStr := wf.WorkflowID
wfID, err := strconv.ParseInt(wfIDStr, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid workflow id: %s", wfIDStr)
}
workflowToolConfig := vo.WorkflowToolConfig{}
if wf.FCSetting != nil {
workflowToolConfig.InputParametersConfig = wf.FCSetting.RequestParameters
workflowToolConfig.OutputParametersConfig = wf.FCSetting.ResponseParameters
}
locator := vo.FromDraft
if wf.WorkflowVersion != "" {
locator = vo.FromSpecificVersion
}
wfTool, err := workflow.GetRepository().WorkflowAsTool(ctx, vo.GetPolicy{
ID: wfID,
QType: locator,
Version: wf.WorkflowVersion,
}, workflowToolConfig)
if err != nil {
return nil, err
}
tools = append(tools, wfTool)
if wfTool.TerminatePlan() == vo.UseAnswerContent {
if toolsReturnDirectly == nil {
toolsReturnDirectly = make(map[string]bool)
}
toolInfo, err := wfTool.Info(ctx)
if err != nil {
return nil, err
}
toolsReturnDirectly[toolInfo.Name] = true
}
}
}
if fcParams.PluginFCParam != nil {
pluginToolsInvokableReq := make(map[int64]*plugin.ToolsInvokableRequest)
for _, p := range fcParams.PluginFCParam.PluginList {
pid, err := strconv.ParseInt(p.PluginID, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid plugin id: %s", p.PluginID)
}
toolID, err := strconv.ParseInt(p.ApiId, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid plugin id: %s", p.PluginID)
}
var (
requestParameters []*workflow3.APIParameter
responseParameters []*workflow3.APIParameter
)
if p.FCSetting != nil {
requestParameters = p.FCSetting.RequestParameters
responseParameters = p.FCSetting.ResponseParameters
}
if req, ok := pluginToolsInvokableReq[pid]; ok {
req.ToolsInvokableInfo[toolID] = &plugin.ToolsInvokableInfo{
ToolID: toolID,
RequestAPIParametersConfig: requestParameters,
ResponseAPIParametersConfig: responseParameters,
}
} else {
pluginToolsInfoRequest := &plugin.ToolsInvokableRequest{
PluginEntity: plugin.Entity{
PluginID: pid,
PluginVersion: ptr.Of(p.PluginVersion),
},
ToolsInvokableInfo: map[int64]*plugin.ToolsInvokableInfo{
toolID: {
ToolID: toolID,
RequestAPIParametersConfig: requestParameters,
ResponseAPIParametersConfig: responseParameters,
},
},
IsDraft: p.IsDraft,
}
pluginToolsInvokableReq[pid] = pluginToolsInfoRequest
}
}
inInvokableTools := make([]tool.BaseTool, 0, len(fcParams.PluginFCParam.PluginList))
for _, req := range pluginToolsInvokableReq {
toolMap, err := plugin.GetPluginService().GetPluginInvokableTools(ctx, req)
if err != nil {
return nil, err
}
for _, t := range toolMap {
inInvokableTools = append(inInvokableTools, plugin.NewInvokableTool(t))
}
}
if len(inInvokableTools) > 0 {
tools = append(tools, inInvokableTools...)
}
}
if fcParams.KnowledgeFCParam != nil && len(fcParams.KnowledgeFCParam.KnowledgeList) > 0 {
kwChatModel := workflow.GetRepository().GetKnowledgeRecallChatModel()
if kwChatModel == nil {
return nil, fmt.Errorf("workflow builtin chat model for knowledge recall not configured")
}
knowledgeOperator := knowledge.GetKnowledgeOperator()
setting := fcParams.KnowledgeFCParam.GlobalSetting
knowledgeRecallConfig = &KnowledgeRecallConfig{
ChatModel: kwChatModel,
Retriever: knowledgeOperator,
}
searchType, err := toRetrievalSearchType(setting.SearchMode)
if err != nil {
return nil, err
}
knowledgeRecallConfig.RetrievalStrategy = &RetrievalStrategy{
RetrievalStrategy: &knowledge.RetrievalStrategy{
TopK: ptr.Of(setting.TopK),
MinScore: ptr.Of(setting.MinScore),
SearchType: searchType,
EnableNL2SQL: setting.UseNL2SQL,
EnableQueryRewrite: setting.UseRewrite,
EnableRerank: setting.UseRerank,
},
NoReCallReplyMode: NoReCallReplyMode(setting.NoRecallReplyMode),
NoReCallReplyCustomizePrompt: setting.NoRecallReplyCustomizePrompt,
}
knowledgeIDs := make([]int64, 0, len(fcParams.KnowledgeFCParam.KnowledgeList))
for _, kw := range fcParams.KnowledgeFCParam.KnowledgeList {
kid, err := strconv.ParseInt(kw.ID, 10, 64)
if err != nil {
return nil, err
}
knowledgeIDs = append(knowledgeIDs, kid)
}
detailResp, err := knowledgeOperator.ListKnowledgeDetail(ctx,
&knowledge.ListKnowledgeDetailRequest{
KnowledgeIDs: knowledgeIDs,
})
if err != nil {
return nil, err
}
knowledgeRecallConfig.SelectedKnowledgeDetails = detailResp.KnowledgeDetails
}
}
func New(ctx context.Context, cfg *Config) (*LLM, error) {
g := compose.NewGraph[map[string]any, map[string]any](compose.WithGenLocalState(func(ctx context.Context) (state llmState) {
return llmState{}
}))
var (
hasReasoning bool
canStream = true
)
var hasReasoning bool
format := cfg.OutputFormat
format := c.OutputFormat
if format == FormatJSON {
if len(cfg.OutputFields) == 1 {
for _, v := range cfg.OutputFields {
if len(ns.OutputTypes) == 1 {
for _, v := range ns.OutputTypes {
if v.Type == vo.DataTypeString {
format = FormatText
break
}
}
} else if len(cfg.OutputFields) == 2 {
if _, ok := cfg.OutputFields[ReasoningOutputKey]; ok {
for k, v := range cfg.OutputFields {
} else if len(ns.OutputTypes) == 2 {
if _, ok := ns.OutputTypes[ReasoningOutputKey]; ok {
for k, v := range ns.OutputTypes {
if k != ReasoningOutputKey && v.Type == vo.DataTypeString {
format = FormatText
break
@ -272,10 +561,10 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
}
}
userPrompt := cfg.UserPrompt
userPrompt := c.UserPrompt
switch format {
case FormatJSON:
jsonSchema, err := vo.TypeInfoToJSONSchema(cfg.OutputFields, nil)
jsonSchema, err := vo.TypeInfoToJSONSchema(ns.OutputTypes, nil)
if err != nil {
return nil, err
}
@ -287,20 +576,20 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
case FormatText:
}
if cfg.KnowledgeRecallConfig != nil {
err := injectKnowledgeTool(ctx, g, cfg.UserPrompt, cfg.KnowledgeRecallConfig)
if knowledgeRecallConfig != nil {
err := injectKnowledgeTool(ctx, g, c.UserPrompt, knowledgeRecallConfig)
if err != nil {
return nil, err
}
userPrompt = fmt.Sprintf("{{%s}}%s", knowledgeUserPromptTemplateKey, userPrompt)
inputs := maps.Clone(cfg.InputFields)
inputs := maps.Clone(ns.InputTypes)
inputs[knowledgeUserPromptTemplateKey] = &vo.TypeInfo{
Type: vo.DataTypeString,
}
sp := newPromptTpl(schema.System, cfg.SystemPrompt, inputs, nil)
sp := newPromptTpl(schema.System, c.SystemPrompt, inputs, nil)
up := newPromptTpl(schema.User, userPrompt, inputs, []string{knowledgeUserPromptTemplateKey})
template := newPrompts(sp, up, cfg.ChatModel)
template := newPrompts(sp, up, modelWithInfo)
_ = g.AddChatTemplateNode(templateNodeKey, template,
compose.WithStatePreHandler(func(ctx context.Context, in map[string]any, state llmState) (map[string]any, error) {
@ -312,28 +601,28 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
_ = g.AddEdge(knowledgeLambdaKey, templateNodeKey)
} else {
sp := newPromptTpl(schema.System, cfg.SystemPrompt, cfg.InputFields, nil)
up := newPromptTpl(schema.User, userPrompt, cfg.InputFields, nil)
template := newPrompts(sp, up, cfg.ChatModel)
sp := newPromptTpl(schema.System, c.SystemPrompt, ns.InputTypes, nil)
up := newPromptTpl(schema.User, userPrompt, ns.InputTypes, nil)
template := newPrompts(sp, up, modelWithInfo)
_ = g.AddChatTemplateNode(templateNodeKey, template)
_ = g.AddEdge(compose.START, templateNodeKey)
}
if len(cfg.Tools) > 0 {
m, ok := cfg.ChatModel.(model.ToolCallingChatModel)
if len(tools) > 0 {
m, ok := modelWithInfo.(model.ToolCallingChatModel)
if !ok {
return nil, errors.New("requires a ToolCallingChatModel to use with tools")
}
reactConfig := react.AgentConfig{
ToolCallingModel: m,
ToolsConfig: compose.ToolsNodeConfig{Tools: cfg.Tools},
ToolsConfig: compose.ToolsNodeConfig{Tools: tools},
ModelNodeName: agentModelName,
}
if len(cfg.ToolsReturnDirectly) > 0 {
reactConfig.ToolReturnDirectly = make(map[string]struct{}, len(cfg.ToolsReturnDirectly))
for k := range cfg.ToolsReturnDirectly {
if len(toolsReturnDirectly) > 0 {
reactConfig.ToolReturnDirectly = make(map[string]struct{}, len(toolsReturnDirectly))
for k := range toolsReturnDirectly {
reactConfig.ToolReturnDirectly[k] = struct{}{}
}
}
@ -347,28 +636,26 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
opts = append(opts, compose.WithNodeName("workflow_llm_react_agent"))
_ = g.AddGraphNode(llmNodeKey, agentNode, opts...)
} else {
_ = g.AddChatModelNode(llmNodeKey, cfg.ChatModel)
_ = g.AddChatModelNode(llmNodeKey, modelWithInfo)
}
_ = g.AddEdge(templateNodeKey, llmNodeKey)
if format == FormatJSON {
iConvert := func(ctx context.Context, msg *schema.Message) (map[string]any, error) {
return jsonParse(ctx, msg.Content, cfg.OutputFields)
return jsonParse(ctx, msg.Content, ns.OutputTypes)
}
convertNode := compose.InvokableLambda(iConvert)
_ = g.AddLambdaNode(outputConvertNodeKey, convertNode)
canStream = false
} else {
var outputKey string
if len(cfg.OutputFields) != 1 && len(cfg.OutputFields) != 2 {
if len(ns.OutputTypes) != 1 && len(ns.OutputTypes) != 2 {
panic("impossible")
}
for k, v := range cfg.OutputFields {
for k, v := range ns.OutputTypes {
if v.Type != vo.DataTypeString {
panic("impossible")
}
@ -443,17 +730,17 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
_ = g.AddEdge(outputConvertNodeKey, compose.END)
requireCheckpoint := false
if len(cfg.Tools) > 0 {
if len(tools) > 0 {
requireCheckpoint = true
}
var opts []compose.GraphCompileOption
var compileOpts []compose.GraphCompileOption
if requireCheckpoint {
opts = append(opts, compose.WithCheckPointStore(workflow.GetRepository()))
compileOpts = append(compileOpts, compose.WithCheckPointStore(workflow.GetRepository()))
}
opts = append(opts, compose.WithGraphName("workflow_llm_node_graph"))
compileOpts = append(compileOpts, compose.WithGraphName("workflow_llm_node_graph"))
r, err := g.Compile(ctx, opts...)
r, err := g.Compile(ctx, compileOpts...)
if err != nil {
return nil, err
}
@ -461,15 +748,132 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
llm := &LLM{
r: r,
outputFormat: format,
canStream: canStream,
requireCheckpoint: requireCheckpoint,
fullSources: cfg.FullSources,
fullSources: ns.FullSources,
}
return llm, nil
}
func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...Option) (composeOpts []compose.Option, resumingEvent *entity.InterruptEvent, err error) {
func (c *Config) RequireCheckpoint() bool {
if c.FCParam != nil {
if c.FCParam.WorkflowFCParam != nil || c.FCParam.PluginFCParam != nil {
return true
}
}
return false
}
func (c *Config) FieldStreamType(path compose.FieldPath, ns *schema2.NodeSchema,
sc *schema2.WorkflowSchema) (schema2.FieldStreamType, error) {
if !sc.RequireStreaming() {
return schema2.FieldNotStream, nil
}
if len(path) != 1 {
return schema2.FieldNotStream, nil
}
outputs := ns.OutputTypes
if len(outputs) != 1 && len(outputs) != 2 {
return schema2.FieldNotStream, nil
}
var outputKey string
for key, output := range outputs {
if output.Type != vo.DataTypeString {
return schema2.FieldNotStream, nil
}
if key != ReasoningOutputKey {
if len(outputKey) > 0 {
return schema2.FieldNotStream, nil
}
outputKey = key
}
}
field := path[0]
if field == ReasoningOutputKey || field == outputKey {
return schema2.FieldIsStream, nil
}
return schema2.FieldNotStream, nil
}
func toRetrievalSearchType(s int64) (knowledge.SearchType, error) {
switch s {
case 0:
return knowledge.SearchTypeSemantic, nil
case 1:
return knowledge.SearchTypeHybrid, nil
case 20:
return knowledge.SearchTypeFullText, nil
default:
return "", fmt.Errorf("invalid retrieval search type %v", s)
}
}
type LLM struct {
r compose.Runnable[map[string]any, map[string]any]
outputFormat Format
requireCheckpoint bool
fullSources map[string]*schema2.SourceInfo
}
const (
rawOutputKey = "llm_raw_output_%s"
warningKey = "llm_warning_%s"
)
func jsonParse(ctx context.Context, data string, schema_ map[string]*vo.TypeInfo) (map[string]any, error) {
data = nodes.ExtractJSONString(data)
var result map[string]any
err := sonic.UnmarshalString(data, &result)
if err != nil {
c := execute.GetExeCtx(ctx)
if c != nil {
logs.CtxErrorf(ctx, "failed to parse json: %v, data: %s", err, data)
rawOutputK := fmt.Sprintf(rawOutputKey, c.NodeCtx.NodeKey)
warningK := fmt.Sprintf(warningKey, c.NodeCtx.NodeKey)
ctxcache.Store(ctx, rawOutputK, data)
ctxcache.Store(ctx, warningK, vo.WrapWarn(errno.ErrLLMStructuredOutputParseFail, err))
return map[string]any{}, nil
}
return nil, err
}
r, ws, err := nodes.ConvertInputs(ctx, result, schema_)
if err != nil {
return nil, vo.WrapError(errno.ErrLLMStructuredOutputParseFail, err)
}
if ws != nil {
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
}
return r, nil
}
type llmOptions struct {
toolWorkflowSW *schema.StreamWriter[*entity.Message]
}
func WithToolWorkflowMessageWriter(sw *schema.StreamWriter[*entity.Message]) nodes.NodeOption {
return nodes.WrapImplSpecificOptFn(func(o *llmOptions) {
o.toolWorkflowSW = sw
})
}
type llmState = map[string]any
const agentModelName = "agent_model"
func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...nodes.NodeOption) (composeOpts []compose.Option, resumingEvent *entity.InterruptEvent, err error) {
c := execute.GetExeCtx(ctx)
if c != nil {
resumingEvent = c.NodeCtx.ResumingEvent
@ -502,17 +906,9 @@ func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...Option) (co
composeOpts = append(composeOpts, compose.WithCheckPointID(checkpointID))
}
llmOpts := &Options{}
for _, opt := range opts {
opt(llmOpts)
}
options := nodes.GetCommonOptions(&nodes.NodeOptions{}, opts...)
nestedOpts := &nodes.NestedWorkflowOptions{}
for _, opt := range llmOpts.nested {
opt(nestedOpts)
}
composeOpts = append(composeOpts, nestedOpts.GetOptsForNested()...)
composeOpts = append(composeOpts, options.GetOptsForNested()...)
if resumingEvent != nil {
var (
@ -580,6 +976,7 @@ func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...Option) (co
composeOpts = append(composeOpts, compose.WithToolsNodeOption(compose.WithToolOption(execute.WithExecuteConfig(exeCfg))))
}
llmOpts := nodes.GetImplSpecificOptions(&llmOptions{}, opts...)
if llmOpts.toolWorkflowSW != nil {
toolMsgOpt, toolMsgSR := execute.WithMessagePipe()
composeOpts = append(composeOpts, toolMsgOpt)
@ -697,7 +1094,7 @@ func handleInterrupt(ctx context.Context, err error, resumingEvent *entity.Inter
return compose.NewInterruptAndRerunErr(ie)
}
func (l *LLM) Chat(ctx context.Context, in map[string]any, opts ...Option) (out map[string]any, err error) {
func (l *LLM) Invoke(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (out map[string]any, err error) {
composeOpts, resumingEvent, err := l.prepare(ctx, in, opts...)
if err != nil {
return nil, err
@ -712,7 +1109,7 @@ func (l *LLM) Chat(ctx context.Context, in map[string]any, opts ...Option) (out
return out, nil
}
func (l *LLM) ChatStream(ctx context.Context, in map[string]any, opts ...Option) (out *schema.StreamReader[map[string]any], err error) {
func (l *LLM) Stream(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (out *schema.StreamReader[map[string]any], err error) {
composeOpts, resumingEvent, err := l.prepare(ctx, in, opts...)
if err != nil {
return nil, err
@ -745,7 +1142,7 @@ func injectKnowledgeTool(_ context.Context, g *compose.Graph[map[string]any, map
_ = g.AddLambdaNode(knowledgeLambdaKey, compose.InvokableLambda(func(ctx context.Context, input *schema.Message) (output map[string]any, err error) {
modelPredictionIDs := strings.Split(input.Content, ",")
selectKwIDs := slices.ToMap(cfg.SelectedKnowledgeDetails, func(e *crossknowledge.KnowledgeDetail) (string, int64) {
selectKwIDs := slices.ToMap(cfg.SelectedKnowledgeDetails, func(e *knowledge.KnowledgeDetail) (string, int64) {
return strconv.Itoa(int(e.ID)), e.ID
})
recallKnowledgeIDs := make([]int64, 0)
@ -759,7 +1156,7 @@ func injectKnowledgeTool(_ context.Context, g *compose.Graph[map[string]any, map
return make(map[string]any), nil
}
docs, err := cfg.Retriever.Retrieve(ctx, &crossknowledge.RetrieveRequest{
docs, err := cfg.Retriever.Retrieve(ctx, &knowledge.RetrieveRequest{
Query: userPrompt,
KnowledgeIDs: recallKnowledgeIDs,
RetrievalStrategy: cfg.RetrievalStrategy.RetrievalStrategy,

View File

@ -26,6 +26,7 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
@ -107,7 +108,7 @@ func newPrompts(sp, up *promptTpl, model ModelWithInfo) *prompts {
}
func (pl *promptTpl) render(ctx context.Context, vs map[string]any,
sources map[string]*nodes.SourceInfo,
sources map[string]*schema2.SourceInfo,
supportedModals map[modelmgr.Modal]bool,
) (*schema.Message, error) {
if !pl.hasMultiModal || len(supportedModals) == 0 {
@ -247,7 +248,7 @@ func (p *prompts) Format(ctx context.Context, vs map[string]any, _ ...prompt.Opt
}
sk := fmt.Sprintf(sourceKey, nodeKey)
sources, ok := ctxcache.Get[map[string]*nodes.SourceInfo](ctx, sk)
sources, ok := ctxcache.Get[map[string]*schema2.SourceInfo](ctx, sk)
if !ok {
return nil, fmt.Errorf("resolved sources not found llm node, key: %s", sk)
}

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
package loop
package _break
import (
"context"
@ -22,21 +22,36 @@ import (
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
type Break struct {
parentIntermediateStore variable.Store
}
func NewBreak(_ context.Context, store variable.Store) (*Break, error) {
type Config struct{}
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
return &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeBreak,
Name: n.Data.Meta.Title,
Configs: c,
}, nil
}
func (c *Config) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
return &Break{
parentIntermediateStore: store,
parentIntermediateStore: &nodes.ParentIntermediateStore{},
}, nil
}
const BreakKey = "$break"
func (b *Break) DoBreak(ctx context.Context, _ map[string]any) (map[string]any, error) {
func (b *Break) Invoke(ctx context.Context, _ map[string]any) (map[string]any, error) {
err := b.parentIntermediateStore.Set(ctx, compose.FieldPath{BreakKey}, true)
if err != nil {
return nil, err

View File

@ -0,0 +1,47 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package _continue
import (
"context"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
type Continue struct{}
type Config struct{}
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
return &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeContinue,
Name: n.Data.Meta.Title,
Configs: c,
}, nil
}
func (c *Config) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
return &Continue{}, nil
}
func (co *Continue) Invoke(_ context.Context, in map[string]any) (map[string]any, error) {
return in, nil
}

View File

@ -27,53 +27,150 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
_break "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop/break"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
type Loop struct {
config *Config
outputs map[string]*vo.FieldSource
outputVars map[string]string
inner compose.Runnable[map[string]any, map[string]any]
nodeKey vo.NodeKey
loopType Type
inputArrays []string
intermediateVars map[string]*vo.TypeInfo
}
type Config struct {
LoopNodeKey vo.NodeKey
LoopType Type
InputArrays []string
IntermediateVars map[string]*vo.TypeInfo
Outputs []*vo.FieldInfo
Inner compose.Runnable[map[string]any, map[string]any]
}
type Type string
const (
ByArray Type = "by_array"
ByIteration Type = "by_iteration"
Infinite Type = "infinite"
)
func NewLoop(_ context.Context, conf *Config) (*Loop, error) {
if conf == nil {
return nil, errors.New("config is nil")
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
if n.Parent() != nil {
return nil, fmt.Errorf("loop node cannot have parent: %s", n.Parent().ID)
}
if conf.LoopType == ByArray {
if len(conf.InputArrays) == 0 {
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeLoop,
Name: n.Data.Meta.Title,
Configs: c,
}
loopType, err := toLoopType(n.Data.Inputs.LoopType)
if err != nil {
return nil, err
}
c.LoopType = loopType
intermediateVars := make(map[string]*vo.TypeInfo)
for _, param := range n.Data.Inputs.VariableParameters {
tInfo, err := convert.CanvasBlockInputToTypeInfo(param.Input)
if err != nil {
return nil, err
}
intermediateVars[param.Name] = tInfo
ns.SetInputType(param.Name, tInfo)
sources, err := convert.CanvasBlockInputToFieldInfo(param.Input, compose.FieldPath{param.Name}, nil)
if err != nil {
return nil, err
}
ns.AddInputSource(sources...)
}
c.IntermediateVars = intermediateVars
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err := convert.SetOutputsForNodeSchema(n, ns); err != nil {
return nil, err
}
for _, fieldInfo := range ns.OutputSources {
if fieldInfo.Source.Ref != nil {
if len(fieldInfo.Source.Ref.FromPath) == 1 {
if _, ok := intermediateVars[fieldInfo.Source.Ref.FromPath[0]]; ok {
fieldInfo.Source.Ref.VariableType = ptr.Of(vo.ParentIntermediate)
}
}
}
}
loopCount := n.Data.Inputs.LoopCount
if loopCount != nil {
typeInfo, err := convert.CanvasBlockInputToTypeInfo(loopCount)
if err != nil {
return nil, err
}
ns.SetInputType(Count, typeInfo)
sources, err := convert.CanvasBlockInputToFieldInfo(loopCount, compose.FieldPath{Count}, nil)
if err != nil {
return nil, err
}
ns.AddInputSource(sources...)
}
for key, tInfo := range ns.InputTypes {
if tInfo.Type != vo.DataTypeArray {
continue
}
if _, ok := intermediateVars[key]; ok { // exclude arrays in intermediate vars
continue
}
c.InputArrays = append(c.InputArrays, key)
}
return ns, nil
}
func toLoopType(l vo.LoopType) (Type, error) {
switch l {
case vo.LoopTypeArray:
return ByArray, nil
case vo.LoopTypeCount:
return ByIteration, nil
case vo.LoopTypeInfinite:
return Infinite, nil
default:
return "", fmt.Errorf("unsupported loop type: %s", l)
}
}
func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, opts ...schema.BuildOption) (any, error) {
if c.LoopType == ByArray {
if len(c.InputArrays) == 0 {
return nil, errors.New("input arrays is empty when loop type is ByArray")
}
}
loop := &Loop{
config: conf,
outputs: make(map[string]*vo.FieldSource),
outputVars: make(map[string]string),
options := schema.GetBuildOptions(opts...)
if options.Inner == nil {
return nil, errors.New("inner workflow is required for Loop Node")
}
for _, info := range conf.Outputs {
loop := &Loop{
outputs: make(map[string]*vo.FieldSource),
outputVars: make(map[string]string),
inputArrays: c.InputArrays,
nodeKey: ns.Key,
intermediateVars: c.IntermediateVars,
inner: options.Inner,
loopType: c.LoopType,
}
for _, info := range ns.OutputSources {
if len(info.Path) != 1 {
return nil, fmt.Errorf("invalid output path: %s", info.Path)
}
@ -87,7 +184,7 @@ func NewLoop(_ context.Context, conf *Config) (*Loop, error) {
return nil, fmt.Errorf("loop output refers to intermediate variable, but path length > 1: %v", fromPath)
}
if _, ok := conf.IntermediateVars[fromPath[0]]; !ok {
if _, ok := c.IntermediateVars[fromPath[0]]; !ok {
return nil, fmt.Errorf("loop output refers to intermediate variable, but not found in intermediate vars: %v", fromPath)
}
@ -102,18 +199,27 @@ func NewLoop(_ context.Context, conf *Config) (*Loop, error) {
return loop, nil
}
type Type string
const (
ByArray Type = "by_array"
ByIteration Type = "by_iteration"
Infinite Type = "infinite"
)
const (
Count = "loopCount"
)
func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.NestedWorkflowOption) (out map[string]any, err error) {
func (l *Loop) Invoke(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (
out map[string]any, err error) {
maxIter, err := l.getMaxIter(in)
if err != nil {
return nil, err
}
arrays := make(map[string][]any, len(l.config.InputArrays))
for _, arrayKey := range l.config.InputArrays {
arrays := make(map[string][]any, len(l.inputArrays))
for _, arrayKey := range l.inputArrays {
a, ok := nodes.TakeMapValue(in, compose.FieldPath{arrayKey})
if !ok {
return nil, fmt.Errorf("incoming array not present in input: %s", arrayKey)
@ -121,10 +227,7 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
arrays[arrayKey] = a.([]any)
}
options := &nodes.NestedWorkflowOptions{}
for _, opt := range opts {
opt(options)
}
options := nodes.GetCommonOptions(&nodes.NodeOptions{}, opts...)
var (
existingCState *nodes.NestedWorkflowState
@ -134,7 +237,7 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
)
err = compose.ProcessState(ctx, func(ctx context.Context, getter nodes.NestedWorkflowAware) error {
var e error
existingCState, _, e = getter.GetNestedWorkflowState(l.config.LoopNodeKey)
existingCState, _, e = getter.GetNestedWorkflowState(l.nodeKey)
if e != nil {
return e
}
@ -150,15 +253,15 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
for k := range existingCState.IntermediateVars {
intermediateVars[k] = ptr.Of(existingCState.IntermediateVars[k])
}
intermediateVars[BreakKey] = &hasBreak
intermediateVars[_break.BreakKey] = &hasBreak
} else {
output = make(map[string]any, len(l.outputs))
for k := range l.outputs {
output[k] = make([]any, 0)
}
intermediateVars = make(map[string]*any, len(l.config.IntermediateVars))
for varKey := range l.config.IntermediateVars {
intermediateVars = make(map[string]*any, len(l.intermediateVars))
for varKey := range l.intermediateVars {
v, ok := nodes.TakeMapValue(in, compose.FieldPath{varKey})
if !ok {
return nil, fmt.Errorf("incoming intermediate variable not present in input: %s", varKey)
@ -166,10 +269,10 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
intermediateVars[varKey] = &v
}
intermediateVars[BreakKey] = &hasBreak
intermediateVars[_break.BreakKey] = &hasBreak
}
ctx = nodes.InitIntermediateVars(ctx, intermediateVars, l.config.IntermediateVars)
ctx = nodes.InitIntermediateVars(ctx, intermediateVars, l.intermediateVars)
getIthInput := func(i int) (map[string]any, map[string]any, error) {
input := make(map[string]any)
@ -190,13 +293,13 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
input[k] = v
}
input[string(l.config.LoopNodeKey)+"#index"] = int64(i)
input[string(l.nodeKey)+"#index"] = int64(i)
items := make(map[string]any)
for arrayKey := range arrays {
ele := arrays[arrayKey][i]
items[arrayKey] = ele
currentKey := string(l.config.LoopNodeKey) + "#" + arrayKey
currentKey := string(l.nodeKey) + "#" + arrayKey
// Recursively expand map[string]any elements
var expand func(prefix string, val interface{})
@ -276,7 +379,7 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
}
}
taskOutput, err := l.config.Inner.Invoke(subCtx, input, ithOpts...)
taskOutput, err := l.inner.Invoke(subCtx, input, ithOpts...)
if err != nil {
info, ok := compose.ExtractInterruptInfo(err)
if !ok {
@ -322,29 +425,26 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
if len(index2InterruptInfo) > 0 { // this invocation of batch.Execute has new interruptions
iEvent := &entity.InterruptEvent{
NodeKey: l.config.LoopNodeKey,
NodeKey: l.nodeKey,
NodeType: entity.NodeTypeLoop,
NestedInterruptInfo: index2InterruptInfo, // only emit the newly generated interruptInfo
}
err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error {
if e := setter.SaveNestedWorkflowState(l.config.LoopNodeKey, compState); e != nil {
if e := setter.SaveNestedWorkflowState(l.nodeKey, compState); e != nil {
return e
}
return setter.SetInterruptEvent(l.config.LoopNodeKey, iEvent)
return setter.SetInterruptEvent(l.nodeKey, iEvent)
})
if err != nil {
return nil, err
}
fmt.Println("save interruptEvent in state within loop: ", iEvent)
fmt.Println("save composite info in state within loop: ", compState)
return nil, compose.InterruptAndRerun
} else {
err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error {
return setter.SaveNestedWorkflowState(l.config.LoopNodeKey, compState)
return setter.SaveNestedWorkflowState(l.nodeKey, compState)
})
if err != nil {
return nil, err
@ -354,8 +454,7 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
}
if existingCState != nil && len(existingCState.Index2InterruptInfo) > 0 {
fmt.Println("no interrupt thrown this round, but has historical interrupt events: ", existingCState.Index2InterruptInfo)
panic("impossible")
panic(fmt.Sprintf("no interrupt thrown this round, but has historical interrupt events: %v", existingCState.Index2InterruptInfo))
}
for outputVarKey, intermediateVarKey := range l.outputVars {
@ -368,9 +467,9 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
func (l *Loop) getMaxIter(in map[string]any) (int, error) {
maxIter := math.MaxInt
switch l.config.LoopType {
switch l.loopType {
case ByArray:
for _, arrayKey := range l.config.InputArrays {
for _, arrayKey := range l.inputArrays {
a, ok := nodes.TakeMapValue(in, compose.FieldPath{arrayKey})
if !ok {
return 0, fmt.Errorf("incoming array not present in input: %s", arrayKey)
@ -394,7 +493,7 @@ func (l *Loop) getMaxIter(in map[string]any) (int, error) {
maxIter = int(iter.(int64))
case Infinite:
default:
return 0, fmt.Errorf("loop type not supported: %v", l.config.LoopType)
return 0, fmt.Errorf("loop type not supported: %v", l.loopType)
}
return maxIter, nil
@ -409,8 +508,8 @@ func convertIntermediateVars(vars map[string]*any) map[string]any {
}
func (l *Loop) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) {
trimmed := make(map[string]any, len(l.config.InputArrays))
for _, arrayKey := range l.config.InputArrays {
trimmed := make(map[string]any, len(l.inputArrays))
for _, arrayKey := range l.inputArrays {
if v, ok := in[arrayKey]; ok {
trimmed[arrayKey] = v
}

View File

@ -1,90 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nodes
import (
"github.com/bytedance/sonic"
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
)
type NestedWorkflowOptions struct {
optsForNested []compose.Option
toResumeIndexes map[int]compose.StateModifier
optsForIndexed map[int][]compose.Option
}
type NestedWorkflowOption func(*NestedWorkflowOptions)
func WithOptsForNested(opts ...compose.Option) NestedWorkflowOption {
return func(o *NestedWorkflowOptions) {
o.optsForNested = append(o.optsForNested, opts...)
}
}
func (c *NestedWorkflowOptions) GetOptsForNested() []compose.Option {
return c.optsForNested
}
func WithResumeIndex(i int, m compose.StateModifier) NestedWorkflowOption {
return func(o *NestedWorkflowOptions) {
if o.toResumeIndexes == nil {
o.toResumeIndexes = map[int]compose.StateModifier{}
}
o.toResumeIndexes[i] = m
}
}
func (c *NestedWorkflowOptions) GetResumeIndexes() map[int]compose.StateModifier {
return c.toResumeIndexes
}
func WithOptsForIndexed(index int, opts ...compose.Option) NestedWorkflowOption {
return func(o *NestedWorkflowOptions) {
if o.optsForIndexed == nil {
o.optsForIndexed = map[int][]compose.Option{}
}
o.optsForIndexed[index] = opts
}
}
func (c *NestedWorkflowOptions) GetOptsForIndexed(index int) []compose.Option {
if c.optsForIndexed == nil {
return nil
}
return c.optsForIndexed[index]
}
type NestedWorkflowState struct {
Index2Done map[int]bool `json:"index_2_done,omitempty"`
Index2InterruptInfo map[int]*compose.InterruptInfo `json:"index_2_interrupt_info,omitempty"`
FullOutput map[string]any `json:"full_output,omitempty"`
IntermediateVars map[string]any `json:"intermediate_vars,omitempty"`
}
func (c *NestedWorkflowState) String() string {
s, _ := sonic.MarshalIndent(c, "", " ")
return string(s)
}
type NestedWorkflowAware interface {
SaveNestedWorkflowState(key vo.NodeKey, state *NestedWorkflowState) error
GetNestedWorkflowState(key vo.NodeKey) (*NestedWorkflowState, bool, error)
InterruptEventStore
}

View File

@ -0,0 +1,194 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nodes
import (
"context"
"fmt"
"github.com/cloudwego/eino/compose"
einoschema "github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
// InvokableNode is a basic workflow node that can Invoke.
// Invoke accepts non-streaming input and returns non-streaming output.
// It does not accept any options.
// Most nodes implement this, such as NodeTypePlugin.
type InvokableNode interface {
Invoke(ctx context.Context, input map[string]any) (
output map[string]any, err error)
}
// InvokableNodeWOpt is a workflow node that can Invoke.
// Invoke accepts non-streaming input and returns non-streaming output.
// It can accept NodeOption.
// e.g. NodeTypeLLM, NodeTypeSubWorkflow implement this.
type InvokableNodeWOpt interface {
Invoke(ctx context.Context, in map[string]any, opts ...NodeOption) (
map[string]any, error)
}
// StreamableNode is a workflow node that can Stream.
// Stream accepts non-streaming input and returns streaming output.
// It does not accept and options
// Currently NO Node implement this.
// A potential example would be streamable plugin for NodeTypePlugin.
type StreamableNode interface {
Stream(ctx context.Context, in map[string]any) (
*einoschema.StreamReader[map[string]any], error)
}
// StreamableNodeWOpt is a workflow node that can Stream.
// Stream accepts non-streaming input and returns streaming output.
// It can accept NodeOption.
// e.g. NodeTypeLLM implement this.
type StreamableNodeWOpt interface {
Stream(ctx context.Context, in map[string]any, opts ...NodeOption) (
*einoschema.StreamReader[map[string]any], error)
}
// CollectableNode is a workflow node that can Collect.
// Collect accepts streaming input and returns non-streaming output.
// It does not accept and options
// Currently NO Node implement this.
// A potential example would be a new condition node that makes decisions
// based on streaming input.
type CollectableNode interface {
Collect(ctx context.Context, in *einoschema.StreamReader[map[string]any]) (
map[string]any, error)
}
// CollectableNodeWOpt is a workflow node that can Collect.
// Collect accepts streaming input and returns non-streaming output.
// It accepts NodeOption.
// Currently NO Node implement this.
// A potential example would be a new batch node that accepts streaming input,
// process them, and finally returns non-stream aggregation of results.
type CollectableNodeWOpt interface {
Collect(ctx context.Context, in *einoschema.StreamReader[map[string]any], opts ...NodeOption) (
map[string]any, error)
}
// TransformableNode is a workflow node that can Transform.
// Transform accepts streaming input and returns streaming output.
// It does not accept and options
// e.g.
// NodeTypeVariableAggregator implements TransformableNode.
type TransformableNode interface {
Transform(ctx context.Context, in *einoschema.StreamReader[map[string]any]) (
*einoschema.StreamReader[map[string]any], error)
}
// TransformableNodeWOpt is a workflow node that can Transform.
// Transform accepts streaming input and returns streaming output.
// It accepts NodeOption.
// Currently NO Node implement this.
// A potential example would be an audio processing node that
// transforms input audio clips, but within the node is a graph
// composed by Eino, and the audio processing node needs to carry
// options for this inner graph.
type TransformableNodeWOpt interface {
Transform(ctx context.Context, in *einoschema.StreamReader[map[string]any], opts ...NodeOption) (
*einoschema.StreamReader[map[string]any], error)
}
// CallbackInputConverted converts node input to a form better suited for UI.
// The converted input will be displayed on canvas when test run,
// and will be returned when querying the node's input through OpenAPI.
type CallbackInputConverted interface {
ToCallbackInput(ctx context.Context, in map[string]any) (map[string]any, error)
}
// CallbackOutputConverted converts node input to a form better suited for UI.
// The converted output will be displayed on canvas when test run,
// and will be returned when querying the node's output through OpenAPI.
type CallbackOutputConverted interface {
ToCallbackOutput(ctx context.Context, out map[string]any) (*StructuredCallbackOutput, error)
}
type Initializer interface {
Init(ctx context.Context) (context.Context, error)
}
type AdaptOptions struct {
Canvas *vo.Canvas
}
type AdaptOption func(*AdaptOptions)
func WithCanvas(canvas *vo.Canvas) AdaptOption {
return func(opts *AdaptOptions) {
opts.Canvas = canvas
}
}
func GetAdaptOptions(opts ...AdaptOption) *AdaptOptions {
options := &AdaptOptions{}
for _, opt := range opts {
opt(options)
}
return options
}
// NodeAdaptor provides conversion from frontend Node to backend NodeSchema.
type NodeAdaptor interface {
Adapt(ctx context.Context, n *vo.Node, opts ...AdaptOption) (
*schema.NodeSchema, error)
}
// BranchAdaptor provides validation and conversion from frontend port to backend port.
type BranchAdaptor interface {
ExpectPorts(ctx context.Context, n *vo.Node) []string
}
var (
nodeAdaptors = map[entity.NodeType]func() NodeAdaptor{}
branchAdaptors = map[entity.NodeType]func() BranchAdaptor{}
)
func RegisterNodeAdaptor(et entity.NodeType, f func() NodeAdaptor) {
nodeAdaptors[et] = f
}
func GetNodeAdaptor(et entity.NodeType) (NodeAdaptor, bool) {
na, ok := nodeAdaptors[et]
if !ok {
panic(fmt.Sprintf("node type %s not registered", et))
}
return na(), ok
}
func RegisterBranchAdaptor(et entity.NodeType, f func() BranchAdaptor) {
branchAdaptors[et] = f
}
func GetBranchAdaptor(et entity.NodeType) (BranchAdaptor, bool) {
na, ok := branchAdaptors[et]
if !ok {
return nil, false
}
return na(), ok
}
type StreamGenerator interface {
FieldStreamType(path compose.FieldPath, ns *schema.NodeSchema,
sc *schema.WorkflowSchema) (schema.FieldStreamType, error)
}

View File

@ -0,0 +1,170 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nodes
import (
"github.com/bytedance/sonic"
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
)
type NodeOptions struct {
Nested *NestedWorkflowOptions
}
type NestedWorkflowOptions struct {
optsForNested []compose.Option
toResumeIndexes map[int]compose.StateModifier
optsForIndexed map[int][]compose.Option
}
type NodeOption struct {
apply func(opts *NodeOptions)
implSpecificOptFn any
}
type NestedWorkflowOption func(*NestedWorkflowOptions)
func WithOptsForNested(opts ...compose.Option) NodeOption {
return NodeOption{
apply: func(options *NodeOptions) {
if options.Nested == nil {
options.Nested = &NestedWorkflowOptions{}
}
options.Nested.optsForNested = append(options.Nested.optsForNested, opts...)
},
}
}
func (c *NodeOptions) GetOptsForNested() []compose.Option {
if c.Nested == nil {
return nil
}
return c.Nested.optsForNested
}
func WithResumeIndex(i int, m compose.StateModifier) NodeOption {
return NodeOption{
apply: func(options *NodeOptions) {
if options.Nested == nil {
options.Nested = &NestedWorkflowOptions{}
}
if options.Nested.toResumeIndexes == nil {
options.Nested.toResumeIndexes = map[int]compose.StateModifier{}
}
options.Nested.toResumeIndexes[i] = m
},
}
}
func (c *NodeOptions) GetResumeIndexes() map[int]compose.StateModifier {
if c.Nested == nil {
return nil
}
return c.Nested.toResumeIndexes
}
func WithOptsForIndexed(index int, opts ...compose.Option) NodeOption {
return NodeOption{
apply: func(options *NodeOptions) {
if options.Nested == nil {
options.Nested = &NestedWorkflowOptions{}
}
if options.Nested.optsForIndexed == nil {
options.Nested.optsForIndexed = map[int][]compose.Option{}
}
options.Nested.optsForIndexed[index] = opts
},
}
}
func (c *NodeOptions) GetOptsForIndexed(index int) []compose.Option {
if c.Nested == nil {
return nil
}
return c.Nested.optsForIndexed[index]
}
// WrapImplSpecificOptFn is the option to wrap the implementation specific option function.
func WrapImplSpecificOptFn[T any](optFn func(*T)) NodeOption {
return NodeOption{
implSpecificOptFn: optFn,
}
}
// GetCommonOptions extract model Options from Option list, optionally providing a base Options with default values.
func GetCommonOptions(base *NodeOptions, opts ...NodeOption) *NodeOptions {
if base == nil {
base = &NodeOptions{}
}
for i := range opts {
opt := opts[i]
if opt.apply != nil {
opt.apply(base)
}
}
return base
}
// GetImplSpecificOptions extract the implementation specific options from Option list, optionally providing a base options with default values.
// e.g.
//
// myOption := &MyOption{
// Field1: "default_value",
// }
//
// myOption := model.GetImplSpecificOptions(myOption, opts...)
func GetImplSpecificOptions[T any](base *T, opts ...NodeOption) *T {
if base == nil {
base = new(T)
}
for i := range opts {
opt := opts[i]
if opt.implSpecificOptFn != nil {
optFn, ok := opt.implSpecificOptFn.(func(*T))
if ok {
optFn(base)
}
}
}
return base
}
type NestedWorkflowState struct {
Index2Done map[int]bool `json:"index_2_done,omitempty"`
Index2InterruptInfo map[int]*compose.InterruptInfo `json:"index_2_interrupt_info,omitempty"`
FullOutput map[string]any `json:"full_output,omitempty"`
IntermediateVars map[string]any `json:"intermediate_vars,omitempty"`
}
func (c *NestedWorkflowState) String() string {
s, _ := sonic.MarshalIndent(c, "", " ")
return string(s)
}
type NestedWorkflowAware interface {
SaveNestedWorkflowState(key vo.NodeKey, state *NestedWorkflowState) error
GetNestedWorkflowState(key vo.NodeKey) (*NestedWorkflowState, bool, error)
InterruptEventStore
}

View File

@ -18,16 +18,21 @@ package plugin
import (
"context"
"errors"
"fmt"
"strconv"
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
@ -35,29 +40,76 @@ type Config struct {
PluginID int64
ToolID int64
PluginVersion string
}
PluginService plugin.Service
func (c *Config) Adapt(ctx context.Context, n *vo.Node, opts ...nodes.AdaptOption) (*schema.NodeSchema, error) {
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypePlugin,
Name: n.Data.Meta.Title,
Configs: c,
}
inputs := n.Data.Inputs
apiParams := slices.ToMap(inputs.APIParams, func(e *vo.Param) (string, *vo.Param) {
return e.Name, e
})
ps, ok := apiParams["pluginID"]
if !ok {
return nil, fmt.Errorf("plugin id param is not found")
}
pID, err := strconv.ParseInt(ps.Input.Value.Content.(string), 10, 64)
c.PluginID = pID
ps, ok = apiParams["apiID"]
if !ok {
return nil, fmt.Errorf("plugin id param is not found")
}
tID, err := strconv.ParseInt(ps.Input.Value.Content.(string), 10, 64)
if err != nil {
return nil, err
}
c.ToolID = tID
ps, ok = apiParams["pluginVersion"]
if !ok {
return nil, fmt.Errorf("plugin version param is not found")
}
version := ps.Input.Value.Content.(string)
c.PluginVersion = version
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func (c *Config) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
return &Plugin{
pluginID: c.PluginID,
toolID: c.ToolID,
pluginVersion: c.PluginVersion,
pluginService: plugin.GetPluginService(),
}, nil
}
type Plugin struct {
config *Config
}
pluginID int64
toolID int64
pluginVersion string
func NewPlugin(_ context.Context, cfg *Config) (*Plugin, error) {
if cfg == nil {
return nil, errors.New("config is nil")
}
if cfg.PluginID == 0 {
return nil, errors.New("plugin id is required")
}
if cfg.ToolID == 0 {
return nil, errors.New("tool id is required")
}
if cfg.PluginService == nil {
return nil, errors.New("tool service is required")
}
return &Plugin{config: cfg}, nil
pluginService plugin.Service
}
func (p *Plugin) Invoke(ctx context.Context, parameters map[string]any) (ret map[string]any, err error) {
@ -65,10 +117,10 @@ func (p *Plugin) Invoke(ctx context.Context, parameters map[string]any) (ret map
if ctxExeCfg := execute.GetExeCtx(ctx); ctxExeCfg != nil {
exeCfg = ctxExeCfg.ExeCfg
}
result, err := p.config.PluginService.ExecutePlugin(ctx, parameters, &vo.PluginEntity{
PluginID: p.config.PluginID,
PluginVersion: ptr.Of(p.config.PluginVersion),
}, p.config.ToolID, exeCfg)
result, err := p.pluginService.ExecutePlugin(ctx, parameters, &vo.PluginEntity{
PluginID: p.pluginID,
PluginVersion: ptr.Of(p.pluginVersion),
}, p.toolID, exeCfg)
if err != nil {
if extra, ok := compose.IsInterruptRerunError(err); ok {
// TODO: temporarily replace interrupt with real error, because frontend cannot handle interrupt for now

View File

@ -29,9 +29,12 @@ import (
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/domain/workflow"
crossmodel "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
@ -39,8 +42,21 @@ import (
)
type QuestionAnswer struct {
config *Config
model model.BaseChatModel
nodeMeta entity.NodeTypeMeta
questionTpl string
answerType AnswerType
choiceType ChoiceType
fixedChoices []string
needExtractFromAnswer bool
additionalSystemPromptTpl string
maxAnswerCount int
nodeKey vo.NodeKey
outputFields map[string]*vo.TypeInfo
}
type Config struct {
@ -51,15 +67,249 @@ type Config struct {
FixedChoices []string
// used for intent recognize if answer by choices and given a custom answer, as well as for extracting structured output from user response
Model model.BaseChatModel
LLMParams *crossmodel.LLMParams
// the following are required if AnswerType is AnswerDirectly and needs to extract from answer
ExtractFromAnswer bool
AdditionalSystemPromptTpl string
MaxAnswerCount int
OutputFields map[string]*vo.TypeInfo
}
NodeKey vo.NodeKey
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema2.NodeSchema, error) {
ns := &schema2.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeQuestionAnswer,
Name: n.Data.Meta.Title,
Configs: c,
}
qaConf := n.Data.Inputs.QA
if qaConf == nil {
return nil, fmt.Errorf("qa config is nil")
}
c.QuestionTpl = qaConf.Question
var llmParams *crossmodel.LLMParams
if n.Data.Inputs.LLMParam != nil {
llmParamBytes, err := sonic.Marshal(n.Data.Inputs.LLMParam)
if err != nil {
return nil, err
}
var qaLLMParams vo.SimpleLLMParam
err = sonic.Unmarshal(llmParamBytes, &qaLLMParams)
if err != nil {
return nil, err
}
llmParams, err = convertLLMParams(qaLLMParams)
if err != nil {
return nil, err
}
c.LLMParams = llmParams
}
answerType, err := convertAnswerType(qaConf.AnswerType)
if err != nil {
return nil, err
}
c.AnswerType = answerType
var choiceType ChoiceType
if len(qaConf.OptionType) > 0 {
choiceType, err = convertChoiceType(qaConf.OptionType)
if err != nil {
return nil, err
}
c.ChoiceType = choiceType
}
if answerType == AnswerByChoices {
switch choiceType {
case FixedChoices:
var options []string
for _, option := range qaConf.Options {
options = append(options, option.Name)
}
c.FixedChoices = options
case DynamicChoices:
inputSources, err := convert.CanvasBlockInputToFieldInfo(qaConf.DynamicOption, compose.FieldPath{DynamicChoicesKey}, n.Parent())
if err != nil {
return nil, err
}
ns.AddInputSource(inputSources...)
inputTypes, err := convert.CanvasBlockInputToTypeInfo(qaConf.DynamicOption)
if err != nil {
return nil, err
}
ns.SetInputType(DynamicChoicesKey, inputTypes)
default:
return nil, fmt.Errorf("qa node is answer by options, but option type not provided")
}
} else if answerType == AnswerDirectly {
c.ExtractFromAnswer = qaConf.ExtractOutput
if qaConf.ExtractOutput {
if llmParams == nil {
return nil, fmt.Errorf("qa node needs to extract from answer, but LLMParams not provided")
}
c.AdditionalSystemPromptTpl = llmParams.SystemPrompt
c.MaxAnswerCount = qaConf.Limit
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
}
}
if err = convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func convertLLMParams(params vo.SimpleLLMParam) (*crossmodel.LLMParams, error) {
p := &crossmodel.LLMParams{}
p.ModelName = params.ModelName
p.ModelType = params.ModelType
p.Temperature = &params.Temperature
p.MaxTokens = params.MaxTokens
p.TopP = &params.TopP
p.ResponseFormat = params.ResponseFormat
p.SystemPrompt = params.SystemPrompt
return p, nil
}
func convertAnswerType(t vo.QAAnswerType) (AnswerType, error) {
switch t {
case vo.QAAnswerTypeOption:
return AnswerByChoices, nil
case vo.QAAnswerTypeText:
return AnswerDirectly, nil
default:
return "", fmt.Errorf("invalid QAAnswerType: %s", t)
}
}
func convertChoiceType(t vo.QAOptionType) (ChoiceType, error) {
switch t {
case vo.QAOptionTypeStatic:
return FixedChoices, nil
case vo.QAOptionTypeDynamic:
return DynamicChoices, nil
default:
return "", fmt.Errorf("invalid QAOptionType: %s", t)
}
}
func (c *Config) Build(ctx context.Context, ns *schema2.NodeSchema, _ ...schema2.BuildOption) (any, error) {
if c.AnswerType == AnswerDirectly {
if c.ExtractFromAnswer {
if c.LLMParams == nil {
return nil, errors.New("model is required when extract from answer")
}
if len(ns.OutputTypes) == 0 {
return nil, errors.New("output fields is required when extract from answer")
}
}
} else if c.AnswerType == AnswerByChoices {
if c.ChoiceType == FixedChoices {
if len(c.FixedChoices) == 0 {
return nil, errors.New("fixed choices is required when extract from answer")
}
}
} else {
return nil, fmt.Errorf("unknown answer type: %s", c.AnswerType)
}
nodeMeta := entity.NodeMetaByNodeType(entity.NodeTypeQuestionAnswer)
if nodeMeta == nil {
return nil, errors.New("node meta not found for question answer")
}
var (
m model.BaseChatModel
err error
)
if c.LLMParams != nil {
m, _, err = crossmodel.GetManager().GetModel(ctx, c.LLMParams)
if err != nil {
return nil, err
}
}
return &QuestionAnswer{
model: m,
nodeMeta: *nodeMeta,
questionTpl: c.QuestionTpl,
answerType: c.AnswerType,
choiceType: c.ChoiceType,
fixedChoices: c.FixedChoices,
needExtractFromAnswer: c.ExtractFromAnswer,
additionalSystemPromptTpl: c.AdditionalSystemPromptTpl,
maxAnswerCount: c.MaxAnswerCount,
nodeKey: ns.Key,
outputFields: ns.OutputTypes,
}, nil
}
func (c *Config) BuildBranch(_ context.Context) (
func(ctx context.Context, nodeOutput map[string]any) (int64, bool, error), bool) {
if c.AnswerType != AnswerByChoices {
return nil, false
}
return func(ctx context.Context, nodeOutput map[string]any) (int64, bool, error) {
optionID, ok := nodeOutput[OptionIDKey]
if !ok {
return -1, false, fmt.Errorf("failed to take option id from input map: %v", nodeOutput)
}
if c.ChoiceType == DynamicChoices {
if optionID.(string) == "other" {
return -1, true, nil
} else {
return 0, false, nil
}
}
if optionID.(string) == "other" {
return -1, true, nil
}
optionIDInt, ok := AlphabetToInt(optionID.(string))
if !ok {
return -1, false, fmt.Errorf("failed to convert option id from input map: %v", optionID)
}
return optionIDInt, false, nil
}, true
}
func (c *Config) ExpectPorts(ctx context.Context, n *vo.Node) (expects []string) {
if n.Data.Inputs.QA.AnswerType != vo.QAAnswerTypeOption {
return expects
}
if n.Data.Inputs.QA.OptionType == vo.QAOptionTypeStatic {
for index := range n.Data.Inputs.QA.Options {
expects = append(expects, fmt.Sprintf(schema2.PortBranchFormat, index))
}
expects = append(expects, schema2.PortDefault)
return expects
}
if n.Data.Inputs.QA.OptionType == vo.QAOptionTypeDynamic {
expects = append(expects, fmt.Sprintf(schema2.PortBranchFormat, 0))
expects = append(expects, schema2.PortDefault)
}
return expects
}
func (c *Config) RequireCheckpoint() bool {
return true
}
type AnswerType string
@ -126,41 +376,6 @@ Strictly identify the intention and select the most suitable option. You can onl
Note: You can only output the id or -1. Your output can only be a pure number and no other content (including the reason)!`
)
func NewQuestionAnswer(_ context.Context, conf *Config) (*QuestionAnswer, error) {
if conf == nil {
return nil, errors.New("config is nil")
}
if conf.AnswerType == AnswerDirectly {
if conf.ExtractFromAnswer {
if conf.Model == nil {
return nil, errors.New("model is required when extract from answer")
}
if len(conf.OutputFields) == 0 {
return nil, errors.New("output fields is required when extract from answer")
}
}
} else if conf.AnswerType == AnswerByChoices {
if conf.ChoiceType == FixedChoices {
if len(conf.FixedChoices) == 0 {
return nil, errors.New("fixed choices is required when extract from answer")
}
}
} else {
return nil, fmt.Errorf("unknown answer type: %s", conf.AnswerType)
}
nodeMeta := entity.NodeMetaByNodeType(entity.NodeTypeQuestionAnswer)
if nodeMeta == nil {
return nil, errors.New("node meta not found for question answer")
}
return &QuestionAnswer{
config: conf,
nodeMeta: *nodeMeta,
}, nil
}
type Question struct {
Question string
Choices []string
@ -182,10 +397,10 @@ type message struct {
ID string `json:"id,omitempty"`
}
// Execute formats the question (optionally with choices), interrupts, then extracts the answer.
// Invoke formats the question (optionally with choices), interrupts, then extracts the answer.
// input: the references by input fields, as well as the dynamic choices array if needed.
// output: USER_RESPONSE for direct answer, structured output if needs to extract from answer, and option ID / content for answer by choices.
func (q *QuestionAnswer) Execute(ctx context.Context, in map[string]any) (out map[string]any, err error) {
func (q *QuestionAnswer) Invoke(ctx context.Context, in map[string]any) (out map[string]any, err error) {
var (
questions []*Question
answers []string
@ -206,11 +421,11 @@ func (q *QuestionAnswer) Execute(ctx context.Context, in map[string]any) (out ma
out[QuestionsKey] = questions
out[AnswersKey] = answers
switch q.config.AnswerType {
switch q.answerType {
case AnswerDirectly:
if isFirst { // first execution, ask the question
// format the question. Which is common to all use cases
firstQuestion, err := nodes.TemplateRender(q.config.QuestionTpl, in)
firstQuestion, err := nodes.TemplateRender(q.questionTpl, in)
if err != nil {
return nil, err
}
@ -218,7 +433,7 @@ func (q *QuestionAnswer) Execute(ctx context.Context, in map[string]any) (out ma
return nil, q.interrupt(ctx, firstQuestion, nil, nil, nil)
}
if q.config.ExtractFromAnswer {
if q.needExtractFromAnswer {
return q.extractFromAnswer(ctx, in, questions, answers)
}
@ -253,15 +468,15 @@ func (q *QuestionAnswer) Execute(ctx context.Context, in map[string]any) (out ma
}
// format the question. Which is common to all use cases
firstQuestion, err := nodes.TemplateRender(q.config.QuestionTpl, in)
firstQuestion, err := nodes.TemplateRender(q.questionTpl, in)
if err != nil {
return nil, err
}
var formattedChoices []string
switch q.config.ChoiceType {
switch q.choiceType {
case FixedChoices:
for _, choice := range q.config.FixedChoices {
for _, choice := range q.fixedChoices {
formattedChoice, err := nodes.TemplateRender(choice, in)
if err != nil {
return nil, err
@ -283,18 +498,18 @@ func (q *QuestionAnswer) Execute(ctx context.Context, in map[string]any) (out ma
formattedChoices = append(formattedChoices, c)
}
default:
return nil, fmt.Errorf("unknown choice type: %s", q.config.ChoiceType)
return nil, fmt.Errorf("unknown choice type: %s", q.choiceType)
}
return nil, q.interrupt(ctx, firstQuestion, formattedChoices, nil, nil)
default:
return nil, fmt.Errorf("unknown answer type: %s", q.config.AnswerType)
return nil, fmt.Errorf("unknown answer type: %s", q.answerType)
}
}
func (q *QuestionAnswer) extractFromAnswer(ctx context.Context, in map[string]any, questions []*Question, answers []string) (map[string]any, error) {
fieldInfo := "FieldInfo"
s, err := vo.TypeInfoToJSONSchema(q.config.OutputFields, &fieldInfo)
s, err := vo.TypeInfoToJSONSchema(q.outputFields, &fieldInfo)
if err != nil {
return nil, err
}
@ -302,15 +517,15 @@ func (q *QuestionAnswer) extractFromAnswer(ctx context.Context, in map[string]an
sysPrompt := fmt.Sprintf(extractSystemPrompt, s)
var requiredFields []string
for fName, tInfo := range q.config.OutputFields {
for fName, tInfo := range q.outputFields {
if tInfo.Required {
requiredFields = append(requiredFields, fName)
}
}
var formattedAdditionalPrompt string
if len(q.config.AdditionalSystemPromptTpl) > 0 {
additionalPrompt, err := nodes.TemplateRender(q.config.AdditionalSystemPromptTpl, in)
if len(q.additionalSystemPromptTpl) > 0 {
additionalPrompt, err := nodes.TemplateRender(q.additionalSystemPromptTpl, in)
if err != nil {
return nil, err
}
@ -336,7 +551,7 @@ func (q *QuestionAnswer) extractFromAnswer(ctx context.Context, in map[string]an
messages = append(messages, schema.UserMessage(answer))
}
out, err := q.config.Model.Generate(ctx, messages)
out, err := q.model.Generate(ctx, messages)
if err != nil {
return nil, err
}
@ -353,8 +568,8 @@ func (q *QuestionAnswer) extractFromAnswer(ctx context.Context, in map[string]an
if ok {
nextQuestionStr, ok := nextQuestion.(string)
if ok && len(nextQuestionStr) > 0 {
if len(answers) >= q.config.MaxAnswerCount {
return nil, fmt.Errorf("max answer count= %d exceeded", q.config.MaxAnswerCount)
if len(answers) >= q.maxAnswerCount {
return nil, fmt.Errorf("max answer count= %d exceeded", q.maxAnswerCount)
}
return nil, q.interrupt(ctx, nextQuestionStr, nil, questions, answers)
@ -366,7 +581,7 @@ func (q *QuestionAnswer) extractFromAnswer(ctx context.Context, in map[string]an
return nil, fmt.Errorf("field %s not found", fieldInfo)
}
realOutput, ws, err := nodes.ConvertInputs(ctx, fields.(map[string]any), q.config.OutputFields, nodes.SkipRequireCheck())
realOutput, ws, err := nodes.ConvertInputs(ctx, fields.(map[string]any), q.outputFields, nodes.SkipRequireCheck())
if err != nil {
return nil, err
}
@ -431,7 +646,7 @@ func (q *QuestionAnswer) intentDetect(ctx context.Context, answer string, choice
schema.UserMessage(answer),
}
out, err := q.config.Model.Generate(ctx, messages)
out, err := q.model.Generate(ctx, messages)
if err != nil {
return -1, err
}
@ -468,7 +683,7 @@ func (q *QuestionAnswer) interrupt(ctx context.Context, newQuestion string, choi
event := &entity.InterruptEvent{
ID: eventID,
NodeKey: q.config.NodeKey,
NodeKey: q.nodeKey,
NodeType: entity.NodeTypeQuestionAnswer,
NodeTitle: q.nodeMeta.Name,
NodeIcon: q.nodeMeta.IconURL,
@ -477,7 +692,7 @@ func (q *QuestionAnswer) interrupt(ctx context.Context, newQuestion string, choi
}
_ = compose.ProcessState(ctx, func(ctx context.Context, setter QuestionAnswerAware) error {
setter.AddQuestion(q.config.NodeKey, &Question{
setter.AddQuestion(q.nodeKey, &Question{
Question: newQuestion,
Choices: choices,
})
@ -495,14 +710,14 @@ func intToAlphabet(num int) string {
return ""
}
func AlphabetToInt(str string) (int, bool) {
func AlphabetToInt(str string) (int64, bool) {
if len(str) != 1 {
return 0, false
}
char := rune(str[0])
char = unicode.ToUpper(char)
if char >= 'A' && char <= 'Z' {
return int(char - 'A'), true
return int64(char - 'A'), true
}
return 0, false
}
@ -521,14 +736,14 @@ func (q *QuestionAnswer) generateHistory(oldQuestions []*Question, oldAnswers []
for i := 0; i < len(oldQuestions); i++ {
oldQuestion := oldQuestions[i]
oldAnswer := oldAnswers[i]
contentType := ternary.IFElse(q.config.AnswerType == AnswerByChoices, "option", "text")
contentType := ternary.IFElse(q.answerType == AnswerByChoices, "option", "text")
questionMsg := &message{
Type: "question",
ContentType: contentType,
ID: fmt.Sprintf("%s_%d", q.config.NodeKey, i*2),
ID: fmt.Sprintf("%s_%d", q.nodeKey, i*2),
}
if q.config.AnswerType == AnswerByChoices {
if q.answerType == AnswerByChoices {
questionMsg.Content = optionContent{
Options: conv(oldQuestion.Choices),
Question: oldQuestion.Question,
@ -541,14 +756,14 @@ func (q *QuestionAnswer) generateHistory(oldQuestions []*Question, oldAnswers []
Type: "answer",
ContentType: contentType,
Content: oldAnswer,
ID: fmt.Sprintf("%s_%d", q.config.NodeKey, i+1),
ID: fmt.Sprintf("%s_%d", q.nodeKey, i+1),
}
history = append(history, questionMsg, answerMsg)
}
if newQuestion != nil {
if q.config.AnswerType == AnswerByChoices {
if q.answerType == AnswerByChoices {
history = append(history, &message{
Type: "question",
ContentType: "option",
@ -556,14 +771,14 @@ func (q *QuestionAnswer) generateHistory(oldQuestions []*Question, oldAnswers []
Options: conv(choices),
Question: *newQuestion,
},
ID: fmt.Sprintf("%s_%d", q.config.NodeKey, len(oldQuestions)*2),
ID: fmt.Sprintf("%s_%d", q.nodeKey, len(oldQuestions)*2),
})
} else {
history = append(history, &message{
Type: "question",
ContentType: "text",
Content: *newQuestion,
ID: fmt.Sprintf("%s_%d", q.config.NodeKey, len(oldQuestions)*2),
ID: fmt.Sprintf("%s_%d", q.nodeKey, len(oldQuestions)*2),
})
}
}

View File

@ -27,8 +27,10 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
@ -37,19 +39,27 @@ import (
)
type Config struct {
OutputTypes map[string]*vo.TypeInfo
NodeKey vo.NodeKey
OutputSchema string
}
type InputReceiver struct {
outputTypes map[string]*vo.TypeInfo
interruptData string
nodeKey vo.NodeKey
nodeMeta entity.NodeTypeMeta
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
c.OutputSchema = n.Data.Inputs.OutputSchema
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeInputReceiver,
Name: n.Data.Meta.Title,
Configs: c,
}
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func New(_ context.Context, cfg *Config) (*InputReceiver, error) {
func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
nodeMeta := entity.NodeMetaByNodeType(entity.NodeTypeInputReceiver)
if nodeMeta == nil {
return nil, errors.New("node meta not found for input receiver")
@ -57,7 +67,7 @@ func New(_ context.Context, cfg *Config) (*InputReceiver, error) {
interruptData := map[string]string{
"content_type": "form_schema",
"content": cfg.OutputSchema,
"content": c.OutputSchema,
}
interruptDataStr, err := sonic.ConfigStd.MarshalToString(interruptData) // keep the order of the keys
@ -66,13 +76,24 @@ func New(_ context.Context, cfg *Config) (*InputReceiver, error) {
}
return &InputReceiver{
outputTypes: cfg.OutputTypes,
outputTypes: ns.OutputTypes, // so the node can refer to its output types during execution
nodeMeta: *nodeMeta,
nodeKey: cfg.NodeKey,
nodeKey: ns.Key,
interruptData: interruptDataStr,
}, nil
}
func (c *Config) RequireCheckpoint() bool {
return true
}
type InputReceiver struct {
outputTypes map[string]*vo.TypeInfo
interruptData string
nodeKey vo.NodeKey
nodeMeta entity.NodeTypeMeta
}
const (
ReceivedDataKey = "$received_data"
receiverWarningKey = "receiver_warning_%d_%s"

View File

@ -0,0 +1,190 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package selector
import (
"context"
"fmt"
"strconv"
"strings"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
)
type selectorCallbackField struct {
Key string `json:"key"`
Type vo.DataType `json:"type"`
Value any `json:"value"`
}
type selectorCondition struct {
Left selectorCallbackField `json:"left"`
Operator vo.OperatorType `json:"operator"`
Right *selectorCallbackField `json:"right,omitempty"`
}
type selectorBranch struct {
Conditions []*selectorCondition `json:"conditions"`
Logic vo.LogicType `json:"logic"`
Name string `json:"name"`
}
func (s *Selector) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) {
count := len(s.clauses)
output := make([]*selectorBranch, count)
for _, source := range s.ns.InputSources {
targetPath := source.Path
if len(targetPath) == 2 {
indexStr := targetPath[0]
index, err := strconv.Atoi(indexStr)
if err != nil {
return nil, err
}
branch := output[index]
if branch == nil {
output[index] = &selectorBranch{
Conditions: []*selectorCondition{
{
Operator: s.clauses[index].Single.ToCanvasOperatorType(),
},
},
Logic: ClauseRelationAND.ToVOLogicType(),
}
}
if targetPath[1] == LeftKey {
leftV, ok := nodes.TakeMapValue(in, targetPath)
if !ok {
return nil, fmt.Errorf("failed to take left value of %s", targetPath)
}
if source.Source.Ref.VariableType != nil {
if *source.Source.Ref.VariableType == vo.ParentIntermediate {
parentNodeKey, ok := s.ws.Hierarchy[s.ns.Key]
if !ok {
return nil, fmt.Errorf("failed to find parent node key of %s", s.ns.Key)
}
parentNode := s.ws.GetNode(parentNodeKey)
output[index].Conditions[0].Left = selectorCallbackField{
Key: parentNode.Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
Value: leftV,
}
} else {
output[index].Conditions[0].Left = selectorCallbackField{
Key: "",
Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
Value: leftV,
}
}
} else {
output[index].Conditions[0].Left = selectorCallbackField{
Key: s.ws.GetNode(source.Source.Ref.FromNodeKey).Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
Value: leftV,
}
}
} else if targetPath[1] == RightKey {
rightV, ok := nodes.TakeMapValue(in, targetPath)
if !ok {
return nil, fmt.Errorf("failed to take right value of %s", targetPath)
}
output[index].Conditions[0].Right = &selectorCallbackField{
Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
Value: rightV,
}
}
} else if len(targetPath) == 3 {
indexStr := targetPath[0]
index, err := strconv.Atoi(indexStr)
if err != nil {
return nil, err
}
multi := s.clauses[index].Multi
branch := output[index]
if branch == nil {
output[index] = &selectorBranch{
Conditions: make([]*selectorCondition, len(multi.Clauses)),
Logic: multi.Relation.ToVOLogicType(),
}
}
clauseIndexStr := targetPath[1]
clauseIndex, err := strconv.Atoi(clauseIndexStr)
if err != nil {
return nil, err
}
clause := multi.Clauses[clauseIndex]
if output[index].Conditions[clauseIndex] == nil {
output[index].Conditions[clauseIndex] = &selectorCondition{
Operator: clause.ToCanvasOperatorType(),
}
}
if targetPath[2] == LeftKey {
leftV, ok := nodes.TakeMapValue(in, targetPath)
if !ok {
return nil, fmt.Errorf("failed to take left value of %s", targetPath)
}
if source.Source.Ref.VariableType != nil {
if *source.Source.Ref.VariableType == vo.ParentIntermediate {
parentNodeKey, ok := s.ws.Hierarchy[s.ns.Key]
if !ok {
return nil, fmt.Errorf("failed to find parent node key of %s", s.ns.Key)
}
parentNode := s.ws.GetNode(parentNodeKey)
output[index].Conditions[clauseIndex].Left = selectorCallbackField{
Key: parentNode.Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
Value: leftV,
}
} else {
output[index].Conditions[clauseIndex].Left = selectorCallbackField{
Key: "",
Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
Value: leftV,
}
}
} else {
output[index].Conditions[clauseIndex].Left = selectorCallbackField{
Key: s.ws.GetNode(source.Source.Ref.FromNodeKey).Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
Value: leftV,
}
}
} else if targetPath[2] == RightKey {
rightV, ok := nodes.TakeMapValue(in, targetPath)
if !ok {
return nil, fmt.Errorf("failed to take right value of %s", targetPath)
}
output[index].Conditions[clauseIndex].Right = &selectorCallbackField{
Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
Value: rightV,
}
}
}
}
return map[string]any{"branches": output}, nil
}

View File

@ -180,3 +180,48 @@ func (o *Operator) ToCanvasOperatorType() vo.OperatorType {
panic(fmt.Sprintf("unknown operator: %+v", o))
}
}
func ToSelectorOperator(o vo.OperatorType, leftType *vo.TypeInfo) (Operator, error) {
switch o {
case vo.Equal:
return OperatorEqual, nil
case vo.NotEqual:
return OperatorNotEqual, nil
case vo.LengthGreaterThan:
return OperatorLengthGreater, nil
case vo.LengthGreaterThanEqual:
return OperatorLengthGreaterOrEqual, nil
case vo.LengthLessThan:
return OperatorLengthLesser, nil
case vo.LengthLessThanEqual:
return OperatorLengthLesserOrEqual, nil
case vo.Contain:
if leftType.Type == vo.DataTypeObject {
return OperatorContainKey, nil
}
return OperatorContain, nil
case vo.NotContain:
if leftType.Type == vo.DataTypeObject {
return OperatorNotContainKey, nil
}
return OperatorNotContain, nil
case vo.Empty:
return OperatorEmpty, nil
case vo.NotEmpty:
return OperatorNotEmpty, nil
case vo.True:
return OperatorIsTrue, nil
case vo.False:
return OperatorIsFalse, nil
case vo.GreaterThan:
return OperatorGreater, nil
case vo.GreaterThanEqual:
return OperatorGreaterOrEqual, nil
case vo.LessThan:
return OperatorLesser, nil
case vo.LessThanEqual:
return OperatorLesserOrEqual, nil
default:
return "", fmt.Errorf("unsupported operator type: %d", o)
}
}

View File

@ -17,9 +17,16 @@
package selector
import (
"context"
"fmt"
einoCompose "github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
type ClauseRelation string
@ -29,10 +36,6 @@ const (
ClauseRelationOR ClauseRelation = "or"
)
type Config struct {
Clauses []*OneClauseSchema `json:"clauses"`
}
type OneClauseSchema struct {
Single *Operator `json:"single,omitempty"`
Multi *MultiClauseSchema `json:"multi,omitempty"`
@ -52,3 +55,140 @@ func (c ClauseRelation) ToVOLogicType() vo.LogicType {
panic(fmt.Sprintf("unknown clause relation: %s", c))
}
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
clauses := make([]*OneClauseSchema, 0)
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Name: n.Data.Meta.Title,
Type: entity.NodeTypeSelector,
Configs: c,
}
for i, branchCond := range n.Data.Inputs.Branches {
inputType := &vo.TypeInfo{
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{},
}
if len(branchCond.Condition.Conditions) == 1 { // single condition
cond := branchCond.Condition.Conditions[0]
left := cond.Left
if left == nil {
return nil, fmt.Errorf("operator left is nil")
}
leftType, err := convert.CanvasBlockInputToTypeInfo(left.Input)
if err != nil {
return nil, err
}
leftSources, err := convert.CanvasBlockInputToFieldInfo(left.Input, einoCompose.FieldPath{fmt.Sprintf("%d", i), LeftKey}, n.Parent())
if err != nil {
return nil, err
}
inputType.Properties[LeftKey] = leftType
ns.AddInputSource(leftSources...)
op, err := ToSelectorOperator(cond.Operator, leftType)
if err != nil {
return nil, err
}
if cond.Right != nil {
rightType, err := convert.CanvasBlockInputToTypeInfo(cond.Right.Input)
if err != nil {
return nil, err
}
rightSources, err := convert.CanvasBlockInputToFieldInfo(cond.Right.Input, einoCompose.FieldPath{fmt.Sprintf("%d", i), RightKey}, n.Parent())
if err != nil {
return nil, err
}
inputType.Properties[RightKey] = rightType
ns.AddInputSource(rightSources...)
}
ns.SetInputType(fmt.Sprintf("%d", i), inputType)
clauses = append(clauses, &OneClauseSchema{
Single: &op,
})
continue
}
var relation ClauseRelation
logic := branchCond.Condition.Logic
if logic == vo.OR {
relation = ClauseRelationOR
} else if logic == vo.AND {
relation = ClauseRelationAND
}
var ops []*Operator
for j, cond := range branchCond.Condition.Conditions {
left := cond.Left
if left == nil {
return nil, fmt.Errorf("operator left is nil")
}
leftType, err := convert.CanvasBlockInputToTypeInfo(left.Input)
if err != nil {
return nil, err
}
leftSources, err := convert.CanvasBlockInputToFieldInfo(left.Input, einoCompose.FieldPath{fmt.Sprintf("%d", i), fmt.Sprintf("%d", j), LeftKey}, n.Parent())
if err != nil {
return nil, err
}
inputType.Properties[fmt.Sprintf("%d", j)] = &vo.TypeInfo{
Type: vo.DataTypeObject,
Properties: map[string]*vo.TypeInfo{
LeftKey: leftType,
},
}
ns.AddInputSource(leftSources...)
op, err := ToSelectorOperator(cond.Operator, leftType)
if err != nil {
return nil, err
}
ops = append(ops, &op)
if cond.Right != nil {
rightType, err := convert.CanvasBlockInputToTypeInfo(cond.Right.Input)
if err != nil {
return nil, err
}
rightSources, err := convert.CanvasBlockInputToFieldInfo(cond.Right.Input, einoCompose.FieldPath{fmt.Sprintf("%d", i), fmt.Sprintf("%d", j), RightKey}, n.Parent())
if err != nil {
return nil, err
}
inputType.Properties[fmt.Sprintf("%d", j)].Properties[RightKey] = rightType
ns.AddInputSource(rightSources...)
}
}
ns.SetInputType(fmt.Sprintf("%d", i), inputType)
clauses = append(clauses, &OneClauseSchema{
Multi: &MultiClauseSchema{
Clauses: ops,
Relation: relation,
},
})
}
c.Clauses = clauses
return ns, nil
}

View File

@ -23,23 +23,32 @@ import (
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
type Selector struct {
config *Config
clauses []*OneClauseSchema
ns *schema.NodeSchema
ws *schema.WorkflowSchema
}
func NewSelector(_ context.Context, config *Config) (*Selector, error) {
if config == nil {
return nil, fmt.Errorf("config is nil")
type Config struct {
Clauses []*OneClauseSchema `json:"clauses"`
}
func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, opts ...schema.BuildOption) (any, error) {
ws := schema.GetBuildOptions(opts...).WS
if ws == nil {
return nil, fmt.Errorf("workflow schema is required")
}
if len(config.Clauses) == 0 {
if len(c.Clauses) == 0 {
return nil, fmt.Errorf("config clauses are empty")
}
for _, clause := range config.Clauses {
for _, clause := range c.Clauses {
if clause.Single == nil && clause.Multi == nil {
return nil, fmt.Errorf("single clause and multi clause are both nil")
}
@ -60,10 +69,42 @@ func NewSelector(_ context.Context, config *Config) (*Selector, error) {
}
return &Selector{
config: config,
clauses: c.Clauses,
ns: ns,
ws: ws,
}, nil
}
func (c *Config) BuildBranch(_ context.Context) (
func(ctx context.Context, nodeOutput map[string]any) (int64, bool, error), bool) {
return func(ctx context.Context, nodeOutput map[string]any) (int64, bool, error) {
choice := nodeOutput[SelectKey].(int64)
if choice < 0 || choice > int64(len(c.Clauses)+1) {
return -1, false, fmt.Errorf("selector choice out of range: %d", choice)
}
if choice == int64(len(c.Clauses)) { // default
return -1, true, nil
}
return choice, false, nil
}, true
}
func (c *Config) ExpectPorts(_ context.Context, n *vo.Node) []string {
expects := make([]string, len(n.Data.Inputs.Branches)+1)
expects[0] = "false" // default branch
if len(n.Data.Inputs.Branches) > 0 {
expects[1] = "true" // first condition
}
for i := 1; i < len(n.Data.Inputs.Branches); i++ { // other conditions
expects[i+1] = "true_" + strconv.Itoa(i)
}
return expects
}
type Operants struct {
Left any
Right any
@ -76,14 +117,14 @@ const (
SelectKey = "selected"
)
func (s *Selector) Select(_ context.Context, input map[string]any) (out map[string]any, err error) {
in, err := s.SelectorInputConverter(input)
func (s *Selector) Invoke(_ context.Context, input map[string]any) (out map[string]any, err error) {
in, err := s.selectorInputConverter(input)
if err != nil {
return nil, err
}
predicates := make([]Predicate, 0, len(s.config.Clauses))
for i, oneConf := range s.config.Clauses {
predicates := make([]Predicate, 0, len(s.clauses))
for i, oneConf := range s.clauses {
if oneConf.Single != nil {
left := in[i].Left
right := in[i].Right
@ -132,23 +173,15 @@ func (s *Selector) Select(_ context.Context, input map[string]any) (out map[stri
}
if isTrue {
return map[string]any{SelectKey: i}, nil
return map[string]any{SelectKey: int64(i)}, nil
}
}
return map[string]any{SelectKey: len(in)}, nil // default choice
return map[string]any{SelectKey: int64(len(in))}, nil // default choice
}
func (s *Selector) GetType() string {
return "Selector"
}
func (s *Selector) ConditionCount() int {
return len(s.config.Clauses)
}
func (s *Selector) SelectorInputConverter(in map[string]any) (out []Operants, err error) {
conf := s.config.Clauses
func (s *Selector) selectorInputConverter(in map[string]any) (out []Operants, err error) {
conf := s.clauses
for i, oneConf := range conf {
if oneConf.Single != nil {
@ -187,8 +220,8 @@ func (s *Selector) SelectorInputConverter(in map[string]any) (out []Operants, er
}
func (s *Selector) ToCallbackOutput(_ context.Context, output map[string]any) (*nodes.StructuredCallbackOutput, error) {
count := len(s.config.Clauses)
out := output[SelectKey].(int)
count := int64(len(s.clauses))
out := output[SelectKey].(int64)
if out == count {
cOutput := map[string]any{"result": "pass to else branch"}
return &nodes.StructuredCallbackOutput{

View File

@ -22,57 +22,27 @@ import (
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
var KeyIsFinished = "\x1FKey is finished\x1F"
type Mode string
const (
Streaming Mode = "streaming"
NonStreaming Mode = "non-streaming"
)
type FieldStreamType string
const (
FieldIsStream FieldStreamType = "yes" // absolutely a stream
FieldNotStream FieldStreamType = "no" // absolutely not a stream
FieldMaybeStream FieldStreamType = "maybe" // maybe a stream, requires request-time resolution
FieldSkipped FieldStreamType = "skipped" // the field source's node is skipped
)
// SourceInfo contains stream type for a input field source of a node.
type SourceInfo struct {
// IsIntermediate means this field is itself not a field source, but a map containing one or more field sources.
IsIntermediate bool
// FieldType the stream type of the field. May require request-time resolution in addition to compile-time.
FieldType FieldStreamType
// FromNodeKey is the node key that produces this field source. empty if the field is a static value or variable.
FromNodeKey vo.NodeKey
// FromPath is the path of this field source within the source node. empty if the field is a static value or variable.
FromPath compose.FieldPath
TypeInfo *vo.TypeInfo
// SubSources are SourceInfo for keys within this intermediate Map(Object) field.
SubSources map[string]*SourceInfo
}
type DynamicStreamContainer interface {
SaveDynamicChoice(nodeKey vo.NodeKey, groupToChoice map[string]int)
GetDynamicChoice(nodeKey vo.NodeKey) map[string]int
GetDynamicStreamType(nodeKey vo.NodeKey, group string) (FieldStreamType, error)
GetAllDynamicStreamTypes(nodeKey vo.NodeKey) (map[string]FieldStreamType, error)
GetDynamicStreamType(nodeKey vo.NodeKey, group string) (schema.FieldStreamType, error)
GetAllDynamicStreamTypes(nodeKey vo.NodeKey) (map[string]schema.FieldStreamType, error)
}
// ResolveStreamSources resolves incoming field sources for a node, deciding their stream type.
func ResolveStreamSources(ctx context.Context, sources map[string]*SourceInfo) (map[string]*SourceInfo, error) {
resolved := make(map[string]*SourceInfo, len(sources))
func ResolveStreamSources(ctx context.Context, sources map[string]*schema.SourceInfo) (map[string]*schema.SourceInfo, error) {
resolved := make(map[string]*schema.SourceInfo, len(sources))
nodeKey2Skipped := make(map[vo.NodeKey]bool)
var resolver func(path string, sInfo *SourceInfo) (*SourceInfo, error)
resolver = func(path string, sInfo *SourceInfo) (*SourceInfo, error) {
resolvedNode := &SourceInfo{
var resolver func(path string, sInfo *schema.SourceInfo) (*schema.SourceInfo, error)
resolver = func(path string, sInfo *schema.SourceInfo) (*schema.SourceInfo, error) {
resolvedNode := &schema.SourceInfo{
IsIntermediate: sInfo.IsIntermediate,
FieldType: sInfo.FieldType,
FromNodeKey: sInfo.FromNodeKey,
@ -81,7 +51,7 @@ func ResolveStreamSources(ctx context.Context, sources map[string]*SourceInfo) (
}
if len(sInfo.SubSources) > 0 {
resolvedNode.SubSources = make(map[string]*SourceInfo, len(sInfo.SubSources))
resolvedNode.SubSources = make(map[string]*schema.SourceInfo, len(sInfo.SubSources))
for k, subInfo := range sInfo.SubSources {
resolvedSub, err := resolver(k, subInfo)
@ -109,16 +79,16 @@ func ResolveStreamSources(ctx context.Context, sources map[string]*SourceInfo) (
}
if skipped {
resolvedNode.FieldType = FieldSkipped
resolvedNode.FieldType = schema.FieldSkipped
return resolvedNode, nil
}
if sInfo.FieldType == FieldMaybeStream {
if sInfo.FieldType == schema.FieldMaybeStream {
if len(sInfo.SubSources) > 0 {
panic("a maybe stream field should not have sub sources")
}
var streamType FieldStreamType
var streamType schema.FieldStreamType
err := compose.ProcessState(ctx, func(ctx context.Context, state DynamicStreamContainer) error {
var e error
streamType, e = state.GetDynamicStreamType(sInfo.FromNodeKey, sInfo.FromPath[0])
@ -128,7 +98,7 @@ func ResolveStreamSources(ctx context.Context, sources map[string]*SourceInfo) (
return nil, err
}
return &SourceInfo{
return &schema.SourceInfo{
IsIntermediate: sInfo.IsIntermediate,
FieldType: streamType,
FromNodeKey: sInfo.FromNodeKey,
@ -156,30 +126,12 @@ type NodeExecuteStatusAware interface {
NodeExecuted(key vo.NodeKey) bool
}
func (s *SourceInfo) Skipped() bool {
if !s.IsIntermediate {
return s.FieldType == FieldSkipped
func IsStreamingField(s *schema.NodeSchema, path compose.FieldPath,
sc *schema.WorkflowSchema) (schema.FieldStreamType, error) {
sg, ok := s.Configs.(StreamGenerator)
if !ok {
return schema.FieldNotStream, nil
}
for _, sub := range s.SubSources {
if !sub.Skipped() {
return false
}
}
return true
}
func (s *SourceInfo) FromNode(nodeKey vo.NodeKey) bool {
if !s.IsIntermediate {
return s.FromNodeKey == nodeKey
}
for _, sub := range s.SubSources {
if sub.FromNode(nodeKey) {
return true
}
}
return false
return sg.FieldStreamType(path, s, sc)
}

View File

@ -18,7 +18,6 @@ package subworkflow
import (
"context"
"errors"
"fmt"
"strconv"
@ -29,35 +28,56 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
type Config struct {
Runner compose.Runnable[map[string]any, map[string]any]
WorkflowID int64
WorkflowVersion string
}
func (c *Config) FieldStreamType(path compose.FieldPath, ns *schema2.NodeSchema,
sc *schema2.WorkflowSchema) (schema2.FieldStreamType, error) {
if !sc.RequireStreaming() {
return schema2.FieldNotStream, nil
}
innerWF := ns.SubWorkflowSchema
if !innerWF.RequireStreaming() {
return schema2.FieldNotStream, nil
}
innerExit := innerWF.GetNode(entity.ExitNodeKey)
if innerExit.Configs.(*exit.Config).TerminatePlan == vo.ReturnVariables {
return schema2.FieldNotStream, nil
}
if !innerExit.StreamConfigs.RequireStreamingInput {
return schema2.FieldNotStream, nil
}
if len(path) > 1 || path[0] != "output" {
return schema2.FieldNotStream, fmt.Errorf(
"streaming answering sub-workflow node can only have out field 'output'")
}
return schema2.FieldIsStream, nil
}
type SubWorkflow struct {
cfg *Config
Runner compose.Runnable[map[string]any, map[string]any]
}
func NewSubWorkflow(_ context.Context, cfg *Config) (*SubWorkflow, error) {
if cfg == nil {
return nil, errors.New("config is nil")
}
if cfg.Runner == nil {
return nil, errors.New("runnable is nil")
}
return &SubWorkflow{cfg: cfg}, nil
}
func (s *SubWorkflow) Invoke(ctx context.Context, in map[string]any, opts ...nodes.NestedWorkflowOption) (map[string]any, error) {
func (s *SubWorkflow) Invoke(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (map[string]any, error) {
nestedOpts, nodeKey, err := prepareOptions(ctx, opts...)
if err != nil {
return nil, err
}
out, err := s.cfg.Runner.Invoke(ctx, in, nestedOpts...)
out, err := s.Runner.Invoke(ctx, in, nestedOpts...)
if err != nil {
interruptInfo, ok := compose.ExtractInterruptInfo(err)
if !ok {
@ -82,13 +102,13 @@ func (s *SubWorkflow) Invoke(ctx context.Context, in map[string]any, opts ...nod
return out, nil
}
func (s *SubWorkflow) Stream(ctx context.Context, in map[string]any, opts ...nodes.NestedWorkflowOption) (*schema.StreamReader[map[string]any], error) {
func (s *SubWorkflow) Stream(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (*schema.StreamReader[map[string]any], error) {
nestedOpts, nodeKey, err := prepareOptions(ctx, opts...)
if err != nil {
return nil, err
}
out, err := s.cfg.Runner.Stream(ctx, in, nestedOpts...)
out, err := s.Runner.Stream(ctx, in, nestedOpts...)
if err != nil {
interruptInfo, ok := compose.ExtractInterruptInfo(err)
if !ok {
@ -114,11 +134,8 @@ func (s *SubWorkflow) Stream(ctx context.Context, in map[string]any, opts ...nod
return out, nil
}
func prepareOptions(ctx context.Context, opts ...nodes.NestedWorkflowOption) ([]compose.Option, vo.NodeKey, error) {
options := &nodes.NestedWorkflowOptions{}
for _, opt := range opts {
opt(options)
}
func prepareOptions(ctx context.Context, opts ...nodes.NodeOption) ([]compose.Option, vo.NodeKey, error) {
options := nodes.GetCommonOptions(&nodes.NodeOptions{}, opts...)
nestedOpts := options.GetOptsForNested()

View File

@ -30,6 +30,7 @@ import (
"github.com/bytedance/sonic/ast"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
@ -156,7 +157,7 @@ func removeSlice(s string) string {
type renderOptions struct {
type2CustomRenderer map[reflect.Type]func(any) (string, error)
reservedKey map[string]struct{}
reservedKey map[string]struct{} // a reservedKey will always render, won't check for node skipping
nilRenderer func() (string, error)
}
@ -300,7 +301,7 @@ func (tp TemplatePart) Render(m []byte, opts ...RenderOption) (string, error) {
}
}
func (tp TemplatePart) Skipped(resolvedSources map[string]*SourceInfo) (skipped bool, invalid bool) {
func (tp TemplatePart) Skipped(resolvedSources map[string]*schema.SourceInfo) (skipped bool, invalid bool) {
if len(resolvedSources) == 0 { // no information available, maybe outside the scope of a workflow
return false, false
}
@ -316,7 +317,7 @@ func (tp TemplatePart) Skipped(resolvedSources map[string]*SourceInfo) (skipped
}
if !matchingSource.IsIntermediate {
return matchingSource.FieldType == FieldSkipped, false
return matchingSource.FieldType == schema.FieldSkipped, false
}
for _, subPath := range tp.SubPathsBeforeSlice {
@ -325,20 +326,20 @@ func (tp TemplatePart) Skipped(resolvedSources map[string]*SourceInfo) (skipped
if matchingSource.IsIntermediate { // the user specified a non-existing source, just skip it
return false, true
}
return matchingSource.FieldType == FieldSkipped, false
return matchingSource.FieldType == schema.FieldSkipped, false
}
matchingSource = subSource
}
if !matchingSource.IsIntermediate {
return matchingSource.FieldType == FieldSkipped, false
return matchingSource.FieldType == schema.FieldSkipped, false
}
var checkSourceSkipped func(sInfo *SourceInfo) bool
checkSourceSkipped = func(sInfo *SourceInfo) bool {
var checkSourceSkipped func(sInfo *schema.SourceInfo) bool
checkSourceSkipped = func(sInfo *schema.SourceInfo) bool {
if !sInfo.IsIntermediate {
return sInfo.FieldType == FieldSkipped
return sInfo.FieldType == schema.FieldSkipped
}
for _, subSource := range sInfo.SubSources {
if !checkSourceSkipped(subSource) {
@ -373,7 +374,7 @@ func (tp TemplatePart) TypeInfo(types map[string]*vo.TypeInfo) *vo.TypeInfo {
return currentType
}
func Render(ctx context.Context, tpl string, input map[string]any, sources map[string]*SourceInfo, opts ...RenderOption) (string, error) {
func Render(ctx context.Context, tpl string, input map[string]any, sources map[string]*schema.SourceInfo, opts ...RenderOption) (string, error) {
mi, err := sonic.Marshal(input)
if err != nil {
return "", err

View File

@ -22,7 +22,11 @@ import (
"reflect"
"strings"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
)
@ -34,42 +38,92 @@ const (
)
type Config struct {
Type Type `json:"type"`
Tpl string `json:"tpl"`
ConcatChar string `json:"concatChar"`
Separators []string `json:"separator"`
FullSources map[string]*nodes.SourceInfo `json:"fullSources"`
Type Type `json:"type"`
Tpl string `json:"tpl"`
ConcatChar string `json:"concatChar"`
Separators []string `json:"separator"`
}
type TextProcessor struct {
config *Config
}
func NewTextProcessor(_ context.Context, cfg *Config) (*TextProcessor, error) {
if cfg == nil {
return nil, fmt.Errorf("config requried")
func (c *Config) Adapt(ctx context.Context, n *vo.Node, opts ...nodes.AdaptOption) (*schema.NodeSchema, error) {
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeTextProcessor,
Name: n.Data.Meta.Title,
Configs: c,
}
if cfg.Type == ConcatText && len(cfg.Tpl) == 0 {
if n.Data.Inputs.Method == vo.Concat {
c.Type = ConcatText
params := n.Data.Inputs.ConcatParams
for _, param := range params {
if param.Name == "concatResult" {
c.Tpl = param.Input.Value.Content.(string)
} else if param.Name == "arrayItemConcatChar" {
c.ConcatChar = param.Input.Value.Content.(string)
}
}
} else if n.Data.Inputs.Method == vo.Split {
c.Type = SplitText
params := n.Data.Inputs.SplitParams
separators := make([]string, 0, len(params))
for _, param := range params {
if param.Name == "delimiters" {
delimiters := param.Input.Value.Content.([]any)
for _, d := range delimiters {
separators = append(separators, d.(string))
}
}
}
c.Separators = separators
} else {
return nil, fmt.Errorf("not supported method: %s", n.Data.Inputs.Method)
}
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
if c.Type == ConcatText && len(c.Tpl) == 0 {
return nil, fmt.Errorf("config tpl requried")
}
return &TextProcessor{
config: cfg,
typ: c.Type,
tpl: c.Tpl,
concatChar: c.ConcatChar,
separators: c.Separators,
fullSources: ns.FullSources,
}, nil
}
type TextProcessor struct {
typ Type
tpl string
concatChar string
separators []string
fullSources map[string]*schema.SourceInfo
}
const OutputKey = "output"
func (t *TextProcessor) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
switch t.config.Type {
switch t.typ {
case ConcatText:
arrayRenderer := func(i any) (string, error) {
vs := i.([]any)
return join(vs, t.config.ConcatChar)
return join(vs, t.concatChar)
}
result, err := nodes.Render(ctx, t.config.Tpl, input, t.config.FullSources,
result, err := nodes.Render(ctx, t.tpl, input, t.fullSources,
nodes.WithCustomRender(reflect.TypeOf([]any{}), arrayRenderer))
if err != nil {
return nil, err
@ -86,9 +140,9 @@ func (t *TextProcessor) Invoke(ctx context.Context, input map[string]any) (map[s
if !ok {
return nil, fmt.Errorf("input string field must string type but got %T", valueString)
}
values := strings.Split(valueString, t.config.Separators[0])
values := strings.Split(valueString, t.separators[0])
// Iterate over each delimiter
for _, sep := range t.config.Separators[1:] {
for _, sep := range t.separators[1:] {
var tempParts []string
for _, part := range values {
tempParts = append(tempParts, strings.Split(part, sep)...)
@ -102,7 +156,7 @@ func (t *TextProcessor) Invoke(ctx context.Context, input map[string]any) (map[s
return map[string]any{OutputKey: anyValues}, nil
default:
return nil, fmt.Errorf("not support type %s", t.config.Type)
return nil, fmt.Errorf("not support type %s", t.typ)
}
}

View File

@ -21,6 +21,8 @@ import (
"testing"
"github.com/stretchr/testify/assert"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
func TestNewTextProcessorNodeGenerator(t *testing.T) {
@ -30,10 +32,10 @@ func TestNewTextProcessorNodeGenerator(t *testing.T) {
Type: SplitText,
Separators: []string{",", "|", "."},
}
p, err := NewTextProcessor(ctx, cfg)
p, err := cfg.Build(ctx, &schema2.NodeSchema{})
assert.NoError(t, err)
result, err := p.Invoke(ctx, map[string]any{
result, err := p.(*TextProcessor).Invoke(ctx, map[string]any{
"String": "a,b|c.d,e|f|g",
})
@ -60,9 +62,9 @@ func TestNewTextProcessorNodeGenerator(t *testing.T) {
ConcatChar: `\t`,
Tpl: "fx{{a}}=={{b.b1}}=={{b.b2[1]}}=={{c}}",
}
p, err := NewTextProcessor(context.Background(), cfg)
p, err := cfg.Build(context.Background(), &schema2.NodeSchema{})
result, err := p.Invoke(ctx, in)
result, err := p.(*TextProcessor).Invoke(ctx, in)
assert.NoError(t, err)
assert.Equal(t, result["output"], `fx1\t{"1":1}\t3==1\t2\t3==2=={"c1":"1"}`)
})

View File

@ -32,8 +32,11 @@ import (
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
@ -48,24 +51,147 @@ const (
type Config struct {
MergeStrategy MergeStrategy
GroupLen map[string]int
FullSources map[string]*nodes.SourceInfo
NodeKey vo.NodeKey
InputSources []*vo.FieldInfo
GroupOrder []string // the order the groups are declared in frontend canvas
}
type VariableAggregator struct {
config *Config
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema2.NodeSchema, error) {
ns := &schema2.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeVariableAggregator,
Name: n.Data.Meta.Title,
Configs: c,
}
c.MergeStrategy = FirstNotNullValue
inputs := n.Data.Inputs
groupToLen := make(map[string]int, len(inputs.VariableAggregator.MergeGroups))
for i := range inputs.VariableAggregator.MergeGroups {
group := inputs.VariableAggregator.MergeGroups[i]
tInfo := &vo.TypeInfo{
Type: vo.DataTypeObject,
Properties: make(map[string]*vo.TypeInfo),
}
ns.SetInputType(group.Name, tInfo)
for ii, v := range group.Variables {
name := strconv.Itoa(ii)
valueTypeInfo, err := convert.CanvasBlockInputToTypeInfo(v)
if err != nil {
return nil, err
}
tInfo.Properties[name] = valueTypeInfo
sources, err := convert.CanvasBlockInputToFieldInfo(v, compose.FieldPath{group.Name, name}, n.Parent())
if err != nil {
return nil, err
}
ns.AddInputSource(sources...)
}
length := len(group.Variables)
groupToLen[group.Name] = length
}
groupOrder := make([]string, 0, len(groupToLen))
for i := range inputs.VariableAggregator.MergeGroups {
group := inputs.VariableAggregator.MergeGroups[i]
groupOrder = append(groupOrder, group.Name)
}
c.GroupLen = groupToLen
c.GroupOrder = groupOrder
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func NewVariableAggregator(_ context.Context, cfg *Config) (*VariableAggregator, error) {
if cfg == nil {
return nil, errors.New("config is required")
func (c *Config) Build(_ context.Context, ns *schema2.NodeSchema, _ ...schema2.BuildOption) (any, error) {
if c.MergeStrategy != FirstNotNullValue {
return nil, fmt.Errorf("merge strategy not supported: %v", c.MergeStrategy)
}
if cfg.MergeStrategy != FirstNotNullValue {
return nil, fmt.Errorf("merge strategy not supported: %v", cfg.MergeStrategy)
return &VariableAggregator{
groupLen: c.GroupLen,
fullSources: ns.FullSources,
nodeKey: ns.Key,
}, nil
}
func (c *Config) FieldStreamType(path compose.FieldPath, ns *schema2.NodeSchema,
sc *schema2.WorkflowSchema) (schema2.FieldStreamType, error) {
if !sc.RequireStreaming() {
return schema2.FieldNotStream, nil
}
return &VariableAggregator{config: cfg}, nil
if len(path) == 2 { // asking about a specific index within a group
for _, fInfo := range ns.InputSources {
if len(fInfo.Path) == len(path) {
equal := true
for i := range fInfo.Path {
if fInfo.Path[i] != path[i] {
equal = false
break
}
}
if equal {
if fInfo.Source.Ref == nil || fInfo.Source.Ref.FromNodeKey == "" {
return schema2.FieldNotStream, nil // variables or static values
}
fromNodeKey := fInfo.Source.Ref.FromNodeKey
fromNode := sc.GetNode(fromNodeKey)
if fromNode == nil {
return schema2.FieldNotStream, fmt.Errorf("node %s not found", fromNodeKey)
}
return nodes.IsStreamingField(fromNode, fInfo.Source.Ref.FromPath, sc)
}
}
}
} else if len(path) == 1 { // asking about the entire group
var streamCount, notStreamCount int
for _, fInfo := range ns.InputSources {
if fInfo.Path[0] == path[0] { // belong to the group
if fInfo.Source.Ref != nil && len(fInfo.Source.Ref.FromNodeKey) > 0 {
fromNode := sc.GetNode(fInfo.Source.Ref.FromNodeKey)
if fromNode == nil {
return schema2.FieldNotStream, fmt.Errorf("node %s not found", fInfo.Source.Ref.FromNodeKey)
}
subStreamType, err := nodes.IsStreamingField(fromNode, fInfo.Source.Ref.FromPath, sc)
if err != nil {
return schema2.FieldNotStream, err
}
if subStreamType == schema2.FieldMaybeStream {
return schema2.FieldMaybeStream, nil
} else if subStreamType == schema2.FieldIsStream {
streamCount++
} else {
notStreamCount++
}
}
}
}
if streamCount > 0 && notStreamCount == 0 {
return schema2.FieldIsStream, nil
}
if streamCount == 0 && notStreamCount > 0 {
return schema2.FieldNotStream, nil
}
return schema2.FieldMaybeStream, nil
}
return schema2.FieldNotStream, fmt.Errorf("variable aggregator output path max len = 2, actual: %v", path)
}
type VariableAggregator struct {
groupLen map[string]int
fullSources map[string]*schema2.SourceInfo
nodeKey vo.NodeKey
groupOrder []string // the order the groups are declared in frontend canvas
}
func (v *VariableAggregator) Invoke(ctx context.Context, input map[string]any) (_ map[string]any, err error) {
@ -76,7 +202,7 @@ func (v *VariableAggregator) Invoke(ctx context.Context, input map[string]any) (
result := make(map[string]any)
groupToChoice := make(map[string]int)
for group, length := range v.config.GroupLen {
for group, length := range v.groupLen {
for i := 0; i < length; i++ {
if value, ok := in[group][i]; ok {
if value != nil {
@ -93,14 +219,14 @@ func (v *VariableAggregator) Invoke(ctx context.Context, input map[string]any) (
}
_ = compose.ProcessState(ctx, func(ctx context.Context, state nodes.DynamicStreamContainer) error {
state.SaveDynamicChoice(v.config.NodeKey, groupToChoice)
state.SaveDynamicChoice(v.nodeKey, groupToChoice)
return nil
})
ctxcache.Store(ctx, groupChoiceTypeCacheKey, map[string]nodes.FieldStreamType{}) // none of the choices are stream
ctxcache.Store(ctx, groupChoiceTypeCacheKey, map[string]schema2.FieldStreamType{}) // none of the choices are stream
groupChoices := make([]any, 0, len(v.config.GroupOrder))
for _, group := range v.config.GroupOrder {
groupChoices := make([]any, 0, len(v.groupOrder))
for _, group := range v.groupOrder {
choice := groupToChoice[group]
if choice == -1 {
groupChoices = append(groupChoices, nil)
@ -125,7 +251,7 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
_ *schema.StreamReader[map[string]any], err error) {
inStream := streamInputConverter(input)
resolvedSources, ok := ctxcache.Get[map[string]*nodes.SourceInfo](ctx, resolvedSourcesCacheKey)
resolvedSources, ok := ctxcache.Get[map[string]*schema2.SourceInfo](ctx, resolvedSourcesCacheKey)
if !ok {
panic("unable to get resolvesSources from ctx cache.")
}
@ -138,18 +264,18 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
defer func() {
if err == nil {
groupChoiceToStreamType := map[string]nodes.FieldStreamType{}
groupChoiceToStreamType := map[string]schema2.FieldStreamType{}
for group, choice := range groupToChoice {
if choice != -1 {
item := groupToItems[group][choice]
if _, ok := item.(stream); ok {
groupChoiceToStreamType[group] = nodes.FieldIsStream
groupChoiceToStreamType[group] = schema2.FieldIsStream
}
}
}
groupChoices := make([]any, 0, len(v.config.GroupOrder))
for _, group := range v.config.GroupOrder {
groupChoices := make([]any, 0, len(v.groupOrder))
for _, group := range v.groupOrder {
choice := groupToChoice[group]
if choice == -1 {
groupChoices = append(groupChoices, nil)
@ -174,16 +300,16 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
// - if an element is not stream, actually receive from the stream to check if it's non-nil
groupToCurrentIndex := make(map[string]int) // the currently known smallest index that is non-nil for each group
for group, length := range v.config.GroupLen {
for group, length := range v.groupLen {
groupToItems[group] = make([]any, length)
groupToCurrentIndex[group] = math.MaxInt
for i := 0; i < length; i++ {
fType := resolvedSources[group].SubSources[strconv.Itoa(i)].FieldType
if fType == nodes.FieldSkipped {
if fType == schema2.FieldSkipped {
groupToItems[group][i] = skipped{}
continue
}
if fType == nodes.FieldIsStream {
if fType == schema2.FieldIsStream {
groupToItems[group][i] = stream{}
if ci, _ := groupToCurrentIndex[group]; i < ci {
groupToCurrentIndex[group] = i
@ -211,7 +337,7 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
}
allDone := func() bool {
for group := range v.config.GroupLen {
for group := range v.groupLen {
_, ok := groupToChoice[group]
if !ok {
return false
@ -223,7 +349,7 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
alreadyDone := allDone()
if alreadyDone { // all groups have made their choices, no need to actually read input streams
result := make(map[string]any, len(v.config.GroupLen))
result := make(map[string]any, len(v.groupLen))
allSkip := true
for group := range groupToChoice {
choice := groupToChoice[group]
@ -237,7 +363,7 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
if allSkip { // no need to convert input streams for the output, because all groups are skipped
_ = compose.ProcessState(ctx, func(ctx context.Context, state nodes.DynamicStreamContainer) error {
state.SaveDynamicChoice(v.config.NodeKey, groupToChoice)
state.SaveDynamicChoice(v.nodeKey, groupToChoice)
return nil
})
return schema.StreamReaderFromArray([]map[string]any{result}), nil
@ -336,7 +462,7 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
}
_ = compose.ProcessState(ctx, func(ctx context.Context, state nodes.DynamicStreamContainer) error {
state.SaveDynamicChoice(v.config.NodeKey, groupToChoice)
state.SaveDynamicChoice(v.nodeKey, groupToChoice)
return nil
})
@ -416,26 +542,12 @@ type vaCallbackInput struct {
Variables []any `json:"variables"`
}
func (v *VariableAggregator) Init(ctx context.Context) (context.Context, error) {
ctx = ctxcache.Init(ctx)
resolvedSources, err := nodes.ResolveStreamSources(ctx, v.config.FullSources)
if err != nil {
return nil, err
}
// need this info for callbacks.OnStart, so we put it in cache within Init()
ctxcache.Store(ctx, resolvedSourcesCacheKey, resolvedSources)
return ctx, nil
}
type streamMarkerType string
const streamMarker streamMarkerType = "<Stream Data...>"
func (v *VariableAggregator) ToCallbackInput(ctx context.Context, input map[string]any) (map[string]any, error) {
resolvedSources, ok := ctxcache.Get[map[string]*nodes.SourceInfo](ctx, resolvedSourcesCacheKey)
resolvedSources, ok := ctxcache.Get[map[string]*schema2.SourceInfo](ctx, resolvedSourcesCacheKey)
if !ok {
panic("unable to get resolved_sources from ctx cache")
}
@ -447,14 +559,14 @@ func (v *VariableAggregator) ToCallbackInput(ctx context.Context, input map[stri
merged := make([]vaCallbackInput, 0, len(in))
groupLen := v.config.GroupLen
groupLen := v.groupLen
for groupName, vars := range in {
orderedVars := make([]any, groupLen[groupName])
for index := range vars {
orderedVars[index] = vars[index]
if len(resolvedSources) > 0 {
if resolvedSources[groupName].SubSources[strconv.Itoa(index)].FieldType == nodes.FieldIsStream {
if resolvedSources[groupName].SubSources[strconv.Itoa(index)].FieldType == schema2.FieldIsStream {
// replace the streams with streamMarker,
// because we won't read, save to execution history, or display these streams to user
orderedVars[index] = streamMarker
@ -479,7 +591,7 @@ func (v *VariableAggregator) ToCallbackInput(ctx context.Context, input map[stri
}
func (v *VariableAggregator) ToCallbackOutput(ctx context.Context, output map[string]any) (*nodes.StructuredCallbackOutput, error) {
dynamicStreamType, ok := ctxcache.Get[map[string]nodes.FieldStreamType](ctx, groupChoiceTypeCacheKey)
dynamicStreamType, ok := ctxcache.Get[map[string]schema2.FieldStreamType](ctx, groupChoiceTypeCacheKey)
if !ok {
panic("unable to get dynamic stream types from ctx cache")
}
@ -501,7 +613,7 @@ func (v *VariableAggregator) ToCallbackOutput(ctx context.Context, output map[st
newOut := maps.Clone(output)
for k := range output {
if t, ok := dynamicStreamType[k]; ok && t == nodes.FieldIsStream {
if t, ok := dynamicStreamType[k]; ok && t == schema2.FieldIsStream {
newOut[k] = streamMarker
}
}
@ -594,3 +706,15 @@ func init() {
nodes.RegisterStreamChunkConcatFunc(concatVACallbackInputs)
nodes.RegisterStreamChunkConcatFunc(concatStreamMarkers)
}
func (v *VariableAggregator) Init(ctx context.Context) (context.Context, error) {
resolvedSources, err := nodes.ResolveStreamSources(ctx, v.fullSources)
if err != nil {
return nil, err
}
// need this info for callbacks.OnStart, so we put it in cache within Init()
ctxcache.Store(ctx, resolvedSourcesCacheKey, resolvedSources)
return ctx, nil
}

View File

@ -25,29 +25,75 @@ import (
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
type VariableAssigner struct {
config *Config
pairs []*Pair
handler *variable.Handler
}
type Config struct {
Pairs []*Pair
Handler *variable.Handler
Pairs []*Pair
}
type Pair struct {
Left vo.Reference
Right compose.FieldPath
func (c *Config) Adapt(ctx context.Context, n *vo.Node, opts ...nodes.AdaptOption) (*schema.NodeSchema, error) {
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeVariableAssigner,
Name: n.Data.Meta.Title,
Configs: c,
}
var pairs = make([]*Pair, 0, len(n.Data.Inputs.InputParameters))
for i, param := range n.Data.Inputs.InputParameters {
if param.Left == nil || param.Input == nil {
return nil, fmt.Errorf("variable assigner node's param left or input is nil")
}
leftSources, err := convert.CanvasBlockInputToFieldInfo(param.Left, compose.FieldPath{fmt.Sprintf("left_%d", i)}, n.Parent())
if err != nil {
return nil, err
}
if leftSources[0].Source.Ref == nil {
return nil, fmt.Errorf("variable assigner node's param left source ref is nil")
}
if leftSources[0].Source.Ref.VariableType == nil {
return nil, fmt.Errorf("variable assigner node's param left source ref's variable type is nil")
}
if *leftSources[0].Source.Ref.VariableType == vo.GlobalSystem {
return nil, fmt.Errorf("variable assigner node's param left's ref's variable type cannot be variable.GlobalSystem")
}
inputSource, err := convert.CanvasBlockInputToFieldInfo(param.Input, leftSources[0].Source.Ref.FromPath, n.Parent())
if err != nil {
return nil, err
}
ns.AddInputSource(inputSource...)
pair := &Pair{
Left: *leftSources[0].Source.Ref,
Right: inputSource[0].Path,
}
pairs = append(pairs, pair)
}
c.Pairs = pairs
return ns, nil
}
func NewVariableAssigner(_ context.Context, conf *Config) (*VariableAssigner, error) {
for _, pair := range conf.Pairs {
func (c *Config) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
for _, pair := range c.Pairs {
if pair.Left.VariableType == nil {
return nil, fmt.Errorf("cannot assign to output of nodes in VariableAssigner, ref: %v", pair.Left)
}
@ -63,12 +109,18 @@ func NewVariableAssigner(_ context.Context, conf *Config) (*VariableAssigner, er
}
return &VariableAssigner{
config: conf,
pairs: c.Pairs,
handler: variable.GetVariableHandler(),
}, nil
}
func (v *VariableAssigner) Assign(ctx context.Context, in map[string]any) (map[string]any, error) {
for _, pair := range v.config.Pairs {
type Pair struct {
Left vo.Reference
Right compose.FieldPath
}
func (v *VariableAssigner) Invoke(ctx context.Context, in map[string]any) (map[string]any, error) {
for _, pair := range v.pairs {
right, ok := nodes.TakeMapValue(in, pair.Right)
if !ok {
return nil, vo.NewError(errno.ErrInputFieldMissing, errorx.KV("name", strings.Join(pair.Right, ".")))
@ -98,7 +150,7 @@ func (v *VariableAssigner) Assign(ctx context.Context, in map[string]any) (map[s
ConnectorUID: exeCfg.ConnectorUID,
}))
}
err := v.config.Handler.Set(ctx, *pair.Left.VariableType, pair.Left.FromPath, right, opts...)
err := v.handler.Set(ctx, *pair.Left.VariableType, pair.Left.FromPath, right, opts...)
if err != nil {
return nil, vo.WrapIfNeeded(errno.ErrVariablesAPIFail, err)
}

View File

@ -20,25 +20,93 @@ import (
"context"
"fmt"
einoCompose "github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
)
type InLoop struct {
config *Config
intermediateVarStore variable.Store
type InLoopConfig struct {
Pairs []*Pair
}
func NewVariableAssignerInLoop(_ context.Context, conf *Config) (*InLoop, error) {
func (i *InLoopConfig) Adapt(ctx context.Context, n *vo.Node, opts ...nodes.AdaptOption) (*schema.NodeSchema, error) {
if n.Parent() == nil {
return nil, fmt.Errorf("loop set variable node must have parent: %s", n.ID)
}
ns := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeVariableAssignerWithinLoop,
Name: n.Data.Meta.Title,
Configs: i,
}
var pairs []*Pair
for i, param := range n.Data.Inputs.InputParameters {
if param.Left == nil || param.Right == nil {
return nil, fmt.Errorf("loop set variable node's param left or right is nil")
}
leftSources, err := convert.CanvasBlockInputToFieldInfo(param.Left, einoCompose.FieldPath{fmt.Sprintf("left_%d", i)}, n.Parent())
if err != nil {
return nil, err
}
if len(leftSources) != 1 {
return nil, fmt.Errorf("loop set variable node's param left is not a single source")
}
if leftSources[0].Source.Ref == nil {
return nil, fmt.Errorf("loop set variable node's param left's ref is nil")
}
if leftSources[0].Source.Ref.VariableType == nil || *leftSources[0].Source.Ref.VariableType != vo.ParentIntermediate {
return nil, fmt.Errorf("loop set variable node's param left's ref's variable type is not variable.ParentIntermediate")
}
rightSources, err := convert.CanvasBlockInputToFieldInfo(param.Right, leftSources[0].Source.Ref.FromPath, n.Parent())
if err != nil {
return nil, err
}
ns.AddInputSource(rightSources...)
if len(rightSources) != 1 {
return nil, fmt.Errorf("loop set variable node's param right is not a single source")
}
pair := &Pair{
Left: *leftSources[0].Source.Ref,
Right: rightSources[0].Path,
}
pairs = append(pairs, pair)
}
i.Pairs = pairs
return ns, nil
}
func (i *InLoopConfig) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
return &InLoop{
config: conf,
pairs: i.Pairs,
intermediateVarStore: &nodes.ParentIntermediateStore{},
}, nil
}
func (v *InLoop) Assign(ctx context.Context, in map[string]any) (out map[string]any, err error) {
for _, pair := range v.config.Pairs {
type InLoop struct {
pairs []*Pair
intermediateVarStore variable.Store
}
func (v *InLoop) Invoke(ctx context.Context, in map[string]any) (out map[string]any, err error) {
for _, pair := range v.pairs {
if pair.Left.VariableType == nil || *pair.Left.VariableType != vo.ParentIntermediate {
panic(fmt.Errorf("dest is %+v in VariableAssignerInloop, invalid", pair.Left))
}

View File

@ -37,36 +37,34 @@ func TestVariableAssigner(t *testing.T) {
arrVar := any([]any{1, "2"})
va := &InLoop{
config: &Config{
Pairs: []*Pair{
{
Left: vo.Reference{
FromPath: compose.FieldPath{"int_var_s"},
VariableType: ptr.Of(vo.ParentIntermediate),
},
Right: compose.FieldPath{"int_var_t"},
pairs: []*Pair{
{
Left: vo.Reference{
FromPath: compose.FieldPath{"int_var_s"},
VariableType: ptr.Of(vo.ParentIntermediate),
},
{
Left: vo.Reference{
FromPath: compose.FieldPath{"str_var_s"},
VariableType: ptr.Of(vo.ParentIntermediate),
},
Right: compose.FieldPath{"str_var_t"},
Right: compose.FieldPath{"int_var_t"},
},
{
Left: vo.Reference{
FromPath: compose.FieldPath{"str_var_s"},
VariableType: ptr.Of(vo.ParentIntermediate),
},
{
Left: vo.Reference{
FromPath: compose.FieldPath{"obj_var_s"},
VariableType: ptr.Of(vo.ParentIntermediate),
},
Right: compose.FieldPath{"obj_var_t"},
Right: compose.FieldPath{"str_var_t"},
},
{
Left: vo.Reference{
FromPath: compose.FieldPath{"obj_var_s"},
VariableType: ptr.Of(vo.ParentIntermediate),
},
{
Left: vo.Reference{
FromPath: compose.FieldPath{"arr_var_s"},
VariableType: ptr.Of(vo.ParentIntermediate),
},
Right: compose.FieldPath{"arr_var_t"},
Right: compose.FieldPath{"obj_var_t"},
},
{
Left: vo.Reference{
FromPath: compose.FieldPath{"arr_var_s"},
VariableType: ptr.Of(vo.ParentIntermediate),
},
Right: compose.FieldPath{"arr_var_t"},
},
},
intermediateVarStore: &nodes.ParentIntermediateStore{},
@ -79,7 +77,7 @@ func TestVariableAssigner(t *testing.T) {
"arr_var_s": &arrVar,
}, nil)
_, err := va.Assign(ctx, map[string]any{
_, err := va.Invoke(ctx, map[string]any{
"int_var_t": 2,
"str_var_t": "str2",
"obj_var_t": map[string]any{

View File

@ -0,0 +1,196 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package schema
import (
"context"
"fmt"
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
)
// Port type constants
const (
PortDefault = "default"
PortBranchError = "branch_error"
PortBranchFormat = "branch_%d"
)
// BranchSchema defines the schema for workflow branches.
type BranchSchema struct {
From vo.NodeKey `json:"from_node"`
DefaultMapping map[string]bool `json:"default_mapping,omitempty"`
ExceptionMapping map[string]bool `json:"exception_mapping,omitempty"`
Mappings map[int64]map[string]bool `json:"mappings,omitempty"`
}
// BuildBranches builds branch schemas from connections.
func BuildBranches(connections []*Connection) (map[vo.NodeKey]*BranchSchema, error) {
var branchMap map[vo.NodeKey]*BranchSchema
for _, conn := range connections {
if conn.FromPort == nil || len(*conn.FromPort) == 0 {
continue
}
port := *conn.FromPort
sourceNodeKey := conn.FromNode
if branchMap == nil {
branchMap = map[vo.NodeKey]*BranchSchema{}
}
// Get or create branch schema for source node
branch, exists := branchMap[sourceNodeKey]
if !exists {
branch = &BranchSchema{
From: sourceNodeKey,
}
branchMap[sourceNodeKey] = branch
}
// Classify port type and add to appropriate mapping
switch {
case port == PortDefault:
if branch.DefaultMapping == nil {
branch.DefaultMapping = map[string]bool{}
}
branch.DefaultMapping[string(conn.ToNode)] = true
case port == PortBranchError:
if branch.ExceptionMapping == nil {
branch.ExceptionMapping = map[string]bool{}
}
branch.ExceptionMapping[string(conn.ToNode)] = true
default:
var branchNum int64
_, err := fmt.Sscanf(port, PortBranchFormat, &branchNum)
if err != nil || branchNum < 0 {
return nil, fmt.Errorf("invalid port format '%s' for connection %+v", port, conn)
}
if branch.Mappings == nil {
branch.Mappings = map[int64]map[string]bool{}
}
if _, exists := branch.Mappings[branchNum]; !exists {
branch.Mappings[branchNum] = make(map[string]bool)
}
branch.Mappings[branchNum][string(conn.ToNode)] = true
}
}
return branchMap, nil
}
func (bs *BranchSchema) OnlyException() bool {
return len(bs.Mappings) == 0 && len(bs.ExceptionMapping) > 0 && len(bs.DefaultMapping) > 0
}
func (bs *BranchSchema) GetExceptionBranch() *compose.GraphBranch {
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
isSuccess, ok := in["isSuccess"]
if ok && isSuccess != nil && !isSuccess.(bool) {
return bs.ExceptionMapping, nil
}
return bs.DefaultMapping, nil
}
// Combine ExceptionMapping and DefaultMapping into a new map
endNodes := make(map[string]bool)
for node := range bs.ExceptionMapping {
endNodes[node] = true
}
for node := range bs.DefaultMapping {
endNodes[node] = true
}
return compose.NewGraphMultiBranch(condition, endNodes)
}
func (bs *BranchSchema) GetFullBranch(ctx context.Context, bb BranchBuilder) (*compose.GraphBranch, error) {
extractor, hasBranch := bb.BuildBranch(ctx)
if !hasBranch {
return nil, fmt.Errorf("branch expected but BranchBuilder thinks not. BranchSchema: %v", bs)
}
if len(bs.ExceptionMapping) == 0 { // no exception, it's a normal branch
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
index, isDefault, err := extractor(ctx, in)
if err != nil {
return nil, err
}
if isDefault {
return bs.DefaultMapping, nil
}
if _, ok := bs.Mappings[index]; !ok {
return nil, fmt.Errorf("chosen index= %d, out of range", index)
}
return bs.Mappings[index], nil
}
// Combine DefaultMapping and normal mappings into a new map
endNodes := make(map[string]bool)
for node := range bs.DefaultMapping {
endNodes[node] = true
}
for _, ms := range bs.Mappings {
for node := range ms {
endNodes[node] = true
}
}
return compose.NewGraphMultiBranch(condition, endNodes), nil
}
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
isSuccess, ok := in["isSuccess"]
if ok && isSuccess != nil && !isSuccess.(bool) {
return bs.ExceptionMapping, nil
}
index, isDefault, err := extractor(ctx, in)
if err != nil {
return nil, err
}
if isDefault {
return bs.DefaultMapping, nil
}
return bs.Mappings[index], nil
}
// Combine ALL mappings into a new map
endNodes := make(map[string]bool)
for node := range bs.ExceptionMapping {
endNodes[node] = true
}
for node := range bs.DefaultMapping {
endNodes[node] = true
}
for _, ms := range bs.Mappings {
for node := range ms {
endNodes[node] = true
}
}
return compose.NewGraphMultiBranch(condition, endNodes), nil
}

View File

@ -0,0 +1,73 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package schema
import (
"context"
"github.com/cloudwego/eino/compose"
)
type BuildOptions struct {
WS *WorkflowSchema
Inner compose.Runnable[map[string]any, map[string]any]
}
func GetBuildOptions(opts ...BuildOption) *BuildOptions {
bo := &BuildOptions{}
for _, o := range opts {
o(bo)
}
return bo
}
type BuildOption func(options *BuildOptions)
func WithWorkflowSchema(ws *WorkflowSchema) BuildOption {
return func(options *BuildOptions) {
options.WS = ws
}
}
func WithInnerWorkflow(inner compose.Runnable[map[string]any, map[string]any]) BuildOption {
return func(options *BuildOptions) {
options.Inner = inner
}
}
// NodeBuilder takes a NodeSchema and several BuildOption to build an executable node instance.
// The result 'executable' MUST implement at least one of the execute interfaces:
// - nodes.InvokableNode
// - nodes.StreamableNode
// - nodes.CollectableNode
// - nodes.TransformableNode
// - nodes.InvokableNodeWOpt
// - nodes.StreamableNodeWOpt
// - nodes.CollectableNodeWOpt
// - nodes.TransformableNodeWOpt
// NOTE: the 'normal' version does not take NodeOption, while the 'WOpt' versions take NodeOption.
// NOTE: a node should either implement the 'normal' versions, or the 'WOpt' versions, not mix them up.
type NodeBuilder interface {
Build(ctx context.Context, ns *NodeSchema, opts ...BuildOption) (
executable any, err error)
}
// BranchBuilder builds the extractor function that maps node output to port index.
type BranchBuilder interface {
BuildBranch(ctx context.Context) (extractor func(ctx context.Context,
nodeOutput map[string]any) (int64, bool /*if is default branch*/, error), hasBranch bool)
}

View File

@ -0,0 +1,131 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package schema
import (
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
)
// NodeSchema is the universal description and configuration for a workflow Node.
// It should contain EVERYTHING a node needs to instantiate.
type NodeSchema struct {
// Key is the node key within the Eino graph.
// A node may need this information during execution,
// e.g.
// - using this Key to query workflow State for data belonging to current node.
Key vo.NodeKey `json:"key"`
// Name is the name for this node as specified on Canvas.
// A node may show this name on Canvas as part of this node's input/output.
Name string `json:"name"`
// Type is the NodeType for the node.
Type entity.NodeType `json:"type"`
// Configs are node specific configurations, with actual struct type defined by each Node Type.
// Will not hold information relating to field mappings, nor as node's static values.
// In a word, these Configs are INTERNAL to node's implementation, NOT related to workflow orchestration.
// Actual type of these Configs should implement two interfaces:
// - NodeAdaptor: to provide conversion from vo.Node to NodeSchema
// - NodeBuilder: to provide instantiation from NodeSchema to actual node instance.
Configs any `json:"configs,omitempty"`
// InputTypes are type information about the node's input fields.
InputTypes map[string]*vo.TypeInfo `json:"input_types,omitempty"`
// InputSources are field mapping information about the node's input fields.
InputSources []*vo.FieldInfo `json:"input_sources,omitempty"`
// OutputTypes are type information about the node's output fields.
OutputTypes map[string]*vo.TypeInfo `json:"output_types,omitempty"`
// OutputSources are field mapping information about the node's output fields.
// NOTE: only applicable to composite nodes such as NodeTypeBatch or NodeTypeLoop.
OutputSources []*vo.FieldInfo `json:"output_sources,omitempty"`
// ExceptionConfigs are about exception handling strategy of the node.
ExceptionConfigs *ExceptionConfig `json:"exception_configs,omitempty"`
// StreamConfigs are streaming characteristics of the node.
StreamConfigs *StreamConfig `json:"stream_configs,omitempty"`
// SubWorkflowBasic is basic information of the sub workflow if this node is NodeTypeSubWorkflow.
SubWorkflowBasic *entity.WorkflowBasic `json:"sub_workflow_basic,omitempty"`
// SubWorkflowSchema is WorkflowSchema of the sub workflow if this node is NodeTypeSubWorkflow.
SubWorkflowSchema *WorkflowSchema `json:"sub_workflow_schema,omitempty"`
// FullSources contains more complete information about a node's input fields' mapping sources,
// such as whether a field's source is a 'streaming field',
// or whether the field is an object that contains sub-fields with real mappings.
// Used for those nodes that need to process streaming input.
// Set InputSourceAware = true in NodeMeta to enable.
FullSources map[string]*SourceInfo
// Lambda directly sets the node to be an Eino Lambda.
// NOTE: not serializable, used ONLY for internal test.
Lambda *compose.Lambda
}
type RequireCheckpoint interface {
RequireCheckpoint() bool
}
type ExceptionConfig struct {
TimeoutMS int64 `json:"timeout_ms,omitempty"` // timeout in milliseconds, 0 means no timeout
MaxRetry int64 `json:"max_retry,omitempty"` // max retry times, 0 means no retry
ProcessType *vo.ErrorProcessType `json:"process_type,omitempty"` // error process type, 0 means throw error
DataOnErr string `json:"data_on_err,omitempty"` // data to return when error, effective when ProcessType==Default occurs
}
type StreamConfig struct {
// whether this node has the ability to produce genuine streaming output.
// not include nodes that only passes stream down as they receives them
CanGeneratesStream bool `json:"can_generates_stream,omitempty"`
// whether this node prioritize streaming input over none-streaming input.
// not include nodes that can accept both and does not have preference.
RequireStreamingInput bool `json:"can_process_stream,omitempty"`
}
func (s *NodeSchema) SetConfigKV(key string, value any) {
if s.Configs == nil {
s.Configs = make(map[string]any)
}
s.Configs.(map[string]any)[key] = value
}
func (s *NodeSchema) SetInputType(key string, t *vo.TypeInfo) {
if s.InputTypes == nil {
s.InputTypes = make(map[string]*vo.TypeInfo)
}
s.InputTypes[key] = t
}
func (s *NodeSchema) AddInputSource(info ...*vo.FieldInfo) {
s.InputSources = append(s.InputSources, info...)
}
func (s *NodeSchema) SetOutputType(key string, t *vo.TypeInfo) {
if s.OutputTypes == nil {
s.OutputTypes = make(map[string]*vo.TypeInfo)
}
s.OutputTypes[key] = t
}
func (s *NodeSchema) AddOutputSource(info ...*vo.FieldInfo) {
s.OutputSources = append(s.OutputSources, info...)
}

View File

@ -0,0 +1,77 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package schema
import (
"github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
)
type FieldStreamType string
const (
FieldIsStream FieldStreamType = "yes" // absolutely a stream
FieldNotStream FieldStreamType = "no" // absolutely not a stream
FieldMaybeStream FieldStreamType = "maybe" // maybe a stream, requires request-time resolution
FieldSkipped FieldStreamType = "skipped" // the field source's node is skipped
)
type FieldSkipStatus string
// SourceInfo contains stream type for a input field source of a node.
type SourceInfo struct {
// IsIntermediate means this field is itself not a field source, but a map containing one or more field sources.
IsIntermediate bool
// FieldType the stream type of the field. May require request-time resolution in addition to compile-time.
FieldType FieldStreamType
// FromNodeKey is the node key that produces this field source. empty if the field is a static value or variable.
FromNodeKey vo.NodeKey
// FromPath is the path of this field source within the source node. empty if the field is a static value or variable.
FromPath compose.FieldPath
TypeInfo *vo.TypeInfo
// SubSources are SourceInfo for keys within this intermediate Map(Object) field.
SubSources map[string]*SourceInfo
}
func (s *SourceInfo) Skipped() bool {
if !s.IsIntermediate {
return s.FieldType == FieldSkipped
}
for _, sub := range s.SubSources {
if !sub.Skipped() {
return false
}
}
return true
}
func (s *SourceInfo) FromNode(nodeKey vo.NodeKey) bool {
if !s.IsIntermediate {
return s.FromNodeKey == nodeKey
}
for _, sub := range s.SubSources {
if sub.FromNode(nodeKey) {
return true
}
}
return false
}

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
package compose
package schema
import (
"fmt"
@ -29,9 +29,10 @@ import (
)
type WorkflowSchema struct {
Nodes []*NodeSchema `json:"nodes"`
Connections []*Connection `json:"connections"`
Hierarchy map[vo.NodeKey]vo.NodeKey `json:"hierarchy,omitempty"` // child node key-> parent node key
Nodes []*NodeSchema `json:"nodes"`
Connections []*Connection `json:"connections"`
Hierarchy map[vo.NodeKey]vo.NodeKey `json:"hierarchy,omitempty"` // child node key-> parent node key
Branches map[vo.NodeKey]*BranchSchema `json:"branches,omitempty"`
GeneratedNodes []vo.NodeKey `json:"generated_nodes,omitempty"` // generated nodes for the nodes in batch mode
@ -71,9 +72,19 @@ func (w *WorkflowSchema) Init() {
w.doGetCompositeNodes()
for _, node := range w.Nodes {
if node.requireCheckpoint() {
w.requireCheckPoint = true
break
if node.Type == entity.NodeTypeSubWorkflow {
node.SubWorkflowSchema.Init()
if node.SubWorkflowSchema.requireCheckPoint {
w.requireCheckPoint = true
break
}
}
if rc, ok := node.Configs.(RequireCheckpoint); ok {
if rc.RequireCheckpoint() {
w.requireCheckPoint = true
break
}
}
}
@ -97,6 +108,22 @@ func (w *WorkflowSchema) GetCompositeNodes() []*CompositeNode {
return w.compositeNodes
}
func (w *WorkflowSchema) GetBranch(key vo.NodeKey) *BranchSchema {
if w.Branches == nil {
return nil
}
return w.Branches[key]
}
func (w *WorkflowSchema) RequireCheckpoint() bool {
return w.requireCheckPoint
}
func (w *WorkflowSchema) RequireStreaming() bool {
return w.requireStreaming
}
func (w *WorkflowSchema) doGetCompositeNodes() (cNodes []*CompositeNode) {
if w.Hierarchy == nil {
return nil
@ -125,7 +152,7 @@ func (w *WorkflowSchema) doGetCompositeNodes() (cNodes []*CompositeNode) {
return cNodes
}
func isInSameWorkflow(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.NodeKey) bool {
func IsInSameWorkflow(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.NodeKey) bool {
if n == nil {
return true
}
@ -144,7 +171,7 @@ func isInSameWorkflow(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.Node
return myParents == theirParents
}
func isBelowOneLevel(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.NodeKey) bool {
func IsBelowOneLevel(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.NodeKey) bool {
if n == nil {
return false
}
@ -154,7 +181,7 @@ func isBelowOneLevel(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.NodeK
return myParentExists && !theirParentExists
}
func isParentOf(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.NodeKey) bool {
func IsParentOf(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.NodeKey) bool {
if n == nil {
return false
}
@ -230,21 +257,14 @@ func (w *WorkflowSchema) doRequireStreaming() bool {
consumers := make(map[vo.NodeKey]bool)
for _, node := range w.Nodes {
meta := entity.NodeMetaByNodeType(node.Type)
if meta != nil {
sps := meta.ExecutableMeta.StreamingParadigms
if _, ok := sps[entity.Stream]; ok {
if node.StreamConfigs != nil && node.StreamConfigs.CanGeneratesStream {
producers[node.Key] = true
}
}
if sps[entity.Transform] || sps[entity.Collect] {
if node.StreamConfigs != nil && node.StreamConfigs.RequireStreamingInput {
consumers[node.Key] = true
}
}
if node.StreamConfigs != nil && node.StreamConfigs.CanGeneratesStream {
producers[node.Key] = true
}
if node.StreamConfigs != nil && node.StreamConfigs.RequireStreamingInput {
consumers[node.Key] = true
}
}
if len(producers) == 0 || len(consumers) == 0 {
@ -290,7 +310,7 @@ func (w *WorkflowSchema) doRequireStreaming() bool {
return false
}
func (w *WorkflowSchema) fanInMergeConfigs() map[string]compose.FanInMergeConfig {
func (w *WorkflowSchema) FanInMergeConfigs() map[string]compose.FanInMergeConfig {
// what we need to do is to see if the workflow requires streaming, if not, then no fan-in merge configs needed
// then we find those nodes that have 'transform' or 'collect' as streaming paradigm,
// and see if each of those nodes has multiple data predecessors, if so, it's a fan-in node.
@ -301,21 +321,15 @@ func (w *WorkflowSchema) fanInMergeConfigs() map[string]compose.FanInMergeConfig
fanInNodes := make(map[vo.NodeKey]bool)
for _, node := range w.Nodes {
meta := entity.NodeMetaByNodeType(node.Type)
if meta != nil {
sps := meta.ExecutableMeta.StreamingParadigms
if sps[entity.Transform] || sps[entity.Collect] {
if node.StreamConfigs != nil && node.StreamConfigs.RequireStreamingInput {
var predecessor *vo.NodeKey
for _, source := range node.InputSources {
if source.Source.Ref != nil && len(source.Source.Ref.FromNodeKey) > 0 {
if predecessor != nil {
fanInNodes[node.Key] = true
break
}
predecessor = &source.Source.Ref.FromNodeKey
}
if node.StreamConfigs != nil && node.StreamConfigs.RequireStreamingInput {
var predecessor *vo.NodeKey
for _, source := range node.InputSources {
if source.Source.Ref != nil && len(source.Source.Ref.FromNodeKey) > 0 {
if predecessor != nil {
fanInNodes[node.Key] = true
break
}
predecessor = &source.Source.Ref.FromNodeKey
}
}
}

View File

@ -37,8 +37,11 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/adaptor"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/intentdetector"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/repo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
@ -73,7 +76,7 @@ func NewWorkflowRepository(idgen idgen.IDGenerator, db *gorm.DB, redis cache.Cmd
return repo.NewRepository(idgen, db, redis, tos, cpStore, chatModel)
}
func (i *impl) ListNodeMeta(ctx context.Context, nodeTypes map[entity.NodeType]bool) (map[string][]*entity.NodeTypeMeta, []entity.Category, error) {
func (i *impl) ListNodeMeta(_ context.Context, nodeTypes map[entity.NodeType]bool) (map[string][]*entity.NodeTypeMeta, []entity.Category, error) {
// Initialize result maps
nodeMetaMap := make(map[string][]*entity.NodeTypeMeta)
@ -82,7 +85,7 @@ func (i *impl) ListNodeMeta(ctx context.Context, nodeTypes map[entity.NodeType]b
if meta.Disabled {
return false
}
nodeType := meta.Type
nodeType := meta.Key
if nodeTypes == nil || len(nodeTypes) == 0 {
return true // No filter, include all
}
@ -192,10 +195,10 @@ func extractInputsAndOutputsNamedInfoList(c *vo.Canvas) (inputs []*vo.NamedTypeI
if startNode != nil && endNode != nil {
break
}
if node.Type == vo.BlockTypeBotStart {
if node.Type == entity.NodeTypeEntry.IDStr() {
startNode = node
}
if node.Type == vo.BlockTypeBotEnd {
if node.Type == entity.NodeTypeExit.IDStr() {
endNode = node
}
}
@ -207,7 +210,7 @@ func extractInputsAndOutputsNamedInfoList(c *vo.Canvas) (inputs []*vo.NamedTypeI
if err != nil {
return nil, err
}
nInfo, err := adaptor.VariableToNamedTypeInfo(v)
nInfo, err := convert.VariableToNamedTypeInfo(v)
if err != nil {
return nil, err
}
@ -220,7 +223,7 @@ func extractInputsAndOutputsNamedInfoList(c *vo.Canvas) (inputs []*vo.NamedTypeI
if endNode != nil {
outputs, err = slices.TransformWithErrorCheck(endNode.Data.Inputs.InputParameters, func(a *vo.Param) (*vo.NamedTypeInfo, error) {
return adaptor.BlockInputToNamedTypeInfo(a.Name, a.Input)
return convert.BlockInputToNamedTypeInfo(a.Name, a.Input)
})
if err != nil {
logs.Warn(fmt.Sprintf("transform end node inputs to named info failed, err=%v", err))
@ -316,6 +319,34 @@ func (i *impl) GetWorkflowReference(ctx context.Context, id int64) (map[int64]*v
return ret, nil
}
type workflowIdentity struct {
ID string `json:"id"`
Version string `json:"version"`
}
func getAllSubWorkflowIdentities(c *vo.Canvas) []*workflowIdentity {
workflowEntities := make([]*workflowIdentity, 0)
var collectSubWorkFlowEntities func(nodes []*vo.Node)
collectSubWorkFlowEntities = func(nodes []*vo.Node) {
for _, n := range nodes {
if n.Type == entity.NodeTypeSubWorkflow.IDStr() {
workflowEntities = append(workflowEntities, &workflowIdentity{
ID: n.Data.Inputs.WorkflowID,
Version: n.Data.Inputs.WorkflowVersion,
})
}
if len(n.Blocks) > 0 {
collectSubWorkFlowEntities(n.Blocks)
}
}
}
collectSubWorkFlowEntities(c.Nodes)
return workflowEntities
}
func (i *impl) ValidateTree(ctx context.Context, id int64, validateConfig vo.ValidateTreeConfig) ([]*cloudworkflow.ValidateTreeInfo, error) {
wfValidateInfos := make([]*cloudworkflow.ValidateTreeInfo, 0)
issues, err := validateWorkflowTree(ctx, validateConfig)
@ -337,7 +368,7 @@ func (i *impl) ValidateTree(ctx context.Context, id int64, validateConfig vo.Val
fmt.Errorf("failed to unmarshal canvas schema: %w", err))
}
subWorkflowIdentities := c.GetAllSubWorkflowIdentities()
subWorkflowIdentities := getAllSubWorkflowIdentities(c)
if len(subWorkflowIdentities) > 0 {
var ids []int64
@ -421,25 +452,21 @@ func (i *impl) collectNodePropertyMap(ctx context.Context, canvas *vo.Canvas) (m
}
for _, n := range canvas.Nodes {
if n.Type == vo.BlockTypeBotSubWorkflow {
nodeSchema := &compose.NodeSchema{
if n.Type == entity.NodeTypeSubWorkflow.IDStr() {
nodeSchema := &schema.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeSubWorkflow,
Name: n.Data.Meta.Title,
}
err := adaptor.SetInputsForNodeSchema(n, nodeSchema)
if err != nil {
return nil, err
}
blockType, err := entityNodeTypeToBlockType(nodeSchema.Type)
err := convert.SetInputsForNodeSchema(n, nodeSchema)
if err != nil {
return nil, err
}
prop := &vo.NodeProperty{
Type: string(blockType),
IsEnableUserQuery: nodeSchema.IsEnableUserQuery(),
IsEnableChatHistory: nodeSchema.IsEnableChatHistory(),
IsRefGlobalVariable: nodeSchema.IsRefGlobalVariable(),
Type: nodeSchema.Type.IDStr(),
IsEnableUserQuery: isEnableUserQuery(nodeSchema),
IsEnableChatHistory: isEnableChatHistory(nodeSchema),
IsRefGlobalVariable: isRefGlobalVariable(nodeSchema),
}
nodePropertyMap[string(nodeSchema.Key)] = prop
wid, err := strconv.ParseInt(n.Data.Inputs.WorkflowID, 10, 64)
@ -478,20 +505,16 @@ func (i *impl) collectNodePropertyMap(ctx context.Context, canvas *vo.Canvas) (m
prop.SubWorkflow = ret
} else {
nodeSchemas, _, err := adaptor.NodeToNodeSchema(ctx, n)
nodeSchemas, _, err := adaptor.NodeToNodeSchema(ctx, n, canvas)
if err != nil {
return nil, err
}
for _, nodeSchema := range nodeSchemas {
blockType, err := entityNodeTypeToBlockType(nodeSchema.Type)
if err != nil {
return nil, err
}
nodePropertyMap[string(nodeSchema.Key)] = &vo.NodeProperty{
Type: string(blockType),
IsEnableUserQuery: nodeSchema.IsEnableUserQuery(),
IsEnableChatHistory: nodeSchema.IsEnableChatHistory(),
IsRefGlobalVariable: nodeSchema.IsRefGlobalVariable(),
Type: nodeSchema.Type.IDStr(),
IsEnableUserQuery: isEnableUserQuery(nodeSchema),
IsEnableChatHistory: isEnableChatHistory(nodeSchema),
IsRefGlobalVariable: isRefGlobalVariable(nodeSchema),
}
}
@ -500,6 +523,60 @@ func (i *impl) collectNodePropertyMap(ctx context.Context, canvas *vo.Canvas) (m
return nodePropertyMap, nil
}
func isEnableUserQuery(s *schema.NodeSchema) bool {
if s == nil {
return false
}
if s.Type != entity.NodeTypeEntry {
return false
}
if len(s.OutputSources) == 0 {
return false
}
for _, source := range s.OutputSources {
fieldPath := source.Path
if len(fieldPath) == 1 && (fieldPath[0] == "BOT_USER_INPUT" || fieldPath[0] == "USER_INPUT") {
return true
}
}
return false
}
func isEnableChatHistory(s *schema.NodeSchema) bool {
if s == nil {
return false
}
switch s.Type {
case entity.NodeTypeLLM:
llmParam := s.Configs.(*llm.Config).LLMParams
return llmParam.EnableChatHistory
case entity.NodeTypeIntentDetector:
llmParam := s.Configs.(*intentdetector.Config).LLMParams
return llmParam.EnableChatHistory
default:
return false
}
}
func isRefGlobalVariable(s *schema.NodeSchema) bool {
for _, source := range s.InputSources {
if source.IsRefGlobalVariable() {
return true
}
}
for _, source := range s.OutputSources {
if source.IsRefGlobalVariable() {
return true
}
}
return false
}
func canvasToRefs(referringID int64, canvasStr string) (map[entity.WorkflowReferenceKey]struct{}, error) {
var canvas vo.Canvas
if err := sonic.UnmarshalString(canvasStr, &canvas); err != nil {
@ -510,7 +587,7 @@ func canvasToRefs(referringID int64, canvasStr string) (map[entity.WorkflowRefer
var getRefFn func([]*vo.Node) error
getRefFn = func(nodes []*vo.Node) error {
for _, node := range nodes {
if node.Type == vo.BlockTypeBotSubWorkflow {
if node.Type == entity.NodeTypeSubWorkflow.IDStr() {
referredID, err := strconv.ParseInt(node.Data.Inputs.WorkflowID, 10, 64)
if err != nil {
return vo.WrapError(errno.ErrSchemaConversionFail, err)
@ -521,19 +598,21 @@ func canvasToRefs(referringID int64, canvasStr string) (map[entity.WorkflowRefer
ReferType: vo.ReferTypeSubWorkflow,
ReferringBizType: vo.ReferringBizTypeWorkflow,
}] = struct{}{}
} else if node.Type == vo.BlockTypeBotLLM {
if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil {
for _, w := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList {
referredID, err := strconv.ParseInt(w.WorkflowID, 10, 64)
if err != nil {
return vo.WrapError(errno.ErrSchemaConversionFail, err)
} else if node.Type == entity.NodeTypeLLM.IDStr() {
if node.Data.Inputs.LLM != nil {
if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil {
for _, w := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList {
referredID, err := strconv.ParseInt(w.WorkflowID, 10, 64)
if err != nil {
return vo.WrapError(errno.ErrSchemaConversionFail, err)
}
wfRefs[entity.WorkflowReferenceKey{
ReferredID: referredID,
ReferringID: referringID,
ReferType: vo.ReferTypeTool,
ReferringBizType: vo.ReferringBizTypeWorkflow,
}] = struct{}{}
}
wfRefs[entity.WorkflowReferenceKey{
ReferredID: referredID,
ReferringID: referringID,
ReferType: vo.ReferTypeTool,
ReferringBizType: vo.ReferringBizTypeWorkflow,
}] = struct{}{}
}
}
} else if len(node.Blocks) > 0 {
@ -832,7 +911,7 @@ func (i *impl) CopyWorkflowFromAppToLibrary(ctx context.Context, workflowID int6
validateAndBuildWorkflowReference = func(nodes []*vo.Node, wf *copiedWorkflow) error {
for _, node := range nodes {
if node.Type == vo.BlockTypeBotSubWorkflow {
if node.Type == entity.NodeTypeSubWorkflow.IDStr() {
var (
v *vo.DraftInfo
wfID int64
@ -883,7 +962,7 @@ func (i *impl) CopyWorkflowFromAppToLibrary(ctx context.Context, workflowID int6
}
if node.Type == vo.BlockTypeBotLLM {
if node.Type == entity.NodeTypeLLM.IDStr() {
if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil {
for _, w := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList {
var (
@ -1086,7 +1165,7 @@ func (i *impl) DuplicateWorkflowsByAppID(ctx context.Context, sourceAppID, targe
var buildWorkflowReference func(nodes []*vo.Node, wf *copiedWorkflow) error
buildWorkflowReference = func(nodes []*vo.Node, wf *copiedWorkflow) error {
for _, node := range nodes {
if node.Type == vo.BlockTypeBotSubWorkflow {
if node.Type == entity.NodeTypeSubWorkflow.IDStr() {
var (
v *vo.DraftInfo
wfID int64
@ -1121,7 +1200,7 @@ func (i *impl) DuplicateWorkflowsByAppID(ctx context.Context, sourceAppID, targe
}
}
if node.Type == vo.BlockTypeBotLLM {
if node.Type == entity.NodeTypeLLM.IDStr() {
if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil {
for _, w := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList {
var (
@ -1323,8 +1402,8 @@ func (i *impl) GetWorkflowDependenceResource(ctx context.Context, workflowID int
var collectDependence func(nodes []*vo.Node) error
collectDependence = func(nodes []*vo.Node) error {
for _, node := range nodes {
switch node.Type {
case vo.BlockTypeBotAPI:
switch entity.IDStrToNodeType(node.Type) {
case entity.NodeTypePlugin:
apiParams := slices.ToMap(node.Data.Inputs.APIParams, func(e *vo.Param) (string, *vo.Param) {
return e.Name, e
})
@ -1347,7 +1426,7 @@ func (i *impl) GetWorkflowDependenceResource(ctx context.Context, workflowID int
ds.PluginIDs = append(ds.PluginIDs, pID)
}
case vo.BlockTypeBotDatasetWrite, vo.BlockTypeBotDataset:
case entity.NodeTypeKnowledgeIndexer, entity.NodeTypeKnowledgeRetriever:
datasetListInfoParam := node.Data.Inputs.DatasetParam[0]
datasetIDs := datasetListInfoParam.Input.Value.Content.([]any)
for _, id := range datasetIDs {
@ -1357,7 +1436,7 @@ func (i *impl) GetWorkflowDependenceResource(ctx context.Context, workflowID int
}
ds.KnowledgeIDs = append(ds.KnowledgeIDs, k)
}
case vo.BlockTypeDatabase, vo.BlockTypeDatabaseSelect, vo.BlockTypeDatabaseInsert, vo.BlockTypeDatabaseDelete, vo.BlockTypeDatabaseUpdate:
case entity.NodeTypeDatabaseCustomSQL, entity.NodeTypeDatabaseQuery, entity.NodeTypeDatabaseInsert, entity.NodeTypeDatabaseDelete, entity.NodeTypeDatabaseUpdate:
dsList := node.Data.Inputs.DatabaseInfoList
if len(dsList) == 0 {
return fmt.Errorf("database info is requird")
@ -1369,7 +1448,7 @@ func (i *impl) GetWorkflowDependenceResource(ctx context.Context, workflowID int
}
ds.DatabaseIDs = append(ds.DatabaseIDs, dsID)
}
case vo.BlockTypeBotLLM:
case entity.NodeTypeLLM:
if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.PluginFCParam != nil {
for idx := range node.Data.Inputs.FCParam.PluginFCParam.PluginList {
pl := node.Data.Inputs.FCParam.PluginFCParam.PluginList[idx]
@ -1396,7 +1475,7 @@ func (i *impl) GetWorkflowDependenceResource(ctx context.Context, workflowID int
}
}
case vo.BlockTypeBotSubWorkflow:
case entity.NodeTypeSubWorkflow:
wfID, err := strconv.ParseInt(node.Data.Inputs.WorkflowID, 10, 64)
if err != nil {
return err
@ -1567,8 +1646,8 @@ func replaceRelatedWorkflowOrExternalResourceInWorkflowNodes(nodes []*vo.Node, r
)
for _, node := range nodes {
switch node.Type {
case vo.BlockTypeBotSubWorkflow:
switch entity.IDStrToNodeType(node.Type) {
case entity.NodeTypeSubWorkflow:
if !hasWorkflowRelated {
continue
}
@ -1580,7 +1659,7 @@ func replaceRelatedWorkflowOrExternalResourceInWorkflowNodes(nodes []*vo.Node, r
node.Data.Inputs.WorkflowID = strconv.FormatInt(wf.ID, 10)
node.Data.Inputs.WorkflowVersion = wf.Version
}
case vo.BlockTypeBotAPI:
case entity.NodeTypePlugin:
if !hasPluginRelated {
continue
}
@ -1623,7 +1702,7 @@ func replaceRelatedWorkflowOrExternalResourceInWorkflowNodes(nodes []*vo.Node, r
apiIDParam.Input.Value.Content = strconv.FormatInt(refApiID, 10)
}
case vo.BlockTypeBotLLM:
case entity.NodeTypeLLM:
if hasWorkflowRelated && node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil {
for idx := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList {
wf := node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList[idx]
@ -1669,7 +1748,7 @@ func replaceRelatedWorkflowOrExternalResourceInWorkflowNodes(nodes []*vo.Node, r
}
}
case vo.BlockTypeBotDataset, vo.BlockTypeBotDatasetWrite:
case entity.NodeTypeKnowledgeIndexer, entity.NodeTypeKnowledgeRetriever:
if !hasKnowledgeRelated {
continue
}
@ -1685,7 +1764,7 @@ func replaceRelatedWorkflowOrExternalResourceInWorkflowNodes(nodes []*vo.Node, r
}
}
case vo.BlockTypeDatabase, vo.BlockTypeDatabaseSelect, vo.BlockTypeDatabaseInsert, vo.BlockTypeDatabaseDelete, vo.BlockTypeDatabaseUpdate:
case entity.NodeTypeDatabaseCustomSQL, entity.NodeTypeDatabaseQuery, entity.NodeTypeDatabaseInsert, entity.NodeTypeDatabaseDelete, entity.NodeTypeDatabaseUpdate:
if !hasDatabaseRelated {
continue
}
@ -1713,3 +1792,7 @@ func replaceRelatedWorkflowOrExternalResourceInWorkflowNodes(nodes []*vo.Node, r
}
return nil
}
func RegisterAllNodeAdaptors() {
adaptor.RegisterAllNodeAdaptors()
}

View File

@ -24,11 +24,9 @@ import (
cloudworkflow "github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/adaptor"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/validate"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
@ -199,158 +197,3 @@ func isIncremental(prev version, next version) bool {
return next.Patch > prev.Patch
}
func replaceRelatedWorkflowOrPluginInWorkflowNodes(nodes []*vo.Node, relatedWorkflows map[int64]entity.IDVersionPair, relatedPlugins map[int64]vo.PluginEntity) error {
for _, node := range nodes {
if node.Type == vo.BlockTypeBotSubWorkflow {
workflowID, err := strconv.ParseInt(node.Data.Inputs.WorkflowID, 10, 64)
if err != nil {
return err
}
if wf, ok := relatedWorkflows[workflowID]; ok {
node.Data.Inputs.WorkflowID = strconv.FormatInt(wf.ID, 10)
node.Data.Inputs.WorkflowVersion = wf.Version
}
}
if node.Type == vo.BlockTypeBotAPI {
apiParams := slices.ToMap(node.Data.Inputs.APIParams, func(e *vo.Param) (string, *vo.Param) {
return e.Name, e
})
pluginIDParam, ok := apiParams["pluginID"]
if !ok {
return fmt.Errorf("plugin id param is not found")
}
pID, err := strconv.ParseInt(pluginIDParam.Input.Value.Content.(string), 10, 64)
if err != nil {
return err
}
pluginVersionParam, ok := apiParams["pluginVersion"]
if !ok {
return fmt.Errorf("plugin version param is not found")
}
if refPlugin, ok := relatedPlugins[pID]; ok {
pluginIDParam.Input.Value.Content = refPlugin.PluginID
if refPlugin.PluginVersion != nil {
pluginVersionParam.Input.Value.Content = *refPlugin.PluginVersion
}
}
}
if node.Type == vo.BlockTypeBotLLM {
if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil {
for idx := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList {
wf := node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList[idx]
workflowID, err := strconv.ParseInt(wf.WorkflowID, 10, 64)
if err != nil {
return err
}
if refWf, ok := relatedWorkflows[workflowID]; ok {
wf.WorkflowID = strconv.FormatInt(refWf.ID, 10)
wf.WorkflowVersion = refWf.Version
}
}
}
if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.PluginFCParam != nil {
for idx := range node.Data.Inputs.FCParam.PluginFCParam.PluginList {
pl := node.Data.Inputs.FCParam.PluginFCParam.PluginList[idx]
pluginID, err := strconv.ParseInt(pl.PluginID, 10, 64)
if err != nil {
return err
}
if refPlugin, ok := relatedPlugins[pluginID]; ok {
pl.PluginID = strconv.FormatInt(refPlugin.PluginID, 10)
if refPlugin.PluginVersion != nil {
pl.PluginVersion = *refPlugin.PluginVersion
}
}
}
}
}
if len(node.Blocks) > 0 {
err := replaceRelatedWorkflowOrPluginInWorkflowNodes(node.Blocks, relatedWorkflows, relatedPlugins)
if err != nil {
return err
}
}
}
return nil
}
// entityNodeTypeToBlockType converts an entity.NodeType to the corresponding vo.BlockType.
func entityNodeTypeToBlockType(nodeType entity.NodeType) (vo.BlockType, error) {
switch nodeType {
case entity.NodeTypeEntry:
return vo.BlockTypeBotStart, nil
case entity.NodeTypeExit:
return vo.BlockTypeBotEnd, nil
case entity.NodeTypeLLM:
return vo.BlockTypeBotLLM, nil
case entity.NodeTypePlugin:
return vo.BlockTypeBotAPI, nil
case entity.NodeTypeCodeRunner:
return vo.BlockTypeBotCode, nil
case entity.NodeTypeKnowledgeRetriever:
return vo.BlockTypeBotDataset, nil
case entity.NodeTypeSelector:
return vo.BlockTypeCondition, nil
case entity.NodeTypeSubWorkflow:
return vo.BlockTypeBotSubWorkflow, nil
case entity.NodeTypeDatabaseCustomSQL:
return vo.BlockTypeDatabase, nil
case entity.NodeTypeOutputEmitter:
return vo.BlockTypeBotMessage, nil
case entity.NodeTypeTextProcessor:
return vo.BlockTypeBotText, nil
case entity.NodeTypeQuestionAnswer:
return vo.BlockTypeQuestion, nil
case entity.NodeTypeBreak:
return vo.BlockTypeBotBreak, nil
case entity.NodeTypeVariableAssigner:
return vo.BlockTypeBotAssignVariable, nil
case entity.NodeTypeVariableAssignerWithinLoop:
return vo.BlockTypeBotLoopSetVariable, nil
case entity.NodeTypeLoop:
return vo.BlockTypeBotLoop, nil
case entity.NodeTypeIntentDetector:
return vo.BlockTypeBotIntent, nil
case entity.NodeTypeKnowledgeIndexer:
return vo.BlockTypeBotDatasetWrite, nil
case entity.NodeTypeBatch:
return vo.BlockTypeBotBatch, nil
case entity.NodeTypeContinue:
return vo.BlockTypeBotContinue, nil
case entity.NodeTypeInputReceiver:
return vo.BlockTypeBotInput, nil
case entity.NodeTypeDatabaseUpdate:
return vo.BlockTypeDatabaseUpdate, nil
case entity.NodeTypeDatabaseQuery:
return vo.BlockTypeDatabaseSelect, nil
case entity.NodeTypeDatabaseDelete:
return vo.BlockTypeDatabaseDelete, nil
case entity.NodeTypeHTTPRequester:
return vo.BlockTypeBotHttp, nil
case entity.NodeTypeDatabaseInsert:
return vo.BlockTypeDatabaseInsert, nil
case entity.NodeTypeVariableAggregator:
return vo.BlockTypeBotVariableMerge, nil
case entity.NodeTypeJsonSerialization:
return vo.BlockTypeJsonSerialization, nil
case entity.NodeTypeJsonDeserialization:
return vo.BlockTypeJsonDeserialization, nil
case entity.NodeTypeKnowledgeDeleter:
return vo.BlockTypeBotDatasetDelete, nil
default:
return "", vo.WrapError(errno.ErrSchemaConversionFail,
fmt.Errorf("cannot map entity node type '%s' to a workflow.NodeTemplateType", nodeType))
}
}

View File

@ -82,7 +82,7 @@ func init() {
code.Register(
ErrMissingRequiredParam,
"Missing required parameters {param}. Please review the API documentation and ensure all mandatory fields are included in your request.",
"Missing required parameters{param}. Please review the API documentation and ensure all mandatory fields are included in your request.",
code.WithAffectStability(false),
)