refactor: how to add a node type in workflow (#558)
This commit is contained in:
parent
5dafd81a3f
commit
bb6ff0026b
|
|
@ -105,6 +105,7 @@ import (
|
|||
|
||||
func TestMain(m *testing.M) {
|
||||
callbacks.AppendGlobalHandlers(service.GetTokenCallbackHandler())
|
||||
service.RegisterAllNodeAdaptors()
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
|
|
@ -123,6 +124,7 @@ type wfTestRunner struct {
|
|||
knowledge *knowledgemock.MockKnowledgeOperator
|
||||
database *databasemock.MockDatabaseOperator
|
||||
pluginSrv *pluginmock.MockService
|
||||
internalModel *testutil.UTChatModel
|
||||
ctx context.Context
|
||||
closeFn func()
|
||||
}
|
||||
|
|
@ -243,9 +245,11 @@ func newWfTestRunner(t *testing.T) *wfTestRunner {
|
|||
|
||||
cpStore := checkpoint.NewRedisStore(redisClient)
|
||||
|
||||
utChatModel := &testutil.UTChatModel{}
|
||||
|
||||
mockTos := storageMock.NewMockStorage(ctrl)
|
||||
mockTos.EXPECT().GetObjectUrl(gomock.Any(), gomock.Any(), gomock.Any()).Return("", nil).AnyTimes()
|
||||
workflowRepo := service.NewWorkflowRepository(mockIDGen, db, redisClient, mockTos, cpStore, nil)
|
||||
workflowRepo := service.NewWorkflowRepository(mockIDGen, db, redisClient, mockTos, cpStore, utChatModel)
|
||||
mockey.Mock(appworkflow.GetWorkflowDomainSVC).Return(service.NewWorkflowService(workflowRepo)).Build()
|
||||
mockey.Mock(workflow2.GetRepository).Return(workflowRepo).Build()
|
||||
|
||||
|
|
@ -325,6 +329,7 @@ func newWfTestRunner(t *testing.T) *wfTestRunner {
|
|||
tos: mockTos,
|
||||
knowledge: mockKwOperator,
|
||||
database: mockDatabaseOperator,
|
||||
internalModel: utChatModel,
|
||||
ctx: context.Background(),
|
||||
closeFn: f,
|
||||
pluginSrv: mockPluginSrv,
|
||||
|
|
@ -1110,7 +1115,8 @@ func TestValidateTree(t *testing.T) {
|
|||
assert.Equal(t, i.Message, `node "代码_1" not connected`)
|
||||
}
|
||||
if i.NodeError.NodeID == "160892" {
|
||||
assert.Contains(t, i.Message, `node "意图识别"'s port "branch_1" not connected`, `node "意图识别"'s port "default" not connected;`)
|
||||
assert.Contains(t, i.Message, `node "意图识别"'s port "branch_1" not connected`)
|
||||
assert.Contains(t, i.Message, `node "意图识别"'s port "default" not connected`)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -1157,7 +1163,8 @@ func TestValidateTree(t *testing.T) {
|
|||
assert.Equal(t, i.Message, `node "代码_1" not connected`)
|
||||
}
|
||||
if i.NodeError.NodeID == "160892" {
|
||||
assert.Contains(t, i.Message, `node "意图识别"'s port "branch_1" not connected`, `node "意图识别"'s port "default" not connected;`)
|
||||
assert.Contains(t, i.Message, `node "意图识别"'s port "branch_1" not connected`)
|
||||
assert.Contains(t, i.Message, `node "意图识别"'s port "default" not connected`)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -2950,8 +2957,8 @@ func TestLLMWithSkills(t *testing.T) {
|
|||
r := newWfTestRunner(t)
|
||||
defer r.closeFn()
|
||||
|
||||
utChatModel := &testutil.UTChatModel{
|
||||
InvokeResultProvider: func(index int, in []*schema.Message) (*schema.Message, error) {
|
||||
utChatModel := r.internalModel
|
||||
utChatModel.InvokeResultProvider = func(index int, in []*schema.Message) (*schema.Message, error) {
|
||||
if index == 0 {
|
||||
assert.Equal(t, 1, len(in))
|
||||
assert.Contains(t, in[0].Content, "7512369185624686592", "你是一个知识库意图识别AI Agent", "北京有哪些著名的景点")
|
||||
|
|
@ -2983,8 +2990,8 @@ func TestLLMWithSkills(t *testing.T) {
|
|||
}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected index: %d", index)
|
||||
},
|
||||
}
|
||||
|
||||
r.modelManage.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(utChatModel, nil, nil).AnyTimes()
|
||||
|
||||
r.knowledge.EXPECT().ListKnowledgeDetail(gomock.Any(), gomock.Any()).Return(&knowledge.ListKnowledgeDetailResponse{
|
||||
|
|
@ -3767,10 +3774,10 @@ func TestReleaseApplicationWorkflows(t *testing.T) {
|
|||
var validateCv func(ns []*vo.Node)
|
||||
validateCv = func(ns []*vo.Node) {
|
||||
for _, n := range ns {
|
||||
if n.Type == vo.BlockTypeBotSubWorkflow {
|
||||
if n.Type == entity.NodeTypeSubWorkflow.IDStr() {
|
||||
assert.Equal(t, n.Data.Inputs.WorkflowVersion, version)
|
||||
}
|
||||
if n.Type == vo.BlockTypeBotAPI {
|
||||
if n.Type == entity.NodeTypePlugin.IDStr() {
|
||||
for _, apiParam := range n.Data.Inputs.APIParams {
|
||||
// In the application, the workflow plugin node When the plugin version is equal to 0, the plugin is a plugin created in the application
|
||||
if apiParam.Name == "pluginVersion" {
|
||||
|
|
@ -3779,7 +3786,7 @@ func TestReleaseApplicationWorkflows(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
if n.Type == vo.BlockTypeBotLLM {
|
||||
if n.Type == entity.NodeTypeLLM.IDStr() {
|
||||
if n.Data.Inputs.FCParam != nil && n.Data.Inputs.FCParam.PluginFCParam != nil {
|
||||
// In the application, the workflow llm node When the plugin version is equal to 0, the plugin is a plugin created in the application
|
||||
for _, p := range n.Data.Inputs.FCParam.PluginFCParam.PluginList {
|
||||
|
|
@ -4063,8 +4070,8 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) {
|
|||
var validateSubWorkflowIDs func(nodes []*vo.Node)
|
||||
validateSubWorkflowIDs = func(nodes []*vo.Node) {
|
||||
for _, node := range nodes {
|
||||
switch node.Type {
|
||||
case vo.BlockTypeBotAPI:
|
||||
switch entity.IDStrToNodeType(node.Type) {
|
||||
case entity.NodeTypePlugin:
|
||||
apiParams := slices.ToMap(node.Data.Inputs.APIParams, func(e *vo.Param) (string, *vo.Param) {
|
||||
return e.Name, e
|
||||
})
|
||||
|
|
@ -4082,7 +4089,7 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) {
|
|||
assert.Equal(t, "100100", pID)
|
||||
}
|
||||
|
||||
case vo.BlockTypeBotSubWorkflow:
|
||||
case entity.NodeTypeSubWorkflow:
|
||||
assert.True(t, copiedIDMap[node.Data.Inputs.WorkflowID])
|
||||
wfId, err := strconv.ParseInt(node.Data.Inputs.WorkflowID, 10, 64)
|
||||
assert.NoError(t, err)
|
||||
|
|
@ -4096,7 +4103,7 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) {
|
|||
err = sonic.UnmarshalString(subWf.Canvas, subworkflowCanvas)
|
||||
assert.NoError(t, err)
|
||||
validateSubWorkflowIDs(subworkflowCanvas.Nodes)
|
||||
case vo.BlockTypeBotLLM:
|
||||
case entity.NodeTypeLLM:
|
||||
if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil {
|
||||
for _, w := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList {
|
||||
assert.True(t, copiedIDMap[w.WorkflowID])
|
||||
|
|
@ -4116,13 +4123,13 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) {
|
|||
assert.Equal(t, "100100", k.ID)
|
||||
}
|
||||
}
|
||||
case vo.BlockTypeBotDataset, vo.BlockTypeBotDatasetWrite:
|
||||
case entity.NodeTypeKnowledgeIndexer, entity.NodeTypeKnowledgeRetriever:
|
||||
datasetListInfoParam := node.Data.Inputs.DatasetParam[0]
|
||||
knowledgeIDs := datasetListInfoParam.Input.Value.Content.([]any)
|
||||
for idx := range knowledgeIDs {
|
||||
assert.Equal(t, "100100", knowledgeIDs[idx].(string))
|
||||
}
|
||||
case vo.BlockTypeDatabase, vo.BlockTypeDatabaseSelect, vo.BlockTypeDatabaseInsert, vo.BlockTypeDatabaseDelete, vo.BlockTypeDatabaseUpdate:
|
||||
case entity.NodeTypeDatabaseCustomSQL, entity.NodeTypeDatabaseQuery, entity.NodeTypeDatabaseInsert, entity.NodeTypeDatabaseDelete, entity.NodeTypeDatabaseUpdate:
|
||||
for _, d := range node.Data.Inputs.DatabaseInfoList {
|
||||
assert.Equal(t, "100100", d.DatabaseInfoID)
|
||||
}
|
||||
|
|
@ -4208,10 +4215,10 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) {
|
|||
var validateSubWorkflowIDs func(nodes []*vo.Node)
|
||||
validateSubWorkflowIDs = func(nodes []*vo.Node) {
|
||||
for _, node := range nodes {
|
||||
switch node.Type {
|
||||
case vo.BlockTypeBotSubWorkflow:
|
||||
switch entity.IDStrToNodeType(node.Type) {
|
||||
case entity.NodeTypeSubWorkflow:
|
||||
assert.True(t, copiedIDMap[node.Data.Inputs.WorkflowID])
|
||||
case vo.BlockTypeBotLLM:
|
||||
case entity.NodeTypeLLM:
|
||||
if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil {
|
||||
for _, w := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList {
|
||||
assert.True(t, copiedIDMap[w.WorkflowID])
|
||||
|
|
@ -4229,13 +4236,13 @@ func TestCopyWorkflowAppToLibrary(t *testing.T) {
|
|||
assert.Equal(t, "100100", k.ID)
|
||||
}
|
||||
}
|
||||
case vo.BlockTypeBotDataset, vo.BlockTypeBotDatasetWrite:
|
||||
case entity.NodeTypeKnowledgeIndexer, entity.NodeTypeKnowledgeRetriever:
|
||||
datasetListInfoParam := node.Data.Inputs.DatasetParam[0]
|
||||
knowledgeIDs := datasetListInfoParam.Input.Value.Content.([]any)
|
||||
for idx := range knowledgeIDs {
|
||||
assert.Equal(t, "100100", knowledgeIDs[idx].(string))
|
||||
}
|
||||
case vo.BlockTypeDatabase, vo.BlockTypeDatabaseSelect, vo.BlockTypeDatabaseInsert, vo.BlockTypeDatabaseDelete, vo.BlockTypeDatabaseUpdate:
|
||||
case entity.NodeTypeDatabaseCustomSQL, entity.NodeTypeDatabaseQuery, entity.NodeTypeDatabaseInsert, entity.NodeTypeDatabaseDelete, entity.NodeTypeDatabaseUpdate:
|
||||
for _, d := range node.Data.Inputs.DatabaseInfoList {
|
||||
assert.Equal(t, "100100", d.DatabaseInfoID)
|
||||
}
|
||||
|
|
@ -4356,7 +4363,7 @@ func TestMoveWorkflowAppToLibrary(t *testing.T) {
|
|||
err = sonic.Unmarshal(data, mainCanvas)
|
||||
assert.NoError(t, err)
|
||||
for _, node := range mainCanvas.Nodes {
|
||||
if node.Type == vo.BlockTypeBotSubWorkflow {
|
||||
if node.Type == entity.NodeTypeSubWorkflow.IDStr() {
|
||||
if node.Data.Inputs.WorkflowID == "7516826260387921920" {
|
||||
node.Data.Inputs.WorkflowID = c1IdStr
|
||||
}
|
||||
|
|
@ -4372,7 +4379,7 @@ func TestMoveWorkflowAppToLibrary(t *testing.T) {
|
|||
err = sonic.Unmarshal(cc1Data, cc1Canvas)
|
||||
assert.NoError(t, err)
|
||||
for _, node := range cc1Canvas.Nodes {
|
||||
if node.Type == vo.BlockTypeBotSubWorkflow {
|
||||
if node.Type == entity.NodeTypeSubWorkflow.IDStr() {
|
||||
if node.Data.Inputs.WorkflowID == "7516826283318181888" {
|
||||
node.Data.Inputs.WorkflowID = c2IdStr
|
||||
}
|
||||
|
|
@ -4423,7 +4430,7 @@ func TestMoveWorkflowAppToLibrary(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
|
||||
for _, node := range newMainCanvas.Nodes {
|
||||
if node.Type == vo.BlockTypeBotSubWorkflow {
|
||||
if node.Type == entity.NodeTypeSubWorkflow.IDStr() {
|
||||
assert.True(t, newSubWorkflowID[node.Data.Inputs.WorkflowID])
|
||||
assert.Equal(t, "v0.0.1", node.Data.Inputs.WorkflowVersion)
|
||||
}
|
||||
|
|
@ -4437,7 +4444,7 @@ func TestMoveWorkflowAppToLibrary(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
|
||||
for _, node := range cc1Canvas.Nodes {
|
||||
if node.Type == vo.BlockTypeBotSubWorkflow {
|
||||
if node.Type == entity.NodeTypeSubWorkflow.IDStr() {
|
||||
assert.True(t, newSubWorkflowID[node.Data.Inputs.WorkflowID])
|
||||
assert.Equal(t, "v0.0.1", node.Data.Inputs.WorkflowVersion)
|
||||
}
|
||||
|
|
@ -4508,10 +4515,10 @@ func TestDuplicateWorkflowsByAppID(t *testing.T) {
|
|||
var validateSubWorkflowIDs func(nodes []*vo.Node)
|
||||
validateSubWorkflowIDs = func(nodes []*vo.Node) {
|
||||
for _, node := range nodes {
|
||||
if node.Type == vo.BlockTypeBotSubWorkflow {
|
||||
if node.Type == entity.NodeTypeSubWorkflow.IDStr() {
|
||||
assert.True(t, copiedIDMap[node.Data.Inputs.WorkflowID])
|
||||
}
|
||||
if node.Type == vo.BlockTypeBotLLM {
|
||||
if node.Type == entity.NodeTypeLLM.IDStr() {
|
||||
if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil {
|
||||
for _, w := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList {
|
||||
assert.True(t, copiedIDMap[w.WorkflowID])
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ import (
|
|||
|
||||
"github.com/cloudwego/eino/callbacks"
|
||||
"github.com/coze-dev/coze-studio/backend/application/internal"
|
||||
|
||||
wfdatabase "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/database"
|
||||
wfknowledge "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/knowledge"
|
||||
wfmodel "github.com/coze-dev/coze-studio/backend/crossdomain/workflow/model"
|
||||
|
|
@ -78,6 +79,9 @@ func InitService(ctx context.Context, components *ServiceComponents) (*Applicati
|
|||
if !ok {
|
||||
logs.CtxWarnf(ctx, "workflow builtin chat model for knowledge recall not configured")
|
||||
}
|
||||
|
||||
service.RegisterAllNodeAdaptors()
|
||||
|
||||
workflowRepo := service.NewWorkflowRepository(components.IDGen, components.DB, components.Cache,
|
||||
components.Tos, components.CPStore, bcm)
|
||||
workflow.SetRepository(workflowRepo)
|
||||
|
|
|
|||
|
|
@ -93,8 +93,8 @@ func (w *ApplicationService) GetNodeTemplateList(ctx context.Context, req *workf
|
|||
|
||||
toQueryTypes := make(map[entity.NodeType]bool)
|
||||
for _, t := range req.NodeTypes {
|
||||
entityType, err := nodeType2EntityNodeType(t)
|
||||
if err != nil {
|
||||
entityType := entity.IDStrToNodeType(t)
|
||||
if len(entityType) == 0 {
|
||||
logs.Warnf("get node type %v failed, err:=%v", t, err)
|
||||
continue
|
||||
}
|
||||
|
|
@ -116,23 +116,19 @@ func (w *ApplicationService) GetNodeTemplateList(ctx context.Context, req *workf
|
|||
Name: category,
|
||||
}
|
||||
for _, nodeMeta := range nodeMetaList {
|
||||
tplType, err := entityNodeTypeToAPINodeTemplateType(nodeMeta.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tpl := &workflow.NodeTemplate{
|
||||
ID: fmt.Sprintf("%d", nodeMeta.ID),
|
||||
Type: tplType,
|
||||
Type: workflow.NodeTemplateType(nodeMeta.ID),
|
||||
Name: ternary.IFElse(i18n.GetLocale(ctx) == i18n.LocaleEN, nodeMeta.EnUSName, nodeMeta.Name),
|
||||
Desc: ternary.IFElse(i18n.GetLocale(ctx) == i18n.LocaleEN, nodeMeta.EnUSDescription, nodeMeta.Desc),
|
||||
IconURL: nodeMeta.IconURL,
|
||||
SupportBatch: ternary.IFElse(nodeMeta.SupportBatch, workflow.SupportBatch_SUPPORT, workflow.SupportBatch_NOT_SUPPORT),
|
||||
NodeType: fmt.Sprintf("%d", tplType),
|
||||
NodeType: fmt.Sprintf("%d", nodeMeta.ID),
|
||||
Color: nodeMeta.Color,
|
||||
}
|
||||
|
||||
resp.Data.TemplateList = append(resp.Data.TemplateList, tpl)
|
||||
categoryMap[category].NodeTypeList = append(categoryMap[category].NodeTypeList, fmt.Sprintf("%d", tplType))
|
||||
categoryMap[category].NodeTypeList = append(categoryMap[category].NodeTypeList, fmt.Sprintf("%d", nodeMeta.ID))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -178,7 +174,7 @@ func (w *ApplicationService) CreateWorkflow(ctx context.Context, req *workflow.C
|
|||
IconURI: req.IconURI,
|
||||
AppID: parseInt64(req.ProjectID),
|
||||
Mode: ternary.IFElse(req.IsSetFlowMode(), req.GetFlowMode(), workflow.WorkflowMode_Workflow),
|
||||
InitCanvasSchema: entity.GetDefaultInitCanvasJsonSchema(i18n.GetLocale(ctx)),
|
||||
InitCanvasSchema: vo.GetDefaultInitCanvasJsonSchema(i18n.GetLocale(ctx)),
|
||||
}
|
||||
|
||||
id, err := GetWorkflowDomainSVC().Create(ctx, wf)
|
||||
|
|
@ -1041,7 +1037,8 @@ func (w *ApplicationService) CopyWorkflowFromLibraryToApp(ctx context.Context, w
|
|||
return wf.ID, nil
|
||||
}
|
||||
|
||||
func (w *ApplicationService) MoveWorkflowFromAppToLibrary(ctx context.Context, workflowID int64, spaceID, appID int64) (_ int64, _ []*vo.ValidateIssue, err error) {
|
||||
func (w *ApplicationService) MoveWorkflowFromAppToLibrary(ctx context.Context, workflowID int64, spaceID, /*not used for now*/
|
||||
appID int64) (_ int64, _ []*vo.ValidateIssue, err error) {
|
||||
defer func() {
|
||||
if panicErr := recover(); panicErr != nil {
|
||||
err = safego.NewPanicErr(panicErr, debug.Stack())
|
||||
|
|
@ -1127,15 +1124,10 @@ func (w *ApplicationService) MoveWorkflowFromAppToLibrary(ctx context.Context, w
|
|||
}
|
||||
|
||||
func convertNodeExecution(nodeExe *entity.NodeExecution) (*workflow.NodeResult, error) {
|
||||
nType, err := entityNodeTypeToAPINodeTemplateType(nodeExe.NodeType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nr := &workflow.NodeResult{
|
||||
NodeId: nodeExe.NodeID,
|
||||
NodeName: nodeExe.NodeName,
|
||||
NodeType: nType.String(),
|
||||
NodeType: entity.NodeMetaByNodeType(nodeExe.NodeType).GetDisplayKey(),
|
||||
NodeStatus: workflow.NodeExeStatus(nodeExe.Status),
|
||||
ErrorInfo: ptr.FromOrDefault(nodeExe.ErrorInfo, ""),
|
||||
Input: ptr.FromOrDefault(nodeExe.Input, ""),
|
||||
|
|
@ -1316,13 +1308,6 @@ func convertStreamRunEvent(workflowID int64) func(msg *entity.Message) (res *wor
|
|||
return nil, schema.ErrNoValue
|
||||
}
|
||||
|
||||
var nodeType workflow.NodeTemplateType
|
||||
nodeType, err = entityNodeTypeToAPINodeTemplateType(msg.NodeType)
|
||||
if err != nil {
|
||||
logs.Errorf("convert node type %v failed, err:=%v", msg.NodeType, err)
|
||||
nodeType = workflow.NodeTemplateType(0)
|
||||
}
|
||||
|
||||
res = &workflow.OpenAPIStreamRunFlowResponse{
|
||||
ID: strconv.Itoa(messageID),
|
||||
Event: string(MessageEvent),
|
||||
|
|
@ -1330,7 +1315,7 @@ func convertStreamRunEvent(workflowID int64) func(msg *entity.Message) (res *wor
|
|||
Content: ptr.Of(msg.Content),
|
||||
ContentType: ptr.Of("text"),
|
||||
NodeIsFinish: ptr.Of(msg.Last),
|
||||
NodeType: ptr.Of(nodeType.String()),
|
||||
NodeType: ptr.Of(entity.NodeMetaByNodeType(msg.NodeType).GetDisplayKey()),
|
||||
NodeID: ptr.Of(msg.NodeID),
|
||||
}
|
||||
|
||||
|
|
@ -3344,178 +3329,6 @@ func toWorkflowParameter(nType *vo.NamedTypeInfo) (*workflow.Parameter, error) {
|
|||
return wp, nil
|
||||
}
|
||||
|
||||
func nodeType2EntityNodeType(t string) (entity.NodeType, error) {
|
||||
i, err := strconv.Atoi(t)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid node type string '%s': %w", t, err)
|
||||
}
|
||||
|
||||
switch i {
|
||||
case 1:
|
||||
return entity.NodeTypeEntry, nil
|
||||
case 2:
|
||||
return entity.NodeTypeExit, nil
|
||||
case 3:
|
||||
return entity.NodeTypeLLM, nil
|
||||
case 4:
|
||||
return entity.NodeTypePlugin, nil
|
||||
case 5:
|
||||
return entity.NodeTypeCodeRunner, nil
|
||||
case 6:
|
||||
return entity.NodeTypeKnowledgeRetriever, nil
|
||||
case 8:
|
||||
return entity.NodeTypeSelector, nil
|
||||
case 9:
|
||||
return entity.NodeTypeSubWorkflow, nil
|
||||
case 12:
|
||||
return entity.NodeTypeDatabaseCustomSQL, nil
|
||||
case 13:
|
||||
return entity.NodeTypeOutputEmitter, nil
|
||||
case 15:
|
||||
return entity.NodeTypeTextProcessor, nil
|
||||
case 18:
|
||||
return entity.NodeTypeQuestionAnswer, nil
|
||||
case 19:
|
||||
return entity.NodeTypeBreak, nil
|
||||
case 20:
|
||||
return entity.NodeTypeVariableAssignerWithinLoop, nil
|
||||
case 21:
|
||||
return entity.NodeTypeLoop, nil
|
||||
case 22:
|
||||
return entity.NodeTypeIntentDetector, nil
|
||||
case 27:
|
||||
return entity.NodeTypeKnowledgeIndexer, nil
|
||||
case 28:
|
||||
return entity.NodeTypeBatch, nil
|
||||
case 29:
|
||||
return entity.NodeTypeContinue, nil
|
||||
case 30:
|
||||
return entity.NodeTypeInputReceiver, nil
|
||||
case 32:
|
||||
return entity.NodeTypeVariableAggregator, nil
|
||||
case 37:
|
||||
return entity.NodeTypeMessageList, nil
|
||||
case 38:
|
||||
return entity.NodeTypeClearMessage, nil
|
||||
case 39:
|
||||
return entity.NodeTypeCreateConversation, nil
|
||||
case 40:
|
||||
return entity.NodeTypeVariableAssigner, nil
|
||||
case 42:
|
||||
return entity.NodeTypeDatabaseUpdate, nil
|
||||
case 43:
|
||||
return entity.NodeTypeDatabaseQuery, nil
|
||||
case 44:
|
||||
return entity.NodeTypeDatabaseDelete, nil
|
||||
case 45:
|
||||
return entity.NodeTypeHTTPRequester, nil
|
||||
case 46:
|
||||
return entity.NodeTypeDatabaseInsert, nil
|
||||
case 58:
|
||||
return entity.NodeTypeJsonSerialization, nil
|
||||
case 59:
|
||||
return entity.NodeTypeJsonDeserialization, nil
|
||||
case 60:
|
||||
return entity.NodeTypeKnowledgeDeleter, nil
|
||||
default:
|
||||
// Handle all unknown or unsupported types here
|
||||
return "", fmt.Errorf("unsupported or unknown node type ID: %d", i)
|
||||
}
|
||||
}
|
||||
|
||||
// entityNodeTypeToAPINodeTemplateType converts an entity.NodeType to the corresponding workflow.NodeTemplateType.
|
||||
func entityNodeTypeToAPINodeTemplateType(nodeType entity.NodeType) (workflow.NodeTemplateType, error) {
|
||||
switch nodeType {
|
||||
case entity.NodeTypeEntry:
|
||||
return workflow.NodeTemplateType_Start, nil
|
||||
case entity.NodeTypeExit:
|
||||
return workflow.NodeTemplateType_End, nil
|
||||
case entity.NodeTypeLLM:
|
||||
return workflow.NodeTemplateType_LLM, nil
|
||||
case entity.NodeTypePlugin:
|
||||
// Maps to Api type in the API model
|
||||
return workflow.NodeTemplateType_Api, nil
|
||||
case entity.NodeTypeCodeRunner:
|
||||
return workflow.NodeTemplateType_Code, nil
|
||||
case entity.NodeTypeKnowledgeRetriever:
|
||||
// Maps to Dataset type in the API model
|
||||
return workflow.NodeTemplateType_Dataset, nil
|
||||
case entity.NodeTypeSelector:
|
||||
// Maps to If type in the API model
|
||||
return workflow.NodeTemplateType_If, nil
|
||||
case entity.NodeTypeSubWorkflow:
|
||||
return workflow.NodeTemplateType_SubWorkflow, nil
|
||||
case entity.NodeTypeDatabaseCustomSQL:
|
||||
// Maps to the generic Database type in the API model
|
||||
return workflow.NodeTemplateType_Database, nil
|
||||
case entity.NodeTypeOutputEmitter:
|
||||
// Maps to Message type in the API model
|
||||
return workflow.NodeTemplateType_Message, nil
|
||||
case entity.NodeTypeTextProcessor:
|
||||
return workflow.NodeTemplateType_Text, nil
|
||||
case entity.NodeTypeQuestionAnswer:
|
||||
return workflow.NodeTemplateType_Question, nil
|
||||
case entity.NodeTypeBreak:
|
||||
return workflow.NodeTemplateType_Break, nil
|
||||
case entity.NodeTypeVariableAssigner:
|
||||
return workflow.NodeTemplateType_AssignVariable, nil
|
||||
case entity.NodeTypeVariableAssignerWithinLoop:
|
||||
return workflow.NodeTemplateType_LoopSetVariable, nil
|
||||
case entity.NodeTypeLoop:
|
||||
return workflow.NodeTemplateType_Loop, nil
|
||||
case entity.NodeTypeIntentDetector:
|
||||
return workflow.NodeTemplateType_Intent, nil
|
||||
case entity.NodeTypeKnowledgeIndexer:
|
||||
// Maps to DatasetWrite type in the API model
|
||||
return workflow.NodeTemplateType_DatasetWrite, nil
|
||||
case entity.NodeTypeBatch:
|
||||
return workflow.NodeTemplateType_Batch, nil
|
||||
case entity.NodeTypeContinue:
|
||||
return workflow.NodeTemplateType_Continue, nil
|
||||
case entity.NodeTypeInputReceiver:
|
||||
return workflow.NodeTemplateType_Input, nil
|
||||
case entity.NodeTypeMessageList:
|
||||
return workflow.NodeTemplateType(37), nil
|
||||
case entity.NodeTypeVariableAggregator:
|
||||
return workflow.NodeTemplateType(32), nil
|
||||
case entity.NodeTypeClearMessage:
|
||||
return workflow.NodeTemplateType(38), nil
|
||||
case entity.NodeTypeCreateConversation:
|
||||
return workflow.NodeTemplateType(39), nil
|
||||
// Note: entity.NodeTypeVariableAggregator (ID 32) has no direct mapping in NodeTemplateType
|
||||
// Note: entity.NodeTypeMessageList (ID 37) has no direct mapping in NodeTemplateType
|
||||
// Note: entity.NodeTypeClearMessage (ID 38) has no direct mapping in NodeTemplateType
|
||||
// Note: entity.NodeTypeCreateConversation (ID 39) has no direct mapping in NodeTemplateType
|
||||
case entity.NodeTypeDatabaseUpdate:
|
||||
return workflow.NodeTemplateType_DatabaseUpdate, nil
|
||||
case entity.NodeTypeDatabaseQuery:
|
||||
// Maps to DatabasesELECT (ID 43) in the API model (note potential typo)
|
||||
return workflow.NodeTemplateType_DatabasesELECT, nil
|
||||
case entity.NodeTypeDatabaseDelete:
|
||||
return workflow.NodeTemplateType_DatabaseDelete, nil
|
||||
|
||||
// Note: entity.NodeTypeHTTPRequester (ID 45) has no direct mapping in NodeTemplateType
|
||||
case entity.NodeTypeHTTPRequester:
|
||||
return workflow.NodeTemplateType(45), nil
|
||||
|
||||
case entity.NodeTypeDatabaseInsert:
|
||||
// Maps to DatabaseInsert (ID 41) in the API model, despite entity ID being 46.
|
||||
// return workflow.NodeTemplateType_DatabaseInsert, nil
|
||||
return workflow.NodeTemplateType(46), nil
|
||||
case entity.NodeTypeJsonSerialization:
|
||||
return workflow.NodeTemplateType(58), nil
|
||||
case entity.NodeTypeJsonDeserialization:
|
||||
return workflow.NodeTemplateType_JsonDeserialization, nil
|
||||
case entity.NodeTypeKnowledgeDeleter:
|
||||
return workflow.NodeTemplateType_DatasetDelete, nil
|
||||
case entity.NodeTypeLambda:
|
||||
return 0, nil
|
||||
default:
|
||||
// Handle entity types that don't have a corresponding NodeTemplateType
|
||||
return workflow.NodeTemplateType(0), fmt.Errorf("cannot map entity node type '%s' to a workflow.NodeTemplateType", nodeType)
|
||||
}
|
||||
}
|
||||
|
||||
func i64PtrToStringPtr(i *int64) *string {
|
||||
if i == nil {
|
||||
return nil
|
||||
|
|
@ -3761,7 +3574,7 @@ func mergeWorkflowAPIParameters(latestAPIParameters []*workflow.APIParameter, ex
|
|||
func parseWorkflowTerminatePlanType(c *vo.Canvas) (int32, error) {
|
||||
var endNode *vo.Node
|
||||
for _, n := range c.Nodes {
|
||||
if n.Type == vo.BlockTypeBotEnd {
|
||||
if n.Type == entity.NodeTypeExit.IDStr() {
|
||||
endNode = n
|
||||
break
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,49 +0,0 @@
|
|||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
. "github.com/bytedance/mockey"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGenRequestString(t *testing.T) {
|
||||
PatchConvey("", t, func() {
|
||||
requestStr, err := genRequestString(&http.Request{
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
},
|
||||
Method: http.MethodPost,
|
||||
URL: &url.URL{Path: "/test"},
|
||||
}, []byte(`{"a": 1}`))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, `{"header":{"Content-Type":["application/json"]},"query":null,"path":"/test","body":{"a": 1}}`, requestStr)
|
||||
})
|
||||
|
||||
PatchConvey("", t, func() {
|
||||
var body []byte
|
||||
requestStr, err := genRequestString(&http.Request{
|
||||
URL: &url.URL{Path: "/test"},
|
||||
}, body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, `{"header":null,"query":null,"path":"/test","body":null}`, requestStr)
|
||||
})
|
||||
}
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
/*
|
||||
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
|
@ -16,12 +17,39 @@
|
|||
|
||||
package entity
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type NodeType string
|
||||
|
||||
func (nt NodeType) IDStr() string {
|
||||
m := NodeMetaByNodeType(nt)
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%d", m.ID)
|
||||
}
|
||||
|
||||
func IDStrToNodeType(s string) NodeType {
|
||||
id, err := strconv.ParseInt(s, 10, 64)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
for _, m := range NodeTypeMetas {
|
||||
if m.ID == id {
|
||||
return m.Key
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type NodeTypeMeta struct {
|
||||
ID int64 `json:"id"`
|
||||
ID int64
|
||||
Key NodeType
|
||||
DisplayKey string
|
||||
Name string `json:"name"`
|
||||
Type NodeType `json:"type"`
|
||||
Category string `json:"category"`
|
||||
Color string `json:"color"`
|
||||
Desc string `json:"desc"`
|
||||
|
|
@ -34,40 +62,40 @@ type NodeTypeMeta struct {
|
|||
ExecutableMeta
|
||||
}
|
||||
|
||||
func (ntm *NodeTypeMeta) GetDisplayKey() string {
|
||||
if len(ntm.DisplayKey) > 0 {
|
||||
return ntm.DisplayKey
|
||||
}
|
||||
|
||||
return string(ntm.Key)
|
||||
}
|
||||
|
||||
type Category struct {
|
||||
Key string `json:"key"`
|
||||
Name string `json:"name"`
|
||||
EnUSName string `json:"en_us_name"`
|
||||
}
|
||||
|
||||
type StreamingParadigm string
|
||||
|
||||
const (
|
||||
Invoke StreamingParadigm = "invoke"
|
||||
Stream StreamingParadigm = "stream"
|
||||
Collect StreamingParadigm = "collect"
|
||||
Transform StreamingParadigm = "transform"
|
||||
)
|
||||
|
||||
type ExecutableMeta struct {
|
||||
IsComposite bool `json:"is_composite,omitempty"`
|
||||
DefaultTimeoutMS int64 `json:"default_timeout_ms,omitempty"` // default timeout in milliseconds, 0 means no timeout
|
||||
PreFillZero bool `json:"pre_fill_zero,omitempty"`
|
||||
PostFillNil bool `json:"post_fill_nil,omitempty"`
|
||||
CallbackEnabled bool `json:"callback_enabled,omitempty"` // is false, Eino framework will inject callbacks for this node
|
||||
MayUseChatModel bool `json:"may_use_chat_model,omitempty"`
|
||||
InputSourceAware bool `json:"input_source_aware,omitempty"` // whether this node needs to know the runtime status of its input sources
|
||||
StreamingParadigms map[StreamingParadigm]bool `json:"streaming_paradigms,omitempty"`
|
||||
StreamSourceEOFAware bool `json:"needs_stream_source_eof,omitempty"` // whether this node needs to be aware stream sources' SourceEOF error
|
||||
/*
|
||||
IncrementalOutput indicates that the node's output is intended for progressive, user-facing streaming.
|
||||
This distinguishes nodes that actually stream text to the user (e.g., 'Exit', 'Output')
|
||||
from those that are merely capable of streaming internally (defined by StreamingParadigms),
|
||||
whose output is consumed by other nodes.
|
||||
In essence, nodes with IncrementalOutput are a subset of those defined in StreamingParadigms.
|
||||
When set to true, stream chunks from the node are persisted in real-time and can be fetched by get_process.
|
||||
*/
|
||||
|
||||
// IncrementalOutput indicates that the node's output is intended for progressive, user-facing streaming.
|
||||
// This distinguishes nodes that actually stream text to the user (e.g., 'Exit', 'Output')
|
||||
//from those that are merely capable of streaming internally (defined by StreamingParadigms),
|
||||
// In essence, nodes with IncrementalOutput are a subset of those defined in StreamingParadigms.
|
||||
// When set to true, stream chunks from the node are persisted in real-time and can be fetched by get_process.
|
||||
IncrementalOutput bool `json:"incremental_output,omitempty"`
|
||||
|
||||
// UseCtxCache indicates that the node would require a newly initialized ctx cache for each invocation.
|
||||
// example use cases:
|
||||
// - write warnings to the ctx cache during Invoke, and read from the ctx within Callback output converter
|
||||
UseCtxCache bool `json:"use_ctx_cache"`
|
||||
}
|
||||
|
||||
type PluginNodeMeta struct {
|
||||
|
|
@ -125,9 +153,700 @@ const (
|
|||
NodeTypeSubWorkflow NodeType = "SubWorkflow"
|
||||
NodeTypeJsonSerialization NodeType = "JsonSerialization"
|
||||
NodeTypeJsonDeserialization NodeType = "JsonDeserialization"
|
||||
NodeTypeComment NodeType = "Comment"
|
||||
)
|
||||
|
||||
const (
|
||||
EntryNodeKey = "100001"
|
||||
ExitNodeKey = "900001"
|
||||
)
|
||||
|
||||
var Categories = []Category{
|
||||
{
|
||||
Key: "", // this is the default category. some of the most important nodes belong here, such as LLM, plugin, sub-workflow
|
||||
Name: "",
|
||||
EnUSName: "",
|
||||
},
|
||||
{
|
||||
Key: "logic",
|
||||
Name: "业务逻辑",
|
||||
EnUSName: "Logic",
|
||||
},
|
||||
{
|
||||
Key: "input&output",
|
||||
Name: "输入&输出",
|
||||
EnUSName: "Input&Output",
|
||||
},
|
||||
{
|
||||
Key: "database",
|
||||
Name: "数据库",
|
||||
EnUSName: "Database",
|
||||
},
|
||||
{
|
||||
Key: "data",
|
||||
Name: "知识库&数据",
|
||||
EnUSName: "Data",
|
||||
},
|
||||
{
|
||||
Key: "image",
|
||||
Name: "图像处理",
|
||||
EnUSName: "Image",
|
||||
},
|
||||
{
|
||||
Key: "audio&video",
|
||||
Name: "音视频处理",
|
||||
EnUSName: "Audio&Video",
|
||||
},
|
||||
{
|
||||
Key: "utilities",
|
||||
Name: "组件",
|
||||
EnUSName: "Utilities",
|
||||
},
|
||||
{
|
||||
Key: "conversation_management",
|
||||
Name: "会话管理",
|
||||
EnUSName: "Conversation management",
|
||||
},
|
||||
{
|
||||
Key: "conversation_history",
|
||||
Name: "会话历史",
|
||||
EnUSName: "Conversation history",
|
||||
},
|
||||
{
|
||||
Key: "message",
|
||||
Name: "消息",
|
||||
EnUSName: "Message",
|
||||
},
|
||||
}
|
||||
|
||||
// NodeTypeMetas holds the metadata for all available node types.
|
||||
// It is initialized with built-in node types and potentially extended by loading from external sources.
|
||||
var NodeTypeMetas = map[NodeType]*NodeTypeMeta{
|
||||
NodeTypeEntry: {
|
||||
ID: 1,
|
||||
Key: NodeTypeEntry,
|
||||
DisplayKey: "Start",
|
||||
Name: "开始",
|
||||
Category: "input&output",
|
||||
Desc: "工作流的起始节点,用于设定启动工作流需要的信息",
|
||||
Color: "#5C62FF",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start-v2.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PostFillNil: true,
|
||||
},
|
||||
EnUSName: "Start",
|
||||
EnUSDescription: "The starting node of the workflow, used to set the information needed to initiate the workflow.",
|
||||
},
|
||||
NodeTypeExit: {
|
||||
ID: 2,
|
||||
Key: NodeTypeExit,
|
||||
DisplayKey: "End",
|
||||
Name: "结束",
|
||||
Category: "input&output",
|
||||
Desc: "工作流的最终节点,用于返回工作流运行后的结果信息",
|
||||
Color: "#5C62FF",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End-v2.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PreFillZero: true,
|
||||
InputSourceAware: true,
|
||||
StreamSourceEOFAware: true,
|
||||
IncrementalOutput: true,
|
||||
},
|
||||
EnUSName: "End",
|
||||
EnUSDescription: "The final node of the workflow, used to return the result information after the workflow runs.",
|
||||
},
|
||||
NodeTypeLLM: {
|
||||
ID: 3,
|
||||
Key: NodeTypeLLM,
|
||||
DisplayKey: "LLM",
|
||||
Name: "大模型",
|
||||
Category: "",
|
||||
Desc: "调用大语言模型,使用变量和提示词生成回复",
|
||||
Color: "#5C62FF",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-LLM-v2.jpg",
|
||||
SupportBatch: true,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
DefaultTimeoutMS: 3 * 60 * 1000, // 3 minutes
|
||||
PreFillZero: true,
|
||||
PostFillNil: true,
|
||||
InputSourceAware: true,
|
||||
MayUseChatModel: true,
|
||||
},
|
||||
EnUSName: "LLM",
|
||||
EnUSDescription: "Invoke the large language model, generate responses using variables and prompt words.",
|
||||
},
|
||||
NodeTypePlugin: {
|
||||
ID: 4,
|
||||
Key: NodeTypePlugin,
|
||||
DisplayKey: "Api",
|
||||
Name: "插件",
|
||||
Category: "",
|
||||
Desc: "通过添加工具访问实时数据和执行外部操作",
|
||||
Color: "#CA61FF",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Plugin-v2.jpg",
|
||||
SupportBatch: true,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
DefaultTimeoutMS: 3 * 60 * 1000, // 3 minutes
|
||||
PreFillZero: true,
|
||||
PostFillNil: true,
|
||||
},
|
||||
EnUSName: "Plugin",
|
||||
EnUSDescription: "Used to access external real-time data and perform operations",
|
||||
},
|
||||
NodeTypeCodeRunner: {
|
||||
ID: 5,
|
||||
Key: NodeTypeCodeRunner,
|
||||
DisplayKey: "Code",
|
||||
Name: "代码",
|
||||
Category: "logic",
|
||||
Desc: "编写代码,处理输入变量来生成返回值",
|
||||
Color: "#00B2B2",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Code-v2.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
DefaultTimeoutMS: 60 * 1000, // 1 minute
|
||||
PreFillZero: true,
|
||||
PostFillNil: true,
|
||||
UseCtxCache: true,
|
||||
},
|
||||
EnUSName: "Code",
|
||||
EnUSDescription: "Write code to process input variables to generate return values.",
|
||||
},
|
||||
NodeTypeKnowledgeRetriever: {
|
||||
ID: 6,
|
||||
Key: NodeTypeKnowledgeRetriever,
|
||||
DisplayKey: "Dataset",
|
||||
Name: "知识库检索",
|
||||
Category: "data",
|
||||
Desc: "在选定的知识中,根据输入变量召回最匹配的信息,并以列表形式返回",
|
||||
Color: "#FF811A",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-KnowledgeQuery-v2.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
DefaultTimeoutMS: 60 * 1000, // 1 minute
|
||||
PreFillZero: true,
|
||||
PostFillNil: true,
|
||||
},
|
||||
EnUSName: "Knowledge retrieval",
|
||||
EnUSDescription: "In the selected knowledge, the best matching information is recalled based on the input variable and returned as an Array.",
|
||||
},
|
||||
NodeTypeSelector: {
|
||||
ID: 8,
|
||||
Key: NodeTypeSelector,
|
||||
DisplayKey: "If",
|
||||
Name: "选择器",
|
||||
Category: "logic",
|
||||
Desc: "连接多个下游分支,若设定的条件成立则仅运行对应的分支,若均不成立则只运行“否则”分支",
|
||||
Color: "#00B2B2",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Condition-v2.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{},
|
||||
EnUSName: "Condition",
|
||||
EnUSDescription: "Connect multiple downstream branches. Only the corresponding branch will be executed if the set conditions are met. If none are met, only the 'else' branch will be executed.",
|
||||
},
|
||||
NodeTypeSubWorkflow: {
|
||||
ID: 9,
|
||||
Key: NodeTypeSubWorkflow,
|
||||
DisplayKey: "SubWorkflow",
|
||||
Name: "工作流",
|
||||
Category: "",
|
||||
Desc: "集成已发布工作流,可以执行嵌套子任务",
|
||||
Color: "#00B83E",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Workflow-v2.jpg",
|
||||
SupportBatch: true,
|
||||
ExecutableMeta: ExecutableMeta{},
|
||||
EnUSName: "Workflow",
|
||||
EnUSDescription: "Add published workflows to execute subtasks",
|
||||
},
|
||||
NodeTypeDatabaseCustomSQL: {
|
||||
ID: 12,
|
||||
Key: NodeTypeDatabaseCustomSQL,
|
||||
DisplayKey: "End",
|
||||
Name: "SQL自定义",
|
||||
Category: "database",
|
||||
Desc: "基于用户自定义的 SQL 完成对数据库的增删改查操作",
|
||||
Color: "#FF811A",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Database-v2.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
DefaultTimeoutMS: 60 * 1000, // 1 minute
|
||||
PreFillZero: true,
|
||||
PostFillNil: true,
|
||||
},
|
||||
EnUSName: "SQL Customization",
|
||||
EnUSDescription: "Complete the operations of adding, deleting, modifying and querying the database based on user-defined SQL",
|
||||
},
|
||||
NodeTypeOutputEmitter: {
|
||||
ID: 13,
|
||||
Key: NodeTypeOutputEmitter,
|
||||
DisplayKey: "Message",
|
||||
Name: "输出",
|
||||
Category: "input&output",
|
||||
Desc: "节点从“消息”更名为“输出”,支持中间过程的消息输出,支持流式和非流式两种方式",
|
||||
Color: "#5C62FF",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Output-v2.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PreFillZero: true,
|
||||
InputSourceAware: true,
|
||||
StreamSourceEOFAware: true,
|
||||
IncrementalOutput: true,
|
||||
},
|
||||
EnUSName: "Output",
|
||||
EnUSDescription: "The node is renamed from \"message\" to \"output\", Supports message output in the intermediate process and streaming and non-streaming methods",
|
||||
},
|
||||
NodeTypeTextProcessor: {
|
||||
ID: 15,
|
||||
Key: NodeTypeTextProcessor,
|
||||
DisplayKey: "Text",
|
||||
Name: "文本处理",
|
||||
Category: "utilities",
|
||||
Desc: "用于处理多个字符串类型变量的格式",
|
||||
Color: "#3071F2",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-StrConcat-v2.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PreFillZero: true,
|
||||
InputSourceAware: true,
|
||||
},
|
||||
EnUSName: "Text Processing",
|
||||
EnUSDescription: "The format used for handling multiple string-type variables.",
|
||||
},
|
||||
NodeTypeQuestionAnswer: {
|
||||
ID: 18,
|
||||
Key: NodeTypeQuestionAnswer,
|
||||
DisplayKey: "Question",
|
||||
Name: "问答",
|
||||
Category: "utilities",
|
||||
Desc: "支持中间向用户提问问题,支持预置选项提问和开放式问题提问两种方式",
|
||||
Color: "#3071F2",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Direct-Question-v2.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
DefaultTimeoutMS: 60 * 1000, // 1 minute
|
||||
PreFillZero: true,
|
||||
PostFillNil: true,
|
||||
MayUseChatModel: true,
|
||||
},
|
||||
EnUSName: "Question",
|
||||
EnUSDescription: "Support asking questions to the user in the middle of the conversation, with both preset options and open-ended questions",
|
||||
},
|
||||
NodeTypeBreak: {
|
||||
ID: 19,
|
||||
Key: NodeTypeBreak,
|
||||
DisplayKey: "Break",
|
||||
Name: "终止循环",
|
||||
Category: "logic",
|
||||
Desc: "用于立即终止当前所在的循环,跳出循环体",
|
||||
Color: "#00B2B2",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Break-v2.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{},
|
||||
EnUSName: "Break",
|
||||
EnUSDescription: "Used to immediately terminate the current loop and jump out of the loop",
|
||||
},
|
||||
NodeTypeVariableAssignerWithinLoop: {
|
||||
ID: 20,
|
||||
Key: NodeTypeVariableAssignerWithinLoop,
|
||||
DisplayKey: "LoopSetVariable",
|
||||
Name: "设置变量",
|
||||
Category: "logic",
|
||||
Desc: "用于重置循环变量的值,使其下次循环使用重置后的值",
|
||||
Color: "#00B2B2",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-LoopSetVariable-v2.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{},
|
||||
EnUSName: "Set Variable",
|
||||
EnUSDescription: "Used to reset the value of the loop variable so that it uses the reset value in the next iteration",
|
||||
},
|
||||
NodeTypeLoop: {
|
||||
ID: 21,
|
||||
Key: NodeTypeLoop,
|
||||
DisplayKey: "Loop",
|
||||
Name: "循环",
|
||||
Category: "logic",
|
||||
Desc: "用于通过设定循环次数和逻辑,重复执行一系列任务",
|
||||
Color: "#00B2B2",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Loop-v2.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
IsComposite: true,
|
||||
DefaultTimeoutMS: 60 * 1000, // 1 minute
|
||||
PreFillZero: true,
|
||||
PostFillNil: true,
|
||||
},
|
||||
EnUSName: "Loop",
|
||||
EnUSDescription: "Used to repeatedly execute a series of tasks by setting the number of iterations and logic",
|
||||
},
|
||||
NodeTypeIntentDetector: {
|
||||
ID: 22,
|
||||
Key: NodeTypeIntentDetector,
|
||||
DisplayKey: "Intent",
|
||||
Name: "意图识别",
|
||||
Category: "logic",
|
||||
Desc: "用于用户输入的意图识别,并将其与预设意图选项进行匹配。",
|
||||
Color: "#00B2B2",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Intent-v2.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
DefaultTimeoutMS: 60 * 1000, // 1 minute
|
||||
PreFillZero: true,
|
||||
PostFillNil: true,
|
||||
MayUseChatModel: true,
|
||||
},
|
||||
EnUSName: "Intent recognition",
|
||||
EnUSDescription: "Used for recognizing the intent in user input and matching it with preset intent options.",
|
||||
},
|
||||
NodeTypeKnowledgeIndexer: {
|
||||
ID: 27,
|
||||
Key: NodeTypeKnowledgeIndexer,
|
||||
DisplayKey: "DatasetWrite",
|
||||
Name: "知识库写入",
|
||||
Category: "data",
|
||||
Desc: "写入节点可以添加 文本类型 的知识库,仅可以添加一个知识库",
|
||||
Color: "#FF811A",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-KnowledgeWriting-v2.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
DefaultTimeoutMS: 60 * 1000, // 1 minute
|
||||
PreFillZero: true,
|
||||
PostFillNil: true,
|
||||
},
|
||||
EnUSName: "Knowledge writing",
|
||||
EnUSDescription: "The write node can add a knowledge base of type text. Only one knowledge base can be added.",
|
||||
},
|
||||
NodeTypeBatch: {
|
||||
ID: 28,
|
||||
Key: NodeTypeBatch,
|
||||
DisplayKey: "Batch",
|
||||
Name: "批处理",
|
||||
Category: "logic",
|
||||
Desc: "通过设定批量运行次数和逻辑,运行批处理体内的任务",
|
||||
Color: "#00B2B2",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Batch-v2.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
IsComposite: true,
|
||||
DefaultTimeoutMS: 60 * 1000, // 1 minute
|
||||
PreFillZero: true,
|
||||
PostFillNil: true,
|
||||
},
|
||||
EnUSName: "Batch",
|
||||
EnUSDescription: "By setting the number of batch runs and logic, run the tasks in the batch body.",
|
||||
},
|
||||
NodeTypeContinue: {
|
||||
ID: 29,
|
||||
Key: NodeTypeContinue,
|
||||
DisplayKey: "Continue",
|
||||
Name: "继续循环",
|
||||
Category: "logic",
|
||||
Desc: "用于终止当前循环,执行下次循环",
|
||||
Color: "#00B2B2",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Continue-v2.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{},
|
||||
EnUSName: "Continue",
|
||||
EnUSDescription: "Used to immediately terminate the current loop and execute next loop",
|
||||
},
|
||||
NodeTypeInputReceiver: {
|
||||
ID: 30,
|
||||
Key: NodeTypeInputReceiver,
|
||||
DisplayKey: "Input",
|
||||
Name: "输入",
|
||||
Category: "input&output",
|
||||
Desc: "支持中间过程的信息输入",
|
||||
Color: "#5C62FF",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Input-v2.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PostFillNil: true,
|
||||
},
|
||||
EnUSName: "Input",
|
||||
EnUSDescription: "Support intermediate information input",
|
||||
},
|
||||
NodeTypeComment: {
|
||||
ID: 31,
|
||||
Key: "",
|
||||
Name: "注释",
|
||||
Category: "", // Not found in cate_list
|
||||
Desc: "comment_desc", // Placeholder from JSON
|
||||
Color: "",
|
||||
IconURL: "comment_icon", // Placeholder from JSON
|
||||
SupportBatch: false, // supportBatch: 1
|
||||
EnUSName: "Comment",
|
||||
},
|
||||
NodeTypeVariableAggregator: {
|
||||
ID: 32,
|
||||
Key: NodeTypeVariableAggregator,
|
||||
Name: "变量聚合",
|
||||
Category: "logic",
|
||||
Desc: "对多个分支的输出进行聚合处理",
|
||||
Color: "#00B2B2",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/VariableMerge-icon.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PostFillNil: true,
|
||||
InputSourceAware: true,
|
||||
UseCtxCache: true,
|
||||
},
|
||||
EnUSName: "Variable Merge",
|
||||
EnUSDescription: "Aggregate the outputs of multiple branches.",
|
||||
},
|
||||
NodeTypeMessageList: {
|
||||
ID: 37,
|
||||
Key: NodeTypeMessageList,
|
||||
Name: "查询消息列表",
|
||||
Category: "message",
|
||||
Desc: "用于查询消息列表",
|
||||
Color: "#F2B600",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Conversation-List.jpeg",
|
||||
SupportBatch: false,
|
||||
Disabled: true,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PreFillZero: true,
|
||||
PostFillNil: true,
|
||||
},
|
||||
EnUSName: "Query message list",
|
||||
EnUSDescription: "Used to query the message list",
|
||||
},
|
||||
NodeTypeClearMessage: {
|
||||
ID: 38,
|
||||
Key: NodeTypeClearMessage,
|
||||
Name: "清除上下文",
|
||||
Category: "conversation_history",
|
||||
Desc: "用于清空会话历史,清空后LLM看到的会话历史为空",
|
||||
Color: "#F2B600",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Conversation-Delete.jpeg",
|
||||
SupportBatch: false,
|
||||
Disabled: true,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PreFillZero: true,
|
||||
PostFillNil: true,
|
||||
},
|
||||
EnUSName: "Clear conversation history",
|
||||
EnUSDescription: "Used to clear conversation history. After clearing, the conversation history visible to the LLM node will be empty.",
|
||||
},
|
||||
NodeTypeCreateConversation: {
|
||||
ID: 39,
|
||||
Key: NodeTypeCreateConversation,
|
||||
Name: "创建会话",
|
||||
Category: "conversation_management",
|
||||
Desc: "用于创建会话",
|
||||
Color: "#F2B600",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Conversation-Create.jpeg",
|
||||
SupportBatch: false,
|
||||
Disabled: true,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PreFillZero: true,
|
||||
PostFillNil: true,
|
||||
},
|
||||
EnUSName: "Create conversation",
|
||||
EnUSDescription: "This node is used to create a conversation.",
|
||||
},
|
||||
NodeTypeVariableAssigner: {
|
||||
ID: 40,
|
||||
Key: NodeTypeVariableAssigner,
|
||||
DisplayKey: "AssignVariable",
|
||||
Name: "变量赋值",
|
||||
Category: "data",
|
||||
Desc: "用于给支持写入的变量赋值,包括应用变量、用户变量",
|
||||
Color: "#FF811A",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/Variable.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{},
|
||||
EnUSName: "Variable assign",
|
||||
EnUSDescription: "Assigns values to variables that support the write operation, including app and user variables.",
|
||||
},
|
||||
NodeTypeDatabaseUpdate: {
|
||||
ID: 42,
|
||||
Key: NodeTypeDatabaseUpdate,
|
||||
DisplayKey: "DatabaseUpdate",
|
||||
Name: "更新数据",
|
||||
Category: "database",
|
||||
Desc: "修改表中已存在的数据记录,用户指定更新条件和内容来更新数据",
|
||||
Color: "#F2B600",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-database-update.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
DefaultTimeoutMS: 60 * 1000, // 1 minute
|
||||
PreFillZero: true,
|
||||
},
|
||||
EnUSName: "Update Data",
|
||||
EnUSDescription: "Modify the existing data records in the table, and the user specifies the update conditions and contents to update the data",
|
||||
},
|
||||
NodeTypeDatabaseQuery: {
|
||||
ID: 43,
|
||||
Key: NodeTypeDatabaseQuery,
|
||||
DisplayKey: "DatabaseSelect",
|
||||
Name: "查询数据",
|
||||
Category: "database",
|
||||
Desc: "从表获取数据,用户可定义查询条件、选择列等,输出符合条件的数据",
|
||||
Color: "#F2B600",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icaon-database-select.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
DefaultTimeoutMS: 60 * 1000, // 1 minute
|
||||
PreFillZero: true,
|
||||
},
|
||||
EnUSName: "Query Data",
|
||||
EnUSDescription: "Query data from the table, and the user can define query conditions, select columns, etc., and output the data that meets the conditions",
|
||||
},
|
||||
NodeTypeDatabaseDelete: {
|
||||
ID: 44,
|
||||
Key: NodeTypeDatabaseDelete,
|
||||
DisplayKey: "DatabaseDelete",
|
||||
Name: "删除数据",
|
||||
Category: "database",
|
||||
Desc: "从表中删除数据记录,用户指定删除条件来删除符合条件的记录",
|
||||
Color: "#F2B600",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-database-delete.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
DefaultTimeoutMS: 60 * 1000, // 1 minute
|
||||
PreFillZero: true,
|
||||
},
|
||||
EnUSName: "Delete Data",
|
||||
EnUSDescription: "Delete data records from the table, and the user specifies the deletion conditions to delete the records that meet the conditions",
|
||||
},
|
||||
NodeTypeHTTPRequester: {
|
||||
ID: 45,
|
||||
Key: NodeTypeHTTPRequester,
|
||||
DisplayKey: "Http",
|
||||
Name: "HTTP 请求",
|
||||
Category: "utilities",
|
||||
Desc: "用于发送API请求,从接口返回数据",
|
||||
Color: "#3071F2",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-HTTP.png",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
DefaultTimeoutMS: 60 * 1000, // 1 minute
|
||||
PreFillZero: true,
|
||||
PostFillNil: true,
|
||||
},
|
||||
EnUSName: "HTTP request",
|
||||
EnUSDescription: "It is used to send API requests and return data from the interface.",
|
||||
},
|
||||
NodeTypeDatabaseInsert: {
|
||||
ID: 46,
|
||||
Key: NodeTypeDatabaseInsert,
|
||||
DisplayKey: "DatabaseInsert",
|
||||
Name: "新增数据",
|
||||
Category: "database",
|
||||
Desc: "向表添加新数据记录,用户输入数据内容后插入数据库",
|
||||
Color: "#F2B600",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-database-insert.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
DefaultTimeoutMS: 60 * 1000, // 1 minute
|
||||
PreFillZero: true,
|
||||
},
|
||||
EnUSName: "Add Data",
|
||||
EnUSDescription: "Add new data records to the table, and insert them into the database after the user enters the data content",
|
||||
},
|
||||
NodeTypeJsonSerialization: {
|
||||
// ID is the unique identifier of this node type. Used in various front-end APIs.
|
||||
ID: 58,
|
||||
|
||||
// Key is the unique NodeType of this node. Used in backend code as well as saved in DB.
|
||||
Key: NodeTypeJsonSerialization,
|
||||
|
||||
// DisplayKey is the string used in frontend to identify this node.
|
||||
// Example use cases:
|
||||
// - used during querying test-run results for nodes
|
||||
// - used in returned messages from streaming openAPI Runs.
|
||||
// If empty, will use Key as DisplayKey.
|
||||
DisplayKey: "ToJSON",
|
||||
|
||||
// Name is the node in ZH_CN, will be displayed on Canvas.
|
||||
Name: "JSON 序列化",
|
||||
|
||||
// Category is the category of this node, determines which category this node will be displayed in.
|
||||
Category: "utilities",
|
||||
|
||||
// Desc is the desc in ZH_CN, will be displayed as tooltip on Canvas.
|
||||
Desc: "用于把变量转化为JSON字符串",
|
||||
|
||||
// Color is the color of the upper edge of the node displayed on Canvas.
|
||||
Color: "F2B600",
|
||||
|
||||
// IconURL is the URL of the icon displayed on Canvas.
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-to_json.png",
|
||||
|
||||
// SupportBatch indicates whether this node can set batch mode.
|
||||
// NOTE: ultimately it's frontend that decides which node can enable batch mode.
|
||||
SupportBatch: false,
|
||||
|
||||
// ExecutableMeta configures certain common aspects of request-time behaviors for this node.
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
// DefaultTimeoutMS configures the default timeout for this node, in milliseconds. 0 means no timeout.
|
||||
DefaultTimeoutMS: 60 * 1000, // 1 minute
|
||||
// PreFillZero decides whether to pre-fill zero value for any missing fields in input.
|
||||
PreFillZero: true,
|
||||
// PostFillNil decides whether to post-fill nil value for any missing fields in output.
|
||||
PostFillNil: true,
|
||||
},
|
||||
// EnUSName is the name in EN_US, will be displayed on Canvas if language of Coze-Studio is set to EnUS.
|
||||
EnUSName: "JSON serialization",
|
||||
// EnUSDescription is the description in EN_US, will be displayed on Canvas if language of Coze-Studio is set to EnUS.
|
||||
EnUSDescription: "Convert variable to JSON string",
|
||||
},
|
||||
NodeTypeJsonDeserialization: {
|
||||
ID: 59,
|
||||
Key: NodeTypeJsonDeserialization,
|
||||
DisplayKey: "FromJSON",
|
||||
Name: "JSON 反序列化",
|
||||
Category: "utilities",
|
||||
Desc: "用于将JSON字符串解析为变量",
|
||||
Color: "F2B600",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-from_json.png",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
DefaultTimeoutMS: 60 * 1000, // 1 minute
|
||||
PreFillZero: true,
|
||||
PostFillNil: true,
|
||||
UseCtxCache: true,
|
||||
},
|
||||
EnUSName: "JSON deserialization",
|
||||
EnUSDescription: "Parse JSON string to variable",
|
||||
},
|
||||
NodeTypeKnowledgeDeleter: {
|
||||
ID: 60,
|
||||
Key: NodeTypeKnowledgeDeleter,
|
||||
DisplayKey: "KnowledgeDelete",
|
||||
Name: "知识库删除",
|
||||
Category: "data",
|
||||
Desc: "用于删除知识库中的文档",
|
||||
Color: "#FF811A",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icons-dataset-delete.png",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
DefaultTimeoutMS: 60 * 1000, // 1 minute
|
||||
PreFillZero: true,
|
||||
PostFillNil: true,
|
||||
},
|
||||
EnUSName: "Knowledge delete",
|
||||
EnUSDescription: "The delete node can delete a document in knowledge base.",
|
||||
},
|
||||
NodeTypeLambda: {
|
||||
ID: 1000,
|
||||
Key: NodeTypeLambda,
|
||||
Name: "Lambda",
|
||||
EnUSName: "Comment",
|
||||
},
|
||||
}
|
||||
|
||||
// PluginNodeMetas holds metadata for specific plugin API entity.
|
||||
var PluginNodeMetas []*PluginNodeMeta
|
||||
|
||||
// PluginCategoryMetas holds metadata for plugin category entity.
|
||||
var PluginCategoryMetas []*PluginCategoryMeta
|
||||
|
||||
func NodeMetaByNodeType(t NodeType) *NodeTypeMeta {
|
||||
if m, ok := NodeTypeMetas[t]; ok {
|
||||
return m
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,867 +0,0 @@
|
|||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package entity
|
||||
|
||||
import (
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/i18n"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
|
||||
)
|
||||
|
||||
var Categories = []Category{
|
||||
{
|
||||
Key: "", // this is the default category. some of the most important nodes belong here, such as LLM, plugin, sub-workflow
|
||||
Name: "",
|
||||
EnUSName: "",
|
||||
},
|
||||
{
|
||||
Key: "logic",
|
||||
Name: "业务逻辑",
|
||||
EnUSName: "Logic",
|
||||
},
|
||||
{
|
||||
Key: "input&output",
|
||||
Name: "输入&输出",
|
||||
EnUSName: "Input&Output",
|
||||
},
|
||||
{
|
||||
Key: "database",
|
||||
Name: "数据库",
|
||||
EnUSName: "Database",
|
||||
},
|
||||
{
|
||||
Key: "data",
|
||||
Name: "知识库&数据",
|
||||
EnUSName: "Data",
|
||||
},
|
||||
{
|
||||
Key: "image",
|
||||
Name: "图像处理",
|
||||
EnUSName: "Image",
|
||||
},
|
||||
{
|
||||
Key: "audio&video",
|
||||
Name: "音视频处理",
|
||||
EnUSName: "Audio&Video",
|
||||
},
|
||||
{
|
||||
Key: "utilities",
|
||||
Name: "组件",
|
||||
EnUSName: "Utilities",
|
||||
},
|
||||
{
|
||||
Key: "conversation_management",
|
||||
Name: "会话管理",
|
||||
EnUSName: "Conversation management",
|
||||
},
|
||||
{
|
||||
Key: "conversation_history",
|
||||
Name: "会话历史",
|
||||
EnUSName: "Conversation history",
|
||||
},
|
||||
{
|
||||
Key: "message",
|
||||
Name: "消息",
|
||||
EnUSName: "Message",
|
||||
},
|
||||
}
|
||||
|
||||
// NodeTypeMetas holds the metadata for all available node types.
|
||||
// It is initialized with built-in types and potentially extended by loading from external sources.
|
||||
var NodeTypeMetas = []*NodeTypeMeta{
|
||||
{
|
||||
ID: 1,
|
||||
Name: "开始",
|
||||
Type: NodeTypeEntry,
|
||||
Category: "input&output", // Mapped from cate_list
|
||||
Desc: "工作流的起始节点,用于设定启动工作流需要的信息",
|
||||
Color: "#5C62FF",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start-v2.jpg",
|
||||
SupportBatch: false, // supportBatch: 1
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PostFillNil: true,
|
||||
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
|
||||
},
|
||||
EnUSName: "Start",
|
||||
EnUSDescription: "The starting node of the workflow, used to set the information needed to initiate the workflow.",
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Name: "结束",
|
||||
Type: NodeTypeExit,
|
||||
Category: "input&output", // Mapped from cate_list
|
||||
Desc: "工作流的最终节点,用于返回工作流运行后的结果信息",
|
||||
Color: "#5C62FF",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End-v2.jpg",
|
||||
SupportBatch: false, // supportBatch: 1
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PreFillZero: true,
|
||||
CallbackEnabled: true,
|
||||
InputSourceAware: true,
|
||||
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true, Transform: true},
|
||||
StreamSourceEOFAware: true,
|
||||
IncrementalOutput: true,
|
||||
},
|
||||
EnUSName: "End",
|
||||
EnUSDescription: "The final node of the workflow, used to return the result information after the workflow runs.",
|
||||
},
|
||||
{
|
||||
ID: 3,
|
||||
Name: "大模型",
|
||||
Type: NodeTypeLLM,
|
||||
Category: "", // Mapped from cate_list
|
||||
Desc: "调用大语言模型,使用变量和提示词生成回复",
|
||||
Color: "#5C62FF",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-LLM-v2.jpg",
|
||||
SupportBatch: true, // supportBatch: 2
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
DefaultTimeoutMS: 3 * 60 * 1000, // 3 minutes
|
||||
PreFillZero: true,
|
||||
PostFillNil: true,
|
||||
CallbackEnabled: true,
|
||||
InputSourceAware: true,
|
||||
MayUseChatModel: true,
|
||||
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true, Stream: true},
|
||||
},
|
||||
EnUSName: "LLM",
|
||||
EnUSDescription: "Invoke the large language model, generate responses using variables and prompt words.",
|
||||
},
|
||||
|
||||
{
|
||||
ID: 4,
|
||||
Name: "插件",
|
||||
Type: NodeTypePlugin,
|
||||
Category: "", // Mapped from cate_list
|
||||
Desc: "通过添加工具访问实时数据和执行外部操作",
|
||||
Color: "#CA61FF",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Plugin-v2.jpg",
|
||||
SupportBatch: true, // supportBatch: 2
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
DefaultTimeoutMS: 3 * 60 * 1000, // 3 minutes
|
||||
PreFillZero: true,
|
||||
PostFillNil: true,
|
||||
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
|
||||
},
|
||||
EnUSName: "Plugin",
|
||||
EnUSDescription: "Used to access external real-time data and perform operations",
|
||||
},
|
||||
{
|
||||
ID: 5,
|
||||
Name: "代码",
|
||||
Type: NodeTypeCodeRunner,
|
||||
Category: "logic", // Mapped from cate_list
|
||||
Desc: "编写代码,处理输入变量来生成返回值",
|
||||
Color: "#00B2B2",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Code-v2.jpg",
|
||||
SupportBatch: false, // supportBatch: 1
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
DefaultTimeoutMS: 60 * 1000, // 1 minute
|
||||
PreFillZero: true,
|
||||
PostFillNil: true,
|
||||
CallbackEnabled: true,
|
||||
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
|
||||
},
|
||||
EnUSName: "Code",
|
||||
EnUSDescription: "Write code to process input variables to generate return values.",
|
||||
},
|
||||
{
|
||||
ID: 6,
|
||||
Name: "知识库检索",
|
||||
Type: NodeTypeKnowledgeRetriever,
|
||||
Category: "data", // Mapped from cate_list
|
||||
Desc: "在选定的知识中,根据输入变量召回最匹配的信息,并以列表形式返回",
|
||||
Color: "#FF811A",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-KnowledgeQuery-v2.jpg",
|
||||
SupportBatch: false, // supportBatch: 1
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
DefaultTimeoutMS: 60 * 1000, // 1 minute
|
||||
PreFillZero: true,
|
||||
PostFillNil: true,
|
||||
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
|
||||
},
|
||||
EnUSName: "Knowledge retrieval",
|
||||
EnUSDescription: "In the selected knowledge, the best matching information is recalled based on the input variable and returned as an Array.",
|
||||
},
|
||||
{
|
||||
ID: 8,
|
||||
Name: "选择器",
|
||||
Type: NodeTypeSelector,
|
||||
Category: "logic", // Mapped from cate_list
|
||||
Desc: "连接多个下游分支,若设定的条件成立则仅运行对应的分支,若均不成立则只运行“否则”分支",
|
||||
Color: "#00B2B2",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Condition-v2.jpg",
|
||||
SupportBatch: false, // supportBatch: 1
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
CallbackEnabled: true,
|
||||
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
|
||||
},
|
||||
EnUSName: "Condition",
|
||||
EnUSDescription: "Connect multiple downstream branches. Only the corresponding branch will be executed if the set conditions are met. If none are met, only the 'else' branch will be executed.",
|
||||
},
|
||||
{
|
||||
ID: 9,
|
||||
Name: "工作流",
|
||||
Type: NodeTypeSubWorkflow,
|
||||
Category: "", // Mapped from cate_list
|
||||
Desc: "集成已发布工作流,可以执行嵌套子任务",
|
||||
Color: "#00B83E",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Workflow-v2.jpg",
|
||||
SupportBatch: true, // supportBatch: 2
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
CallbackEnabled: true,
|
||||
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
|
||||
},
|
||||
EnUSName: "Workflow",
|
||||
EnUSDescription: "Add published workflows to execute subtasks",
|
||||
},
|
||||
{
|
||||
ID: 12,
|
||||
Name: "SQL自定义",
|
||||
Type: NodeTypeDatabaseCustomSQL,
|
||||
Category: "database", // Mapped from cate_list
|
||||
Desc: "基于用户自定义的 SQL 完成对数据库的增删改查操作",
|
||||
Color: "#FF811A",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Database-v2.jpg",
|
||||
SupportBatch: false, // supportBatch: 2
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
DefaultTimeoutMS: 60 * 1000, // 1 minute
|
||||
PreFillZero: true,
|
||||
PostFillNil: true,
|
||||
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
|
||||
},
|
||||
EnUSName: "SQL Customization",
|
||||
EnUSDescription: "Complete the operations of adding, deleting, modifying and querying the database based on user-defined SQL",
|
||||
},
|
||||
{
|
||||
ID: 13,
|
||||
Name: "输出",
|
||||
Type: NodeTypeOutputEmitter,
|
||||
Category: "input&output", // Mapped from cate_list
|
||||
Desc: "节点从“消息”更名为“输出”,支持中间过程的消息输出,支持流式和非流式两种方式",
|
||||
Color: "#5C62FF",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Output-v2.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PreFillZero: true,
|
||||
CallbackEnabled: true,
|
||||
InputSourceAware: true,
|
||||
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true, Stream: true},
|
||||
StreamSourceEOFAware: true,
|
||||
IncrementalOutput: true,
|
||||
},
|
||||
EnUSName: "Output",
|
||||
EnUSDescription: "The node is renamed from \"message\" to \"output\", Supports message output in the intermediate process and streaming and non-streaming methods",
|
||||
},
|
||||
{
|
||||
ID: 15,
|
||||
Name: "文本处理",
|
||||
Type: NodeTypeTextProcessor,
|
||||
Category: "utilities", // Mapped from cate_list
|
||||
Desc: "用于处理多个字符串类型变量的格式",
|
||||
Color: "#3071F2",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-StrConcat-v2.jpg",
|
||||
SupportBatch: false, // supportBatch: 2
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PreFillZero: true,
|
||||
CallbackEnabled: true,
|
||||
InputSourceAware: true,
|
||||
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
|
||||
},
|
||||
EnUSName: "Text Processing",
|
||||
EnUSDescription: "The format used for handling multiple string-type variables.",
|
||||
},
|
||||
{
|
||||
ID: 18,
|
||||
Name: "问答",
|
||||
Type: NodeTypeQuestionAnswer,
|
||||
Category: "utilities", // Mapped from cate_list
|
||||
Desc: "支持中间向用户提问问题,支持预置选项提问和开放式问题提问两种方式",
|
||||
Color: "#3071F2",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Direct-Question-v2.jpg",
|
||||
SupportBatch: false, // supportBatch: 1
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
DefaultTimeoutMS: 60 * 1000, // 1 minute
|
||||
PreFillZero: true,
|
||||
PostFillNil: true,
|
||||
CallbackEnabled: true,
|
||||
MayUseChatModel: true,
|
||||
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
|
||||
},
|
||||
EnUSName: "Question",
|
||||
EnUSDescription: "Support asking questions to the user in the middle of the conversation, with both preset options and open-ended questions",
|
||||
},
|
||||
{
|
||||
ID: 19,
|
||||
Name: "终止循环",
|
||||
Type: NodeTypeBreak,
|
||||
Category: "logic", // Mapped from cate_list
|
||||
Desc: "用于立即终止当前所在的循环,跳出循环体",
|
||||
Color: "#00B2B2",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Break-v2.jpg",
|
||||
SupportBatch: false, // supportBatch: 1
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
|
||||
},
|
||||
EnUSName: "Break",
|
||||
EnUSDescription: "Used to immediately terminate the current loop and jump out of the loop",
|
||||
},
|
||||
{
|
||||
ID: 20,
|
||||
Name: "设置变量",
|
||||
Type: NodeTypeVariableAssignerWithinLoop,
|
||||
Category: "logic", // Mapped from cate_list
|
||||
Desc: "用于重置循环变量的值,使其下次循环使用重置后的值",
|
||||
Color: "#00B2B2",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-LoopSetVariable-v2.jpg",
|
||||
SupportBatch: false, // supportBatch: 1
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
|
||||
},
|
||||
EnUSName: "Set Variable",
|
||||
EnUSDescription: "Used to reset the value of the loop variable so that it uses the reset value in the next iteration",
|
||||
},
|
||||
{
|
||||
ID: 21,
|
||||
Name: "循环",
|
||||
Type: NodeTypeLoop,
|
||||
Category: "logic", // Mapped from cate_list
|
||||
Desc: "用于通过设定循环次数和逻辑,重复执行一系列任务",
|
||||
Color: "#00B2B2",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Loop-v2.jpg",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
IsComposite: true,
|
||||
DefaultTimeoutMS: 60 * 1000, // 1 minute
|
||||
PreFillZero: true,
|
||||
PostFillNil: true,
|
||||
CallbackEnabled: true,
|
||||
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
|
||||
},
|
||||
EnUSName: "Loop",
|
||||
EnUSDescription: "Used to repeatedly execute a series of tasks by setting the number of iterations and logic",
|
||||
},
|
||||
{
|
||||
ID: 22,
|
||||
Name: "意图识别",
|
||||
Type: NodeTypeIntentDetector,
|
||||
Category: "logic", // Mapped from cate_list
|
||||
Desc: "用于用户输入的意图识别,并将其与预设意图选项进行匹配。",
|
||||
Color: "#00B2B2",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Intent-v2.jpg",
|
||||
SupportBatch: false, // supportBatch: 1
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
DefaultTimeoutMS: 60 * 1000, // 1 minute
|
||||
PreFillZero: true,
|
||||
PostFillNil: true,
|
||||
CallbackEnabled: true,
|
||||
MayUseChatModel: true,
|
||||
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
|
||||
},
|
||||
EnUSName: "Intent recognition",
|
||||
EnUSDescription: "Used for recognizing the intent in user input and matching it with preset intent options.",
|
||||
},
|
||||
{
|
||||
ID: 27,
|
||||
Name: "知识库写入",
|
||||
Type: NodeTypeKnowledgeIndexer,
|
||||
Category: "data", // Mapped from cate_list
|
||||
Desc: "写入节点可以添加 文本类型 的知识库,仅可以添加一个知识库",
|
||||
Color: "#FF811A",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-KnowledgeWriting-v2.jpg",
|
||||
SupportBatch: false, // supportBatch: 1
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
DefaultTimeoutMS: 60 * 1000, // 1 minute
|
||||
PreFillZero: true,
|
||||
PostFillNil: true,
|
||||
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
|
||||
},
|
||||
EnUSName: "Knowledge writing",
|
||||
EnUSDescription: "The write node can add a knowledge base of type text. Only one knowledge base can be added.",
|
||||
},
|
||||
{
|
||||
ID: 28,
|
||||
Name: "批处理",
|
||||
Type: NodeTypeBatch,
|
||||
Category: "logic", // Mapped from cate_list
|
||||
Desc: "通过设定批量运行次数和逻辑,运行批处理体内的任务",
|
||||
Color: "#00B2B2",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Batch-v2.jpg",
|
||||
SupportBatch: false, // supportBatch: 1 (Corrected from previous assumption)
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
IsComposite: true,
|
||||
DefaultTimeoutMS: 60 * 1000, // 1 minute
|
||||
PreFillZero: true,
|
||||
PostFillNil: true,
|
||||
CallbackEnabled: true,
|
||||
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
|
||||
},
|
||||
EnUSName: "Batch",
|
||||
EnUSDescription: "By setting the number of batch runs and logic, run the tasks in the batch body.",
|
||||
},
|
||||
{
|
||||
ID: 29,
|
||||
Name: "继续循环",
|
||||
Type: NodeTypeContinue,
|
||||
Category: "logic", // Mapped from cate_list
|
||||
Desc: "用于终止当前循环,执行下次循环",
|
||||
Color: "#00B2B2",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Continue-v2.jpg",
|
||||
SupportBatch: false, // supportBatch: 1
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
|
||||
},
|
||||
EnUSName: "Continue",
|
||||
EnUSDescription: "Used to immediately terminate the current loop and execute next loop",
|
||||
},
|
||||
{
|
||||
ID: 30,
|
||||
Name: "输入",
|
||||
Type: NodeTypeInputReceiver,
|
||||
Category: "input&output", // Mapped from cate_list
|
||||
Desc: "支持中间过程的信息输入",
|
||||
Color: "#5C62FF",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Input-v2.jpg",
|
||||
SupportBatch: false, // supportBatch: 1
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PostFillNil: true,
|
||||
CallbackEnabled: true,
|
||||
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
|
||||
},
|
||||
EnUSName: "Input",
|
||||
EnUSDescription: "Support intermediate information input",
|
||||
},
|
||||
{
|
||||
ID: 31,
|
||||
Name: "注释",
|
||||
Type: "",
|
||||
Category: "", // Not found in cate_list
|
||||
Desc: "comment_desc", // Placeholder from JSON
|
||||
Color: "",
|
||||
IconURL: "comment_icon", // Placeholder from JSON
|
||||
SupportBatch: false, // supportBatch: 1
|
||||
EnUSName: "Comment",
|
||||
},
|
||||
{
|
||||
ID: 32,
|
||||
Name: "变量聚合",
|
||||
Type: NodeTypeVariableAggregator,
|
||||
Category: "logic", // Mapped from cate_list
|
||||
Desc: "对多个分支的输出进行聚合处理",
|
||||
Color: "#00B2B2",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/VariableMerge-icon.jpg",
|
||||
SupportBatch: false, // supportBatch: 1
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PostFillNil: true,
|
||||
CallbackEnabled: true,
|
||||
InputSourceAware: true,
|
||||
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true, Transform: true},
|
||||
},
|
||||
EnUSName: "Variable Merge",
|
||||
EnUSDescription: "Aggregate the outputs of multiple branches.",
|
||||
},
|
||||
{
|
||||
ID: 37,
|
||||
Name: "查询消息列表",
|
||||
Type: NodeTypeMessageList,
|
||||
Category: "message", // Mapped from cate_list
|
||||
Desc: "用于查询消息列表",
|
||||
Color: "#F2B600",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Conversation-List.jpeg",
|
||||
SupportBatch: false, // supportBatch: 1
|
||||
Disabled: true,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PreFillZero: true,
|
||||
PostFillNil: true,
|
||||
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
|
||||
},
|
||||
EnUSName: "Query message list",
|
||||
EnUSDescription: "Used to query the message list",
|
||||
},
|
||||
{
|
||||
ID: 38,
|
||||
Name: "清除上下文",
|
||||
Type: NodeTypeClearMessage,
|
||||
Category: "conversation_history", // Mapped from cate_list
|
||||
Desc: "用于清空会话历史,清空后LLM看到的会话历史为空",
|
||||
Color: "#F2B600",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Conversation-Delete.jpeg",
|
||||
SupportBatch: false, // supportBatch: 1
|
||||
Disabled: true,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PreFillZero: true,
|
||||
PostFillNil: true,
|
||||
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
|
||||
},
|
||||
EnUSName: "Clear conversation history",
|
||||
EnUSDescription: "Used to clear conversation history. After clearing, the conversation history visible to the LLM node will be empty.",
|
||||
},
|
||||
{
|
||||
ID: 39,
|
||||
Name: "创建会话",
|
||||
Type: NodeTypeCreateConversation,
|
||||
Category: "conversation_management", // Mapped from cate_list
|
||||
Desc: "用于创建会话",
|
||||
Color: "#F2B600",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Conversation-Create.jpeg",
|
||||
SupportBatch: false, // supportBatch: 1
|
||||
Disabled: true,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
PreFillZero: true,
|
||||
PostFillNil: true,
|
||||
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
|
||||
},
|
||||
EnUSName: "Create conversation",
|
||||
EnUSDescription: "This node is used to create a conversation.",
|
||||
},
|
||||
{
|
||||
ID: 40,
|
||||
Name: "变量赋值",
|
||||
Type: NodeTypeVariableAssigner,
|
||||
Category: "data", // Mapped from cate_list
|
||||
Desc: "用于给支持写入的变量赋值,包括应用变量、用户变量",
|
||||
Color: "#FF811A",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/Variable.jpg",
|
||||
SupportBatch: false, // supportBatch: 1
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
|
||||
},
|
||||
EnUSName: "Variable assign",
|
||||
EnUSDescription: "Assigns values to variables that support the write operation, including app and user variables.",
|
||||
},
|
||||
{
|
||||
ID: 42,
|
||||
Name: "更新数据",
|
||||
Type: NodeTypeDatabaseUpdate,
|
||||
Category: "database", // Mapped from cate_list
|
||||
Desc: "修改表中已存在的数据记录,用户指定更新条件和内容来更新数据",
|
||||
Color: "#F2B600",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-database-update.jpg", // Corrected Icon URL from JSON
|
||||
SupportBatch: false, // supportBatch: 1
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
DefaultTimeoutMS: 60 * 1000, // 1 minute
|
||||
PreFillZero: true,
|
||||
CallbackEnabled: true,
|
||||
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
|
||||
},
|
||||
EnUSName: "Update Data",
|
||||
EnUSDescription: "Modify the existing data records in the table, and the user specifies the update conditions and contents to update the data",
|
||||
},
|
||||
{
|
||||
ID: 43,
|
||||
Name: "查询数据", // Corrected Name from JSON (was "insert data")
|
||||
Type: NodeTypeDatabaseQuery,
|
||||
Category: "database", // Mapped from cate_list
|
||||
Desc: "从表获取数据,用户可定义查询条件、选择列等,输出符合条件的数据", // Corrected Desc from JSON
|
||||
Color: "#F2B600",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icaon-database-select.jpg", // Corrected Icon URL from JSON
|
||||
SupportBatch: false, // supportBatch: 1
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
DefaultTimeoutMS: 60 * 1000, // 1 minute
|
||||
PreFillZero: true,
|
||||
CallbackEnabled: true,
|
||||
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
|
||||
},
|
||||
EnUSName: "Query Data",
|
||||
EnUSDescription: "Query data from the table, and the user can define query conditions, select columns, etc., and output the data that meets the conditions",
|
||||
},
|
||||
{
|
||||
ID: 44,
|
||||
Name: "删除数据",
|
||||
Type: NodeTypeDatabaseDelete,
|
||||
Category: "database", // Mapped from cate_list
|
||||
Desc: "从表中删除数据记录,用户指定删除条件来删除符合条件的记录", // Corrected Desc from JSON
|
||||
Color: "#F2B600",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-database-delete.jpg", // Corrected Icon URL from JSON
|
||||
SupportBatch: false, // supportBatch: 1
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
DefaultTimeoutMS: 60 * 1000, // 1 minute
|
||||
PreFillZero: true,
|
||||
CallbackEnabled: true,
|
||||
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
|
||||
},
|
||||
EnUSName: "Delete Data",
|
||||
EnUSDescription: "Delete data records from the table, and the user specifies the deletion conditions to delete the records that meet the conditions",
|
||||
},
|
||||
{
|
||||
ID: 45,
|
||||
Name: "HTTP 请求",
|
||||
Type: NodeTypeHTTPRequester,
|
||||
Category: "utilities", // Mapped from cate_list
|
||||
Desc: "用于发送API请求,从接口返回数据", // Corrected Desc from JSON
|
||||
Color: "#3071F2",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-HTTP.png", // Corrected Icon URL from JSON
|
||||
SupportBatch: false, // supportBatch: 1
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
DefaultTimeoutMS: 60 * 1000, // 1 minute
|
||||
PreFillZero: true,
|
||||
PostFillNil: true,
|
||||
CallbackEnabled: true,
|
||||
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
|
||||
},
|
||||
EnUSName: "HTTP request",
|
||||
EnUSDescription: "It is used to send API requests and return data from the interface.",
|
||||
},
|
||||
{
|
||||
ID: 46,
|
||||
Name: "新增数据", // Corrected Name from JSON (was "Query Data")
|
||||
Type: NodeTypeDatabaseInsert,
|
||||
Category: "database", // Mapped from cate_list
|
||||
Desc: "向表添加新数据记录,用户输入数据内容后插入数据库", // Corrected Desc from JSON
|
||||
Color: "#F2B600",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-database-insert.jpg", // Corrected Icon URL from JSON
|
||||
SupportBatch: false, // supportBatch: 1
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
DefaultTimeoutMS: 60 * 1000, // 1 minute
|
||||
PreFillZero: true,
|
||||
CallbackEnabled: true,
|
||||
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
|
||||
},
|
||||
EnUSName: "Add Data",
|
||||
EnUSDescription: "Add new data records to the table, and insert them into the database after the user enters the data content",
|
||||
},
|
||||
{
|
||||
ID: 58,
|
||||
Name: "JSON 序列化",
|
||||
Type: NodeTypeJsonSerialization,
|
||||
Category: "utilities",
|
||||
Desc: "用于把变量转化为JSON字符串",
|
||||
Color: "F2B600",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-to_json.png",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
DefaultTimeoutMS: 60 * 1000, // 1 minute
|
||||
PreFillZero: true,
|
||||
CallbackEnabled: true,
|
||||
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
|
||||
},
|
||||
EnUSName: "JSON serialization",
|
||||
EnUSDescription: "Convert variable to JSON string",
|
||||
},
|
||||
{
|
||||
ID: 59,
|
||||
Name: "JSON 反序列化",
|
||||
Type: NodeTypeJsonDeserialization,
|
||||
Category: "utilities",
|
||||
Desc: "用于将JSON字符串解析为变量",
|
||||
Color: "F2B600",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-from_json.png",
|
||||
SupportBatch: false,
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
DefaultTimeoutMS: 60 * 1000, // 1 minute
|
||||
PreFillZero: true,
|
||||
PostFillNil: true,
|
||||
CallbackEnabled: true,
|
||||
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
|
||||
},
|
||||
EnUSName: "JSON deserialization",
|
||||
EnUSDescription: "Parse JSON string to variable",
|
||||
},
|
||||
{
|
||||
ID: 60,
|
||||
Name: "知识库删除",
|
||||
Type: NodeTypeKnowledgeDeleter,
|
||||
Category: "data", // Mapped from cate_list
|
||||
Desc: "用于删除知识库中的文档",
|
||||
Color: "#FF811A",
|
||||
IconURL: "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icons-dataset-delete.png",
|
||||
SupportBatch: false, // supportBatch: 1
|
||||
ExecutableMeta: ExecutableMeta{
|
||||
DefaultTimeoutMS: 60 * 1000, // 1 minute
|
||||
PreFillZero: true,
|
||||
PostFillNil: true,
|
||||
StreamingParadigms: map[StreamingParadigm]bool{Invoke: true},
|
||||
},
|
||||
EnUSName: "Knowledge delete",
|
||||
EnUSDescription: "The delete node can delete a document in knowledge base.",
|
||||
},
|
||||
// --- End of nodes parsed from template_list ---
|
||||
}
|
||||
|
||||
// PluginNodeMetas holds metadata for specific plugin API entity.
|
||||
var PluginNodeMetas []*PluginNodeMeta
|
||||
|
||||
// PluginCategoryMetas holds metadata for plugin category entity.
|
||||
var PluginCategoryMetas []*PluginCategoryMeta
|
||||
|
||||
func NodeMetaByNodeType(t NodeType) *NodeTypeMeta {
|
||||
for _, meta := range NodeTypeMetas {
|
||||
if meta.Type == t {
|
||||
return meta
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
const defaultZhCNInitCanvasJsonSchema = `{
|
||||
"nodes": [
|
||||
{
|
||||
"id": "100001",
|
||||
"type": "1",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 0,
|
||||
"y": 0
|
||||
}
|
||||
},
|
||||
"data": {
|
||||
"nodeMeta": {
|
||||
"description": "工作流的起始节点,用于设定启动工作流需要的信息",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start.png",
|
||||
"subTitle": "",
|
||||
"title": "开始"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"type": "string",
|
||||
"name": "input",
|
||||
"required": false
|
||||
}
|
||||
],
|
||||
"trigger_parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"name": "input",
|
||||
"required": false
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "900001",
|
||||
"type": "2",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 1000,
|
||||
"y": 0
|
||||
}
|
||||
},
|
||||
"data": {
|
||||
"nodeMeta": {
|
||||
"description": "工作流的最终节点,用于返回工作流运行后的结果信息",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End.png",
|
||||
"subTitle": "",
|
||||
"title": "结束"
|
||||
},
|
||||
"inputs": {
|
||||
"terminatePlan": "returnVariables",
|
||||
"inputParameters": [
|
||||
{
|
||||
"name": "output",
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"type": "ref",
|
||||
"content": {
|
||||
"source": "block-output",
|
||||
"blockID": "",
|
||||
"name": ""
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
"versions": {
|
||||
"loop": "v2"
|
||||
}
|
||||
}`
|
||||
|
||||
const defaultEnUSInitCanvasJsonSchema = `{
|
||||
"nodes": [
|
||||
{
|
||||
"id": "100001",
|
||||
"type": "1",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 0,
|
||||
"y": 0
|
||||
}
|
||||
},
|
||||
"data": {
|
||||
"nodeMeta": {
|
||||
"description": "The starting node of the workflow, used to set the information needed to initiate the workflow.",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start.png",
|
||||
"subTitle": "",
|
||||
"title": "Start"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"type": "string",
|
||||
"name": "input",
|
||||
"required": false
|
||||
}
|
||||
],
|
||||
"trigger_parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"name": "input",
|
||||
"required": false
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "900001",
|
||||
"type": "2",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 1000,
|
||||
"y": 0
|
||||
}
|
||||
},
|
||||
"data": {
|
||||
"nodeMeta": {
|
||||
"description": "The final node of the workflow, used to return the result information after the workflow runs.",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End.png",
|
||||
"subTitle": "",
|
||||
"title": "End"
|
||||
},
|
||||
"inputs": {
|
||||
"terminatePlan": "returnVariables",
|
||||
"inputParameters": [
|
||||
{
|
||||
"name": "output",
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"type": "ref",
|
||||
"content": {
|
||||
"source": "block-output",
|
||||
"blockID": "",
|
||||
"name": ""
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
"versions": {
|
||||
"loop": "v2"
|
||||
}
|
||||
}`
|
||||
|
||||
func GetDefaultInitCanvasJsonSchema(locale i18n.Locale) string {
|
||||
return ternary.IFElse(locale == i18n.LocaleEN, defaultEnUSInitCanvasJsonSchema, defaultZhCNInitCanvasJsonSchema)
|
||||
}
|
||||
|
|
@ -19,24 +19,48 @@ package vo
|
|||
import (
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/i18n"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
|
||||
)
|
||||
|
||||
// Canvas is the definition of FRONTEND schema for a workflow.
|
||||
type Canvas struct {
|
||||
Nodes []*Node `json:"nodes"`
|
||||
Edges []*Edge `json:"edges"`
|
||||
Versions any `json:"versions"`
|
||||
}
|
||||
|
||||
// Node represents a node within a workflow canvas.
|
||||
type Node struct {
|
||||
// ID is the unique node ID within the workflow.
|
||||
// In normal use cases, this ID is generated by frontend.
|
||||
// It does NOT need to be unique between parent workflow and sub workflows.
|
||||
// The Entry node and Exit node of a workflow always have fixed node IDs: 100001 and 900001.
|
||||
ID string `json:"id"`
|
||||
Type BlockType `json:"type"`
|
||||
|
||||
// Type is the Node Type of this node instance.
|
||||
// It corresponds to the string value of 'ID' field of NodeMeta.
|
||||
Type string `json:"type"`
|
||||
|
||||
// Meta holds meta data used by frontend, such as the node's position within canvas.
|
||||
Meta any `json:"meta"`
|
||||
|
||||
// Data holds the actual configurations of a node, such as inputs, outputs and exception handling.
|
||||
// It also holds exclusive configurations for different node types, such as LLM configurations.
|
||||
Data *Data `json:"data"`
|
||||
|
||||
// Blocks holds the sub nodes of this node.
|
||||
// It is only used when the node type is Composite, such as NodeTypeBatch and NodeTypeLoop.
|
||||
Blocks []*Node `json:"blocks,omitempty"`
|
||||
|
||||
// Edges are the connections between nodes.
|
||||
// Strictly corresponds to connections drawn on canvas.
|
||||
Edges []*Edge `json:"edges,omitempty"`
|
||||
|
||||
// Version is the version of this node type's schema.
|
||||
Version string `json:"version,omitempty"`
|
||||
|
||||
parent *Node
|
||||
parent *Node // if this node is within a composite node, coze will set this. No need to set manually
|
||||
}
|
||||
|
||||
func (n *Node) SetParent(parent *Node) {
|
||||
|
|
@ -47,7 +71,7 @@ func (n *Node) Parent() *Node {
|
|||
return n.parent
|
||||
}
|
||||
|
||||
type NodeMeta struct {
|
||||
type NodeMetaFE struct {
|
||||
Title string `json:"title,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Icon string `json:"icon,omitempty"`
|
||||
|
|
@ -62,52 +86,88 @@ type Edge struct {
|
|||
TargetPortID string `json:"targetPortID,omitempty"`
|
||||
}
|
||||
|
||||
// Data holds the actual configuration of a Node.
|
||||
type Data struct {
|
||||
Meta *NodeMeta `json:"nodeMeta,omitempty"`
|
||||
Outputs []any `json:"outputs,omitempty"` // either []*Variable or []*Param
|
||||
// Meta is the meta data of this node. Only used by frontend.
|
||||
Meta *NodeMetaFE `json:"nodeMeta,omitempty"`
|
||||
|
||||
// Outputs configures the output fields and their types.
|
||||
// Outputs can be either []*Variable (most of the cases, just fields and types),
|
||||
// or []*Param (used by composite nodes as they need to refer to outputs of sub nodes)
|
||||
Outputs []any `json:"outputs,omitempty"`
|
||||
|
||||
// Inputs configures ALL input information of a node,
|
||||
// including fixed input fields and dynamic input fields defined by user.
|
||||
Inputs *Inputs `json:"inputs,omitempty"`
|
||||
|
||||
// Size configures the size of this node in frontend.
|
||||
// Only used by NodeTypeComment.
|
||||
Size any `json:"size,omitempty"`
|
||||
}
|
||||
|
||||
type Inputs struct {
|
||||
// InputParameters are the fields defined by user for this particular node.
|
||||
InputParameters []*Param `json:"inputParameters"`
|
||||
Content *BlockInput `json:"content"`
|
||||
TerminatePlan *TerminatePlan `json:"terminatePlan,omitempty"`
|
||||
StreamingOutput bool `json:"streamingOutput,omitempty"`
|
||||
CallTransferVoice bool `json:"callTransferVoice,omitempty"`
|
||||
ChatHistoryWriting string `json:"chatHistoryWriting,omitempty"`
|
||||
LLMParam any `json:"llmParam,omitempty"` // The LLMParam type may be one of the LLMParam or IntentDetectorLLMParam type or QALLMParam type
|
||||
FCParam *FCParam `json:"fcParam,omitempty"`
|
||||
|
||||
// SettingOnError configures common error handling strategy for nodes.
|
||||
// NOTE: enable in frontend node's form first.
|
||||
SettingOnError *SettingOnError `json:"settingOnError,omitempty"`
|
||||
|
||||
// NodeBatchInfo configures batch mode for nodes.
|
||||
// NOTE: not to be confused with NodeTypeBatch.
|
||||
NodeBatchInfo *NodeBatch `json:"batch,omitempty"`
|
||||
|
||||
// LLMParam may be one of the LLMParam or IntentDetectorLLMParam or SimpleLLMParam.
|
||||
// Shared between most nodes requiring an ChatModel to function.
|
||||
LLMParam any `json:"llmParam,omitempty"`
|
||||
|
||||
*OutputEmitter // exclusive configurations for NodeTypeEmitter and NodeTypeExit in Answer mode
|
||||
*Exit // exclusive configurations for NodeTypeExit
|
||||
*LLM // exclusive configurations for NodeTypeLLM
|
||||
*Loop // exclusive configurations for NodeTypeLoop
|
||||
*Selector // exclusive configurations for NodeTypeSelector
|
||||
*TextProcessor // exclusive configurations for NodeTypeTextProcessor
|
||||
*SubWorkflow // exclusive configurations for NodeTypeSubWorkflow
|
||||
*IntentDetector // exclusive configurations for NodeTypeIntentDetector
|
||||
*DatabaseNode // exclusive configurations for various Database nodes
|
||||
*HttpRequestNode // exclusive configurations for NodeTypeHTTPRequester
|
||||
*Knowledge // exclusive configurations for various Knowledge nodes
|
||||
*CodeRunner // exclusive configurations for NodeTypeCodeRunner
|
||||
*PluginAPIParam // exclusive configurations for NodeTypePlugin
|
||||
*VariableAggregator // exclusive configurations for NodeTypeVariableAggregator
|
||||
*VariableAssigner // exclusive configurations for NodeTypeVariableAssigner
|
||||
*QA // exclusive configurations for NodeTypeQuestionAnswer
|
||||
*Batch // exclusive configurations for NodeTypeBatch
|
||||
*Comment // exclusive configurations for NodeTypeComment
|
||||
*InputReceiver // exclusive configurations for NodeTypeInputReceiver
|
||||
}
|
||||
|
||||
type OutputEmitter struct {
|
||||
Content *BlockInput `json:"content"`
|
||||
StreamingOutput bool `json:"streamingOutput,omitempty"`
|
||||
}
|
||||
|
||||
type Exit struct {
|
||||
TerminatePlan *TerminatePlan `json:"terminatePlan,omitempty"`
|
||||
}
|
||||
|
||||
type LLM struct {
|
||||
FCParam *FCParam `json:"fcParam,omitempty"`
|
||||
}
|
||||
|
||||
type Loop struct {
|
||||
LoopType LoopType `json:"loopType,omitempty"`
|
||||
LoopCount *BlockInput `json:"loopCount,omitempty"`
|
||||
VariableParameters []*Param `json:"variableParameters,omitempty"`
|
||||
}
|
||||
|
||||
type Selector struct {
|
||||
Branches []*struct {
|
||||
Condition struct {
|
||||
Logic LogicType `json:"logic"`
|
||||
Conditions []*Condition `json:"conditions"`
|
||||
} `json:"condition"`
|
||||
} `json:"branches,omitempty"`
|
||||
|
||||
NodeBatchInfo *NodeBatch `json:"batch,omitempty"` // node in batch mode
|
||||
|
||||
*TextProcessor
|
||||
*SubWorkflow
|
||||
*IntentDetector
|
||||
*DatabaseNode
|
||||
*HttpRequestNode
|
||||
*KnowledgeIndexer
|
||||
*CodeRunner
|
||||
*PluginAPIParam
|
||||
*VariableAggregator
|
||||
*VariableAssigner
|
||||
*QA
|
||||
*Batch
|
||||
*Comment
|
||||
|
||||
OutputSchema string `json:"outputSchema,omitempty"`
|
||||
}
|
||||
|
||||
type Comment struct {
|
||||
|
|
@ -127,7 +187,7 @@ type VariableAssigner struct {
|
|||
|
||||
type LLMParam = []*Param
|
||||
type IntentDetectorLLMParam = map[string]any
|
||||
type QALLMParam struct {
|
||||
type SimpleLLMParam struct {
|
||||
GenerationDiversity string `json:"generationDiversity"`
|
||||
MaxTokens int `json:"maxTokens"`
|
||||
ModelName string `json:"modelName"`
|
||||
|
|
@ -248,7 +308,7 @@ type CodeRunner struct {
|
|||
Language int64 `json:"language"`
|
||||
}
|
||||
|
||||
type KnowledgeIndexer struct {
|
||||
type Knowledge struct {
|
||||
DatasetParam []*Param `json:"datasetParam,omitempty"`
|
||||
StrategyParam StrategyParam `json:"strategyParam,omitempty"`
|
||||
}
|
||||
|
|
@ -384,22 +444,53 @@ type ChatHistorySetting struct {
|
|||
type Intent struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// Param is a node's field with type and source info.
|
||||
type Param struct {
|
||||
// Name is the field's name.
|
||||
Name string `json:"name,omitempty"`
|
||||
|
||||
// Input is the configurations for normal, singular field.
|
||||
Input *BlockInput `json:"input,omitempty"`
|
||||
|
||||
// Left is the configurations for the left half of an expression,
|
||||
// such as an assignment in NodeTypeVariableAssigner.
|
||||
Left *BlockInput `json:"left,omitempty"`
|
||||
|
||||
// Right is the configuration for the right half of an expression.
|
||||
Right *BlockInput `json:"right,omitempty"`
|
||||
|
||||
// Variables are configurations for a group of fields.
|
||||
// Only used in NodeTypeVariableAggregator.
|
||||
Variables []*BlockInput `json:"variables,omitempty"`
|
||||
}
|
||||
|
||||
// Variable is the configuration of a node's field, either input or output.
|
||||
type Variable struct {
|
||||
// Name is the field's name as defined on canvas.
|
||||
Name string `json:"name"`
|
||||
|
||||
// Type is the field's data type, such as string, integer, number, object, array, etc.
|
||||
Type VariableType `json:"type"`
|
||||
|
||||
// Required is set to true if you checked the 'required box' on this field
|
||||
Required bool `json:"required,omitempty"`
|
||||
|
||||
// AssistType is the 'secondary' type of string fields, such as different types of file and image, or time.
|
||||
AssistType AssistType `json:"assistType,omitempty"`
|
||||
|
||||
// Schema contains detailed info for sub-fields of an object field, or element type of an array.
|
||||
Schema any `json:"schema,omitempty"` // either []*Variable (for object) or *Variable (for list)
|
||||
|
||||
// Description describes the field's intended use. Used on Entry node. Useful for workflow tools.
|
||||
Description string `json:"description,omitempty"`
|
||||
|
||||
// ReadOnly indicates a field is not to be set by Node's business logic.
|
||||
// e.g. the ErrorBody field when exception strategy is configured.
|
||||
ReadOnly bool `json:"readOnly,omitempty"`
|
||||
|
||||
// DefaultValue configures the 'default value' if this field is missing in input.
|
||||
// Effective only in Entry node.
|
||||
DefaultValue any `json:"defaultValue,omitempty"`
|
||||
}
|
||||
|
||||
|
|
@ -436,48 +527,6 @@ type SubWorkflow struct {
|
|||
SpaceID string `json:"spaceId,omitempty"`
|
||||
}
|
||||
|
||||
// BlockType is the enumeration of node types for front-end canvas schema.
|
||||
// To add a new BlockType, start from a really big number such as 1000, to avoid conflict with future extensions.
|
||||
type BlockType string
|
||||
|
||||
func (b BlockType) String() string {
|
||||
return string(b)
|
||||
}
|
||||
|
||||
const (
|
||||
BlockTypeBotStart BlockType = "1"
|
||||
BlockTypeBotEnd BlockType = "2"
|
||||
BlockTypeBotLLM BlockType = "3"
|
||||
BlockTypeBotAPI BlockType = "4"
|
||||
BlockTypeBotCode BlockType = "5"
|
||||
BlockTypeBotDataset BlockType = "6"
|
||||
BlockTypeCondition BlockType = "8"
|
||||
BlockTypeBotSubWorkflow BlockType = "9"
|
||||
BlockTypeDatabase BlockType = "12"
|
||||
BlockTypeBotMessage BlockType = "13"
|
||||
BlockTypeBotText BlockType = "15"
|
||||
BlockTypeQuestion BlockType = "18"
|
||||
BlockTypeBotBreak BlockType = "19"
|
||||
BlockTypeBotLoopSetVariable BlockType = "20"
|
||||
BlockTypeBotLoop BlockType = "21"
|
||||
BlockTypeBotIntent BlockType = "22"
|
||||
BlockTypeBotDatasetWrite BlockType = "27"
|
||||
BlockTypeBotInput BlockType = "30"
|
||||
BlockTypeBotBatch BlockType = "28"
|
||||
BlockTypeBotContinue BlockType = "29"
|
||||
BlockTypeBotComment BlockType = "31"
|
||||
BlockTypeBotVariableMerge BlockType = "32"
|
||||
BlockTypeBotAssignVariable BlockType = "40"
|
||||
BlockTypeDatabaseUpdate BlockType = "42"
|
||||
BlockTypeDatabaseSelect BlockType = "43"
|
||||
BlockTypeDatabaseDelete BlockType = "44"
|
||||
BlockTypeBotHttp BlockType = "45"
|
||||
BlockTypeDatabaseInsert BlockType = "46"
|
||||
BlockTypeJsonSerialization BlockType = "58"
|
||||
BlockTypeJsonDeserialization BlockType = "59"
|
||||
BlockTypeBotDatasetDelete BlockType = "60"
|
||||
)
|
||||
|
||||
type VariableType string
|
||||
|
||||
const (
|
||||
|
|
@ -536,19 +585,31 @@ const (
|
|||
type ErrorProcessType int
|
||||
|
||||
const (
|
||||
ErrorProcessTypeThrow ErrorProcessType = 1
|
||||
ErrorProcessTypeDefault ErrorProcessType = 2
|
||||
ErrorProcessTypeExceptionBranch ErrorProcessType = 3
|
||||
ErrorProcessTypeThrow ErrorProcessType = 1 // throws the error as usual
|
||||
ErrorProcessTypeReturnDefaultData ErrorProcessType = 2 // return DataOnErr configured in SettingOnError
|
||||
ErrorProcessTypeExceptionBranch ErrorProcessType = 3 // executes the exception branch on error
|
||||
)
|
||||
|
||||
// SettingOnError contains common error handling strategy.
|
||||
type SettingOnError struct {
|
||||
// DataOnErr defines the JSON result to be returned on error.
|
||||
DataOnErr string `json:"dataOnErr,omitempty"`
|
||||
// Switch defines whether ANY error handling strategy is active.
|
||||
// If set to false, it's equivalent to set ProcessType = ErrorProcessTypeThrow
|
||||
Switch bool `json:"switch,omitempty"`
|
||||
// ProcessType determines the error handling strategy for this node.
|
||||
ProcessType *ErrorProcessType `json:"processType,omitempty"`
|
||||
// RetryTimes determines how many times to retry. 0 means no retry.
|
||||
// If positive, any retries will be executed immediately after error.
|
||||
RetryTimes int64 `json:"retryTimes,omitempty"`
|
||||
// TimeoutMs sets the timeout duration in millisecond.
|
||||
// If any retry happens, ALL retry attempts accumulates to the same timeout threshold.
|
||||
TimeoutMs int64 `json:"timeoutMs,omitempty"`
|
||||
// Ext sets any extra settings specific to NodeType
|
||||
Ext *struct {
|
||||
BackupLLMParam string `json:"backupLLMParam,omitempty"` // only for LLM Node, marshaled from QALLMParam
|
||||
// BackupLLMParam is only for LLM Node, marshaled from SimpleLLMParam.
|
||||
// If retry happens, the backup LLM will be used instead of the main LLM.
|
||||
BackupLLMParam string `json:"backupLLMParam,omitempty"`
|
||||
} `json:"ext,omitempty"`
|
||||
}
|
||||
|
||||
|
|
@ -597,32 +658,8 @@ const (
|
|||
LoopTypeInfinite LoopType = "infinite"
|
||||
)
|
||||
|
||||
type WorkflowIdentity struct {
|
||||
ID string `json:"id"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
func (c *Canvas) GetAllSubWorkflowIdentities() []*WorkflowIdentity {
|
||||
workflowEntities := make([]*WorkflowIdentity, 0)
|
||||
|
||||
var collectSubWorkFlowEntities func(nodes []*Node)
|
||||
collectSubWorkFlowEntities = func(nodes []*Node) {
|
||||
for _, n := range nodes {
|
||||
if n.Type == BlockTypeBotSubWorkflow {
|
||||
workflowEntities = append(workflowEntities, &WorkflowIdentity{
|
||||
ID: n.Data.Inputs.WorkflowID,
|
||||
Version: n.Data.Inputs.WorkflowVersion,
|
||||
})
|
||||
}
|
||||
if len(n.Blocks) > 0 {
|
||||
collectSubWorkFlowEntities(n.Blocks)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
collectSubWorkFlowEntities(c.Nodes)
|
||||
|
||||
return workflowEntities
|
||||
type InputReceiver struct {
|
||||
OutputSchema string `json:"outputSchema,omitempty"`
|
||||
}
|
||||
|
||||
func GenerateNodeIDForBatchMode(key string) string {
|
||||
|
|
@ -632,3 +669,163 @@ func GenerateNodeIDForBatchMode(key string) string {
|
|||
func IsGeneratedNodeForBatchMode(key string, parentKey string) bool {
|
||||
return key == GenerateNodeIDForBatchMode(parentKey)
|
||||
}
|
||||
|
||||
const defaultZhCNInitCanvasJsonSchema = `{
|
||||
"nodes": [
|
||||
{
|
||||
"id": "100001",
|
||||
"type": "1",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 0,
|
||||
"y": 0
|
||||
}
|
||||
},
|
||||
"data": {
|
||||
"nodeMeta": {
|
||||
"description": "工作流的起始节点,用于设定启动工作流需要的信息",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start.png",
|
||||
"subTitle": "",
|
||||
"title": "开始"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"type": "string",
|
||||
"name": "input",
|
||||
"required": false
|
||||
}
|
||||
],
|
||||
"trigger_parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"name": "input",
|
||||
"required": false
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "900001",
|
||||
"type": "2",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 1000,
|
||||
"y": 0
|
||||
}
|
||||
},
|
||||
"data": {
|
||||
"nodeMeta": {
|
||||
"description": "工作流的最终节点,用于返回工作流运行后的结果信息",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End.png",
|
||||
"subTitle": "",
|
||||
"title": "结束"
|
||||
},
|
||||
"inputs": {
|
||||
"terminatePlan": "returnVariables",
|
||||
"inputParameters": [
|
||||
{
|
||||
"name": "output",
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"type": "ref",
|
||||
"content": {
|
||||
"source": "block-output",
|
||||
"blockID": "",
|
||||
"name": ""
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
"versions": {
|
||||
"loop": "v2"
|
||||
}
|
||||
}`
|
||||
|
||||
const defaultEnUSInitCanvasJsonSchema = `{
|
||||
"nodes": [
|
||||
{
|
||||
"id": "100001",
|
||||
"type": "1",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 0,
|
||||
"y": 0
|
||||
}
|
||||
},
|
||||
"data": {
|
||||
"nodeMeta": {
|
||||
"description": "The starting node of the workflow, used to set the information needed to initiate the workflow.",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-Start.png",
|
||||
"subTitle": "",
|
||||
"title": "Start"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"type": "string",
|
||||
"name": "input",
|
||||
"required": false
|
||||
}
|
||||
],
|
||||
"trigger_parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"name": "input",
|
||||
"required": false
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "900001",
|
||||
"type": "2",
|
||||
"meta": {
|
||||
"position": {
|
||||
"x": 1000,
|
||||
"y": 0
|
||||
}
|
||||
},
|
||||
"data": {
|
||||
"nodeMeta": {
|
||||
"description": "The final node of the workflow, used to return the result information after the workflow runs.",
|
||||
"icon": "https://lf3-static.bytednsdoc.com/obj/eden-cn/dvsmryvd_avi_dvsm/ljhwZthlaukjlkulzlp/icon/icon-End.png",
|
||||
"subTitle": "",
|
||||
"title": "End"
|
||||
},
|
||||
"inputs": {
|
||||
"terminatePlan": "returnVariables",
|
||||
"inputParameters": [
|
||||
{
|
||||
"name": "output",
|
||||
"input": {
|
||||
"type": "string",
|
||||
"value": {
|
||||
"type": "ref",
|
||||
"content": {
|
||||
"source": "block-output",
|
||||
"blockID": "",
|
||||
"name": ""
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
"versions": {
|
||||
"loop": "v2"
|
||||
}
|
||||
}`
|
||||
|
||||
func GetDefaultInitCanvasJsonSchema(locale i18n.Locale) string {
|
||||
return ternary.IFElse(locale == i18n.LocaleEN, defaultEnUSInitCanvasJsonSchema, defaultZhCNInitCanvasJsonSchema)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -47,12 +47,6 @@ type FieldSource struct {
|
|||
Val any `json:"val,omitempty"`
|
||||
}
|
||||
|
||||
type ImplicitNodeDependency struct {
|
||||
NodeID string
|
||||
FieldPath compose.FieldPath
|
||||
TypeInfo *TypeInfo
|
||||
}
|
||||
|
||||
type TypeInfo struct {
|
||||
Type DataType `json:"type"`
|
||||
ElemTypeInfo *TypeInfo `json:"elem_type_info,omitempty"`
|
||||
|
|
|
|||
|
|
@ -69,13 +69,6 @@ type IDVersionPair struct {
|
|||
Version string
|
||||
}
|
||||
|
||||
type Stage uint8
|
||||
|
||||
const (
|
||||
StageDraft Stage = 1
|
||||
StagePublished Stage = 2
|
||||
)
|
||||
|
||||
type WorkflowBasic struct {
|
||||
ID int64
|
||||
Version string
|
||||
|
|
|
|||
|
|
@ -58,6 +58,11 @@ import (
|
|||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
RegisterAllNodeAdaptors()
|
||||
m.Run()
|
||||
}
|
||||
|
||||
func TestIntentDetectorAndDatabase(t *testing.T) {
|
||||
mockey.PatchConvey("intent detector & database custom sql", t, func() {
|
||||
data, err := os.ReadFile("../examples/intent_detector_database_custom_sql.json")
|
||||
|
|
|
|||
|
|
@ -26,10 +26,11 @@ import (
|
|||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
func WorkflowSchemaFromNode(ctx context.Context, c *vo.Canvas, nodeID string) (
|
||||
*compose.WorkflowSchema, error) {
|
||||
*schema.WorkflowSchema, error) {
|
||||
var (
|
||||
n *vo.Node
|
||||
nodeFinder func(nodes []*vo.Node) *vo.Node
|
||||
|
|
@ -62,35 +63,27 @@ func WorkflowSchemaFromNode(ctx context.Context, c *vo.Canvas, nodeID string) (
|
|||
n = batchN
|
||||
}
|
||||
|
||||
implicitDependencies, err := extractImplicitDependency(n, c.Nodes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
opts := make([]OptionFn, 0, 1)
|
||||
if len(implicitDependencies) > 0 {
|
||||
opts = append(opts, WithImplicitNodeDependencies(implicitDependencies))
|
||||
}
|
||||
nsList, hierarchy, err := NodeToNodeSchema(ctx, n, opts...)
|
||||
nsList, hierarchy, err := NodeToNodeSchema(ctx, n, c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var (
|
||||
ns *compose.NodeSchema
|
||||
innerNodes map[vo.NodeKey]*compose.NodeSchema // inner nodes of the composite node if nodeKey is composite
|
||||
connections []*compose.Connection
|
||||
ns *schema.NodeSchema
|
||||
innerNodes map[vo.NodeKey]*schema.NodeSchema // inner nodes of the composite node if nodeKey is composite
|
||||
connections []*schema.Connection
|
||||
)
|
||||
|
||||
if len(nsList) == 1 {
|
||||
ns = nsList[0]
|
||||
} else {
|
||||
innerNodes = make(map[vo.NodeKey]*compose.NodeSchema)
|
||||
innerNodes = make(map[vo.NodeKey]*schema.NodeSchema)
|
||||
for i := range nsList {
|
||||
one := nsList[i]
|
||||
if _, ok := hierarchy[one.Key]; ok {
|
||||
innerNodes[one.Key] = one
|
||||
if one.Type == entity.NodeTypeContinue || one.Type == entity.NodeTypeBreak {
|
||||
connections = append(connections, &compose.Connection{
|
||||
connections = append(connections, &schema.Connection{
|
||||
FromNode: one.Key,
|
||||
ToNode: vo.NodeKey(nodeID),
|
||||
})
|
||||
|
|
@ -106,13 +99,13 @@ func WorkflowSchemaFromNode(ctx context.Context, c *vo.Canvas, nodeID string) (
|
|||
}
|
||||
|
||||
const inputFillerKey = "input_filler"
|
||||
connections = append(connections, &compose.Connection{
|
||||
connections = append(connections, &schema.Connection{
|
||||
FromNode: einoCompose.START,
|
||||
ToNode: inputFillerKey,
|
||||
}, &compose.Connection{
|
||||
}, &schema.Connection{
|
||||
FromNode: inputFillerKey,
|
||||
ToNode: ns.Key,
|
||||
}, &compose.Connection{
|
||||
}, &schema.Connection{
|
||||
FromNode: ns.Key,
|
||||
ToNode: einoCompose.END,
|
||||
})
|
||||
|
|
@ -209,7 +202,7 @@ func WorkflowSchemaFromNode(ctx context.Context, c *vo.Canvas, nodeID string) (
|
|||
return newOutput, nil
|
||||
}
|
||||
|
||||
inputFiller := &compose.NodeSchema{
|
||||
inputFiller := &schema.NodeSchema{
|
||||
Key: inputFillerKey,
|
||||
Type: entity.NodeTypeLambda,
|
||||
Lambda: einoCompose.InvokableLambda(i),
|
||||
|
|
@ -227,10 +220,16 @@ func WorkflowSchemaFromNode(ctx context.Context, c *vo.Canvas, nodeID string) (
|
|||
OutputTypes: startOutputTypes,
|
||||
}
|
||||
|
||||
trimmedSC := &compose.WorkflowSchema{
|
||||
Nodes: append([]*compose.NodeSchema{ns, inputFiller}, maps.Values(innerNodes)...),
|
||||
branches, err := schema.BuildBranches(connections)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
trimmedSC := &schema.WorkflowSchema{
|
||||
Nodes: append([]*schema.NodeSchema{ns, inputFiller}, maps.Values(innerNodes)...),
|
||||
Connections: connections,
|
||||
Hierarchy: hierarchy,
|
||||
Branches: branches,
|
||||
}
|
||||
|
||||
if enabled {
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,663 @@
|
|||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package convert
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
einoCompose "github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
func CanvasVariableToTypeInfo(v *vo.Variable) (*vo.TypeInfo, error) {
|
||||
tInfo := &vo.TypeInfo{
|
||||
Required: v.Required,
|
||||
Desc: v.Description,
|
||||
}
|
||||
|
||||
switch v.Type {
|
||||
case vo.VariableTypeString:
|
||||
switch v.AssistType {
|
||||
case vo.AssistTypeTime:
|
||||
tInfo.Type = vo.DataTypeTime
|
||||
case vo.AssistTypeNotSet:
|
||||
tInfo.Type = vo.DataTypeString
|
||||
default:
|
||||
fileType, ok := assistTypeToFileType(v.AssistType)
|
||||
if ok {
|
||||
tInfo.Type = vo.DataTypeFile
|
||||
tInfo.FileType = &fileType
|
||||
} else {
|
||||
return nil, fmt.Errorf("unsupported assist type: %v", v.AssistType)
|
||||
}
|
||||
}
|
||||
case vo.VariableTypeInteger:
|
||||
tInfo.Type = vo.DataTypeInteger
|
||||
case vo.VariableTypeFloat:
|
||||
tInfo.Type = vo.DataTypeNumber
|
||||
case vo.VariableTypeBoolean:
|
||||
tInfo.Type = vo.DataTypeBoolean
|
||||
case vo.VariableTypeObject:
|
||||
tInfo.Type = vo.DataTypeObject
|
||||
tInfo.Properties = make(map[string]*vo.TypeInfo)
|
||||
if v.Schema != nil {
|
||||
for _, subVAny := range v.Schema.([]any) {
|
||||
subV, err := vo.ParseVariable(subVAny)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
subTInfo, err := CanvasVariableToTypeInfo(subV)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tInfo.Properties[subV.Name] = subTInfo
|
||||
}
|
||||
}
|
||||
case vo.VariableTypeList:
|
||||
tInfo.Type = vo.DataTypeArray
|
||||
subVAny := v.Schema
|
||||
subV, err := vo.ParseVariable(subVAny)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
subTInfo, err := CanvasVariableToTypeInfo(subV)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tInfo.ElemTypeInfo = subTInfo
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported variable type: %s", v.Type)
|
||||
}
|
||||
|
||||
return tInfo, nil
|
||||
}
|
||||
|
||||
func CanvasBlockInputToTypeInfo(b *vo.BlockInput) (tInfo *vo.TypeInfo, err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = vo.WrapIfNeeded(errno.ErrSchemaConversionFail, err)
|
||||
}
|
||||
}()
|
||||
|
||||
tInfo = &vo.TypeInfo{}
|
||||
|
||||
if b == nil {
|
||||
return tInfo, nil
|
||||
}
|
||||
|
||||
switch b.Type {
|
||||
case vo.VariableTypeString:
|
||||
switch b.AssistType {
|
||||
case vo.AssistTypeTime:
|
||||
tInfo.Type = vo.DataTypeTime
|
||||
case vo.AssistTypeNotSet:
|
||||
tInfo.Type = vo.DataTypeString
|
||||
default:
|
||||
fileType, ok := assistTypeToFileType(b.AssistType)
|
||||
if ok {
|
||||
tInfo.Type = vo.DataTypeFile
|
||||
tInfo.FileType = &fileType
|
||||
} else {
|
||||
return nil, fmt.Errorf("unsupported assist type: %v", b.AssistType)
|
||||
}
|
||||
}
|
||||
case vo.VariableTypeInteger:
|
||||
tInfo.Type = vo.DataTypeInteger
|
||||
case vo.VariableTypeFloat:
|
||||
tInfo.Type = vo.DataTypeNumber
|
||||
case vo.VariableTypeBoolean:
|
||||
tInfo.Type = vo.DataTypeBoolean
|
||||
case vo.VariableTypeObject:
|
||||
tInfo.Type = vo.DataTypeObject
|
||||
tInfo.Properties = make(map[string]*vo.TypeInfo)
|
||||
if b.Schema != nil {
|
||||
for _, subVAny := range b.Schema.([]any) {
|
||||
if b.Value.Type == vo.BlockInputValueTypeRef {
|
||||
subV, err := vo.ParseVariable(subVAny)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
subTInfo, err := CanvasVariableToTypeInfo(subV)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tInfo.Properties[subV.Name] = subTInfo
|
||||
} else if b.Value.Type == vo.BlockInputValueTypeObjectRef {
|
||||
subV, err := parseParam(subVAny)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
subTInfo, err := CanvasBlockInputToTypeInfo(subV.Input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tInfo.Properties[subV.Name] = subTInfo
|
||||
}
|
||||
}
|
||||
}
|
||||
case vo.VariableTypeList:
|
||||
tInfo.Type = vo.DataTypeArray
|
||||
subVAny := b.Schema
|
||||
subV, err := vo.ParseVariable(subVAny)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
subTInfo, err := CanvasVariableToTypeInfo(subV)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tInfo.ElemTypeInfo = subTInfo
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported variable type: %s", b.Type)
|
||||
}
|
||||
|
||||
return tInfo, nil
|
||||
}
|
||||
|
||||
func CanvasBlockInputToFieldInfo(b *vo.BlockInput, path einoCompose.FieldPath, parentNode *vo.Node) (sources []*vo.FieldInfo, err error) {
|
||||
value := b.Value
|
||||
if value == nil {
|
||||
return nil, fmt.Errorf("input %v has no value, type= %s", path, b.Type)
|
||||
}
|
||||
|
||||
switch value.Type {
|
||||
case vo.BlockInputValueTypeObjectRef:
|
||||
sc := b.Schema
|
||||
if sc == nil {
|
||||
return nil, fmt.Errorf("input %v has no schema, type= %s", path, b.Type)
|
||||
}
|
||||
|
||||
paramList, ok := sc.([]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("input %v schema not []any, type= %T", path, sc)
|
||||
}
|
||||
|
||||
for i := range paramList {
|
||||
paramAny := paramList[i]
|
||||
param, err := parseParam(paramAny)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
copied := make([]string, len(path))
|
||||
copy(copied, path)
|
||||
subFieldInfo, err := CanvasBlockInputToFieldInfo(param.Input, append(copied, param.Name), parentNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sources = append(sources, subFieldInfo...)
|
||||
}
|
||||
return sources, nil
|
||||
case vo.BlockInputValueTypeLiteral:
|
||||
content := value.Content
|
||||
if content == nil {
|
||||
return nil, fmt.Errorf("input %v is literal but has no value, type= %s", path, b.Type)
|
||||
}
|
||||
|
||||
switch b.Type {
|
||||
case vo.VariableTypeObject:
|
||||
m := make(map[string]any)
|
||||
if err = sonic.UnmarshalString(content.(string), &m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
content = m
|
||||
case vo.VariableTypeList:
|
||||
l := make([]any, 0)
|
||||
if err = sonic.UnmarshalString(content.(string), &l); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
content = l
|
||||
case vo.VariableTypeInteger:
|
||||
switch content.(type) {
|
||||
case string:
|
||||
content, err = strconv.ParseInt(content.(string), 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case int64:
|
||||
content = content.(int64)
|
||||
case float64:
|
||||
content = int64(content.(float64))
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported variable type fot integer: %s", b.Type)
|
||||
}
|
||||
case vo.VariableTypeFloat:
|
||||
switch content.(type) {
|
||||
case string:
|
||||
content, err = strconv.ParseFloat(content.(string), 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case int64:
|
||||
content = float64(content.(int64))
|
||||
case float64:
|
||||
content = content.(float64)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported variable type for float: %s", b.Type)
|
||||
}
|
||||
case vo.VariableTypeBoolean:
|
||||
switch content.(type) {
|
||||
case string:
|
||||
content, err = strconv.ParseBool(content.(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case bool:
|
||||
content = content.(bool)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported variable type for boolean: %s", b.Type)
|
||||
}
|
||||
default:
|
||||
}
|
||||
return []*vo.FieldInfo{
|
||||
{
|
||||
Path: path,
|
||||
Source: vo.FieldSource{
|
||||
Val: content,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
case vo.BlockInputValueTypeRef:
|
||||
content := value.Content
|
||||
if content == nil {
|
||||
return nil, fmt.Errorf("input %v is literal but has no value, type= %s", path, b.Type)
|
||||
}
|
||||
|
||||
ref, err := parseBlockInputRef(content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fieldSource, err := CanvasBlockInputRefToFieldSource(ref)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if parentNode != nil {
|
||||
if fieldSource.Ref != nil && len(fieldSource.Ref.FromNodeKey) > 0 && fieldSource.Ref.FromNodeKey == vo.NodeKey(parentNode.ID) {
|
||||
varRoot := fieldSource.Ref.FromPath[0]
|
||||
if parentNode.Data.Inputs.Loop != nil {
|
||||
for _, p := range parentNode.Data.Inputs.VariableParameters {
|
||||
if p.Name == varRoot {
|
||||
fieldSource.Ref.FromNodeKey = ""
|
||||
pi := vo.ParentIntermediate
|
||||
fieldSource.Ref.VariableType = &pi
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return []*vo.FieldInfo{
|
||||
{
|
||||
Path: path,
|
||||
Source: *fieldSource,
|
||||
},
|
||||
}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported value type: %s for blockInput type= %s", value.Type, b.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func parseBlockInputRef(content any) (*vo.BlockInputReference, error) {
|
||||
if bi, ok := content.(*vo.BlockInputReference); ok {
|
||||
return bi, nil
|
||||
}
|
||||
|
||||
m, ok := content.(map[string]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid content type: %T when parse BlockInputRef", content)
|
||||
}
|
||||
|
||||
marshaled, err := sonic.Marshal(m)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p := &vo.BlockInputReference{}
|
||||
if err := sonic.Unmarshal(marshaled, p); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func parseParam(v any) (*vo.Param, error) {
|
||||
if pa, ok := v.(*vo.Param); ok {
|
||||
return pa, nil
|
||||
}
|
||||
|
||||
m, ok := v.(map[string]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid content type: %T when parse Param", v)
|
||||
}
|
||||
|
||||
marshaled, err := sonic.Marshal(m)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p := &vo.Param{}
|
||||
if err := sonic.Unmarshal(marshaled, p); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func CanvasBlockInputRefToFieldSource(r *vo.BlockInputReference) (*vo.FieldSource, error) {
|
||||
switch r.Source {
|
||||
case vo.RefSourceTypeBlockOutput:
|
||||
if len(r.BlockID) == 0 {
|
||||
return nil, fmt.Errorf("invalid BlockInputReference = %+v, BlockID is empty when source is block output", r)
|
||||
}
|
||||
|
||||
parts := strings.Split(r.Name, ".") // an empty r.Name signals an all-to-all mapping
|
||||
return &vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: vo.NodeKey(r.BlockID),
|
||||
FromPath: parts,
|
||||
},
|
||||
}, nil
|
||||
case vo.RefSourceTypeGlobalApp, vo.RefSourceTypeGlobalSystem, vo.RefSourceTypeGlobalUser:
|
||||
if len(r.Path) == 0 {
|
||||
return nil, fmt.Errorf("invalid BlockInputReference = %+v, Path is empty when source is variables", r)
|
||||
}
|
||||
|
||||
var varType vo.GlobalVarType
|
||||
switch r.Source {
|
||||
case vo.RefSourceTypeGlobalApp:
|
||||
varType = vo.GlobalAPP
|
||||
case vo.RefSourceTypeGlobalSystem:
|
||||
varType = vo.GlobalSystem
|
||||
case vo.RefSourceTypeGlobalUser:
|
||||
varType = vo.GlobalUser
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid BlockInputReference = %+v, Source is invalid", r)
|
||||
}
|
||||
|
||||
return &vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
VariableType: &varType,
|
||||
FromPath: r.Path,
|
||||
},
|
||||
}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported ref source type: %s", r.Source)
|
||||
}
|
||||
}
|
||||
|
||||
func assistTypeToFileType(a vo.AssistType) (vo.FileSubType, bool) {
|
||||
switch a {
|
||||
case vo.AssistTypeNotSet:
|
||||
return "", false
|
||||
case vo.AssistTypeTime:
|
||||
return "", false
|
||||
case vo.AssistTypeImage:
|
||||
return vo.FileTypeImage, true
|
||||
case vo.AssistTypeAudio:
|
||||
return vo.FileTypeAudio, true
|
||||
case vo.AssistTypeVideo:
|
||||
return vo.FileTypeVideo, true
|
||||
case vo.AssistTypeDefault:
|
||||
return vo.FileTypeDefault, true
|
||||
case vo.AssistTypeDoc:
|
||||
return vo.FileTypeDocument, true
|
||||
case vo.AssistTypeExcel:
|
||||
return vo.FileTypeExcel, true
|
||||
case vo.AssistTypeCode:
|
||||
return vo.FileTypeCode, true
|
||||
case vo.AssistTypePPT:
|
||||
return vo.FileTypePPT, true
|
||||
case vo.AssistTypeTXT:
|
||||
return vo.FileTypeTxt, true
|
||||
case vo.AssistTypeSvg:
|
||||
return vo.FileTypeSVG, true
|
||||
case vo.AssistTypeVoice:
|
||||
return vo.FileTypeVoice, true
|
||||
case vo.AssistTypeZip:
|
||||
return vo.FileTypeZip, true
|
||||
default:
|
||||
panic("impossible")
|
||||
}
|
||||
}
|
||||
|
||||
func SetInputsForNodeSchema(n *vo.Node, ns *schema.NodeSchema) error {
|
||||
if n.Data.Inputs == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
inputParams := n.Data.Inputs.InputParameters
|
||||
if len(inputParams) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, param := range inputParams {
|
||||
name := param.Name
|
||||
tInfo, err := CanvasBlockInputToTypeInfo(param.Input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ns.SetInputType(name, tInfo)
|
||||
|
||||
sources, err := CanvasBlockInputToFieldInfo(param.Input, einoCompose.FieldPath{name}, n.Parent())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ns.AddInputSource(sources...)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func SetOutputTypesForNodeSchema(n *vo.Node, ns *schema.NodeSchema) error {
|
||||
for _, vAny := range n.Data.Outputs {
|
||||
v, err := vo.ParseVariable(vAny)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tInfo, err := CanvasVariableToTypeInfo(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if v.ReadOnly {
|
||||
if v.Name == "errorBody" { // reserved output fields when exception happens
|
||||
continue
|
||||
}
|
||||
}
|
||||
ns.SetOutputType(v.Name, tInfo)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func SetOutputsForNodeSchema(n *vo.Node, ns *schema.NodeSchema) error {
|
||||
for _, vAny := range n.Data.Outputs {
|
||||
param, err := parseParam(vAny)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
name := param.Name
|
||||
tInfo, err := CanvasBlockInputToTypeInfo(param.Input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ns.SetOutputType(name, tInfo)
|
||||
|
||||
sources, err := CanvasBlockInputToFieldInfo(param.Input, einoCompose.FieldPath{name}, n.Parent())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ns.AddOutputSource(sources...)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func BlockInputToNamedTypeInfo(name string, b *vo.BlockInput) (*vo.NamedTypeInfo, error) {
|
||||
tInfo := &vo.NamedTypeInfo{
|
||||
Name: name,
|
||||
}
|
||||
if b == nil {
|
||||
return tInfo, nil
|
||||
}
|
||||
switch b.Type {
|
||||
case vo.VariableTypeString:
|
||||
switch b.AssistType {
|
||||
case vo.AssistTypeTime:
|
||||
tInfo.Type = vo.DataTypeTime
|
||||
case vo.AssistTypeNotSet:
|
||||
tInfo.Type = vo.DataTypeString
|
||||
default:
|
||||
fileType, ok := assistTypeToFileType(b.AssistType)
|
||||
if ok {
|
||||
tInfo.Type = vo.DataTypeFile
|
||||
tInfo.FileType = &fileType
|
||||
} else {
|
||||
return nil, fmt.Errorf("unsupported assist type: %v", b.AssistType)
|
||||
}
|
||||
}
|
||||
case vo.VariableTypeInteger:
|
||||
tInfo.Type = vo.DataTypeInteger
|
||||
case vo.VariableTypeFloat:
|
||||
tInfo.Type = vo.DataTypeNumber
|
||||
case vo.VariableTypeBoolean:
|
||||
tInfo.Type = vo.DataTypeBoolean
|
||||
case vo.VariableTypeObject:
|
||||
tInfo.Type = vo.DataTypeObject
|
||||
if b.Schema != nil {
|
||||
tInfo.Properties = make([]*vo.NamedTypeInfo, 0, len(b.Schema.([]any)))
|
||||
for _, subVAny := range b.Schema.([]any) {
|
||||
if b.Value.Type == vo.BlockInputValueTypeRef {
|
||||
subV, err := vo.ParseVariable(subVAny)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
subNInfo, err := VariableToNamedTypeInfo(subV)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tInfo.Properties = append(tInfo.Properties, subNInfo)
|
||||
} else if b.Value.Type == vo.BlockInputValueTypeObjectRef {
|
||||
subV, err := parseParam(subVAny)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
subNInfo, err := BlockInputToNamedTypeInfo(subV.Name, subV.Input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tInfo.Properties = append(tInfo.Properties, subNInfo)
|
||||
}
|
||||
}
|
||||
}
|
||||
case vo.VariableTypeList:
|
||||
tInfo.Type = vo.DataTypeArray
|
||||
subVAny := b.Schema
|
||||
subV, err := vo.ParseVariable(subVAny)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
subNInfo, err := VariableToNamedTypeInfo(subV)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tInfo.ElemTypeInfo = subNInfo
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported variable type: %s", b.Type)
|
||||
}
|
||||
|
||||
return tInfo, nil
|
||||
}
|
||||
|
||||
func VariableToNamedTypeInfo(v *vo.Variable) (*vo.NamedTypeInfo, error) {
|
||||
nInfo := &vo.NamedTypeInfo{
|
||||
Required: v.Required,
|
||||
Name: v.Name,
|
||||
Desc: v.Description,
|
||||
}
|
||||
|
||||
switch v.Type {
|
||||
case vo.VariableTypeString:
|
||||
switch v.AssistType {
|
||||
case vo.AssistTypeTime:
|
||||
nInfo.Type = vo.DataTypeTime
|
||||
case vo.AssistTypeNotSet:
|
||||
nInfo.Type = vo.DataTypeString
|
||||
default:
|
||||
fileType, ok := assistTypeToFileType(v.AssistType)
|
||||
if ok {
|
||||
nInfo.Type = vo.DataTypeFile
|
||||
nInfo.FileType = &fileType
|
||||
} else {
|
||||
return nil, fmt.Errorf("unsupported assist type: %v", v.AssistType)
|
||||
}
|
||||
}
|
||||
case vo.VariableTypeInteger:
|
||||
nInfo.Type = vo.DataTypeInteger
|
||||
case vo.VariableTypeFloat:
|
||||
nInfo.Type = vo.DataTypeNumber
|
||||
case vo.VariableTypeBoolean:
|
||||
nInfo.Type = vo.DataTypeBoolean
|
||||
case vo.VariableTypeObject:
|
||||
nInfo.Type = vo.DataTypeObject
|
||||
if v.Schema != nil {
|
||||
nInfo.Properties = make([]*vo.NamedTypeInfo, 0)
|
||||
for _, subVAny := range v.Schema.([]any) {
|
||||
subV, err := vo.ParseVariable(subVAny)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
subTInfo, err := VariableToNamedTypeInfo(subV)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nInfo.Properties = append(nInfo.Properties, subTInfo)
|
||||
|
||||
}
|
||||
}
|
||||
case vo.VariableTypeList:
|
||||
nInfo.Type = vo.DataTypeArray
|
||||
subVAny := v.Schema
|
||||
subV, err := vo.ParseVariable(subVAny)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
subTInfo, err := VariableToNamedTypeInfo(subV)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nInfo.ElemTypeInfo = subTInfo
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported variable type: %s", v.Type)
|
||||
}
|
||||
|
||||
return nInfo, nil
|
||||
}
|
||||
|
|
@ -24,8 +24,11 @@ import (
|
|||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/adaptor"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
|
@ -123,7 +126,7 @@ func (cv *CanvasValidator) ValidateConnections(ctx context.Context) (issues []*I
|
|||
return issues, nil
|
||||
}
|
||||
|
||||
func (cv *CanvasValidator) CheckRefVariable(ctx context.Context) (issues []*Issue, err error) {
|
||||
func (cv *CanvasValidator) CheckRefVariable(_ context.Context) (issues []*Issue, err error) {
|
||||
issues = make([]*Issue, 0)
|
||||
var checkRefVariable func(reachability *reachability, reachableNodes map[string]bool) error
|
||||
checkRefVariable = func(reachability *reachability, parentReachableNodes map[string]bool) error {
|
||||
|
|
@ -237,7 +240,7 @@ func (cv *CanvasValidator) CheckRefVariable(ctx context.Context) (issues []*Issu
|
|||
return issues, nil
|
||||
}
|
||||
|
||||
func (cv *CanvasValidator) ValidateNestedFlows(ctx context.Context) (issues []*Issue, err error) {
|
||||
func (cv *CanvasValidator) ValidateNestedFlows(_ context.Context) (issues []*Issue, err error) {
|
||||
issues = make([]*Issue, 0)
|
||||
for nodeID, node := range cv.reachability.reachableNodes {
|
||||
if nestedReachableNodes, ok := cv.reachability.nestedReachability[nodeID]; ok && len(nestedReachableNodes.nestedReachability) > 0 {
|
||||
|
|
@ -265,13 +268,13 @@ func (cv *CanvasValidator) CheckGlobalVariables(ctx context.Context) (issues []*
|
|||
|
||||
nVars := make([]*nodeVars, 0)
|
||||
for _, node := range cv.cfg.Canvas.Nodes {
|
||||
if node.Type == vo.BlockTypeBotComment {
|
||||
if node.Type == entity.NodeTypeComment.IDStr() {
|
||||
continue
|
||||
}
|
||||
if node.Type == vo.BlockTypeBotAssignVariable {
|
||||
if node.Type == entity.NodeTypeVariableAssigner.IDStr() {
|
||||
v := &nodeVars{node: node, vars: make(map[string]*vo.TypeInfo)}
|
||||
for _, p := range node.Data.Inputs.InputParameters {
|
||||
v.vars[p.Name], err = adaptor.CanvasBlockInputToTypeInfo(p.Left)
|
||||
v.vars[p.Name], err = convert.CanvasBlockInputToTypeInfo(p.Left)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -338,7 +341,7 @@ func (cv *CanvasValidator) CheckSubWorkFlowTerminatePlanType(ctx context.Context
|
|||
var collectSubWorkFlowNodes func(nodes []*vo.Node)
|
||||
collectSubWorkFlowNodes = func(nodes []*vo.Node) {
|
||||
for _, n := range nodes {
|
||||
if n.Type == vo.BlockTypeBotSubWorkflow {
|
||||
if n.Type == entity.NodeTypeSubWorkflow.IDStr() {
|
||||
subWfMap = append(subWfMap, n)
|
||||
wID, err := strconv.ParseInt(n.Data.Inputs.WorkflowID, 10, 64)
|
||||
if err != nil {
|
||||
|
|
@ -465,62 +468,28 @@ func validateConnections(ctx context.Context, c *vo.Canvas) (issues []*Issue, er
|
|||
selectorPorts := make(map[string]map[string]bool)
|
||||
|
||||
for nodeID, node := range nodeMap {
|
||||
switch node.Type {
|
||||
case vo.BlockTypeCondition:
|
||||
branches := node.Data.Inputs.Branches
|
||||
if _, exists := selectorPorts[nodeID]; !exists {
|
||||
selectorPorts[nodeID] = make(map[string]bool)
|
||||
}
|
||||
selectorPorts[nodeID]["false"] = true
|
||||
for index := range branches {
|
||||
if index == 0 {
|
||||
selectorPorts[nodeID]["true"] = true
|
||||
} else {
|
||||
selectorPorts[nodeID][fmt.Sprintf("true_%v", index)] = true
|
||||
}
|
||||
}
|
||||
case vo.BlockTypeBotIntent:
|
||||
intents := node.Data.Inputs.Intents
|
||||
if _, exists := selectorPorts[nodeID]; !exists {
|
||||
selectorPorts[nodeID] = make(map[string]bool)
|
||||
}
|
||||
for index := range intents {
|
||||
selectorPorts[nodeID][fmt.Sprintf("branch_%v", index)] = true
|
||||
}
|
||||
selectorPorts[nodeID]["default"] = true
|
||||
if node.Data.Inputs.SettingOnError != nil && node.Data.Inputs.SettingOnError.ProcessType != nil &&
|
||||
*node.Data.Inputs.SettingOnError.ProcessType == vo.ErrorProcessTypeExceptionBranch {
|
||||
selectorPorts[nodeID]["branch_error"] = true
|
||||
}
|
||||
case vo.BlockTypeQuestion:
|
||||
if node.Data.Inputs.QA.AnswerType == vo.QAAnswerTypeOption {
|
||||
if _, exists := selectorPorts[nodeID]; !exists {
|
||||
selectorPorts[nodeID] = make(map[string]bool)
|
||||
}
|
||||
if node.Data.Inputs.QA.OptionType == vo.QAOptionTypeStatic {
|
||||
for index := range node.Data.Inputs.QA.Options {
|
||||
selectorPorts[nodeID][fmt.Sprintf("branch_%v", index)] = true
|
||||
}
|
||||
}
|
||||
|
||||
if node.Data.Inputs.QA.OptionType == vo.QAOptionTypeDynamic {
|
||||
selectorPorts[nodeID][fmt.Sprintf("branch_%v", 0)] = true
|
||||
}
|
||||
}
|
||||
default:
|
||||
if node.Data.Inputs != nil && node.Data.Inputs.SettingOnError != nil &&
|
||||
node.Data.Inputs.SettingOnError.ProcessType != nil &&
|
||||
*node.Data.Inputs.SettingOnError.ProcessType == vo.ErrorProcessTypeExceptionBranch {
|
||||
if _, exists := selectorPorts[nodeID]; !exists {
|
||||
selectorPorts[nodeID] = make(map[string]bool)
|
||||
}
|
||||
selectorPorts[nodeID]["branch_error"] = true
|
||||
selectorPorts[nodeID]["default"] = true
|
||||
} else {
|
||||
outDegree[node.ID] = 0
|
||||
}
|
||||
selectorPorts[nodeID][schema.PortBranchError] = true
|
||||
selectorPorts[nodeID][schema.PortDefault] = true
|
||||
}
|
||||
|
||||
ba, ok := nodes.GetBranchAdaptor(entity.IDStrToNodeType(node.Type))
|
||||
if ok {
|
||||
expects := ba.ExpectPorts(ctx, node)
|
||||
if len(expects) > 0 {
|
||||
if _, exists := selectorPorts[nodeID]; !exists {
|
||||
selectorPorts[nodeID] = make(map[string]bool)
|
||||
}
|
||||
for _, e := range expects {
|
||||
selectorPorts[nodeID][e] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, edge := range c.Edges {
|
||||
|
|
@ -544,8 +513,8 @@ func validateConnections(ctx context.Context, c *vo.Canvas) (issues []*Issue, er
|
|||
for nodeID, node := range nodeMap {
|
||||
nodeName := node.Data.Meta.Title
|
||||
|
||||
switch node.Type {
|
||||
case vo.BlockTypeBotStart:
|
||||
switch et := entity.IDStrToNodeType(node.Type); et {
|
||||
case entity.NodeTypeEntry:
|
||||
if outDegree[nodeID] == 0 {
|
||||
issues = append(issues, &Issue{
|
||||
NodeErr: &NodeErr{
|
||||
|
|
@ -555,13 +524,9 @@ func validateConnections(ctx context.Context, c *vo.Canvas) (issues []*Issue, er
|
|||
Message: `node "start" not connected`,
|
||||
})
|
||||
}
|
||||
case vo.BlockTypeBotEnd:
|
||||
case entity.NodeTypeExit:
|
||||
default:
|
||||
if ports, isSelector := selectorPorts[nodeID]; isSelector {
|
||||
selectorIssues := &Issue{NodeErr: &NodeErr{
|
||||
NodeID: node.ID,
|
||||
NodeName: nodeName,
|
||||
}}
|
||||
message := ""
|
||||
for port := range ports {
|
||||
if portOutDegree[nodeID][port] == 0 {
|
||||
|
|
@ -569,12 +534,15 @@ func validateConnections(ctx context.Context, c *vo.Canvas) (issues []*Issue, er
|
|||
}
|
||||
}
|
||||
if len(message) > 0 {
|
||||
selectorIssues.Message = message
|
||||
selectorIssues := &Issue{NodeErr: &NodeErr{
|
||||
NodeID: node.ID,
|
||||
NodeName: nodeName,
|
||||
}, Message: message}
|
||||
issues = append(issues, selectorIssues)
|
||||
}
|
||||
} else {
|
||||
// Break, continue without checking out degrees
|
||||
if node.Type == vo.BlockTypeBotBreak || node.Type == vo.BlockTypeBotContinue {
|
||||
if et == entity.NodeTypeBreak || et == entity.NodeTypeContinue {
|
||||
continue
|
||||
}
|
||||
if outDegree[nodeID] == 0 {
|
||||
|
|
@ -585,7 +553,6 @@ func validateConnections(ctx context.Context, c *vo.Canvas) (issues []*Issue, er
|
|||
},
|
||||
Message: fmt.Sprintf(`node "%v" not connected`, nodeName),
|
||||
})
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -602,7 +569,7 @@ func analyzeCanvasReachability(c *vo.Canvas) (*reachability, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
startNode, endNode, err := findStartAndEndNodes(c.Nodes)
|
||||
startNode, _, err := findStartAndEndNodes(c.Nodes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -612,7 +579,7 @@ func analyzeCanvasReachability(c *vo.Canvas) (*reachability, error) {
|
|||
edgeMap[edge.SourceNodeID] = append(edgeMap[edge.SourceNodeID], edge.TargetNodeID)
|
||||
}
|
||||
|
||||
reachable.reachableNodes, err = performReachabilityAnalysis(nodeMap, edgeMap, startNode, endNode)
|
||||
reachable.reachableNodes, err = performReachabilityAnalysis(nodeMap, edgeMap, startNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -635,12 +602,12 @@ func processNestedReachability(c *vo.Canvas, r *reachability) error {
|
|||
Nodes: append([]*vo.Node{
|
||||
{
|
||||
ID: node.ID,
|
||||
Type: vo.BlockTypeBotStart,
|
||||
Type: entity.NodeTypeEntry.IDStr(),
|
||||
Data: node.Data,
|
||||
},
|
||||
{
|
||||
ID: node.ID,
|
||||
Type: vo.BlockTypeBotEnd,
|
||||
Type: entity.NodeTypeExit.IDStr(),
|
||||
},
|
||||
}, node.Blocks...),
|
||||
Edges: node.Edges,
|
||||
|
|
@ -663,9 +630,9 @@ func findStartAndEndNodes(nodes []*vo.Node) (*vo.Node, *vo.Node, error) {
|
|||
|
||||
for _, node := range nodes {
|
||||
switch node.Type {
|
||||
case vo.BlockTypeBotStart:
|
||||
case entity.NodeTypeEntry.IDStr():
|
||||
startNode = node
|
||||
case vo.BlockTypeBotEnd:
|
||||
case entity.NodeTypeExit.IDStr():
|
||||
endNode = node
|
||||
}
|
||||
}
|
||||
|
|
@ -680,7 +647,7 @@ func findStartAndEndNodes(nodes []*vo.Node) (*vo.Node, *vo.Node, error) {
|
|||
return startNode, endNode, nil
|
||||
}
|
||||
|
||||
func performReachabilityAnalysis(nodeMap map[string]*vo.Node, edgeMap map[string][]string, startNode *vo.Node, endNode *vo.Node) (map[string]*vo.Node, error) {
|
||||
func performReachabilityAnalysis(nodeMap map[string]*vo.Node, edgeMap map[string][]string, startNode *vo.Node) (map[string]*vo.Node, error) {
|
||||
result := make(map[string]*vo.Node)
|
||||
result[startNode.ID] = startNode
|
||||
|
||||
|
|
|
|||
|
|
@ -1,181 +0,0 @@
|
|||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package compose
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/spf13/cast"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector"
|
||||
)
|
||||
|
||||
func (s *NodeSchema) OutputPortCount() (int, bool) {
|
||||
var hasExceptionPort bool
|
||||
if s.ExceptionConfigs != nil && s.ExceptionConfigs.ProcessType != nil &&
|
||||
*s.ExceptionConfigs.ProcessType == vo.ErrorProcessTypeExceptionBranch {
|
||||
hasExceptionPort = true
|
||||
}
|
||||
|
||||
switch s.Type {
|
||||
case entity.NodeTypeSelector:
|
||||
return len(mustGetKey[[]*selector.OneClauseSchema]("Clauses", s.Configs)) + 1, hasExceptionPort
|
||||
case entity.NodeTypeQuestionAnswer:
|
||||
if mustGetKey[qa.AnswerType]("AnswerType", s.Configs.(map[string]any)) == qa.AnswerByChoices {
|
||||
if mustGetKey[qa.ChoiceType]("ChoiceType", s.Configs.(map[string]any)) == qa.FixedChoices {
|
||||
return len(mustGetKey[[]string]("FixedChoices", s.Configs.(map[string]any))) + 1, hasExceptionPort
|
||||
} else {
|
||||
return 2, hasExceptionPort
|
||||
}
|
||||
}
|
||||
return 1, hasExceptionPort
|
||||
case entity.NodeTypeIntentDetector:
|
||||
intents := mustGetKey[[]string]("Intents", s.Configs.(map[string]any))
|
||||
return len(intents) + 1, hasExceptionPort
|
||||
default:
|
||||
return 1, hasExceptionPort
|
||||
}
|
||||
}
|
||||
|
||||
type BranchMapping struct {
|
||||
Normal []map[string]bool
|
||||
Exception map[string]bool
|
||||
}
|
||||
|
||||
const (
|
||||
DefaultBranch = "default"
|
||||
BranchFmt = "branch_%d"
|
||||
)
|
||||
|
||||
func (s *NodeSchema) GetBranch(bMapping *BranchMapping) (*compose.GraphBranch, error) {
|
||||
if bMapping == nil {
|
||||
return nil, errors.New("no branch mapping")
|
||||
}
|
||||
|
||||
endNodes := make(map[string]bool)
|
||||
for i := range bMapping.Normal {
|
||||
for k := range bMapping.Normal[i] {
|
||||
endNodes[k] = true
|
||||
}
|
||||
}
|
||||
|
||||
if bMapping.Exception != nil {
|
||||
for k := range bMapping.Exception {
|
||||
endNodes[k] = true
|
||||
}
|
||||
}
|
||||
|
||||
switch s.Type {
|
||||
case entity.NodeTypeSelector:
|
||||
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
|
||||
choice := in[selector.SelectKey].(int)
|
||||
if choice < 0 || choice > len(bMapping.Normal) {
|
||||
return nil, fmt.Errorf("node %s choice out of range: %d", s.Key, choice)
|
||||
}
|
||||
|
||||
choices := make(map[string]bool, len((bMapping.Normal)[choice]))
|
||||
for k := range (bMapping.Normal)[choice] {
|
||||
choices[k] = true
|
||||
}
|
||||
|
||||
return choices, nil
|
||||
}
|
||||
return compose.NewGraphMultiBranch(condition, endNodes), nil
|
||||
case entity.NodeTypeQuestionAnswer:
|
||||
conf := s.Configs.(map[string]any)
|
||||
if mustGetKey[qa.AnswerType]("AnswerType", conf) == qa.AnswerByChoices {
|
||||
choiceType := mustGetKey[qa.ChoiceType]("ChoiceType", conf)
|
||||
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
|
||||
optionID, ok := nodes.TakeMapValue(in, compose.FieldPath{qa.OptionIDKey})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to take option id from input map: %v", in)
|
||||
}
|
||||
|
||||
if optionID.(string) == "other" {
|
||||
return (bMapping.Normal)[len(bMapping.Normal)-1], nil
|
||||
}
|
||||
|
||||
if choiceType == qa.DynamicChoices { // all dynamic choices maps to branch 0
|
||||
return (bMapping.Normal)[0], nil
|
||||
}
|
||||
|
||||
optionIDInt, ok := qa.AlphabetToInt(optionID.(string))
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to convert option id from input map: %v", optionID)
|
||||
}
|
||||
|
||||
if optionIDInt < 0 || optionIDInt >= len(bMapping.Normal) {
|
||||
return nil, fmt.Errorf("failed to take option id from input map: %v", in)
|
||||
}
|
||||
|
||||
return (bMapping.Normal)[optionIDInt], nil
|
||||
}
|
||||
return compose.NewGraphMultiBranch(condition, endNodes), nil
|
||||
}
|
||||
return nil, fmt.Errorf("this qa node should not have branches: %s", s.Key)
|
||||
|
||||
case entity.NodeTypeIntentDetector:
|
||||
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
|
||||
isSuccess, ok := in["isSuccess"]
|
||||
if ok && isSuccess != nil && !isSuccess.(bool) {
|
||||
return bMapping.Exception, nil
|
||||
}
|
||||
|
||||
classificationId, ok := nodes.TakeMapValue(in, compose.FieldPath{"classificationId"})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to take classification id from input map: %v", in)
|
||||
}
|
||||
|
||||
// Intent detector the node default branch uses classificationId=0. But currently scene, the implementation uses default as the last element of the array.
|
||||
// Therefore, when classificationId=0, it needs to be converted into the node corresponding to the last index of the array.
|
||||
// Other options also need to reduce the index by 1.
|
||||
id, err := cast.ToInt64E(classificationId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
realID := id - 1
|
||||
|
||||
if realID >= int64(len(bMapping.Normal)) {
|
||||
return nil, fmt.Errorf("invalid classification id from input, classification id: %v", classificationId)
|
||||
}
|
||||
|
||||
if realID < 0 {
|
||||
realID = int64(len(bMapping.Normal)) - 1
|
||||
}
|
||||
|
||||
return (bMapping.Normal)[realID], nil
|
||||
}
|
||||
return compose.NewGraphMultiBranch(condition, endNodes), nil
|
||||
default:
|
||||
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
|
||||
isSuccess, ok := in["isSuccess"]
|
||||
if ok && isSuccess != nil && !isSuccess.(bool) {
|
||||
return bMapping.Exception, nil
|
||||
}
|
||||
|
||||
return (bMapping.Normal)[0], nil
|
||||
}
|
||||
return compose.NewGraphMultiBranch(condition, endNodes), nil
|
||||
}
|
||||
}
|
||||
|
|
@ -1,194 +0,0 @@
|
|||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package compose
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector"
|
||||
)
|
||||
|
||||
type selectorCallbackField struct {
|
||||
Key string `json:"key"`
|
||||
Type vo.DataType `json:"type"`
|
||||
Value any `json:"value"`
|
||||
}
|
||||
|
||||
type selectorCondition struct {
|
||||
Left selectorCallbackField `json:"left"`
|
||||
Operator vo.OperatorType `json:"operator"`
|
||||
Right *selectorCallbackField `json:"right,omitempty"`
|
||||
}
|
||||
|
||||
type selectorBranch struct {
|
||||
Conditions []*selectorCondition `json:"conditions"`
|
||||
Logic vo.LogicType `json:"logic"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func (s *NodeSchema) toSelectorCallbackInput(sc *WorkflowSchema) func(_ context.Context, in map[string]any) (map[string]any, error) {
|
||||
return func(_ context.Context, in map[string]any) (map[string]any, error) {
|
||||
config := mustGetKey[[]*selector.OneClauseSchema]("Clauses", s.Configs)
|
||||
count := len(config)
|
||||
|
||||
output := make([]*selectorBranch, count)
|
||||
|
||||
for _, source := range s.InputSources {
|
||||
targetPath := source.Path
|
||||
if len(targetPath) == 2 {
|
||||
indexStr := targetPath[0]
|
||||
index, err := strconv.Atoi(indexStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
branch := output[index]
|
||||
if branch == nil {
|
||||
output[index] = &selectorBranch{
|
||||
Conditions: []*selectorCondition{
|
||||
{
|
||||
Operator: config[index].Single.ToCanvasOperatorType(),
|
||||
},
|
||||
},
|
||||
Logic: selector.ClauseRelationAND.ToVOLogicType(),
|
||||
}
|
||||
}
|
||||
|
||||
if targetPath[1] == selector.LeftKey {
|
||||
leftV, ok := nodes.TakeMapValue(in, targetPath)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to take left value of %s", targetPath)
|
||||
}
|
||||
if source.Source.Ref.VariableType != nil {
|
||||
if *source.Source.Ref.VariableType == vo.ParentIntermediate {
|
||||
parentNodeKey, ok := sc.Hierarchy[s.Key]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to find parent node key of %s", s.Key)
|
||||
}
|
||||
parentNode := sc.GetNode(parentNodeKey)
|
||||
output[index].Conditions[0].Left = selectorCallbackField{
|
||||
Key: parentNode.Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
|
||||
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
|
||||
Value: leftV,
|
||||
}
|
||||
} else {
|
||||
output[index].Conditions[0].Left = selectorCallbackField{
|
||||
Key: "",
|
||||
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
|
||||
Value: leftV,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
output[index].Conditions[0].Left = selectorCallbackField{
|
||||
Key: sc.GetNode(source.Source.Ref.FromNodeKey).Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
|
||||
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
|
||||
Value: leftV,
|
||||
}
|
||||
}
|
||||
} else if targetPath[1] == selector.RightKey {
|
||||
rightV, ok := nodes.TakeMapValue(in, targetPath)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to take right value of %s", targetPath)
|
||||
}
|
||||
output[index].Conditions[0].Right = &selectorCallbackField{
|
||||
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
|
||||
Value: rightV,
|
||||
}
|
||||
}
|
||||
} else if len(targetPath) == 3 {
|
||||
indexStr := targetPath[0]
|
||||
index, err := strconv.Atoi(indexStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
multi := config[index].Multi
|
||||
|
||||
branch := output[index]
|
||||
if branch == nil {
|
||||
output[index] = &selectorBranch{
|
||||
Conditions: make([]*selectorCondition, len(multi.Clauses)),
|
||||
Logic: multi.Relation.ToVOLogicType(),
|
||||
}
|
||||
}
|
||||
|
||||
clauseIndexStr := targetPath[1]
|
||||
clauseIndex, err := strconv.Atoi(clauseIndexStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
clause := multi.Clauses[clauseIndex]
|
||||
|
||||
if output[index].Conditions[clauseIndex] == nil {
|
||||
output[index].Conditions[clauseIndex] = &selectorCondition{
|
||||
Operator: clause.ToCanvasOperatorType(),
|
||||
}
|
||||
}
|
||||
|
||||
if targetPath[2] == selector.LeftKey {
|
||||
leftV, ok := nodes.TakeMapValue(in, targetPath)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to take left value of %s", targetPath)
|
||||
}
|
||||
if source.Source.Ref.VariableType != nil {
|
||||
if *source.Source.Ref.VariableType == vo.ParentIntermediate {
|
||||
parentNodeKey, ok := sc.Hierarchy[s.Key]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to find parent node key of %s", s.Key)
|
||||
}
|
||||
parentNode := sc.GetNode(parentNodeKey)
|
||||
output[index].Conditions[clauseIndex].Left = selectorCallbackField{
|
||||
Key: parentNode.Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
|
||||
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
|
||||
Value: leftV,
|
||||
}
|
||||
} else {
|
||||
output[index].Conditions[clauseIndex].Left = selectorCallbackField{
|
||||
Key: "",
|
||||
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
|
||||
Value: leftV,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
output[index].Conditions[clauseIndex].Left = selectorCallbackField{
|
||||
Key: sc.GetNode(source.Source.Ref.FromNodeKey).Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
|
||||
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
|
||||
Value: leftV,
|
||||
}
|
||||
}
|
||||
} else if targetPath[2] == selector.RightKey {
|
||||
rightV, ok := nodes.TakeMapValue(in, targetPath)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to take right value of %s", targetPath)
|
||||
}
|
||||
output[index].Conditions[clauseIndex].Right = &selectorCallbackField{
|
||||
Type: s.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
|
||||
Value: rightV,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]any{"branches": output}, nil
|
||||
}
|
||||
}
|
||||
|
|
@ -31,7 +31,9 @@ import (
|
|||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm"
|
||||
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
|
|
@ -53,7 +55,7 @@ func (r *WorkflowRunner) designateOptions(ctx context.Context) (context.Context,
|
|||
rootHandler := execute.NewRootWorkflowHandler(
|
||||
wb,
|
||||
executeID,
|
||||
workflowSC.requireCheckPoint,
|
||||
workflowSC.RequireCheckpoint(),
|
||||
eventChan,
|
||||
resumedEvent,
|
||||
exeCfg,
|
||||
|
|
@ -67,7 +69,7 @@ func (r *WorkflowRunner) designateOptions(ctx context.Context) (context.Context,
|
|||
var nodeOpt einoCompose.Option
|
||||
if ns.Type == entity.NodeTypeExit {
|
||||
nodeOpt = nodeCallbackOption(key, ns.Name, eventChan, resumedEvent,
|
||||
ptr.Of(mustGetKey[vo.TerminatePlan]("TerminalPlan", ns.Configs)))
|
||||
ptr.Of(ns.Configs.(*exit.Config).TerminatePlan))
|
||||
} else if ns.Type != entity.NodeTypeLambda {
|
||||
nodeOpt = nodeCallbackOption(key, ns.Name, eventChan, resumedEvent, nil)
|
||||
}
|
||||
|
|
@ -117,7 +119,7 @@ func (r *WorkflowRunner) designateOptions(ctx context.Context) (context.Context,
|
|||
}
|
||||
}
|
||||
|
||||
if workflowSC.requireCheckPoint {
|
||||
if workflowSC.RequireCheckpoint() {
|
||||
opts = append(opts, einoCompose.WithCheckPointID(strconv.FormatInt(executeID, 10)))
|
||||
}
|
||||
|
||||
|
|
@ -139,7 +141,7 @@ func WrapOptWithIndex(opt einoCompose.Option, parentNodeKey vo.NodeKey, index in
|
|||
|
||||
func (r *WorkflowRunner) designateOptionsForSubWorkflow(ctx context.Context,
|
||||
parentHandler *execute.WorkflowHandler,
|
||||
ns *NodeSchema,
|
||||
ns *schema2.NodeSchema,
|
||||
pathPrefix ...string) (opts []einoCompose.Option, err error) {
|
||||
var (
|
||||
resumeEvent = r.interruptEvent
|
||||
|
|
@ -163,7 +165,7 @@ func (r *WorkflowRunner) designateOptionsForSubWorkflow(ctx context.Context,
|
|||
var nodeOpt einoCompose.Option
|
||||
if subNS.Type == entity.NodeTypeExit {
|
||||
nodeOpt = nodeCallbackOption(key, subNS.Name, eventChan, resumeEvent,
|
||||
ptr.Of(mustGetKey[vo.TerminatePlan]("TerminalPlan", subNS.Configs)))
|
||||
ptr.Of(subNS.Configs.(*exit.Config).TerminatePlan))
|
||||
} else {
|
||||
nodeOpt = nodeCallbackOption(key, subNS.Name, eventChan, resumeEvent, nil)
|
||||
}
|
||||
|
|
@ -219,7 +221,7 @@ func (r *WorkflowRunner) designateOptionsForSubWorkflow(ctx context.Context,
|
|||
return opts, nil
|
||||
}
|
||||
|
||||
func llmToolCallbackOptions(ctx context.Context, ns *NodeSchema, eventChan chan *execute.Event,
|
||||
func llmToolCallbackOptions(ctx context.Context, ns *schema2.NodeSchema, eventChan chan *execute.Event,
|
||||
sw *schema.StreamWriter[*entity.Message]) (
|
||||
opts []einoCompose.Option, err error) {
|
||||
// this is a LLM node.
|
||||
|
|
@ -229,7 +231,8 @@ func llmToolCallbackOptions(ctx context.Context, ns *NodeSchema, eventChan chan
|
|||
panic("impossible: llmToolCallbackOptions is called on a non-LLM node")
|
||||
}
|
||||
|
||||
fcParams := getKeyOrZero[*vo.FCParam]("FCParam", ns.Configs)
|
||||
cfg := ns.Configs.(*llm.Config)
|
||||
fcParams := cfg.FCParam
|
||||
if fcParams != nil {
|
||||
if fcParams.WorkflowFCParam != nil {
|
||||
// TODO: try to avoid getting the workflow tool all over again
|
||||
|
|
@ -272,7 +275,7 @@ func llmToolCallbackOptions(ctx context.Context, ns *NodeSchema, eventChan chan
|
|||
|
||||
toolHandler := execute.NewToolHandler(eventChan, funcInfo)
|
||||
opt := einoCompose.WithCallbacks(toolHandler)
|
||||
opt = einoCompose.WithLambdaOption(llm.WithNestedWorkflowOptions(nodes.WithOptsForNested(opt))).DesignateNode(string(ns.Key))
|
||||
opt = einoCompose.WithLambdaOption(nodes.WithOptsForNested(opt)).DesignateNode(string(ns.Key))
|
||||
opts = append(opts, opt)
|
||||
}
|
||||
}
|
||||
|
|
@ -310,7 +313,7 @@ func llmToolCallbackOptions(ctx context.Context, ns *NodeSchema, eventChan chan
|
|||
|
||||
toolHandler := execute.NewToolHandler(eventChan, funcInfo)
|
||||
opt := einoCompose.WithCallbacks(toolHandler)
|
||||
opt = einoCompose.WithLambdaOption(llm.WithNestedWorkflowOptions(nodes.WithOptsForNested(opt))).DesignateNode(string(ns.Key))
|
||||
opt = einoCompose.WithLambdaOption(nodes.WithOptsForNested(opt)).DesignateNode(string(ns.Key))
|
||||
opts = append(opts, opt)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,12 +25,13 @@ import (
|
|||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
)
|
||||
|
||||
// outputValueFiller will fill the output value with nil if the key is not present in the output map.
|
||||
// if a node emits stream as output, the node needs to handle these absent keys in stream themselves.
|
||||
func (s *NodeSchema) outputValueFiller() func(ctx context.Context, output map[string]any) (map[string]any, error) {
|
||||
func outputValueFiller(s *schema2.NodeSchema) func(ctx context.Context, output map[string]any) (map[string]any, error) {
|
||||
if len(s.OutputTypes) == 0 {
|
||||
return func(ctx context.Context, output map[string]any) (map[string]any, error) {
|
||||
return output, nil
|
||||
|
|
@ -55,7 +56,7 @@ func (s *NodeSchema) outputValueFiller() func(ctx context.Context, output map[st
|
|||
|
||||
// inputValueFiller will fill the input value with default value(zero or nil) if the input value is not present in map.
|
||||
// if a node accepts stream as input, the node needs to handle these absent keys in stream themselves.
|
||||
func (s *NodeSchema) inputValueFiller() func(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
func inputValueFiller(s *schema2.NodeSchema) func(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
if len(s.InputTypes) == 0 {
|
||||
return func(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
return input, nil
|
||||
|
|
@ -78,7 +79,7 @@ func (s *NodeSchema) inputValueFiller() func(ctx context.Context, input map[stri
|
|||
}
|
||||
}
|
||||
|
||||
func (s *NodeSchema) streamInputValueFiller() func(ctx context.Context,
|
||||
func streamInputValueFiller(s *schema2.NodeSchema) func(ctx context.Context,
|
||||
input *schema.StreamReader[map[string]any]) *schema.StreamReader[map[string]any] {
|
||||
fn := func(ctx context.Context, i map[string]any) (map[string]any, error) {
|
||||
newI := make(map[string]any)
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
func TestNodeSchema_OutputValueFiller(t *testing.T) {
|
||||
|
|
@ -282,11 +283,11 @@ func TestNodeSchema_OutputValueFiller(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &NodeSchema{
|
||||
s := &schema.NodeSchema{
|
||||
OutputTypes: tt.fields.Outputs,
|
||||
}
|
||||
|
||||
got, err := s.outputValueFiller()(context.Background(), tt.fields.In)
|
||||
got, err := outputValueFiller(s)(context.Background(), tt.fields.In)
|
||||
|
||||
if len(tt.wantErr) > 0 {
|
||||
assert.Error(t, err)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,118 @@
|
|||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package compose
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime/debug"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/subworkflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
type Node struct {
|
||||
Lambda *compose.Lambda
|
||||
}
|
||||
|
||||
// New instantiates the actual node type from NodeSchema.
|
||||
func New(ctx context.Context, s *schema.NodeSchema,
|
||||
inner compose.Runnable[map[string]any, map[string]any], // inner workflow for composite node
|
||||
sc *schema.WorkflowSchema, // the workflow this NodeSchema is in
|
||||
deps *dependencyInfo, // the dependency for this node pre-calculated by workflow engine
|
||||
) (_ *Node, err error) {
|
||||
defer func() {
|
||||
if panicErr := recover(); panicErr != nil {
|
||||
err = safego.NewPanicErr(panicErr, debug.Stack())
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
err = vo.WrapIfNeeded(errno.ErrCreateNodeFail, err, errorx.KV("node_name", s.Name), errorx.KV("cause", err.Error()))
|
||||
}
|
||||
}()
|
||||
|
||||
var fullSources map[string]*schema.SourceInfo
|
||||
if m := entity.NodeMetaByNodeType(s.Type); m != nil && m.InputSourceAware {
|
||||
if fullSources, err = GetFullSources(s, sc, deps); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.FullSources = fullSources
|
||||
}
|
||||
|
||||
// if NodeSchema's Configs implements NodeBuilder, will use it to build the node
|
||||
nb, ok := s.Configs.(schema.NodeBuilder)
|
||||
if ok {
|
||||
opts := []schema.BuildOption{
|
||||
schema.WithWorkflowSchema(sc),
|
||||
schema.WithInnerWorkflow(inner),
|
||||
}
|
||||
|
||||
// build the actual InvokableNode, etc.
|
||||
n, err := nb.Build(ctx, s, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// wrap InvokableNode, etc. within NodeRunner, converting to eino's Lambda
|
||||
return toNode(s, n), nil
|
||||
}
|
||||
|
||||
switch s.Type {
|
||||
case entity.NodeTypeLambda:
|
||||
if s.Lambda == nil {
|
||||
return nil, fmt.Errorf("lambda is not defined for NodeTypeLambda")
|
||||
}
|
||||
|
||||
return &Node{Lambda: s.Lambda}, nil
|
||||
case entity.NodeTypeSubWorkflow:
|
||||
subWorkflow, err := buildSubWorkflow(ctx, s, sc.RequireCheckpoint())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toNode(s, subWorkflow), nil
|
||||
default:
|
||||
panic(fmt.Sprintf("node schema's Configs does not implement NodeBuilder. type: %v", s.Type))
|
||||
}
|
||||
}
|
||||
|
||||
func buildSubWorkflow(ctx context.Context, s *schema.NodeSchema, requireCheckpoint bool) (*subworkflow.SubWorkflow, error) {
|
||||
var opts []WorkflowOption
|
||||
opts = append(opts, WithIDAsName(s.Configs.(*subworkflow.Config).WorkflowID))
|
||||
if requireCheckpoint {
|
||||
opts = append(opts, WithParentRequireCheckpoint())
|
||||
}
|
||||
if s := execute.GetStaticConfig(); s != nil && s.MaxNodeCountPerWorkflow > 0 {
|
||||
opts = append(opts, WithMaxNodeCount(s.MaxNodeCountPerWorkflow))
|
||||
}
|
||||
wf, err := NewWorkflow(ctx, s.SubWorkflowSchema, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &subworkflow.SubWorkflow{
|
||||
Runner: wf.Runner,
|
||||
}, nil
|
||||
}
|
||||
|
|
@ -33,6 +33,8 @@ import (
|
|||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
|
|
@ -48,7 +50,6 @@ type nodeRunConfig[O any] struct {
|
|||
maxRetry int64
|
||||
errProcessType vo.ErrorProcessType
|
||||
dataOnErr func(ctx context.Context) map[string]any
|
||||
callbackEnabled bool
|
||||
preProcessors []func(ctx context.Context, input map[string]any) (map[string]any, error)
|
||||
postProcessors []func(ctx context.Context, input map[string]any) (map[string]any, error)
|
||||
streamPreProcessors []func(ctx context.Context,
|
||||
|
|
@ -58,12 +59,14 @@ type nodeRunConfig[O any] struct {
|
|||
init []func(context.Context) (context.Context, error)
|
||||
i compose.Invoke[map[string]any, map[string]any, O]
|
||||
s compose.Stream[map[string]any, map[string]any, O]
|
||||
c compose.Collect[map[string]any, map[string]any, O]
|
||||
t compose.Transform[map[string]any, map[string]any, O]
|
||||
}
|
||||
|
||||
func newNodeRunConfig[O any](ns *NodeSchema,
|
||||
func newNodeRunConfig[O any](ns *schema2.NodeSchema,
|
||||
i compose.Invoke[map[string]any, map[string]any, O],
|
||||
s compose.Stream[map[string]any, map[string]any, O],
|
||||
c compose.Collect[map[string]any, map[string]any, O],
|
||||
t compose.Transform[map[string]any, map[string]any, O],
|
||||
opts *newNodeOptions) *nodeRunConfig[O] {
|
||||
meta := entity.NodeMetaByNodeType(ns.Type)
|
||||
|
|
@ -92,12 +95,12 @@ func newNodeRunConfig[O any](ns *NodeSchema,
|
|||
keyFinishedMarkerTrimmer(),
|
||||
}
|
||||
if meta.PreFillZero {
|
||||
preProcessors = append(preProcessors, ns.inputValueFiller())
|
||||
preProcessors = append(preProcessors, inputValueFiller(ns))
|
||||
}
|
||||
|
||||
var postProcessors []func(ctx context.Context, input map[string]any) (map[string]any, error)
|
||||
if meta.PostFillNil {
|
||||
postProcessors = append(postProcessors, ns.outputValueFiller())
|
||||
postProcessors = append(postProcessors, outputValueFiller(ns))
|
||||
}
|
||||
|
||||
streamPreProcessors := []func(ctx context.Context,
|
||||
|
|
@ -110,7 +113,15 @@ func newNodeRunConfig[O any](ns *NodeSchema,
|
|||
},
|
||||
}
|
||||
if meta.PreFillZero {
|
||||
streamPreProcessors = append(streamPreProcessors, ns.streamInputValueFiller())
|
||||
streamPreProcessors = append(streamPreProcessors, streamInputValueFiller(ns))
|
||||
}
|
||||
|
||||
if meta.UseCtxCache {
|
||||
opts.init = append([]func(ctx context.Context) (context.Context, error){
|
||||
func(ctx context.Context) (context.Context, error) {
|
||||
return ctxcache.Init(ctx), nil
|
||||
},
|
||||
}, opts.init...)
|
||||
}
|
||||
|
||||
opts.init = append(opts.init, func(ctx context.Context) (context.Context, error) {
|
||||
|
|
@ -129,7 +140,6 @@ func newNodeRunConfig[O any](ns *NodeSchema,
|
|||
maxRetry: maxRetry,
|
||||
errProcessType: errProcessType,
|
||||
dataOnErr: dataOnErr,
|
||||
callbackEnabled: meta.CallbackEnabled,
|
||||
preProcessors: preProcessors,
|
||||
postProcessors: postProcessors,
|
||||
streamPreProcessors: streamPreProcessors,
|
||||
|
|
@ -138,18 +148,21 @@ func newNodeRunConfig[O any](ns *NodeSchema,
|
|||
init: opts.init,
|
||||
i: i,
|
||||
s: s,
|
||||
c: c,
|
||||
t: t,
|
||||
}
|
||||
}
|
||||
|
||||
func newNodeRunConfigWOOpt(ns *NodeSchema,
|
||||
func newNodeRunConfigWOOpt(ns *schema2.NodeSchema,
|
||||
i compose.InvokeWOOpt[map[string]any, map[string]any],
|
||||
s compose.StreamWOOpt[map[string]any, map[string]any],
|
||||
c compose.CollectWOOpt[map[string]any, map[string]any],
|
||||
t compose.TransformWOOpts[map[string]any, map[string]any],
|
||||
opts *newNodeOptions) *nodeRunConfig[any] {
|
||||
var (
|
||||
iWO compose.Invoke[map[string]any, map[string]any, any]
|
||||
sWO compose.Stream[map[string]any, map[string]any, any]
|
||||
cWO compose.Collect[map[string]any, map[string]any, any]
|
||||
tWO compose.Transform[map[string]any, map[string]any, any]
|
||||
)
|
||||
|
||||
|
|
@ -165,13 +178,19 @@ func newNodeRunConfigWOOpt(ns *NodeSchema,
|
|||
}
|
||||
}
|
||||
|
||||
if c != nil {
|
||||
cWO = func(ctx context.Context, in *schema.StreamReader[map[string]any], _ ...any) (out map[string]any, err error) {
|
||||
return c(ctx, in)
|
||||
}
|
||||
}
|
||||
|
||||
if t != nil {
|
||||
tWO = func(ctx context.Context, input *schema.StreamReader[map[string]any], opts ...any) (output *schema.StreamReader[map[string]any], err error) {
|
||||
return t(ctx, input)
|
||||
}
|
||||
}
|
||||
|
||||
return newNodeRunConfig[any](ns, iWO, sWO, tWO, opts)
|
||||
return newNodeRunConfig[any](ns, iWO, sWO, cWO, tWO, opts)
|
||||
}
|
||||
|
||||
type newNodeOptions struct {
|
||||
|
|
@ -180,57 +199,100 @@ type newNodeOptions struct {
|
|||
init []func(context.Context) (context.Context, error)
|
||||
}
|
||||
|
||||
type newNodeOption func(*newNodeOptions)
|
||||
func toNode(ns *schema2.NodeSchema, r any) *Node {
|
||||
iWOpt, _ := r.(nodes.InvokableNodeWOpt)
|
||||
sWOpt, _ := r.(nodes.StreamableNodeWOpt)
|
||||
cWOpt, _ := r.(nodes.CollectableNodeWOpt)
|
||||
tWOpt, _ := r.(nodes.TransformableNodeWOpt)
|
||||
iWOOpt, _ := r.(nodes.InvokableNode)
|
||||
sWOOpt, _ := r.(nodes.StreamableNode)
|
||||
cWOOpt, _ := r.(nodes.CollectableNode)
|
||||
tWOOpt, _ := r.(nodes.TransformableNode)
|
||||
|
||||
func withCallbackInputConverter(f func(context.Context, map[string]any) (map[string]any, error)) newNodeOption {
|
||||
return func(opts *newNodeOptions) {
|
||||
opts.callbackInputConverter = f
|
||||
var wOpt, wOOpt bool
|
||||
if iWOpt != nil || sWOpt != nil || cWOpt != nil || tWOpt != nil {
|
||||
wOpt = true
|
||||
}
|
||||
}
|
||||
func withCallbackOutputConverter(f func(context.Context, map[string]any) (*nodes.StructuredCallbackOutput, error)) newNodeOption {
|
||||
return func(opts *newNodeOptions) {
|
||||
opts.callbackOutputConverter = f
|
||||
if iWOOpt != nil || sWOOpt != nil || cWOOpt != nil || tWOOpt != nil {
|
||||
wOOpt = true
|
||||
}
|
||||
|
||||
if wOpt && wOOpt {
|
||||
panic("a node's different streaming methods needs to be consistent: " +
|
||||
"they should ALL have NodeOption or None should have them")
|
||||
}
|
||||
|
||||
if !wOpt && !wOOpt {
|
||||
panic("a node should implement at least one interface among: InvokableNodeWOpt, StreamableNodeWOpt, CollectableNodeWOpt, TransformableNodeWOpt, InvokableNode, StreamableNode, CollectableNode, TransformableNode")
|
||||
}
|
||||
}
|
||||
func withInit(f func(context.Context) (context.Context, error)) newNodeOption {
|
||||
return func(opts *newNodeOptions) {
|
||||
opts.init = append(opts.init, f)
|
||||
}
|
||||
}
|
||||
|
||||
func invokableNode(ns *NodeSchema, i compose.InvokeWOOpt[map[string]any, map[string]any], opts ...newNodeOption) *Node {
|
||||
options := &newNodeOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
ci, ok := r.(nodes.CallbackInputConverted)
|
||||
if ok {
|
||||
options.callbackInputConverter = ci.ToCallbackInput
|
||||
}
|
||||
|
||||
return newNodeRunConfigWOOpt(ns, i, nil, nil, options).toNode()
|
||||
}
|
||||
|
||||
func invokableNodeWO[O any](ns *NodeSchema, i compose.Invoke[map[string]any, map[string]any, O], opts ...newNodeOption) *Node {
|
||||
options := &newNodeOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
co, ok := r.(nodes.CallbackOutputConverted)
|
||||
if ok {
|
||||
options.callbackOutputConverter = co.ToCallbackOutput
|
||||
}
|
||||
|
||||
return newNodeRunConfig(ns, i, nil, nil, options).toNode()
|
||||
}
|
||||
|
||||
func invokableTransformableNode(ns *NodeSchema, i compose.InvokeWOOpt[map[string]any, map[string]any],
|
||||
t compose.TransformWOOpts[map[string]any, map[string]any], opts ...newNodeOption) *Node {
|
||||
options := &newNodeOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
init, ok := r.(nodes.Initializer)
|
||||
if ok {
|
||||
options.init = append(options.init, init.Init)
|
||||
}
|
||||
return newNodeRunConfigWOOpt(ns, i, nil, t, options).toNode()
|
||||
}
|
||||
|
||||
func invokableStreamableNodeWO[O any](ns *NodeSchema, i compose.Invoke[map[string]any, map[string]any, O], s compose.Stream[map[string]any, map[string]any, O], opts ...newNodeOption) *Node {
|
||||
options := &newNodeOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
if wOpt {
|
||||
var (
|
||||
i compose.Invoke[map[string]any, map[string]any, nodes.NodeOption]
|
||||
s compose.Stream[map[string]any, map[string]any, nodes.NodeOption]
|
||||
c compose.Collect[map[string]any, map[string]any, nodes.NodeOption]
|
||||
t compose.Transform[map[string]any, map[string]any, nodes.NodeOption]
|
||||
)
|
||||
|
||||
if iWOpt != nil {
|
||||
i = iWOpt.Invoke
|
||||
}
|
||||
return newNodeRunConfig(ns, i, s, nil, options).toNode()
|
||||
|
||||
if sWOpt != nil {
|
||||
s = sWOpt.Stream
|
||||
}
|
||||
|
||||
if cWOpt != nil {
|
||||
c = cWOpt.Collect
|
||||
}
|
||||
|
||||
if tWOpt != nil {
|
||||
t = tWOpt.Transform
|
||||
}
|
||||
|
||||
return newNodeRunConfig(ns, i, s, c, t, options).toNode()
|
||||
}
|
||||
|
||||
var (
|
||||
i compose.InvokeWOOpt[map[string]any, map[string]any]
|
||||
s compose.StreamWOOpt[map[string]any, map[string]any]
|
||||
c compose.CollectWOOpt[map[string]any, map[string]any]
|
||||
t compose.TransformWOOpts[map[string]any, map[string]any]
|
||||
)
|
||||
|
||||
if iWOOpt != nil {
|
||||
i = iWOOpt.Invoke
|
||||
}
|
||||
|
||||
if sWOOpt != nil {
|
||||
s = sWOOpt.Stream
|
||||
}
|
||||
|
||||
if cWOOpt != nil {
|
||||
c = cWOOpt.Collect
|
||||
}
|
||||
|
||||
if tWOOpt != nil {
|
||||
t = tWOOpt.Transform
|
||||
}
|
||||
|
||||
return newNodeRunConfigWOOpt(ns, i, s, c, t, options).toNode()
|
||||
}
|
||||
|
||||
func (nc *nodeRunConfig[O]) invoke() func(ctx context.Context, input map[string]any, opts ...O) (output map[string]any, err error) {
|
||||
|
|
@ -375,10 +437,8 @@ func (nc *nodeRunConfig[O]) transform() func(ctx context.Context, input *schema.
|
|||
func (nc *nodeRunConfig[O]) toNode() *Node {
|
||||
var opts []compose.LambdaOpt
|
||||
opts = append(opts, compose.WithLambdaType(string(nc.nodeType)))
|
||||
|
||||
if nc.callbackEnabled {
|
||||
opts = append(opts, compose.WithLambdaCallbackEnable(true))
|
||||
}
|
||||
|
||||
l, err := compose.AnyLambda(nc.invoke(), nc.stream(), nil, nc.transform(), opts...)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to create lambda for node %s, err: %v", nc.nodeName, err))
|
||||
|
|
@ -406,9 +466,6 @@ func newNodeRunner[O any](ctx context.Context, cfg *nodeRunConfig[O]) (context.C
|
|||
}
|
||||
|
||||
func (r *nodeRunner[O]) onStart(ctx context.Context, input map[string]any) (context.Context, error) {
|
||||
if !r.callbackEnabled {
|
||||
return ctx, nil
|
||||
}
|
||||
if r.callbackInputConverter != nil {
|
||||
convertedInput, err := r.callbackInputConverter(ctx, input)
|
||||
if err != nil {
|
||||
|
|
@ -425,10 +482,6 @@ func (r *nodeRunner[O]) onStart(ctx context.Context, input map[string]any) (cont
|
|||
|
||||
func (r *nodeRunner[O]) onStartStream(ctx context.Context, input *schema.StreamReader[map[string]any]) (
|
||||
context.Context, *schema.StreamReader[map[string]any], error) {
|
||||
if !r.callbackEnabled {
|
||||
return ctx, input, nil
|
||||
}
|
||||
|
||||
if r.callbackInputConverter != nil {
|
||||
copied := input.Copy(2)
|
||||
realConverter := func(ctx context.Context) func(map[string]any) (map[string]any, error) {
|
||||
|
|
@ -580,14 +633,10 @@ func (r *nodeRunner[O]) transform(ctx context.Context, input *schema.StreamReade
|
|||
}
|
||||
|
||||
func (r *nodeRunner[O]) onEnd(ctx context.Context, output map[string]any) error {
|
||||
if r.errProcessType == vo.ErrorProcessTypeExceptionBranch || r.errProcessType == vo.ErrorProcessTypeDefault {
|
||||
if r.errProcessType == vo.ErrorProcessTypeExceptionBranch || r.errProcessType == vo.ErrorProcessTypeReturnDefaultData {
|
||||
output["isSuccess"] = true
|
||||
}
|
||||
|
||||
if !r.callbackEnabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
if r.callbackOutputConverter != nil {
|
||||
convertedOutput, err := r.callbackOutputConverter(ctx, output)
|
||||
if err != nil {
|
||||
|
|
@ -603,15 +652,11 @@ func (r *nodeRunner[O]) onEnd(ctx context.Context, output map[string]any) error
|
|||
|
||||
func (r *nodeRunner[O]) onEndStream(ctx context.Context, output *schema.StreamReader[map[string]any]) (
|
||||
*schema.StreamReader[map[string]any], error) {
|
||||
if r.errProcessType == vo.ErrorProcessTypeExceptionBranch || r.errProcessType == vo.ErrorProcessTypeDefault {
|
||||
if r.errProcessType == vo.ErrorProcessTypeExceptionBranch || r.errProcessType == vo.ErrorProcessTypeReturnDefaultData {
|
||||
flag := schema.StreamReaderFromArray([]map[string]any{{"isSuccess": true}})
|
||||
output = schema.MergeStreamReaders([]*schema.StreamReader[map[string]any]{flag, output})
|
||||
}
|
||||
|
||||
if !r.callbackEnabled {
|
||||
return output, nil
|
||||
}
|
||||
|
||||
if r.callbackOutputConverter != nil {
|
||||
copied := output.Copy(2)
|
||||
realConverter := func(ctx context.Context) func(map[string]any) (*nodes.StructuredCallbackOutput, error) {
|
||||
|
|
@ -632,9 +677,7 @@ func (r *nodeRunner[O]) onEndStream(ctx context.Context, output *schema.StreamRe
|
|||
|
||||
func (r *nodeRunner[O]) onError(ctx context.Context, err error) (map[string]any, bool) {
|
||||
if r.interrupted {
|
||||
if r.callbackEnabled {
|
||||
_ = callbacks.OnError(ctx, err)
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
|
|
@ -653,14 +696,13 @@ func (r *nodeRunner[O]) onError(ctx context.Context, err error) (map[string]any,
|
|||
msg := sErr.Msg()
|
||||
|
||||
switch r.errProcessType {
|
||||
case vo.ErrorProcessTypeDefault:
|
||||
case vo.ErrorProcessTypeReturnDefaultData:
|
||||
d := r.dataOnErr(ctx)
|
||||
d["errorBody"] = map[string]any{
|
||||
"errorMessage": msg,
|
||||
"errorCode": code,
|
||||
}
|
||||
d["isSuccess"] = false
|
||||
if r.callbackEnabled {
|
||||
sErr = sErr.ChangeErrLevel(vo.LevelWarn)
|
||||
sOutput := &nodes.StructuredCallbackOutput{
|
||||
Output: d,
|
||||
|
|
@ -668,7 +710,6 @@ func (r *nodeRunner[O]) onError(ctx context.Context, err error) (map[string]any,
|
|||
Error: sErr,
|
||||
}
|
||||
_ = callbacks.OnEnd(ctx, sOutput)
|
||||
}
|
||||
return d, true
|
||||
case vo.ErrorProcessTypeExceptionBranch:
|
||||
s := make(map[string]any)
|
||||
|
|
@ -677,7 +718,6 @@ func (r *nodeRunner[O]) onError(ctx context.Context, err error) (map[string]any,
|
|||
"errorCode": code,
|
||||
}
|
||||
s["isSuccess"] = false
|
||||
if r.callbackEnabled {
|
||||
sErr = sErr.ChangeErrLevel(vo.LevelWarn)
|
||||
sOutput := &nodes.StructuredCallbackOutput{
|
||||
Output: s,
|
||||
|
|
@ -685,12 +725,9 @@ func (r *nodeRunner[O]) onError(ctx context.Context, err error) (map[string]any,
|
|||
Error: sErr,
|
||||
}
|
||||
_ = callbacks.OnEnd(ctx, sOutput)
|
||||
}
|
||||
return s, true
|
||||
default:
|
||||
if r.callbackEnabled {
|
||||
_ = callbacks.OnError(ctx, sErr)
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,580 +0,0 @@
|
|||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package compose
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime/debug"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/batch"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/code"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/conversation"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/emitter"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/httprequester"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/intentdetector"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/json"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/receiver"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/subworkflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/textprocessor"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableaggregator"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableassigner"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
|
||||
)
|
||||
|
||||
type NodeSchema struct {
|
||||
Key vo.NodeKey `json:"key"`
|
||||
Name string `json:"name"`
|
||||
|
||||
Type entity.NodeType `json:"type"`
|
||||
|
||||
// Configs are node specific configurations with pre-defined config key and config value.
|
||||
// Will not participate in request-time field mapping, nor as node's static values.
|
||||
// In a word, these Configs are INTERNAL to node's implementation, the workflow layer is not aware of them.
|
||||
Configs any `json:"configs,omitempty"`
|
||||
|
||||
InputTypes map[string]*vo.TypeInfo `json:"input_types,omitempty"`
|
||||
InputSources []*vo.FieldInfo `json:"input_sources,omitempty"`
|
||||
|
||||
OutputTypes map[string]*vo.TypeInfo `json:"output_types,omitempty"`
|
||||
OutputSources []*vo.FieldInfo `json:"output_sources,omitempty"` // only applicable to composite nodes such as Batch or Loop
|
||||
|
||||
ExceptionConfigs *ExceptionConfig `json:"exception_configs,omitempty"` // generic configurations applicable to most nodes
|
||||
StreamConfigs *StreamConfig `json:"stream_configs,omitempty"`
|
||||
|
||||
SubWorkflowBasic *entity.WorkflowBasic `json:"sub_workflow_basic,omitempty"`
|
||||
SubWorkflowSchema *WorkflowSchema `json:"sub_workflow_schema,omitempty"`
|
||||
|
||||
Lambda *compose.Lambda // not serializable, used for internal test.
|
||||
}
|
||||
|
||||
type ExceptionConfig struct {
|
||||
TimeoutMS int64 `json:"timeout_ms,omitempty"` // timeout in milliseconds, 0 means no timeout
|
||||
MaxRetry int64 `json:"max_retry,omitempty"` // max retry times, 0 means no retry
|
||||
ProcessType *vo.ErrorProcessType `json:"process_type,omitempty"` // error process type, 0 means throw error
|
||||
DataOnErr string `json:"data_on_err,omitempty"` // data to return when error, effective when ProcessType==Default occurs
|
||||
}
|
||||
|
||||
type StreamConfig struct {
|
||||
// whether this node has the ability to produce genuine streaming output.
|
||||
// not include nodes that only passes stream down as they receives them
|
||||
CanGeneratesStream bool `json:"can_generates_stream,omitempty"`
|
||||
// whether this node prioritize streaming input over none-streaming input.
|
||||
// not include nodes that can accept both and does not have preference.
|
||||
RequireStreamingInput bool `json:"can_process_stream,omitempty"`
|
||||
}
|
||||
|
||||
type Node struct {
|
||||
Lambda *compose.Lambda
|
||||
}
|
||||
|
||||
func (s *NodeSchema) New(ctx context.Context, inner compose.Runnable[map[string]any, map[string]any],
|
||||
sc *WorkflowSchema, deps *dependencyInfo) (_ *Node, err error) {
|
||||
defer func() {
|
||||
if panicErr := recover(); panicErr != nil {
|
||||
err = safego.NewPanicErr(panicErr, debug.Stack())
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
err = vo.WrapIfNeeded(errno.ErrCreateNodeFail, err, errorx.KV("node_name", s.Name), errorx.KV("cause", err.Error()))
|
||||
}
|
||||
}()
|
||||
|
||||
if m := entity.NodeMetaByNodeType(s.Type); m != nil && m.InputSourceAware {
|
||||
if err = s.SetFullSources(sc.GetAllNodes(), deps); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
switch s.Type {
|
||||
case entity.NodeTypeLambda:
|
||||
if s.Lambda == nil {
|
||||
return nil, fmt.Errorf("lambda is not defined for NodeTypeLambda")
|
||||
}
|
||||
|
||||
return &Node{Lambda: s.Lambda}, nil
|
||||
case entity.NodeTypeLLM:
|
||||
conf, err := s.ToLLMConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
l, err := llm.New(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableStreamableNodeWO(s, l.Chat, l.ChatStream, withCallbackOutputConverter(l.ToCallbackOutput)), nil
|
||||
case entity.NodeTypeSelector:
|
||||
conf := s.ToSelectorConfig()
|
||||
|
||||
sl, err := selector.NewSelector(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableNode(s, sl.Select, withCallbackInputConverter(s.toSelectorCallbackInput(sc)), withCallbackOutputConverter(sl.ToCallbackOutput)), nil
|
||||
case entity.NodeTypeBatch:
|
||||
if inner == nil {
|
||||
return nil, fmt.Errorf("inner workflow must not be nil when creating batch node")
|
||||
}
|
||||
|
||||
conf, err := s.ToBatchConfig(inner)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b, err := batch.NewBatch(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableNodeWO(s, b.Execute, withCallbackInputConverter(b.ToCallbackInput)), nil
|
||||
case entity.NodeTypeVariableAggregator:
|
||||
conf, err := s.ToVariableAggregatorConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
va, err := variableaggregator.NewVariableAggregator(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableTransformableNode(s, va.Invoke, va.Transform,
|
||||
withCallbackInputConverter(va.ToCallbackInput),
|
||||
withCallbackOutputConverter(va.ToCallbackOutput),
|
||||
withInit(va.Init)), nil
|
||||
case entity.NodeTypeTextProcessor:
|
||||
conf, err := s.ToTextProcessorConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tp, err := textprocessor.NewTextProcessor(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableNode(s, tp.Invoke), nil
|
||||
case entity.NodeTypeHTTPRequester:
|
||||
conf, err := s.ToHTTPRequesterConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hr, err := httprequester.NewHTTPRequester(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableNode(s, hr.Invoke, withCallbackInputConverter(hr.ToCallbackInput), withCallbackOutputConverter(hr.ToCallbackOutput)), nil
|
||||
case entity.NodeTypeContinue:
|
||||
i := func(ctx context.Context, in map[string]any) (map[string]any, error) {
|
||||
return map[string]any{}, nil
|
||||
}
|
||||
return invokableNode(s, i), nil
|
||||
case entity.NodeTypeBreak:
|
||||
b, err := loop.NewBreak(ctx, &nodes.ParentIntermediateStore{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, b.DoBreak), nil
|
||||
case entity.NodeTypeVariableAssigner:
|
||||
handler := variable.GetVariableHandler()
|
||||
|
||||
conf, err := s.ToVariableAssignerConfig(handler)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
va, err := variableassigner.NewVariableAssigner(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableNode(s, va.Assign), nil
|
||||
case entity.NodeTypeVariableAssignerWithinLoop:
|
||||
conf, err := s.ToVariableAssignerInLoopConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
va, err := variableassigner.NewVariableAssignerInLoop(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableNode(s, va.Assign), nil
|
||||
case entity.NodeTypeLoop:
|
||||
conf, err := s.ToLoopConfig(inner)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l, err := loop.NewLoop(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNodeWO(s, l.Execute, withCallbackInputConverter(l.ToCallbackInput)), nil
|
||||
case entity.NodeTypeQuestionAnswer:
|
||||
conf, err := s.ToQAConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
qA, err := qa.NewQuestionAnswer(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, qA.Execute, withCallbackOutputConverter(qA.ToCallbackOutput)), nil
|
||||
case entity.NodeTypeInputReceiver:
|
||||
conf, err := s.ToInputReceiverConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inputR, err := receiver.New(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, inputR.Invoke, withCallbackOutputConverter(inputR.ToCallbackOutput)), nil
|
||||
case entity.NodeTypeOutputEmitter:
|
||||
conf, err := s.ToOutputEmitterConfig(sc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
e, err := emitter.New(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableTransformableNode(s, e.Emit, e.EmitStream), nil
|
||||
case entity.NodeTypeEntry:
|
||||
conf, err := s.ToEntryConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
e, err := entry.NewEntry(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, e.Invoke), nil
|
||||
case entity.NodeTypeExit:
|
||||
terminalPlan := mustGetKey[vo.TerminatePlan]("TerminalPlan", s.Configs)
|
||||
if terminalPlan == vo.ReturnVariables {
|
||||
i := func(ctx context.Context, in map[string]any) (map[string]any, error) {
|
||||
if in == nil {
|
||||
return map[string]any{}, nil
|
||||
}
|
||||
return in, nil
|
||||
}
|
||||
return invokableNode(s, i), nil
|
||||
}
|
||||
|
||||
conf, err := s.ToOutputEmitterConfig(sc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
e, err := emitter.New(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableTransformableNode(s, e.Emit, e.EmitStream), nil
|
||||
case entity.NodeTypeDatabaseCustomSQL:
|
||||
conf, err := s.ToDatabaseCustomSQLConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sqlER, err := database.NewCustomSQL(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, sqlER.Execute), nil
|
||||
case entity.NodeTypeDatabaseQuery:
|
||||
conf, err := s.ToDatabaseQueryConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
query, err := database.NewQuery(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableNode(s, query.Query, withCallbackInputConverter(query.ToCallbackInput)), nil
|
||||
case entity.NodeTypeDatabaseInsert:
|
||||
conf, err := s.ToDatabaseInsertConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
insert, err := database.NewInsert(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableNode(s, insert.Insert, withCallbackInputConverter(insert.ToCallbackInput)), nil
|
||||
case entity.NodeTypeDatabaseUpdate:
|
||||
conf, err := s.ToDatabaseUpdateConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
update, err := database.NewUpdate(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, update.Update, withCallbackInputConverter(update.ToCallbackInput)), nil
|
||||
case entity.NodeTypeDatabaseDelete:
|
||||
conf, err := s.ToDatabaseDeleteConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
del, err := database.NewDelete(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableNode(s, del.Delete, withCallbackInputConverter(del.ToCallbackInput)), nil
|
||||
case entity.NodeTypeKnowledgeIndexer:
|
||||
conf, err := s.ToKnowledgeIndexerConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
w, err := knowledge.NewKnowledgeIndexer(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, w.Store), nil
|
||||
case entity.NodeTypeKnowledgeRetriever:
|
||||
conf, err := s.ToKnowledgeRetrieveConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r, err := knowledge.NewKnowledgeRetrieve(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, r.Retrieve), nil
|
||||
case entity.NodeTypeKnowledgeDeleter:
|
||||
conf, err := s.ToKnowledgeDeleterConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r, err := knowledge.NewKnowledgeDeleter(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, r.Delete), nil
|
||||
case entity.NodeTypeCodeRunner:
|
||||
conf, err := s.ToCodeRunnerConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r, err := code.NewCodeRunner(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
initFn := func(ctx context.Context) (context.Context, error) {
|
||||
return ctxcache.Init(ctx), nil
|
||||
}
|
||||
return invokableNode(s, r.RunCode, withCallbackOutputConverter(r.ToCallbackOutput), withInit(initFn)), nil
|
||||
case entity.NodeTypePlugin:
|
||||
conf, err := s.ToPluginConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r, err := plugin.NewPlugin(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, r.Invoke), nil
|
||||
case entity.NodeTypeCreateConversation:
|
||||
conf, err := s.ToCreateConversationConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r, err := conversation.NewCreateConversation(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, r.Create), nil
|
||||
case entity.NodeTypeMessageList:
|
||||
conf, err := s.ToMessageListConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r, err := conversation.NewMessageList(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, r.List), nil
|
||||
case entity.NodeTypeClearMessage:
|
||||
conf, err := s.ToClearMessageConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r, err := conversation.NewClearMessage(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, r.Clear), nil
|
||||
case entity.NodeTypeIntentDetector:
|
||||
conf, err := s.ToIntentDetectorConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r, err := intentdetector.NewIntentDetector(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return invokableNode(s, r.Invoke), nil
|
||||
case entity.NodeTypeSubWorkflow:
|
||||
conf, err := s.ToSubWorkflowConfig(ctx, sc.requireCheckPoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r, err := subworkflow.NewSubWorkflow(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableStreamableNodeWO(s, r.Invoke, r.Stream), nil
|
||||
case entity.NodeTypeJsonSerialization:
|
||||
conf, err := s.ToJsonSerializationConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
js, err := json.NewJsonSerializer(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invokableNode(s, js.Invoke), nil
|
||||
case entity.NodeTypeJsonDeserialization:
|
||||
conf, err := s.ToJsonDeserializationConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jd, err := json.NewJsonDeserializer(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
initFn := func(ctx context.Context) (context.Context, error) {
|
||||
return ctxcache.Init(ctx), nil
|
||||
}
|
||||
return invokableNode(s, jd.Invoke, withCallbackOutputConverter(jd.ToCallbackOutput), withInit(initFn)), nil
|
||||
default:
|
||||
panic("not implemented")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NodeSchema) IsEnableUserQuery() bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
if s.Type != entity.NodeTypeEntry {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(s.OutputSources) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, source := range s.OutputSources {
|
||||
fieldPath := source.Path
|
||||
if len(fieldPath) == 1 && (fieldPath[0] == "BOT_USER_INPUT" || fieldPath[0] == "USER_INPUT") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
|
||||
}
|
||||
|
||||
func (s *NodeSchema) IsEnableChatHistory() bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
switch s.Type {
|
||||
|
||||
case entity.NodeTypeLLM:
|
||||
llmParam := mustGetKey[*model.LLMParams]("LLMParams", s.Configs)
|
||||
return llmParam.EnableChatHistory
|
||||
case entity.NodeTypeIntentDetector:
|
||||
llmParam := mustGetKey[*model.LLMParams]("LLMParams", s.Configs)
|
||||
return llmParam.EnableChatHistory
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (s *NodeSchema) IsRefGlobalVariable() bool {
|
||||
for _, source := range s.InputSources {
|
||||
if source.IsRefGlobalVariable() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
for _, source := range s.OutputSources {
|
||||
if source.IsRefGlobalVariable() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *NodeSchema) requireCheckpoint() bool {
|
||||
if s.Type == entity.NodeTypeQuestionAnswer || s.Type == entity.NodeTypeInputReceiver {
|
||||
return true
|
||||
}
|
||||
|
||||
if s.Type == entity.NodeTypeLLM {
|
||||
fcParams := getKeyOrZero[*vo.FCParam]("FCParam", s.Configs)
|
||||
if fcParams != nil && fcParams.WorkflowFCParam != nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if s.Type == entity.NodeTypeSubWorkflow {
|
||||
s.SubWorkflowSchema.Init()
|
||||
if s.SubWorkflowSchema.requireCheckPoint {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
|
@ -32,8 +32,10 @@ import (
|
|||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/receiver"
|
||||
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
)
|
||||
|
||||
|
|
@ -47,7 +49,7 @@ type State struct {
|
|||
NestedWorkflowStates map[vo.NodeKey]*nodes.NestedWorkflowState `json:"nested_workflow_states,omitempty"`
|
||||
|
||||
ExecutedNodes map[vo.NodeKey]bool `json:"executed_nodes,omitempty"`
|
||||
SourceInfos map[vo.NodeKey]map[string]*nodes.SourceInfo `json:"source_infos,omitempty"`
|
||||
SourceInfos map[vo.NodeKey]map[string]*schema2.SourceInfo `json:"source_infos,omitempty"`
|
||||
GroupChoices map[vo.NodeKey]map[string]int `json:"group_choices,omitempty"`
|
||||
|
||||
ToolInterruptEvents map[vo.NodeKey]map[string] /*ToolCallID*/ *entity.ToolInterruptEvent `json:"tool_interrupt_events,omitempty"`
|
||||
|
|
@ -71,8 +73,8 @@ func init() {
|
|||
_ = compose.RegisterSerializableType[*model.TokenUsage]("model_token_usage")
|
||||
_ = compose.RegisterSerializableType[*nodes.NestedWorkflowState]("composite_state")
|
||||
_ = compose.RegisterSerializableType[*compose.InterruptInfo]("interrupt_info")
|
||||
_ = compose.RegisterSerializableType[*nodes.SourceInfo]("source_info")
|
||||
_ = compose.RegisterSerializableType[nodes.FieldStreamType]("field_stream_type")
|
||||
_ = compose.RegisterSerializableType[*schema2.SourceInfo]("source_info")
|
||||
_ = compose.RegisterSerializableType[schema2.FieldStreamType]("field_stream_type")
|
||||
_ = compose.RegisterSerializableType[compose.FieldPath]("field_path")
|
||||
_ = compose.RegisterSerializableType[*entity.WorkflowBasic]("workflow_basic")
|
||||
_ = compose.RegisterSerializableType[vo.TerminatePlan]("terminate_plan")
|
||||
|
|
@ -162,41 +164,41 @@ func (s *State) GetDynamicChoice(nodeKey vo.NodeKey) map[string]int {
|
|||
return s.GroupChoices[nodeKey]
|
||||
}
|
||||
|
||||
func (s *State) GetDynamicStreamType(nodeKey vo.NodeKey, group string) (nodes.FieldStreamType, error) {
|
||||
func (s *State) GetDynamicStreamType(nodeKey vo.NodeKey, group string) (schema2.FieldStreamType, error) {
|
||||
choices, ok := s.GroupChoices[nodeKey]
|
||||
if !ok {
|
||||
return nodes.FieldMaybeStream, fmt.Errorf("choice not found for node %s", nodeKey)
|
||||
return schema2.FieldMaybeStream, fmt.Errorf("choice not found for node %s", nodeKey)
|
||||
}
|
||||
|
||||
choice, ok := choices[group]
|
||||
if !ok {
|
||||
return nodes.FieldMaybeStream, fmt.Errorf("choice not found for node %s and group %s", nodeKey, group)
|
||||
return schema2.FieldMaybeStream, fmt.Errorf("choice not found for node %s and group %s", nodeKey, group)
|
||||
}
|
||||
|
||||
if choice == -1 { // this group picks none of the elements
|
||||
return nodes.FieldNotStream, nil
|
||||
return schema2.FieldNotStream, nil
|
||||
}
|
||||
|
||||
sInfos, ok := s.SourceInfos[nodeKey]
|
||||
if !ok {
|
||||
return nodes.FieldMaybeStream, fmt.Errorf("source infos not found for node %s", nodeKey)
|
||||
return schema2.FieldMaybeStream, fmt.Errorf("source infos not found for node %s", nodeKey)
|
||||
}
|
||||
|
||||
groupInfo, ok := sInfos[group]
|
||||
if !ok {
|
||||
return nodes.FieldMaybeStream, fmt.Errorf("source infos not found for node %s and group %s", nodeKey, group)
|
||||
return schema2.FieldMaybeStream, fmt.Errorf("source infos not found for node %s and group %s", nodeKey, group)
|
||||
}
|
||||
|
||||
if groupInfo.SubSources == nil {
|
||||
return nodes.FieldNotStream, fmt.Errorf("dynamic group %s of node %s does not contain any sub sources", group, nodeKey)
|
||||
return schema2.FieldNotStream, fmt.Errorf("dynamic group %s of node %s does not contain any sub sources", group, nodeKey)
|
||||
}
|
||||
|
||||
subInfo, ok := groupInfo.SubSources[strconv.Itoa(choice)]
|
||||
if !ok {
|
||||
return nodes.FieldNotStream, fmt.Errorf("dynamic group %s of node %s does not contain sub source for choice %d", group, nodeKey, choice)
|
||||
return schema2.FieldNotStream, fmt.Errorf("dynamic group %s of node %s does not contain sub source for choice %d", group, nodeKey, choice)
|
||||
}
|
||||
|
||||
if subInfo.FieldType != nodes.FieldMaybeStream {
|
||||
if subInfo.FieldType != schema2.FieldMaybeStream {
|
||||
return subInfo.FieldType, nil
|
||||
}
|
||||
|
||||
|
|
@ -211,8 +213,8 @@ func (s *State) GetDynamicStreamType(nodeKey vo.NodeKey, group string) (nodes.Fi
|
|||
return s.GetDynamicStreamType(subInfo.FromNodeKey, subInfo.FromPath[0])
|
||||
}
|
||||
|
||||
func (s *State) GetAllDynamicStreamTypes(nodeKey vo.NodeKey) (map[string]nodes.FieldStreamType, error) {
|
||||
result := make(map[string]nodes.FieldStreamType)
|
||||
func (s *State) GetAllDynamicStreamTypes(nodeKey vo.NodeKey) (map[string]schema2.FieldStreamType, error) {
|
||||
result := make(map[string]schema2.FieldStreamType)
|
||||
choices, ok := s.GroupChoices[nodeKey]
|
||||
if !ok {
|
||||
return result, nil
|
||||
|
|
@ -269,7 +271,7 @@ func GenState() compose.GenLocalState[*State] {
|
|||
InterruptEvents: make(map[vo.NodeKey]*entity.InterruptEvent),
|
||||
NestedWorkflowStates: make(map[vo.NodeKey]*nodes.NestedWorkflowState),
|
||||
ExecutedNodes: make(map[vo.NodeKey]bool),
|
||||
SourceInfos: make(map[vo.NodeKey]map[string]*nodes.SourceInfo),
|
||||
SourceInfos: make(map[vo.NodeKey]map[string]*schema2.SourceInfo),
|
||||
GroupChoices: make(map[vo.NodeKey]map[string]int),
|
||||
ToolInterruptEvents: make(map[vo.NodeKey]map[string]*entity.ToolInterruptEvent),
|
||||
LLMToResumeData: make(map[vo.NodeKey]string),
|
||||
|
|
@ -277,7 +279,7 @@ func GenState() compose.GenLocalState[*State] {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt {
|
||||
func statePreHandler(s *schema2.NodeSchema, stream bool) compose.GraphAddNodeOpt {
|
||||
var (
|
||||
handlers []compose.StatePreHandler[map[string]any, *State]
|
||||
streamHandlers []compose.StreamStatePreHandler[map[string]any, *State]
|
||||
|
|
@ -314,7 +316,7 @@ func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt {
|
|||
}
|
||||
return in, nil
|
||||
})
|
||||
} else if s.Type == entity.NodeTypeBatch || s.Type == entity.NodeTypeLoop {
|
||||
} else if entity.NodeMetaByNodeType(s.Type).IsComposite {
|
||||
handlers = append(handlers, func(ctx context.Context, in map[string]any, state *State) (map[string]any, error) {
|
||||
if _, ok := state.Inputs[s.Key]; !ok { // first execution, store input for potential resume later
|
||||
state.Inputs[s.Key] = in
|
||||
|
|
@ -329,7 +331,7 @@ func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt {
|
|||
}
|
||||
|
||||
if len(handlers) > 0 || !stream {
|
||||
handlerForVars := s.statePreHandlerForVars()
|
||||
handlerForVars := statePreHandlerForVars(s)
|
||||
if handlerForVars != nil {
|
||||
handlers = append(handlers, handlerForVars)
|
||||
}
|
||||
|
|
@ -349,12 +351,12 @@ func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt {
|
|||
|
||||
if s.Type == entity.NodeTypeVariableAggregator {
|
||||
streamHandlers = append(streamHandlers, func(ctx context.Context, in *schema.StreamReader[map[string]any], state *State) (*schema.StreamReader[map[string]any], error) {
|
||||
state.SourceInfos[s.Key] = mustGetKey[map[string]*nodes.SourceInfo]("FullSources", s.Configs)
|
||||
state.SourceInfos[s.Key] = s.FullSources
|
||||
return in, nil
|
||||
})
|
||||
}
|
||||
|
||||
handlerForVars := s.streamStatePreHandlerForVars()
|
||||
handlerForVars := streamStatePreHandlerForVars(s)
|
||||
if handlerForVars != nil {
|
||||
streamHandlers = append(streamHandlers, handlerForVars)
|
||||
}
|
||||
|
|
@ -381,7 +383,7 @@ func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) statePreHandlerForVars() compose.StatePreHandler[map[string]any, *State] {
|
||||
func statePreHandlerForVars(s *schema2.NodeSchema) compose.StatePreHandler[map[string]any, *State] {
|
||||
// checkout the node's inputs, if it has any variable, use the state's variableHandler to get the variables and set them to the input
|
||||
var vars []*vo.FieldInfo
|
||||
for _, input := range s.InputSources {
|
||||
|
|
@ -456,7 +458,7 @@ func (s *NodeSchema) statePreHandlerForVars() compose.StatePreHandler[map[string
|
|||
}
|
||||
}
|
||||
|
||||
func (s *NodeSchema) streamStatePreHandlerForVars() compose.StreamStatePreHandler[map[string]any, *State] {
|
||||
func streamStatePreHandlerForVars(s *schema2.NodeSchema) compose.StreamStatePreHandler[map[string]any, *State] {
|
||||
// checkout the node's inputs, if it has any variables, get the variables and merge them with the input
|
||||
var vars []*vo.FieldInfo
|
||||
for _, input := range s.InputSources {
|
||||
|
|
@ -533,7 +535,7 @@ func (s *NodeSchema) streamStatePreHandlerForVars() compose.StreamStatePreHandle
|
|||
}
|
||||
}
|
||||
|
||||
func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamStatePreHandler[map[string]any, *State] {
|
||||
func streamStatePreHandlerForStreamSources(s *schema2.NodeSchema) compose.StreamStatePreHandler[map[string]any, *State] {
|
||||
// if it does not have source info, do not add this pre handler
|
||||
if s.Configs == nil {
|
||||
return nil
|
||||
|
|
@ -543,7 +545,7 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
|
|||
case entity.NodeTypeVariableAggregator, entity.NodeTypeOutputEmitter:
|
||||
return nil
|
||||
case entity.NodeTypeExit:
|
||||
terminatePlan := mustGetKey[vo.TerminatePlan]("TerminalPlan", s.Configs)
|
||||
terminatePlan := s.Configs.(*exit.Config).TerminatePlan
|
||||
if terminatePlan != vo.ReturnVariables {
|
||||
return nil
|
||||
}
|
||||
|
|
@ -551,7 +553,7 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
|
|||
// all other node can only accept non-stream inputs, relying on Eino's automatically stream concatenation.
|
||||
}
|
||||
|
||||
sourceInfo := getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs)
|
||||
sourceInfo := s.FullSources
|
||||
if len(sourceInfo) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
|
@ -566,10 +568,10 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
|
|||
|
||||
var (
|
||||
anyStream bool
|
||||
checker func(source *nodes.SourceInfo) bool
|
||||
checker func(source *schema2.SourceInfo) bool
|
||||
)
|
||||
checker = func(source *nodes.SourceInfo) bool {
|
||||
if source.FieldType != nodes.FieldNotStream {
|
||||
checker = func(source *schema2.SourceInfo) bool {
|
||||
if source.FieldType != schema2.FieldNotStream {
|
||||
return true
|
||||
}
|
||||
for _, subSource := range source.SubSources {
|
||||
|
|
@ -594,8 +596,8 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
|
|||
return func(ctx context.Context, in *schema.StreamReader[map[string]any], state *State) (*schema.StreamReader[map[string]any], error) {
|
||||
resolved := map[string]resolvedStreamSource{}
|
||||
|
||||
var resolver func(source nodes.SourceInfo) (result *resolvedStreamSource, err error)
|
||||
resolver = func(source nodes.SourceInfo) (result *resolvedStreamSource, err error) {
|
||||
var resolver func(source schema2.SourceInfo) (result *resolvedStreamSource, err error)
|
||||
resolver = func(source schema2.SourceInfo) (result *resolvedStreamSource, err error) {
|
||||
if source.IsIntermediate {
|
||||
result = &resolvedStreamSource{
|
||||
intermediate: true,
|
||||
|
|
@ -615,14 +617,14 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
|
|||
}
|
||||
|
||||
streamType := source.FieldType
|
||||
if streamType == nodes.FieldMaybeStream {
|
||||
if streamType == schema2.FieldMaybeStream {
|
||||
streamType, err = state.GetDynamicStreamType(source.FromNodeKey, source.FromPath[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if streamType == nodes.FieldNotStream {
|
||||
if streamType == schema2.FieldNotStream {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
|
@ -690,7 +692,7 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
|
|||
}
|
||||
}
|
||||
|
||||
func (s *NodeSchema) StatePostHandler(stream bool) compose.GraphAddNodeOpt {
|
||||
func statePostHandler(s *schema2.NodeSchema, stream bool) compose.GraphAddNodeOpt {
|
||||
var (
|
||||
handlers []compose.StatePostHandler[map[string]any, *State]
|
||||
streamHandlers []compose.StreamStatePostHandler[map[string]any, *State]
|
||||
|
|
@ -702,7 +704,7 @@ func (s *NodeSchema) StatePostHandler(stream bool) compose.GraphAddNodeOpt {
|
|||
return out, nil
|
||||
})
|
||||
|
||||
forVars := s.streamStatePostHandlerForVars()
|
||||
forVars := streamStatePostHandlerForVars(s)
|
||||
if forVars != nil {
|
||||
streamHandlers = append(streamHandlers, forVars)
|
||||
}
|
||||
|
|
@ -725,7 +727,7 @@ func (s *NodeSchema) StatePostHandler(stream bool) compose.GraphAddNodeOpt {
|
|||
return out, nil
|
||||
})
|
||||
|
||||
forVars := s.statePostHandlerForVars()
|
||||
forVars := statePostHandlerForVars(s)
|
||||
if forVars != nil {
|
||||
handlers = append(handlers, forVars)
|
||||
}
|
||||
|
|
@ -745,7 +747,7 @@ func (s *NodeSchema) StatePostHandler(stream bool) compose.GraphAddNodeOpt {
|
|||
return compose.WithStatePostHandler(handler)
|
||||
}
|
||||
|
||||
func (s *NodeSchema) statePostHandlerForVars() compose.StatePostHandler[map[string]any, *State] {
|
||||
func statePostHandlerForVars(s *schema2.NodeSchema) compose.StatePostHandler[map[string]any, *State] {
|
||||
// checkout the node's output sources, if it has any variable,
|
||||
// use the state's variableHandler to get the variables and set them to the output
|
||||
var vars []*vo.FieldInfo
|
||||
|
|
@ -823,7 +825,7 @@ func (s *NodeSchema) statePostHandlerForVars() compose.StatePostHandler[map[stri
|
|||
}
|
||||
}
|
||||
|
||||
func (s *NodeSchema) streamStatePostHandlerForVars() compose.StreamStatePostHandler[map[string]any, *State] {
|
||||
func streamStatePostHandlerForVars(s *schema2.NodeSchema) compose.StreamStatePostHandler[map[string]any, *State] {
|
||||
// checkout the node's output sources, if it has any variables, get the variables and merge them with the output
|
||||
var vars []*vo.FieldInfo
|
||||
for _, output := range s.OutputSources {
|
||||
|
|
|
|||
|
|
@ -21,19 +21,20 @@ import (
|
|||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
// SetFullSources calculates REAL input sources for a node.
|
||||
// GetFullSources calculates REAL input sources for a node.
|
||||
// It may be different from a NodeSchema's InputSources because of the following reasons:
|
||||
// 1. a inner node under composite node may refer to a field from a node in its parent workflow,
|
||||
// this is instead routed to and sourced from the inner workflow's start node.
|
||||
// 2. at the same time, the composite node needs to delegate the input source to the inner workflow.
|
||||
// 3. also, some node may have implicit input sources not defined in its NodeSchema's InputSources.
|
||||
func (s *NodeSchema) SetFullSources(allNS map[vo.NodeKey]*NodeSchema, dep *dependencyInfo) error {
|
||||
fullSource := make(map[string]*nodes.SourceInfo)
|
||||
func GetFullSources(s *schema.NodeSchema, sc *schema.WorkflowSchema, dep *dependencyInfo) (
|
||||
map[string]*schema.SourceInfo, error) {
|
||||
fullSource := make(map[string]*schema.SourceInfo)
|
||||
var fieldInfos []vo.FieldInfo
|
||||
for _, s := range dep.staticValues {
|
||||
fieldInfos = append(fieldInfos, vo.FieldInfo{
|
||||
|
|
@ -113,14 +114,14 @@ func (s *NodeSchema) SetFullSources(allNS map[vo.NodeKey]*NodeSchema, dep *depen
|
|||
tInfo = tInfo.Properties[path[j]]
|
||||
}
|
||||
if current, ok := currentSource[path[j]]; !ok {
|
||||
currentSource[path[j]] = &nodes.SourceInfo{
|
||||
currentSource[path[j]] = &schema.SourceInfo{
|
||||
IsIntermediate: true,
|
||||
FieldType: nodes.FieldNotStream,
|
||||
FieldType: schema.FieldNotStream,
|
||||
TypeInfo: tInfo,
|
||||
SubSources: make(map[string]*nodes.SourceInfo),
|
||||
SubSources: make(map[string]*schema.SourceInfo),
|
||||
}
|
||||
} else if !current.IsIntermediate {
|
||||
return fmt.Errorf("existing sourceInfo for path %s is not intermediate, conflict", path[:j+1])
|
||||
return nil, fmt.Errorf("existing sourceInfo for path %s is not intermediate, conflict", path[:j+1])
|
||||
}
|
||||
|
||||
currentSource = currentSource[path[j]].SubSources
|
||||
|
|
@ -135,9 +136,9 @@ func (s *NodeSchema) SetFullSources(allNS map[vo.NodeKey]*NodeSchema, dep *depen
|
|||
|
||||
// static values or variables
|
||||
if fInfo.Source.Ref == nil || fInfo.Source.Ref.FromNodeKey == "" {
|
||||
currentSource[lastPath] = &nodes.SourceInfo{
|
||||
currentSource[lastPath] = &schema.SourceInfo{
|
||||
IsIntermediate: false,
|
||||
FieldType: nodes.FieldNotStream,
|
||||
FieldType: schema.FieldNotStream,
|
||||
TypeInfo: tInfo,
|
||||
}
|
||||
continue
|
||||
|
|
@ -145,25 +146,25 @@ func (s *NodeSchema) SetFullSources(allNS map[vo.NodeKey]*NodeSchema, dep *depen
|
|||
|
||||
fromNodeKey := fInfo.Source.Ref.FromNodeKey
|
||||
var (
|
||||
streamType nodes.FieldStreamType
|
||||
streamType schema.FieldStreamType
|
||||
err error
|
||||
)
|
||||
if len(fromNodeKey) > 0 {
|
||||
if fromNodeKey == compose.START {
|
||||
streamType = nodes.FieldNotStream // TODO: set start node to not stream for now until composite node supports transform
|
||||
streamType = schema.FieldNotStream // TODO: set start node to not stream for now until composite node supports transform
|
||||
} else {
|
||||
fromNode, ok := allNS[fromNodeKey]
|
||||
if !ok {
|
||||
return fmt.Errorf("node %s not found", fromNodeKey)
|
||||
fromNode := sc.GetNode(fromNodeKey)
|
||||
if fromNode == nil {
|
||||
return nil, fmt.Errorf("node %s not found", fromNodeKey)
|
||||
}
|
||||
streamType, err = fromNode.IsStreamingField(fInfo.Source.Ref.FromPath, allNS)
|
||||
streamType, err = nodes.IsStreamingField(fromNode, fInfo.Source.Ref.FromPath, sc)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
currentSource[lastPath] = &nodes.SourceInfo{
|
||||
currentSource[lastPath] = &schema.SourceInfo{
|
||||
IsIntermediate: false,
|
||||
FieldType: streamType,
|
||||
FromNodeKey: fromNodeKey,
|
||||
|
|
@ -172,121 +173,5 @@ func (s *NodeSchema) SetFullSources(allNS map[vo.NodeKey]*NodeSchema, dep *depen
|
|||
}
|
||||
}
|
||||
|
||||
s.Configs.(map[string]any)["FullSources"] = fullSource
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) IsStreamingField(path compose.FieldPath, allNS map[vo.NodeKey]*NodeSchema) (nodes.FieldStreamType, error) {
|
||||
if s.Type == entity.NodeTypeExit {
|
||||
if mustGetKey[nodes.Mode]("Mode", s.Configs) == nodes.Streaming {
|
||||
if len(path) == 1 && path[0] == "output" {
|
||||
return nodes.FieldIsStream, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nodes.FieldNotStream, nil
|
||||
} else if s.Type == entity.NodeTypeSubWorkflow { // TODO: why not use sub workflow's Mode configuration directly?
|
||||
subSC := s.SubWorkflowSchema
|
||||
subExit := subSC.GetNode(entity.ExitNodeKey)
|
||||
subStreamType, err := subExit.IsStreamingField(path, nil)
|
||||
if err != nil {
|
||||
return nodes.FieldNotStream, err
|
||||
}
|
||||
|
||||
return subStreamType, nil
|
||||
} else if s.Type == entity.NodeTypeVariableAggregator {
|
||||
if len(path) == 2 { // asking about a specific index within a group
|
||||
for _, fInfo := range s.InputSources {
|
||||
if len(fInfo.Path) == len(path) {
|
||||
equal := true
|
||||
for i := range fInfo.Path {
|
||||
if fInfo.Path[i] != path[i] {
|
||||
equal = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if equal {
|
||||
if fInfo.Source.Ref == nil || fInfo.Source.Ref.FromNodeKey == "" {
|
||||
return nodes.FieldNotStream, nil
|
||||
}
|
||||
fromNodeKey := fInfo.Source.Ref.FromNodeKey
|
||||
fromNode, ok := allNS[fromNodeKey]
|
||||
if !ok {
|
||||
return nodes.FieldNotStream, fmt.Errorf("node %s not found", fromNodeKey)
|
||||
}
|
||||
return fromNode.IsStreamingField(fInfo.Source.Ref.FromPath, allNS)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if len(path) == 1 { // asking about the entire group
|
||||
var streamCount, notStreamCount int
|
||||
for _, fInfo := range s.InputSources {
|
||||
if fInfo.Path[0] == path[0] { // belong to the group
|
||||
if fInfo.Source.Ref != nil && len(fInfo.Source.Ref.FromNodeKey) > 0 {
|
||||
fromNode, ok := allNS[fInfo.Source.Ref.FromNodeKey]
|
||||
if !ok {
|
||||
return nodes.FieldNotStream, fmt.Errorf("node %s not found", fInfo.Source.Ref.FromNodeKey)
|
||||
}
|
||||
subStreamType, err := fromNode.IsStreamingField(fInfo.Source.Ref.FromPath, allNS)
|
||||
if err != nil {
|
||||
return nodes.FieldNotStream, err
|
||||
}
|
||||
|
||||
if subStreamType == nodes.FieldMaybeStream {
|
||||
return nodes.FieldMaybeStream, nil
|
||||
} else if subStreamType == nodes.FieldIsStream {
|
||||
streamCount++
|
||||
} else {
|
||||
notStreamCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if streamCount > 0 && notStreamCount == 0 {
|
||||
return nodes.FieldIsStream, nil
|
||||
}
|
||||
|
||||
if streamCount == 0 && notStreamCount > 0 {
|
||||
return nodes.FieldNotStream, nil
|
||||
}
|
||||
|
||||
return nodes.FieldMaybeStream, nil
|
||||
}
|
||||
}
|
||||
|
||||
if s.Type != entity.NodeTypeLLM {
|
||||
return nodes.FieldNotStream, nil
|
||||
}
|
||||
|
||||
if len(path) != 1 {
|
||||
return nodes.FieldNotStream, nil
|
||||
}
|
||||
|
||||
outputs := s.OutputTypes
|
||||
if len(outputs) != 1 && len(outputs) != 2 {
|
||||
return nodes.FieldNotStream, nil
|
||||
}
|
||||
|
||||
var outputKey string
|
||||
for key, output := range outputs {
|
||||
if output.Type != vo.DataTypeString {
|
||||
return nodes.FieldNotStream, nil
|
||||
}
|
||||
|
||||
if key != "reasoning_content" {
|
||||
if len(outputKey) > 0 {
|
||||
return nodes.FieldNotStream, nil
|
||||
}
|
||||
outputKey = key
|
||||
}
|
||||
}
|
||||
|
||||
field := path[0]
|
||||
if field == "reasoning_content" || field == outputKey {
|
||||
return nodes.FieldIsStream, nil
|
||||
}
|
||||
|
||||
return nodes.FieldNotStream, nil
|
||||
return fullSource, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,6 +28,9 @@ import (
|
|||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/batch"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
func TestBatch(t *testing.T) {
|
||||
|
|
@ -52,7 +55,7 @@ func TestBatch(t *testing.T) {
|
|||
return in, nil
|
||||
}
|
||||
|
||||
lambdaNode1 := &compose2.NodeSchema{
|
||||
lambdaNode1 := &schema.NodeSchema{
|
||||
Key: "lambda",
|
||||
Type: entity.NodeTypeLambda,
|
||||
Lambda: compose.InvokableLambda(lambda1),
|
||||
|
|
@ -86,7 +89,7 @@ func TestBatch(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}
|
||||
lambdaNode2 := &compose2.NodeSchema{
|
||||
lambdaNode2 := &schema.NodeSchema{
|
||||
Key: "index",
|
||||
Type: entity.NodeTypeLambda,
|
||||
Lambda: compose.InvokableLambda(lambda2),
|
||||
|
|
@ -103,7 +106,7 @@ func TestBatch(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
lambdaNode3 := &compose2.NodeSchema{
|
||||
lambdaNode3 := &schema.NodeSchema{
|
||||
Key: "consumer",
|
||||
Type: entity.NodeTypeLambda,
|
||||
Lambda: compose.InvokableLambda(lambda3),
|
||||
|
|
@ -135,23 +138,22 @@ func TestBatch(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
entry := &compose2.NodeSchema{
|
||||
entryN := &schema.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
Configs: &entry.Config{},
|
||||
}
|
||||
|
||||
ns := &compose2.NodeSchema{
|
||||
ns := &schema.NodeSchema{
|
||||
Key: "batch_node_key",
|
||||
Type: entity.NodeTypeBatch,
|
||||
Configs: &batch.Config{},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"array_1"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"array_1"},
|
||||
},
|
||||
},
|
||||
|
|
@ -160,7 +162,7 @@ func TestBatch(t *testing.T) {
|
|||
Path: compose.FieldPath{"array_2"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"array_2"},
|
||||
},
|
||||
},
|
||||
|
|
@ -214,11 +216,11 @@ func TestBatch(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
exitN := &schema.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
Configs: &exit.Config{
|
||||
TerminatePlan: vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
|
|
@ -246,18 +248,18 @@ func TestBatch(t *testing.T) {
|
|||
return map[string]any{"success": true}, nil
|
||||
}
|
||||
|
||||
parentLambdaNode := &compose2.NodeSchema{
|
||||
parentLambdaNode := &schema.NodeSchema{
|
||||
Key: "parent_predecessor_1",
|
||||
Type: entity.NodeTypeLambda,
|
||||
Lambda: compose.InvokableLambda(parentLambda),
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
ws := &schema.WorkflowSchema{
|
||||
Nodes: []*schema.NodeSchema{
|
||||
entryN,
|
||||
parentLambdaNode,
|
||||
ns,
|
||||
exit,
|
||||
exitN,
|
||||
lambdaNode1,
|
||||
lambdaNode2,
|
||||
lambdaNode3,
|
||||
|
|
@ -267,7 +269,7 @@ func TestBatch(t *testing.T) {
|
|||
"index": "batch_node_key",
|
||||
"consumer": "batch_node_key",
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
Connections: []*schema.Connection{
|
||||
{
|
||||
FromNode: entity.EntryNodeKey,
|
||||
ToNode: "parent_predecessor_1",
|
||||
|
|
|
|||
|
|
@ -40,7 +40,11 @@ import (
|
|||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/emitter"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm"
|
||||
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/internal/testutil"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
|
|
@ -108,22 +112,20 @@ func TestLLM(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
entry := &compose2.NodeSchema{
|
||||
entryN := &schema2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
Configs: &entry.Config{},
|
||||
}
|
||||
|
||||
llmNode := &compose2.NodeSchema{
|
||||
llmNode := &schema2.NodeSchema{
|
||||
Key: "llm_node_key",
|
||||
Type: entity.NodeTypeLLM,
|
||||
Configs: map[string]any{
|
||||
"SystemPrompt": "{{sys_prompt}}",
|
||||
"UserPrompt": "{{query}}",
|
||||
"OutputFormat": llm.FormatText,
|
||||
"LLMParams": &model.LLMParams{
|
||||
Configs: &llm.Config{
|
||||
SystemPrompt: "{{sys_prompt}}",
|
||||
UserPrompt: "{{query}}",
|
||||
OutputFormat: llm.FormatText,
|
||||
LLMParams: &model.LLMParams{
|
||||
ModelName: modelName,
|
||||
},
|
||||
},
|
||||
|
|
@ -132,7 +134,7 @@ func TestLLM(t *testing.T) {
|
|||
Path: compose.FieldPath{"sys_prompt"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"sys_prompt"},
|
||||
},
|
||||
},
|
||||
|
|
@ -141,7 +143,7 @@ func TestLLM(t *testing.T) {
|
|||
Path: compose.FieldPath{"query"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"query"},
|
||||
},
|
||||
},
|
||||
|
|
@ -162,11 +164,11 @@ func TestLLM(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
exitN := &schema2.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
Configs: &exit.Config{
|
||||
TerminatePlan: vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
|
|
@ -181,20 +183,20 @@ func TestLLM(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
ws := &schema2.WorkflowSchema{
|
||||
Nodes: []*schema2.NodeSchema{
|
||||
entryN,
|
||||
llmNode,
|
||||
exit,
|
||||
exitN,
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
Connections: []*schema2.Connection{
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
FromNode: entryN.Key,
|
||||
ToNode: llmNode.Key,
|
||||
},
|
||||
{
|
||||
FromNode: llmNode.Key,
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
@ -228,27 +230,20 @@ func TestLLM(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
entry := &compose2.NodeSchema{
|
||||
entryN := &schema2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
Configs: &entry.Config{},
|
||||
}
|
||||
|
||||
llmNode := &compose2.NodeSchema{
|
||||
llmNode := &schema2.NodeSchema{
|
||||
Key: "llm_node_key",
|
||||
Type: entity.NodeTypeLLM,
|
||||
Configs: map[string]any{
|
||||
"SystemPrompt": "you are a helpful assistant",
|
||||
"UserPrompt": "what's the largest country in the world and it's area size in square kilometers?",
|
||||
"OutputFormat": llm.FormatJSON,
|
||||
"IgnoreException": true,
|
||||
"DefaultOutput": map[string]any{
|
||||
"country_name": "unknown",
|
||||
"area_size": int64(0),
|
||||
},
|
||||
"LLMParams": &model.LLMParams{
|
||||
Configs: &llm.Config{
|
||||
SystemPrompt: "you are a helpful assistant",
|
||||
UserPrompt: "what's the largest country in the world and it's area size in square kilometers?",
|
||||
OutputFormat: llm.FormatJSON,
|
||||
LLMParams: &model.LLMParams{
|
||||
ModelName: modelName,
|
||||
},
|
||||
},
|
||||
|
|
@ -264,11 +259,11 @@ func TestLLM(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
exitN := &schema2.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
Configs: &exit.Config{
|
||||
TerminatePlan: vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
|
|
@ -292,20 +287,20 @@ func TestLLM(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
ws := &schema2.WorkflowSchema{
|
||||
Nodes: []*schema2.NodeSchema{
|
||||
entryN,
|
||||
llmNode,
|
||||
exit,
|
||||
exitN,
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
Connections: []*schema2.Connection{
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
FromNode: entryN.Key,
|
||||
ToNode: llmNode.Key,
|
||||
},
|
||||
{
|
||||
FromNode: llmNode.Key,
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
@ -337,22 +332,20 @@ func TestLLM(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
entry := &compose2.NodeSchema{
|
||||
entryN := &schema2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
Configs: &entry.Config{},
|
||||
}
|
||||
|
||||
llmNode := &compose2.NodeSchema{
|
||||
llmNode := &schema2.NodeSchema{
|
||||
Key: "llm_node_key",
|
||||
Type: entity.NodeTypeLLM,
|
||||
Configs: map[string]any{
|
||||
"SystemPrompt": "you are a helpful assistant",
|
||||
"UserPrompt": "list the top 5 largest countries in the world",
|
||||
"OutputFormat": llm.FormatMarkdown,
|
||||
"LLMParams": &model.LLMParams{
|
||||
Configs: &llm.Config{
|
||||
SystemPrompt: "you are a helpful assistant",
|
||||
UserPrompt: "list the top 5 largest countries in the world",
|
||||
OutputFormat: llm.FormatMarkdown,
|
||||
LLMParams: &model.LLMParams{
|
||||
ModelName: modelName,
|
||||
},
|
||||
},
|
||||
|
|
@ -363,11 +356,11 @@ func TestLLM(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
exitN := &schema2.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
Configs: &exit.Config{
|
||||
TerminatePlan: vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
|
|
@ -382,20 +375,20 @@ func TestLLM(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
ws := &schema2.WorkflowSchema{
|
||||
Nodes: []*schema2.NodeSchema{
|
||||
entryN,
|
||||
llmNode,
|
||||
exit,
|
||||
exitN,
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
Connections: []*schema2.Connection{
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
FromNode: entryN.Key,
|
||||
ToNode: llmNode.Key,
|
||||
},
|
||||
{
|
||||
FromNode: llmNode.Key,
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
@ -456,22 +449,20 @@ func TestLLM(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
entry := &compose2.NodeSchema{
|
||||
entryN := &schema2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
Configs: &entry.Config{},
|
||||
}
|
||||
|
||||
openaiNode := &compose2.NodeSchema{
|
||||
openaiNode := &schema2.NodeSchema{
|
||||
Key: "openai_llm_node_key",
|
||||
Type: entity.NodeTypeLLM,
|
||||
Configs: map[string]any{
|
||||
"SystemPrompt": "you are a helpful assistant",
|
||||
"UserPrompt": "plan a 10 day family visit to China.",
|
||||
"OutputFormat": llm.FormatText,
|
||||
"LLMParams": &model.LLMParams{
|
||||
Configs: &llm.Config{
|
||||
SystemPrompt: "you are a helpful assistant",
|
||||
UserPrompt: "plan a 10 day family visit to China.",
|
||||
OutputFormat: llm.FormatText,
|
||||
LLMParams: &model.LLMParams{
|
||||
ModelName: modelName,
|
||||
},
|
||||
},
|
||||
|
|
@ -482,14 +473,14 @@ func TestLLM(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
deepseekNode := &compose2.NodeSchema{
|
||||
deepseekNode := &schema2.NodeSchema{
|
||||
Key: "deepseek_llm_node_key",
|
||||
Type: entity.NodeTypeLLM,
|
||||
Configs: map[string]any{
|
||||
"SystemPrompt": "you are a helpful assistant",
|
||||
"UserPrompt": "thoroughly plan a 10 day family visit to China. Use your reasoning ability.",
|
||||
"OutputFormat": llm.FormatText,
|
||||
"LLMParams": &model.LLMParams{
|
||||
Configs: &llm.Config{
|
||||
SystemPrompt: "you are a helpful assistant",
|
||||
UserPrompt: "thoroughly plan a 10 day family visit to China. Use your reasoning ability.",
|
||||
OutputFormat: llm.FormatText,
|
||||
LLMParams: &model.LLMParams{
|
||||
ModelName: modelName,
|
||||
},
|
||||
},
|
||||
|
|
@ -503,12 +494,11 @@ func TestLLM(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
emitterNode := &compose2.NodeSchema{
|
||||
emitterNode := &schema2.NodeSchema{
|
||||
Key: "emitter_node_key",
|
||||
Type: entity.NodeTypeOutputEmitter,
|
||||
Configs: map[string]any{
|
||||
"Template": "prefix {{inputObj.field1}} {{input2}} {{deepseek_reasoning}} \n\n###\n\n {{openai_output}} \n\n###\n\n {{deepseek_output}} {{inputObj.field2}} suffix",
|
||||
"Mode": nodes.Streaming,
|
||||
Configs: &emitter.Config{
|
||||
Template: "prefix {{inputObj.field1}} {{input2}} {{deepseek_reasoning}} \n\n###\n\n {{openai_output}} \n\n###\n\n {{deepseek_output}} {{inputObj.field2}} suffix",
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
|
|
@ -542,7 +532,7 @@ func TestLLM(t *testing.T) {
|
|||
Path: compose.FieldPath{"inputObj"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"inputObj"},
|
||||
},
|
||||
},
|
||||
|
|
@ -551,7 +541,7 @@ func TestLLM(t *testing.T) {
|
|||
Path: compose.FieldPath{"input2"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"input2"},
|
||||
},
|
||||
},
|
||||
|
|
@ -559,11 +549,11 @@ func TestLLM(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
exitN := &schema2.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.UseAnswerContent,
|
||||
Configs: &exit.Config{
|
||||
TerminatePlan: vo.UseAnswerContent,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
|
|
@ -596,17 +586,17 @@ func TestLLM(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
ws := &schema2.WorkflowSchema{
|
||||
Nodes: []*schema2.NodeSchema{
|
||||
entryN,
|
||||
openaiNode,
|
||||
deepseekNode,
|
||||
emitterNode,
|
||||
exit,
|
||||
exitN,
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
Connections: []*schema2.Connection{
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
FromNode: entryN.Key,
|
||||
ToNode: openaiNode.Key,
|
||||
},
|
||||
{
|
||||
|
|
@ -614,7 +604,7 @@ func TestLLM(t *testing.T) {
|
|||
ToNode: emitterNode.Key,
|
||||
},
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
FromNode: entryN.Key,
|
||||
ToNode: deepseekNode.Key,
|
||||
},
|
||||
{
|
||||
|
|
@ -623,7 +613,7 @@ func TestLLM(t *testing.T) {
|
|||
},
|
||||
{
|
||||
FromNode: emitterNode.Key,
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -26,15 +26,20 @@ import (
|
|||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop"
|
||||
_break "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop/break"
|
||||
_continue "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop/continue"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableassigner"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
func TestLoop(t *testing.T) {
|
||||
t.Run("by iteration", func(t *testing.T) {
|
||||
// start-> loop_node_key[innerNode->continue] -> end
|
||||
innerNode := &compose2.NodeSchema{
|
||||
innerNode := &schema.NodeSchema{
|
||||
Key: "innerNode",
|
||||
Type: entity.NodeTypeLambda,
|
||||
Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
|
||||
|
|
@ -54,31 +59,30 @@ func TestLoop(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
continueNode := &compose2.NodeSchema{
|
||||
continueNode := &schema.NodeSchema{
|
||||
Key: "continueNode",
|
||||
Type: entity.NodeTypeContinue,
|
||||
Configs: &_continue.Config{},
|
||||
}
|
||||
|
||||
entry := &compose2.NodeSchema{
|
||||
entryN := &schema.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
Configs: &entry.Config{},
|
||||
}
|
||||
|
||||
loopNode := &compose2.NodeSchema{
|
||||
loopNode := &schema.NodeSchema{
|
||||
Key: "loop_node_key",
|
||||
Type: entity.NodeTypeLoop,
|
||||
Configs: map[string]any{
|
||||
"LoopType": loop.ByIteration,
|
||||
Configs: &loop.Config{
|
||||
LoopType: loop.ByIteration,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{loop.Count},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"count"},
|
||||
},
|
||||
},
|
||||
|
|
@ -97,11 +101,11 @@ func TestLoop(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
exitN := &schema.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
Configs: &exit.Config{
|
||||
TerminatePlan: vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
|
|
@ -116,11 +120,11 @@ func TestLoop(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
ws := &schema.WorkflowSchema{
|
||||
Nodes: []*schema.NodeSchema{
|
||||
entryN,
|
||||
loopNode,
|
||||
exit,
|
||||
exitN,
|
||||
innerNode,
|
||||
continueNode,
|
||||
},
|
||||
|
|
@ -128,7 +132,7 @@ func TestLoop(t *testing.T) {
|
|||
"innerNode": "loop_node_key",
|
||||
"continueNode": "loop_node_key",
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
Connections: []*schema.Connection{
|
||||
{
|
||||
FromNode: "loop_node_key",
|
||||
ToNode: "innerNode",
|
||||
|
|
@ -142,12 +146,12 @@ func TestLoop(t *testing.T) {
|
|||
ToNode: "loop_node_key",
|
||||
},
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
FromNode: entryN.Key,
|
||||
ToNode: "loop_node_key",
|
||||
},
|
||||
{
|
||||
FromNode: "loop_node_key",
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
@ -168,7 +172,7 @@ func TestLoop(t *testing.T) {
|
|||
|
||||
t.Run("infinite", func(t *testing.T) {
|
||||
// start-> loop_node_key[innerNode->break] -> end
|
||||
innerNode := &compose2.NodeSchema{
|
||||
innerNode := &schema.NodeSchema{
|
||||
Key: "innerNode",
|
||||
Type: entity.NodeTypeLambda,
|
||||
Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
|
||||
|
|
@ -188,24 +192,23 @@ func TestLoop(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
breakNode := &compose2.NodeSchema{
|
||||
breakNode := &schema.NodeSchema{
|
||||
Key: "breakNode",
|
||||
Type: entity.NodeTypeBreak,
|
||||
Configs: &_break.Config{},
|
||||
}
|
||||
|
||||
entry := &compose2.NodeSchema{
|
||||
entryN := &schema.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
Configs: &entry.Config{},
|
||||
}
|
||||
|
||||
loopNode := &compose2.NodeSchema{
|
||||
loopNode := &schema.NodeSchema{
|
||||
Key: "loop_node_key",
|
||||
Type: entity.NodeTypeLoop,
|
||||
Configs: map[string]any{
|
||||
"LoopType": loop.Infinite,
|
||||
Configs: &loop.Config{
|
||||
LoopType: loop.Infinite,
|
||||
},
|
||||
OutputSources: []*vo.FieldInfo{
|
||||
{
|
||||
|
|
@ -220,11 +223,11 @@ func TestLoop(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
exitN := &schema.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
Configs: &exit.Config{
|
||||
TerminatePlan: vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
|
|
@ -239,11 +242,11 @@ func TestLoop(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
ws := &schema.WorkflowSchema{
|
||||
Nodes: []*schema.NodeSchema{
|
||||
entryN,
|
||||
loopNode,
|
||||
exit,
|
||||
exitN,
|
||||
innerNode,
|
||||
breakNode,
|
||||
},
|
||||
|
|
@ -251,7 +254,7 @@ func TestLoop(t *testing.T) {
|
|||
"innerNode": "loop_node_key",
|
||||
"breakNode": "loop_node_key",
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
Connections: []*schema.Connection{
|
||||
{
|
||||
FromNode: "loop_node_key",
|
||||
ToNode: "innerNode",
|
||||
|
|
@ -265,12 +268,12 @@ func TestLoop(t *testing.T) {
|
|||
ToNode: "loop_node_key",
|
||||
},
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
FromNode: entryN.Key,
|
||||
ToNode: "loop_node_key",
|
||||
},
|
||||
{
|
||||
FromNode: "loop_node_key",
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
@ -290,14 +293,14 @@ func TestLoop(t *testing.T) {
|
|||
t.Run("by array", func(t *testing.T) {
|
||||
// start-> loop_node_key[innerNode->variable_assign] -> end
|
||||
|
||||
innerNode := &compose2.NodeSchema{
|
||||
innerNode := &schema.NodeSchema{
|
||||
Key: "innerNode",
|
||||
Type: entity.NodeTypeLambda,
|
||||
Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
|
||||
item1 := in["item1"].(string)
|
||||
item2 := in["item2"].(string)
|
||||
count := in["count"].(int)
|
||||
return map[string]any{"total": int(count) + len(item1) + len(item2)}, nil
|
||||
return map[string]any{"total": count + len(item1) + len(item2)}, nil
|
||||
}),
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
|
|
@ -330,10 +333,11 @@ func TestLoop(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
assigner := &compose2.NodeSchema{
|
||||
assigner := &schema.NodeSchema{
|
||||
Key: "assigner",
|
||||
Type: entity.NodeTypeVariableAssignerWithinLoop,
|
||||
Configs: []*variableassigner.Pair{
|
||||
Configs: &variableassigner.InLoopConfig{
|
||||
Pairs: []*variableassigner.Pair{
|
||||
{
|
||||
Left: vo.Reference{
|
||||
FromPath: compose.FieldPath{"count"},
|
||||
|
|
@ -342,6 +346,7 @@ func TestLoop(t *testing.T) {
|
|||
Right: compose.FieldPath{"total"},
|
||||
},
|
||||
},
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"total"},
|
||||
|
|
@ -355,19 +360,17 @@ func TestLoop(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
entry := &compose2.NodeSchema{
|
||||
entryN := &schema.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
Configs: &entry.Config{},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
exitN := &schema.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
Configs: &exit.Config{
|
||||
TerminatePlan: vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
|
|
@ -382,12 +385,13 @@ func TestLoop(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
loopNode := &compose2.NodeSchema{
|
||||
loopNode := &schema.NodeSchema{
|
||||
Key: "loop_node_key",
|
||||
Type: entity.NodeTypeLoop,
|
||||
Configs: map[string]any{
|
||||
"LoopType": loop.ByArray,
|
||||
"IntermediateVars": map[string]*vo.TypeInfo{
|
||||
Configs: &loop.Config{
|
||||
LoopType: loop.ByArray,
|
||||
InputArrays: []string{"items1", "items2"},
|
||||
IntermediateVars: map[string]*vo.TypeInfo{
|
||||
"count": {
|
||||
Type: vo.DataTypeInteger,
|
||||
},
|
||||
|
|
@ -408,7 +412,7 @@ func TestLoop(t *testing.T) {
|
|||
Path: compose.FieldPath{"items1"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"items1"},
|
||||
},
|
||||
},
|
||||
|
|
@ -417,7 +421,7 @@ func TestLoop(t *testing.T) {
|
|||
Path: compose.FieldPath{"items2"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"items2"},
|
||||
},
|
||||
},
|
||||
|
|
@ -442,11 +446,11 @@ func TestLoop(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
ws := &schema.WorkflowSchema{
|
||||
Nodes: []*schema.NodeSchema{
|
||||
entryN,
|
||||
loopNode,
|
||||
exit,
|
||||
exitN,
|
||||
innerNode,
|
||||
assigner,
|
||||
},
|
||||
|
|
@ -454,7 +458,7 @@ func TestLoop(t *testing.T) {
|
|||
"innerNode": "loop_node_key",
|
||||
"assigner": "loop_node_key",
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
Connections: []*schema.Connection{
|
||||
{
|
||||
FromNode: "loop_node_key",
|
||||
ToNode: "innerNode",
|
||||
|
|
@ -468,12 +472,12 @@ func TestLoop(t *testing.T) {
|
|||
ToNode: "loop_node_key",
|
||||
},
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
FromNode: entryN.Key,
|
||||
ToNode: "loop_node_key",
|
||||
},
|
||||
{
|
||||
FromNode: "loop_node_key",
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -43,8 +43,11 @@ import (
|
|||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
|
||||
repo2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/repo"
|
||||
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/impl/checkpoint"
|
||||
mock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/idgen"
|
||||
storageMock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/storage"
|
||||
|
|
@ -106,26 +109,25 @@ func TestQuestionAnswer(t *testing.T) {
|
|||
mockey.Mock(workflow.GetRepository).Return(repo).Build()
|
||||
|
||||
t.Run("answer directly, no structured output", func(t *testing.T) {
|
||||
entry := &compose2.NodeSchema{
|
||||
entryN := &schema2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
}}
|
||||
Configs: &entry.Config{},
|
||||
}
|
||||
|
||||
ns := &compose2.NodeSchema{
|
||||
ns := &schema2.NodeSchema{
|
||||
Key: "qa_node_key",
|
||||
Type: entity.NodeTypeQuestionAnswer,
|
||||
Configs: map[string]any{
|
||||
"QuestionTpl": "{{input}}",
|
||||
"AnswerType": qa.AnswerDirectly,
|
||||
Configs: &qa.Config{
|
||||
QuestionTpl: "{{input}}",
|
||||
AnswerType: qa.AnswerDirectly,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"input"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"query"},
|
||||
},
|
||||
},
|
||||
|
|
@ -133,11 +135,11 @@ func TestQuestionAnswer(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
exitN := &schema2.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
Configs: &exit.Config{
|
||||
TerminatePlan: vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
|
|
@ -152,20 +154,20 @@ func TestQuestionAnswer(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
ws := &schema2.WorkflowSchema{
|
||||
Nodes: []*schema2.NodeSchema{
|
||||
entryN,
|
||||
ns,
|
||||
exit,
|
||||
exitN,
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
Connections: []*schema2.Connection{
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
FromNode: entryN.Key,
|
||||
ToNode: "qa_node_key",
|
||||
},
|
||||
{
|
||||
FromNode: "qa_node_key",
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
@ -210,30 +212,28 @@ func TestQuestionAnswer(t *testing.T) {
|
|||
mockModelManager.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(oneChatModel, nil, nil).Times(1)
|
||||
}
|
||||
|
||||
entry := &compose2.NodeSchema{
|
||||
entryN := &schema2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
Configs: &entry.Config{},
|
||||
}
|
||||
|
||||
ns := &compose2.NodeSchema{
|
||||
ns := &schema2.NodeSchema{
|
||||
Key: "qa_node_key",
|
||||
Type: entity.NodeTypeQuestionAnswer,
|
||||
Configs: map[string]any{
|
||||
"QuestionTpl": "{{input}}",
|
||||
"AnswerType": qa.AnswerByChoices,
|
||||
"ChoiceType": qa.FixedChoices,
|
||||
"FixedChoices": []string{"{{choice1}}", "{{choice2}}"},
|
||||
"LLMParams": &model.LLMParams{},
|
||||
Configs: &qa.Config{
|
||||
QuestionTpl: "{{input}}",
|
||||
AnswerType: qa.AnswerByChoices,
|
||||
ChoiceType: qa.FixedChoices,
|
||||
FixedChoices: []string{"{{choice1}}", "{{choice2}}"},
|
||||
LLMParams: &model.LLMParams{},
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"input"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"query"},
|
||||
},
|
||||
},
|
||||
|
|
@ -242,7 +242,7 @@ func TestQuestionAnswer(t *testing.T) {
|
|||
Path: compose.FieldPath{"choice1"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"choice1"},
|
||||
},
|
||||
},
|
||||
|
|
@ -251,7 +251,7 @@ func TestQuestionAnswer(t *testing.T) {
|
|||
Path: compose.FieldPath{"choice2"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"choice2"},
|
||||
},
|
||||
},
|
||||
|
|
@ -259,11 +259,11 @@ func TestQuestionAnswer(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
exitN := &schema2.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
Configs: &exit.Config{
|
||||
TerminatePlan: vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
|
|
@ -287,7 +287,7 @@ func TestQuestionAnswer(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
lambda := &compose2.NodeSchema{
|
||||
lambda := &schema2.NodeSchema{
|
||||
Key: "lambda",
|
||||
Type: entity.NodeTypeLambda,
|
||||
Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
|
||||
|
|
@ -295,26 +295,26 @@ func TestQuestionAnswer(t *testing.T) {
|
|||
}),
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
ws := &schema2.WorkflowSchema{
|
||||
Nodes: []*schema2.NodeSchema{
|
||||
entryN,
|
||||
ns,
|
||||
exit,
|
||||
exitN,
|
||||
lambda,
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
Connections: []*schema2.Connection{
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
FromNode: entryN.Key,
|
||||
ToNode: "qa_node_key",
|
||||
},
|
||||
{
|
||||
FromNode: "qa_node_key",
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
FromPort: ptr.Of("branch_0"),
|
||||
},
|
||||
{
|
||||
FromNode: "qa_node_key",
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
FromPort: ptr.Of("branch_1"),
|
||||
},
|
||||
{
|
||||
|
|
@ -324,11 +324,15 @@ func TestQuestionAnswer(t *testing.T) {
|
|||
},
|
||||
{
|
||||
FromNode: "lambda",
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
branches, err := schema2.BuildBranches(ws.Connections)
|
||||
assert.NoError(t, err)
|
||||
ws.Branches = branches
|
||||
|
||||
ws.Init()
|
||||
|
||||
wf, err := compose2.NewWorkflow(context.Background(), ws)
|
||||
|
|
@ -362,28 +366,26 @@ func TestQuestionAnswer(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("answer with dynamic choices", func(t *testing.T) {
|
||||
entry := &compose2.NodeSchema{
|
||||
entryN := &schema2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
Configs: &entry.Config{},
|
||||
}
|
||||
|
||||
ns := &compose2.NodeSchema{
|
||||
ns := &schema2.NodeSchema{
|
||||
Key: "qa_node_key",
|
||||
Type: entity.NodeTypeQuestionAnswer,
|
||||
Configs: map[string]any{
|
||||
"QuestionTpl": "{{input}}",
|
||||
"AnswerType": qa.AnswerByChoices,
|
||||
"ChoiceType": qa.DynamicChoices,
|
||||
Configs: &qa.Config{
|
||||
QuestionTpl: "{{input}}",
|
||||
AnswerType: qa.AnswerByChoices,
|
||||
ChoiceType: qa.DynamicChoices,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"input"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"query"},
|
||||
},
|
||||
},
|
||||
|
|
@ -392,7 +394,7 @@ func TestQuestionAnswer(t *testing.T) {
|
|||
Path: compose.FieldPath{qa.DynamicChoicesKey},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"choices"},
|
||||
},
|
||||
},
|
||||
|
|
@ -400,11 +402,11 @@ func TestQuestionAnswer(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
exitN := &schema2.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
Configs: &exit.Config{
|
||||
TerminatePlan: vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
|
|
@ -428,7 +430,7 @@ func TestQuestionAnswer(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
lambda := &compose2.NodeSchema{
|
||||
lambda := &schema2.NodeSchema{
|
||||
Key: "lambda",
|
||||
Type: entity.NodeTypeLambda,
|
||||
Lambda: compose.InvokableLambda(func(ctx context.Context, in map[string]any) (out map[string]any, err error) {
|
||||
|
|
@ -436,26 +438,26 @@ func TestQuestionAnswer(t *testing.T) {
|
|||
}),
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
ws := &schema2.WorkflowSchema{
|
||||
Nodes: []*schema2.NodeSchema{
|
||||
entryN,
|
||||
ns,
|
||||
exit,
|
||||
exitN,
|
||||
lambda,
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
Connections: []*schema2.Connection{
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
FromNode: entryN.Key,
|
||||
ToNode: "qa_node_key",
|
||||
},
|
||||
{
|
||||
FromNode: "qa_node_key",
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
FromPort: ptr.Of("branch_0"),
|
||||
},
|
||||
{
|
||||
FromNode: "lambda",
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
},
|
||||
{
|
||||
FromNode: "qa_node_key",
|
||||
|
|
@ -465,6 +467,10 @@ func TestQuestionAnswer(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
branches, err := schema2.BuildBranches(ws.Connections)
|
||||
assert.NoError(t, err)
|
||||
ws.Branches = branches
|
||||
|
||||
ws.Init()
|
||||
|
||||
wf, err := compose2.NewWorkflow(context.Background(), ws)
|
||||
|
|
@ -522,31 +528,29 @@ func TestQuestionAnswer(t *testing.T) {
|
|||
mockModelManager.EXPECT().GetModel(gomock.Any(), gomock.Any()).Return(chatModel, nil, nil).Times(1)
|
||||
}
|
||||
|
||||
entry := &compose2.NodeSchema{
|
||||
entryN := &schema2.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
Configs: &entry.Config{},
|
||||
}
|
||||
|
||||
ns := &compose2.NodeSchema{
|
||||
ns := &schema2.NodeSchema{
|
||||
Key: "qa_node_key",
|
||||
Type: entity.NodeTypeQuestionAnswer,
|
||||
Configs: map[string]any{
|
||||
"QuestionTpl": "{{input}}",
|
||||
"AnswerType": qa.AnswerDirectly,
|
||||
"ExtractFromAnswer": true,
|
||||
"AdditionalSystemPromptTpl": "{{prompt}}",
|
||||
"MaxAnswerCount": 2,
|
||||
"LLMParams": &model.LLMParams{},
|
||||
Configs: &qa.Config{
|
||||
QuestionTpl: "{{input}}",
|
||||
AnswerType: qa.AnswerDirectly,
|
||||
ExtractFromAnswer: true,
|
||||
AdditionalSystemPromptTpl: "{{prompt}}",
|
||||
MaxAnswerCount: 2,
|
||||
LLMParams: &model.LLMParams{},
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"input"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"query"},
|
||||
},
|
||||
},
|
||||
|
|
@ -555,7 +559,7 @@ func TestQuestionAnswer(t *testing.T) {
|
|||
Path: compose.FieldPath{"prompt"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"prompt"},
|
||||
},
|
||||
},
|
||||
|
|
@ -573,11 +577,11 @@ func TestQuestionAnswer(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
exitN := &schema2.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
Configs: &exit.Config{
|
||||
TerminatePlan: vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
|
|
@ -610,20 +614,20 @@ func TestQuestionAnswer(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
ws := &schema2.WorkflowSchema{
|
||||
Nodes: []*schema2.NodeSchema{
|
||||
entryN,
|
||||
ns,
|
||||
exit,
|
||||
exitN,
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
Connections: []*schema2.Connection{
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
FromNode: entryN.Key,
|
||||
ToNode: "qa_node_key",
|
||||
},
|
||||
{
|
||||
FromNode: "qa_node_key",
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -26,26 +26,28 @@ import (
|
|||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
compose2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/textprocessor"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableaggregator"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
func TestAddSelector(t *testing.T) {
|
||||
// start -> selector, selector.condition1 -> lambda1 -> end, selector.condition2 -> [lambda2, lambda3] -> end, selector default -> end
|
||||
entry := &compose2.NodeSchema{
|
||||
entryN := &schema.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
}}
|
||||
Configs: &entry.Config{},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
exitN := &schema.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
Configs: &exit.Config{
|
||||
TerminatePlan: vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
|
|
@ -84,7 +86,7 @@ func TestAddSelector(t *testing.T) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
lambdaNode1 := &compose2.NodeSchema{
|
||||
lambdaNode1 := &schema.NodeSchema{
|
||||
Key: "lambda1",
|
||||
Type: entity.NodeTypeLambda,
|
||||
Lambda: compose.InvokableLambda(lambda1),
|
||||
|
|
@ -96,7 +98,7 @@ func TestAddSelector(t *testing.T) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
LambdaNode2 := &compose2.NodeSchema{
|
||||
LambdaNode2 := &schema.NodeSchema{
|
||||
Key: "lambda2",
|
||||
Type: entity.NodeTypeLambda,
|
||||
Lambda: compose.InvokableLambda(lambda2),
|
||||
|
|
@ -108,16 +110,16 @@ func TestAddSelector(t *testing.T) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
lambdaNode3 := &compose2.NodeSchema{
|
||||
lambdaNode3 := &schema.NodeSchema{
|
||||
Key: "lambda3",
|
||||
Type: entity.NodeTypeLambda,
|
||||
Lambda: compose.InvokableLambda(lambda3),
|
||||
}
|
||||
|
||||
ns := &compose2.NodeSchema{
|
||||
ns := &schema.NodeSchema{
|
||||
Key: "selector",
|
||||
Type: entity.NodeTypeSelector,
|
||||
Configs: map[string]any{"Clauses": []*selector.OneClauseSchema{
|
||||
Configs: &selector.Config{Clauses: []*selector.OneClauseSchema{
|
||||
{
|
||||
Single: ptr.Of(selector.OperatorEqual),
|
||||
},
|
||||
|
|
@ -136,7 +138,7 @@ func TestAddSelector(t *testing.T) {
|
|||
Path: compose.FieldPath{"0", selector.LeftKey},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"key1"},
|
||||
},
|
||||
},
|
||||
|
|
@ -151,7 +153,7 @@ func TestAddSelector(t *testing.T) {
|
|||
Path: compose.FieldPath{"1", "0", selector.LeftKey},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"key2"},
|
||||
},
|
||||
},
|
||||
|
|
@ -160,7 +162,7 @@ func TestAddSelector(t *testing.T) {
|
|||
Path: compose.FieldPath{"1", "0", selector.RightKey},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"key3"},
|
||||
},
|
||||
},
|
||||
|
|
@ -169,7 +171,7 @@ func TestAddSelector(t *testing.T) {
|
|||
Path: compose.FieldPath{"1", "1", selector.LeftKey},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"key4"},
|
||||
},
|
||||
},
|
||||
|
|
@ -214,18 +216,18 @@ func TestAddSelector(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
ws := &schema.WorkflowSchema{
|
||||
Nodes: []*schema.NodeSchema{
|
||||
entryN,
|
||||
ns,
|
||||
lambdaNode1,
|
||||
LambdaNode2,
|
||||
lambdaNode3,
|
||||
exit,
|
||||
exitN,
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
Connections: []*schema.Connection{
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
FromNode: entryN.Key,
|
||||
ToNode: "selector",
|
||||
},
|
||||
{
|
||||
|
|
@ -245,24 +247,28 @@ func TestAddSelector(t *testing.T) {
|
|||
},
|
||||
{
|
||||
FromNode: "selector",
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
FromPort: ptr.Of("default"),
|
||||
},
|
||||
{
|
||||
FromNode: "lambda1",
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
},
|
||||
{
|
||||
FromNode: "lambda2",
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
},
|
||||
{
|
||||
FromNode: "lambda3",
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
branches, err := schema.BuildBranches(ws.Connections)
|
||||
assert.NoError(t, err)
|
||||
ws.Branches = branches
|
||||
|
||||
ws.Init()
|
||||
|
||||
ctx := context.Background()
|
||||
|
|
@ -303,19 +309,17 @@ func TestAddSelector(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestVariableAggregator(t *testing.T) {
|
||||
entry := &compose2.NodeSchema{
|
||||
entryN := &schema.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
Configs: &entry.Config{},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
exitN := &schema.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
Configs: &exit.Config{
|
||||
TerminatePlan: vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
|
|
@ -339,16 +343,16 @@ func TestVariableAggregator(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
ns := &compose2.NodeSchema{
|
||||
ns := &schema.NodeSchema{
|
||||
Key: "va",
|
||||
Type: entity.NodeTypeVariableAggregator,
|
||||
Configs: map[string]any{
|
||||
"MergeStrategy": variableaggregator.FirstNotNullValue,
|
||||
"GroupToLen": map[string]int{
|
||||
Configs: &variableaggregator.Config{
|
||||
MergeStrategy: variableaggregator.FirstNotNullValue,
|
||||
GroupLen: map[string]int{
|
||||
"Group1": 1,
|
||||
"Group2": 1,
|
||||
},
|
||||
"GroupOrder": []string{
|
||||
GroupOrder: []string{
|
||||
"Group1",
|
||||
"Group2",
|
||||
},
|
||||
|
|
@ -358,7 +362,7 @@ func TestVariableAggregator(t *testing.T) {
|
|||
Path: compose.FieldPath{"Group1", "0"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"Str1"},
|
||||
},
|
||||
},
|
||||
|
|
@ -367,7 +371,7 @@ func TestVariableAggregator(t *testing.T) {
|
|||
Path: compose.FieldPath{"Group2", "0"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"Int1"},
|
||||
},
|
||||
},
|
||||
|
|
@ -401,20 +405,20 @@ func TestVariableAggregator(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
entry,
|
||||
ws := &schema.WorkflowSchema{
|
||||
Nodes: []*schema.NodeSchema{
|
||||
entryN,
|
||||
ns,
|
||||
exit,
|
||||
exitN,
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
Connections: []*schema.Connection{
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
FromNode: entryN.Key,
|
||||
ToNode: "va",
|
||||
},
|
||||
{
|
||||
FromNode: "va",
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
@ -448,19 +452,17 @@ func TestVariableAggregator(t *testing.T) {
|
|||
|
||||
func TestTextProcessor(t *testing.T) {
|
||||
t.Run("split", func(t *testing.T) {
|
||||
entry := &compose2.NodeSchema{
|
||||
entryN := &schema.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
Configs: &entry.Config{},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
exitN := &schema.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
Configs: &exit.Config{
|
||||
TerminatePlan: vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
|
|
@ -475,19 +477,19 @@ func TestTextProcessor(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
ns := &compose2.NodeSchema{
|
||||
ns := &schema.NodeSchema{
|
||||
Key: "tp",
|
||||
Type: entity.NodeTypeTextProcessor,
|
||||
Configs: map[string]any{
|
||||
"Type": textprocessor.SplitText,
|
||||
"Separators": []string{"|"},
|
||||
Configs: &textprocessor.Config{
|
||||
Type: textprocessor.SplitText,
|
||||
Separators: []string{"|"},
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"String"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"Str"},
|
||||
},
|
||||
},
|
||||
|
|
@ -495,20 +497,20 @@ func TestTextProcessor(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
ws := &schema.WorkflowSchema{
|
||||
Nodes: []*schema.NodeSchema{
|
||||
ns,
|
||||
entry,
|
||||
exit,
|
||||
entryN,
|
||||
exitN,
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
Connections: []*schema.Connection{
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
FromNode: entryN.Key,
|
||||
ToNode: "tp",
|
||||
},
|
||||
{
|
||||
FromNode: "tp",
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
@ -527,19 +529,17 @@ func TestTextProcessor(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("concat", func(t *testing.T) {
|
||||
entry := &compose2.NodeSchema{
|
||||
entryN := &schema.NodeSchema{
|
||||
Key: entity.EntryNodeKey,
|
||||
Type: entity.NodeTypeEntry,
|
||||
Configs: map[string]any{
|
||||
"DefaultValues": map[string]any{},
|
||||
},
|
||||
Configs: &entry.Config{},
|
||||
}
|
||||
|
||||
exit := &compose2.NodeSchema{
|
||||
exitN := &schema.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Configs: map[string]any{
|
||||
"TerminalPlan": vo.ReturnVariables,
|
||||
Configs: &exit.Config{
|
||||
TerminatePlan: vo.ReturnVariables,
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
|
|
@ -554,20 +554,20 @@ func TestTextProcessor(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
ns := &compose2.NodeSchema{
|
||||
ns := &schema.NodeSchema{
|
||||
Key: "tp",
|
||||
Type: entity.NodeTypeTextProcessor,
|
||||
Configs: map[string]any{
|
||||
"Type": textprocessor.ConcatText,
|
||||
"Tpl": "{{String1}}_{{String2.f1}}_{{String3.f2[1]}}",
|
||||
"ConcatChar": "\t",
|
||||
Configs: &textprocessor.Config{
|
||||
Type: textprocessor.ConcatText,
|
||||
Tpl: "{{String1}}_{{String2.f1}}_{{String3.f2[1]}}",
|
||||
ConcatChar: "\t",
|
||||
},
|
||||
InputSources: []*vo.FieldInfo{
|
||||
{
|
||||
Path: compose.FieldPath{"String1"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"Str1"},
|
||||
},
|
||||
},
|
||||
|
|
@ -576,7 +576,7 @@ func TestTextProcessor(t *testing.T) {
|
|||
Path: compose.FieldPath{"String2"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"Str2"},
|
||||
},
|
||||
},
|
||||
|
|
@ -585,7 +585,7 @@ func TestTextProcessor(t *testing.T) {
|
|||
Path: compose.FieldPath{"String3"},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: entry.Key,
|
||||
FromNodeKey: entryN.Key,
|
||||
FromPath: compose.FieldPath{"Str3"},
|
||||
},
|
||||
},
|
||||
|
|
@ -593,20 +593,20 @@ func TestTextProcessor(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
ws := &compose2.WorkflowSchema{
|
||||
Nodes: []*compose2.NodeSchema{
|
||||
ws := &schema.WorkflowSchema{
|
||||
Nodes: []*schema.NodeSchema{
|
||||
ns,
|
||||
entry,
|
||||
exit,
|
||||
entryN,
|
||||
exitN,
|
||||
},
|
||||
Connections: []*compose2.Connection{
|
||||
Connections: []*schema.Connection{
|
||||
{
|
||||
FromNode: entry.Key,
|
||||
FromNode: entryN.Key,
|
||||
ToNode: "tp",
|
||||
},
|
||||
{
|
||||
FromNode: "tp",
|
||||
ToNode: exit.Key,
|
||||
ToNode: exitN.Key,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,652 +0,0 @@
|
|||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package compose
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
einomodel "github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
workflow3 "github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/workflow"
|
||||
workflow2 "github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
crosscode "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/code"
|
||||
crossconversation "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/conversation"
|
||||
crossdatabase "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
crossknowledge "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
|
||||
crossplugin "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/batch"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/code"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/conversation"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/emitter"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/entry"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/httprequester"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/intentdetector"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/json"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/receiver"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/selector"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/subworkflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/textprocessor"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableaggregator"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/variableassigner"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/coderunner"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
)
|
||||
|
||||
func (s *NodeSchema) ToEntryConfig(_ context.Context) (*entry.Config, error) {
|
||||
return &entry.Config{
|
||||
DefaultValues: getKeyOrZero[map[string]any]("DefaultValues", s.Configs),
|
||||
OutputTypes: s.OutputTypes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToLLMConfig(ctx context.Context) (*llm.Config, error) {
|
||||
llmConf := &llm.Config{
|
||||
SystemPrompt: getKeyOrZero[string]("SystemPrompt", s.Configs),
|
||||
UserPrompt: getKeyOrZero[string]("UserPrompt", s.Configs),
|
||||
OutputFormat: mustGetKey[llm.Format]("OutputFormat", s.Configs),
|
||||
InputFields: s.InputTypes,
|
||||
OutputFields: s.OutputTypes,
|
||||
FullSources: getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs),
|
||||
}
|
||||
|
||||
llmParams := getKeyOrZero[*model.LLMParams]("LLMParams", s.Configs)
|
||||
|
||||
if llmParams == nil {
|
||||
return nil, fmt.Errorf("llm node llmParams is required")
|
||||
}
|
||||
var (
|
||||
err error
|
||||
chatModel, fallbackM einomodel.BaseChatModel
|
||||
info, fallbackI *modelmgr.Model
|
||||
modelWithInfo llm.ModelWithInfo
|
||||
)
|
||||
|
||||
chatModel, info, err = model.GetManager().GetModel(ctx, llmParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
metaConfigs := s.ExceptionConfigs
|
||||
if metaConfigs != nil && metaConfigs.MaxRetry > 0 {
|
||||
backupModelParams := getKeyOrZero[*model.LLMParams]("BackupLLMParams", s.Configs)
|
||||
if backupModelParams != nil {
|
||||
fallbackM, fallbackI, err = model.GetManager().GetModel(ctx, backupModelParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if fallbackM == nil {
|
||||
modelWithInfo = llm.NewModel(chatModel, info)
|
||||
} else {
|
||||
modelWithInfo = llm.NewModelWithFallback(chatModel, fallbackM, info, fallbackI)
|
||||
}
|
||||
llmConf.ChatModel = modelWithInfo
|
||||
|
||||
fcParams := getKeyOrZero[*vo.FCParam]("FCParam", s.Configs)
|
||||
if fcParams != nil {
|
||||
if fcParams.WorkflowFCParam != nil {
|
||||
for _, wf := range fcParams.WorkflowFCParam.WorkflowList {
|
||||
wfIDStr := wf.WorkflowID
|
||||
wfID, err := strconv.ParseInt(wfIDStr, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid workflow id: %s", wfIDStr)
|
||||
}
|
||||
|
||||
workflowToolConfig := vo.WorkflowToolConfig{}
|
||||
if wf.FCSetting != nil {
|
||||
workflowToolConfig.InputParametersConfig = wf.FCSetting.RequestParameters
|
||||
workflowToolConfig.OutputParametersConfig = wf.FCSetting.ResponseParameters
|
||||
}
|
||||
|
||||
locator := vo.FromDraft
|
||||
if wf.WorkflowVersion != "" {
|
||||
locator = vo.FromSpecificVersion
|
||||
}
|
||||
|
||||
wfTool, err := workflow2.GetRepository().WorkflowAsTool(ctx, vo.GetPolicy{
|
||||
ID: wfID,
|
||||
QType: locator,
|
||||
Version: wf.WorkflowVersion,
|
||||
}, workflowToolConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
llmConf.Tools = append(llmConf.Tools, wfTool)
|
||||
if wfTool.TerminatePlan() == vo.UseAnswerContent {
|
||||
if llmConf.ToolsReturnDirectly == nil {
|
||||
llmConf.ToolsReturnDirectly = make(map[string]bool)
|
||||
}
|
||||
toolInfo, err := wfTool.Info(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
llmConf.ToolsReturnDirectly[toolInfo.Name] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if fcParams.PluginFCParam != nil {
|
||||
pluginToolsInvokableReq := make(map[int64]*crossplugin.ToolsInvokableRequest)
|
||||
for _, p := range fcParams.PluginFCParam.PluginList {
|
||||
pid, err := strconv.ParseInt(p.PluginID, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid plugin id: %s", p.PluginID)
|
||||
}
|
||||
toolID, err := strconv.ParseInt(p.ApiId, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid plugin id: %s", p.PluginID)
|
||||
}
|
||||
|
||||
var (
|
||||
requestParameters []*workflow3.APIParameter
|
||||
responseParameters []*workflow3.APIParameter
|
||||
)
|
||||
if p.FCSetting != nil {
|
||||
requestParameters = p.FCSetting.RequestParameters
|
||||
responseParameters = p.FCSetting.ResponseParameters
|
||||
}
|
||||
|
||||
if req, ok := pluginToolsInvokableReq[pid]; ok {
|
||||
req.ToolsInvokableInfo[toolID] = &crossplugin.ToolsInvokableInfo{
|
||||
ToolID: toolID,
|
||||
RequestAPIParametersConfig: requestParameters,
|
||||
ResponseAPIParametersConfig: responseParameters,
|
||||
}
|
||||
} else {
|
||||
pluginToolsInfoRequest := &crossplugin.ToolsInvokableRequest{
|
||||
PluginEntity: crossplugin.Entity{
|
||||
PluginID: pid,
|
||||
PluginVersion: ptr.Of(p.PluginVersion),
|
||||
},
|
||||
ToolsInvokableInfo: map[int64]*crossplugin.ToolsInvokableInfo{
|
||||
toolID: {
|
||||
ToolID: toolID,
|
||||
RequestAPIParametersConfig: requestParameters,
|
||||
ResponseAPIParametersConfig: responseParameters,
|
||||
},
|
||||
},
|
||||
IsDraft: p.IsDraft,
|
||||
}
|
||||
pluginToolsInvokableReq[pid] = pluginToolsInfoRequest
|
||||
}
|
||||
}
|
||||
inInvokableTools := make([]tool.BaseTool, 0, len(fcParams.PluginFCParam.PluginList))
|
||||
for _, req := range pluginToolsInvokableReq {
|
||||
toolMap, err := crossplugin.GetPluginService().GetPluginInvokableTools(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, t := range toolMap {
|
||||
inInvokableTools = append(inInvokableTools, crossplugin.NewInvokableTool(t))
|
||||
}
|
||||
}
|
||||
if len(inInvokableTools) > 0 {
|
||||
llmConf.Tools = inInvokableTools
|
||||
}
|
||||
}
|
||||
|
||||
if fcParams.KnowledgeFCParam != nil && len(fcParams.KnowledgeFCParam.KnowledgeList) > 0 {
|
||||
kwChatModel := workflow2.GetRepository().GetKnowledgeRecallChatModel()
|
||||
if kwChatModel == nil {
|
||||
return nil, fmt.Errorf("workflow builtin chat model for knowledge recall not configured")
|
||||
}
|
||||
knowledgeOperator := crossknowledge.GetKnowledgeOperator()
|
||||
setting := fcParams.KnowledgeFCParam.GlobalSetting
|
||||
cfg := &llm.KnowledgeRecallConfig{
|
||||
ChatModel: kwChatModel,
|
||||
Retriever: knowledgeOperator,
|
||||
}
|
||||
searchType, err := totRetrievalSearchType(setting.SearchMode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cfg.RetrievalStrategy = &llm.RetrievalStrategy{
|
||||
RetrievalStrategy: &crossknowledge.RetrievalStrategy{
|
||||
TopK: ptr.Of(setting.TopK),
|
||||
MinScore: ptr.Of(setting.MinScore),
|
||||
SearchType: searchType,
|
||||
EnableNL2SQL: setting.UseNL2SQL,
|
||||
EnableQueryRewrite: setting.UseRewrite,
|
||||
EnableRerank: setting.UseRerank,
|
||||
},
|
||||
NoReCallReplyMode: llm.NoReCallReplyMode(setting.NoRecallReplyMode),
|
||||
NoReCallReplyCustomizePrompt: setting.NoRecallReplyCustomizePrompt,
|
||||
}
|
||||
|
||||
knowledgeIDs := make([]int64, 0, len(fcParams.KnowledgeFCParam.KnowledgeList))
|
||||
for _, kw := range fcParams.KnowledgeFCParam.KnowledgeList {
|
||||
kid, err := strconv.ParseInt(kw.ID, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
knowledgeIDs = append(knowledgeIDs, kid)
|
||||
}
|
||||
|
||||
detailResp, err := knowledgeOperator.ListKnowledgeDetail(ctx, &crossknowledge.ListKnowledgeDetailRequest{
|
||||
KnowledgeIDs: knowledgeIDs,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cfg.SelectedKnowledgeDetails = detailResp.KnowledgeDetails
|
||||
llmConf.KnowledgeRecallConfig = cfg
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return llmConf, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToSelectorConfig() *selector.Config {
|
||||
return &selector.Config{
|
||||
Clauses: mustGetKey[[]*selector.OneClauseSchema]("Clauses", s.Configs),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NodeSchema) SelectorInputConverter(in map[string]any) (out []selector.Operants, err error) {
|
||||
conf := mustGetKey[[]*selector.OneClauseSchema]("Clauses", s.Configs)
|
||||
|
||||
for i, oneConf := range conf {
|
||||
if oneConf.Single != nil {
|
||||
left, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), selector.LeftKey})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to take left operant from input map: %v, clause index= %d", in, i)
|
||||
}
|
||||
|
||||
right, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), selector.RightKey})
|
||||
if ok {
|
||||
out = append(out, selector.Operants{Left: left, Right: right})
|
||||
} else {
|
||||
out = append(out, selector.Operants{Left: left})
|
||||
}
|
||||
} else if oneConf.Multi != nil {
|
||||
multiClause := make([]*selector.Operants, 0)
|
||||
for j := range oneConf.Multi.Clauses {
|
||||
left, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), strconv.Itoa(j), selector.LeftKey})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to take left operant from input map: %v, clause index= %d, single clause index= %d", in, i, j)
|
||||
}
|
||||
right, ok := nodes.TakeMapValue(in, compose.FieldPath{strconv.Itoa(i), strconv.Itoa(j), selector.RightKey})
|
||||
if ok {
|
||||
multiClause = append(multiClause, &selector.Operants{Left: left, Right: right})
|
||||
} else {
|
||||
multiClause = append(multiClause, &selector.Operants{Left: left})
|
||||
}
|
||||
}
|
||||
out = append(out, selector.Operants{Multi: multiClause})
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid clause config, both single and multi are nil: %v", oneConf)
|
||||
}
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToBatchConfig(inner compose.Runnable[map[string]any, map[string]any]) (*batch.Config, error) {
|
||||
conf := &batch.Config{
|
||||
BatchNodeKey: s.Key,
|
||||
InnerWorkflow: inner,
|
||||
Outputs: s.OutputSources,
|
||||
}
|
||||
|
||||
for key, tInfo := range s.InputTypes {
|
||||
if tInfo.Type != vo.DataTypeArray {
|
||||
continue
|
||||
}
|
||||
|
||||
conf.InputArrays = append(conf.InputArrays, key)
|
||||
}
|
||||
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToVariableAggregatorConfig() (*variableaggregator.Config, error) {
|
||||
return &variableaggregator.Config{
|
||||
MergeStrategy: s.Configs.(map[string]any)["MergeStrategy"].(variableaggregator.MergeStrategy),
|
||||
GroupLen: s.Configs.(map[string]any)["GroupToLen"].(map[string]int),
|
||||
FullSources: getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs),
|
||||
NodeKey: s.Key,
|
||||
InputSources: s.InputSources,
|
||||
GroupOrder: mustGetKey[[]string]("GroupOrder", s.Configs),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) variableAggregatorInputConverter(in map[string]any) (converted map[string]map[int]any) {
|
||||
converted = make(map[string]map[int]any)
|
||||
|
||||
for k, value := range in {
|
||||
m, ok := value.(map[string]any)
|
||||
if !ok {
|
||||
panic(errors.New("value is not a map[string]any"))
|
||||
}
|
||||
converted[k] = make(map[int]any, len(m))
|
||||
for i, sv := range m {
|
||||
index, err := strconv.Atoi(i)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf(" converting %s to int failed, err=%v", i, err))
|
||||
}
|
||||
converted[k][index] = sv
|
||||
}
|
||||
}
|
||||
|
||||
return converted
|
||||
}
|
||||
|
||||
func (s *NodeSchema) variableAggregatorStreamInputConverter(in *schema.StreamReader[map[string]any]) *schema.StreamReader[map[string]map[int]any] {
|
||||
converter := func(input map[string]any) (output map[string]map[int]any, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = safego.NewPanicErr(r, debug.Stack())
|
||||
}
|
||||
}()
|
||||
return s.variableAggregatorInputConverter(input), nil
|
||||
}
|
||||
return schema.StreamReaderWithConvert(in, converter)
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToTextProcessorConfig() (*textprocessor.Config, error) {
|
||||
return &textprocessor.Config{
|
||||
Type: s.Configs.(map[string]any)["Type"].(textprocessor.Type),
|
||||
Tpl: getKeyOrZero[string]("Tpl", s.Configs.(map[string]any)),
|
||||
ConcatChar: getKeyOrZero[string]("ConcatChar", s.Configs.(map[string]any)),
|
||||
Separators: getKeyOrZero[[]string]("Separators", s.Configs.(map[string]any)),
|
||||
FullSources: getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToJsonSerializationConfig() (*json.SerializationConfig, error) {
|
||||
return &json.SerializationConfig{
|
||||
InputTypes: s.InputTypes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToJsonDeserializationConfig() (*json.DeserializationConfig, error) {
|
||||
return &json.DeserializationConfig{
|
||||
OutputFields: s.OutputTypes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToHTTPRequesterConfig() (*httprequester.Config, error) {
|
||||
return &httprequester.Config{
|
||||
URLConfig: mustGetKey[httprequester.URLConfig]("URLConfig", s.Configs),
|
||||
AuthConfig: getKeyOrZero[*httprequester.AuthenticationConfig]("AuthConfig", s.Configs),
|
||||
BodyConfig: mustGetKey[httprequester.BodyConfig]("BodyConfig", s.Configs),
|
||||
Method: mustGetKey[string]("Method", s.Configs),
|
||||
Timeout: mustGetKey[time.Duration]("Timeout", s.Configs),
|
||||
RetryTimes: mustGetKey[uint64]("RetryTimes", s.Configs),
|
||||
MD5FieldMapping: mustGetKey[httprequester.MD5FieldMapping]("MD5FieldMapping", s.Configs),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToVariableAssignerConfig(handler *variable.Handler) (*variableassigner.Config, error) {
|
||||
return &variableassigner.Config{
|
||||
Pairs: s.Configs.([]*variableassigner.Pair),
|
||||
Handler: handler,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToVariableAssignerInLoopConfig() (*variableassigner.Config, error) {
|
||||
return &variableassigner.Config{
|
||||
Pairs: s.Configs.([]*variableassigner.Pair),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToLoopConfig(inner compose.Runnable[map[string]any, map[string]any]) (*loop.Config, error) {
|
||||
conf := &loop.Config{
|
||||
LoopNodeKey: s.Key,
|
||||
LoopType: mustGetKey[loop.Type]("LoopType", s.Configs),
|
||||
IntermediateVars: getKeyOrZero[map[string]*vo.TypeInfo]("IntermediateVars", s.Configs),
|
||||
Outputs: s.OutputSources,
|
||||
Inner: inner,
|
||||
}
|
||||
|
||||
for key, tInfo := range s.InputTypes {
|
||||
if tInfo.Type != vo.DataTypeArray {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := conf.IntermediateVars[key]; ok { // exclude arrays in intermediate vars
|
||||
continue
|
||||
}
|
||||
|
||||
conf.InputArrays = append(conf.InputArrays, key)
|
||||
}
|
||||
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToQAConfig(ctx context.Context) (*qa.Config, error) {
|
||||
conf := &qa.Config{
|
||||
QuestionTpl: mustGetKey[string]("QuestionTpl", s.Configs),
|
||||
AnswerType: mustGetKey[qa.AnswerType]("AnswerType", s.Configs),
|
||||
ChoiceType: getKeyOrZero[qa.ChoiceType]("ChoiceType", s.Configs),
|
||||
FixedChoices: getKeyOrZero[[]string]("FixedChoices", s.Configs),
|
||||
ExtractFromAnswer: getKeyOrZero[bool]("ExtractFromAnswer", s.Configs),
|
||||
MaxAnswerCount: getKeyOrZero[int]("MaxAnswerCount", s.Configs),
|
||||
AdditionalSystemPromptTpl: getKeyOrZero[string]("AdditionalSystemPromptTpl", s.Configs),
|
||||
OutputFields: s.OutputTypes,
|
||||
NodeKey: s.Key,
|
||||
}
|
||||
|
||||
llmParams := getKeyOrZero[*model.LLMParams]("LLMParams", s.Configs)
|
||||
if llmParams != nil {
|
||||
m, _, err := model.GetManager().GetModel(ctx, llmParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conf.Model = m
|
||||
}
|
||||
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToInputReceiverConfig() (*receiver.Config, error) {
|
||||
return &receiver.Config{
|
||||
OutputTypes: s.OutputTypes,
|
||||
NodeKey: s.Key,
|
||||
OutputSchema: mustGetKey[string]("OutputSchema", s.Configs),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToOutputEmitterConfig(sc *WorkflowSchema) (*emitter.Config, error) {
|
||||
conf := &emitter.Config{
|
||||
Template: getKeyOrZero[string]("Template", s.Configs),
|
||||
FullSources: getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs),
|
||||
}
|
||||
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToDatabaseCustomSQLConfig() (*database.CustomSQLConfig, error) {
|
||||
return &database.CustomSQLConfig{
|
||||
DatabaseInfoID: mustGetKey[int64]("DatabaseInfoID", s.Configs),
|
||||
SQLTemplate: mustGetKey[string]("SQLTemplate", s.Configs),
|
||||
OutputConfig: s.OutputTypes,
|
||||
CustomSQLExecutor: crossdatabase.GetDatabaseOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToDatabaseQueryConfig() (*database.QueryConfig, error) {
|
||||
return &database.QueryConfig{
|
||||
DatabaseInfoID: mustGetKey[int64]("DatabaseInfoID", s.Configs),
|
||||
QueryFields: getKeyOrZero[[]string]("QueryFields", s.Configs),
|
||||
OrderClauses: getKeyOrZero[[]*crossdatabase.OrderClause]("OrderClauses", s.Configs),
|
||||
ClauseGroup: getKeyOrZero[*crossdatabase.ClauseGroup]("ClauseGroup", s.Configs),
|
||||
OutputConfig: s.OutputTypes,
|
||||
Limit: mustGetKey[int64]("Limit", s.Configs),
|
||||
Op: crossdatabase.GetDatabaseOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToDatabaseInsertConfig() (*database.InsertConfig, error) {
|
||||
return &database.InsertConfig{
|
||||
DatabaseInfoID: mustGetKey[int64]("DatabaseInfoID", s.Configs),
|
||||
OutputConfig: s.OutputTypes,
|
||||
Inserter: crossdatabase.GetDatabaseOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToDatabaseDeleteConfig() (*database.DeleteConfig, error) {
|
||||
return &database.DeleteConfig{
|
||||
DatabaseInfoID: mustGetKey[int64]("DatabaseInfoID", s.Configs),
|
||||
ClauseGroup: mustGetKey[*crossdatabase.ClauseGroup]("ClauseGroup", s.Configs),
|
||||
OutputConfig: s.OutputTypes,
|
||||
Deleter: crossdatabase.GetDatabaseOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToDatabaseUpdateConfig() (*database.UpdateConfig, error) {
|
||||
return &database.UpdateConfig{
|
||||
DatabaseInfoID: mustGetKey[int64]("DatabaseInfoID", s.Configs),
|
||||
ClauseGroup: mustGetKey[*crossdatabase.ClauseGroup]("ClauseGroup", s.Configs),
|
||||
OutputConfig: s.OutputTypes,
|
||||
Updater: crossdatabase.GetDatabaseOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToKnowledgeIndexerConfig() (*knowledge.IndexerConfig, error) {
|
||||
return &knowledge.IndexerConfig{
|
||||
KnowledgeID: mustGetKey[int64]("KnowledgeID", s.Configs),
|
||||
ParsingStrategy: mustGetKey[*crossknowledge.ParsingStrategy]("ParsingStrategy", s.Configs),
|
||||
ChunkingStrategy: mustGetKey[*crossknowledge.ChunkingStrategy]("ChunkingStrategy", s.Configs),
|
||||
KnowledgeIndexer: crossknowledge.GetKnowledgeOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToKnowledgeRetrieveConfig() (*knowledge.RetrieveConfig, error) {
|
||||
return &knowledge.RetrieveConfig{
|
||||
KnowledgeIDs: mustGetKey[[]int64]("KnowledgeIDs", s.Configs),
|
||||
RetrievalStrategy: mustGetKey[*crossknowledge.RetrievalStrategy]("RetrievalStrategy", s.Configs),
|
||||
Retriever: crossknowledge.GetKnowledgeOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToKnowledgeDeleterConfig() (*knowledge.DeleterConfig, error) {
|
||||
return &knowledge.DeleterConfig{
|
||||
KnowledgeID: mustGetKey[int64]("KnowledgeID", s.Configs),
|
||||
KnowledgeDeleter: crossknowledge.GetKnowledgeOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToPluginConfig() (*plugin.Config, error) {
|
||||
return &plugin.Config{
|
||||
PluginID: mustGetKey[int64]("PluginID", s.Configs),
|
||||
ToolID: mustGetKey[int64]("ToolID", s.Configs),
|
||||
PluginVersion: mustGetKey[string]("PluginVersion", s.Configs),
|
||||
PluginService: crossplugin.GetPluginService(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToCodeRunnerConfig() (*code.Config, error) {
|
||||
return &code.Config{
|
||||
Code: mustGetKey[string]("Code", s.Configs),
|
||||
Language: mustGetKey[coderunner.Language]("Language", s.Configs),
|
||||
OutputConfig: s.OutputTypes,
|
||||
Runner: crosscode.GetCodeRunner(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToCreateConversationConfig() (*conversation.CreateConversationConfig, error) {
|
||||
return &conversation.CreateConversationConfig{
|
||||
Creator: crossconversation.ConversationManagerImpl,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToClearMessageConfig() (*conversation.ClearMessageConfig, error) {
|
||||
return &conversation.ClearMessageConfig{
|
||||
Clearer: crossconversation.ConversationManagerImpl,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToMessageListConfig() (*conversation.MessageListConfig, error) {
|
||||
return &conversation.MessageListConfig{
|
||||
Lister: crossconversation.ConversationManagerImpl,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToIntentDetectorConfig(ctx context.Context) (*intentdetector.Config, error) {
|
||||
cfg := &intentdetector.Config{
|
||||
Intents: mustGetKey[[]string]("Intents", s.Configs),
|
||||
SystemPrompt: getKeyOrZero[string]("SystemPrompt", s.Configs),
|
||||
IsFastMode: getKeyOrZero[bool]("IsFastMode", s.Configs),
|
||||
}
|
||||
|
||||
llmParams := mustGetKey[*model.LLMParams]("LLMParams", s.Configs)
|
||||
m, _, err := model.GetManager().GetModel(ctx, llmParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cfg.ChatModel = m
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (s *NodeSchema) ToSubWorkflowConfig(ctx context.Context, requireCheckpoint bool) (*subworkflow.Config, error) {
|
||||
var opts []WorkflowOption
|
||||
opts = append(opts, WithIDAsName(mustGetKey[int64]("WorkflowID", s.Configs)))
|
||||
if requireCheckpoint {
|
||||
opts = append(opts, WithParentRequireCheckpoint())
|
||||
}
|
||||
if s := execute.GetStaticConfig(); s != nil && s.MaxNodeCountPerWorkflow > 0 {
|
||||
opts = append(opts, WithMaxNodeCount(s.MaxNodeCountPerWorkflow))
|
||||
}
|
||||
wf, err := NewWorkflow(ctx, s.SubWorkflowSchema, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &subworkflow.Config{
|
||||
Runner: wf.Runner,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func totRetrievalSearchType(s int64) (crossknowledge.SearchType, error) {
|
||||
switch s {
|
||||
case 0:
|
||||
return crossknowledge.SearchTypeSemantic, nil
|
||||
case 1:
|
||||
return crossknowledge.SearchTypeHybrid, nil
|
||||
case 20:
|
||||
return crossknowledge.SearchTypeFullText, nil
|
||||
default:
|
||||
return "", fmt.Errorf("invalid retrieval search type %v", s)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,107 +0,0 @@
|
|||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package compose
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
)
|
||||
|
||||
func getKeyOrZero[T any](key string, cfg any) T {
|
||||
var zero T
|
||||
if cfg == nil {
|
||||
return zero
|
||||
}
|
||||
|
||||
m, ok := cfg.(map[string]any)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("m is not a map[string]any, actual type: %v", reflect.TypeOf(cfg)))
|
||||
}
|
||||
|
||||
if len(m) == 0 {
|
||||
return zero
|
||||
}
|
||||
|
||||
if v, ok := m[key]; ok {
|
||||
return v.(T)
|
||||
}
|
||||
|
||||
return zero
|
||||
}
|
||||
|
||||
func mustGetKey[T any](key string, cfg any) T {
|
||||
if cfg == nil {
|
||||
panic(fmt.Sprintf("mustGetKey[*any] is nil, key=%s", key))
|
||||
}
|
||||
|
||||
m, ok := cfg.(map[string]any)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("m is not a map[string]any, actual type: %v", reflect.TypeOf(cfg)))
|
||||
}
|
||||
|
||||
if _, ok := m[key]; !ok {
|
||||
panic(fmt.Sprintf("key %s does not exist in map: %v", key, m))
|
||||
}
|
||||
|
||||
v, ok := m[key].(T)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("key %s is not a %v, actual type: %v", key, reflect.TypeOf(v), reflect.TypeOf(m[key])))
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
func (s *NodeSchema) SetConfigKV(key string, value any) {
|
||||
if s.Configs == nil {
|
||||
s.Configs = make(map[string]any)
|
||||
}
|
||||
|
||||
s.Configs.(map[string]any)[key] = value
|
||||
}
|
||||
|
||||
func (s *NodeSchema) SetInputType(key string, t *vo.TypeInfo) {
|
||||
if s.InputTypes == nil {
|
||||
s.InputTypes = make(map[string]*vo.TypeInfo)
|
||||
}
|
||||
s.InputTypes[key] = t
|
||||
}
|
||||
|
||||
func (s *NodeSchema) AddInputSource(info ...*vo.FieldInfo) {
|
||||
s.InputSources = append(s.InputSources, info...)
|
||||
}
|
||||
|
||||
func (s *NodeSchema) SetOutputType(key string, t *vo.TypeInfo) {
|
||||
if s.OutputTypes == nil {
|
||||
s.OutputTypes = make(map[string]*vo.TypeInfo)
|
||||
}
|
||||
s.OutputTypes[key] = t
|
||||
}
|
||||
|
||||
func (s *NodeSchema) AddOutputSource(info ...*vo.FieldInfo) {
|
||||
s.OutputSources = append(s.OutputSources, info...)
|
||||
}
|
||||
|
||||
func (s *NodeSchema) GetSubWorkflowIdentity() (int64, string, bool) {
|
||||
if s.Type != entity.NodeTypeSubWorkflow {
|
||||
return 0, "", false
|
||||
}
|
||||
|
||||
return mustGetKey[int64]("WorkflowID", s.Configs), mustGetKey[string]("WorkflowVersion", s.Configs), true
|
||||
}
|
||||
|
|
@ -29,6 +29,8 @@ import (
|
|||
workflow2 "github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
)
|
||||
|
||||
|
|
@ -37,7 +39,7 @@ type workflow = compose.Workflow[map[string]any, map[string]any]
|
|||
type Workflow struct { // TODO: too many fields in this struct, cut them down to the absolutely essentials
|
||||
*workflow
|
||||
hierarchy map[vo.NodeKey]vo.NodeKey
|
||||
connections []*Connection
|
||||
connections []*schema.Connection
|
||||
requireCheckpoint bool
|
||||
entry *compose.WorkflowNode
|
||||
inner bool
|
||||
|
|
@ -47,7 +49,7 @@ type Workflow struct { // TODO: too many fields in this struct, cut them down to
|
|||
input map[string]*vo.TypeInfo
|
||||
output map[string]*vo.TypeInfo
|
||||
terminatePlan vo.TerminatePlan
|
||||
schema *WorkflowSchema
|
||||
schema *schema.WorkflowSchema
|
||||
}
|
||||
|
||||
type workflowOptions struct {
|
||||
|
|
@ -78,7 +80,7 @@ func WithMaxNodeCount(c int) WorkflowOption {
|
|||
}
|
||||
}
|
||||
|
||||
func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption) (*Workflow, error) {
|
||||
func NewWorkflow(ctx context.Context, sc *schema.WorkflowSchema, opts ...WorkflowOption) (*Workflow, error) {
|
||||
sc.Init()
|
||||
|
||||
wf := &Workflow{
|
||||
|
|
@ -88,8 +90,8 @@ func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption
|
|||
schema: sc,
|
||||
}
|
||||
|
||||
wf.streamRun = sc.requireStreaming
|
||||
wf.requireCheckpoint = sc.requireCheckPoint
|
||||
wf.streamRun = sc.RequireStreaming()
|
||||
wf.requireCheckpoint = sc.RequireCheckpoint()
|
||||
|
||||
wfOpts := &workflowOptions{}
|
||||
for _, opt := range opts {
|
||||
|
|
@ -125,7 +127,6 @@ func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption
|
|||
processedNodeKey[child.Key] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// add all nodes other than composite nodes and their children
|
||||
for _, ns := range sc.Nodes {
|
||||
if _, ok := processedNodeKey[ns.Key]; !ok {
|
||||
|
|
@ -135,7 +136,7 @@ func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption
|
|||
}
|
||||
|
||||
if ns.Type == entity.NodeTypeExit {
|
||||
wf.terminatePlan = mustGetKey[vo.TerminatePlan]("TerminalPlan", ns.Configs)
|
||||
wf.terminatePlan = ns.Configs.(*exit.Config).TerminatePlan
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -147,7 +148,7 @@ func NewWorkflow(ctx context.Context, sc *WorkflowSchema, opts ...WorkflowOption
|
|||
compileOpts = append(compileOpts, compose.WithGraphName(strconv.FormatInt(wfOpts.wfID, 10)))
|
||||
}
|
||||
|
||||
fanInConfigs := sc.fanInMergeConfigs()
|
||||
fanInConfigs := sc.FanInMergeConfigs()
|
||||
if len(fanInConfigs) > 0 {
|
||||
compileOpts = append(compileOpts, compose.WithFanInMergeConfig(fanInConfigs))
|
||||
}
|
||||
|
|
@ -199,12 +200,12 @@ type innerWorkflowInfo struct {
|
|||
carryOvers map[vo.NodeKey][]*compose.FieldMapping
|
||||
}
|
||||
|
||||
func (w *Workflow) AddNode(ctx context.Context, ns *NodeSchema) error {
|
||||
func (w *Workflow) AddNode(ctx context.Context, ns *schema.NodeSchema) error {
|
||||
_, err := w.addNodeInternal(ctx, ns, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *Workflow) AddCompositeNode(ctx context.Context, cNode *CompositeNode) error {
|
||||
func (w *Workflow) AddCompositeNode(ctx context.Context, cNode *schema.CompositeNode) error {
|
||||
inner, err := w.getInnerWorkflow(ctx, cNode)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -213,11 +214,11 @@ func (w *Workflow) AddCompositeNode(ctx context.Context, cNode *CompositeNode) e
|
|||
return err
|
||||
}
|
||||
|
||||
func (w *Workflow) addInnerNode(ctx context.Context, cNode *NodeSchema) (map[vo.NodeKey][]*compose.FieldMapping, error) {
|
||||
func (w *Workflow) addInnerNode(ctx context.Context, cNode *schema.NodeSchema) (map[vo.NodeKey][]*compose.FieldMapping, error) {
|
||||
return w.addNodeInternal(ctx, cNode, nil)
|
||||
}
|
||||
|
||||
func (w *Workflow) addNodeInternal(ctx context.Context, ns *NodeSchema, inner *innerWorkflowInfo) (map[vo.NodeKey][]*compose.FieldMapping, error) {
|
||||
func (w *Workflow) addNodeInternal(ctx context.Context, ns *schema.NodeSchema, inner *innerWorkflowInfo) (map[vo.NodeKey][]*compose.FieldMapping, error) {
|
||||
key := ns.Key
|
||||
var deps *dependencyInfo
|
||||
|
||||
|
|
@ -237,7 +238,7 @@ func (w *Workflow) addNodeInternal(ctx context.Context, ns *NodeSchema, inner *i
|
|||
innerWorkflow = inner.inner
|
||||
}
|
||||
|
||||
ins, err := ns.New(ctx, innerWorkflow, w.schema, deps)
|
||||
ins, err := New(ctx, ns, innerWorkflow, w.schema, deps)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -245,12 +246,12 @@ func (w *Workflow) addNodeInternal(ctx context.Context, ns *NodeSchema, inner *i
|
|||
var opts []compose.GraphAddNodeOpt
|
||||
opts = append(opts, compose.WithNodeName(string(ns.Key)))
|
||||
|
||||
preHandler := ns.StatePreHandler(w.streamRun)
|
||||
preHandler := statePreHandler(ns, w.streamRun)
|
||||
if preHandler != nil {
|
||||
opts = append(opts, preHandler)
|
||||
}
|
||||
|
||||
postHandler := ns.StatePostHandler(w.streamRun)
|
||||
postHandler := statePostHandler(ns, w.streamRun)
|
||||
if postHandler != nil {
|
||||
opts = append(opts, postHandler)
|
||||
}
|
||||
|
|
@ -297,19 +298,23 @@ func (w *Workflow) addNodeInternal(ctx context.Context, ns *NodeSchema, inner *i
|
|||
w.entry = wNode
|
||||
}
|
||||
|
||||
outputPortCount, hasExceptionPort := ns.OutputPortCount()
|
||||
if outputPortCount > 1 || hasExceptionPort {
|
||||
bMapping, err := w.resolveBranch(key, outputPortCount)
|
||||
b := w.schema.GetBranch(ns.Key)
|
||||
if b != nil {
|
||||
if b.OnlyException() {
|
||||
_ = w.AddBranch(string(key), b.GetExceptionBranch())
|
||||
} else {
|
||||
bb, ok := ns.Configs.(schema.BranchBuilder)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("node schema's Configs should implement BranchBuilder, node type= %v", ns.Type)
|
||||
}
|
||||
|
||||
br, err := b.GetFullBranch(ctx, bb)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
branch, err := ns.GetBranch(bMapping)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
_ = w.AddBranch(string(key), br)
|
||||
}
|
||||
|
||||
_ = w.AddBranch(string(key), branch)
|
||||
}
|
||||
|
||||
return deps.inputsForParent, nil
|
||||
|
|
@ -328,15 +333,15 @@ func (w *Workflow) Compile(ctx context.Context, opts ...compose.GraphCompileOpti
|
|||
return w.workflow.Compile(ctx, opts...)
|
||||
}
|
||||
|
||||
func (w *Workflow) getInnerWorkflow(ctx context.Context, cNode *CompositeNode) (*innerWorkflowInfo, error) {
|
||||
innerNodes := make(map[vo.NodeKey]*NodeSchema)
|
||||
func (w *Workflow) getInnerWorkflow(ctx context.Context, cNode *schema.CompositeNode) (*innerWorkflowInfo, error) {
|
||||
innerNodes := make(map[vo.NodeKey]*schema.NodeSchema)
|
||||
for _, n := range cNode.Children {
|
||||
innerNodes[n.Key] = n
|
||||
}
|
||||
|
||||
// trim the connections, only keep the connections that are related to the inner workflow
|
||||
// ignore the cases when we have nested inner workflows, because we do not support nested composite nodes
|
||||
innerConnections := make([]*Connection, 0)
|
||||
innerConnections := make([]*schema.Connection, 0)
|
||||
for i := range w.schema.Connections {
|
||||
conn := w.schema.Connections[i]
|
||||
if _, ok := innerNodes[conn.FromNode]; ok {
|
||||
|
|
@ -510,7 +515,7 @@ func (d *dependencyInfo) merge(mappings map[vo.NodeKey][]*compose.FieldMapping)
|
|||
// For example, if the 'from path' is ['a', 'b', 'c'], and 'b' is an array, we will take value using a.b[0].c.
|
||||
// As a counter example, if the 'from path' is ['a', 'b', 'c'], and 'b' is not an array, but 'c' is an array,
|
||||
// we will not try to drill, instead, just take value using a.b.c.
|
||||
func (d *dependencyInfo) arrayDrillDown(allNS map[vo.NodeKey]*NodeSchema) error {
|
||||
func (d *dependencyInfo) arrayDrillDown(allNS map[vo.NodeKey]*schema.NodeSchema) error {
|
||||
for nKey, fms := range d.inputs {
|
||||
if nKey == compose.START { // reference to START node would NEVER need to do array drill down
|
||||
continue
|
||||
|
|
@ -638,55 +643,6 @@ type variableInfo struct {
|
|||
toPath compose.FieldPath
|
||||
}
|
||||
|
||||
func (w *Workflow) resolveBranch(n vo.NodeKey, portCount int) (*BranchMapping, error) {
|
||||
m := make([]map[string]bool, portCount)
|
||||
var exception map[string]bool
|
||||
|
||||
for _, conn := range w.connections {
|
||||
if conn.FromNode != n {
|
||||
continue
|
||||
}
|
||||
|
||||
if conn.FromPort == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if *conn.FromPort == "default" { // default condition
|
||||
if m[portCount-1] == nil {
|
||||
m[portCount-1] = make(map[string]bool)
|
||||
}
|
||||
m[portCount-1][string(conn.ToNode)] = true
|
||||
} else if *conn.FromPort == "branch_error" {
|
||||
if exception == nil {
|
||||
exception = make(map[string]bool)
|
||||
}
|
||||
exception[string(conn.ToNode)] = true
|
||||
} else {
|
||||
if !strings.HasPrefix(*conn.FromPort, "branch_") {
|
||||
return nil, fmt.Errorf("outgoing connections has invalid port= %s", *conn.FromPort)
|
||||
}
|
||||
|
||||
index := (*conn.FromPort)[7:]
|
||||
i, err := strconv.Atoi(index)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("outgoing connections has invalid port index= %s", *conn.FromPort)
|
||||
}
|
||||
if i < 0 || i >= portCount {
|
||||
return nil, fmt.Errorf("outgoing connections has invalid port index range= %d, condition count= %d", i, portCount)
|
||||
}
|
||||
if m[i] == nil {
|
||||
m[i] = make(map[string]bool)
|
||||
}
|
||||
m[i][string(conn.ToNode)] = true
|
||||
}
|
||||
}
|
||||
|
||||
return &BranchMapping{
|
||||
Normal: m,
|
||||
Exception: exception,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.FieldInfo) (*dependencyInfo, error) {
|
||||
var (
|
||||
inputs = make(map[vo.NodeKey][]*compose.FieldMapping)
|
||||
|
|
@ -701,7 +657,7 @@ func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.Field
|
|||
inputsForParent = make(map[vo.NodeKey][]*compose.FieldMapping)
|
||||
)
|
||||
|
||||
connMap := make(map[vo.NodeKey]Connection)
|
||||
connMap := make(map[vo.NodeKey]schema.Connection)
|
||||
for _, conn := range w.connections {
|
||||
if conn.ToNode != n {
|
||||
continue
|
||||
|
|
@ -734,7 +690,7 @@ func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.Field
|
|||
continue
|
||||
}
|
||||
|
||||
if ok := isInSameWorkflow(w.hierarchy, n, fromNode); ok {
|
||||
if ok := schema.IsInSameWorkflow(w.hierarchy, n, fromNode); ok {
|
||||
if _, ok := connMap[fromNode]; ok { // direct dependency
|
||||
if len(swp.Source.Ref.FromPath) == 0 && len(swp.Path) == 0 {
|
||||
if inputFull == nil {
|
||||
|
|
@ -755,10 +711,10 @@ func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.Field
|
|||
compose.MapFieldPaths(swp.Source.Ref.FromPath, swp.Path))
|
||||
}
|
||||
}
|
||||
} else if ok := isBelowOneLevel(w.hierarchy, n, fromNode); ok {
|
||||
} else if ok := schema.IsBelowOneLevel(w.hierarchy, n, fromNode); ok {
|
||||
firstNodesInInnerWorkflow := true
|
||||
for _, conn := range connMap {
|
||||
if isInSameWorkflow(w.hierarchy, n, conn.FromNode) {
|
||||
if schema.IsInSameWorkflow(w.hierarchy, n, conn.FromNode) {
|
||||
// there is another node 'conn.FromNode' that connects to this node, while also at the same level
|
||||
firstNodesInInnerWorkflow = false
|
||||
break
|
||||
|
|
@ -805,9 +761,9 @@ func (w *Workflow) resolveDependencies(n vo.NodeKey, sourceWithPaths []*vo.Field
|
|||
continue
|
||||
}
|
||||
|
||||
if isBelowOneLevel(w.hierarchy, n, fromNodeKey) {
|
||||
if schema.IsBelowOneLevel(w.hierarchy, n, fromNodeKey) {
|
||||
fromNodeKey = compose.START
|
||||
} else if !isInSameWorkflow(w.hierarchy, n, fromNodeKey) {
|
||||
} else if !schema.IsInSameWorkflow(w.hierarchy, n, fromNodeKey) {
|
||||
continue
|
||||
}
|
||||
|
||||
|
|
@ -864,13 +820,13 @@ func (w *Workflow) resolveDependenciesAsParent(n vo.NodeKey, sourceWithPaths []*
|
|||
variableInfos []*variableInfo
|
||||
)
|
||||
|
||||
connMap := make(map[vo.NodeKey]Connection)
|
||||
connMap := make(map[vo.NodeKey]schema.Connection)
|
||||
for _, conn := range w.connections {
|
||||
if conn.ToNode != n {
|
||||
continue
|
||||
}
|
||||
|
||||
if isInSameWorkflow(w.hierarchy, conn.FromNode, n) {
|
||||
if schema.IsInSameWorkflow(w.hierarchy, conn.FromNode, n) {
|
||||
continue
|
||||
}
|
||||
|
||||
|
|
@ -899,7 +855,7 @@ func (w *Workflow) resolveDependenciesAsParent(n vo.NodeKey, sourceWithPaths []*
|
|||
swp.Source.Ref.FromPath, swp.Path)
|
||||
}
|
||||
|
||||
if ok := isParentOf(w.hierarchy, n, fromNode); ok {
|
||||
if ok := schema.IsParentOf(w.hierarchy, n, fromNode); ok {
|
||||
if _, ok := connMap[fromNode]; ok { // direct dependency
|
||||
inputs[fromNode] = append(inputs[fromNode], compose.MapFieldPaths(swp.Source.Ref.FromPath, append(compose.FieldPath{string(fromNode)}, swp.Source.Ref.FromPath...)))
|
||||
} else { // indirect dependency
|
||||
|
|
|
|||
|
|
@ -23,9 +23,10 @@ import (
|
|||
|
||||
workflow2 "github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
func NewWorkflowFromNode(ctx context.Context, sc *WorkflowSchema, nodeKey vo.NodeKey, opts ...compose.GraphCompileOption) (
|
||||
func NewWorkflowFromNode(ctx context.Context, sc *schema.WorkflowSchema, nodeKey vo.NodeKey, opts ...compose.GraphCompileOption) (
|
||||
*Workflow, error) {
|
||||
sc.Init()
|
||||
ns := sc.GetNode(nodeKey)
|
||||
|
|
@ -37,7 +38,7 @@ func NewWorkflowFromNode(ctx context.Context, sc *WorkflowSchema, nodeKey vo.Nod
|
|||
schema: sc,
|
||||
fromNode: true,
|
||||
streamRun: false, // single node run can only invoke
|
||||
requireCheckpoint: sc.requireCheckPoint,
|
||||
requireCheckpoint: sc.RequireCheckpoint(),
|
||||
input: ns.InputTypes,
|
||||
output: ns.OutputTypes,
|
||||
terminatePlan: vo.ReturnVariables,
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ import (
|
|||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
|
||||
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
|
|
@ -42,7 +43,7 @@ type WorkflowRunner struct {
|
|||
basic *entity.WorkflowBasic
|
||||
input string
|
||||
resumeReq *entity.ResumeRequest
|
||||
schema *WorkflowSchema
|
||||
schema *schema2.WorkflowSchema
|
||||
streamWriter *schema.StreamWriter[*entity.Message]
|
||||
config vo.ExecuteConfig
|
||||
|
||||
|
|
@ -76,7 +77,7 @@ func WithStreamWriter(sw *schema.StreamWriter[*entity.Message]) WorkflowRunnerOp
|
|||
}
|
||||
}
|
||||
|
||||
func NewWorkflowRunner(b *entity.WorkflowBasic, sc *WorkflowSchema, config vo.ExecuteConfig, opts ...WorkflowRunnerOption) *WorkflowRunner {
|
||||
func NewWorkflowRunner(b *entity.WorkflowBasic, sc *schema2.WorkflowSchema, config vo.ExecuteConfig, opts ...WorkflowRunnerOption) *WorkflowRunner {
|
||||
options := &workflowRunOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ import (
|
|||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
)
|
||||
|
|
@ -41,7 +42,7 @@ type invokableWorkflow struct {
|
|||
invoke func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (map[string]any, error)
|
||||
terminatePlan vo.TerminatePlan
|
||||
wfEntity *entity.Workflow
|
||||
sc *WorkflowSchema
|
||||
sc *schema2.WorkflowSchema
|
||||
repo wf.Repository
|
||||
}
|
||||
|
||||
|
|
@ -49,7 +50,7 @@ func NewInvokableWorkflow(info *schema.ToolInfo,
|
|||
invoke func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (map[string]any, error),
|
||||
terminatePlan vo.TerminatePlan,
|
||||
wfEntity *entity.Workflow,
|
||||
sc *WorkflowSchema,
|
||||
sc *schema2.WorkflowSchema,
|
||||
repo wf.Repository,
|
||||
) wf.ToolFromWorkflow {
|
||||
return &invokableWorkflow{
|
||||
|
|
@ -112,7 +113,7 @@ func (i *invokableWorkflow) InvokableRun(ctx context.Context, argumentsInJSON st
|
|||
return "", err
|
||||
}
|
||||
|
||||
var entryNode *NodeSchema
|
||||
var entryNode *schema2.NodeSchema
|
||||
for _, node := range i.sc.Nodes {
|
||||
if node.Type == entity.NodeTypeEntry {
|
||||
entryNode = node
|
||||
|
|
@ -190,7 +191,7 @@ type streamableWorkflow struct {
|
|||
stream func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (*schema.StreamReader[map[string]any], error)
|
||||
terminatePlan vo.TerminatePlan
|
||||
wfEntity *entity.Workflow
|
||||
sc *WorkflowSchema
|
||||
sc *schema2.WorkflowSchema
|
||||
repo wf.Repository
|
||||
}
|
||||
|
||||
|
|
@ -198,7 +199,7 @@ func NewStreamableWorkflow(info *schema.ToolInfo,
|
|||
stream func(ctx context.Context, input map[string]any, opts ...einoCompose.Option) (*schema.StreamReader[map[string]any], error),
|
||||
terminatePlan vo.TerminatePlan,
|
||||
wfEntity *entity.Workflow,
|
||||
sc *WorkflowSchema,
|
||||
sc *schema2.WorkflowSchema,
|
||||
repo wf.Repository,
|
||||
) wf.ToolFromWorkflow {
|
||||
return &streamableWorkflow{
|
||||
|
|
@ -261,7 +262,7 @@ func (s *streamableWorkflow) StreamableRun(ctx context.Context, argumentsInJSON
|
|||
return nil, err
|
||||
}
|
||||
|
||||
var entryNode *NodeSchema
|
||||
var entryNode *schema2.NodeSchema
|
||||
for _, node := range s.sc.Nodes {
|
||||
if node.Type == entity.NodeTypeEntry {
|
||||
entryNode = node
|
||||
|
|
|
|||
|
|
@ -30,50 +30,108 @@ import (
|
|||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
)
|
||||
|
||||
type Batch struct {
|
||||
config *Config
|
||||
outputs map[string]*vo.FieldSource
|
||||
innerWorkflow compose.Runnable[map[string]any, map[string]any]
|
||||
key vo.NodeKey
|
||||
inputArrays []string
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
BatchNodeKey vo.NodeKey `json:"batch_node_key"`
|
||||
InnerWorkflow compose.Runnable[map[string]any, map[string]any]
|
||||
type Config struct{}
|
||||
|
||||
InputArrays []string `json:"input_arrays"`
|
||||
Outputs []*vo.FieldInfo `json:"outputs"`
|
||||
}
|
||||
|
||||
func NewBatch(_ context.Context, config *Config) (*Batch, error) {
|
||||
if config == nil {
|
||||
return nil, errors.New("config is required")
|
||||
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
if n.Parent() != nil {
|
||||
return nil, fmt.Errorf("batch node cannot have parent: %s", n.Parent().ID)
|
||||
}
|
||||
|
||||
if len(config.InputArrays) == 0 {
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeBatch,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: c,
|
||||
}
|
||||
|
||||
batchSizeField, err := convert.CanvasBlockInputToFieldInfo(n.Data.Inputs.BatchSize,
|
||||
compose.FieldPath{MaxBatchSizeKey}, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ns.AddInputSource(batchSizeField...)
|
||||
concurrentSizeField, err := convert.CanvasBlockInputToFieldInfo(n.Data.Inputs.ConcurrentSize,
|
||||
compose.FieldPath{ConcurrentSizeKey}, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ns.AddInputSource(concurrentSizeField...)
|
||||
|
||||
batchSizeType, err := convert.CanvasBlockInputToTypeInfo(n.Data.Inputs.BatchSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ns.SetInputType(MaxBatchSizeKey, batchSizeType)
|
||||
concurrentSizeType, err := convert.CanvasBlockInputToTypeInfo(n.Data.Inputs.ConcurrentSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ns.SetInputType(ConcurrentSizeKey, concurrentSizeType)
|
||||
|
||||
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := convert.SetOutputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, opts ...schema.BuildOption) (any, error) {
|
||||
var inputArrays []string
|
||||
for key, tInfo := range ns.InputTypes {
|
||||
if tInfo.Type != vo.DataTypeArray {
|
||||
continue
|
||||
}
|
||||
|
||||
inputArrays = append(inputArrays, key)
|
||||
}
|
||||
|
||||
if len(inputArrays) == 0 {
|
||||
return nil, errors.New("need to have at least one incoming array for batch")
|
||||
}
|
||||
|
||||
if len(config.Outputs) == 0 {
|
||||
if len(ns.OutputSources) == 0 {
|
||||
return nil, errors.New("need to have at least one output variable for batch")
|
||||
}
|
||||
|
||||
b := &Batch{
|
||||
config: config,
|
||||
outputs: make(map[string]*vo.FieldSource),
|
||||
bo := schema.GetBuildOptions(opts...)
|
||||
if bo.Inner == nil {
|
||||
return nil, errors.New("need to have inner workflow for batch")
|
||||
}
|
||||
|
||||
for i := range config.Outputs {
|
||||
source := config.Outputs[i]
|
||||
b := &Batch{
|
||||
outputs: make(map[string]*vo.FieldSource),
|
||||
innerWorkflow: bo.Inner,
|
||||
key: ns.Key,
|
||||
inputArrays: inputArrays,
|
||||
}
|
||||
|
||||
for i := range ns.OutputSources {
|
||||
source := ns.OutputSources[i]
|
||||
path := source.Path
|
||||
if len(path) != 1 {
|
||||
return nil, fmt.Errorf("invalid path %q", path)
|
||||
}
|
||||
|
||||
// from which inner node's which field does the batch's output fields come from
|
||||
b.outputs[path[0]] = &source.Source
|
||||
}
|
||||
|
||||
|
|
@ -97,11 +155,11 @@ func (b *Batch) initOutput(length int) map[string]any {
|
|||
return out
|
||||
}
|
||||
|
||||
func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.NestedWorkflowOption) (
|
||||
func (b *Batch) Invoke(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (
|
||||
out map[string]any, err error) {
|
||||
arrays := make(map[string]any, len(b.config.InputArrays))
|
||||
arrays := make(map[string]any, len(b.inputArrays))
|
||||
minLen := math.MaxInt64
|
||||
for _, arrayKey := range b.config.InputArrays {
|
||||
for _, arrayKey := range b.inputArrays {
|
||||
a, ok := nodes.TakeMapValue(in, compose.FieldPath{arrayKey})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("incoming array not present in input: %s", arrayKey)
|
||||
|
|
@ -160,13 +218,13 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
|
|||
}
|
||||
}
|
||||
|
||||
input[string(b.config.BatchNodeKey)+"#index"] = int64(i)
|
||||
input[string(b.key)+"#index"] = int64(i)
|
||||
|
||||
items := make(map[string]any)
|
||||
for arrayKey, array := range arrays {
|
||||
ele := reflect.ValueOf(array).Index(i).Interface()
|
||||
items[arrayKey] = []any{ele}
|
||||
currentKey := string(b.config.BatchNodeKey) + "#" + arrayKey
|
||||
currentKey := string(b.key) + "#" + arrayKey
|
||||
|
||||
// Recursively expand map[string]any elements
|
||||
var expand func(prefix string, val interface{})
|
||||
|
|
@ -200,15 +258,11 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
|
|||
return nil
|
||||
}
|
||||
|
||||
options := &nodes.NestedWorkflowOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
options := nodes.GetCommonOptions(&nodes.NodeOptions{}, opts...)
|
||||
var existingCState *nodes.NestedWorkflowState
|
||||
err = compose.ProcessState(ctx, func(ctx context.Context, getter nodes.NestedWorkflowAware) error {
|
||||
var e error
|
||||
existingCState, _, e = getter.GetNestedWorkflowState(b.config.BatchNodeKey)
|
||||
existingCState, _, e = getter.GetNestedWorkflowState(b.key)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
|
|
@ -280,7 +334,7 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
|
|||
mu.Unlock()
|
||||
if subCheckpointID != "" {
|
||||
logs.CtxInfof(ctx, "[testInterrupt] prepare %d th run for batch node %s, subCheckPointID %s",
|
||||
i, b.config.BatchNodeKey, subCheckpointID)
|
||||
i, b.key, subCheckpointID)
|
||||
ithOpts = append(ithOpts, compose.WithCheckPointID(subCheckpointID))
|
||||
}
|
||||
|
||||
|
|
@ -298,7 +352,7 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
|
|||
|
||||
// if the innerWorkflow has output emitter that requires stream output, then we need to stream the inner workflow
|
||||
// the output then needs to be concatenated.
|
||||
taskOutput, err := b.config.InnerWorkflow.Invoke(subCtx, input, ithOpts...)
|
||||
taskOutput, err := b.innerWorkflow.Invoke(subCtx, input, ithOpts...)
|
||||
if err != nil {
|
||||
info, ok := compose.ExtractInterruptInfo(err)
|
||||
if !ok {
|
||||
|
|
@ -376,17 +430,17 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
|
|||
|
||||
if len(index2InterruptInfo) > 0 { // this invocation of batch.Execute has new interruptions
|
||||
iEvent := &entity.InterruptEvent{
|
||||
NodeKey: b.config.BatchNodeKey,
|
||||
NodeKey: b.key,
|
||||
NodeType: entity.NodeTypeBatch,
|
||||
NestedInterruptInfo: index2InterruptInfo, // only emit the newly generated interruptInfo
|
||||
}
|
||||
|
||||
err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error {
|
||||
if e := setter.SaveNestedWorkflowState(b.config.BatchNodeKey, compState); e != nil {
|
||||
if e := setter.SaveNestedWorkflowState(b.key, compState); e != nil {
|
||||
return e
|
||||
}
|
||||
|
||||
return setter.SetInterruptEvent(b.config.BatchNodeKey, iEvent)
|
||||
return setter.SetInterruptEvent(b.key, iEvent)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -398,7 +452,7 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
|
|||
return nil, compose.InterruptAndRerun
|
||||
} else {
|
||||
err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error {
|
||||
if e := setter.SaveNestedWorkflowState(b.config.BatchNodeKey, compState); e != nil {
|
||||
if e := setter.SaveNestedWorkflowState(b.key, compState); e != nil {
|
||||
return e
|
||||
}
|
||||
|
||||
|
|
@ -409,8 +463,8 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
|
|||
// although this invocation does not have new interruptions,
|
||||
// this batch node previously have interrupts yet to be resumed.
|
||||
// we overwrite the interrupt events, keeping only the interrupts yet to be resumed.
|
||||
return setter.SetInterruptEvent(b.config.BatchNodeKey, &entity.InterruptEvent{
|
||||
NodeKey: b.config.BatchNodeKey,
|
||||
return setter.SetInterruptEvent(b.key, &entity.InterruptEvent{
|
||||
NodeKey: b.key,
|
||||
NodeType: entity.NodeTypeBatch,
|
||||
NestedInterruptInfo: existingCState.Index2InterruptInfo,
|
||||
})
|
||||
|
|
@ -424,7 +478,7 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
|
|||
|
||||
if existingCState != nil && len(existingCState.Index2InterruptInfo) > 0 {
|
||||
logs.CtxInfof(ctx, "no interrupt thrown this round, but has historical interrupt events yet to be resumed, "+
|
||||
"nodeKey: %v. indexes: %v", b.config.BatchNodeKey, maps.Keys(existingCState.Index2InterruptInfo))
|
||||
"nodeKey: %v. indexes: %v", b.key, maps.Keys(existingCState.Index2InterruptInfo))
|
||||
return nil, compose.InterruptAndRerun // interrupt again to wait for resuming of previously interrupted index runs
|
||||
}
|
||||
|
||||
|
|
@ -432,8 +486,8 @@ func (b *Batch) Execute(ctx context.Context, in map[string]any, opts ...nodes.Ne
|
|||
}
|
||||
|
||||
func (b *Batch) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) {
|
||||
trimmed := make(map[string]any, len(b.config.InputArrays))
|
||||
for _, arrayKey := range b.config.InputArrays {
|
||||
trimmed := make(map[string]any, len(b.inputArrays))
|
||||
for _, arrayKey := range b.inputArrays {
|
||||
if v, ok := in[arrayKey]; ok {
|
||||
trimmed[arrayKey] = v
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,6 +25,10 @@ import (
|
|||
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
code2 "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/code"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/coderunner"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
|
|
@ -115,48 +119,75 @@ var pythonThirdPartyWhitelist = map[string]struct{}{
|
|||
type Config struct {
|
||||
Code string
|
||||
Language coderunner.Language
|
||||
OutputConfig map[string]*vo.TypeInfo
|
||||
|
||||
Runner coderunner.Runner
|
||||
}
|
||||
|
||||
type CodeRunner struct {
|
||||
config *Config
|
||||
importError error
|
||||
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeCodeRunner,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: c,
|
||||
}
|
||||
inputs := n.Data.Inputs
|
||||
|
||||
code := inputs.Code
|
||||
c.Code = code
|
||||
|
||||
language, err := convertCodeLanguage(inputs.Language)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.Language = language
|
||||
|
||||
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func NewCodeRunner(ctx context.Context, cfg *Config) (*CodeRunner, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("cfg is required")
|
||||
func convertCodeLanguage(l int64) (coderunner.Language, error) {
|
||||
switch l {
|
||||
case 5:
|
||||
return coderunner.JavaScript, nil
|
||||
case 3:
|
||||
return coderunner.Python, nil
|
||||
default:
|
||||
return "", fmt.Errorf("invalid language: %d", l)
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.Language == "" {
|
||||
return nil, errors.New("language is required")
|
||||
}
|
||||
func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
|
||||
if cfg.Code == "" {
|
||||
return nil, errors.New("code is required")
|
||||
}
|
||||
|
||||
if cfg.Language != coderunner.Python {
|
||||
if c.Language != coderunner.Python {
|
||||
return nil, errors.New("only support python language")
|
||||
}
|
||||
|
||||
if len(cfg.OutputConfig) == 0 {
|
||||
return nil, errors.New("output config is required")
|
||||
}
|
||||
importErr := validatePythonImports(c.Code)
|
||||
|
||||
if cfg.Runner == nil {
|
||||
return nil, errors.New("run coder is required")
|
||||
}
|
||||
|
||||
importErr := validatePythonImports(cfg.Code)
|
||||
|
||||
return &CodeRunner{
|
||||
config: cfg,
|
||||
return &Runner{
|
||||
code: c.Code,
|
||||
language: c.Language,
|
||||
outputConfig: ns.OutputTypes,
|
||||
runner: code2.GetCodeRunner(),
|
||||
importError: importErr,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type Runner struct {
|
||||
outputConfig map[string]*vo.TypeInfo
|
||||
code string
|
||||
language coderunner.Language
|
||||
runner coderunner.Runner
|
||||
importError error
|
||||
}
|
||||
|
||||
func validatePythonImports(code string) error {
|
||||
imports := parsePythonImports(code)
|
||||
importErrors := make([]string, 0)
|
||||
|
|
@ -191,11 +222,11 @@ func validatePythonImports(code string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *CodeRunner) RunCode(ctx context.Context, input map[string]any) (ret map[string]any, err error) {
|
||||
func (c *Runner) Invoke(ctx context.Context, input map[string]any) (ret map[string]any, err error) {
|
||||
if c.importError != nil {
|
||||
return nil, vo.WrapError(errno.ErrCodeExecuteFail, c.importError, errorx.KV("detail", c.importError.Error()))
|
||||
}
|
||||
response, err := c.config.Runner.Run(ctx, &coderunner.RunRequest{Code: c.config.Code, Language: c.config.Language, Params: input})
|
||||
response, err := c.runner.Run(ctx, &coderunner.RunRequest{Code: c.code, Language: c.language, Params: input})
|
||||
if err != nil {
|
||||
return nil, vo.WrapError(errno.ErrCodeExecuteFail, err, errorx.KV("detail", err.Error()))
|
||||
}
|
||||
|
|
@ -203,7 +234,7 @@ func (c *CodeRunner) RunCode(ctx context.Context, input map[string]any) (ret map
|
|||
result := response.Result
|
||||
ctxcache.Store(ctx, coderRunnerRawOutputCtxKey, result)
|
||||
|
||||
output, ws, err := nodes.ConvertInputs(ctx, result, c.config.OutputConfig)
|
||||
output, ws, err := nodes.ConvertInputs(ctx, result, c.outputConfig)
|
||||
if err != nil {
|
||||
return nil, vo.WrapIfNeeded(errno.ErrCodeExecuteFail, err, errorx.KV("detail", err.Error()))
|
||||
}
|
||||
|
|
@ -217,7 +248,7 @@ func (c *CodeRunner) RunCode(ctx context.Context, input map[string]any) (ret map
|
|||
|
||||
}
|
||||
|
||||
func (c *CodeRunner) ToCallbackOutput(ctx context.Context, output map[string]any) (*nodes.StructuredCallbackOutput, error) {
|
||||
func (c *Runner) ToCallbackOutput(ctx context.Context, output map[string]any) (*nodes.StructuredCallbackOutput, error) {
|
||||
rawOutput, ok := ctxcache.Get[map[string]any](ctx, coderRunnerRawOutputCtxKey)
|
||||
if !ok {
|
||||
return nil, errors.New("raw output config is required")
|
||||
|
|
|
|||
|
|
@ -75,30 +75,29 @@ async def main(args:Args)->Output:
|
|||
|
||||
mockRunner.EXPECT().Run(gomock.Any(), gomock.Any()).Return(response, nil)
|
||||
ctx := t.Context()
|
||||
c := &CodeRunner{
|
||||
config: &Config{
|
||||
Language: coderunner.Python,
|
||||
Code: codeTpl,
|
||||
OutputConfig: map[string]*vo.TypeInfo{
|
||||
c := &Runner{
|
||||
language: coderunner.Python,
|
||||
code: codeTpl,
|
||||
outputConfig: map[string]*vo.TypeInfo{
|
||||
"key0": {Type: vo.DataTypeInteger},
|
||||
"key1": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeString}},
|
||||
"key2": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
"key3": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"key31": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key32": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key33": &vo.TypeInfo{Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
"key34": &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"key341": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key342": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key31": {Type: vo.DataTypeString},
|
||||
"key32": {Type: vo.DataTypeString},
|
||||
"key33": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
"key34": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"key341": {Type: vo.DataTypeString},
|
||||
"key342": {Type: vo.DataTypeString},
|
||||
}},
|
||||
},
|
||||
},
|
||||
"key4": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject}},
|
||||
},
|
||||
Runner: mockRunner,
|
||||
},
|
||||
runner: mockRunner,
|
||||
}
|
||||
ret, err := c.RunCode(ctx, map[string]any{
|
||||
|
||||
ret, err := c.Invoke(ctx, map[string]any{
|
||||
"input": "1123",
|
||||
})
|
||||
|
||||
|
|
@ -145,38 +144,36 @@ async def main(args:Args)->Output:
|
|||
mockRunner.EXPECT().Run(gomock.Any(), gomock.Any()).Return(response, nil)
|
||||
|
||||
ctx := t.Context()
|
||||
c := &CodeRunner{
|
||||
config: &Config{
|
||||
Code: codeTpl,
|
||||
Language: coderunner.Python,
|
||||
OutputConfig: map[string]*vo.TypeInfo{
|
||||
c := &Runner{
|
||||
code: codeTpl,
|
||||
language: coderunner.Python,
|
||||
outputConfig: map[string]*vo.TypeInfo{
|
||||
"key0": {Type: vo.DataTypeInteger},
|
||||
"key1": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeString}},
|
||||
"key2": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
"key3": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"key31": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key32": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key33": &vo.TypeInfo{Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
"key34": &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"key341": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key342": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key31": {Type: vo.DataTypeString},
|
||||
"key32": {Type: vo.DataTypeString},
|
||||
"key33": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
"key34": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"key341": {Type: vo.DataTypeString},
|
||||
"key342": {Type: vo.DataTypeString},
|
||||
}},
|
||||
}},
|
||||
"key4": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"key31": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key32": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key33": &vo.TypeInfo{Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
"key34": &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"key341": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key342": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key31": {Type: vo.DataTypeString},
|
||||
"key32": {Type: vo.DataTypeString},
|
||||
"key33": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
"key34": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"key341": {Type: vo.DataTypeString},
|
||||
"key342": {Type: vo.DataTypeString},
|
||||
},
|
||||
}},
|
||||
},
|
||||
},
|
||||
Runner: mockRunner,
|
||||
},
|
||||
runner: mockRunner,
|
||||
}
|
||||
ret, err := c.RunCode(ctx, map[string]any{
|
||||
ret, err := c.Invoke(ctx, map[string]any{
|
||||
"input": "1123",
|
||||
})
|
||||
|
||||
|
|
@ -219,30 +216,28 @@ async def main(args:Args)->Output:
|
|||
}
|
||||
mockRunner.EXPECT().Run(gomock.Any(), gomock.Any()).Return(response, nil)
|
||||
|
||||
c := &CodeRunner{
|
||||
config: &Config{
|
||||
Code: codeTpl,
|
||||
Language: coderunner.Python,
|
||||
OutputConfig: map[string]*vo.TypeInfo{
|
||||
c := &Runner{
|
||||
code: codeTpl,
|
||||
language: coderunner.Python,
|
||||
outputConfig: map[string]*vo.TypeInfo{
|
||||
"key0": {Type: vo.DataTypeInteger},
|
||||
"key1": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
"key2": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
"key3": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"key31": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key32": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key33": &vo.TypeInfo{Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
"key34": &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"key341": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key342": &vo.TypeInfo{Type: vo.DataTypeString},
|
||||
"key343": &vo.TypeInfo{Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
"key31": {Type: vo.DataTypeString},
|
||||
"key32": {Type: vo.DataTypeString},
|
||||
"key33": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
"key34": {Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"key341": {Type: vo.DataTypeString},
|
||||
"key342": {Type: vo.DataTypeString},
|
||||
"key343": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
}},
|
||||
},
|
||||
},
|
||||
},
|
||||
Runner: mockRunner,
|
||||
},
|
||||
runner: mockRunner,
|
||||
}
|
||||
ret, err := c.RunCode(ctx, map[string]any{
|
||||
ret, err := c.Invoke(ctx, map[string]any{
|
||||
"input": "1123",
|
||||
})
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,236 @@
|
|||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
einoCompose "github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
func setDatabaseInputsForNodeSchema(n *vo.Node, ns *schema.NodeSchema) (err error) {
|
||||
selectParam := n.Data.Inputs.SelectParam
|
||||
if selectParam != nil {
|
||||
err = applyDBConditionToSchema(ns, selectParam.Condition, n.Parent())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
insertParam := n.Data.Inputs.InsertParam
|
||||
if insertParam != nil {
|
||||
err = applyInsetFieldInfoToSchema(ns, insertParam.FieldInfo, n.Parent())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
deleteParam := n.Data.Inputs.DeleteParam
|
||||
if deleteParam != nil {
|
||||
err = applyDBConditionToSchema(ns, &deleteParam.Condition, n.Parent())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
updateParam := n.Data.Inputs.UpdateParam
|
||||
if updateParam != nil {
|
||||
err = applyDBConditionToSchema(ns, &updateParam.Condition, n.Parent())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = applyInsetFieldInfoToSchema(ns, updateParam.FieldInfo, n.Parent())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func applyDBConditionToSchema(ns *schema.NodeSchema, condition *vo.DBCondition, parentNode *vo.Node) error {
|
||||
if condition.ConditionList == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
for idx, params := range condition.ConditionList {
|
||||
var right *vo.Param
|
||||
for _, param := range params {
|
||||
if param == nil {
|
||||
continue
|
||||
}
|
||||
if param.Name == "right" {
|
||||
right = param
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if right == nil {
|
||||
continue
|
||||
}
|
||||
name := fmt.Sprintf("__condition_right_%d", idx)
|
||||
tInfo, err := convert.CanvasBlockInputToTypeInfo(right.Input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ns.SetInputType(name, tInfo)
|
||||
sources, err := convert.CanvasBlockInputToFieldInfo(right.Input, einoCompose.FieldPath{name}, parentNode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ns.AddInputSource(sources...)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func applyInsetFieldInfoToSchema(ns *schema.NodeSchema, fieldInfo [][]*vo.Param, parentNode *vo.Node) error {
|
||||
if len(fieldInfo) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, params := range fieldInfo {
|
||||
// Each FieldInfo is list params, containing two elements.
|
||||
// The first is to set the name of the field and the second is the corresponding value.
|
||||
p0 := params[0]
|
||||
p1 := params[1]
|
||||
name := p0.Input.Value.Content.(string) // must string type
|
||||
tInfo, err := convert.CanvasBlockInputToTypeInfo(p1.Input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
name = "__setting_field_" + name
|
||||
ns.SetInputType(name, tInfo)
|
||||
sources, err := convert.CanvasBlockInputToFieldInfo(p1.Input, einoCompose.FieldPath{name}, parentNode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ns.AddInputSource(sources...)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildClauseGroupFromCondition(condition *vo.DBCondition) (*database.ClauseGroup, error) {
|
||||
clauseGroup := &database.ClauseGroup{}
|
||||
if len(condition.ConditionList) == 1 {
|
||||
params := condition.ConditionList[0]
|
||||
clause, err := buildClauseFromParams(params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clauseGroup.Single = clause
|
||||
} else {
|
||||
relation, err := convertLogicTypeToRelation(condition.Logic)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clauseGroup.Multi = &database.MultiClause{
|
||||
Clauses: make([]*database.Clause, 0, len(condition.ConditionList)),
|
||||
Relation: relation,
|
||||
}
|
||||
for i := range condition.ConditionList {
|
||||
params := condition.ConditionList[i]
|
||||
clause, err := buildClauseFromParams(params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clauseGroup.Multi.Clauses = append(clauseGroup.Multi.Clauses, clause)
|
||||
}
|
||||
}
|
||||
|
||||
return clauseGroup, nil
|
||||
}
|
||||
|
||||
func buildClauseFromParams(params []*vo.Param) (*database.Clause, error) {
|
||||
var left, operation *vo.Param
|
||||
for _, p := range params {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
if p.Name == "left" {
|
||||
left = p
|
||||
continue
|
||||
}
|
||||
if p.Name == "operation" {
|
||||
operation = p
|
||||
continue
|
||||
}
|
||||
}
|
||||
if left == nil {
|
||||
return nil, fmt.Errorf("left clause is required")
|
||||
}
|
||||
if operation == nil {
|
||||
return nil, fmt.Errorf("operation clause is required")
|
||||
}
|
||||
operator, err := operationToOperator(operation.Input.Value.Content.(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clause := &database.Clause{
|
||||
Left: left.Input.Value.Content.(string),
|
||||
Operator: operator,
|
||||
}
|
||||
|
||||
return clause, nil
|
||||
}
|
||||
|
||||
func convertLogicTypeToRelation(logicType vo.DatabaseLogicType) (database.ClauseRelation, error) {
|
||||
switch logicType {
|
||||
case vo.DatabaseLogicAnd:
|
||||
return database.ClauseRelationAND, nil
|
||||
case vo.DatabaseLogicOr:
|
||||
return database.ClauseRelationOR, nil
|
||||
default:
|
||||
return "", fmt.Errorf("logic type %v is invalid", logicType)
|
||||
}
|
||||
}
|
||||
|
||||
func operationToOperator(s string) (database.Operator, error) {
|
||||
switch s {
|
||||
case "EQUAL":
|
||||
return database.OperatorEqual, nil
|
||||
case "NOT_EQUAL":
|
||||
return database.OperatorNotEqual, nil
|
||||
case "GREATER_THAN":
|
||||
return database.OperatorGreater, nil
|
||||
case "LESS_THAN":
|
||||
return database.OperatorLesser, nil
|
||||
case "GREATER_EQUAL":
|
||||
return database.OperatorGreaterOrEqual, nil
|
||||
case "LESS_EQUAL":
|
||||
return database.OperatorLesserOrEqual, nil
|
||||
case "IN":
|
||||
return database.OperatorIn, nil
|
||||
case "NOT_IN":
|
||||
return database.OperatorNotIn, nil
|
||||
case "IS_NULL":
|
||||
return database.OperatorIsNull, nil
|
||||
case "IS_NOT_NULL":
|
||||
return database.OperatorIsNotNull, nil
|
||||
case "LIKE":
|
||||
return database.OperatorLike, nil
|
||||
case "NOT_LIKE":
|
||||
return database.OperatorNotLike, nil
|
||||
}
|
||||
return "", fmt.Errorf("not a valid Operation string")
|
||||
}
|
||||
|
|
@ -342,7 +342,7 @@ func responseFormatted(configOutput map[string]*vo.TypeInfo, response *database.
|
|||
return ret, nil
|
||||
}
|
||||
|
||||
func convertClauseGroupToConditionGroup(ctx context.Context, clauseGroup *database.ClauseGroup, input map[string]any) (*database.ConditionGroup, error) {
|
||||
func convertClauseGroupToConditionGroup(_ context.Context, clauseGroup *database.ClauseGroup, input map[string]any) (*database.ConditionGroup, error) {
|
||||
var (
|
||||
rightValue any
|
||||
ok bool
|
||||
|
|
@ -394,13 +394,13 @@ func convertClauseGroupToConditionGroup(ctx context.Context, clauseGroup *databa
|
|||
return conditionGroup, nil
|
||||
}
|
||||
|
||||
func convertClauseGroupToUpdateInventory(ctx context.Context, clauseGroup *database.ClauseGroup, input map[string]any) (*UpdateInventory, error) {
|
||||
func convertClauseGroupToUpdateInventory(ctx context.Context, clauseGroup *database.ClauseGroup, input map[string]any) (*updateInventory, error) {
|
||||
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, clauseGroup, input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fields := parseToInput(input)
|
||||
inventory := &UpdateInventory{
|
||||
inventory := &updateInventory{
|
||||
ConditionGroup: conditionGroup,
|
||||
Fields: fields,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,48 +19,89 @@ package database
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
)
|
||||
|
||||
type CustomSQLConfig struct {
|
||||
DatabaseInfoID int64
|
||||
SQLTemplate string
|
||||
OutputConfig map[string]*vo.TypeInfo
|
||||
CustomSQLExecutor database.DatabaseOperator
|
||||
}
|
||||
|
||||
func NewCustomSQL(_ context.Context, cfg *CustomSQLConfig) (*CustomSQL, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config is required")
|
||||
func (c *CustomSQLConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeDatabaseCustomSQL,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: c,
|
||||
}
|
||||
if cfg.DatabaseInfoID == 0 {
|
||||
|
||||
dsList := n.Data.Inputs.DatabaseInfoList
|
||||
if len(dsList) == 0 {
|
||||
return nil, fmt.Errorf("database info is requird")
|
||||
}
|
||||
databaseInfo := dsList[0]
|
||||
|
||||
dsID, err := strconv.ParseInt(databaseInfo.DatabaseInfoID, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.DatabaseInfoID = dsID
|
||||
|
||||
sql := n.Data.Inputs.SQL
|
||||
if len(sql) == 0 {
|
||||
return nil, fmt.Errorf("sql is requird")
|
||||
}
|
||||
|
||||
c.SQLTemplate = sql
|
||||
|
||||
if err = convert.SetInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func (c *CustomSQLConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
if c.DatabaseInfoID == 0 {
|
||||
return nil, errors.New("database info id is required and greater than 0")
|
||||
}
|
||||
if cfg.SQLTemplate == "" {
|
||||
if c.SQLTemplate == "" {
|
||||
return nil, errors.New("sql template is required")
|
||||
}
|
||||
if cfg.CustomSQLExecutor == nil {
|
||||
return nil, errors.New("custom sqler is required")
|
||||
}
|
||||
|
||||
return &CustomSQL{
|
||||
config: cfg,
|
||||
databaseInfoID: c.DatabaseInfoID,
|
||||
sqlTemplate: c.SQLTemplate,
|
||||
outputTypes: ns.OutputTypes,
|
||||
customSQLExecutor: database.GetDatabaseOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type CustomSQL struct {
|
||||
config *CustomSQLConfig
|
||||
databaseInfoID int64
|
||||
sqlTemplate string
|
||||
outputTypes map[string]*vo.TypeInfo
|
||||
customSQLExecutor database.DatabaseOperator
|
||||
}
|
||||
|
||||
func (c *CustomSQL) Execute(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
|
||||
func (c *CustomSQL) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
req := &database.CustomSQLRequest{
|
||||
DatabaseInfoID: c.config.DatabaseInfoID,
|
||||
DatabaseInfoID: c.databaseInfoID,
|
||||
IsDebugRun: isDebugExecute(ctx),
|
||||
UserID: getExecUserID(ctx),
|
||||
}
|
||||
|
|
@ -71,7 +112,7 @@ func (c *CustomSQL) Execute(ctx context.Context, input map[string]any) (map[stri
|
|||
}
|
||||
|
||||
templateSQL := ""
|
||||
templateParts := nodes.ParseTemplate(c.config.SQLTemplate)
|
||||
templateParts := nodes.ParseTemplate(c.sqlTemplate)
|
||||
sqlParams := make([]database.SQLParam, 0, len(templateParts))
|
||||
var nilError = errors.New("field is nil")
|
||||
for _, templatePart := range templateParts {
|
||||
|
|
@ -113,12 +154,12 @@ func (c *CustomSQL) Execute(ctx context.Context, input map[string]any) (map[stri
|
|||
templateSQL = strings.Replace(templateSQL, "`?`", "?", -1)
|
||||
req.SQL = templateSQL
|
||||
req.Params = sqlParams
|
||||
response, err := c.config.CustomSQLExecutor.Execute(ctx, req)
|
||||
response, err := c.customSQLExecutor.Execute(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret, err := responseFormatted(c.config.OutputConfig, response)
|
||||
ret, err := responseFormatted(c.outputTypes, response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ import (
|
|||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database/databasemock"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
type mockCustomSQLer struct {
|
||||
|
|
@ -39,7 +40,7 @@ func (m mockCustomSQLer) Execute() func(ctx context.Context, request *database.C
|
|||
m.validate(request)
|
||||
r := &database.Response{
|
||||
Objects: []database.Object{
|
||||
database.Object{
|
||||
{
|
||||
"v1": "v1_ret",
|
||||
"v2": "v2_ret",
|
||||
},
|
||||
|
|
@ -58,9 +59,9 @@ func TestCustomSQL_Execute(t *testing.T) {
|
|||
validate: func(req *database.CustomSQLRequest) {
|
||||
assert.Equal(t, int64(111), req.DatabaseInfoID)
|
||||
ps := []database.SQLParam{
|
||||
database.SQLParam{Value: "v1_value"},
|
||||
database.SQLParam{Value: "v2_value"},
|
||||
database.SQLParam{Value: "v3_value"},
|
||||
{Value: "v1_value"},
|
||||
{Value: "v2_value"},
|
||||
{Value: "v3_value"},
|
||||
}
|
||||
assert.Equal(t, ps, req.Params)
|
||||
assert.Equal(t, "select * from v1 where v1 = ? and v2 = ? and v3 = ?", req.SQL)
|
||||
|
|
@ -80,23 +81,25 @@ func TestCustomSQL_Execute(t *testing.T) {
|
|||
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
|
||||
mockDatabaseOperator.EXPECT().Execute(gomock.Any(), gomock.Any()).DoAndReturn(mockSQLer.Execute()).AnyTimes()
|
||||
|
||||
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
|
||||
|
||||
cfg := &CustomSQLConfig{
|
||||
DatabaseInfoID: 111,
|
||||
SQLTemplate: "select * from v1 where v1 = {{v1}} and v2 = '{{v2}}' and v3 = `{{v3}}`",
|
||||
CustomSQLExecutor: mockDatabaseOperator,
|
||||
OutputConfig: map[string]*vo.TypeInfo{
|
||||
}
|
||||
|
||||
c1, err := cfg.Build(context.Background(), &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"v1": {Type: vo.DataTypeString},
|
||||
"v2": {Type: vo.DataTypeString},
|
||||
}}},
|
||||
"rowNum": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
}
|
||||
cl := &CustomSQL{
|
||||
config: cfg,
|
||||
}
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
ret, err := cl.Execute(t.Context(), map[string]any{
|
||||
ret, err := c1.(*CustomSQL).Invoke(t.Context(), map[string]any{
|
||||
"v1": "v1_value",
|
||||
"v2": "v2_value",
|
||||
"v3": "v3_value",
|
||||
|
|
|
|||
|
|
@ -20,61 +20,102 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
type DeleteConfig struct {
|
||||
DatabaseInfoID int64
|
||||
ClauseGroup *database.ClauseGroup
|
||||
OutputConfig map[string]*vo.TypeInfo
|
||||
|
||||
Deleter database.DatabaseOperator
|
||||
}
|
||||
type Delete struct {
|
||||
config *DeleteConfig
|
||||
}
|
||||
|
||||
func NewDelete(_ context.Context, cfg *DeleteConfig) (*Delete, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config is required")
|
||||
func (d *DeleteConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeDatabaseDelete,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: d,
|
||||
}
|
||||
if cfg.DatabaseInfoID == 0 {
|
||||
|
||||
dsList := n.Data.Inputs.DatabaseInfoList
|
||||
if len(dsList) == 0 {
|
||||
return nil, fmt.Errorf("database info is requird")
|
||||
}
|
||||
databaseInfo := dsList[0]
|
||||
|
||||
dsID, err := strconv.ParseInt(databaseInfo.DatabaseInfoID, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.DatabaseInfoID = dsID
|
||||
|
||||
deleteParam := n.Data.Inputs.DeleteParam
|
||||
|
||||
clauseGroup, err := buildClauseGroupFromCondition(&deleteParam.Condition)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.ClauseGroup = clauseGroup
|
||||
|
||||
if err = setDatabaseInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func (d *DeleteConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
if d.DatabaseInfoID == 0 {
|
||||
return nil, errors.New("database info id is required and greater than 0")
|
||||
}
|
||||
|
||||
if cfg.ClauseGroup == nil {
|
||||
if d.ClauseGroup == nil {
|
||||
return nil, errors.New("clauseGroup is required")
|
||||
}
|
||||
if cfg.Deleter == nil {
|
||||
return nil, errors.New("deleter is required")
|
||||
}
|
||||
|
||||
return &Delete{
|
||||
config: cfg,
|
||||
databaseInfoID: d.DatabaseInfoID,
|
||||
clauseGroup: d.ClauseGroup,
|
||||
outputTypes: ns.OutputTypes,
|
||||
deleter: database.GetDatabaseOperator(),
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
func (d *Delete) Delete(ctx context.Context, in map[string]any) (map[string]any, error) {
|
||||
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, d.config.ClauseGroup, in)
|
||||
type Delete struct {
|
||||
databaseInfoID int64
|
||||
clauseGroup *database.ClauseGroup
|
||||
outputTypes map[string]*vo.TypeInfo
|
||||
deleter database.DatabaseOperator
|
||||
}
|
||||
|
||||
func (d *Delete) Invoke(ctx context.Context, in map[string]any) (map[string]any, error) {
|
||||
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, d.clauseGroup, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request := &database.DeleteRequest{
|
||||
DatabaseInfoID: d.config.DatabaseInfoID,
|
||||
DatabaseInfoID: d.databaseInfoID,
|
||||
ConditionGroup: conditionGroup,
|
||||
IsDebugRun: isDebugExecute(ctx),
|
||||
UserID: getExecUserID(ctx),
|
||||
}
|
||||
|
||||
response, err := d.config.Deleter.Delete(ctx, request)
|
||||
response, err := d.deleter.Delete(ctx, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret, err := responseFormatted(d.config.OutputConfig, response)
|
||||
ret, err := responseFormatted(d.outputTypes, response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -82,7 +123,7 @@ func (d *Delete) Delete(ctx context.Context, in map[string]any) (map[string]any,
|
|||
}
|
||||
|
||||
func (d *Delete) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) {
|
||||
conditionGroup, err := convertClauseGroupToConditionGroup(context.Background(), d.config.ClauseGroup, in)
|
||||
conditionGroup, err := convertClauseGroupToConditionGroup(context.Background(), d.clauseGroup, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -90,7 +131,7 @@ func (d *Delete) ToCallbackInput(_ context.Context, in map[string]any) (map[stri
|
|||
}
|
||||
|
||||
func (d *Delete) toDatabaseDeleteCallbackInput(conditionGroup *database.ConditionGroup) (map[string]any, error) {
|
||||
databaseID := d.config.DatabaseInfoID
|
||||
databaseID := d.databaseInfoID
|
||||
result := make(map[string]any)
|
||||
|
||||
result["databaseInfoList"] = []string{fmt.Sprintf("%d", databaseID)}
|
||||
|
|
|
|||
|
|
@ -20,54 +20,84 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
type InsertConfig struct {
|
||||
DatabaseInfoID int64
|
||||
OutputConfig map[string]*vo.TypeInfo
|
||||
Inserter database.DatabaseOperator
|
||||
}
|
||||
|
||||
type Insert struct {
|
||||
config *InsertConfig
|
||||
}
|
||||
|
||||
func NewInsert(_ context.Context, cfg *InsertConfig) (*Insert, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config is required")
|
||||
func (i *InsertConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeDatabaseInsert,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: i,
|
||||
}
|
||||
if cfg.DatabaseInfoID == 0 {
|
||||
|
||||
dsList := n.Data.Inputs.DatabaseInfoList
|
||||
if len(dsList) == 0 {
|
||||
return nil, fmt.Errorf("database info is requird")
|
||||
}
|
||||
databaseInfo := dsList[0]
|
||||
|
||||
dsID, err := strconv.ParseInt(databaseInfo.DatabaseInfoID, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
i.DatabaseInfoID = dsID
|
||||
|
||||
if err = setDatabaseInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func (i *InsertConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
if i.DatabaseInfoID == 0 {
|
||||
return nil, errors.New("database info id is required and greater than 0")
|
||||
}
|
||||
|
||||
if cfg.Inserter == nil {
|
||||
return nil, errors.New("inserter is required")
|
||||
}
|
||||
return &Insert{
|
||||
config: cfg,
|
||||
databaseInfoID: i.DatabaseInfoID,
|
||||
outputTypes: ns.OutputTypes,
|
||||
inserter: database.GetDatabaseOperator(),
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
func (is *Insert) Insert(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
type Insert struct {
|
||||
databaseInfoID int64
|
||||
outputTypes map[string]*vo.TypeInfo
|
||||
inserter database.DatabaseOperator
|
||||
}
|
||||
|
||||
func (is *Insert) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
fields := parseToInput(input)
|
||||
req := &database.InsertRequest{
|
||||
DatabaseInfoID: is.config.DatabaseInfoID,
|
||||
DatabaseInfoID: is.databaseInfoID,
|
||||
Fields: fields,
|
||||
IsDebugRun: isDebugExecute(ctx),
|
||||
UserID: getExecUserID(ctx),
|
||||
}
|
||||
|
||||
response, err := is.config.Inserter.Insert(ctx, req)
|
||||
response, err := is.inserter.Insert(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret, err := responseFormatted(is.config.OutputConfig, response)
|
||||
ret, err := responseFormatted(is.outputTypes, response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -76,7 +106,7 @@ func (is *Insert) Insert(ctx context.Context, input map[string]any) (map[string]
|
|||
}
|
||||
|
||||
func (is *Insert) ToCallbackInput(_ context.Context, input map[string]any) (map[string]any, error) {
|
||||
databaseID := is.config.DatabaseInfoID
|
||||
databaseID := is.databaseInfoID
|
||||
fs := parseToInput(input)
|
||||
result := make(map[string]any)
|
||||
result["databaseInfoList"] = []string{fmt.Sprintf("%d", databaseID)}
|
||||
|
|
|
|||
|
|
@ -20,68 +20,137 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
type QueryConfig struct {
|
||||
DatabaseInfoID int64
|
||||
QueryFields []string
|
||||
OrderClauses []*database.OrderClause
|
||||
OutputConfig map[string]*vo.TypeInfo
|
||||
ClauseGroup *database.ClauseGroup
|
||||
Limit int64
|
||||
Op database.DatabaseOperator
|
||||
}
|
||||
|
||||
type Query struct {
|
||||
config *QueryConfig
|
||||
}
|
||||
|
||||
func NewQuery(_ context.Context, cfg *QueryConfig) (*Query, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config is required")
|
||||
func (q *QueryConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeDatabaseQuery,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: q,
|
||||
}
|
||||
if cfg.DatabaseInfoID == 0 {
|
||||
|
||||
dsList := n.Data.Inputs.DatabaseInfoList
|
||||
if len(dsList) == 0 {
|
||||
return nil, fmt.Errorf("database info is requird")
|
||||
}
|
||||
databaseInfo := dsList[0]
|
||||
|
||||
dsID, err := strconv.ParseInt(databaseInfo.DatabaseInfoID, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
q.DatabaseInfoID = dsID
|
||||
|
||||
selectParam := n.Data.Inputs.SelectParam
|
||||
q.Limit = selectParam.Limit
|
||||
|
||||
queryFields := make([]string, 0)
|
||||
for _, v := range selectParam.FieldList {
|
||||
queryFields = append(queryFields, strconv.FormatInt(v.FieldID, 10))
|
||||
}
|
||||
q.QueryFields = queryFields
|
||||
|
||||
orderClauses := make([]*database.OrderClause, 0, len(selectParam.OrderByList))
|
||||
for _, o := range selectParam.OrderByList {
|
||||
orderClauses = append(orderClauses, &database.OrderClause{
|
||||
FieldID: strconv.FormatInt(o.FieldID, 10),
|
||||
IsAsc: o.IsAsc,
|
||||
})
|
||||
}
|
||||
q.OrderClauses = orderClauses
|
||||
|
||||
clauseGroup := &database.ClauseGroup{}
|
||||
|
||||
if selectParam.Condition != nil {
|
||||
clauseGroup, err = buildClauseGroupFromCondition(selectParam.Condition)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
q.ClauseGroup = clauseGroup
|
||||
|
||||
if err = setDatabaseInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func (q *QueryConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
if q.DatabaseInfoID == 0 {
|
||||
return nil, errors.New("database info id is required and greater than 0")
|
||||
}
|
||||
|
||||
if cfg.Limit == 0 {
|
||||
if q.Limit == 0 {
|
||||
return nil, errors.New("limit is required and greater than 0")
|
||||
}
|
||||
|
||||
if cfg.Op == nil {
|
||||
return nil, errors.New("op is required")
|
||||
}
|
||||
|
||||
return &Query{config: cfg}, nil
|
||||
|
||||
return &Query{
|
||||
databaseInfoID: q.DatabaseInfoID,
|
||||
queryFields: q.QueryFields,
|
||||
orderClauses: q.OrderClauses,
|
||||
outputTypes: ns.OutputTypes,
|
||||
clauseGroup: q.ClauseGroup,
|
||||
limit: q.Limit,
|
||||
op: database.GetDatabaseOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (ds *Query) Query(ctx context.Context, in map[string]any) (map[string]any, error) {
|
||||
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, ds.config.ClauseGroup, in)
|
||||
type Query struct {
|
||||
databaseInfoID int64
|
||||
queryFields []string
|
||||
orderClauses []*database.OrderClause
|
||||
outputTypes map[string]*vo.TypeInfo
|
||||
clauseGroup *database.ClauseGroup
|
||||
limit int64
|
||||
op database.DatabaseOperator
|
||||
}
|
||||
|
||||
func (ds *Query) Invoke(ctx context.Context, in map[string]any) (map[string]any, error) {
|
||||
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, ds.clauseGroup, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req := &database.QueryRequest{
|
||||
DatabaseInfoID: ds.config.DatabaseInfoID,
|
||||
OrderClauses: ds.config.OrderClauses,
|
||||
SelectFields: ds.config.QueryFields,
|
||||
Limit: ds.config.Limit,
|
||||
DatabaseInfoID: ds.databaseInfoID,
|
||||
OrderClauses: ds.orderClauses,
|
||||
SelectFields: ds.queryFields,
|
||||
Limit: ds.limit,
|
||||
IsDebugRun: isDebugExecute(ctx),
|
||||
UserID: getExecUserID(ctx),
|
||||
}
|
||||
|
||||
req.ConditionGroup = conditionGroup
|
||||
|
||||
response, err := ds.config.Op.Query(ctx, req)
|
||||
response, err := ds.op.Query(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret, err := responseFormatted(ds.config.OutputConfig, response)
|
||||
ret, err := responseFormatted(ds.outputTypes, response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -93,18 +162,18 @@ func notNeedTakeMapValue(op database.Operator) bool {
|
|||
}
|
||||
|
||||
func (ds *Query) ToCallbackInput(ctx context.Context, in map[string]any) (map[string]any, error) {
|
||||
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, ds.config.ClauseGroup, in)
|
||||
conditionGroup, err := convertClauseGroupToConditionGroup(ctx, ds.clauseGroup, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return toDatabaseQueryCallbackInput(ds.config, conditionGroup)
|
||||
return ds.toDatabaseQueryCallbackInput(conditionGroup)
|
||||
}
|
||||
|
||||
func toDatabaseQueryCallbackInput(config *QueryConfig, conditionGroup *database.ConditionGroup) (map[string]any, error) {
|
||||
func (ds *Query) toDatabaseQueryCallbackInput(conditionGroup *database.ConditionGroup) (map[string]any, error) {
|
||||
result := make(map[string]any)
|
||||
|
||||
databaseID := config.DatabaseInfoID
|
||||
databaseID := ds.databaseInfoID
|
||||
result["databaseInfoList"] = []string{fmt.Sprintf("%d", databaseID)}
|
||||
result["selectParam"] = map[string]any{}
|
||||
|
||||
|
|
@ -116,8 +185,8 @@ func toDatabaseQueryCallbackInput(config *QueryConfig, conditionGroup *database.
|
|||
FieldID string `json:"fieldId"`
|
||||
IsDistinct bool `json:"isDistinct"`
|
||||
}
|
||||
fieldList := make([]Field, 0, len(config.QueryFields))
|
||||
for _, f := range config.QueryFields {
|
||||
fieldList := make([]Field, 0, len(ds.queryFields))
|
||||
for _, f := range ds.queryFields {
|
||||
fieldList = append(fieldList, Field{FieldID: f})
|
||||
}
|
||||
type Order struct {
|
||||
|
|
@ -126,7 +195,7 @@ func toDatabaseQueryCallbackInput(config *QueryConfig, conditionGroup *database.
|
|||
}
|
||||
|
||||
OrderList := make([]Order, 0)
|
||||
for _, c := range config.OrderClauses {
|
||||
for _, c := range ds.orderClauses {
|
||||
OrderList = append(OrderList, Order{
|
||||
FieldID: c.FieldID,
|
||||
IsAsc: c.IsAsc,
|
||||
|
|
@ -135,12 +204,11 @@ func toDatabaseQueryCallbackInput(config *QueryConfig, conditionGroup *database.
|
|||
result["selectParam"] = map[string]any{
|
||||
"condition": condition,
|
||||
"fieldList": fieldList,
|
||||
"limit": config.Limit,
|
||||
"limit": ds.limit,
|
||||
"orderByList": OrderList,
|
||||
}
|
||||
|
||||
return result, nil
|
||||
|
||||
}
|
||||
|
||||
type ConditionItem struct {
|
||||
|
|
@ -216,6 +284,5 @@ func convertToLogic(rel database.ClauseRelation) (string, error) {
|
|||
return "AND", nil
|
||||
default:
|
||||
return "", fmt.Errorf("unknown clause relation %v", rel)
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ import (
|
|||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database/databasemock"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
type mockDsSelect struct {
|
||||
|
|
@ -82,16 +83,7 @@ func TestDataset_Query(t *testing.T) {
|
|||
},
|
||||
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
|
||||
QueryFields: []string{"v1", "v2"},
|
||||
OutputConfig: map[string]*vo.TypeInfo{
|
||||
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
"v1": {Type: vo.DataTypeString},
|
||||
"v2": {Type: vo.DataTypeString},
|
||||
},
|
||||
}},
|
||||
"rowNum": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
Limit: 10,
|
||||
}
|
||||
|
||||
mockQuery := &mockDsSelect{objects: objects, t: t, validate: func(request *database.QueryRequest) {
|
||||
|
|
@ -106,17 +98,27 @@ func TestDataset_Query(t *testing.T) {
|
|||
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
|
||||
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query())
|
||||
|
||||
cfg.Op = mockDatabaseOperator
|
||||
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
|
||||
|
||||
ds := Query{
|
||||
config: cfg,
|
||||
}
|
||||
ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
"v1": {Type: vo.DataTypeString},
|
||||
"v2": {Type: vo.DataTypeString},
|
||||
},
|
||||
}},
|
||||
"rowNum": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
in := map[string]interface{}{
|
||||
"__condition_right_0": 1,
|
||||
}
|
||||
|
||||
result, err := ds.Query(t.Context(), in)
|
||||
result, err := ds.(*Query).Invoke(t.Context(), in)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "1", result["outputList"].([]any)[0].(database.Object)["v1"])
|
||||
assert.Equal(t, "2", result["outputList"].([]any)[0].(database.Object)["v2"])
|
||||
|
|
@ -137,17 +139,7 @@ func TestDataset_Query(t *testing.T) {
|
|||
|
||||
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
|
||||
QueryFields: []string{"v1", "v2"},
|
||||
|
||||
OutputConfig: map[string]*vo.TypeInfo{
|
||||
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
"v1": {Type: vo.DataTypeString},
|
||||
"v2": {Type: vo.DataTypeString},
|
||||
},
|
||||
}},
|
||||
"rowNum": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
Limit: 10,
|
||||
}
|
||||
|
||||
objects := make([]database.Object, 0)
|
||||
|
|
@ -170,18 +162,28 @@ func TestDataset_Query(t *testing.T) {
|
|||
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
|
||||
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
|
||||
|
||||
cfg.Op = mockDatabaseOperator
|
||||
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
|
||||
|
||||
ds := Query{
|
||||
config: cfg,
|
||||
}
|
||||
ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
"v1": {Type: vo.DataTypeString},
|
||||
"v2": {Type: vo.DataTypeString},
|
||||
},
|
||||
}},
|
||||
"rowNum": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
in := map[string]any{
|
||||
"__condition_right_0": 1,
|
||||
"__condition_right_1": 2,
|
||||
}
|
||||
|
||||
result, err := ds.Query(t.Context(), in)
|
||||
result, err := ds.(*Query).Invoke(t.Context(), in)
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "1", result["outputList"].([]any)[0].(database.Object)["v1"])
|
||||
|
|
@ -199,17 +201,7 @@ func TestDataset_Query(t *testing.T) {
|
|||
},
|
||||
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
|
||||
QueryFields: []string{"v1", "v2"},
|
||||
|
||||
OutputConfig: map[string]*vo.TypeInfo{
|
||||
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
"v1": {Type: vo.DataTypeInteger},
|
||||
"v2": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
}},
|
||||
"rowNum": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
Limit: 10,
|
||||
}
|
||||
objects := make([]database.Object, 0)
|
||||
objects = append(objects, database.Object{
|
||||
|
|
@ -230,17 +222,27 @@ func TestDataset_Query(t *testing.T) {
|
|||
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
|
||||
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
|
||||
|
||||
cfg.Op = mockDatabaseOperator
|
||||
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
|
||||
|
||||
ds := Query{
|
||||
config: cfg,
|
||||
}
|
||||
ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
"v1": {Type: vo.DataTypeInteger},
|
||||
"v2": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
}},
|
||||
"rowNum": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
in := map[string]any{
|
||||
"__condition_right_0": 1,
|
||||
}
|
||||
|
||||
result, err := ds.Query(t.Context(), in)
|
||||
result, err := ds.(*Query).Invoke(t.Context(), in)
|
||||
assert.NoError(t, err)
|
||||
fmt.Println(result)
|
||||
assert.Equal(t, map[string]any{
|
||||
|
|
@ -261,18 +263,7 @@ func TestDataset_Query(t *testing.T) {
|
|||
},
|
||||
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
|
||||
QueryFields: []string{"v1", "v2"},
|
||||
|
||||
OutputConfig: map[string]*vo.TypeInfo{
|
||||
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
"v1": {Type: vo.DataTypeInteger},
|
||||
"v2": {Type: vo.DataTypeInteger},
|
||||
"v3": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
}},
|
||||
"rowNum": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
Limit: 10,
|
||||
}
|
||||
objects := make([]database.Object, 0)
|
||||
objects = append(objects, database.Object{
|
||||
|
|
@ -290,15 +281,26 @@ func TestDataset_Query(t *testing.T) {
|
|||
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
|
||||
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
|
||||
|
||||
cfg.Op = mockDatabaseOperator
|
||||
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
|
||||
|
||||
ds := Query{
|
||||
config: cfg,
|
||||
}
|
||||
ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
"v1": {Type: vo.DataTypeInteger},
|
||||
"v2": {Type: vo.DataTypeInteger},
|
||||
"v3": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
}},
|
||||
"rowNum": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
in := map[string]any{"__condition_right_0": 1}
|
||||
|
||||
result, err := ds.Query(t.Context(), in)
|
||||
result, err := ds.(*Query).Invoke(t.Context(), in)
|
||||
assert.NoError(t, err)
|
||||
fmt.Println(result)
|
||||
assert.Equal(t, int64(1), result["outputList"].([]any)[0].(database.Object)["v1"])
|
||||
|
|
@ -321,22 +323,7 @@ func TestDataset_Query(t *testing.T) {
|
|||
},
|
||||
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
|
||||
QueryFields: []string{"v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8"},
|
||||
|
||||
OutputConfig: map[string]*vo.TypeInfo{
|
||||
"outputList": {Type: vo.DataTypeArray,
|
||||
ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"v1": {Type: vo.DataTypeInteger},
|
||||
"v2": {Type: vo.DataTypeNumber},
|
||||
"v3": {Type: vo.DataTypeBoolean},
|
||||
"v4": {Type: vo.DataTypeBoolean},
|
||||
"v5": {Type: vo.DataTypeTime},
|
||||
"v6": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeInteger}},
|
||||
"v7": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeBoolean}},
|
||||
"v8": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
},
|
||||
}},
|
||||
"rowNum": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
Limit: 10,
|
||||
}
|
||||
|
||||
objects := make([]database.Object, 0)
|
||||
|
|
@ -363,17 +350,32 @@ func TestDataset_Query(t *testing.T) {
|
|||
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
|
||||
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
|
||||
|
||||
cfg.Op = mockDatabaseOperator
|
||||
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
|
||||
|
||||
ds := Query{
|
||||
config: cfg,
|
||||
}
|
||||
ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
"outputList": {Type: vo.DataTypeArray,
|
||||
ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{
|
||||
"v1": {Type: vo.DataTypeInteger},
|
||||
"v2": {Type: vo.DataTypeNumber},
|
||||
"v3": {Type: vo.DataTypeBoolean},
|
||||
"v4": {Type: vo.DataTypeBoolean},
|
||||
"v5": {Type: vo.DataTypeTime},
|
||||
"v6": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeInteger}},
|
||||
"v7": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeBoolean}},
|
||||
"v8": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeNumber}},
|
||||
},
|
||||
}},
|
||||
"rowNum": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
in := map[string]any{
|
||||
"__condition_right_0": 1,
|
||||
}
|
||||
|
||||
result, err := ds.Query(t.Context(), in)
|
||||
result, err := ds.(*Query).Invoke(t.Context(), in)
|
||||
assert.NoError(t, err)
|
||||
object := result["outputList"].([]any)[0].(database.Object)
|
||||
|
||||
|
|
@ -400,10 +402,7 @@ func TestDataset_Query(t *testing.T) {
|
|||
},
|
||||
OrderClauses: []*database.OrderClause{{FieldID: "v1", IsAsc: false}},
|
||||
QueryFields: []string{"v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8"},
|
||||
OutputConfig: map[string]*vo.TypeInfo{
|
||||
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{}}},
|
||||
"rowNum": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
Limit: 10,
|
||||
}
|
||||
|
||||
objects := make([]database.Object, 0)
|
||||
|
|
@ -429,16 +428,21 @@ func TestDataset_Query(t *testing.T) {
|
|||
mockDatabaseOperator := databasemock.NewMockDatabaseOperator(ctrl)
|
||||
mockDatabaseOperator.EXPECT().Query(gomock.Any(), gomock.Any()).DoAndReturn(mockQuery.Query()).AnyTimes()
|
||||
|
||||
cfg.Op = mockDatabaseOperator
|
||||
ds := Query{
|
||||
config: cfg,
|
||||
}
|
||||
defer mockey.Mock(database.GetDatabaseOperator).Return(mockDatabaseOperator).Build().UnPatch()
|
||||
|
||||
ds, err := cfg.Build(context.Background(), &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
"outputList": {Type: vo.DataTypeArray, ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeObject, Properties: map[string]*vo.TypeInfo{}}},
|
||||
"rowNum": {Type: vo.DataTypeInteger},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
in := map[string]any{
|
||||
"__condition_right_0": 1,
|
||||
}
|
||||
|
||||
result, err := ds.Query(t.Context(), in)
|
||||
result, err := ds.(*Query).Invoke(t.Context(), in)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, result["outputList"].([]any)[0].(database.Object), database.Object{
|
||||
"v1": "1",
|
||||
|
|
|
|||
|
|
@ -20,47 +20,93 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/database"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
type UpdateConfig struct {
|
||||
DatabaseInfoID int64
|
||||
ClauseGroup *database.ClauseGroup
|
||||
OutputConfig map[string]*vo.TypeInfo
|
||||
Updater database.DatabaseOperator
|
||||
}
|
||||
|
||||
func (u *UpdateConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeDatabaseUpdate,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: u,
|
||||
}
|
||||
|
||||
dsList := n.Data.Inputs.DatabaseInfoList
|
||||
if len(dsList) == 0 {
|
||||
return nil, fmt.Errorf("database info is requird")
|
||||
}
|
||||
databaseInfo := dsList[0]
|
||||
|
||||
dsID, err := strconv.ParseInt(databaseInfo.DatabaseInfoID, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.DatabaseInfoID = dsID
|
||||
|
||||
updateParam := n.Data.Inputs.UpdateParam
|
||||
if updateParam == nil {
|
||||
return nil, fmt.Errorf("update param is requird")
|
||||
}
|
||||
clauseGroup, err := buildClauseGroupFromCondition(&updateParam.Condition)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.ClauseGroup = clauseGroup
|
||||
|
||||
if err = setDatabaseInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func (u *UpdateConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
if u.DatabaseInfoID == 0 {
|
||||
return nil, errors.New("database info id is required and greater than 0")
|
||||
}
|
||||
|
||||
if u.ClauseGroup == nil {
|
||||
return nil, errors.New("clause group is required and greater than 0")
|
||||
}
|
||||
|
||||
return &Update{
|
||||
databaseInfoID: u.DatabaseInfoID,
|
||||
clauseGroup: u.ClauseGroup,
|
||||
outputTypes: ns.OutputTypes,
|
||||
updater: database.GetDatabaseOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type Update struct {
|
||||
config *UpdateConfig
|
||||
databaseInfoID int64
|
||||
clauseGroup *database.ClauseGroup
|
||||
outputTypes map[string]*vo.TypeInfo
|
||||
updater database.DatabaseOperator
|
||||
}
|
||||
type UpdateInventory struct {
|
||||
|
||||
type updateInventory struct {
|
||||
ConditionGroup *database.ConditionGroup
|
||||
Fields map[string]any
|
||||
}
|
||||
|
||||
func NewUpdate(_ context.Context, cfg *UpdateConfig) (*Update, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config is required")
|
||||
}
|
||||
if cfg.DatabaseInfoID == 0 {
|
||||
return nil, errors.New("database info id is required and greater than 0")
|
||||
}
|
||||
|
||||
if cfg.ClauseGroup == nil {
|
||||
return nil, errors.New("clause group is required and greater than 0")
|
||||
}
|
||||
|
||||
if cfg.Updater == nil {
|
||||
return nil, errors.New("updater is required")
|
||||
}
|
||||
|
||||
return &Update{config: cfg}, nil
|
||||
}
|
||||
|
||||
func (u *Update) Update(ctx context.Context, in map[string]any) (map[string]any, error) {
|
||||
inventory, err := convertClauseGroupToUpdateInventory(ctx, u.config.ClauseGroup, in)
|
||||
func (u *Update) Invoke(ctx context.Context, in map[string]any) (map[string]any, error) {
|
||||
inventory, err := convertClauseGroupToUpdateInventory(ctx, u.clauseGroup, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -72,20 +118,20 @@ func (u *Update) Update(ctx context.Context, in map[string]any) (map[string]any,
|
|||
}
|
||||
|
||||
req := &database.UpdateRequest{
|
||||
DatabaseInfoID: u.config.DatabaseInfoID,
|
||||
DatabaseInfoID: u.databaseInfoID,
|
||||
ConditionGroup: inventory.ConditionGroup,
|
||||
Fields: fields,
|
||||
IsDebugRun: isDebugExecute(ctx),
|
||||
UserID: getExecUserID(ctx),
|
||||
}
|
||||
|
||||
response, err := u.config.Updater.Update(ctx, req)
|
||||
response, err := u.updater.Update(ctx, req)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret, err := responseFormatted(u.config.OutputConfig, response)
|
||||
ret, err := responseFormatted(u.outputTypes, response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -94,15 +140,15 @@ func (u *Update) Update(ctx context.Context, in map[string]any) (map[string]any,
|
|||
}
|
||||
|
||||
func (u *Update) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) {
|
||||
inventory, err := convertClauseGroupToUpdateInventory(context.Background(), u.config.ClauseGroup, in)
|
||||
inventory, err := convertClauseGroupToUpdateInventory(context.Background(), u.clauseGroup, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return u.toDatabaseUpdateCallbackInput(inventory)
|
||||
}
|
||||
|
||||
func (u *Update) toDatabaseUpdateCallbackInput(inventory *UpdateInventory) (map[string]any, error) {
|
||||
databaseID := u.config.DatabaseInfoID
|
||||
func (u *Update) toDatabaseUpdateCallbackInput(inventory *updateInventory) (map[string]any, error) {
|
||||
databaseID := u.databaseInfoID
|
||||
result := make(map[string]any)
|
||||
result["databaseInfoList"] = []string{fmt.Sprintf("%d", databaseID)}
|
||||
result["updateParam"] = map[string]any{}
|
||||
|
|
@ -128,6 +174,6 @@ func (u *Update) toDatabaseUpdateCallbackInput(inventory *UpdateInventory) (map[
|
|||
"condition": condition,
|
||||
"fieldInfo": fieldInfo,
|
||||
}
|
||||
return result, nil
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ package emitter
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
|
@ -26,28 +25,77 @@ import (
|
|||
"github.com/bytedance/sonic"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
)
|
||||
|
||||
type OutputEmitter struct {
|
||||
cfg *Config
|
||||
Template string
|
||||
FullSources map[string]*schema2.SourceInfo
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Template string
|
||||
FullSources map[string]*nodes.SourceInfo
|
||||
}
|
||||
|
||||
func New(_ context.Context, cfg *Config) (*OutputEmitter, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config is required")
|
||||
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema2.NodeSchema, error) {
|
||||
ns := &schema2.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeOutputEmitter,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: c,
|
||||
}
|
||||
|
||||
content := n.Data.Inputs.Content
|
||||
streamingOutput := n.Data.Inputs.StreamingOutput
|
||||
|
||||
if streamingOutput {
|
||||
ns.StreamConfigs = &schema2.StreamConfig{
|
||||
RequireStreamingInput: true,
|
||||
}
|
||||
} else {
|
||||
ns.StreamConfigs = &schema2.StreamConfig{
|
||||
RequireStreamingInput: false,
|
||||
}
|
||||
}
|
||||
|
||||
if content != nil {
|
||||
if content.Type != vo.VariableTypeString {
|
||||
return nil, fmt.Errorf("output emitter node's content type must be %s, got %s", vo.VariableTypeString, content.Type)
|
||||
}
|
||||
|
||||
if content.Value.Type != vo.BlockInputValueTypeLiteral {
|
||||
return nil, fmt.Errorf("output emitter node's content value type must be %s, got %s", vo.BlockInputValueTypeLiteral, content.Value.Type)
|
||||
}
|
||||
|
||||
if content.Value.Content == nil {
|
||||
c.Template = ""
|
||||
} else {
|
||||
template, ok := content.Value.Content.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("output emitter node's content value must be string, got %v", content.Value.Content)
|
||||
}
|
||||
|
||||
c.Template = template
|
||||
}
|
||||
}
|
||||
|
||||
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func (c *Config) Build(_ context.Context, ns *schema2.NodeSchema, _ ...schema2.BuildOption) (any, error) {
|
||||
return &OutputEmitter{
|
||||
cfg: cfg,
|
||||
Template: c.Template,
|
||||
FullSources: ns.FullSources,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
@ -59,10 +107,10 @@ type cachedVal struct {
|
|||
|
||||
type cacheStore struct {
|
||||
store map[string]*cachedVal
|
||||
infos map[string]*nodes.SourceInfo
|
||||
infos map[string]*schema2.SourceInfo
|
||||
}
|
||||
|
||||
func newCacheStore(infos map[string]*nodes.SourceInfo) *cacheStore {
|
||||
func newCacheStore(infos map[string]*schema2.SourceInfo) *cacheStore {
|
||||
return &cacheStore{
|
||||
store: make(map[string]*cachedVal),
|
||||
infos: infos,
|
||||
|
|
@ -76,7 +124,7 @@ func (c *cacheStore) put(k string, v any) (any, error) {
|
|||
}
|
||||
|
||||
if !sInfo.IsIntermediate { // this is not an intermediate object container
|
||||
isStream := sInfo.FieldType == nodes.FieldIsStream
|
||||
isStream := sInfo.FieldType == schema2.FieldIsStream
|
||||
if !isStream {
|
||||
_, ok := c.store[k]
|
||||
if !ok {
|
||||
|
|
@ -159,7 +207,7 @@ func (c *cacheStore) put(k string, v any) (any, error) {
|
|||
func (c *cacheStore) finished(k string) bool {
|
||||
cached, ok := c.store[k]
|
||||
if !ok {
|
||||
return c.infos[k].FieldType == nodes.FieldSkipped
|
||||
return c.infos[k].FieldType == schema2.FieldSkipped
|
||||
}
|
||||
|
||||
if cached.finished {
|
||||
|
|
@ -182,7 +230,7 @@ func (c *cacheStore) finished(k string) bool {
|
|||
return true
|
||||
}
|
||||
|
||||
func (c *cacheStore) find(part nodes.TemplatePart) (root any, subCache *cachedVal, sourceInfo *nodes.SourceInfo,
|
||||
func (c *cacheStore) find(part nodes.TemplatePart) (root any, subCache *cachedVal, sourceInfo *schema2.SourceInfo,
|
||||
actualPath []string,
|
||||
) {
|
||||
rootCached, ok := c.store[part.Root]
|
||||
|
|
@ -230,7 +278,7 @@ func (c *cacheStore) readyForPart(part nodes.TemplatePart, sw *schema.StreamWrit
|
|||
hasErr bool, partFinished bool) {
|
||||
cachedRoot, subCache, sourceInfo, _ := c.find(part)
|
||||
if cachedRoot != nil && subCache != nil {
|
||||
if subCache.finished || sourceInfo.FieldType == nodes.FieldIsStream {
|
||||
if subCache.finished || sourceInfo.FieldType == schema2.FieldIsStream {
|
||||
hasErr = renderAndSend(part, part.Root, cachedRoot, sw)
|
||||
if hasErr {
|
||||
return true, false
|
||||
|
|
@ -315,14 +363,14 @@ func merge(a, b any) any {
|
|||
|
||||
const outputKey = "output"
|
||||
|
||||
func (e *OutputEmitter) EmitStream(ctx context.Context, in *schema.StreamReader[map[string]any]) (out *schema.StreamReader[map[string]any], err error) {
|
||||
resolvedSources, err := nodes.ResolveStreamSources(ctx, e.cfg.FullSources)
|
||||
func (e *OutputEmitter) Transform(ctx context.Context, in *schema.StreamReader[map[string]any]) (out *schema.StreamReader[map[string]any], err error) {
|
||||
resolvedSources, err := nodes.ResolveStreamSources(ctx, e.FullSources)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sr, sw := schema.Pipe[map[string]any](0)
|
||||
parts := nodes.ParseTemplate(e.cfg.Template)
|
||||
parts := nodes.ParseTemplate(e.Template)
|
||||
safego.Go(ctx, func() {
|
||||
hasErr := false
|
||||
defer func() {
|
||||
|
|
@ -454,7 +502,7 @@ func (e *OutputEmitter) EmitStream(ctx context.Context, in *schema.StreamReader[
|
|||
shouldChangePart = true
|
||||
}
|
||||
} else {
|
||||
if sourceInfo.FieldType == nodes.FieldIsStream {
|
||||
if sourceInfo.FieldType == schema2.FieldIsStream {
|
||||
currentV := v
|
||||
for i := 0; i < len(actualPath)-1; i++ {
|
||||
currentM, ok := currentV.(map[string]any)
|
||||
|
|
@ -518,8 +566,8 @@ func (e *OutputEmitter) EmitStream(ctx context.Context, in *schema.StreamReader[
|
|||
return sr, nil
|
||||
}
|
||||
|
||||
func (e *OutputEmitter) Emit(ctx context.Context, in map[string]any) (output map[string]any, err error) {
|
||||
s, err := nodes.Render(ctx, e.cfg.Template, in, e.cfg.FullSources)
|
||||
func (e *OutputEmitter) Invoke(ctx context.Context, in map[string]any) (output map[string]any, err error) {
|
||||
s, err := nodes.Render(ctx, e.Template, in, e.FullSources)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,41 +20,74 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
DefaultValues map[string]any
|
||||
OutputTypes map[string]*vo.TypeInfo
|
||||
}
|
||||
|
||||
type Entry struct {
|
||||
cfg *Config
|
||||
defaultValues map[string]any
|
||||
}
|
||||
|
||||
func NewEntry(ctx context.Context, cfg *Config) (*Entry, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("config is requried")
|
||||
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
if n.Parent() != nil {
|
||||
return nil, fmt.Errorf("entry node cannot have parent: %s", n.Parent().ID)
|
||||
}
|
||||
defaultValues, _, err := nodes.ConvertInputs(ctx, cfg.DefaultValues, cfg.OutputTypes, nodes.FailFast(), nodes.SkipRequireCheck())
|
||||
|
||||
if n.ID != entity.EntryNodeKey {
|
||||
return nil, fmt.Errorf("entry node id must be %s, got %s", entity.EntryNodeKey, n.ID)
|
||||
}
|
||||
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Name: n.Data.Meta.Title,
|
||||
Type: entity.NodeTypeEntry,
|
||||
}
|
||||
|
||||
defaultValues := make(map[string]any, len(n.Data.Outputs))
|
||||
for _, v := range n.Data.Outputs {
|
||||
variable, err := vo.ParseVariable(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if variable.DefaultValue != nil {
|
||||
defaultValues[variable.Name] = variable.DefaultValue
|
||||
}
|
||||
}
|
||||
|
||||
c.DefaultValues = defaultValues
|
||||
ns.Configs = c
|
||||
|
||||
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func (c *Config) Build(ctx context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
defaultValues, _, err := nodes.ConvertInputs(ctx, c.DefaultValues, ns.OutputTypes, nodes.FailFast(), nodes.SkipRequireCheck())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Entry{
|
||||
cfg: cfg,
|
||||
defaultValues: defaultValues,
|
||||
outputTypes: ns.OutputTypes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type Entry struct {
|
||||
defaultValues map[string]any
|
||||
outputTypes map[string]*vo.TypeInfo
|
||||
}
|
||||
|
||||
func (e *Entry) Invoke(_ context.Context, in map[string]any) (out map[string]any, err error) {
|
||||
|
||||
for k, v := range e.defaultValues {
|
||||
if val, ok := in[k]; ok {
|
||||
tInfo := e.cfg.OutputTypes[k]
|
||||
tInfo := e.outputTypes[k]
|
||||
switch tInfo.Type {
|
||||
case vo.DataTypeString:
|
||||
if len(val.(string)) == 0 {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,113 @@
|
|||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package exit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/emitter"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Template string
|
||||
TerminatePlan vo.TerminatePlan
|
||||
}
|
||||
|
||||
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
if n.Parent() != nil {
|
||||
return nil, fmt.Errorf("exit node cannot have parent: %s", n.Parent().ID)
|
||||
}
|
||||
|
||||
if n.ID != entity.ExitNodeKey {
|
||||
return nil, fmt.Errorf("exit node id must be %s, got %s", entity.ExitNodeKey, n.ID)
|
||||
}
|
||||
|
||||
ns := &schema.NodeSchema{
|
||||
Key: entity.ExitNodeKey,
|
||||
Type: entity.NodeTypeExit,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: c,
|
||||
}
|
||||
|
||||
var (
|
||||
content *vo.BlockInput
|
||||
streamingOutput bool
|
||||
)
|
||||
if n.Data.Inputs.OutputEmitter != nil {
|
||||
content = n.Data.Inputs.Content
|
||||
streamingOutput = n.Data.Inputs.StreamingOutput
|
||||
}
|
||||
|
||||
if streamingOutput {
|
||||
ns.StreamConfigs = &schema.StreamConfig{
|
||||
RequireStreamingInput: true,
|
||||
}
|
||||
} else {
|
||||
ns.StreamConfigs = &schema.StreamConfig{
|
||||
RequireStreamingInput: false,
|
||||
}
|
||||
}
|
||||
|
||||
if content != nil {
|
||||
if content.Type != vo.VariableTypeString {
|
||||
return nil, fmt.Errorf("exit node's content type must be %s, got %s", vo.VariableTypeString, content.Type)
|
||||
}
|
||||
|
||||
if content.Value.Type != vo.BlockInputValueTypeLiteral {
|
||||
return nil, fmt.Errorf("exit node's content value type must be %s, got %s", vo.BlockInputValueTypeLiteral, content.Value.Type)
|
||||
}
|
||||
|
||||
c.Template = content.Value.Content.(string)
|
||||
}
|
||||
|
||||
if n.Data.Inputs.TerminatePlan == nil {
|
||||
return nil, fmt.Errorf("exit node requires a TerminatePlan")
|
||||
}
|
||||
c.TerminatePlan = *n.Data.Inputs.TerminatePlan
|
||||
|
||||
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
if c.TerminatePlan == vo.ReturnVariables {
|
||||
return &Exit{}, nil
|
||||
}
|
||||
|
||||
return &emitter.OutputEmitter{
|
||||
Template: c.Template,
|
||||
FullSources: ns.FullSources,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type Exit struct{}
|
||||
|
||||
func (e *Exit) Invoke(_ context.Context, in map[string]any) (map[string]any, error) {
|
||||
if in == nil {
|
||||
return map[string]any{}, nil
|
||||
}
|
||||
return in, nil
|
||||
}
|
||||
|
|
@ -0,0 +1,340 @@
|
|||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package httprequester
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/crypto"
|
||||
)
|
||||
|
||||
var extractBracesRegexp = regexp.MustCompile(`\{\{(.*?)}}`)
|
||||
|
||||
func extractBracesContent(s string) []string {
|
||||
matches := extractBracesRegexp.FindAllStringSubmatch(s, -1)
|
||||
var result []string
|
||||
for _, match := range matches {
|
||||
if len(match) >= 2 {
|
||||
result = append(result, match[1])
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
type ImplicitNodeDependency struct {
|
||||
NodeID string
|
||||
FieldPath compose.FieldPath
|
||||
TypeInfo *vo.TypeInfo
|
||||
}
|
||||
|
||||
func extractImplicitDependency(node *vo.Node, canvas *vo.Canvas) ([]*ImplicitNodeDependency, error) {
|
||||
dependencies := make([]*ImplicitNodeDependency, 0, len(canvas.Nodes))
|
||||
url := node.Data.Inputs.APIInfo.URL
|
||||
urlVars := extractBracesContent(url)
|
||||
hasReferred := make(map[string]bool)
|
||||
extractDependenciesFromVars := func(vars []string) error {
|
||||
for _, v := range vars {
|
||||
if strings.HasPrefix(v, "block_output_") {
|
||||
paths := strings.Split(strings.TrimPrefix(v, "block_output_"), ".")
|
||||
if len(paths) < 2 {
|
||||
return fmt.Errorf("invalid block_output_ variable: %s", v)
|
||||
}
|
||||
if hasReferred[v] {
|
||||
continue
|
||||
}
|
||||
hasReferred[v] = true
|
||||
dependencies = append(dependencies, &ImplicitNodeDependency{
|
||||
NodeID: paths[0],
|
||||
FieldPath: paths[1:],
|
||||
})
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
err := extractDependenciesFromVars(urlVars)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if node.Data.Inputs.Body.BodyType == string(BodyTypeJSON) {
|
||||
jsonVars := extractBracesContent(node.Data.Inputs.Body.BodyData.Json)
|
||||
err = extractDependenciesFromVars(jsonVars)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if node.Data.Inputs.Body.BodyType == string(BodyTypeRawText) {
|
||||
rawTextVars := extractBracesContent(node.Data.Inputs.Body.BodyData.Json)
|
||||
err = extractDependenciesFromVars(rawTextVars)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var nodeFinder func(nodes []*vo.Node, nodeID string) *vo.Node
|
||||
nodeFinder = func(nodes []*vo.Node, nodeID string) *vo.Node {
|
||||
for i := range nodes {
|
||||
if nodes[i].ID == nodeID {
|
||||
return nodes[i]
|
||||
}
|
||||
if len(nodes[i].Blocks) > 0 {
|
||||
if n := nodeFinder(nodes[i].Blocks, nodeID); n != nil {
|
||||
return n
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
for _, ds := range dependencies {
|
||||
fNode := nodeFinder(canvas.Nodes, ds.NodeID)
|
||||
if fNode == nil {
|
||||
continue
|
||||
}
|
||||
tInfoMap := make(map[string]*vo.TypeInfo, len(node.Data.Outputs))
|
||||
for _, vAny := range fNode.Data.Outputs {
|
||||
v, err := vo.ParseVariable(vAny)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tInfo, err := convert.CanvasVariableToTypeInfo(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tInfoMap[v.Name] = tInfo
|
||||
}
|
||||
tInfo, ok := getTypeInfoByPath(ds.FieldPath[0], ds.FieldPath[1:], tInfoMap)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("cannot find type info for dependency: %s", ds.FieldPath)
|
||||
}
|
||||
ds.TypeInfo = tInfo
|
||||
}
|
||||
|
||||
return dependencies, nil
|
||||
}
|
||||
|
||||
func getTypeInfoByPath(root string, properties []string, tInfoMap map[string]*vo.TypeInfo) (*vo.TypeInfo, bool) {
|
||||
if len(properties) == 0 {
|
||||
if tInfo, ok := tInfoMap[root]; ok {
|
||||
return tInfo, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
tInfo, ok := tInfoMap[root]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return getTypeInfoByPath(properties[0], properties[1:], tInfo.Properties)
|
||||
}
|
||||
|
||||
var globalVariableRegex = regexp.MustCompile(`global_variable_\w+\s*\["(.*?)"]`)
|
||||
|
||||
func setHttpRequesterInputsForNodeSchema(n *vo.Node, ns *schema.NodeSchema, implicitNodeDependencies []*ImplicitNodeDependency) (err error) {
|
||||
inputs := n.Data.Inputs
|
||||
implicitPathVars := make(map[string]bool)
|
||||
addImplicitVarsSources := func(prefix string, vars []string) error {
|
||||
for _, v := range vars {
|
||||
if strings.HasPrefix(v, "block_output_") {
|
||||
paths := strings.Split(strings.TrimPrefix(v, "block_output_"), ".")
|
||||
if len(paths) < 2 {
|
||||
return fmt.Errorf("invalid implicit var : %s", v)
|
||||
}
|
||||
for _, dep := range implicitNodeDependencies {
|
||||
if dep.NodeID == paths[0] && strings.Join(dep.FieldPath, ".") == strings.Join(paths[1:], ".") {
|
||||
pathValue := prefix + crypto.MD5HexValue(v)
|
||||
if _, visited := implicitPathVars[pathValue]; visited {
|
||||
continue
|
||||
}
|
||||
implicitPathVars[pathValue] = true
|
||||
ns.SetInputType(pathValue, dep.TypeInfo)
|
||||
ns.AddInputSource(&vo.FieldInfo{
|
||||
Path: []string{pathValue},
|
||||
Source: vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
FromNodeKey: vo.NodeKey(dep.NodeID),
|
||||
FromPath: dep.FieldPath,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(v, "global_variable_") {
|
||||
matches := globalVariableRegex.FindStringSubmatch(v)
|
||||
if len(matches) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
var varType vo.GlobalVarType
|
||||
if strings.HasPrefix(v, string(vo.RefSourceTypeGlobalApp)) {
|
||||
varType = vo.GlobalAPP
|
||||
} else if strings.HasPrefix(v, string(vo.RefSourceTypeGlobalUser)) {
|
||||
varType = vo.GlobalUser
|
||||
} else if strings.HasPrefix(v, string(vo.RefSourceTypeGlobalSystem)) {
|
||||
varType = vo.GlobalSystem
|
||||
} else {
|
||||
return fmt.Errorf("invalid global variable type: %s", v)
|
||||
}
|
||||
|
||||
source := vo.FieldSource{
|
||||
Ref: &vo.Reference{
|
||||
VariableType: &varType,
|
||||
FromPath: []string{matches[1]},
|
||||
},
|
||||
}
|
||||
|
||||
ns.AddInputSource(&vo.FieldInfo{
|
||||
Path: []string{prefix + crypto.MD5HexValue(v)},
|
||||
Source: source,
|
||||
})
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
urlVars := extractBracesContent(inputs.APIInfo.URL)
|
||||
err = addImplicitVarsSources("__apiInfo_url_", urlVars)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = applyParamsToSchema(ns, "__headers_", inputs.Headers, n.Parent())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = applyParamsToSchema(ns, "__params_", inputs.Params, n.Parent())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if inputs.Auth != nil && inputs.Auth.AuthOpen {
|
||||
authData := inputs.Auth.AuthData
|
||||
const bearerTokenKey = "__auth_authData_bearerTokenData_token"
|
||||
if inputs.Auth.AuthType == "BEARER_AUTH" {
|
||||
bearTokenParam := authData.BearerTokenData[0]
|
||||
tInfo, err := convert.CanvasBlockInputToTypeInfo(bearTokenParam.Input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ns.SetInputType(bearerTokenKey, tInfo)
|
||||
sources, err := convert.CanvasBlockInputToFieldInfo(bearTokenParam.Input, compose.FieldPath{bearerTokenKey}, n.Parent())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ns.AddInputSource(sources...)
|
||||
|
||||
}
|
||||
|
||||
if inputs.Auth.AuthType == "CUSTOM_AUTH" {
|
||||
const (
|
||||
customDataDataKey = "__auth_authData_customData_data_Key"
|
||||
customDataDataValue = "__auth_authData_customData_data_Value"
|
||||
)
|
||||
dataParams := authData.CustomData.Data
|
||||
keyParam := dataParams[0]
|
||||
keyTypeInfo, err := convert.CanvasBlockInputToTypeInfo(keyParam.Input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ns.SetInputType(customDataDataKey, keyTypeInfo)
|
||||
sources, err := convert.CanvasBlockInputToFieldInfo(keyParam.Input, compose.FieldPath{customDataDataKey}, n.Parent())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ns.AddInputSource(sources...)
|
||||
|
||||
valueParam := dataParams[1]
|
||||
valueTypeInfo, err := convert.CanvasBlockInputToTypeInfo(valueParam.Input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ns.SetInputType(customDataDataValue, valueTypeInfo)
|
||||
sources, err = convert.CanvasBlockInputToFieldInfo(valueParam.Input, compose.FieldPath{customDataDataValue}, n.Parent())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ns.AddInputSource(sources...)
|
||||
}
|
||||
}
|
||||
|
||||
switch BodyType(inputs.Body.BodyType) {
|
||||
case BodyTypeFormData:
|
||||
err = applyParamsToSchema(ns, "__body_bodyData_formData_", inputs.Body.BodyData.FormData.Data, n.Parent())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case BodyTypeFormURLEncoded:
|
||||
err = applyParamsToSchema(ns, "__body_bodyData_formURLEncoded_", inputs.Body.BodyData.FormURLEncoded, n.Parent())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case BodyTypeBinary:
|
||||
const fileURLName = "__body_bodyData_binary_fileURL"
|
||||
fileURLInput := inputs.Body.BodyData.Binary.FileURL
|
||||
ns.SetInputType(fileURLName, &vo.TypeInfo{
|
||||
Type: vo.DataTypeString,
|
||||
})
|
||||
sources, err := convert.CanvasBlockInputToFieldInfo(fileURLInput, compose.FieldPath{fileURLName}, n.Parent())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ns.AddInputSource(sources...)
|
||||
case BodyTypeJSON:
|
||||
jsonVars := extractBracesContent(inputs.Body.BodyData.Json)
|
||||
err = addImplicitVarsSources("__body_bodyData_json_", jsonVars)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case BodyTypeRawText:
|
||||
rawTextVars := extractBracesContent(inputs.Body.BodyData.RawText)
|
||||
err = addImplicitVarsSources("__body_bodyData_rawText_", rawTextVars)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func applyParamsToSchema(ns *schema.NodeSchema, prefix string, params []*vo.Param, parentNode *vo.Node) error {
|
||||
for i := range params {
|
||||
param := params[i]
|
||||
name := param.Name
|
||||
tInfo, err := convert.CanvasBlockInputToTypeInfo(param.Input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fieldName := prefix + crypto.MD5HexValue(name)
|
||||
ns.SetInputType(fieldName, tInfo)
|
||||
sources, err := convert.CanvasBlockInputToFieldInfo(param.Input, compose.FieldPath{fieldName}, parentNode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ns.AddInputSource(sources...)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -31,9 +31,14 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/crypto"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
)
|
||||
|
||||
|
|
@ -129,7 +134,7 @@ type Request struct {
|
|||
FileURL *string
|
||||
}
|
||||
|
||||
var globalVariableReplaceRegexp = regexp.MustCompile(`global_variable_(\w+)\["(\w+)"\]`)
|
||||
var globalVariableReplaceRegexp = regexp.MustCompile(`global_variable_(\w+)\["(\w+)"]`)
|
||||
|
||||
type MD5FieldMapping struct {
|
||||
HeaderMD5Mapping map[string]string `json:"header_md_5_mapping,omitempty"` // md5 vs key
|
||||
|
|
@ -184,49 +189,188 @@ type Config struct {
|
|||
Timeout time.Duration
|
||||
RetryTimes uint64
|
||||
|
||||
IgnoreException bool
|
||||
DefaultOutput map[string]any
|
||||
|
||||
MD5FieldMapping
|
||||
}
|
||||
|
||||
func (c *Config) Adapt(_ context.Context, n *vo.Node, opts ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
options := nodes.GetAdaptOptions(opts...)
|
||||
if options.Canvas == nil {
|
||||
return nil, fmt.Errorf("canvas is requried when adapting HTTPRequester node")
|
||||
}
|
||||
|
||||
implicitDeps, err := extractImplicitDependency(n, options.Canvas)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeHTTPRequester,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: c,
|
||||
}
|
||||
|
||||
inputs := n.Data.Inputs
|
||||
|
||||
md5FieldMapping := &MD5FieldMapping{}
|
||||
|
||||
method := inputs.APIInfo.Method
|
||||
c.Method = method
|
||||
reqURL := inputs.APIInfo.URL
|
||||
c.URLConfig = URLConfig{
|
||||
Tpl: strings.TrimSpace(reqURL),
|
||||
}
|
||||
|
||||
urlVars := extractBracesContent(reqURL)
|
||||
md5FieldMapping.SetURLFields(urlVars...)
|
||||
|
||||
md5FieldMapping.SetHeaderFields(slices.Transform(inputs.Headers, func(a *vo.Param) string {
|
||||
return a.Name
|
||||
})...)
|
||||
|
||||
md5FieldMapping.SetParamFields(slices.Transform(inputs.Params, func(a *vo.Param) string {
|
||||
return a.Name
|
||||
})...)
|
||||
|
||||
if inputs.Auth != nil && inputs.Auth.AuthOpen {
|
||||
auth := &AuthenticationConfig{}
|
||||
ty, err := convertAuthType(inputs.Auth.AuthType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
auth.Type = ty
|
||||
location, err := convertLocation(inputs.Auth.AuthData.CustomData.AddTo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
auth.Location = location
|
||||
|
||||
c.AuthConfig = auth
|
||||
}
|
||||
|
||||
bodyConfig := BodyConfig{}
|
||||
|
||||
bodyConfig.BodyType = BodyType(inputs.Body.BodyType)
|
||||
switch BodyType(inputs.Body.BodyType) {
|
||||
case BodyTypeJSON:
|
||||
jsonTpl := inputs.Body.BodyData.Json
|
||||
bodyConfig.TextJsonConfig = &TextJsonConfig{
|
||||
Tpl: jsonTpl,
|
||||
}
|
||||
jsonVars := extractBracesContent(jsonTpl)
|
||||
md5FieldMapping.SetBodyFields(jsonVars...)
|
||||
case BodyTypeFormData:
|
||||
bodyConfig.FormDataConfig = &FormDataConfig{
|
||||
FileTypeMapping: map[string]bool{},
|
||||
}
|
||||
formDataVars := make([]string, 0)
|
||||
for i := range inputs.Body.BodyData.FormData.Data {
|
||||
p := inputs.Body.BodyData.FormData.Data[i]
|
||||
formDataVars = append(formDataVars, p.Name)
|
||||
if p.Input.Type == vo.VariableTypeString && p.Input.AssistType > vo.AssistTypeNotSet && p.Input.AssistType < vo.AssistTypeTime {
|
||||
bodyConfig.FormDataConfig.FileTypeMapping[p.Name] = true
|
||||
}
|
||||
}
|
||||
|
||||
md5FieldMapping.SetBodyFields(formDataVars...)
|
||||
case BodyTypeRawText:
|
||||
TextTpl := inputs.Body.BodyData.RawText
|
||||
bodyConfig.TextPlainConfig = &TextPlainConfig{
|
||||
Tpl: TextTpl,
|
||||
}
|
||||
textPlainVars := extractBracesContent(TextTpl)
|
||||
md5FieldMapping.SetBodyFields(textPlainVars...)
|
||||
case BodyTypeFormURLEncoded:
|
||||
formURLEncodedVars := make([]string, 0)
|
||||
for _, p := range inputs.Body.BodyData.FormURLEncoded {
|
||||
formURLEncodedVars = append(formURLEncodedVars, p.Name)
|
||||
}
|
||||
md5FieldMapping.SetBodyFields(formURLEncodedVars...)
|
||||
}
|
||||
c.BodyConfig = bodyConfig
|
||||
c.MD5FieldMapping = *md5FieldMapping
|
||||
|
||||
if inputs.Setting != nil {
|
||||
c.Timeout = time.Duration(inputs.Setting.Timeout) * time.Second
|
||||
c.RetryTimes = uint64(inputs.Setting.RetryTimes)
|
||||
}
|
||||
|
||||
if err := setHttpRequesterInputsForNodeSchema(n, ns, implicitDeps); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func convertAuthType(auth string) (AuthType, error) {
|
||||
switch auth {
|
||||
case "CUSTOM_AUTH":
|
||||
return Custom, nil
|
||||
case "BEARER_AUTH":
|
||||
return BearToken, nil
|
||||
default:
|
||||
return AuthType(0), fmt.Errorf("invalid auth type")
|
||||
}
|
||||
}
|
||||
|
||||
func convertLocation(l string) (Location, error) {
|
||||
switch l {
|
||||
case "header":
|
||||
return Header, nil
|
||||
case "query":
|
||||
return QueryParam, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("invalid location")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (c *Config) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
if len(c.Method) == 0 {
|
||||
return nil, fmt.Errorf("method is requried")
|
||||
}
|
||||
|
||||
hg := &HTTPRequester{
|
||||
urlConfig: c.URLConfig,
|
||||
method: c.Method,
|
||||
retryTimes: c.RetryTimes,
|
||||
authConfig: c.AuthConfig,
|
||||
bodyConfig: c.BodyConfig,
|
||||
md5FieldMapping: c.MD5FieldMapping,
|
||||
}
|
||||
client := http.DefaultClient
|
||||
if c.Timeout > 0 {
|
||||
client.Timeout = c.Timeout
|
||||
}
|
||||
|
||||
hg.client = client
|
||||
|
||||
return hg, nil
|
||||
}
|
||||
|
||||
type HTTPRequester struct {
|
||||
client *http.Client
|
||||
config *Config
|
||||
}
|
||||
|
||||
func NewHTTPRequester(_ context.Context, cfg *Config) (*HTTPRequester, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("config is requried")
|
||||
}
|
||||
|
||||
if len(cfg.Method) == 0 {
|
||||
return nil, fmt.Errorf("method is requried")
|
||||
}
|
||||
|
||||
hg := &HTTPRequester{}
|
||||
client := http.DefaultClient
|
||||
if cfg.Timeout > 0 {
|
||||
client.Timeout = cfg.Timeout
|
||||
}
|
||||
|
||||
hg.client = client
|
||||
hg.config = cfg
|
||||
|
||||
return hg, nil
|
||||
urlConfig URLConfig
|
||||
authConfig *AuthenticationConfig
|
||||
bodyConfig BodyConfig
|
||||
method string
|
||||
retryTimes uint64
|
||||
md5FieldMapping MD5FieldMapping
|
||||
}
|
||||
|
||||
func (hg *HTTPRequester) Invoke(ctx context.Context, input map[string]any) (output map[string]any, err error) {
|
||||
var (
|
||||
req = &Request{}
|
||||
method = hg.config.Method
|
||||
retryTimes = hg.config.RetryTimes
|
||||
method = hg.method
|
||||
retryTimes = hg.retryTimes
|
||||
body io.ReadCloser
|
||||
contentType string
|
||||
response *http.Response
|
||||
)
|
||||
|
||||
req, err = hg.config.parserToRequest(input)
|
||||
req, err = hg.parserToRequest(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -236,7 +380,7 @@ func (hg *HTTPRequester) Invoke(ctx context.Context, input map[string]any) (outp
|
|||
Header: http.Header{},
|
||||
}
|
||||
|
||||
httpURL, err := nodes.TemplateRender(hg.config.URLConfig.Tpl, req.URLVars)
|
||||
httpURL, err := nodes.TemplateRender(hg.urlConfig.Tpl, req.URLVars)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -255,8 +399,8 @@ func (hg *HTTPRequester) Invoke(ctx context.Context, input map[string]any) (outp
|
|||
params.Set(key, value)
|
||||
}
|
||||
|
||||
if hg.config.AuthConfig != nil {
|
||||
httpRequest.Header, params, err = hg.config.AuthConfig.addAuthentication(ctx, req.Authentication, httpRequest.Header, params)
|
||||
if hg.authConfig != nil {
|
||||
httpRequest.Header, params, err = hg.authConfig.addAuthentication(ctx, req.Authentication, httpRequest.Header, params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -264,7 +408,7 @@ func (hg *HTTPRequester) Invoke(ctx context.Context, input map[string]any) (outp
|
|||
u.RawQuery = params.Encode()
|
||||
httpRequest.URL = u
|
||||
|
||||
body, contentType, err = hg.config.BodyConfig.getBodyAndContentType(ctx, req)
|
||||
body, contentType, err = hg.bodyConfig.getBodyAndContentType(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -479,18 +623,16 @@ func httpGet(ctx context.Context, url string) (*http.Response, error) {
|
|||
}
|
||||
|
||||
func (hg *HTTPRequester) ToCallbackInput(_ context.Context, input map[string]any) (map[string]any, error) {
|
||||
var (
|
||||
request = &Request{}
|
||||
config = hg.config
|
||||
)
|
||||
request, err := hg.config.parserToRequest(input)
|
||||
var request = &Request{}
|
||||
|
||||
request, err := hg.parserToRequest(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := make(map[string]any)
|
||||
result["method"] = config.Method
|
||||
result["method"] = hg.method
|
||||
|
||||
u, err := nodes.TemplateRender(config.URLConfig.Tpl, request.URLVars)
|
||||
u, err := nodes.TemplateRender(hg.urlConfig.Tpl, request.URLVars)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -508,13 +650,13 @@ func (hg *HTTPRequester) ToCallbackInput(_ context.Context, input map[string]any
|
|||
}
|
||||
result["header"] = headers
|
||||
result["auth"] = nil
|
||||
if config.AuthConfig != nil {
|
||||
if config.AuthConfig.Type == Custom {
|
||||
if hg.authConfig != nil {
|
||||
if hg.authConfig.Type == Custom {
|
||||
result["auth"] = map[string]interface{}{
|
||||
"Key": request.Authentication.Key,
|
||||
"Value": request.Authentication.Value,
|
||||
}
|
||||
} else if config.AuthConfig.Type == BearToken {
|
||||
} else if hg.authConfig.Type == BearToken {
|
||||
result["auth"] = map[string]interface{}{
|
||||
"token": request.Authentication.Token,
|
||||
}
|
||||
|
|
@ -522,9 +664,9 @@ func (hg *HTTPRequester) ToCallbackInput(_ context.Context, input map[string]any
|
|||
}
|
||||
|
||||
result["body"] = nil
|
||||
switch config.BodyConfig.BodyType {
|
||||
switch hg.bodyConfig.BodyType {
|
||||
case BodyTypeJSON:
|
||||
js, err := nodes.TemplateRender(config.BodyConfig.TextJsonConfig.Tpl, request.JsonVars)
|
||||
js, err := nodes.TemplateRender(hg.bodyConfig.TextJsonConfig.Tpl, request.JsonVars)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -535,7 +677,7 @@ func (hg *HTTPRequester) ToCallbackInput(_ context.Context, input map[string]any
|
|||
}
|
||||
result["body"] = ret
|
||||
case BodyTypeRawText:
|
||||
tx, err := nodes.TemplateRender(config.BodyConfig.TextPlainConfig.Tpl, request.TextPlainVars)
|
||||
tx, err := nodes.TemplateRender(hg.bodyConfig.TextPlainConfig.Tpl, request.TextPlainVars)
|
||||
if err != nil {
|
||||
|
||||
return nil, err
|
||||
|
|
@ -569,7 +711,7 @@ const (
|
|||
bodyBinaryFileURLPrefix = "binary_fileURL"
|
||||
)
|
||||
|
||||
func (cfg *Config) parserToRequest(input map[string]any) (*Request, error) {
|
||||
func (hg *HTTPRequester) parserToRequest(input map[string]any) (*Request, error) {
|
||||
request := &Request{
|
||||
URLVars: make(map[string]any),
|
||||
Headers: make(map[string]string),
|
||||
|
|
@ -583,7 +725,7 @@ func (cfg *Config) parserToRequest(input map[string]any) (*Request, error) {
|
|||
for key, value := range input {
|
||||
if strings.HasPrefix(key, apiInfoURLPrefix) {
|
||||
urlMD5 := strings.TrimPrefix(key, apiInfoURLPrefix)
|
||||
if urlKey, ok := cfg.URLMD5Mapping[urlMD5]; ok {
|
||||
if urlKey, ok := hg.md5FieldMapping.URLMD5Mapping[urlMD5]; ok {
|
||||
if strings.HasPrefix(urlKey, "global_variable_") {
|
||||
urlKey = globalVariableReplaceRegexp.ReplaceAllString(urlKey, "global_variable_$1.$2")
|
||||
}
|
||||
|
|
@ -592,13 +734,13 @@ func (cfg *Config) parserToRequest(input map[string]any) (*Request, error) {
|
|||
}
|
||||
if strings.HasPrefix(key, headersPrefix) {
|
||||
headerKeyMD5 := strings.TrimPrefix(key, headersPrefix)
|
||||
if headerKey, ok := cfg.HeaderMD5Mapping[headerKeyMD5]; ok {
|
||||
if headerKey, ok := hg.md5FieldMapping.HeaderMD5Mapping[headerKeyMD5]; ok {
|
||||
request.Headers[headerKey] = value.(string)
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(key, paramsPrefix) {
|
||||
paramKeyMD5 := strings.TrimPrefix(key, paramsPrefix)
|
||||
if paramKey, ok := cfg.ParamMD5Mapping[paramKeyMD5]; ok {
|
||||
if paramKey, ok := hg.md5FieldMapping.ParamMD5Mapping[paramKeyMD5]; ok {
|
||||
request.Params[paramKey] = value.(string)
|
||||
}
|
||||
}
|
||||
|
|
@ -622,7 +764,7 @@ func (cfg *Config) parserToRequest(input map[string]any) (*Request, error) {
|
|||
bodyKey := strings.TrimPrefix(key, bodyDataPrefix)
|
||||
if strings.HasPrefix(bodyKey, bodyJsonPrefix) {
|
||||
jsonMd5Key := strings.TrimPrefix(bodyKey, bodyJsonPrefix)
|
||||
if jsonKey, ok := cfg.BodyMD5Mapping[jsonMd5Key]; ok {
|
||||
if jsonKey, ok := hg.md5FieldMapping.BodyMD5Mapping[jsonMd5Key]; ok {
|
||||
if strings.HasPrefix(jsonKey, "global_variable_") {
|
||||
jsonKey = globalVariableReplaceRegexp.ReplaceAllString(jsonKey, "global_variable_$1.$2")
|
||||
}
|
||||
|
|
@ -632,7 +774,7 @@ func (cfg *Config) parserToRequest(input map[string]any) (*Request, error) {
|
|||
}
|
||||
if strings.HasPrefix(bodyKey, bodyFormDataPrefix) {
|
||||
formDataMd5Key := strings.TrimPrefix(bodyKey, bodyFormDataPrefix)
|
||||
if formDataKey, ok := cfg.BodyMD5Mapping[formDataMd5Key]; ok {
|
||||
if formDataKey, ok := hg.md5FieldMapping.BodyMD5Mapping[formDataMd5Key]; ok {
|
||||
request.FormDataVars[formDataKey] = value.(string)
|
||||
}
|
||||
|
||||
|
|
@ -640,14 +782,14 @@ func (cfg *Config) parserToRequest(input map[string]any) (*Request, error) {
|
|||
|
||||
if strings.HasPrefix(bodyKey, bodyFormURLEncodedPrefix) {
|
||||
formURLEncodeMd5Key := strings.TrimPrefix(bodyKey, bodyFormURLEncodedPrefix)
|
||||
if formURLEncodeKey, ok := cfg.BodyMD5Mapping[formURLEncodeMd5Key]; ok {
|
||||
if formURLEncodeKey, ok := hg.md5FieldMapping.BodyMD5Mapping[formURLEncodeMd5Key]; ok {
|
||||
request.FormURLEncodedVars[formURLEncodeKey] = value.(string)
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(bodyKey, bodyRawTextPrefix) {
|
||||
rawTextMd5Key := strings.TrimPrefix(bodyKey, bodyRawTextPrefix)
|
||||
if rawTextKey, ok := cfg.BodyMD5Mapping[rawTextMd5Key]; ok {
|
||||
if rawTextKey, ok := hg.md5FieldMapping.BodyMD5Mapping[rawTextMd5Key]; ok {
|
||||
if strings.HasPrefix(rawTextKey, "global_variable_") {
|
||||
rawTextKey = globalVariableReplaceRegexp.ReplaceAllString(rawTextKey, "global_variable_$1.$2")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ import (
|
|||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/crypto"
|
||||
)
|
||||
|
||||
|
|
@ -68,7 +69,7 @@ func TestInvoke(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}
|
||||
hg, err := NewHTTPRequester(context.Background(), cfg)
|
||||
hg, err := cfg.Build(context.Background(), &schema.NodeSchema{})
|
||||
assert.NoError(t, err)
|
||||
m := map[string]any{
|
||||
"__apiInfo_url_" + crypto.MD5HexValue("url_v1"): "v1",
|
||||
|
|
@ -78,7 +79,7 @@ func TestInvoke(t *testing.T) {
|
|||
"__params_" + crypto.MD5HexValue("p2"): "v2",
|
||||
}
|
||||
|
||||
result, err := hg.Invoke(context.Background(), m)
|
||||
result, err := hg.(*HTTPRequester).Invoke(context.Background(), m)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, `{"message":"success"}`, result["body"])
|
||||
assert.Equal(t, int64(200), result["statusCode"])
|
||||
|
|
@ -157,7 +158,7 @@ func TestInvoke(t *testing.T) {
|
|||
}
|
||||
|
||||
// Create an HTTPRequest instance
|
||||
hg, err := NewHTTPRequester(context.Background(), cfg)
|
||||
hg, err := cfg.Build(context.Background(), &schema.NodeSchema{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
m := map[string]any{
|
||||
|
|
@ -171,7 +172,7 @@ func TestInvoke(t *testing.T) {
|
|||
"__body_bodyData_formData_" + crypto.MD5HexValue("fileURL"): fileServer.URL,
|
||||
}
|
||||
|
||||
result, err := hg.Invoke(context.Background(), m)
|
||||
result, err := hg.(*HTTPRequester).Invoke(context.Background(), m)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, `{"message":"success"}`, result["body"])
|
||||
assert.Equal(t, int64(200), result["statusCode"])
|
||||
|
|
@ -228,7 +229,7 @@ func TestInvoke(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}
|
||||
hg, err := NewHTTPRequester(context.Background(), cfg)
|
||||
hg, err := cfg.Build(context.Background(), &schema.NodeSchema{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
m := map[string]any{
|
||||
|
|
@ -241,7 +242,7 @@ func TestInvoke(t *testing.T) {
|
|||
"__body_bodyData_rawText_" + crypto.MD5HexValue("v2"): "v2",
|
||||
}
|
||||
|
||||
result, err := hg.Invoke(context.Background(), m)
|
||||
result, err := hg.(*HTTPRequester).Invoke(context.Background(), m)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, `{"message":"success"}`, result["body"])
|
||||
assert.Equal(t, int64(200), result["statusCode"])
|
||||
|
|
@ -303,7 +304,7 @@ func TestInvoke(t *testing.T) {
|
|||
}
|
||||
|
||||
// Create an HTTPRequest instance
|
||||
hg, err := NewHTTPRequester(context.Background(), cfg)
|
||||
hg, err := cfg.Build(context.Background(), &schema.NodeSchema{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
m := map[string]any{
|
||||
|
|
@ -316,7 +317,7 @@ func TestInvoke(t *testing.T) {
|
|||
"__body_bodyData_json_" + crypto.MD5HexValue("v2"): "v2",
|
||||
}
|
||||
|
||||
result, err := hg.Invoke(context.Background(), m)
|
||||
result, err := hg.(*HTTPRequester).Invoke(context.Background(), m)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, `{"message":"success"}`, result["body"])
|
||||
assert.Equal(t, int64(200), result["statusCode"])
|
||||
|
|
@ -376,7 +377,7 @@ func TestInvoke(t *testing.T) {
|
|||
}
|
||||
|
||||
// Create an HTTPRequest instance
|
||||
hg, err := NewHTTPRequester(context.Background(), cfg)
|
||||
hg, err := cfg.Build(context.Background(), &schema.NodeSchema{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
m := map[string]any{
|
||||
|
|
@ -388,7 +389,7 @@ func TestInvoke(t *testing.T) {
|
|||
"__body_bodyData_binary_fileURL" + crypto.MD5HexValue("v1"): fileServer.URL,
|
||||
}
|
||||
|
||||
result, err := hg.Invoke(context.Background(), m)
|
||||
result, err := hg.(*HTTPRequester).Invoke(context.Background(), m)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, `{"message":"success"}`, result["body"])
|
||||
assert.Equal(t, int64(200), result["statusCode"])
|
||||
|
|
|
|||
|
|
@ -18,26 +18,167 @@ package intentdetector
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/components/prompt"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/spf13/cast"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Intents []string
|
||||
SystemPrompt string
|
||||
IsFastMode bool
|
||||
ChatModel model.BaseChatModel
|
||||
LLMParams *model.LLMParams
|
||||
}
|
||||
|
||||
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema2.NodeSchema, error) {
|
||||
ns := &schema2.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeIntentDetector,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: c,
|
||||
}
|
||||
|
||||
param := n.Data.Inputs.LLMParam
|
||||
if param == nil {
|
||||
return nil, fmt.Errorf("intent detector node's llmParam is nil")
|
||||
}
|
||||
|
||||
llmParam, ok := param.(vo.IntentDetectorLLMParam)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("llm node's llmParam must be LLMParam, got %v", llmParam)
|
||||
}
|
||||
|
||||
paramBytes, err := sonic.Marshal(param)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var intentDetectorConfig = &vo.IntentDetectorLLMConfig{}
|
||||
|
||||
err = sonic.Unmarshal(paramBytes, &intentDetectorConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
modelLLMParams := &model.LLMParams{}
|
||||
modelLLMParams.ModelType = int64(intentDetectorConfig.ModelType)
|
||||
modelLLMParams.ModelName = intentDetectorConfig.ModelName
|
||||
modelLLMParams.TopP = intentDetectorConfig.TopP
|
||||
modelLLMParams.Temperature = intentDetectorConfig.Temperature
|
||||
modelLLMParams.MaxTokens = intentDetectorConfig.MaxTokens
|
||||
modelLLMParams.ResponseFormat = model.ResponseFormat(intentDetectorConfig.ResponseFormat)
|
||||
modelLLMParams.SystemPrompt = intentDetectorConfig.SystemPrompt.Value.Content.(string)
|
||||
|
||||
c.LLMParams = modelLLMParams
|
||||
c.SystemPrompt = modelLLMParams.SystemPrompt
|
||||
|
||||
var intents = make([]string, 0, len(n.Data.Inputs.Intents))
|
||||
for _, it := range n.Data.Inputs.Intents {
|
||||
intents = append(intents, it.Name)
|
||||
}
|
||||
c.Intents = intents
|
||||
|
||||
if n.Data.Inputs.Mode == "top_speed" {
|
||||
c.IsFastMode = true
|
||||
}
|
||||
|
||||
if err = convert.SetInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func (c *Config) Build(ctx context.Context, _ *schema2.NodeSchema, _ ...schema2.BuildOption) (any, error) {
|
||||
if !c.IsFastMode && c.LLMParams == nil {
|
||||
return nil, errors.New("config chat model is required")
|
||||
}
|
||||
|
||||
if len(c.Intents) == 0 {
|
||||
return nil, errors.New("config intents is required")
|
||||
}
|
||||
|
||||
m, _, err := model.GetManager().GetModel(ctx, c.LLMParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
chain := compose.NewChain[map[string]any, *schema.Message]()
|
||||
|
||||
spt := ternary.IFElse[string](c.IsFastMode, FastModeSystemIntentPrompt, SystemIntentPrompt)
|
||||
|
||||
intents, err := toIntentString(c.Intents)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sptTemplate, err := nodes.TemplateRender(spt, map[string]interface{}{
|
||||
"intents": intents,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
prompts := prompt.FromMessages(schema.Jinja2,
|
||||
&schema.Message{Content: sptTemplate, Role: schema.System},
|
||||
&schema.Message{Content: "{{query}}", Role: schema.User})
|
||||
|
||||
r, err := chain.AppendChatTemplate(prompts).AppendChatModel(m).Compile(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &IntentDetector{
|
||||
isFastMode: c.IsFastMode,
|
||||
systemPrompt: c.SystemPrompt,
|
||||
runner: r,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Config) BuildBranch(_ context.Context) (
|
||||
func(ctx context.Context, nodeOutput map[string]any) (int64, bool, error), bool) {
|
||||
return func(ctx context.Context, nodeOutput map[string]any) (int64, bool, error) {
|
||||
classificationId, ok := nodeOutput[classificationID]
|
||||
if !ok {
|
||||
return -1, false, fmt.Errorf("failed to take classification id from input map: %v", nodeOutput)
|
||||
}
|
||||
|
||||
cID64, ok := classificationId.(int64)
|
||||
if !ok {
|
||||
return -1, false, fmt.Errorf("classificationID not of type int64, actual type: %T", classificationId)
|
||||
}
|
||||
|
||||
if cID64 == 0 {
|
||||
return -1, true, nil
|
||||
}
|
||||
|
||||
return cID64 - 1, false, nil
|
||||
}, true
|
||||
}
|
||||
|
||||
func (c *Config) ExpectPorts(ctx context.Context, n *vo.Node) []string {
|
||||
expects := make([]string, len(n.Data.Inputs.Intents)+1)
|
||||
expects[0] = schema2.PortDefault
|
||||
for i := 0; i < len(n.Data.Inputs.Intents); i++ {
|
||||
expects[i+1] = fmt.Sprintf(schema2.PortBranchFormat, i)
|
||||
}
|
||||
return expects
|
||||
}
|
||||
|
||||
const SystemIntentPrompt = `
|
||||
|
|
@ -95,71 +236,39 @@ Note:
|
|||
##Limit
|
||||
- Please do not reply in text.`
|
||||
|
||||
const classificationID = "classificationId"
|
||||
|
||||
type IntentDetector struct {
|
||||
config *Config
|
||||
isFastMode bool
|
||||
systemPrompt string
|
||||
runner compose.Runnable[map[string]any, *schema.Message]
|
||||
}
|
||||
|
||||
func NewIntentDetector(ctx context.Context, cfg *Config) (*IntentDetector, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("cfg is required")
|
||||
}
|
||||
if !cfg.IsFastMode && cfg.ChatModel == nil {
|
||||
return nil, errors.New("config chat model is required")
|
||||
}
|
||||
|
||||
if len(cfg.Intents) == 0 {
|
||||
return nil, errors.New("config intents is required")
|
||||
}
|
||||
chain := compose.NewChain[map[string]any, *schema.Message]()
|
||||
|
||||
spt := ternary.IFElse[string](cfg.IsFastMode, FastModeSystemIntentPrompt, SystemIntentPrompt)
|
||||
|
||||
sptTemplate, err := nodes.TemplateRender(spt, map[string]interface{}{
|
||||
"intents": toIntentString(cfg.Intents),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
prompts := prompt.FromMessages(schema.Jinja2,
|
||||
&schema.Message{Content: sptTemplate, Role: schema.System},
|
||||
&schema.Message{Content: "{{query}}", Role: schema.User})
|
||||
|
||||
r, err := chain.AppendChatTemplate(prompts).AppendChatModel(cfg.ChatModel).Compile(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &IntentDetector{
|
||||
config: cfg,
|
||||
runner: r,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (id *IntentDetector) parseToNodeOut(content string) (map[string]any, error) {
|
||||
nodeOutput := make(map[string]any)
|
||||
nodeOutput["classificationId"] = 0
|
||||
if content == "" {
|
||||
return nodeOutput, errors.New("content is empty")
|
||||
return nil, errors.New("intent detector's LLM output content is empty")
|
||||
}
|
||||
|
||||
if id.config.IsFastMode {
|
||||
if id.isFastMode {
|
||||
cid, err := strconv.ParseInt(content, 10, 64)
|
||||
if err != nil {
|
||||
return nodeOutput, err
|
||||
return nil, err
|
||||
}
|
||||
nodeOutput["classificationId"] = cid
|
||||
return nodeOutput, nil
|
||||
return map[string]any{
|
||||
classificationID: cid,
|
||||
}, nil
|
||||
}
|
||||
|
||||
leftIndex := strings.Index(content, "{")
|
||||
rightIndex := strings.Index(content, "}")
|
||||
if leftIndex == -1 || rightIndex == -1 {
|
||||
return nodeOutput, errors.New("content is invalid")
|
||||
return nil, fmt.Errorf("intent detector's LLM output content is invalid: %s", content)
|
||||
}
|
||||
|
||||
err := json.Unmarshal([]byte(content[leftIndex:rightIndex+1]), &nodeOutput)
|
||||
var nodeOutput map[string]any
|
||||
err := sonic.UnmarshalString(content[leftIndex:rightIndex+1], &nodeOutput)
|
||||
if err != nil {
|
||||
return nodeOutput, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nodeOutput, nil
|
||||
|
|
@ -178,8 +287,8 @@ func (id *IntentDetector) Invoke(ctx context.Context, input map[string]any) (map
|
|||
|
||||
vars := make(map[string]any)
|
||||
vars["query"] = queryStr
|
||||
if !id.config.IsFastMode {
|
||||
ad, err := nodes.TemplateRender(id.config.SystemPrompt, map[string]any{"query": query})
|
||||
if !id.isFastMode {
|
||||
ad, err := nodes.TemplateRender(id.systemPrompt, map[string]any{"query": query})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -193,7 +302,7 @@ func (id *IntentDetector) Invoke(ctx context.Context, input map[string]any) (map
|
|||
return id.parseToNodeOut(o.Content)
|
||||
}
|
||||
|
||||
func toIntentString(its []string) string {
|
||||
func toIntentString(its []string) (string, error) {
|
||||
type IntentVariableItem struct {
|
||||
ClassificationID int64 `json:"classificationId"`
|
||||
Content string `json:"content"`
|
||||
|
|
@ -207,6 +316,6 @@ func toIntentString(its []string) string {
|
|||
Content: it,
|
||||
})
|
||||
}
|
||||
itsBytes, _ := json.Marshal(vs)
|
||||
return string(itsBytes)
|
||||
|
||||
return sonic.MarshalString(vs)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,88 +0,0 @@
|
|||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package intentdetector
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type mockChatModel struct {
|
||||
topSeed bool
|
||||
}
|
||||
|
||||
func (m mockChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) {
|
||||
if m.topSeed {
|
||||
return &schema.Message{
|
||||
Content: "1",
|
||||
}, nil
|
||||
}
|
||||
return &schema.Message{
|
||||
Content: `{"classificationId":1,"reason":"高兴"}`,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m mockChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m mockChatModel) BindTools(tools []*schema.ToolInfo) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestNewIntentDetector(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
t.Run("fast mode", func(t *testing.T) {
|
||||
dt, err := NewIntentDetector(ctx, &Config{
|
||||
Intents: []string{"高兴", "悲伤"},
|
||||
IsFastMode: true,
|
||||
ChatModel: &mockChatModel{topSeed: true},
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
|
||||
ret, err := dt.Invoke(ctx, map[string]any{
|
||||
"query": "我考了100分",
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, ret["classificationId"], int64(1))
|
||||
})
|
||||
|
||||
t.Run("full mode", func(t *testing.T) {
|
||||
|
||||
dt, err := NewIntentDetector(ctx, &Config{
|
||||
Intents: []string{"高兴", "悲伤"},
|
||||
IsFastMode: false,
|
||||
ChatModel: &mockChatModel{},
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
|
||||
ret, err := dt.Invoke(ctx, map[string]any{
|
||||
"query": "我考了100分",
|
||||
})
|
||||
fmt.Println(err)
|
||||
assert.Nil(t, err)
|
||||
fmt.Println(ret)
|
||||
assert.Equal(t, ret["classificationId"], float64(1))
|
||||
assert.Equal(t, ret["reason"], "高兴")
|
||||
})
|
||||
|
||||
}
|
||||
|
|
@ -20,8 +20,11 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
|
|
@ -34,32 +37,42 @@ const (
|
|||
warningsKey = "deserialization_warnings"
|
||||
)
|
||||
|
||||
type DeserializationConfig struct {
|
||||
OutputFields map[string]*vo.TypeInfo `json:"outputFields,omitempty"`
|
||||
type DeserializationConfig struct{}
|
||||
|
||||
func (d *DeserializationConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (
|
||||
*schema.NodeSchema, error) {
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeJsonDeserialization,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: d,
|
||||
}
|
||||
|
||||
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
type Deserializer struct {
|
||||
config *DeserializationConfig
|
||||
typeInfo *vo.TypeInfo
|
||||
}
|
||||
|
||||
func NewJsonDeserializer(_ context.Context, cfg *DeserializationConfig) (*Deserializer, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("config required")
|
||||
}
|
||||
if cfg.OutputFields == nil {
|
||||
return nil, fmt.Errorf("OutputFields is required for deserialization")
|
||||
}
|
||||
typeInfo := cfg.OutputFields[OutputKeyDeserialization]
|
||||
if typeInfo == nil {
|
||||
func (d *DeserializationConfig) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
typeInfo, ok := ns.OutputTypes[OutputKeyDeserialization]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no output field specified in deserialization config")
|
||||
}
|
||||
return &Deserializer{
|
||||
config: cfg,
|
||||
typeInfo: typeInfo,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type Deserializer struct {
|
||||
typeInfo *vo.TypeInfo
|
||||
}
|
||||
|
||||
func (jd *Deserializer) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
jsonStrValue := input[InputKeyDeserialization]
|
||||
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ import (
|
|||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
)
|
||||
|
|
@ -31,19 +32,9 @@ import (
|
|||
func TestNewJsonDeserializer(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Test with nil config
|
||||
_, err := NewJsonDeserializer(ctx, nil)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "config required")
|
||||
|
||||
// Test with missing OutputFields config
|
||||
_, err = NewJsonDeserializer(ctx, &DeserializationConfig{})
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "OutputFields is required")
|
||||
|
||||
// Test with missing output key in OutputFields
|
||||
_, err = NewJsonDeserializer(ctx, &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
_, err := (&DeserializationConfig{}).Build(ctx, &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
"testKey": {Type: vo.DataTypeString},
|
||||
},
|
||||
})
|
||||
|
|
@ -51,12 +42,12 @@ func TestNewJsonDeserializer(t *testing.T) {
|
|||
assert.Contains(t, err.Error(), "no output field specified in deserialization config")
|
||||
|
||||
// Test with valid config
|
||||
validConfig := &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
validConfig := &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
OutputKeyDeserialization: {Type: vo.DataTypeString},
|
||||
},
|
||||
}
|
||||
processor, err := NewJsonDeserializer(ctx, validConfig)
|
||||
processor, err := (&DeserializationConfig{}).Build(ctx, validConfig)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, processor)
|
||||
}
|
||||
|
|
@ -65,16 +56,16 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
|
||||
// Base type test config
|
||||
baseTypeConfig := &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {Type: vo.DataTypeString},
|
||||
baseTypeConfig := &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
OutputKeyDeserialization: {Type: vo.DataTypeString},
|
||||
},
|
||||
}
|
||||
|
||||
// Object type test config
|
||||
objectTypeConfig := &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {
|
||||
objectTypeConfig := &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
OutputKeyDeserialization: {
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
"name": {Type: vo.DataTypeString, Required: true},
|
||||
|
|
@ -85,9 +76,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
|
|||
}
|
||||
|
||||
// Array type test config
|
||||
arrayTypeConfig := &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {
|
||||
arrayTypeConfig := &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
OutputKeyDeserialization: {
|
||||
Type: vo.DataTypeArray,
|
||||
ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeInteger},
|
||||
},
|
||||
|
|
@ -95,9 +86,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
|
|||
}
|
||||
|
||||
// Nested array object test config
|
||||
nestedArrayConfig := &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {
|
||||
nestedArrayConfig := &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
OutputKeyDeserialization: {
|
||||
Type: vo.DataTypeArray,
|
||||
ElemTypeInfo: &vo.TypeInfo{
|
||||
Type: vo.DataTypeObject,
|
||||
|
|
@ -113,7 +104,7 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
|
|||
// Test cases
|
||||
tests := []struct {
|
||||
name string
|
||||
config *DeserializationConfig
|
||||
config *schema.NodeSchema
|
||||
inputJSON string
|
||||
expectedOutput any
|
||||
expectErr bool
|
||||
|
|
@ -127,9 +118,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
|
|||
expectWarnings: 0,
|
||||
}, {
|
||||
name: "Test integer deserialization",
|
||||
config: &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {Type: vo.DataTypeInteger},
|
||||
config: &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
OutputKeyDeserialization: {Type: vo.DataTypeInteger},
|
||||
},
|
||||
},
|
||||
inputJSON: `123`,
|
||||
|
|
@ -138,9 +129,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
|
|||
expectWarnings: 0,
|
||||
}, {
|
||||
name: "Test boolean deserialization",
|
||||
config: &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {Type: vo.DataTypeBoolean},
|
||||
config: &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
OutputKeyDeserialization: {Type: vo.DataTypeBoolean},
|
||||
},
|
||||
},
|
||||
inputJSON: `true`,
|
||||
|
|
@ -180,9 +171,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
|
|||
expectWarnings: 0,
|
||||
}, {
|
||||
name: "Test type mismatch warning",
|
||||
config: &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {Type: vo.DataTypeInteger},
|
||||
config: &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
OutputKeyDeserialization: {Type: vo.DataTypeInteger},
|
||||
},
|
||||
},
|
||||
inputJSON: `"not a number"`,
|
||||
|
|
@ -198,9 +189,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
|
|||
expectWarnings: 0,
|
||||
}, {
|
||||
name: "Test string to integer conversion",
|
||||
config: &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {Type: vo.DataTypeInteger},
|
||||
config: &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
OutputKeyDeserialization: {Type: vo.DataTypeInteger},
|
||||
},
|
||||
},
|
||||
inputJSON: `"123"`,
|
||||
|
|
@ -209,9 +200,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
|
|||
expectWarnings: 0,
|
||||
}, {
|
||||
name: "Test float to integer conversion (integer part)",
|
||||
config: &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {Type: vo.DataTypeInteger},
|
||||
config: &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
OutputKeyDeserialization: {Type: vo.DataTypeInteger},
|
||||
},
|
||||
},
|
||||
inputJSON: `123.0`,
|
||||
|
|
@ -220,9 +211,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
|
|||
expectWarnings: 0,
|
||||
}, {
|
||||
name: "Test float to integer conversion (non-integer part)",
|
||||
config: &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {Type: vo.DataTypeInteger},
|
||||
config: &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
OutputKeyDeserialization: {Type: vo.DataTypeInteger},
|
||||
},
|
||||
},
|
||||
inputJSON: `123.5`,
|
||||
|
|
@ -231,9 +222,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
|
|||
expectWarnings: 0,
|
||||
}, {
|
||||
name: "Test boolean to integer conversion",
|
||||
config: &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {Type: vo.DataTypeInteger},
|
||||
config: &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
OutputKeyDeserialization: {Type: vo.DataTypeInteger},
|
||||
},
|
||||
},
|
||||
inputJSON: `true`,
|
||||
|
|
@ -242,9 +233,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
|
|||
expectWarnings: 1,
|
||||
}, {
|
||||
name: "Test string to boolean conversion",
|
||||
config: &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {Type: vo.DataTypeBoolean},
|
||||
config: &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
OutputKeyDeserialization: {Type: vo.DataTypeBoolean},
|
||||
},
|
||||
},
|
||||
inputJSON: `"true"`,
|
||||
|
|
@ -253,9 +244,10 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
|
|||
expectWarnings: 0,
|
||||
}, {
|
||||
name: "Test string to integer conversion in nested object",
|
||||
config: &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {
|
||||
inputJSON: `{"age":"456"}`,
|
||||
config: &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
OutputKeyDeserialization: {
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
"age": {Type: vo.DataTypeInteger},
|
||||
|
|
@ -263,15 +255,14 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
inputJSON: `{"age":"456"}`,
|
||||
expectedOutput: map[string]any{"age": 456},
|
||||
expectErr: false,
|
||||
expectWarnings: 0,
|
||||
}, {
|
||||
name: "Test string to integer conversion for array elements",
|
||||
config: &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {
|
||||
config: &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
OutputKeyDeserialization: {
|
||||
Type: vo.DataTypeArray,
|
||||
ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeInteger},
|
||||
},
|
||||
|
|
@ -283,9 +274,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
|
|||
expectWarnings: 0,
|
||||
}, {
|
||||
name: "Test string with non-numeric characters to integer conversion",
|
||||
config: &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {Type: vo.DataTypeInteger},
|
||||
config: &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
OutputKeyDeserialization: {Type: vo.DataTypeInteger},
|
||||
},
|
||||
},
|
||||
inputJSON: `"123abc"`,
|
||||
|
|
@ -294,9 +285,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
|
|||
expectWarnings: 1,
|
||||
}, {
|
||||
name: "Test type mismatch in nested object field",
|
||||
config: &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {
|
||||
config: &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
OutputKeyDeserialization: {
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
"score": {Type: vo.DataTypeInteger},
|
||||
|
|
@ -310,9 +301,9 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
|
|||
expectWarnings: 1,
|
||||
}, {
|
||||
name: "Test partial conversion failure in array elements",
|
||||
config: &DeserializationConfig{
|
||||
OutputFields: map[string]*vo.TypeInfo{
|
||||
"output": {
|
||||
config: &schema.NodeSchema{
|
||||
OutputTypes: map[string]*vo.TypeInfo{
|
||||
OutputKeyDeserialization: {
|
||||
Type: vo.DataTypeArray,
|
||||
ElemTypeInfo: &vo.TypeInfo{Type: vo.DataTypeInteger},
|
||||
},
|
||||
|
|
@ -326,12 +317,12 @@ func TestJsonDeserializer_Invoke(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
processor, err := NewJsonDeserializer(ctx, tt.config)
|
||||
processor, err := (&DeserializationConfig{}).Build(ctx, tt.config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
ctxWithCache := ctxcache.Init(ctx)
|
||||
input := map[string]any{"input": tt.inputJSON}
|
||||
result, err := processor.Invoke(ctxWithCache, input)
|
||||
result, err := processor.(*Deserializer).Invoke(ctxWithCache, input)
|
||||
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
|
|
|
|||
|
|
@ -20,7 +20,11 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
)
|
||||
|
||||
|
|
@ -29,28 +33,57 @@ const (
|
|||
OutputKeySerialization = "output"
|
||||
)
|
||||
|
||||
// SerializationConfig is the Config type for NodeTypeJsonSerialization.
|
||||
// Each Node Type should have its own designated Config type,
|
||||
// which should implement NodeAdaptor and NodeBuilder.
|
||||
// NOTE: we didn't define any fields for this type,
|
||||
// because this node is simple, we doesn't need to extract any SPECIFIC piece of info
|
||||
// from frontend Node. In other cases we would need to do it, such as LLM's model configs.
|
||||
type SerializationConfig struct {
|
||||
InputTypes map[string]*vo.TypeInfo
|
||||
// you can define ANY number of fields here,
|
||||
// as long as these fields are SERIALIZABLE and EXPORTED.
|
||||
// to store specific info extracted from frontend node.
|
||||
// e.g.
|
||||
// - LLM model configs
|
||||
// - conditional expressions
|
||||
// - fixed input fields such as MaxBatchSize
|
||||
}
|
||||
|
||||
type JsonSerializer struct {
|
||||
config *SerializationConfig
|
||||
}
|
||||
|
||||
func NewJsonSerializer(_ context.Context, cfg *SerializationConfig) (*JsonSerializer, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("config required")
|
||||
}
|
||||
if cfg.InputTypes == nil {
|
||||
return nil, fmt.Errorf("InputTypes is required for serialization")
|
||||
// Adapt provides conversion from Node to NodeSchema.
|
||||
// NOTE: in this specific case, we don't need AdaptOption.
|
||||
func (s *SerializationConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeJsonSerialization,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: s, // remember to set the Node's Config Type to NodeSchema as well
|
||||
}
|
||||
|
||||
return &JsonSerializer{
|
||||
config: cfg,
|
||||
}, nil
|
||||
// this sets input fields' type and mapping info
|
||||
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// this set output fields' type info
|
||||
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func (js *JsonSerializer) Invoke(_ context.Context, input map[string]any) (map[string]any, error) {
|
||||
func (s *SerializationConfig) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (
|
||||
any, error) {
|
||||
return &Serializer{}, nil
|
||||
}
|
||||
|
||||
// Serializer is the actual node implementation.
|
||||
type Serializer struct {
|
||||
// here can holds ANY data required for node execution
|
||||
}
|
||||
|
||||
// Invoke implements the InvokableNode interface.
|
||||
func (js *Serializer) Invoke(_ context.Context, input map[string]any) (map[string]any, error) {
|
||||
// Directly use the input map for serialization
|
||||
if input == nil {
|
||||
return nil, fmt.Errorf("input data for serialization cannot be nil")
|
||||
|
|
|
|||
|
|
@ -23,44 +23,34 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
func TestNewJsonSerialize(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Test with nil config
|
||||
_, err := NewJsonSerializer(ctx, nil)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "config required")
|
||||
|
||||
// Test with missing InputTypes config
|
||||
_, err = NewJsonSerializer(ctx, &SerializationConfig{})
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "InputTypes is required")
|
||||
|
||||
// Test with valid config
|
||||
validConfig := &SerializationConfig{
|
||||
s, err := (&SerializationConfig{}).Build(ctx, &schema.NodeSchema{
|
||||
InputTypes: map[string]*vo.TypeInfo{
|
||||
"testKey": {Type: "string"},
|
||||
},
|
||||
}
|
||||
processor, err := NewJsonSerializer(ctx, validConfig)
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, processor)
|
||||
assert.NotNil(t, s)
|
||||
}
|
||||
|
||||
func TestJsonSerialize_Invoke(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
config := &SerializationConfig{
|
||||
|
||||
processor, err := (&SerializationConfig{}).Build(ctx, &schema.NodeSchema{
|
||||
InputTypes: map[string]*vo.TypeInfo{
|
||||
"stringKey": {Type: "string"},
|
||||
"intKey": {Type: "integer"},
|
||||
"boolKey": {Type: "boolean"},
|
||||
"objKey": {Type: "object"},
|
||||
},
|
||||
}
|
||||
|
||||
processor, err := NewJsonSerializer(ctx, config)
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test cases
|
||||
|
|
@ -115,7 +105,7 @@ func TestJsonSerialize_Invoke(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := processor.Invoke(ctx, tt.input)
|
||||
result, err := processor.(*Serializer).Invoke(ctx, tt.input)
|
||||
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,57 @@
|
|||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
|
||||
)
|
||||
|
||||
func convertParsingType(p string) (knowledge.ParseMode, error) {
|
||||
switch p {
|
||||
case "fast":
|
||||
return knowledge.FastParseMode, nil
|
||||
case "accurate":
|
||||
return knowledge.AccurateParseMode, nil
|
||||
default:
|
||||
return "", fmt.Errorf("invalid parsingType: %s", p)
|
||||
}
|
||||
}
|
||||
|
||||
func convertChunkType(p string) (knowledge.ChunkType, error) {
|
||||
switch p {
|
||||
case "custom":
|
||||
return knowledge.ChunkTypeCustom, nil
|
||||
case "default":
|
||||
return knowledge.ChunkTypeDefault, nil
|
||||
default:
|
||||
return "", fmt.Errorf("invalid ChunkType: %s", p)
|
||||
}
|
||||
}
|
||||
func convertRetrievalSearchType(s int64) (knowledge.SearchType, error) {
|
||||
switch s {
|
||||
case 0:
|
||||
return knowledge.SearchTypeSemantic, nil
|
||||
case 1:
|
||||
return knowledge.SearchTypeHybrid, nil
|
||||
case 20:
|
||||
return knowledge.SearchTypeFullText, nil
|
||||
default:
|
||||
return "", fmt.Errorf("invalid RetrievalSearchType %v", s)
|
||||
}
|
||||
}
|
||||
|
|
@ -21,27 +21,45 @@ import (
|
|||
"errors"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
type DeleterConfig struct {
|
||||
KnowledgeID int64
|
||||
KnowledgeDeleter knowledge.KnowledgeOperator
|
||||
}
|
||||
type DeleterConfig struct{}
|
||||
|
||||
type KnowledgeDeleter struct {
|
||||
config *DeleterConfig
|
||||
}
|
||||
|
||||
func NewKnowledgeDeleter(_ context.Context, cfg *DeleterConfig) (*KnowledgeDeleter, error) {
|
||||
if cfg.KnowledgeDeleter == nil {
|
||||
return nil, errors.New("knowledge deleter is required")
|
||||
func (d *DeleterConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeKnowledgeDeleter,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: d,
|
||||
}
|
||||
return &KnowledgeDeleter{
|
||||
config: cfg,
|
||||
|
||||
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func (d *DeleterConfig) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
return &Deleter{
|
||||
knowledgeDeleter: knowledge.GetKnowledgeOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (k *KnowledgeDeleter) Delete(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
type Deleter struct {
|
||||
knowledgeDeleter knowledge.KnowledgeOperator
|
||||
}
|
||||
|
||||
func (k *Deleter) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
documentID, ok := input["documentID"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("documentID is required and must be a string")
|
||||
|
|
@ -51,7 +69,7 @@ func (k *KnowledgeDeleter) Delete(ctx context.Context, input map[string]any) (ma
|
|||
DocumentID: documentID,
|
||||
}
|
||||
|
||||
response, err := k.config.KnowledgeDeleter.Delete(ctx, req)
|
||||
response, err := k.knowledgeDeleter.Delete(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,7 +24,14 @@ import (
|
|||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cast"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
|
||||
)
|
||||
|
||||
|
|
@ -32,30 +39,88 @@ type IndexerConfig struct {
|
|||
KnowledgeID int64
|
||||
ParsingStrategy *knowledge.ParsingStrategy
|
||||
ChunkingStrategy *knowledge.ChunkingStrategy
|
||||
KnowledgeIndexer knowledge.KnowledgeOperator
|
||||
}
|
||||
|
||||
type KnowledgeIndexer struct {
|
||||
config *IndexerConfig
|
||||
func (i *IndexerConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeKnowledgeIndexer,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: i,
|
||||
}
|
||||
|
||||
inputs := n.Data.Inputs
|
||||
datasetListInfoParam := inputs.DatasetParam[0]
|
||||
datasetIDs := datasetListInfoParam.Input.Value.Content.([]any)
|
||||
if len(datasetIDs) == 0 {
|
||||
return nil, fmt.Errorf("dataset ids is required")
|
||||
}
|
||||
knowledgeID, err := cast.ToInt64E(datasetIDs[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
i.KnowledgeID = knowledgeID
|
||||
ps := inputs.StrategyParam.ParsingStrategy
|
||||
parseMode, err := convertParsingType(ps.ParsingType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parsingStrategy := &knowledge.ParsingStrategy{
|
||||
ParseMode: parseMode,
|
||||
ImageOCR: ps.ImageOcr,
|
||||
ExtractImage: ps.ImageExtraction,
|
||||
ExtractTable: ps.TableExtraction,
|
||||
}
|
||||
i.ParsingStrategy = parsingStrategy
|
||||
|
||||
cs := inputs.StrategyParam.ChunkStrategy
|
||||
chunkType, err := convertChunkType(cs.ChunkType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
chunkingStrategy := &knowledge.ChunkingStrategy{
|
||||
ChunkType: chunkType,
|
||||
Separator: cs.Separator,
|
||||
ChunkSize: cs.MaxToken,
|
||||
Overlap: int64(cs.Overlap * float64(cs.MaxToken)),
|
||||
}
|
||||
i.ChunkingStrategy = chunkingStrategy
|
||||
|
||||
if err = convert.SetInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func NewKnowledgeIndexer(_ context.Context, cfg *IndexerConfig) (*KnowledgeIndexer, error) {
|
||||
if cfg.ParsingStrategy == nil {
|
||||
func (i *IndexerConfig) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
if i.ParsingStrategy == nil {
|
||||
return nil, errors.New("parsing strategy is required")
|
||||
}
|
||||
if cfg.ChunkingStrategy == nil {
|
||||
if i.ChunkingStrategy == nil {
|
||||
return nil, errors.New("chunking strategy is required")
|
||||
}
|
||||
if cfg.KnowledgeIndexer == nil {
|
||||
return nil, errors.New("knowledge indexer is required")
|
||||
}
|
||||
return &KnowledgeIndexer{
|
||||
config: cfg,
|
||||
return &Indexer{
|
||||
knowledgeID: i.KnowledgeID,
|
||||
parsingStrategy: i.ParsingStrategy,
|
||||
chunkingStrategy: i.ChunkingStrategy,
|
||||
knowledgeIndexer: knowledge.GetKnowledgeOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (k *KnowledgeIndexer) Store(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
type Indexer struct {
|
||||
knowledgeID int64
|
||||
parsingStrategy *knowledge.ParsingStrategy
|
||||
chunkingStrategy *knowledge.ChunkingStrategy
|
||||
knowledgeIndexer knowledge.KnowledgeOperator
|
||||
}
|
||||
|
||||
func (k *Indexer) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
fileURL, ok := input["knowledge"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("knowledge is required")
|
||||
|
|
@ -68,15 +133,15 @@ func (k *KnowledgeIndexer) Store(ctx context.Context, input map[string]any) (map
|
|||
}
|
||||
|
||||
req := &knowledge.CreateDocumentRequest{
|
||||
KnowledgeID: k.config.KnowledgeID,
|
||||
ParsingStrategy: k.config.ParsingStrategy,
|
||||
ChunkingStrategy: k.config.ChunkingStrategy,
|
||||
KnowledgeID: k.knowledgeID,
|
||||
ParsingStrategy: k.parsingStrategy,
|
||||
ChunkingStrategy: k.chunkingStrategy,
|
||||
FileURL: fileURL,
|
||||
FileName: fileName,
|
||||
FileExtension: ext,
|
||||
}
|
||||
|
||||
response, err := k.config.KnowledgeIndexer.Store(ctx, req)
|
||||
response, err := k.knowledgeIndexer.Store(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,7 +20,14 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/spf13/cast"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
)
|
||||
|
||||
|
|
@ -29,37 +36,136 @@ const outputList = "outputList"
|
|||
type RetrieveConfig struct {
|
||||
KnowledgeIDs []int64
|
||||
RetrievalStrategy *knowledge.RetrievalStrategy
|
||||
Retriever knowledge.KnowledgeOperator
|
||||
}
|
||||
|
||||
type KnowledgeRetrieve struct {
|
||||
config *RetrieveConfig
|
||||
func (r *RetrieveConfig) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeKnowledgeRetriever,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: r,
|
||||
}
|
||||
|
||||
inputs := n.Data.Inputs
|
||||
datasetListInfoParam := inputs.DatasetParam[0]
|
||||
datasetIDs := datasetListInfoParam.Input.Value.Content.([]any)
|
||||
knowledgeIDs := make([]int64, 0, len(datasetIDs))
|
||||
for _, id := range datasetIDs {
|
||||
k, err := cast.ToInt64E(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
knowledgeIDs = append(knowledgeIDs, k)
|
||||
}
|
||||
r.KnowledgeIDs = knowledgeIDs
|
||||
|
||||
retrievalStrategy := &knowledge.RetrievalStrategy{}
|
||||
|
||||
var getDesignatedParamContent = func(name string) (any, bool) {
|
||||
for _, param := range inputs.DatasetParam {
|
||||
if param.Name == name {
|
||||
return param.Input.Value.Content, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if content, ok := getDesignatedParamContent("topK"); ok {
|
||||
topK, err := cast.ToInt64E(content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
retrievalStrategy.TopK = &topK
|
||||
}
|
||||
|
||||
if content, ok := getDesignatedParamContent("useRerank"); ok {
|
||||
useRerank, err := cast.ToBoolE(content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
retrievalStrategy.EnableRerank = useRerank
|
||||
}
|
||||
|
||||
if content, ok := getDesignatedParamContent("useRewrite"); ok {
|
||||
useRewrite, err := cast.ToBoolE(content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
retrievalStrategy.EnableQueryRewrite = useRewrite
|
||||
}
|
||||
|
||||
if content, ok := getDesignatedParamContent("isPersonalOnly"); ok {
|
||||
isPersonalOnly, err := cast.ToBoolE(content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
retrievalStrategy.IsPersonalOnly = isPersonalOnly
|
||||
}
|
||||
|
||||
if content, ok := getDesignatedParamContent("useNl2sql"); ok {
|
||||
useNl2sql, err := cast.ToBoolE(content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
retrievalStrategy.EnableNL2SQL = useNl2sql
|
||||
}
|
||||
|
||||
if content, ok := getDesignatedParamContent("minScore"); ok {
|
||||
minScore, err := cast.ToFloat64E(content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
retrievalStrategy.MinScore = &minScore
|
||||
}
|
||||
|
||||
if content, ok := getDesignatedParamContent("strategy"); ok {
|
||||
strategy, err := cast.ToInt64E(content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
searchType, err := convertRetrievalSearchType(strategy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
retrievalStrategy.SearchType = searchType
|
||||
}
|
||||
|
||||
r.RetrievalStrategy = retrievalStrategy
|
||||
|
||||
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func NewKnowledgeRetrieve(_ context.Context, cfg *RetrieveConfig) (*KnowledgeRetrieve, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("cfg is required")
|
||||
func (r *RetrieveConfig) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
if len(r.KnowledgeIDs) == 0 {
|
||||
return nil, errors.New("knowledge ids are required")
|
||||
}
|
||||
|
||||
if cfg.Retriever == nil {
|
||||
return nil, errors.New("retriever is required")
|
||||
}
|
||||
|
||||
if len(cfg.KnowledgeIDs) == 0 {
|
||||
return nil, errors.New("knowledgeI ids is required")
|
||||
}
|
||||
|
||||
if cfg.RetrievalStrategy == nil {
|
||||
if r.RetrievalStrategy == nil {
|
||||
return nil, errors.New("retrieval strategy is required")
|
||||
}
|
||||
|
||||
return &KnowledgeRetrieve{
|
||||
config: cfg,
|
||||
return &Retrieve{
|
||||
knowledgeIDs: r.KnowledgeIDs,
|
||||
retrievalStrategy: r.RetrievalStrategy,
|
||||
retriever: knowledge.GetKnowledgeOperator(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (kr *KnowledgeRetrieve) Retrieve(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
type Retrieve struct {
|
||||
knowledgeIDs []int64
|
||||
retrievalStrategy *knowledge.RetrievalStrategy
|
||||
retriever knowledge.KnowledgeOperator
|
||||
}
|
||||
|
||||
func (kr *Retrieve) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
query, ok := input["Query"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("capital query key is required")
|
||||
|
|
@ -67,11 +173,11 @@ func (kr *KnowledgeRetrieve) Retrieve(ctx context.Context, input map[string]any)
|
|||
|
||||
req := &knowledge.RetrieveRequest{
|
||||
Query: query,
|
||||
KnowledgeIDs: kr.config.KnowledgeIDs,
|
||||
RetrievalStrategy: kr.config.RetrievalStrategy,
|
||||
KnowledgeIDs: kr.knowledgeIDs,
|
||||
RetrievalStrategy: kr.retrievalStrategy,
|
||||
}
|
||||
|
||||
response, err := kr.config.Retriever.Retrieve(ctx, req)
|
||||
response, err := kr.retriever.Retrieve(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -34,13 +34,20 @@ import (
|
|||
callbacks2 "github.com/cloudwego/eino/utils/callbacks"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
workflow3 "github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
crossknowledge "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
|
||||
crossmodel "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
|
|
@ -143,126 +150,408 @@ const (
|
|||
)
|
||||
|
||||
type RetrievalStrategy struct {
|
||||
RetrievalStrategy *crossknowledge.RetrievalStrategy
|
||||
RetrievalStrategy *knowledge.RetrievalStrategy
|
||||
NoReCallReplyMode NoReCallReplyMode
|
||||
NoReCallReplyCustomizePrompt string
|
||||
}
|
||||
|
||||
type KnowledgeRecallConfig struct {
|
||||
ChatModel model.BaseChatModel
|
||||
Retriever crossknowledge.KnowledgeOperator
|
||||
Retriever knowledge.KnowledgeOperator
|
||||
RetrievalStrategy *RetrievalStrategy
|
||||
SelectedKnowledgeDetails []*crossknowledge.KnowledgeDetail
|
||||
SelectedKnowledgeDetails []*knowledge.KnowledgeDetail
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
ChatModel ModelWithInfo
|
||||
Tools []tool.BaseTool
|
||||
SystemPrompt string
|
||||
UserPrompt string
|
||||
OutputFormat Format
|
||||
InputFields map[string]*vo.TypeInfo
|
||||
OutputFields map[string]*vo.TypeInfo
|
||||
ToolsReturnDirectly map[string]bool
|
||||
KnowledgeRecallConfig *KnowledgeRecallConfig
|
||||
FullSources map[string]*nodes.SourceInfo
|
||||
LLMParams *crossmodel.LLMParams
|
||||
FCParam *vo.FCParam
|
||||
BackupLLMParams *crossmodel.LLMParams
|
||||
}
|
||||
|
||||
type LLM struct {
|
||||
r compose.Runnable[map[string]any, map[string]any]
|
||||
outputFormat Format
|
||||
outputFields map[string]*vo.TypeInfo
|
||||
canStream bool
|
||||
requireCheckpoint bool
|
||||
fullSources map[string]*nodes.SourceInfo
|
||||
}
|
||||
|
||||
const (
|
||||
rawOutputKey = "llm_raw_output_%s"
|
||||
warningKey = "llm_warning_%s"
|
||||
)
|
||||
|
||||
func jsonParse(ctx context.Context, data string, schema_ map[string]*vo.TypeInfo) (map[string]any, error) {
|
||||
data = nodes.ExtractJSONString(data)
|
||||
|
||||
var result map[string]any
|
||||
|
||||
err := sonic.UnmarshalString(data, &result)
|
||||
if err != nil {
|
||||
c := execute.GetExeCtx(ctx)
|
||||
if c != nil {
|
||||
logs.CtxErrorf(ctx, "failed to parse json: %v, data: %s", err, data)
|
||||
rawOutputK := fmt.Sprintf(rawOutputKey, c.NodeCtx.NodeKey)
|
||||
warningK := fmt.Sprintf(warningKey, c.NodeCtx.NodeKey)
|
||||
ctxcache.Store(ctx, rawOutputK, data)
|
||||
ctxcache.Store(ctx, warningK, vo.WrapWarn(errno.ErrLLMStructuredOutputParseFail, err))
|
||||
return map[string]any{}, nil
|
||||
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema2.NodeSchema, error) {
|
||||
ns := &schema2.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeLLM,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: c,
|
||||
}
|
||||
|
||||
param := n.Data.Inputs.LLMParam
|
||||
if param == nil {
|
||||
return nil, fmt.Errorf("llm node's llmParam is nil")
|
||||
}
|
||||
|
||||
bs, _ := sonic.Marshal(param)
|
||||
llmParam := make(vo.LLMParam, 0)
|
||||
if err := sonic.Unmarshal(bs, &llmParam); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
convertedLLMParam, err := llmParamsToLLMParam(llmParam)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r, ws, err := nodes.ConvertInputs(ctx, result, schema_)
|
||||
c.LLMParams = convertedLLMParam
|
||||
c.SystemPrompt = convertedLLMParam.SystemPrompt
|
||||
c.UserPrompt = convertedLLMParam.Prompt
|
||||
|
||||
var resFormat Format
|
||||
switch convertedLLMParam.ResponseFormat {
|
||||
case crossmodel.ResponseFormatText:
|
||||
resFormat = FormatText
|
||||
case crossmodel.ResponseFormatMarkdown:
|
||||
resFormat = FormatMarkdown
|
||||
case crossmodel.ResponseFormatJSON:
|
||||
resFormat = FormatJSON
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported response format: %d", convertedLLMParam.ResponseFormat)
|
||||
}
|
||||
|
||||
c.OutputFormat = resFormat
|
||||
|
||||
if err = convert.SetInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resFormat == FormatJSON {
|
||||
if len(ns.OutputTypes) == 1 {
|
||||
for _, v := range ns.OutputTypes {
|
||||
if v.Type == vo.DataTypeString {
|
||||
resFormat = FormatText
|
||||
break
|
||||
}
|
||||
}
|
||||
} else if len(ns.OutputTypes) == 2 {
|
||||
if _, ok := ns.OutputTypes[ReasoningOutputKey]; ok {
|
||||
for k, v := range ns.OutputTypes {
|
||||
if k != ReasoningOutputKey && v.Type == vo.DataTypeString {
|
||||
resFormat = FormatText
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if resFormat == FormatJSON {
|
||||
ns.StreamConfigs = &schema2.StreamConfig{
|
||||
CanGeneratesStream: false,
|
||||
}
|
||||
} else {
|
||||
ns.StreamConfigs = &schema2.StreamConfig{
|
||||
CanGeneratesStream: true,
|
||||
}
|
||||
}
|
||||
|
||||
if n.Data.Inputs.LLM != nil && n.Data.Inputs.FCParam != nil {
|
||||
c.FCParam = n.Data.Inputs.FCParam
|
||||
}
|
||||
|
||||
if se := n.Data.Inputs.SettingOnError; se != nil {
|
||||
if se.Ext != nil && len(se.Ext.BackupLLMParam) > 0 {
|
||||
var backupLLMParam vo.SimpleLLMParam
|
||||
if err = sonic.UnmarshalString(se.Ext.BackupLLMParam, &backupLLMParam); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
backupModel, err := simpleLLMParamsToLLMParams(backupLLMParam)
|
||||
if err != nil {
|
||||
return nil, vo.WrapError(errno.ErrLLMStructuredOutputParseFail, err)
|
||||
return nil, err
|
||||
}
|
||||
c.BackupLLMParams = backupModel
|
||||
}
|
||||
}
|
||||
|
||||
if ws != nil {
|
||||
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func llmParamsToLLMParam(params vo.LLMParam) (*crossmodel.LLMParams, error) {
|
||||
p := &crossmodel.LLMParams{}
|
||||
for _, param := range params {
|
||||
switch param.Name {
|
||||
case "temperature":
|
||||
strVal := param.Input.Value.Content.(string)
|
||||
floatVal, err := strconv.ParseFloat(strVal, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.Temperature = &floatVal
|
||||
case "maxTokens":
|
||||
strVal := param.Input.Value.Content.(string)
|
||||
intVal, err := strconv.Atoi(strVal)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.MaxTokens = intVal
|
||||
case "responseFormat":
|
||||
strVal := param.Input.Value.Content.(string)
|
||||
int64Val, err := strconv.ParseInt(strVal, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.ResponseFormat = crossmodel.ResponseFormat(int64Val)
|
||||
case "modleName":
|
||||
strVal := param.Input.Value.Content.(string)
|
||||
p.ModelName = strVal
|
||||
case "modelType":
|
||||
strVal := param.Input.Value.Content.(string)
|
||||
int64Val, err := strconv.ParseInt(strVal, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.ModelType = int64Val
|
||||
case "prompt":
|
||||
strVal := param.Input.Value.Content.(string)
|
||||
p.Prompt = strVal
|
||||
case "enableChatHistory":
|
||||
boolVar := param.Input.Value.Content.(bool)
|
||||
p.EnableChatHistory = boolVar
|
||||
case "systemPrompt":
|
||||
strVal := param.Input.Value.Content.(string)
|
||||
p.SystemPrompt = strVal
|
||||
case "chatHistoryRound", "generationDiversity", "frequencyPenalty", "presencePenalty":
|
||||
// do nothing
|
||||
case "topP":
|
||||
strVal := param.Input.Value.Content.(string)
|
||||
floatVar, err := strconv.ParseFloat(strVal, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.TopP = &floatVar
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid LLMParam name: %s", param.Name)
|
||||
}
|
||||
}
|
||||
|
||||
return r, nil
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func simpleLLMParamsToLLMParams(params vo.SimpleLLMParam) (*crossmodel.LLMParams, error) {
|
||||
p := &crossmodel.LLMParams{}
|
||||
p.ModelName = params.ModelName
|
||||
p.ModelType = params.ModelType
|
||||
p.Temperature = ¶ms.Temperature
|
||||
p.MaxTokens = params.MaxTokens
|
||||
p.TopP = ¶ms.TopP
|
||||
p.ResponseFormat = params.ResponseFormat
|
||||
p.SystemPrompt = params.SystemPrompt
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func getReasoningContent(message *schema.Message) string {
|
||||
return message.ReasoningContent
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
nested []nodes.NestedWorkflowOption
|
||||
toolWorkflowSW *schema.StreamWriter[*entity.Message]
|
||||
}
|
||||
func (c *Config) Build(ctx context.Context, ns *schema2.NodeSchema, _ ...schema2.BuildOption) (any, error) {
|
||||
var (
|
||||
err error
|
||||
chatModel, fallbackM model.BaseChatModel
|
||||
info, fallbackI *modelmgr.Model
|
||||
modelWithInfo ModelWithInfo
|
||||
tools []tool.BaseTool
|
||||
toolsReturnDirectly map[string]bool
|
||||
knowledgeRecallConfig *KnowledgeRecallConfig
|
||||
)
|
||||
|
||||
type Option func(o *Options)
|
||||
|
||||
func WithNestedWorkflowOptions(nested ...nodes.NestedWorkflowOption) Option {
|
||||
return func(o *Options) {
|
||||
o.nested = append(o.nested, nested...)
|
||||
chatModel, info, err = crossmodel.GetManager().GetModel(ctx, c.LLMParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func WithToolWorkflowMessageWriter(sw *schema.StreamWriter[*entity.Message]) Option {
|
||||
return func(o *Options) {
|
||||
o.toolWorkflowSW = sw
|
||||
exceptionConf := ns.ExceptionConfigs
|
||||
if exceptionConf != nil && exceptionConf.MaxRetry > 0 {
|
||||
backupModelParams := c.BackupLLMParams
|
||||
if backupModelParams != nil {
|
||||
fallbackM, fallbackI, err = crossmodel.GetManager().GetModel(ctx, backupModelParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type llmState = map[string]any
|
||||
if fallbackM == nil {
|
||||
modelWithInfo = NewModel(chatModel, info)
|
||||
} else {
|
||||
modelWithInfo = NewModelWithFallback(chatModel, fallbackM, info, fallbackI)
|
||||
}
|
||||
|
||||
const agentModelName = "agent_model"
|
||||
fcParams := c.FCParam
|
||||
if fcParams != nil {
|
||||
if fcParams.WorkflowFCParam != nil {
|
||||
for _, wf := range fcParams.WorkflowFCParam.WorkflowList {
|
||||
wfIDStr := wf.WorkflowID
|
||||
wfID, err := strconv.ParseInt(wfIDStr, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid workflow id: %s", wfIDStr)
|
||||
}
|
||||
|
||||
workflowToolConfig := vo.WorkflowToolConfig{}
|
||||
if wf.FCSetting != nil {
|
||||
workflowToolConfig.InputParametersConfig = wf.FCSetting.RequestParameters
|
||||
workflowToolConfig.OutputParametersConfig = wf.FCSetting.ResponseParameters
|
||||
}
|
||||
|
||||
locator := vo.FromDraft
|
||||
if wf.WorkflowVersion != "" {
|
||||
locator = vo.FromSpecificVersion
|
||||
}
|
||||
|
||||
wfTool, err := workflow.GetRepository().WorkflowAsTool(ctx, vo.GetPolicy{
|
||||
ID: wfID,
|
||||
QType: locator,
|
||||
Version: wf.WorkflowVersion,
|
||||
}, workflowToolConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tools = append(tools, wfTool)
|
||||
if wfTool.TerminatePlan() == vo.UseAnswerContent {
|
||||
if toolsReturnDirectly == nil {
|
||||
toolsReturnDirectly = make(map[string]bool)
|
||||
}
|
||||
toolInfo, err := wfTool.Info(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
toolsReturnDirectly[toolInfo.Name] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if fcParams.PluginFCParam != nil {
|
||||
pluginToolsInvokableReq := make(map[int64]*plugin.ToolsInvokableRequest)
|
||||
for _, p := range fcParams.PluginFCParam.PluginList {
|
||||
pid, err := strconv.ParseInt(p.PluginID, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid plugin id: %s", p.PluginID)
|
||||
}
|
||||
toolID, err := strconv.ParseInt(p.ApiId, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid plugin id: %s", p.PluginID)
|
||||
}
|
||||
|
||||
var (
|
||||
requestParameters []*workflow3.APIParameter
|
||||
responseParameters []*workflow3.APIParameter
|
||||
)
|
||||
if p.FCSetting != nil {
|
||||
requestParameters = p.FCSetting.RequestParameters
|
||||
responseParameters = p.FCSetting.ResponseParameters
|
||||
}
|
||||
|
||||
if req, ok := pluginToolsInvokableReq[pid]; ok {
|
||||
req.ToolsInvokableInfo[toolID] = &plugin.ToolsInvokableInfo{
|
||||
ToolID: toolID,
|
||||
RequestAPIParametersConfig: requestParameters,
|
||||
ResponseAPIParametersConfig: responseParameters,
|
||||
}
|
||||
} else {
|
||||
pluginToolsInfoRequest := &plugin.ToolsInvokableRequest{
|
||||
PluginEntity: plugin.Entity{
|
||||
PluginID: pid,
|
||||
PluginVersion: ptr.Of(p.PluginVersion),
|
||||
},
|
||||
ToolsInvokableInfo: map[int64]*plugin.ToolsInvokableInfo{
|
||||
toolID: {
|
||||
ToolID: toolID,
|
||||
RequestAPIParametersConfig: requestParameters,
|
||||
ResponseAPIParametersConfig: responseParameters,
|
||||
},
|
||||
},
|
||||
IsDraft: p.IsDraft,
|
||||
}
|
||||
pluginToolsInvokableReq[pid] = pluginToolsInfoRequest
|
||||
}
|
||||
}
|
||||
inInvokableTools := make([]tool.BaseTool, 0, len(fcParams.PluginFCParam.PluginList))
|
||||
for _, req := range pluginToolsInvokableReq {
|
||||
toolMap, err := plugin.GetPluginService().GetPluginInvokableTools(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, t := range toolMap {
|
||||
inInvokableTools = append(inInvokableTools, plugin.NewInvokableTool(t))
|
||||
}
|
||||
}
|
||||
if len(inInvokableTools) > 0 {
|
||||
tools = append(tools, inInvokableTools...)
|
||||
}
|
||||
}
|
||||
|
||||
if fcParams.KnowledgeFCParam != nil && len(fcParams.KnowledgeFCParam.KnowledgeList) > 0 {
|
||||
kwChatModel := workflow.GetRepository().GetKnowledgeRecallChatModel()
|
||||
if kwChatModel == nil {
|
||||
return nil, fmt.Errorf("workflow builtin chat model for knowledge recall not configured")
|
||||
}
|
||||
|
||||
knowledgeOperator := knowledge.GetKnowledgeOperator()
|
||||
setting := fcParams.KnowledgeFCParam.GlobalSetting
|
||||
knowledgeRecallConfig = &KnowledgeRecallConfig{
|
||||
ChatModel: kwChatModel,
|
||||
Retriever: knowledgeOperator,
|
||||
}
|
||||
searchType, err := toRetrievalSearchType(setting.SearchMode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
knowledgeRecallConfig.RetrievalStrategy = &RetrievalStrategy{
|
||||
RetrievalStrategy: &knowledge.RetrievalStrategy{
|
||||
TopK: ptr.Of(setting.TopK),
|
||||
MinScore: ptr.Of(setting.MinScore),
|
||||
SearchType: searchType,
|
||||
EnableNL2SQL: setting.UseNL2SQL,
|
||||
EnableQueryRewrite: setting.UseRewrite,
|
||||
EnableRerank: setting.UseRerank,
|
||||
},
|
||||
NoReCallReplyMode: NoReCallReplyMode(setting.NoRecallReplyMode),
|
||||
NoReCallReplyCustomizePrompt: setting.NoRecallReplyCustomizePrompt,
|
||||
}
|
||||
|
||||
knowledgeIDs := make([]int64, 0, len(fcParams.KnowledgeFCParam.KnowledgeList))
|
||||
for _, kw := range fcParams.KnowledgeFCParam.KnowledgeList {
|
||||
kid, err := strconv.ParseInt(kw.ID, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
knowledgeIDs = append(knowledgeIDs, kid)
|
||||
}
|
||||
|
||||
detailResp, err := knowledgeOperator.ListKnowledgeDetail(ctx,
|
||||
&knowledge.ListKnowledgeDetailRequest{
|
||||
KnowledgeIDs: knowledgeIDs,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
knowledgeRecallConfig.SelectedKnowledgeDetails = detailResp.KnowledgeDetails
|
||||
}
|
||||
}
|
||||
|
||||
func New(ctx context.Context, cfg *Config) (*LLM, error) {
|
||||
g := compose.NewGraph[map[string]any, map[string]any](compose.WithGenLocalState(func(ctx context.Context) (state llmState) {
|
||||
return llmState{}
|
||||
}))
|
||||
|
||||
var (
|
||||
hasReasoning bool
|
||||
canStream = true
|
||||
)
|
||||
var hasReasoning bool
|
||||
|
||||
format := cfg.OutputFormat
|
||||
format := c.OutputFormat
|
||||
if format == FormatJSON {
|
||||
if len(cfg.OutputFields) == 1 {
|
||||
for _, v := range cfg.OutputFields {
|
||||
if len(ns.OutputTypes) == 1 {
|
||||
for _, v := range ns.OutputTypes {
|
||||
if v.Type == vo.DataTypeString {
|
||||
format = FormatText
|
||||
break
|
||||
}
|
||||
}
|
||||
} else if len(cfg.OutputFields) == 2 {
|
||||
if _, ok := cfg.OutputFields[ReasoningOutputKey]; ok {
|
||||
for k, v := range cfg.OutputFields {
|
||||
} else if len(ns.OutputTypes) == 2 {
|
||||
if _, ok := ns.OutputTypes[ReasoningOutputKey]; ok {
|
||||
for k, v := range ns.OutputTypes {
|
||||
if k != ReasoningOutputKey && v.Type == vo.DataTypeString {
|
||||
format = FormatText
|
||||
break
|
||||
|
|
@ -272,10 +561,10 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
|
|||
}
|
||||
}
|
||||
|
||||
userPrompt := cfg.UserPrompt
|
||||
userPrompt := c.UserPrompt
|
||||
switch format {
|
||||
case FormatJSON:
|
||||
jsonSchema, err := vo.TypeInfoToJSONSchema(cfg.OutputFields, nil)
|
||||
jsonSchema, err := vo.TypeInfoToJSONSchema(ns.OutputTypes, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -287,20 +576,20 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
|
|||
case FormatText:
|
||||
}
|
||||
|
||||
if cfg.KnowledgeRecallConfig != nil {
|
||||
err := injectKnowledgeTool(ctx, g, cfg.UserPrompt, cfg.KnowledgeRecallConfig)
|
||||
if knowledgeRecallConfig != nil {
|
||||
err := injectKnowledgeTool(ctx, g, c.UserPrompt, knowledgeRecallConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userPrompt = fmt.Sprintf("{{%s}}%s", knowledgeUserPromptTemplateKey, userPrompt)
|
||||
|
||||
inputs := maps.Clone(cfg.InputFields)
|
||||
inputs := maps.Clone(ns.InputTypes)
|
||||
inputs[knowledgeUserPromptTemplateKey] = &vo.TypeInfo{
|
||||
Type: vo.DataTypeString,
|
||||
}
|
||||
sp := newPromptTpl(schema.System, cfg.SystemPrompt, inputs, nil)
|
||||
sp := newPromptTpl(schema.System, c.SystemPrompt, inputs, nil)
|
||||
up := newPromptTpl(schema.User, userPrompt, inputs, []string{knowledgeUserPromptTemplateKey})
|
||||
template := newPrompts(sp, up, cfg.ChatModel)
|
||||
template := newPrompts(sp, up, modelWithInfo)
|
||||
|
||||
_ = g.AddChatTemplateNode(templateNodeKey, template,
|
||||
compose.WithStatePreHandler(func(ctx context.Context, in map[string]any, state llmState) (map[string]any, error) {
|
||||
|
|
@ -312,28 +601,28 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
|
|||
_ = g.AddEdge(knowledgeLambdaKey, templateNodeKey)
|
||||
|
||||
} else {
|
||||
sp := newPromptTpl(schema.System, cfg.SystemPrompt, cfg.InputFields, nil)
|
||||
up := newPromptTpl(schema.User, userPrompt, cfg.InputFields, nil)
|
||||
template := newPrompts(sp, up, cfg.ChatModel)
|
||||
sp := newPromptTpl(schema.System, c.SystemPrompt, ns.InputTypes, nil)
|
||||
up := newPromptTpl(schema.User, userPrompt, ns.InputTypes, nil)
|
||||
template := newPrompts(sp, up, modelWithInfo)
|
||||
_ = g.AddChatTemplateNode(templateNodeKey, template)
|
||||
|
||||
_ = g.AddEdge(compose.START, templateNodeKey)
|
||||
}
|
||||
|
||||
if len(cfg.Tools) > 0 {
|
||||
m, ok := cfg.ChatModel.(model.ToolCallingChatModel)
|
||||
if len(tools) > 0 {
|
||||
m, ok := modelWithInfo.(model.ToolCallingChatModel)
|
||||
if !ok {
|
||||
return nil, errors.New("requires a ToolCallingChatModel to use with tools")
|
||||
}
|
||||
reactConfig := react.AgentConfig{
|
||||
ToolCallingModel: m,
|
||||
ToolsConfig: compose.ToolsNodeConfig{Tools: cfg.Tools},
|
||||
ToolsConfig: compose.ToolsNodeConfig{Tools: tools},
|
||||
ModelNodeName: agentModelName,
|
||||
}
|
||||
|
||||
if len(cfg.ToolsReturnDirectly) > 0 {
|
||||
reactConfig.ToolReturnDirectly = make(map[string]struct{}, len(cfg.ToolsReturnDirectly))
|
||||
for k := range cfg.ToolsReturnDirectly {
|
||||
if len(toolsReturnDirectly) > 0 {
|
||||
reactConfig.ToolReturnDirectly = make(map[string]struct{}, len(toolsReturnDirectly))
|
||||
for k := range toolsReturnDirectly {
|
||||
reactConfig.ToolReturnDirectly[k] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
|
@ -347,28 +636,26 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
|
|||
opts = append(opts, compose.WithNodeName("workflow_llm_react_agent"))
|
||||
_ = g.AddGraphNode(llmNodeKey, agentNode, opts...)
|
||||
} else {
|
||||
_ = g.AddChatModelNode(llmNodeKey, cfg.ChatModel)
|
||||
_ = g.AddChatModelNode(llmNodeKey, modelWithInfo)
|
||||
}
|
||||
|
||||
_ = g.AddEdge(templateNodeKey, llmNodeKey)
|
||||
|
||||
if format == FormatJSON {
|
||||
iConvert := func(ctx context.Context, msg *schema.Message) (map[string]any, error) {
|
||||
return jsonParse(ctx, msg.Content, cfg.OutputFields)
|
||||
return jsonParse(ctx, msg.Content, ns.OutputTypes)
|
||||
}
|
||||
|
||||
convertNode := compose.InvokableLambda(iConvert)
|
||||
|
||||
_ = g.AddLambdaNode(outputConvertNodeKey, convertNode)
|
||||
|
||||
canStream = false
|
||||
} else {
|
||||
var outputKey string
|
||||
if len(cfg.OutputFields) != 1 && len(cfg.OutputFields) != 2 {
|
||||
if len(ns.OutputTypes) != 1 && len(ns.OutputTypes) != 2 {
|
||||
panic("impossible")
|
||||
}
|
||||
|
||||
for k, v := range cfg.OutputFields {
|
||||
for k, v := range ns.OutputTypes {
|
||||
if v.Type != vo.DataTypeString {
|
||||
panic("impossible")
|
||||
}
|
||||
|
|
@ -443,17 +730,17 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
|
|||
_ = g.AddEdge(outputConvertNodeKey, compose.END)
|
||||
|
||||
requireCheckpoint := false
|
||||
if len(cfg.Tools) > 0 {
|
||||
if len(tools) > 0 {
|
||||
requireCheckpoint = true
|
||||
}
|
||||
|
||||
var opts []compose.GraphCompileOption
|
||||
var compileOpts []compose.GraphCompileOption
|
||||
if requireCheckpoint {
|
||||
opts = append(opts, compose.WithCheckPointStore(workflow.GetRepository()))
|
||||
compileOpts = append(compileOpts, compose.WithCheckPointStore(workflow.GetRepository()))
|
||||
}
|
||||
opts = append(opts, compose.WithGraphName("workflow_llm_node_graph"))
|
||||
compileOpts = append(compileOpts, compose.WithGraphName("workflow_llm_node_graph"))
|
||||
|
||||
r, err := g.Compile(ctx, opts...)
|
||||
r, err := g.Compile(ctx, compileOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -461,15 +748,132 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
|
|||
llm := &LLM{
|
||||
r: r,
|
||||
outputFormat: format,
|
||||
canStream: canStream,
|
||||
requireCheckpoint: requireCheckpoint,
|
||||
fullSources: cfg.FullSources,
|
||||
fullSources: ns.FullSources,
|
||||
}
|
||||
|
||||
return llm, nil
|
||||
}
|
||||
|
||||
func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...Option) (composeOpts []compose.Option, resumingEvent *entity.InterruptEvent, err error) {
|
||||
func (c *Config) RequireCheckpoint() bool {
|
||||
if c.FCParam != nil {
|
||||
if c.FCParam.WorkflowFCParam != nil || c.FCParam.PluginFCParam != nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Config) FieldStreamType(path compose.FieldPath, ns *schema2.NodeSchema,
|
||||
sc *schema2.WorkflowSchema) (schema2.FieldStreamType, error) {
|
||||
if !sc.RequireStreaming() {
|
||||
return schema2.FieldNotStream, nil
|
||||
}
|
||||
|
||||
if len(path) != 1 {
|
||||
return schema2.FieldNotStream, nil
|
||||
}
|
||||
|
||||
outputs := ns.OutputTypes
|
||||
if len(outputs) != 1 && len(outputs) != 2 {
|
||||
return schema2.FieldNotStream, nil
|
||||
}
|
||||
|
||||
var outputKey string
|
||||
for key, output := range outputs {
|
||||
if output.Type != vo.DataTypeString {
|
||||
return schema2.FieldNotStream, nil
|
||||
}
|
||||
|
||||
if key != ReasoningOutputKey {
|
||||
if len(outputKey) > 0 {
|
||||
return schema2.FieldNotStream, nil
|
||||
}
|
||||
outputKey = key
|
||||
}
|
||||
}
|
||||
|
||||
field := path[0]
|
||||
if field == ReasoningOutputKey || field == outputKey {
|
||||
return schema2.FieldIsStream, nil
|
||||
}
|
||||
|
||||
return schema2.FieldNotStream, nil
|
||||
}
|
||||
|
||||
func toRetrievalSearchType(s int64) (knowledge.SearchType, error) {
|
||||
switch s {
|
||||
case 0:
|
||||
return knowledge.SearchTypeSemantic, nil
|
||||
case 1:
|
||||
return knowledge.SearchTypeHybrid, nil
|
||||
case 20:
|
||||
return knowledge.SearchTypeFullText, nil
|
||||
default:
|
||||
return "", fmt.Errorf("invalid retrieval search type %v", s)
|
||||
}
|
||||
}
|
||||
|
||||
type LLM struct {
|
||||
r compose.Runnable[map[string]any, map[string]any]
|
||||
outputFormat Format
|
||||
requireCheckpoint bool
|
||||
fullSources map[string]*schema2.SourceInfo
|
||||
}
|
||||
|
||||
const (
|
||||
rawOutputKey = "llm_raw_output_%s"
|
||||
warningKey = "llm_warning_%s"
|
||||
)
|
||||
|
||||
func jsonParse(ctx context.Context, data string, schema_ map[string]*vo.TypeInfo) (map[string]any, error) {
|
||||
data = nodes.ExtractJSONString(data)
|
||||
|
||||
var result map[string]any
|
||||
|
||||
err := sonic.UnmarshalString(data, &result)
|
||||
if err != nil {
|
||||
c := execute.GetExeCtx(ctx)
|
||||
if c != nil {
|
||||
logs.CtxErrorf(ctx, "failed to parse json: %v, data: %s", err, data)
|
||||
rawOutputK := fmt.Sprintf(rawOutputKey, c.NodeCtx.NodeKey)
|
||||
warningK := fmt.Sprintf(warningKey, c.NodeCtx.NodeKey)
|
||||
ctxcache.Store(ctx, rawOutputK, data)
|
||||
ctxcache.Store(ctx, warningK, vo.WrapWarn(errno.ErrLLMStructuredOutputParseFail, err))
|
||||
return map[string]any{}, nil
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r, ws, err := nodes.ConvertInputs(ctx, result, schema_)
|
||||
if err != nil {
|
||||
return nil, vo.WrapError(errno.ErrLLMStructuredOutputParseFail, err)
|
||||
}
|
||||
|
||||
if ws != nil {
|
||||
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
type llmOptions struct {
|
||||
toolWorkflowSW *schema.StreamWriter[*entity.Message]
|
||||
}
|
||||
|
||||
func WithToolWorkflowMessageWriter(sw *schema.StreamWriter[*entity.Message]) nodes.NodeOption {
|
||||
return nodes.WrapImplSpecificOptFn(func(o *llmOptions) {
|
||||
o.toolWorkflowSW = sw
|
||||
})
|
||||
}
|
||||
|
||||
type llmState = map[string]any
|
||||
|
||||
const agentModelName = "agent_model"
|
||||
|
||||
func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...nodes.NodeOption) (composeOpts []compose.Option, resumingEvent *entity.InterruptEvent, err error) {
|
||||
c := execute.GetExeCtx(ctx)
|
||||
if c != nil {
|
||||
resumingEvent = c.NodeCtx.ResumingEvent
|
||||
|
|
@ -502,17 +906,9 @@ func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...Option) (co
|
|||
composeOpts = append(composeOpts, compose.WithCheckPointID(checkpointID))
|
||||
}
|
||||
|
||||
llmOpts := &Options{}
|
||||
for _, opt := range opts {
|
||||
opt(llmOpts)
|
||||
}
|
||||
options := nodes.GetCommonOptions(&nodes.NodeOptions{}, opts...)
|
||||
|
||||
nestedOpts := &nodes.NestedWorkflowOptions{}
|
||||
for _, opt := range llmOpts.nested {
|
||||
opt(nestedOpts)
|
||||
}
|
||||
|
||||
composeOpts = append(composeOpts, nestedOpts.GetOptsForNested()...)
|
||||
composeOpts = append(composeOpts, options.GetOptsForNested()...)
|
||||
|
||||
if resumingEvent != nil {
|
||||
var (
|
||||
|
|
@ -580,6 +976,7 @@ func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...Option) (co
|
|||
composeOpts = append(composeOpts, compose.WithToolsNodeOption(compose.WithToolOption(execute.WithExecuteConfig(exeCfg))))
|
||||
}
|
||||
|
||||
llmOpts := nodes.GetImplSpecificOptions(&llmOptions{}, opts...)
|
||||
if llmOpts.toolWorkflowSW != nil {
|
||||
toolMsgOpt, toolMsgSR := execute.WithMessagePipe()
|
||||
composeOpts = append(composeOpts, toolMsgOpt)
|
||||
|
|
@ -697,7 +1094,7 @@ func handleInterrupt(ctx context.Context, err error, resumingEvent *entity.Inter
|
|||
return compose.NewInterruptAndRerunErr(ie)
|
||||
}
|
||||
|
||||
func (l *LLM) Chat(ctx context.Context, in map[string]any, opts ...Option) (out map[string]any, err error) {
|
||||
func (l *LLM) Invoke(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (out map[string]any, err error) {
|
||||
composeOpts, resumingEvent, err := l.prepare(ctx, in, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -712,7 +1109,7 @@ func (l *LLM) Chat(ctx context.Context, in map[string]any, opts ...Option) (out
|
|||
return out, nil
|
||||
}
|
||||
|
||||
func (l *LLM) ChatStream(ctx context.Context, in map[string]any, opts ...Option) (out *schema.StreamReader[map[string]any], err error) {
|
||||
func (l *LLM) Stream(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (out *schema.StreamReader[map[string]any], err error) {
|
||||
composeOpts, resumingEvent, err := l.prepare(ctx, in, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -745,7 +1142,7 @@ func injectKnowledgeTool(_ context.Context, g *compose.Graph[map[string]any, map
|
|||
|
||||
_ = g.AddLambdaNode(knowledgeLambdaKey, compose.InvokableLambda(func(ctx context.Context, input *schema.Message) (output map[string]any, err error) {
|
||||
modelPredictionIDs := strings.Split(input.Content, ",")
|
||||
selectKwIDs := slices.ToMap(cfg.SelectedKnowledgeDetails, func(e *crossknowledge.KnowledgeDetail) (string, int64) {
|
||||
selectKwIDs := slices.ToMap(cfg.SelectedKnowledgeDetails, func(e *knowledge.KnowledgeDetail) (string, int64) {
|
||||
return strconv.Itoa(int(e.ID)), e.ID
|
||||
})
|
||||
recallKnowledgeIDs := make([]int64, 0)
|
||||
|
|
@ -759,7 +1156,7 @@ func injectKnowledgeTool(_ context.Context, g *compose.Graph[map[string]any, map
|
|||
return make(map[string]any), nil
|
||||
}
|
||||
|
||||
docs, err := cfg.Retriever.Retrieve(ctx, &crossknowledge.RetrieveRequest{
|
||||
docs, err := cfg.Retriever.Retrieve(ctx, &knowledge.RetrieveRequest{
|
||||
Query: userPrompt,
|
||||
KnowledgeIDs: recallKnowledgeIDs,
|
||||
RetrievalStrategy: cfg.RetrievalStrategy.RetrievalStrategy,
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ import (
|
|||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
|
|
@ -107,7 +108,7 @@ func newPrompts(sp, up *promptTpl, model ModelWithInfo) *prompts {
|
|||
}
|
||||
|
||||
func (pl *promptTpl) render(ctx context.Context, vs map[string]any,
|
||||
sources map[string]*nodes.SourceInfo,
|
||||
sources map[string]*schema2.SourceInfo,
|
||||
supportedModals map[modelmgr.Modal]bool,
|
||||
) (*schema.Message, error) {
|
||||
if !pl.hasMultiModal || len(supportedModals) == 0 {
|
||||
|
|
@ -247,7 +248,7 @@ func (p *prompts) Format(ctx context.Context, vs map[string]any, _ ...prompt.Opt
|
|||
}
|
||||
sk := fmt.Sprintf(sourceKey, nodeKey)
|
||||
|
||||
sources, ok := ctxcache.Get[map[string]*nodes.SourceInfo](ctx, sk)
|
||||
sources, ok := ctxcache.Get[map[string]*schema2.SourceInfo](ctx, sk)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("resolved sources not found llm node, key: %s", sk)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package loop
|
||||
package _break
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
|
@ -22,21 +22,36 @@ import (
|
|||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
type Break struct {
|
||||
parentIntermediateStore variable.Store
|
||||
}
|
||||
|
||||
func NewBreak(_ context.Context, store variable.Store) (*Break, error) {
|
||||
type Config struct{}
|
||||
|
||||
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
return &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeBreak,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: c,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Config) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
return &Break{
|
||||
parentIntermediateStore: store,
|
||||
parentIntermediateStore: &nodes.ParentIntermediateStore{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
const BreakKey = "$break"
|
||||
|
||||
func (b *Break) DoBreak(ctx context.Context, _ map[string]any) (map[string]any, error) {
|
||||
func (b *Break) Invoke(ctx context.Context, _ map[string]any) (map[string]any, error) {
|
||||
err := b.parentIntermediateStore.Set(ctx, compose.FieldPath{BreakKey}, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package _continue
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
type Continue struct{}
|
||||
|
||||
type Config struct{}
|
||||
|
||||
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
return &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeContinue,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: c,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Config) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
return &Continue{}, nil
|
||||
}
|
||||
|
||||
func (co *Continue) Invoke(_ context.Context, in map[string]any) (map[string]any, error) {
|
||||
return in, nil
|
||||
}
|
||||
|
|
@ -27,53 +27,150 @@ import (
|
|||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
_break "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/loop/break"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
type Loop struct {
|
||||
config *Config
|
||||
outputs map[string]*vo.FieldSource
|
||||
outputVars map[string]string
|
||||
inner compose.Runnable[map[string]any, map[string]any]
|
||||
nodeKey vo.NodeKey
|
||||
|
||||
loopType Type
|
||||
inputArrays []string
|
||||
intermediateVars map[string]*vo.TypeInfo
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
LoopNodeKey vo.NodeKey
|
||||
LoopType Type
|
||||
InputArrays []string
|
||||
IntermediateVars map[string]*vo.TypeInfo
|
||||
Outputs []*vo.FieldInfo
|
||||
|
||||
Inner compose.Runnable[map[string]any, map[string]any]
|
||||
}
|
||||
|
||||
type Type string
|
||||
|
||||
const (
|
||||
ByArray Type = "by_array"
|
||||
ByIteration Type = "by_iteration"
|
||||
Infinite Type = "infinite"
|
||||
)
|
||||
|
||||
func NewLoop(_ context.Context, conf *Config) (*Loop, error) {
|
||||
if conf == nil {
|
||||
return nil, errors.New("config is nil")
|
||||
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
if n.Parent() != nil {
|
||||
return nil, fmt.Errorf("loop node cannot have parent: %s", n.Parent().ID)
|
||||
}
|
||||
|
||||
if conf.LoopType == ByArray {
|
||||
if len(conf.InputArrays) == 0 {
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeLoop,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: c,
|
||||
}
|
||||
|
||||
loopType, err := toLoopType(n.Data.Inputs.LoopType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.LoopType = loopType
|
||||
|
||||
intermediateVars := make(map[string]*vo.TypeInfo)
|
||||
for _, param := range n.Data.Inputs.VariableParameters {
|
||||
tInfo, err := convert.CanvasBlockInputToTypeInfo(param.Input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
intermediateVars[param.Name] = tInfo
|
||||
|
||||
ns.SetInputType(param.Name, tInfo)
|
||||
sources, err := convert.CanvasBlockInputToFieldInfo(param.Input, compose.FieldPath{param.Name}, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ns.AddInputSource(sources...)
|
||||
}
|
||||
c.IntermediateVars = intermediateVars
|
||||
|
||||
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := convert.SetOutputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, fieldInfo := range ns.OutputSources {
|
||||
if fieldInfo.Source.Ref != nil {
|
||||
if len(fieldInfo.Source.Ref.FromPath) == 1 {
|
||||
if _, ok := intermediateVars[fieldInfo.Source.Ref.FromPath[0]]; ok {
|
||||
fieldInfo.Source.Ref.VariableType = ptr.Of(vo.ParentIntermediate)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
loopCount := n.Data.Inputs.LoopCount
|
||||
if loopCount != nil {
|
||||
typeInfo, err := convert.CanvasBlockInputToTypeInfo(loopCount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ns.SetInputType(Count, typeInfo)
|
||||
|
||||
sources, err := convert.CanvasBlockInputToFieldInfo(loopCount, compose.FieldPath{Count}, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ns.AddInputSource(sources...)
|
||||
}
|
||||
|
||||
for key, tInfo := range ns.InputTypes {
|
||||
if tInfo.Type != vo.DataTypeArray {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := intermediateVars[key]; ok { // exclude arrays in intermediate vars
|
||||
continue
|
||||
}
|
||||
|
||||
c.InputArrays = append(c.InputArrays, key)
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func toLoopType(l vo.LoopType) (Type, error) {
|
||||
switch l {
|
||||
case vo.LoopTypeArray:
|
||||
return ByArray, nil
|
||||
case vo.LoopTypeCount:
|
||||
return ByIteration, nil
|
||||
case vo.LoopTypeInfinite:
|
||||
return Infinite, nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported loop type: %s", l)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, opts ...schema.BuildOption) (any, error) {
|
||||
if c.LoopType == ByArray {
|
||||
if len(c.InputArrays) == 0 {
|
||||
return nil, errors.New("input arrays is empty when loop type is ByArray")
|
||||
}
|
||||
}
|
||||
|
||||
loop := &Loop{
|
||||
config: conf,
|
||||
outputs: make(map[string]*vo.FieldSource),
|
||||
outputVars: make(map[string]string),
|
||||
options := schema.GetBuildOptions(opts...)
|
||||
if options.Inner == nil {
|
||||
return nil, errors.New("inner workflow is required for Loop Node")
|
||||
}
|
||||
|
||||
for _, info := range conf.Outputs {
|
||||
loop := &Loop{
|
||||
outputs: make(map[string]*vo.FieldSource),
|
||||
outputVars: make(map[string]string),
|
||||
inputArrays: c.InputArrays,
|
||||
nodeKey: ns.Key,
|
||||
intermediateVars: c.IntermediateVars,
|
||||
inner: options.Inner,
|
||||
loopType: c.LoopType,
|
||||
}
|
||||
|
||||
for _, info := range ns.OutputSources {
|
||||
if len(info.Path) != 1 {
|
||||
return nil, fmt.Errorf("invalid output path: %s", info.Path)
|
||||
}
|
||||
|
|
@ -87,7 +184,7 @@ func NewLoop(_ context.Context, conf *Config) (*Loop, error) {
|
|||
return nil, fmt.Errorf("loop output refers to intermediate variable, but path length > 1: %v", fromPath)
|
||||
}
|
||||
|
||||
if _, ok := conf.IntermediateVars[fromPath[0]]; !ok {
|
||||
if _, ok := c.IntermediateVars[fromPath[0]]; !ok {
|
||||
return nil, fmt.Errorf("loop output refers to intermediate variable, but not found in intermediate vars: %v", fromPath)
|
||||
}
|
||||
|
||||
|
|
@ -102,18 +199,27 @@ func NewLoop(_ context.Context, conf *Config) (*Loop, error) {
|
|||
return loop, nil
|
||||
}
|
||||
|
||||
type Type string
|
||||
|
||||
const (
|
||||
ByArray Type = "by_array"
|
||||
ByIteration Type = "by_iteration"
|
||||
Infinite Type = "infinite"
|
||||
)
|
||||
|
||||
const (
|
||||
Count = "loopCount"
|
||||
)
|
||||
|
||||
func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.NestedWorkflowOption) (out map[string]any, err error) {
|
||||
func (l *Loop) Invoke(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (
|
||||
out map[string]any, err error) {
|
||||
maxIter, err := l.getMaxIter(in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
arrays := make(map[string][]any, len(l.config.InputArrays))
|
||||
for _, arrayKey := range l.config.InputArrays {
|
||||
arrays := make(map[string][]any, len(l.inputArrays))
|
||||
for _, arrayKey := range l.inputArrays {
|
||||
a, ok := nodes.TakeMapValue(in, compose.FieldPath{arrayKey})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("incoming array not present in input: %s", arrayKey)
|
||||
|
|
@ -121,10 +227,7 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
|
|||
arrays[arrayKey] = a.([]any)
|
||||
}
|
||||
|
||||
options := &nodes.NestedWorkflowOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
options := nodes.GetCommonOptions(&nodes.NodeOptions{}, opts...)
|
||||
|
||||
var (
|
||||
existingCState *nodes.NestedWorkflowState
|
||||
|
|
@ -134,7 +237,7 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
|
|||
)
|
||||
err = compose.ProcessState(ctx, func(ctx context.Context, getter nodes.NestedWorkflowAware) error {
|
||||
var e error
|
||||
existingCState, _, e = getter.GetNestedWorkflowState(l.config.LoopNodeKey)
|
||||
existingCState, _, e = getter.GetNestedWorkflowState(l.nodeKey)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
|
|
@ -150,15 +253,15 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
|
|||
for k := range existingCState.IntermediateVars {
|
||||
intermediateVars[k] = ptr.Of(existingCState.IntermediateVars[k])
|
||||
}
|
||||
intermediateVars[BreakKey] = &hasBreak
|
||||
intermediateVars[_break.BreakKey] = &hasBreak
|
||||
} else {
|
||||
output = make(map[string]any, len(l.outputs))
|
||||
for k := range l.outputs {
|
||||
output[k] = make([]any, 0)
|
||||
}
|
||||
|
||||
intermediateVars = make(map[string]*any, len(l.config.IntermediateVars))
|
||||
for varKey := range l.config.IntermediateVars {
|
||||
intermediateVars = make(map[string]*any, len(l.intermediateVars))
|
||||
for varKey := range l.intermediateVars {
|
||||
v, ok := nodes.TakeMapValue(in, compose.FieldPath{varKey})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("incoming intermediate variable not present in input: %s", varKey)
|
||||
|
|
@ -166,10 +269,10 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
|
|||
|
||||
intermediateVars[varKey] = &v
|
||||
}
|
||||
intermediateVars[BreakKey] = &hasBreak
|
||||
intermediateVars[_break.BreakKey] = &hasBreak
|
||||
}
|
||||
|
||||
ctx = nodes.InitIntermediateVars(ctx, intermediateVars, l.config.IntermediateVars)
|
||||
ctx = nodes.InitIntermediateVars(ctx, intermediateVars, l.intermediateVars)
|
||||
|
||||
getIthInput := func(i int) (map[string]any, map[string]any, error) {
|
||||
input := make(map[string]any)
|
||||
|
|
@ -190,13 +293,13 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
|
|||
input[k] = v
|
||||
}
|
||||
|
||||
input[string(l.config.LoopNodeKey)+"#index"] = int64(i)
|
||||
input[string(l.nodeKey)+"#index"] = int64(i)
|
||||
|
||||
items := make(map[string]any)
|
||||
for arrayKey := range arrays {
|
||||
ele := arrays[arrayKey][i]
|
||||
items[arrayKey] = ele
|
||||
currentKey := string(l.config.LoopNodeKey) + "#" + arrayKey
|
||||
currentKey := string(l.nodeKey) + "#" + arrayKey
|
||||
|
||||
// Recursively expand map[string]any elements
|
||||
var expand func(prefix string, val interface{})
|
||||
|
|
@ -276,7 +379,7 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
|
|||
}
|
||||
}
|
||||
|
||||
taskOutput, err := l.config.Inner.Invoke(subCtx, input, ithOpts...)
|
||||
taskOutput, err := l.inner.Invoke(subCtx, input, ithOpts...)
|
||||
if err != nil {
|
||||
info, ok := compose.ExtractInterruptInfo(err)
|
||||
if !ok {
|
||||
|
|
@ -322,29 +425,26 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
|
|||
|
||||
if len(index2InterruptInfo) > 0 { // this invocation of batch.Execute has new interruptions
|
||||
iEvent := &entity.InterruptEvent{
|
||||
NodeKey: l.config.LoopNodeKey,
|
||||
NodeKey: l.nodeKey,
|
||||
NodeType: entity.NodeTypeLoop,
|
||||
NestedInterruptInfo: index2InterruptInfo, // only emit the newly generated interruptInfo
|
||||
}
|
||||
|
||||
err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error {
|
||||
if e := setter.SaveNestedWorkflowState(l.config.LoopNodeKey, compState); e != nil {
|
||||
if e := setter.SaveNestedWorkflowState(l.nodeKey, compState); e != nil {
|
||||
return e
|
||||
}
|
||||
|
||||
return setter.SetInterruptEvent(l.config.LoopNodeKey, iEvent)
|
||||
return setter.SetInterruptEvent(l.nodeKey, iEvent)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fmt.Println("save interruptEvent in state within loop: ", iEvent)
|
||||
fmt.Println("save composite info in state within loop: ", compState)
|
||||
|
||||
return nil, compose.InterruptAndRerun
|
||||
} else {
|
||||
err := compose.ProcessState(ctx, func(ctx context.Context, setter nodes.NestedWorkflowAware) error {
|
||||
return setter.SaveNestedWorkflowState(l.config.LoopNodeKey, compState)
|
||||
return setter.SaveNestedWorkflowState(l.nodeKey, compState)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -354,8 +454,7 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
|
|||
}
|
||||
|
||||
if existingCState != nil && len(existingCState.Index2InterruptInfo) > 0 {
|
||||
fmt.Println("no interrupt thrown this round, but has historical interrupt events: ", existingCState.Index2InterruptInfo)
|
||||
panic("impossible")
|
||||
panic(fmt.Sprintf("no interrupt thrown this round, but has historical interrupt events: %v", existingCState.Index2InterruptInfo))
|
||||
}
|
||||
|
||||
for outputVarKey, intermediateVarKey := range l.outputVars {
|
||||
|
|
@ -368,9 +467,9 @@ func (l *Loop) Execute(ctx context.Context, in map[string]any, opts ...nodes.Nes
|
|||
func (l *Loop) getMaxIter(in map[string]any) (int, error) {
|
||||
maxIter := math.MaxInt
|
||||
|
||||
switch l.config.LoopType {
|
||||
switch l.loopType {
|
||||
case ByArray:
|
||||
for _, arrayKey := range l.config.InputArrays {
|
||||
for _, arrayKey := range l.inputArrays {
|
||||
a, ok := nodes.TakeMapValue(in, compose.FieldPath{arrayKey})
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("incoming array not present in input: %s", arrayKey)
|
||||
|
|
@ -394,7 +493,7 @@ func (l *Loop) getMaxIter(in map[string]any) (int, error) {
|
|||
maxIter = int(iter.(int64))
|
||||
case Infinite:
|
||||
default:
|
||||
return 0, fmt.Errorf("loop type not supported: %v", l.config.LoopType)
|
||||
return 0, fmt.Errorf("loop type not supported: %v", l.loopType)
|
||||
}
|
||||
|
||||
return maxIter, nil
|
||||
|
|
@ -409,8 +508,8 @@ func convertIntermediateVars(vars map[string]*any) map[string]any {
|
|||
}
|
||||
|
||||
func (l *Loop) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) {
|
||||
trimmed := make(map[string]any, len(l.config.InputArrays))
|
||||
for _, arrayKey := range l.config.InputArrays {
|
||||
trimmed := make(map[string]any, len(l.inputArrays))
|
||||
for _, arrayKey := range l.inputArrays {
|
||||
if v, ok := in[arrayKey]; ok {
|
||||
trimmed[arrayKey] = v
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,90 +0,0 @@
|
|||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
)
|
||||
|
||||
type NestedWorkflowOptions struct {
|
||||
optsForNested []compose.Option
|
||||
toResumeIndexes map[int]compose.StateModifier
|
||||
optsForIndexed map[int][]compose.Option
|
||||
}
|
||||
|
||||
type NestedWorkflowOption func(*NestedWorkflowOptions)
|
||||
|
||||
func WithOptsForNested(opts ...compose.Option) NestedWorkflowOption {
|
||||
return func(o *NestedWorkflowOptions) {
|
||||
o.optsForNested = append(o.optsForNested, opts...)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *NestedWorkflowOptions) GetOptsForNested() []compose.Option {
|
||||
return c.optsForNested
|
||||
}
|
||||
|
||||
func WithResumeIndex(i int, m compose.StateModifier) NestedWorkflowOption {
|
||||
return func(o *NestedWorkflowOptions) {
|
||||
if o.toResumeIndexes == nil {
|
||||
o.toResumeIndexes = map[int]compose.StateModifier{}
|
||||
}
|
||||
|
||||
o.toResumeIndexes[i] = m
|
||||
}
|
||||
}
|
||||
|
||||
func (c *NestedWorkflowOptions) GetResumeIndexes() map[int]compose.StateModifier {
|
||||
return c.toResumeIndexes
|
||||
}
|
||||
|
||||
func WithOptsForIndexed(index int, opts ...compose.Option) NestedWorkflowOption {
|
||||
return func(o *NestedWorkflowOptions) {
|
||||
if o.optsForIndexed == nil {
|
||||
o.optsForIndexed = map[int][]compose.Option{}
|
||||
}
|
||||
o.optsForIndexed[index] = opts
|
||||
}
|
||||
}
|
||||
|
||||
func (c *NestedWorkflowOptions) GetOptsForIndexed(index int) []compose.Option {
|
||||
if c.optsForIndexed == nil {
|
||||
return nil
|
||||
}
|
||||
return c.optsForIndexed[index]
|
||||
}
|
||||
|
||||
type NestedWorkflowState struct {
|
||||
Index2Done map[int]bool `json:"index_2_done,omitempty"`
|
||||
Index2InterruptInfo map[int]*compose.InterruptInfo `json:"index_2_interrupt_info,omitempty"`
|
||||
FullOutput map[string]any `json:"full_output,omitempty"`
|
||||
IntermediateVars map[string]any `json:"intermediate_vars,omitempty"`
|
||||
}
|
||||
|
||||
func (c *NestedWorkflowState) String() string {
|
||||
s, _ := sonic.MarshalIndent(c, "", " ")
|
||||
return string(s)
|
||||
}
|
||||
|
||||
type NestedWorkflowAware interface {
|
||||
SaveNestedWorkflowState(key vo.NodeKey, state *NestedWorkflowState) error
|
||||
GetNestedWorkflowState(key vo.NodeKey) (*NestedWorkflowState, bool, error)
|
||||
InterruptEventStore
|
||||
}
|
||||
|
|
@ -0,0 +1,194 @@
|
|||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
einoschema "github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
// InvokableNode is a basic workflow node that can Invoke.
|
||||
// Invoke accepts non-streaming input and returns non-streaming output.
|
||||
// It does not accept any options.
|
||||
// Most nodes implement this, such as NodeTypePlugin.
|
||||
type InvokableNode interface {
|
||||
Invoke(ctx context.Context, input map[string]any) (
|
||||
output map[string]any, err error)
|
||||
}
|
||||
|
||||
// InvokableNodeWOpt is a workflow node that can Invoke.
|
||||
// Invoke accepts non-streaming input and returns non-streaming output.
|
||||
// It can accept NodeOption.
|
||||
// e.g. NodeTypeLLM, NodeTypeSubWorkflow implement this.
|
||||
type InvokableNodeWOpt interface {
|
||||
Invoke(ctx context.Context, in map[string]any, opts ...NodeOption) (
|
||||
map[string]any, error)
|
||||
}
|
||||
|
||||
// StreamableNode is a workflow node that can Stream.
|
||||
// Stream accepts non-streaming input and returns streaming output.
|
||||
// It does not accept and options
|
||||
// Currently NO Node implement this.
|
||||
// A potential example would be streamable plugin for NodeTypePlugin.
|
||||
type StreamableNode interface {
|
||||
Stream(ctx context.Context, in map[string]any) (
|
||||
*einoschema.StreamReader[map[string]any], error)
|
||||
}
|
||||
|
||||
// StreamableNodeWOpt is a workflow node that can Stream.
|
||||
// Stream accepts non-streaming input and returns streaming output.
|
||||
// It can accept NodeOption.
|
||||
// e.g. NodeTypeLLM implement this.
|
||||
type StreamableNodeWOpt interface {
|
||||
Stream(ctx context.Context, in map[string]any, opts ...NodeOption) (
|
||||
*einoschema.StreamReader[map[string]any], error)
|
||||
}
|
||||
|
||||
// CollectableNode is a workflow node that can Collect.
|
||||
// Collect accepts streaming input and returns non-streaming output.
|
||||
// It does not accept and options
|
||||
// Currently NO Node implement this.
|
||||
// A potential example would be a new condition node that makes decisions
|
||||
// based on streaming input.
|
||||
type CollectableNode interface {
|
||||
Collect(ctx context.Context, in *einoschema.StreamReader[map[string]any]) (
|
||||
map[string]any, error)
|
||||
}
|
||||
|
||||
// CollectableNodeWOpt is a workflow node that can Collect.
|
||||
// Collect accepts streaming input and returns non-streaming output.
|
||||
// It accepts NodeOption.
|
||||
// Currently NO Node implement this.
|
||||
// A potential example would be a new batch node that accepts streaming input,
|
||||
// process them, and finally returns non-stream aggregation of results.
|
||||
type CollectableNodeWOpt interface {
|
||||
Collect(ctx context.Context, in *einoschema.StreamReader[map[string]any], opts ...NodeOption) (
|
||||
map[string]any, error)
|
||||
}
|
||||
|
||||
// TransformableNode is a workflow node that can Transform.
|
||||
// Transform accepts streaming input and returns streaming output.
|
||||
// It does not accept and options
|
||||
// e.g.
|
||||
// NodeTypeVariableAggregator implements TransformableNode.
|
||||
type TransformableNode interface {
|
||||
Transform(ctx context.Context, in *einoschema.StreamReader[map[string]any]) (
|
||||
*einoschema.StreamReader[map[string]any], error)
|
||||
}
|
||||
|
||||
// TransformableNodeWOpt is a workflow node that can Transform.
|
||||
// Transform accepts streaming input and returns streaming output.
|
||||
// It accepts NodeOption.
|
||||
// Currently NO Node implement this.
|
||||
// A potential example would be an audio processing node that
|
||||
// transforms input audio clips, but within the node is a graph
|
||||
// composed by Eino, and the audio processing node needs to carry
|
||||
// options for this inner graph.
|
||||
type TransformableNodeWOpt interface {
|
||||
Transform(ctx context.Context, in *einoschema.StreamReader[map[string]any], opts ...NodeOption) (
|
||||
*einoschema.StreamReader[map[string]any], error)
|
||||
}
|
||||
|
||||
// CallbackInputConverted converts node input to a form better suited for UI.
|
||||
// The converted input will be displayed on canvas when test run,
|
||||
// and will be returned when querying the node's input through OpenAPI.
|
||||
type CallbackInputConverted interface {
|
||||
ToCallbackInput(ctx context.Context, in map[string]any) (map[string]any, error)
|
||||
}
|
||||
|
||||
// CallbackOutputConverted converts node input to a form better suited for UI.
|
||||
// The converted output will be displayed on canvas when test run,
|
||||
// and will be returned when querying the node's output through OpenAPI.
|
||||
type CallbackOutputConverted interface {
|
||||
ToCallbackOutput(ctx context.Context, out map[string]any) (*StructuredCallbackOutput, error)
|
||||
}
|
||||
|
||||
type Initializer interface {
|
||||
Init(ctx context.Context) (context.Context, error)
|
||||
}
|
||||
|
||||
type AdaptOptions struct {
|
||||
Canvas *vo.Canvas
|
||||
}
|
||||
|
||||
type AdaptOption func(*AdaptOptions)
|
||||
|
||||
func WithCanvas(canvas *vo.Canvas) AdaptOption {
|
||||
return func(opts *AdaptOptions) {
|
||||
opts.Canvas = canvas
|
||||
}
|
||||
}
|
||||
|
||||
func GetAdaptOptions(opts ...AdaptOption) *AdaptOptions {
|
||||
options := &AdaptOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
return options
|
||||
}
|
||||
|
||||
// NodeAdaptor provides conversion from frontend Node to backend NodeSchema.
|
||||
type NodeAdaptor interface {
|
||||
Adapt(ctx context.Context, n *vo.Node, opts ...AdaptOption) (
|
||||
*schema.NodeSchema, error)
|
||||
}
|
||||
|
||||
// BranchAdaptor provides validation and conversion from frontend port to backend port.
|
||||
type BranchAdaptor interface {
|
||||
ExpectPorts(ctx context.Context, n *vo.Node) []string
|
||||
}
|
||||
|
||||
var (
|
||||
nodeAdaptors = map[entity.NodeType]func() NodeAdaptor{}
|
||||
branchAdaptors = map[entity.NodeType]func() BranchAdaptor{}
|
||||
)
|
||||
|
||||
func RegisterNodeAdaptor(et entity.NodeType, f func() NodeAdaptor) {
|
||||
nodeAdaptors[et] = f
|
||||
}
|
||||
|
||||
func GetNodeAdaptor(et entity.NodeType) (NodeAdaptor, bool) {
|
||||
na, ok := nodeAdaptors[et]
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("node type %s not registered", et))
|
||||
}
|
||||
return na(), ok
|
||||
}
|
||||
|
||||
func RegisterBranchAdaptor(et entity.NodeType, f func() BranchAdaptor) {
|
||||
branchAdaptors[et] = f
|
||||
}
|
||||
|
||||
func GetBranchAdaptor(et entity.NodeType) (BranchAdaptor, bool) {
|
||||
na, ok := branchAdaptors[et]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return na(), ok
|
||||
}
|
||||
|
||||
type StreamGenerator interface {
|
||||
FieldStreamType(path compose.FieldPath, ns *schema.NodeSchema,
|
||||
sc *schema.WorkflowSchema) (schema.FieldStreamType, error)
|
||||
}
|
||||
|
|
@ -0,0 +1,170 @@
|
|||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
)
|
||||
|
||||
type NodeOptions struct {
|
||||
Nested *NestedWorkflowOptions
|
||||
}
|
||||
|
||||
type NestedWorkflowOptions struct {
|
||||
optsForNested []compose.Option
|
||||
toResumeIndexes map[int]compose.StateModifier
|
||||
optsForIndexed map[int][]compose.Option
|
||||
}
|
||||
|
||||
type NodeOption struct {
|
||||
apply func(opts *NodeOptions)
|
||||
|
||||
implSpecificOptFn any
|
||||
}
|
||||
|
||||
type NestedWorkflowOption func(*NestedWorkflowOptions)
|
||||
|
||||
func WithOptsForNested(opts ...compose.Option) NodeOption {
|
||||
return NodeOption{
|
||||
apply: func(options *NodeOptions) {
|
||||
if options.Nested == nil {
|
||||
options.Nested = &NestedWorkflowOptions{}
|
||||
}
|
||||
options.Nested.optsForNested = append(options.Nested.optsForNested, opts...)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *NodeOptions) GetOptsForNested() []compose.Option {
|
||||
if c.Nested == nil {
|
||||
return nil
|
||||
}
|
||||
return c.Nested.optsForNested
|
||||
}
|
||||
|
||||
func WithResumeIndex(i int, m compose.StateModifier) NodeOption {
|
||||
return NodeOption{
|
||||
apply: func(options *NodeOptions) {
|
||||
if options.Nested == nil {
|
||||
options.Nested = &NestedWorkflowOptions{}
|
||||
}
|
||||
if options.Nested.toResumeIndexes == nil {
|
||||
options.Nested.toResumeIndexes = map[int]compose.StateModifier{}
|
||||
}
|
||||
|
||||
options.Nested.toResumeIndexes[i] = m
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *NodeOptions) GetResumeIndexes() map[int]compose.StateModifier {
|
||||
if c.Nested == nil {
|
||||
return nil
|
||||
}
|
||||
return c.Nested.toResumeIndexes
|
||||
}
|
||||
|
||||
func WithOptsForIndexed(index int, opts ...compose.Option) NodeOption {
|
||||
return NodeOption{
|
||||
apply: func(options *NodeOptions) {
|
||||
if options.Nested == nil {
|
||||
options.Nested = &NestedWorkflowOptions{}
|
||||
}
|
||||
if options.Nested.optsForIndexed == nil {
|
||||
options.Nested.optsForIndexed = map[int][]compose.Option{}
|
||||
}
|
||||
options.Nested.optsForIndexed[index] = opts
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *NodeOptions) GetOptsForIndexed(index int) []compose.Option {
|
||||
if c.Nested == nil {
|
||||
return nil
|
||||
}
|
||||
return c.Nested.optsForIndexed[index]
|
||||
}
|
||||
|
||||
// WrapImplSpecificOptFn is the option to wrap the implementation specific option function.
|
||||
func WrapImplSpecificOptFn[T any](optFn func(*T)) NodeOption {
|
||||
return NodeOption{
|
||||
implSpecificOptFn: optFn,
|
||||
}
|
||||
}
|
||||
|
||||
// GetCommonOptions extract model Options from Option list, optionally providing a base Options with default values.
|
||||
func GetCommonOptions(base *NodeOptions, opts ...NodeOption) *NodeOptions {
|
||||
if base == nil {
|
||||
base = &NodeOptions{}
|
||||
}
|
||||
|
||||
for i := range opts {
|
||||
opt := opts[i]
|
||||
if opt.apply != nil {
|
||||
opt.apply(base)
|
||||
}
|
||||
}
|
||||
|
||||
return base
|
||||
}
|
||||
|
||||
// GetImplSpecificOptions extract the implementation specific options from Option list, optionally providing a base options with default values.
|
||||
// e.g.
|
||||
//
|
||||
// myOption := &MyOption{
|
||||
// Field1: "default_value",
|
||||
// }
|
||||
//
|
||||
// myOption := model.GetImplSpecificOptions(myOption, opts...)
|
||||
func GetImplSpecificOptions[T any](base *T, opts ...NodeOption) *T {
|
||||
if base == nil {
|
||||
base = new(T)
|
||||
}
|
||||
|
||||
for i := range opts {
|
||||
opt := opts[i]
|
||||
if opt.implSpecificOptFn != nil {
|
||||
optFn, ok := opt.implSpecificOptFn.(func(*T))
|
||||
if ok {
|
||||
optFn(base)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return base
|
||||
}
|
||||
|
||||
type NestedWorkflowState struct {
|
||||
Index2Done map[int]bool `json:"index_2_done,omitempty"`
|
||||
Index2InterruptInfo map[int]*compose.InterruptInfo `json:"index_2_interrupt_info,omitempty"`
|
||||
FullOutput map[string]any `json:"full_output,omitempty"`
|
||||
IntermediateVars map[string]any `json:"intermediate_vars,omitempty"`
|
||||
}
|
||||
|
||||
func (c *NestedWorkflowState) String() string {
|
||||
s, _ := sonic.MarshalIndent(c, "", " ")
|
||||
return string(s)
|
||||
}
|
||||
|
||||
type NestedWorkflowAware interface {
|
||||
SaveNestedWorkflowState(key vo.NodeKey, state *NestedWorkflowState) error
|
||||
GetNestedWorkflowState(key vo.NodeKey) (*NestedWorkflowState, bool, error)
|
||||
InterruptEventStore
|
||||
}
|
||||
|
|
@ -18,16 +18,21 @@ package plugin
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
|
|
@ -35,29 +40,76 @@ type Config struct {
|
|||
PluginID int64
|
||||
ToolID int64
|
||||
PluginVersion string
|
||||
}
|
||||
|
||||
PluginService plugin.Service
|
||||
func (c *Config) Adapt(ctx context.Context, n *vo.Node, opts ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypePlugin,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: c,
|
||||
}
|
||||
inputs := n.Data.Inputs
|
||||
|
||||
apiParams := slices.ToMap(inputs.APIParams, func(e *vo.Param) (string, *vo.Param) {
|
||||
return e.Name, e
|
||||
})
|
||||
|
||||
ps, ok := apiParams["pluginID"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("plugin id param is not found")
|
||||
}
|
||||
|
||||
pID, err := strconv.ParseInt(ps.Input.Value.Content.(string), 10, 64)
|
||||
|
||||
c.PluginID = pID
|
||||
|
||||
ps, ok = apiParams["apiID"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("plugin id param is not found")
|
||||
}
|
||||
|
||||
tID, err := strconv.ParseInt(ps.Input.Value.Content.(string), 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.ToolID = tID
|
||||
|
||||
ps, ok = apiParams["pluginVersion"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("plugin version param is not found")
|
||||
}
|
||||
version := ps.Input.Value.Content.(string)
|
||||
|
||||
c.PluginVersion = version
|
||||
|
||||
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func (c *Config) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
return &Plugin{
|
||||
pluginID: c.PluginID,
|
||||
toolID: c.ToolID,
|
||||
pluginVersion: c.PluginVersion,
|
||||
pluginService: plugin.GetPluginService(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type Plugin struct {
|
||||
config *Config
|
||||
}
|
||||
pluginID int64
|
||||
toolID int64
|
||||
pluginVersion string
|
||||
|
||||
func NewPlugin(_ context.Context, cfg *Config) (*Plugin, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config is nil")
|
||||
}
|
||||
if cfg.PluginID == 0 {
|
||||
return nil, errors.New("plugin id is required")
|
||||
}
|
||||
if cfg.ToolID == 0 {
|
||||
return nil, errors.New("tool id is required")
|
||||
}
|
||||
if cfg.PluginService == nil {
|
||||
return nil, errors.New("tool service is required")
|
||||
}
|
||||
|
||||
return &Plugin{config: cfg}, nil
|
||||
pluginService plugin.Service
|
||||
}
|
||||
|
||||
func (p *Plugin) Invoke(ctx context.Context, parameters map[string]any) (ret map[string]any, err error) {
|
||||
|
|
@ -65,10 +117,10 @@ func (p *Plugin) Invoke(ctx context.Context, parameters map[string]any) (ret map
|
|||
if ctxExeCfg := execute.GetExeCtx(ctx); ctxExeCfg != nil {
|
||||
exeCfg = ctxExeCfg.ExeCfg
|
||||
}
|
||||
result, err := p.config.PluginService.ExecutePlugin(ctx, parameters, &vo.PluginEntity{
|
||||
PluginID: p.config.PluginID,
|
||||
PluginVersion: ptr.Of(p.config.PluginVersion),
|
||||
}, p.config.ToolID, exeCfg)
|
||||
result, err := p.pluginService.ExecutePlugin(ctx, parameters, &vo.PluginEntity{
|
||||
PluginID: p.pluginID,
|
||||
PluginVersion: ptr.Of(p.pluginVersion),
|
||||
}, p.toolID, exeCfg)
|
||||
if err != nil {
|
||||
if extra, ok := compose.IsInterruptRerunError(err); ok {
|
||||
// TODO: temporarily replace interrupt with real error, because frontend cannot handle interrupt for now
|
||||
|
|
|
|||
|
|
@ -29,9 +29,12 @@ import (
|
|||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
crossmodel "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
|
|
@ -39,8 +42,21 @@ import (
|
|||
)
|
||||
|
||||
type QuestionAnswer struct {
|
||||
config *Config
|
||||
model model.BaseChatModel
|
||||
nodeMeta entity.NodeTypeMeta
|
||||
|
||||
questionTpl string
|
||||
answerType AnswerType
|
||||
|
||||
choiceType ChoiceType
|
||||
fixedChoices []string
|
||||
|
||||
needExtractFromAnswer bool
|
||||
additionalSystemPromptTpl string
|
||||
maxAnswerCount int
|
||||
|
||||
nodeKey vo.NodeKey
|
||||
outputFields map[string]*vo.TypeInfo
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
|
|
@ -51,15 +67,249 @@ type Config struct {
|
|||
FixedChoices []string
|
||||
|
||||
// used for intent recognize if answer by choices and given a custom answer, as well as for extracting structured output from user response
|
||||
Model model.BaseChatModel
|
||||
LLMParams *crossmodel.LLMParams
|
||||
|
||||
// the following are required if AnswerType is AnswerDirectly and needs to extract from answer
|
||||
ExtractFromAnswer bool
|
||||
AdditionalSystemPromptTpl string
|
||||
MaxAnswerCount int
|
||||
OutputFields map[string]*vo.TypeInfo
|
||||
}
|
||||
|
||||
NodeKey vo.NodeKey
|
||||
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema2.NodeSchema, error) {
|
||||
ns := &schema2.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeQuestionAnswer,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: c,
|
||||
}
|
||||
|
||||
qaConf := n.Data.Inputs.QA
|
||||
if qaConf == nil {
|
||||
return nil, fmt.Errorf("qa config is nil")
|
||||
}
|
||||
c.QuestionTpl = qaConf.Question
|
||||
|
||||
var llmParams *crossmodel.LLMParams
|
||||
if n.Data.Inputs.LLMParam != nil {
|
||||
llmParamBytes, err := sonic.Marshal(n.Data.Inputs.LLMParam)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var qaLLMParams vo.SimpleLLMParam
|
||||
err = sonic.Unmarshal(llmParamBytes, &qaLLMParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
llmParams, err = convertLLMParams(qaLLMParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.LLMParams = llmParams
|
||||
}
|
||||
|
||||
answerType, err := convertAnswerType(qaConf.AnswerType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.AnswerType = answerType
|
||||
|
||||
var choiceType ChoiceType
|
||||
if len(qaConf.OptionType) > 0 {
|
||||
choiceType, err = convertChoiceType(qaConf.OptionType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.ChoiceType = choiceType
|
||||
}
|
||||
|
||||
if answerType == AnswerByChoices {
|
||||
switch choiceType {
|
||||
case FixedChoices:
|
||||
var options []string
|
||||
for _, option := range qaConf.Options {
|
||||
options = append(options, option.Name)
|
||||
}
|
||||
c.FixedChoices = options
|
||||
case DynamicChoices:
|
||||
inputSources, err := convert.CanvasBlockInputToFieldInfo(qaConf.DynamicOption, compose.FieldPath{DynamicChoicesKey}, n.Parent())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ns.AddInputSource(inputSources...)
|
||||
|
||||
inputTypes, err := convert.CanvasBlockInputToTypeInfo(qaConf.DynamicOption)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ns.SetInputType(DynamicChoicesKey, inputTypes)
|
||||
default:
|
||||
return nil, fmt.Errorf("qa node is answer by options, but option type not provided")
|
||||
}
|
||||
} else if answerType == AnswerDirectly {
|
||||
c.ExtractFromAnswer = qaConf.ExtractOutput
|
||||
if qaConf.ExtractOutput {
|
||||
if llmParams == nil {
|
||||
return nil, fmt.Errorf("qa node needs to extract from answer, but LLMParams not provided")
|
||||
}
|
||||
c.AdditionalSystemPromptTpl = llmParams.SystemPrompt
|
||||
c.MaxAnswerCount = qaConf.Limit
|
||||
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err = convert.SetInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func convertLLMParams(params vo.SimpleLLMParam) (*crossmodel.LLMParams, error) {
|
||||
p := &crossmodel.LLMParams{}
|
||||
p.ModelName = params.ModelName
|
||||
p.ModelType = params.ModelType
|
||||
p.Temperature = ¶ms.Temperature
|
||||
p.MaxTokens = params.MaxTokens
|
||||
p.TopP = ¶ms.TopP
|
||||
p.ResponseFormat = params.ResponseFormat
|
||||
p.SystemPrompt = params.SystemPrompt
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func convertAnswerType(t vo.QAAnswerType) (AnswerType, error) {
|
||||
switch t {
|
||||
case vo.QAAnswerTypeOption:
|
||||
return AnswerByChoices, nil
|
||||
case vo.QAAnswerTypeText:
|
||||
return AnswerDirectly, nil
|
||||
default:
|
||||
return "", fmt.Errorf("invalid QAAnswerType: %s", t)
|
||||
}
|
||||
}
|
||||
|
||||
func convertChoiceType(t vo.QAOptionType) (ChoiceType, error) {
|
||||
switch t {
|
||||
case vo.QAOptionTypeStatic:
|
||||
return FixedChoices, nil
|
||||
case vo.QAOptionTypeDynamic:
|
||||
return DynamicChoices, nil
|
||||
default:
|
||||
return "", fmt.Errorf("invalid QAOptionType: %s", t)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) Build(ctx context.Context, ns *schema2.NodeSchema, _ ...schema2.BuildOption) (any, error) {
|
||||
if c.AnswerType == AnswerDirectly {
|
||||
if c.ExtractFromAnswer {
|
||||
if c.LLMParams == nil {
|
||||
return nil, errors.New("model is required when extract from answer")
|
||||
}
|
||||
if len(ns.OutputTypes) == 0 {
|
||||
return nil, errors.New("output fields is required when extract from answer")
|
||||
}
|
||||
}
|
||||
} else if c.AnswerType == AnswerByChoices {
|
||||
if c.ChoiceType == FixedChoices {
|
||||
if len(c.FixedChoices) == 0 {
|
||||
return nil, errors.New("fixed choices is required when extract from answer")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("unknown answer type: %s", c.AnswerType)
|
||||
}
|
||||
|
||||
nodeMeta := entity.NodeMetaByNodeType(entity.NodeTypeQuestionAnswer)
|
||||
if nodeMeta == nil {
|
||||
return nil, errors.New("node meta not found for question answer")
|
||||
}
|
||||
|
||||
var (
|
||||
m model.BaseChatModel
|
||||
err error
|
||||
)
|
||||
if c.LLMParams != nil {
|
||||
m, _, err = crossmodel.GetManager().GetModel(ctx, c.LLMParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &QuestionAnswer{
|
||||
model: m,
|
||||
nodeMeta: *nodeMeta,
|
||||
questionTpl: c.QuestionTpl,
|
||||
answerType: c.AnswerType,
|
||||
choiceType: c.ChoiceType,
|
||||
fixedChoices: c.FixedChoices,
|
||||
needExtractFromAnswer: c.ExtractFromAnswer,
|
||||
additionalSystemPromptTpl: c.AdditionalSystemPromptTpl,
|
||||
maxAnswerCount: c.MaxAnswerCount,
|
||||
nodeKey: ns.Key,
|
||||
outputFields: ns.OutputTypes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Config) BuildBranch(_ context.Context) (
|
||||
func(ctx context.Context, nodeOutput map[string]any) (int64, bool, error), bool) {
|
||||
if c.AnswerType != AnswerByChoices {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return func(ctx context.Context, nodeOutput map[string]any) (int64, bool, error) {
|
||||
optionID, ok := nodeOutput[OptionIDKey]
|
||||
if !ok {
|
||||
return -1, false, fmt.Errorf("failed to take option id from input map: %v", nodeOutput)
|
||||
}
|
||||
|
||||
if c.ChoiceType == DynamicChoices {
|
||||
if optionID.(string) == "other" {
|
||||
return -1, true, nil
|
||||
} else {
|
||||
return 0, false, nil
|
||||
}
|
||||
}
|
||||
|
||||
if optionID.(string) == "other" {
|
||||
return -1, true, nil
|
||||
}
|
||||
|
||||
optionIDInt, ok := AlphabetToInt(optionID.(string))
|
||||
if !ok {
|
||||
return -1, false, fmt.Errorf("failed to convert option id from input map: %v", optionID)
|
||||
}
|
||||
|
||||
return optionIDInt, false, nil
|
||||
}, true
|
||||
}
|
||||
|
||||
func (c *Config) ExpectPorts(ctx context.Context, n *vo.Node) (expects []string) {
|
||||
if n.Data.Inputs.QA.AnswerType != vo.QAAnswerTypeOption {
|
||||
return expects
|
||||
}
|
||||
|
||||
if n.Data.Inputs.QA.OptionType == vo.QAOptionTypeStatic {
|
||||
for index := range n.Data.Inputs.QA.Options {
|
||||
expects = append(expects, fmt.Sprintf(schema2.PortBranchFormat, index))
|
||||
}
|
||||
|
||||
expects = append(expects, schema2.PortDefault)
|
||||
return expects
|
||||
}
|
||||
|
||||
if n.Data.Inputs.QA.OptionType == vo.QAOptionTypeDynamic {
|
||||
expects = append(expects, fmt.Sprintf(schema2.PortBranchFormat, 0))
|
||||
expects = append(expects, schema2.PortDefault)
|
||||
}
|
||||
|
||||
return expects
|
||||
}
|
||||
|
||||
func (c *Config) RequireCheckpoint() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type AnswerType string
|
||||
|
|
@ -126,41 +376,6 @@ Strictly identify the intention and select the most suitable option. You can onl
|
|||
Note: You can only output the id or -1. Your output can only be a pure number and no other content (including the reason)!`
|
||||
)
|
||||
|
||||
func NewQuestionAnswer(_ context.Context, conf *Config) (*QuestionAnswer, error) {
|
||||
if conf == nil {
|
||||
return nil, errors.New("config is nil")
|
||||
}
|
||||
|
||||
if conf.AnswerType == AnswerDirectly {
|
||||
if conf.ExtractFromAnswer {
|
||||
if conf.Model == nil {
|
||||
return nil, errors.New("model is required when extract from answer")
|
||||
}
|
||||
if len(conf.OutputFields) == 0 {
|
||||
return nil, errors.New("output fields is required when extract from answer")
|
||||
}
|
||||
}
|
||||
} else if conf.AnswerType == AnswerByChoices {
|
||||
if conf.ChoiceType == FixedChoices {
|
||||
if len(conf.FixedChoices) == 0 {
|
||||
return nil, errors.New("fixed choices is required when extract from answer")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("unknown answer type: %s", conf.AnswerType)
|
||||
}
|
||||
|
||||
nodeMeta := entity.NodeMetaByNodeType(entity.NodeTypeQuestionAnswer)
|
||||
if nodeMeta == nil {
|
||||
return nil, errors.New("node meta not found for question answer")
|
||||
}
|
||||
|
||||
return &QuestionAnswer{
|
||||
config: conf,
|
||||
nodeMeta: *nodeMeta,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type Question struct {
|
||||
Question string
|
||||
Choices []string
|
||||
|
|
@ -182,10 +397,10 @@ type message struct {
|
|||
ID string `json:"id,omitempty"`
|
||||
}
|
||||
|
||||
// Execute formats the question (optionally with choices), interrupts, then extracts the answer.
|
||||
// Invoke formats the question (optionally with choices), interrupts, then extracts the answer.
|
||||
// input: the references by input fields, as well as the dynamic choices array if needed.
|
||||
// output: USER_RESPONSE for direct answer, structured output if needs to extract from answer, and option ID / content for answer by choices.
|
||||
func (q *QuestionAnswer) Execute(ctx context.Context, in map[string]any) (out map[string]any, err error) {
|
||||
func (q *QuestionAnswer) Invoke(ctx context.Context, in map[string]any) (out map[string]any, err error) {
|
||||
var (
|
||||
questions []*Question
|
||||
answers []string
|
||||
|
|
@ -206,11 +421,11 @@ func (q *QuestionAnswer) Execute(ctx context.Context, in map[string]any) (out ma
|
|||
out[QuestionsKey] = questions
|
||||
out[AnswersKey] = answers
|
||||
|
||||
switch q.config.AnswerType {
|
||||
switch q.answerType {
|
||||
case AnswerDirectly:
|
||||
if isFirst { // first execution, ask the question
|
||||
// format the question. Which is common to all use cases
|
||||
firstQuestion, err := nodes.TemplateRender(q.config.QuestionTpl, in)
|
||||
firstQuestion, err := nodes.TemplateRender(q.questionTpl, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -218,7 +433,7 @@ func (q *QuestionAnswer) Execute(ctx context.Context, in map[string]any) (out ma
|
|||
return nil, q.interrupt(ctx, firstQuestion, nil, nil, nil)
|
||||
}
|
||||
|
||||
if q.config.ExtractFromAnswer {
|
||||
if q.needExtractFromAnswer {
|
||||
return q.extractFromAnswer(ctx, in, questions, answers)
|
||||
}
|
||||
|
||||
|
|
@ -253,15 +468,15 @@ func (q *QuestionAnswer) Execute(ctx context.Context, in map[string]any) (out ma
|
|||
}
|
||||
|
||||
// format the question. Which is common to all use cases
|
||||
firstQuestion, err := nodes.TemplateRender(q.config.QuestionTpl, in)
|
||||
firstQuestion, err := nodes.TemplateRender(q.questionTpl, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var formattedChoices []string
|
||||
switch q.config.ChoiceType {
|
||||
switch q.choiceType {
|
||||
case FixedChoices:
|
||||
for _, choice := range q.config.FixedChoices {
|
||||
for _, choice := range q.fixedChoices {
|
||||
formattedChoice, err := nodes.TemplateRender(choice, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -283,18 +498,18 @@ func (q *QuestionAnswer) Execute(ctx context.Context, in map[string]any) (out ma
|
|||
formattedChoices = append(formattedChoices, c)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown choice type: %s", q.config.ChoiceType)
|
||||
return nil, fmt.Errorf("unknown choice type: %s", q.choiceType)
|
||||
}
|
||||
|
||||
return nil, q.interrupt(ctx, firstQuestion, formattedChoices, nil, nil)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown answer type: %s", q.config.AnswerType)
|
||||
return nil, fmt.Errorf("unknown answer type: %s", q.answerType)
|
||||
}
|
||||
}
|
||||
|
||||
func (q *QuestionAnswer) extractFromAnswer(ctx context.Context, in map[string]any, questions []*Question, answers []string) (map[string]any, error) {
|
||||
fieldInfo := "FieldInfo"
|
||||
s, err := vo.TypeInfoToJSONSchema(q.config.OutputFields, &fieldInfo)
|
||||
s, err := vo.TypeInfoToJSONSchema(q.outputFields, &fieldInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -302,15 +517,15 @@ func (q *QuestionAnswer) extractFromAnswer(ctx context.Context, in map[string]an
|
|||
sysPrompt := fmt.Sprintf(extractSystemPrompt, s)
|
||||
|
||||
var requiredFields []string
|
||||
for fName, tInfo := range q.config.OutputFields {
|
||||
for fName, tInfo := range q.outputFields {
|
||||
if tInfo.Required {
|
||||
requiredFields = append(requiredFields, fName)
|
||||
}
|
||||
}
|
||||
|
||||
var formattedAdditionalPrompt string
|
||||
if len(q.config.AdditionalSystemPromptTpl) > 0 {
|
||||
additionalPrompt, err := nodes.TemplateRender(q.config.AdditionalSystemPromptTpl, in)
|
||||
if len(q.additionalSystemPromptTpl) > 0 {
|
||||
additionalPrompt, err := nodes.TemplateRender(q.additionalSystemPromptTpl, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -336,7 +551,7 @@ func (q *QuestionAnswer) extractFromAnswer(ctx context.Context, in map[string]an
|
|||
messages = append(messages, schema.UserMessage(answer))
|
||||
}
|
||||
|
||||
out, err := q.config.Model.Generate(ctx, messages)
|
||||
out, err := q.model.Generate(ctx, messages)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -353,8 +568,8 @@ func (q *QuestionAnswer) extractFromAnswer(ctx context.Context, in map[string]an
|
|||
if ok {
|
||||
nextQuestionStr, ok := nextQuestion.(string)
|
||||
if ok && len(nextQuestionStr) > 0 {
|
||||
if len(answers) >= q.config.MaxAnswerCount {
|
||||
return nil, fmt.Errorf("max answer count= %d exceeded", q.config.MaxAnswerCount)
|
||||
if len(answers) >= q.maxAnswerCount {
|
||||
return nil, fmt.Errorf("max answer count= %d exceeded", q.maxAnswerCount)
|
||||
}
|
||||
|
||||
return nil, q.interrupt(ctx, nextQuestionStr, nil, questions, answers)
|
||||
|
|
@ -366,7 +581,7 @@ func (q *QuestionAnswer) extractFromAnswer(ctx context.Context, in map[string]an
|
|||
return nil, fmt.Errorf("field %s not found", fieldInfo)
|
||||
}
|
||||
|
||||
realOutput, ws, err := nodes.ConvertInputs(ctx, fields.(map[string]any), q.config.OutputFields, nodes.SkipRequireCheck())
|
||||
realOutput, ws, err := nodes.ConvertInputs(ctx, fields.(map[string]any), q.outputFields, nodes.SkipRequireCheck())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -431,7 +646,7 @@ func (q *QuestionAnswer) intentDetect(ctx context.Context, answer string, choice
|
|||
schema.UserMessage(answer),
|
||||
}
|
||||
|
||||
out, err := q.config.Model.Generate(ctx, messages)
|
||||
out, err := q.model.Generate(ctx, messages)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
|
|
@ -468,7 +683,7 @@ func (q *QuestionAnswer) interrupt(ctx context.Context, newQuestion string, choi
|
|||
|
||||
event := &entity.InterruptEvent{
|
||||
ID: eventID,
|
||||
NodeKey: q.config.NodeKey,
|
||||
NodeKey: q.nodeKey,
|
||||
NodeType: entity.NodeTypeQuestionAnswer,
|
||||
NodeTitle: q.nodeMeta.Name,
|
||||
NodeIcon: q.nodeMeta.IconURL,
|
||||
|
|
@ -477,7 +692,7 @@ func (q *QuestionAnswer) interrupt(ctx context.Context, newQuestion string, choi
|
|||
}
|
||||
|
||||
_ = compose.ProcessState(ctx, func(ctx context.Context, setter QuestionAnswerAware) error {
|
||||
setter.AddQuestion(q.config.NodeKey, &Question{
|
||||
setter.AddQuestion(q.nodeKey, &Question{
|
||||
Question: newQuestion,
|
||||
Choices: choices,
|
||||
})
|
||||
|
|
@ -495,14 +710,14 @@ func intToAlphabet(num int) string {
|
|||
return ""
|
||||
}
|
||||
|
||||
func AlphabetToInt(str string) (int, bool) {
|
||||
func AlphabetToInt(str string) (int64, bool) {
|
||||
if len(str) != 1 {
|
||||
return 0, false
|
||||
}
|
||||
char := rune(str[0])
|
||||
char = unicode.ToUpper(char)
|
||||
if char >= 'A' && char <= 'Z' {
|
||||
return int(char - 'A'), true
|
||||
return int64(char - 'A'), true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
|
@ -521,14 +736,14 @@ func (q *QuestionAnswer) generateHistory(oldQuestions []*Question, oldAnswers []
|
|||
for i := 0; i < len(oldQuestions); i++ {
|
||||
oldQuestion := oldQuestions[i]
|
||||
oldAnswer := oldAnswers[i]
|
||||
contentType := ternary.IFElse(q.config.AnswerType == AnswerByChoices, "option", "text")
|
||||
contentType := ternary.IFElse(q.answerType == AnswerByChoices, "option", "text")
|
||||
questionMsg := &message{
|
||||
Type: "question",
|
||||
ContentType: contentType,
|
||||
ID: fmt.Sprintf("%s_%d", q.config.NodeKey, i*2),
|
||||
ID: fmt.Sprintf("%s_%d", q.nodeKey, i*2),
|
||||
}
|
||||
|
||||
if q.config.AnswerType == AnswerByChoices {
|
||||
if q.answerType == AnswerByChoices {
|
||||
questionMsg.Content = optionContent{
|
||||
Options: conv(oldQuestion.Choices),
|
||||
Question: oldQuestion.Question,
|
||||
|
|
@ -541,14 +756,14 @@ func (q *QuestionAnswer) generateHistory(oldQuestions []*Question, oldAnswers []
|
|||
Type: "answer",
|
||||
ContentType: contentType,
|
||||
Content: oldAnswer,
|
||||
ID: fmt.Sprintf("%s_%d", q.config.NodeKey, i+1),
|
||||
ID: fmt.Sprintf("%s_%d", q.nodeKey, i+1),
|
||||
}
|
||||
|
||||
history = append(history, questionMsg, answerMsg)
|
||||
}
|
||||
|
||||
if newQuestion != nil {
|
||||
if q.config.AnswerType == AnswerByChoices {
|
||||
if q.answerType == AnswerByChoices {
|
||||
history = append(history, &message{
|
||||
Type: "question",
|
||||
ContentType: "option",
|
||||
|
|
@ -556,14 +771,14 @@ func (q *QuestionAnswer) generateHistory(oldQuestions []*Question, oldAnswers []
|
|||
Options: conv(choices),
|
||||
Question: *newQuestion,
|
||||
},
|
||||
ID: fmt.Sprintf("%s_%d", q.config.NodeKey, len(oldQuestions)*2),
|
||||
ID: fmt.Sprintf("%s_%d", q.nodeKey, len(oldQuestions)*2),
|
||||
})
|
||||
} else {
|
||||
history = append(history, &message{
|
||||
Type: "question",
|
||||
ContentType: "text",
|
||||
Content: *newQuestion,
|
||||
ID: fmt.Sprintf("%s_%d", q.config.NodeKey, len(oldQuestions)*2),
|
||||
ID: fmt.Sprintf("%s_%d", q.nodeKey, len(oldQuestions)*2),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,8 +27,10 @@ import (
|
|||
"github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
|
|
@ -37,19 +39,27 @@ import (
|
|||
)
|
||||
|
||||
type Config struct {
|
||||
OutputTypes map[string]*vo.TypeInfo
|
||||
NodeKey vo.NodeKey
|
||||
OutputSchema string
|
||||
}
|
||||
|
||||
type InputReceiver struct {
|
||||
outputTypes map[string]*vo.TypeInfo
|
||||
interruptData string
|
||||
nodeKey vo.NodeKey
|
||||
nodeMeta entity.NodeTypeMeta
|
||||
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
c.OutputSchema = n.Data.Inputs.OutputSchema
|
||||
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeInputReceiver,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: c,
|
||||
}
|
||||
|
||||
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func New(_ context.Context, cfg *Config) (*InputReceiver, error) {
|
||||
func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
nodeMeta := entity.NodeMetaByNodeType(entity.NodeTypeInputReceiver)
|
||||
if nodeMeta == nil {
|
||||
return nil, errors.New("node meta not found for input receiver")
|
||||
|
|
@ -57,7 +67,7 @@ func New(_ context.Context, cfg *Config) (*InputReceiver, error) {
|
|||
|
||||
interruptData := map[string]string{
|
||||
"content_type": "form_schema",
|
||||
"content": cfg.OutputSchema,
|
||||
"content": c.OutputSchema,
|
||||
}
|
||||
|
||||
interruptDataStr, err := sonic.ConfigStd.MarshalToString(interruptData) // keep the order of the keys
|
||||
|
|
@ -66,13 +76,24 @@ func New(_ context.Context, cfg *Config) (*InputReceiver, error) {
|
|||
}
|
||||
|
||||
return &InputReceiver{
|
||||
outputTypes: cfg.OutputTypes,
|
||||
outputTypes: ns.OutputTypes, // so the node can refer to its output types during execution
|
||||
nodeMeta: *nodeMeta,
|
||||
nodeKey: cfg.NodeKey,
|
||||
nodeKey: ns.Key,
|
||||
interruptData: interruptDataStr,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Config) RequireCheckpoint() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type InputReceiver struct {
|
||||
outputTypes map[string]*vo.TypeInfo
|
||||
interruptData string
|
||||
nodeKey vo.NodeKey
|
||||
nodeMeta entity.NodeTypeMeta
|
||||
}
|
||||
|
||||
const (
|
||||
ReceivedDataKey = "$received_data"
|
||||
receiverWarningKey = "receiver_warning_%d_%s"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,190 @@
|
|||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package selector
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
)
|
||||
|
||||
type selectorCallbackField struct {
|
||||
Key string `json:"key"`
|
||||
Type vo.DataType `json:"type"`
|
||||
Value any `json:"value"`
|
||||
}
|
||||
|
||||
type selectorCondition struct {
|
||||
Left selectorCallbackField `json:"left"`
|
||||
Operator vo.OperatorType `json:"operator"`
|
||||
Right *selectorCallbackField `json:"right,omitempty"`
|
||||
}
|
||||
|
||||
type selectorBranch struct {
|
||||
Conditions []*selectorCondition `json:"conditions"`
|
||||
Logic vo.LogicType `json:"logic"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func (s *Selector) ToCallbackInput(_ context.Context, in map[string]any) (map[string]any, error) {
|
||||
count := len(s.clauses)
|
||||
|
||||
output := make([]*selectorBranch, count)
|
||||
|
||||
for _, source := range s.ns.InputSources {
|
||||
targetPath := source.Path
|
||||
if len(targetPath) == 2 {
|
||||
indexStr := targetPath[0]
|
||||
index, err := strconv.Atoi(indexStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
branch := output[index]
|
||||
if branch == nil {
|
||||
output[index] = &selectorBranch{
|
||||
Conditions: []*selectorCondition{
|
||||
{
|
||||
Operator: s.clauses[index].Single.ToCanvasOperatorType(),
|
||||
},
|
||||
},
|
||||
Logic: ClauseRelationAND.ToVOLogicType(),
|
||||
}
|
||||
}
|
||||
|
||||
if targetPath[1] == LeftKey {
|
||||
leftV, ok := nodes.TakeMapValue(in, targetPath)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to take left value of %s", targetPath)
|
||||
}
|
||||
if source.Source.Ref.VariableType != nil {
|
||||
if *source.Source.Ref.VariableType == vo.ParentIntermediate {
|
||||
parentNodeKey, ok := s.ws.Hierarchy[s.ns.Key]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to find parent node key of %s", s.ns.Key)
|
||||
}
|
||||
parentNode := s.ws.GetNode(parentNodeKey)
|
||||
output[index].Conditions[0].Left = selectorCallbackField{
|
||||
Key: parentNode.Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
|
||||
Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
|
||||
Value: leftV,
|
||||
}
|
||||
} else {
|
||||
output[index].Conditions[0].Left = selectorCallbackField{
|
||||
Key: "",
|
||||
Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
|
||||
Value: leftV,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
output[index].Conditions[0].Left = selectorCallbackField{
|
||||
Key: s.ws.GetNode(source.Source.Ref.FromNodeKey).Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
|
||||
Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
|
||||
Value: leftV,
|
||||
}
|
||||
}
|
||||
} else if targetPath[1] == RightKey {
|
||||
rightV, ok := nodes.TakeMapValue(in, targetPath)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to take right value of %s", targetPath)
|
||||
}
|
||||
output[index].Conditions[0].Right = &selectorCallbackField{
|
||||
Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Type,
|
||||
Value: rightV,
|
||||
}
|
||||
}
|
||||
} else if len(targetPath) == 3 {
|
||||
indexStr := targetPath[0]
|
||||
index, err := strconv.Atoi(indexStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
multi := s.clauses[index].Multi
|
||||
|
||||
branch := output[index]
|
||||
if branch == nil {
|
||||
output[index] = &selectorBranch{
|
||||
Conditions: make([]*selectorCondition, len(multi.Clauses)),
|
||||
Logic: multi.Relation.ToVOLogicType(),
|
||||
}
|
||||
}
|
||||
|
||||
clauseIndexStr := targetPath[1]
|
||||
clauseIndex, err := strconv.Atoi(clauseIndexStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
clause := multi.Clauses[clauseIndex]
|
||||
|
||||
if output[index].Conditions[clauseIndex] == nil {
|
||||
output[index].Conditions[clauseIndex] = &selectorCondition{
|
||||
Operator: clause.ToCanvasOperatorType(),
|
||||
}
|
||||
}
|
||||
|
||||
if targetPath[2] == LeftKey {
|
||||
leftV, ok := nodes.TakeMapValue(in, targetPath)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to take left value of %s", targetPath)
|
||||
}
|
||||
if source.Source.Ref.VariableType != nil {
|
||||
if *source.Source.Ref.VariableType == vo.ParentIntermediate {
|
||||
parentNodeKey, ok := s.ws.Hierarchy[s.ns.Key]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to find parent node key of %s", s.ns.Key)
|
||||
}
|
||||
parentNode := s.ws.GetNode(parentNodeKey)
|
||||
output[index].Conditions[clauseIndex].Left = selectorCallbackField{
|
||||
Key: parentNode.Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
|
||||
Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
|
||||
Value: leftV,
|
||||
}
|
||||
} else {
|
||||
output[index].Conditions[clauseIndex].Left = selectorCallbackField{
|
||||
Key: "",
|
||||
Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
|
||||
Value: leftV,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
output[index].Conditions[clauseIndex].Left = selectorCallbackField{
|
||||
Key: s.ws.GetNode(source.Source.Ref.FromNodeKey).Name + "." + strings.Join(source.Source.Ref.FromPath, "."),
|
||||
Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
|
||||
Value: leftV,
|
||||
}
|
||||
}
|
||||
} else if targetPath[2] == RightKey {
|
||||
rightV, ok := nodes.TakeMapValue(in, targetPath)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to take right value of %s", targetPath)
|
||||
}
|
||||
output[index].Conditions[clauseIndex].Right = &selectorCallbackField{
|
||||
Type: s.ns.InputTypes[targetPath[0]].Properties[targetPath[1]].Properties[targetPath[2]].Type,
|
||||
Value: rightV,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]any{"branches": output}, nil
|
||||
}
|
||||
|
|
@ -180,3 +180,48 @@ func (o *Operator) ToCanvasOperatorType() vo.OperatorType {
|
|||
panic(fmt.Sprintf("unknown operator: %+v", o))
|
||||
}
|
||||
}
|
||||
|
||||
func ToSelectorOperator(o vo.OperatorType, leftType *vo.TypeInfo) (Operator, error) {
|
||||
switch o {
|
||||
case vo.Equal:
|
||||
return OperatorEqual, nil
|
||||
case vo.NotEqual:
|
||||
return OperatorNotEqual, nil
|
||||
case vo.LengthGreaterThan:
|
||||
return OperatorLengthGreater, nil
|
||||
case vo.LengthGreaterThanEqual:
|
||||
return OperatorLengthGreaterOrEqual, nil
|
||||
case vo.LengthLessThan:
|
||||
return OperatorLengthLesser, nil
|
||||
case vo.LengthLessThanEqual:
|
||||
return OperatorLengthLesserOrEqual, nil
|
||||
case vo.Contain:
|
||||
if leftType.Type == vo.DataTypeObject {
|
||||
return OperatorContainKey, nil
|
||||
}
|
||||
return OperatorContain, nil
|
||||
case vo.NotContain:
|
||||
if leftType.Type == vo.DataTypeObject {
|
||||
return OperatorNotContainKey, nil
|
||||
}
|
||||
return OperatorNotContain, nil
|
||||
case vo.Empty:
|
||||
return OperatorEmpty, nil
|
||||
case vo.NotEmpty:
|
||||
return OperatorNotEmpty, nil
|
||||
case vo.True:
|
||||
return OperatorIsTrue, nil
|
||||
case vo.False:
|
||||
return OperatorIsFalse, nil
|
||||
case vo.GreaterThan:
|
||||
return OperatorGreater, nil
|
||||
case vo.GreaterThanEqual:
|
||||
return OperatorGreaterOrEqual, nil
|
||||
case vo.LessThan:
|
||||
return OperatorLesser, nil
|
||||
case vo.LessThanEqual:
|
||||
return OperatorLesserOrEqual, nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported operator type: %d", o)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,9 +17,16 @@
|
|||
package selector
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
einoCompose "github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
type ClauseRelation string
|
||||
|
|
@ -29,10 +36,6 @@ const (
|
|||
ClauseRelationOR ClauseRelation = "or"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Clauses []*OneClauseSchema `json:"clauses"`
|
||||
}
|
||||
|
||||
type OneClauseSchema struct {
|
||||
Single *Operator `json:"single,omitempty"`
|
||||
Multi *MultiClauseSchema `json:"multi,omitempty"`
|
||||
|
|
@ -52,3 +55,140 @@ func (c ClauseRelation) ToVOLogicType() vo.LogicType {
|
|||
|
||||
panic(fmt.Sprintf("unknown clause relation: %s", c))
|
||||
}
|
||||
|
||||
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
clauses := make([]*OneClauseSchema, 0)
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Name: n.Data.Meta.Title,
|
||||
Type: entity.NodeTypeSelector,
|
||||
Configs: c,
|
||||
}
|
||||
|
||||
for i, branchCond := range n.Data.Inputs.Branches {
|
||||
inputType := &vo.TypeInfo{
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{},
|
||||
}
|
||||
|
||||
if len(branchCond.Condition.Conditions) == 1 { // single condition
|
||||
cond := branchCond.Condition.Conditions[0]
|
||||
|
||||
left := cond.Left
|
||||
if left == nil {
|
||||
return nil, fmt.Errorf("operator left is nil")
|
||||
}
|
||||
|
||||
leftType, err := convert.CanvasBlockInputToTypeInfo(left.Input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
leftSources, err := convert.CanvasBlockInputToFieldInfo(left.Input, einoCompose.FieldPath{fmt.Sprintf("%d", i), LeftKey}, n.Parent())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
inputType.Properties[LeftKey] = leftType
|
||||
|
||||
ns.AddInputSource(leftSources...)
|
||||
|
||||
op, err := ToSelectorOperator(cond.Operator, leftType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cond.Right != nil {
|
||||
rightType, err := convert.CanvasBlockInputToTypeInfo(cond.Right.Input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rightSources, err := convert.CanvasBlockInputToFieldInfo(cond.Right.Input, einoCompose.FieldPath{fmt.Sprintf("%d", i), RightKey}, n.Parent())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
inputType.Properties[RightKey] = rightType
|
||||
ns.AddInputSource(rightSources...)
|
||||
}
|
||||
|
||||
ns.SetInputType(fmt.Sprintf("%d", i), inputType)
|
||||
|
||||
clauses = append(clauses, &OneClauseSchema{
|
||||
Single: &op,
|
||||
})
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
var relation ClauseRelation
|
||||
logic := branchCond.Condition.Logic
|
||||
if logic == vo.OR {
|
||||
relation = ClauseRelationOR
|
||||
} else if logic == vo.AND {
|
||||
relation = ClauseRelationAND
|
||||
}
|
||||
|
||||
var ops []*Operator
|
||||
for j, cond := range branchCond.Condition.Conditions {
|
||||
left := cond.Left
|
||||
if left == nil {
|
||||
return nil, fmt.Errorf("operator left is nil")
|
||||
}
|
||||
|
||||
leftType, err := convert.CanvasBlockInputToTypeInfo(left.Input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
leftSources, err := convert.CanvasBlockInputToFieldInfo(left.Input, einoCompose.FieldPath{fmt.Sprintf("%d", i), fmt.Sprintf("%d", j), LeftKey}, n.Parent())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
inputType.Properties[fmt.Sprintf("%d", j)] = &vo.TypeInfo{
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: map[string]*vo.TypeInfo{
|
||||
LeftKey: leftType,
|
||||
},
|
||||
}
|
||||
|
||||
ns.AddInputSource(leftSources...)
|
||||
|
||||
op, err := ToSelectorOperator(cond.Operator, leftType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ops = append(ops, &op)
|
||||
|
||||
if cond.Right != nil {
|
||||
rightType, err := convert.CanvasBlockInputToTypeInfo(cond.Right.Input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rightSources, err := convert.CanvasBlockInputToFieldInfo(cond.Right.Input, einoCompose.FieldPath{fmt.Sprintf("%d", i), fmt.Sprintf("%d", j), RightKey}, n.Parent())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
inputType.Properties[fmt.Sprintf("%d", j)].Properties[RightKey] = rightType
|
||||
ns.AddInputSource(rightSources...)
|
||||
}
|
||||
}
|
||||
|
||||
ns.SetInputType(fmt.Sprintf("%d", i), inputType)
|
||||
|
||||
clauses = append(clauses, &OneClauseSchema{
|
||||
Multi: &MultiClauseSchema{
|
||||
Clauses: ops,
|
||||
Relation: relation,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
c.Clauses = clauses
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,23 +23,32 @@ import (
|
|||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
type Selector struct {
|
||||
config *Config
|
||||
clauses []*OneClauseSchema
|
||||
ns *schema.NodeSchema
|
||||
ws *schema.WorkflowSchema
|
||||
}
|
||||
|
||||
func NewSelector(_ context.Context, config *Config) (*Selector, error) {
|
||||
if config == nil {
|
||||
return nil, fmt.Errorf("config is nil")
|
||||
type Config struct {
|
||||
Clauses []*OneClauseSchema `json:"clauses"`
|
||||
}
|
||||
|
||||
func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, opts ...schema.BuildOption) (any, error) {
|
||||
ws := schema.GetBuildOptions(opts...).WS
|
||||
if ws == nil {
|
||||
return nil, fmt.Errorf("workflow schema is required")
|
||||
}
|
||||
|
||||
if len(config.Clauses) == 0 {
|
||||
if len(c.Clauses) == 0 {
|
||||
return nil, fmt.Errorf("config clauses are empty")
|
||||
}
|
||||
|
||||
for _, clause := range config.Clauses {
|
||||
for _, clause := range c.Clauses {
|
||||
if clause.Single == nil && clause.Multi == nil {
|
||||
return nil, fmt.Errorf("single clause and multi clause are both nil")
|
||||
}
|
||||
|
|
@ -60,10 +69,42 @@ func NewSelector(_ context.Context, config *Config) (*Selector, error) {
|
|||
}
|
||||
|
||||
return &Selector{
|
||||
config: config,
|
||||
clauses: c.Clauses,
|
||||
ns: ns,
|
||||
ws: ws,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Config) BuildBranch(_ context.Context) (
|
||||
func(ctx context.Context, nodeOutput map[string]any) (int64, bool, error), bool) {
|
||||
return func(ctx context.Context, nodeOutput map[string]any) (int64, bool, error) {
|
||||
choice := nodeOutput[SelectKey].(int64)
|
||||
if choice < 0 || choice > int64(len(c.Clauses)+1) {
|
||||
return -1, false, fmt.Errorf("selector choice out of range: %d", choice)
|
||||
}
|
||||
|
||||
if choice == int64(len(c.Clauses)) { // default
|
||||
return -1, true, nil
|
||||
}
|
||||
|
||||
return choice, false, nil
|
||||
}, true
|
||||
}
|
||||
|
||||
func (c *Config) ExpectPorts(_ context.Context, n *vo.Node) []string {
|
||||
expects := make([]string, len(n.Data.Inputs.Branches)+1)
|
||||
expects[0] = "false" // default branch
|
||||
if len(n.Data.Inputs.Branches) > 0 {
|
||||
expects[1] = "true" // first condition
|
||||
}
|
||||
|
||||
for i := 1; i < len(n.Data.Inputs.Branches); i++ { // other conditions
|
||||
expects[i+1] = "true_" + strconv.Itoa(i)
|
||||
}
|
||||
|
||||
return expects
|
||||
}
|
||||
|
||||
type Operants struct {
|
||||
Left any
|
||||
Right any
|
||||
|
|
@ -76,14 +117,14 @@ const (
|
|||
SelectKey = "selected"
|
||||
)
|
||||
|
||||
func (s *Selector) Select(_ context.Context, input map[string]any) (out map[string]any, err error) {
|
||||
in, err := s.SelectorInputConverter(input)
|
||||
func (s *Selector) Invoke(_ context.Context, input map[string]any) (out map[string]any, err error) {
|
||||
in, err := s.selectorInputConverter(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
predicates := make([]Predicate, 0, len(s.config.Clauses))
|
||||
for i, oneConf := range s.config.Clauses {
|
||||
predicates := make([]Predicate, 0, len(s.clauses))
|
||||
for i, oneConf := range s.clauses {
|
||||
if oneConf.Single != nil {
|
||||
left := in[i].Left
|
||||
right := in[i].Right
|
||||
|
|
@ -132,23 +173,15 @@ func (s *Selector) Select(_ context.Context, input map[string]any) (out map[stri
|
|||
}
|
||||
|
||||
if isTrue {
|
||||
return map[string]any{SelectKey: i}, nil
|
||||
return map[string]any{SelectKey: int64(i)}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]any{SelectKey: len(in)}, nil // default choice
|
||||
return map[string]any{SelectKey: int64(len(in))}, nil // default choice
|
||||
}
|
||||
|
||||
func (s *Selector) GetType() string {
|
||||
return "Selector"
|
||||
}
|
||||
|
||||
func (s *Selector) ConditionCount() int {
|
||||
return len(s.config.Clauses)
|
||||
}
|
||||
|
||||
func (s *Selector) SelectorInputConverter(in map[string]any) (out []Operants, err error) {
|
||||
conf := s.config.Clauses
|
||||
func (s *Selector) selectorInputConverter(in map[string]any) (out []Operants, err error) {
|
||||
conf := s.clauses
|
||||
|
||||
for i, oneConf := range conf {
|
||||
if oneConf.Single != nil {
|
||||
|
|
@ -187,8 +220,8 @@ func (s *Selector) SelectorInputConverter(in map[string]any) (out []Operants, er
|
|||
}
|
||||
|
||||
func (s *Selector) ToCallbackOutput(_ context.Context, output map[string]any) (*nodes.StructuredCallbackOutput, error) {
|
||||
count := len(s.config.Clauses)
|
||||
out := output[SelectKey].(int)
|
||||
count := int64(len(s.clauses))
|
||||
out := output[SelectKey].(int64)
|
||||
if out == count {
|
||||
cOutput := map[string]any{"result": "pass to else branch"}
|
||||
return &nodes.StructuredCallbackOutput{
|
||||
|
|
|
|||
|
|
@ -22,57 +22,27 @@ import (
|
|||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
var KeyIsFinished = "\x1FKey is finished\x1F"
|
||||
|
||||
type Mode string
|
||||
|
||||
const (
|
||||
Streaming Mode = "streaming"
|
||||
NonStreaming Mode = "non-streaming"
|
||||
)
|
||||
|
||||
type FieldStreamType string
|
||||
|
||||
const (
|
||||
FieldIsStream FieldStreamType = "yes" // absolutely a stream
|
||||
FieldNotStream FieldStreamType = "no" // absolutely not a stream
|
||||
FieldMaybeStream FieldStreamType = "maybe" // maybe a stream, requires request-time resolution
|
||||
FieldSkipped FieldStreamType = "skipped" // the field source's node is skipped
|
||||
)
|
||||
|
||||
// SourceInfo contains stream type for a input field source of a node.
|
||||
type SourceInfo struct {
|
||||
// IsIntermediate means this field is itself not a field source, but a map containing one or more field sources.
|
||||
IsIntermediate bool
|
||||
// FieldType the stream type of the field. May require request-time resolution in addition to compile-time.
|
||||
FieldType FieldStreamType
|
||||
// FromNodeKey is the node key that produces this field source. empty if the field is a static value or variable.
|
||||
FromNodeKey vo.NodeKey
|
||||
// FromPath is the path of this field source within the source node. empty if the field is a static value or variable.
|
||||
FromPath compose.FieldPath
|
||||
TypeInfo *vo.TypeInfo
|
||||
// SubSources are SourceInfo for keys within this intermediate Map(Object) field.
|
||||
SubSources map[string]*SourceInfo
|
||||
}
|
||||
|
||||
type DynamicStreamContainer interface {
|
||||
SaveDynamicChoice(nodeKey vo.NodeKey, groupToChoice map[string]int)
|
||||
GetDynamicChoice(nodeKey vo.NodeKey) map[string]int
|
||||
GetDynamicStreamType(nodeKey vo.NodeKey, group string) (FieldStreamType, error)
|
||||
GetAllDynamicStreamTypes(nodeKey vo.NodeKey) (map[string]FieldStreamType, error)
|
||||
GetDynamicStreamType(nodeKey vo.NodeKey, group string) (schema.FieldStreamType, error)
|
||||
GetAllDynamicStreamTypes(nodeKey vo.NodeKey) (map[string]schema.FieldStreamType, error)
|
||||
}
|
||||
|
||||
// ResolveStreamSources resolves incoming field sources for a node, deciding their stream type.
|
||||
func ResolveStreamSources(ctx context.Context, sources map[string]*SourceInfo) (map[string]*SourceInfo, error) {
|
||||
resolved := make(map[string]*SourceInfo, len(sources))
|
||||
func ResolveStreamSources(ctx context.Context, sources map[string]*schema.SourceInfo) (map[string]*schema.SourceInfo, error) {
|
||||
resolved := make(map[string]*schema.SourceInfo, len(sources))
|
||||
|
||||
nodeKey2Skipped := make(map[vo.NodeKey]bool)
|
||||
|
||||
var resolver func(path string, sInfo *SourceInfo) (*SourceInfo, error)
|
||||
resolver = func(path string, sInfo *SourceInfo) (*SourceInfo, error) {
|
||||
resolvedNode := &SourceInfo{
|
||||
var resolver func(path string, sInfo *schema.SourceInfo) (*schema.SourceInfo, error)
|
||||
resolver = func(path string, sInfo *schema.SourceInfo) (*schema.SourceInfo, error) {
|
||||
resolvedNode := &schema.SourceInfo{
|
||||
IsIntermediate: sInfo.IsIntermediate,
|
||||
FieldType: sInfo.FieldType,
|
||||
FromNodeKey: sInfo.FromNodeKey,
|
||||
|
|
@ -81,7 +51,7 @@ func ResolveStreamSources(ctx context.Context, sources map[string]*SourceInfo) (
|
|||
}
|
||||
|
||||
if len(sInfo.SubSources) > 0 {
|
||||
resolvedNode.SubSources = make(map[string]*SourceInfo, len(sInfo.SubSources))
|
||||
resolvedNode.SubSources = make(map[string]*schema.SourceInfo, len(sInfo.SubSources))
|
||||
|
||||
for k, subInfo := range sInfo.SubSources {
|
||||
resolvedSub, err := resolver(k, subInfo)
|
||||
|
|
@ -109,16 +79,16 @@ func ResolveStreamSources(ctx context.Context, sources map[string]*SourceInfo) (
|
|||
}
|
||||
|
||||
if skipped {
|
||||
resolvedNode.FieldType = FieldSkipped
|
||||
resolvedNode.FieldType = schema.FieldSkipped
|
||||
return resolvedNode, nil
|
||||
}
|
||||
|
||||
if sInfo.FieldType == FieldMaybeStream {
|
||||
if sInfo.FieldType == schema.FieldMaybeStream {
|
||||
if len(sInfo.SubSources) > 0 {
|
||||
panic("a maybe stream field should not have sub sources")
|
||||
}
|
||||
|
||||
var streamType FieldStreamType
|
||||
var streamType schema.FieldStreamType
|
||||
err := compose.ProcessState(ctx, func(ctx context.Context, state DynamicStreamContainer) error {
|
||||
var e error
|
||||
streamType, e = state.GetDynamicStreamType(sInfo.FromNodeKey, sInfo.FromPath[0])
|
||||
|
|
@ -128,7 +98,7 @@ func ResolveStreamSources(ctx context.Context, sources map[string]*SourceInfo) (
|
|||
return nil, err
|
||||
}
|
||||
|
||||
return &SourceInfo{
|
||||
return &schema.SourceInfo{
|
||||
IsIntermediate: sInfo.IsIntermediate,
|
||||
FieldType: streamType,
|
||||
FromNodeKey: sInfo.FromNodeKey,
|
||||
|
|
@ -156,30 +126,12 @@ type NodeExecuteStatusAware interface {
|
|||
NodeExecuted(key vo.NodeKey) bool
|
||||
}
|
||||
|
||||
func (s *SourceInfo) Skipped() bool {
|
||||
if !s.IsIntermediate {
|
||||
return s.FieldType == FieldSkipped
|
||||
func IsStreamingField(s *schema.NodeSchema, path compose.FieldPath,
|
||||
sc *schema.WorkflowSchema) (schema.FieldStreamType, error) {
|
||||
sg, ok := s.Configs.(StreamGenerator)
|
||||
if !ok {
|
||||
return schema.FieldNotStream, nil
|
||||
}
|
||||
|
||||
for _, sub := range s.SubSources {
|
||||
if !sub.Skipped() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *SourceInfo) FromNode(nodeKey vo.NodeKey) bool {
|
||||
if !s.IsIntermediate {
|
||||
return s.FromNodeKey == nodeKey
|
||||
}
|
||||
|
||||
for _, sub := range s.SubSources {
|
||||
if sub.FromNode(nodeKey) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
return sg.FieldStreamType(path, s, sc)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ package subworkflow
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
|
|
@ -29,35 +28,56 @@ import (
|
|||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
|
||||
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Runner compose.Runnable[map[string]any, map[string]any]
|
||||
WorkflowID int64
|
||||
WorkflowVersion string
|
||||
}
|
||||
|
||||
func (c *Config) FieldStreamType(path compose.FieldPath, ns *schema2.NodeSchema,
|
||||
sc *schema2.WorkflowSchema) (schema2.FieldStreamType, error) {
|
||||
if !sc.RequireStreaming() {
|
||||
return schema2.FieldNotStream, nil
|
||||
}
|
||||
|
||||
innerWF := ns.SubWorkflowSchema
|
||||
|
||||
if !innerWF.RequireStreaming() {
|
||||
return schema2.FieldNotStream, nil
|
||||
}
|
||||
|
||||
innerExit := innerWF.GetNode(entity.ExitNodeKey)
|
||||
|
||||
if innerExit.Configs.(*exit.Config).TerminatePlan == vo.ReturnVariables {
|
||||
return schema2.FieldNotStream, nil
|
||||
}
|
||||
|
||||
if !innerExit.StreamConfigs.RequireStreamingInput {
|
||||
return schema2.FieldNotStream, nil
|
||||
}
|
||||
|
||||
if len(path) > 1 || path[0] != "output" {
|
||||
return schema2.FieldNotStream, fmt.Errorf(
|
||||
"streaming answering sub-workflow node can only have out field 'output'")
|
||||
}
|
||||
|
||||
return schema2.FieldIsStream, nil
|
||||
}
|
||||
|
||||
type SubWorkflow struct {
|
||||
cfg *Config
|
||||
Runner compose.Runnable[map[string]any, map[string]any]
|
||||
}
|
||||
|
||||
func NewSubWorkflow(_ context.Context, cfg *Config) (*SubWorkflow, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config is nil")
|
||||
}
|
||||
|
||||
if cfg.Runner == nil {
|
||||
return nil, errors.New("runnable is nil")
|
||||
}
|
||||
|
||||
return &SubWorkflow{cfg: cfg}, nil
|
||||
}
|
||||
|
||||
func (s *SubWorkflow) Invoke(ctx context.Context, in map[string]any, opts ...nodes.NestedWorkflowOption) (map[string]any, error) {
|
||||
func (s *SubWorkflow) Invoke(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (map[string]any, error) {
|
||||
nestedOpts, nodeKey, err := prepareOptions(ctx, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out, err := s.cfg.Runner.Invoke(ctx, in, nestedOpts...)
|
||||
out, err := s.Runner.Invoke(ctx, in, nestedOpts...)
|
||||
if err != nil {
|
||||
interruptInfo, ok := compose.ExtractInterruptInfo(err)
|
||||
if !ok {
|
||||
|
|
@ -82,13 +102,13 @@ func (s *SubWorkflow) Invoke(ctx context.Context, in map[string]any, opts ...nod
|
|||
return out, nil
|
||||
}
|
||||
|
||||
func (s *SubWorkflow) Stream(ctx context.Context, in map[string]any, opts ...nodes.NestedWorkflowOption) (*schema.StreamReader[map[string]any], error) {
|
||||
func (s *SubWorkflow) Stream(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (*schema.StreamReader[map[string]any], error) {
|
||||
nestedOpts, nodeKey, err := prepareOptions(ctx, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out, err := s.cfg.Runner.Stream(ctx, in, nestedOpts...)
|
||||
out, err := s.Runner.Stream(ctx, in, nestedOpts...)
|
||||
if err != nil {
|
||||
interruptInfo, ok := compose.ExtractInterruptInfo(err)
|
||||
if !ok {
|
||||
|
|
@ -114,11 +134,8 @@ func (s *SubWorkflow) Stream(ctx context.Context, in map[string]any, opts ...nod
|
|||
return out, nil
|
||||
}
|
||||
|
||||
func prepareOptions(ctx context.Context, opts ...nodes.NestedWorkflowOption) ([]compose.Option, vo.NodeKey, error) {
|
||||
options := &nodes.NestedWorkflowOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
func prepareOptions(ctx context.Context, opts ...nodes.NodeOption) ([]compose.Option, vo.NodeKey, error) {
|
||||
options := nodes.GetCommonOptions(&nodes.NodeOptions{}, opts...)
|
||||
|
||||
nestedOpts := options.GetOptsForNested()
|
||||
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ import (
|
|||
"github.com/bytedance/sonic/ast"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
|
@ -156,7 +157,7 @@ func removeSlice(s string) string {
|
|||
|
||||
type renderOptions struct {
|
||||
type2CustomRenderer map[reflect.Type]func(any) (string, error)
|
||||
reservedKey map[string]struct{}
|
||||
reservedKey map[string]struct{} // a reservedKey will always render, won't check for node skipping
|
||||
nilRenderer func() (string, error)
|
||||
}
|
||||
|
||||
|
|
@ -300,7 +301,7 @@ func (tp TemplatePart) Render(m []byte, opts ...RenderOption) (string, error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (tp TemplatePart) Skipped(resolvedSources map[string]*SourceInfo) (skipped bool, invalid bool) {
|
||||
func (tp TemplatePart) Skipped(resolvedSources map[string]*schema.SourceInfo) (skipped bool, invalid bool) {
|
||||
if len(resolvedSources) == 0 { // no information available, maybe outside the scope of a workflow
|
||||
return false, false
|
||||
}
|
||||
|
|
@ -316,7 +317,7 @@ func (tp TemplatePart) Skipped(resolvedSources map[string]*SourceInfo) (skipped
|
|||
}
|
||||
|
||||
if !matchingSource.IsIntermediate {
|
||||
return matchingSource.FieldType == FieldSkipped, false
|
||||
return matchingSource.FieldType == schema.FieldSkipped, false
|
||||
}
|
||||
|
||||
for _, subPath := range tp.SubPathsBeforeSlice {
|
||||
|
|
@ -325,20 +326,20 @@ func (tp TemplatePart) Skipped(resolvedSources map[string]*SourceInfo) (skipped
|
|||
if matchingSource.IsIntermediate { // the user specified a non-existing source, just skip it
|
||||
return false, true
|
||||
}
|
||||
return matchingSource.FieldType == FieldSkipped, false
|
||||
return matchingSource.FieldType == schema.FieldSkipped, false
|
||||
}
|
||||
|
||||
matchingSource = subSource
|
||||
}
|
||||
|
||||
if !matchingSource.IsIntermediate {
|
||||
return matchingSource.FieldType == FieldSkipped, false
|
||||
return matchingSource.FieldType == schema.FieldSkipped, false
|
||||
}
|
||||
|
||||
var checkSourceSkipped func(sInfo *SourceInfo) bool
|
||||
checkSourceSkipped = func(sInfo *SourceInfo) bool {
|
||||
var checkSourceSkipped func(sInfo *schema.SourceInfo) bool
|
||||
checkSourceSkipped = func(sInfo *schema.SourceInfo) bool {
|
||||
if !sInfo.IsIntermediate {
|
||||
return sInfo.FieldType == FieldSkipped
|
||||
return sInfo.FieldType == schema.FieldSkipped
|
||||
}
|
||||
for _, subSource := range sInfo.SubSources {
|
||||
if !checkSourceSkipped(subSource) {
|
||||
|
|
@ -373,7 +374,7 @@ func (tp TemplatePart) TypeInfo(types map[string]*vo.TypeInfo) *vo.TypeInfo {
|
|||
return currentType
|
||||
}
|
||||
|
||||
func Render(ctx context.Context, tpl string, input map[string]any, sources map[string]*SourceInfo, opts ...RenderOption) (string, error) {
|
||||
func Render(ctx context.Context, tpl string, input map[string]any, sources map[string]*schema.SourceInfo, opts ...RenderOption) (string, error) {
|
||||
mi, err := sonic.Marshal(input)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
|
|
|||
|
|
@ -22,7 +22,11 @@ import (
|
|||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
)
|
||||
|
||||
|
|
@ -38,38 +42,88 @@ type Config struct {
|
|||
Tpl string `json:"tpl"`
|
||||
ConcatChar string `json:"concatChar"`
|
||||
Separators []string `json:"separator"`
|
||||
FullSources map[string]*nodes.SourceInfo `json:"fullSources"`
|
||||
}
|
||||
|
||||
type TextProcessor struct {
|
||||
config *Config
|
||||
}
|
||||
|
||||
func NewTextProcessor(_ context.Context, cfg *Config) (*TextProcessor, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("config requried")
|
||||
func (c *Config) Adapt(ctx context.Context, n *vo.Node, opts ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeTextProcessor,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: c,
|
||||
}
|
||||
if cfg.Type == ConcatText && len(cfg.Tpl) == 0 {
|
||||
|
||||
if n.Data.Inputs.Method == vo.Concat {
|
||||
c.Type = ConcatText
|
||||
params := n.Data.Inputs.ConcatParams
|
||||
for _, param := range params {
|
||||
if param.Name == "concatResult" {
|
||||
c.Tpl = param.Input.Value.Content.(string)
|
||||
} else if param.Name == "arrayItemConcatChar" {
|
||||
c.ConcatChar = param.Input.Value.Content.(string)
|
||||
}
|
||||
}
|
||||
} else if n.Data.Inputs.Method == vo.Split {
|
||||
c.Type = SplitText
|
||||
params := n.Data.Inputs.SplitParams
|
||||
separators := make([]string, 0, len(params))
|
||||
for _, param := range params {
|
||||
if param.Name == "delimiters" {
|
||||
delimiters := param.Input.Value.Content.([]any)
|
||||
for _, d := range delimiters {
|
||||
separators = append(separators, d.(string))
|
||||
}
|
||||
}
|
||||
}
|
||||
c.Separators = separators
|
||||
|
||||
} else {
|
||||
return nil, fmt.Errorf("not supported method: %s", n.Data.Inputs.Method)
|
||||
}
|
||||
|
||||
if err := convert.SetInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func (c *Config) Build(_ context.Context, ns *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
if c.Type == ConcatText && len(c.Tpl) == 0 {
|
||||
return nil, fmt.Errorf("config tpl requried")
|
||||
}
|
||||
|
||||
return &TextProcessor{
|
||||
config: cfg,
|
||||
typ: c.Type,
|
||||
tpl: c.Tpl,
|
||||
concatChar: c.ConcatChar,
|
||||
separators: c.Separators,
|
||||
fullSources: ns.FullSources,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type TextProcessor struct {
|
||||
typ Type
|
||||
tpl string
|
||||
concatChar string
|
||||
separators []string
|
||||
fullSources map[string]*schema.SourceInfo
|
||||
}
|
||||
|
||||
const OutputKey = "output"
|
||||
|
||||
func (t *TextProcessor) Invoke(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
switch t.config.Type {
|
||||
switch t.typ {
|
||||
case ConcatText:
|
||||
arrayRenderer := func(i any) (string, error) {
|
||||
vs := i.([]any)
|
||||
return join(vs, t.config.ConcatChar)
|
||||
return join(vs, t.concatChar)
|
||||
}
|
||||
|
||||
result, err := nodes.Render(ctx, t.config.Tpl, input, t.config.FullSources,
|
||||
result, err := nodes.Render(ctx, t.tpl, input, t.fullSources,
|
||||
nodes.WithCustomRender(reflect.TypeOf([]any{}), arrayRenderer))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -86,9 +140,9 @@ func (t *TextProcessor) Invoke(ctx context.Context, input map[string]any) (map[s
|
|||
if !ok {
|
||||
return nil, fmt.Errorf("input string field must string type but got %T", valueString)
|
||||
}
|
||||
values := strings.Split(valueString, t.config.Separators[0])
|
||||
values := strings.Split(valueString, t.separators[0])
|
||||
// Iterate over each delimiter
|
||||
for _, sep := range t.config.Separators[1:] {
|
||||
for _, sep := range t.separators[1:] {
|
||||
var tempParts []string
|
||||
for _, part := range values {
|
||||
tempParts = append(tempParts, strings.Split(part, sep)...)
|
||||
|
|
@ -102,7 +156,7 @@ func (t *TextProcessor) Invoke(ctx context.Context, input map[string]any) (map[s
|
|||
|
||||
return map[string]any{OutputKey: anyValues}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("not support type %s", t.config.Type)
|
||||
return nil, fmt.Errorf("not support type %s", t.typ)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -21,6 +21,8 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
func TestNewTextProcessorNodeGenerator(t *testing.T) {
|
||||
|
|
@ -30,10 +32,10 @@ func TestNewTextProcessorNodeGenerator(t *testing.T) {
|
|||
Type: SplitText,
|
||||
Separators: []string{",", "|", "."},
|
||||
}
|
||||
p, err := NewTextProcessor(ctx, cfg)
|
||||
p, err := cfg.Build(ctx, &schema2.NodeSchema{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
result, err := p.Invoke(ctx, map[string]any{
|
||||
result, err := p.(*TextProcessor).Invoke(ctx, map[string]any{
|
||||
"String": "a,b|c.d,e|f|g",
|
||||
})
|
||||
|
||||
|
|
@ -60,9 +62,9 @@ func TestNewTextProcessorNodeGenerator(t *testing.T) {
|
|||
ConcatChar: `\t`,
|
||||
Tpl: "fx{{a}}=={{b.b1}}=={{b.b2[1]}}=={{c}}",
|
||||
}
|
||||
p, err := NewTextProcessor(context.Background(), cfg)
|
||||
p, err := cfg.Build(context.Background(), &schema2.NodeSchema{})
|
||||
|
||||
result, err := p.Invoke(ctx, in)
|
||||
result, err := p.(*TextProcessor).Invoke(ctx, in)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, result["output"], `fx1\t{"1":1}\t3==1\t2\t3==2=={"c1":"1"}`)
|
||||
})
|
||||
|
|
|
|||
|
|
@ -32,8 +32,11 @@ import (
|
|||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
|
|
@ -48,24 +51,147 @@ const (
|
|||
type Config struct {
|
||||
MergeStrategy MergeStrategy
|
||||
GroupLen map[string]int
|
||||
FullSources map[string]*nodes.SourceInfo
|
||||
NodeKey vo.NodeKey
|
||||
InputSources []*vo.FieldInfo
|
||||
GroupOrder []string // the order the groups are declared in frontend canvas
|
||||
}
|
||||
|
||||
type VariableAggregator struct {
|
||||
config *Config
|
||||
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema2.NodeSchema, error) {
|
||||
ns := &schema2.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeVariableAggregator,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: c,
|
||||
}
|
||||
|
||||
c.MergeStrategy = FirstNotNullValue
|
||||
inputs := n.Data.Inputs
|
||||
|
||||
groupToLen := make(map[string]int, len(inputs.VariableAggregator.MergeGroups))
|
||||
for i := range inputs.VariableAggregator.MergeGroups {
|
||||
group := inputs.VariableAggregator.MergeGroups[i]
|
||||
tInfo := &vo.TypeInfo{
|
||||
Type: vo.DataTypeObject,
|
||||
Properties: make(map[string]*vo.TypeInfo),
|
||||
}
|
||||
ns.SetInputType(group.Name, tInfo)
|
||||
for ii, v := range group.Variables {
|
||||
name := strconv.Itoa(ii)
|
||||
valueTypeInfo, err := convert.CanvasBlockInputToTypeInfo(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tInfo.Properties[name] = valueTypeInfo
|
||||
sources, err := convert.CanvasBlockInputToFieldInfo(v, compose.FieldPath{group.Name, name}, n.Parent())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ns.AddInputSource(sources...)
|
||||
}
|
||||
|
||||
length := len(group.Variables)
|
||||
groupToLen[group.Name] = length
|
||||
}
|
||||
|
||||
groupOrder := make([]string, 0, len(groupToLen))
|
||||
for i := range inputs.VariableAggregator.MergeGroups {
|
||||
group := inputs.VariableAggregator.MergeGroups[i]
|
||||
groupOrder = append(groupOrder, group.Name)
|
||||
}
|
||||
|
||||
c.GroupLen = groupToLen
|
||||
c.GroupOrder = groupOrder
|
||||
|
||||
if err := convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func NewVariableAggregator(_ context.Context, cfg *Config) (*VariableAggregator, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("config is required")
|
||||
func (c *Config) Build(_ context.Context, ns *schema2.NodeSchema, _ ...schema2.BuildOption) (any, error) {
|
||||
if c.MergeStrategy != FirstNotNullValue {
|
||||
return nil, fmt.Errorf("merge strategy not supported: %v", c.MergeStrategy)
|
||||
}
|
||||
if cfg.MergeStrategy != FirstNotNullValue {
|
||||
return nil, fmt.Errorf("merge strategy not supported: %v", cfg.MergeStrategy)
|
||||
|
||||
return &VariableAggregator{
|
||||
groupLen: c.GroupLen,
|
||||
fullSources: ns.FullSources,
|
||||
nodeKey: ns.Key,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Config) FieldStreamType(path compose.FieldPath, ns *schema2.NodeSchema,
|
||||
sc *schema2.WorkflowSchema) (schema2.FieldStreamType, error) {
|
||||
if !sc.RequireStreaming() {
|
||||
return schema2.FieldNotStream, nil
|
||||
}
|
||||
return &VariableAggregator{config: cfg}, nil
|
||||
|
||||
if len(path) == 2 { // asking about a specific index within a group
|
||||
for _, fInfo := range ns.InputSources {
|
||||
if len(fInfo.Path) == len(path) {
|
||||
equal := true
|
||||
for i := range fInfo.Path {
|
||||
if fInfo.Path[i] != path[i] {
|
||||
equal = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if equal {
|
||||
if fInfo.Source.Ref == nil || fInfo.Source.Ref.FromNodeKey == "" {
|
||||
return schema2.FieldNotStream, nil // variables or static values
|
||||
}
|
||||
fromNodeKey := fInfo.Source.Ref.FromNodeKey
|
||||
fromNode := sc.GetNode(fromNodeKey)
|
||||
if fromNode == nil {
|
||||
return schema2.FieldNotStream, fmt.Errorf("node %s not found", fromNodeKey)
|
||||
}
|
||||
return nodes.IsStreamingField(fromNode, fInfo.Source.Ref.FromPath, sc)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if len(path) == 1 { // asking about the entire group
|
||||
var streamCount, notStreamCount int
|
||||
for _, fInfo := range ns.InputSources {
|
||||
if fInfo.Path[0] == path[0] { // belong to the group
|
||||
if fInfo.Source.Ref != nil && len(fInfo.Source.Ref.FromNodeKey) > 0 {
|
||||
fromNode := sc.GetNode(fInfo.Source.Ref.FromNodeKey)
|
||||
if fromNode == nil {
|
||||
return schema2.FieldNotStream, fmt.Errorf("node %s not found", fInfo.Source.Ref.FromNodeKey)
|
||||
}
|
||||
subStreamType, err := nodes.IsStreamingField(fromNode, fInfo.Source.Ref.FromPath, sc)
|
||||
if err != nil {
|
||||
return schema2.FieldNotStream, err
|
||||
}
|
||||
|
||||
if subStreamType == schema2.FieldMaybeStream {
|
||||
return schema2.FieldMaybeStream, nil
|
||||
} else if subStreamType == schema2.FieldIsStream {
|
||||
streamCount++
|
||||
} else {
|
||||
notStreamCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if streamCount > 0 && notStreamCount == 0 {
|
||||
return schema2.FieldIsStream, nil
|
||||
}
|
||||
|
||||
if streamCount == 0 && notStreamCount > 0 {
|
||||
return schema2.FieldNotStream, nil
|
||||
}
|
||||
|
||||
return schema2.FieldMaybeStream, nil
|
||||
}
|
||||
|
||||
return schema2.FieldNotStream, fmt.Errorf("variable aggregator output path max len = 2, actual: %v", path)
|
||||
}
|
||||
|
||||
type VariableAggregator struct {
|
||||
groupLen map[string]int
|
||||
fullSources map[string]*schema2.SourceInfo
|
||||
nodeKey vo.NodeKey
|
||||
groupOrder []string // the order the groups are declared in frontend canvas
|
||||
}
|
||||
|
||||
func (v *VariableAggregator) Invoke(ctx context.Context, input map[string]any) (_ map[string]any, err error) {
|
||||
|
|
@ -76,7 +202,7 @@ func (v *VariableAggregator) Invoke(ctx context.Context, input map[string]any) (
|
|||
|
||||
result := make(map[string]any)
|
||||
groupToChoice := make(map[string]int)
|
||||
for group, length := range v.config.GroupLen {
|
||||
for group, length := range v.groupLen {
|
||||
for i := 0; i < length; i++ {
|
||||
if value, ok := in[group][i]; ok {
|
||||
if value != nil {
|
||||
|
|
@ -93,14 +219,14 @@ func (v *VariableAggregator) Invoke(ctx context.Context, input map[string]any) (
|
|||
}
|
||||
|
||||
_ = compose.ProcessState(ctx, func(ctx context.Context, state nodes.DynamicStreamContainer) error {
|
||||
state.SaveDynamicChoice(v.config.NodeKey, groupToChoice)
|
||||
state.SaveDynamicChoice(v.nodeKey, groupToChoice)
|
||||
return nil
|
||||
})
|
||||
|
||||
ctxcache.Store(ctx, groupChoiceTypeCacheKey, map[string]nodes.FieldStreamType{}) // none of the choices are stream
|
||||
ctxcache.Store(ctx, groupChoiceTypeCacheKey, map[string]schema2.FieldStreamType{}) // none of the choices are stream
|
||||
|
||||
groupChoices := make([]any, 0, len(v.config.GroupOrder))
|
||||
for _, group := range v.config.GroupOrder {
|
||||
groupChoices := make([]any, 0, len(v.groupOrder))
|
||||
for _, group := range v.groupOrder {
|
||||
choice := groupToChoice[group]
|
||||
if choice == -1 {
|
||||
groupChoices = append(groupChoices, nil)
|
||||
|
|
@ -125,7 +251,7 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
|
|||
_ *schema.StreamReader[map[string]any], err error) {
|
||||
inStream := streamInputConverter(input)
|
||||
|
||||
resolvedSources, ok := ctxcache.Get[map[string]*nodes.SourceInfo](ctx, resolvedSourcesCacheKey)
|
||||
resolvedSources, ok := ctxcache.Get[map[string]*schema2.SourceInfo](ctx, resolvedSourcesCacheKey)
|
||||
if !ok {
|
||||
panic("unable to get resolvesSources from ctx cache.")
|
||||
}
|
||||
|
|
@ -138,18 +264,18 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
|
|||
|
||||
defer func() {
|
||||
if err == nil {
|
||||
groupChoiceToStreamType := map[string]nodes.FieldStreamType{}
|
||||
groupChoiceToStreamType := map[string]schema2.FieldStreamType{}
|
||||
for group, choice := range groupToChoice {
|
||||
if choice != -1 {
|
||||
item := groupToItems[group][choice]
|
||||
if _, ok := item.(stream); ok {
|
||||
groupChoiceToStreamType[group] = nodes.FieldIsStream
|
||||
groupChoiceToStreamType[group] = schema2.FieldIsStream
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
groupChoices := make([]any, 0, len(v.config.GroupOrder))
|
||||
for _, group := range v.config.GroupOrder {
|
||||
groupChoices := make([]any, 0, len(v.groupOrder))
|
||||
for _, group := range v.groupOrder {
|
||||
choice := groupToChoice[group]
|
||||
if choice == -1 {
|
||||
groupChoices = append(groupChoices, nil)
|
||||
|
|
@ -174,16 +300,16 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
|
|||
// - if an element is not stream, actually receive from the stream to check if it's non-nil
|
||||
|
||||
groupToCurrentIndex := make(map[string]int) // the currently known smallest index that is non-nil for each group
|
||||
for group, length := range v.config.GroupLen {
|
||||
for group, length := range v.groupLen {
|
||||
groupToItems[group] = make([]any, length)
|
||||
groupToCurrentIndex[group] = math.MaxInt
|
||||
for i := 0; i < length; i++ {
|
||||
fType := resolvedSources[group].SubSources[strconv.Itoa(i)].FieldType
|
||||
if fType == nodes.FieldSkipped {
|
||||
if fType == schema2.FieldSkipped {
|
||||
groupToItems[group][i] = skipped{}
|
||||
continue
|
||||
}
|
||||
if fType == nodes.FieldIsStream {
|
||||
if fType == schema2.FieldIsStream {
|
||||
groupToItems[group][i] = stream{}
|
||||
if ci, _ := groupToCurrentIndex[group]; i < ci {
|
||||
groupToCurrentIndex[group] = i
|
||||
|
|
@ -211,7 +337,7 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
|
|||
}
|
||||
|
||||
allDone := func() bool {
|
||||
for group := range v.config.GroupLen {
|
||||
for group := range v.groupLen {
|
||||
_, ok := groupToChoice[group]
|
||||
if !ok {
|
||||
return false
|
||||
|
|
@ -223,7 +349,7 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
|
|||
|
||||
alreadyDone := allDone()
|
||||
if alreadyDone { // all groups have made their choices, no need to actually read input streams
|
||||
result := make(map[string]any, len(v.config.GroupLen))
|
||||
result := make(map[string]any, len(v.groupLen))
|
||||
allSkip := true
|
||||
for group := range groupToChoice {
|
||||
choice := groupToChoice[group]
|
||||
|
|
@ -237,7 +363,7 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
|
|||
|
||||
if allSkip { // no need to convert input streams for the output, because all groups are skipped
|
||||
_ = compose.ProcessState(ctx, func(ctx context.Context, state nodes.DynamicStreamContainer) error {
|
||||
state.SaveDynamicChoice(v.config.NodeKey, groupToChoice)
|
||||
state.SaveDynamicChoice(v.nodeKey, groupToChoice)
|
||||
return nil
|
||||
})
|
||||
return schema.StreamReaderFromArray([]map[string]any{result}), nil
|
||||
|
|
@ -336,7 +462,7 @@ func (v *VariableAggregator) Transform(ctx context.Context, input *schema.Stream
|
|||
}
|
||||
|
||||
_ = compose.ProcessState(ctx, func(ctx context.Context, state nodes.DynamicStreamContainer) error {
|
||||
state.SaveDynamicChoice(v.config.NodeKey, groupToChoice)
|
||||
state.SaveDynamicChoice(v.nodeKey, groupToChoice)
|
||||
return nil
|
||||
})
|
||||
|
||||
|
|
@ -416,26 +542,12 @@ type vaCallbackInput struct {
|
|||
Variables []any `json:"variables"`
|
||||
}
|
||||
|
||||
func (v *VariableAggregator) Init(ctx context.Context) (context.Context, error) {
|
||||
ctx = ctxcache.Init(ctx)
|
||||
|
||||
resolvedSources, err := nodes.ResolveStreamSources(ctx, v.config.FullSources)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// need this info for callbacks.OnStart, so we put it in cache within Init()
|
||||
ctxcache.Store(ctx, resolvedSourcesCacheKey, resolvedSources)
|
||||
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
type streamMarkerType string
|
||||
|
||||
const streamMarker streamMarkerType = "<Stream Data...>"
|
||||
|
||||
func (v *VariableAggregator) ToCallbackInput(ctx context.Context, input map[string]any) (map[string]any, error) {
|
||||
resolvedSources, ok := ctxcache.Get[map[string]*nodes.SourceInfo](ctx, resolvedSourcesCacheKey)
|
||||
resolvedSources, ok := ctxcache.Get[map[string]*schema2.SourceInfo](ctx, resolvedSourcesCacheKey)
|
||||
if !ok {
|
||||
panic("unable to get resolved_sources from ctx cache")
|
||||
}
|
||||
|
|
@ -447,14 +559,14 @@ func (v *VariableAggregator) ToCallbackInput(ctx context.Context, input map[stri
|
|||
|
||||
merged := make([]vaCallbackInput, 0, len(in))
|
||||
|
||||
groupLen := v.config.GroupLen
|
||||
groupLen := v.groupLen
|
||||
|
||||
for groupName, vars := range in {
|
||||
orderedVars := make([]any, groupLen[groupName])
|
||||
for index := range vars {
|
||||
orderedVars[index] = vars[index]
|
||||
if len(resolvedSources) > 0 {
|
||||
if resolvedSources[groupName].SubSources[strconv.Itoa(index)].FieldType == nodes.FieldIsStream {
|
||||
if resolvedSources[groupName].SubSources[strconv.Itoa(index)].FieldType == schema2.FieldIsStream {
|
||||
// replace the streams with streamMarker,
|
||||
// because we won't read, save to execution history, or display these streams to user
|
||||
orderedVars[index] = streamMarker
|
||||
|
|
@ -479,7 +591,7 @@ func (v *VariableAggregator) ToCallbackInput(ctx context.Context, input map[stri
|
|||
}
|
||||
|
||||
func (v *VariableAggregator) ToCallbackOutput(ctx context.Context, output map[string]any) (*nodes.StructuredCallbackOutput, error) {
|
||||
dynamicStreamType, ok := ctxcache.Get[map[string]nodes.FieldStreamType](ctx, groupChoiceTypeCacheKey)
|
||||
dynamicStreamType, ok := ctxcache.Get[map[string]schema2.FieldStreamType](ctx, groupChoiceTypeCacheKey)
|
||||
if !ok {
|
||||
panic("unable to get dynamic stream types from ctx cache")
|
||||
}
|
||||
|
|
@ -501,7 +613,7 @@ func (v *VariableAggregator) ToCallbackOutput(ctx context.Context, output map[st
|
|||
|
||||
newOut := maps.Clone(output)
|
||||
for k := range output {
|
||||
if t, ok := dynamicStreamType[k]; ok && t == nodes.FieldIsStream {
|
||||
if t, ok := dynamicStreamType[k]; ok && t == schema2.FieldIsStream {
|
||||
newOut[k] = streamMarker
|
||||
}
|
||||
}
|
||||
|
|
@ -594,3 +706,15 @@ func init() {
|
|||
nodes.RegisterStreamChunkConcatFunc(concatVACallbackInputs)
|
||||
nodes.RegisterStreamChunkConcatFunc(concatStreamMarkers)
|
||||
}
|
||||
|
||||
func (v *VariableAggregator) Init(ctx context.Context) (context.Context, error) {
|
||||
resolvedSources, err := nodes.ResolveStreamSources(ctx, v.fullSources)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// need this info for callbacks.OnStart, so we put it in cache within Init()
|
||||
ctxcache.Store(ctx, resolvedSourcesCacheKey, resolvedSources)
|
||||
|
||||
return ctx, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,29 +25,75 @@ import (
|
|||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
type VariableAssigner struct {
|
||||
config *Config
|
||||
pairs []*Pair
|
||||
handler *variable.Handler
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Pairs []*Pair
|
||||
Handler *variable.Handler
|
||||
}
|
||||
|
||||
type Pair struct {
|
||||
Left vo.Reference
|
||||
Right compose.FieldPath
|
||||
func (c *Config) Adapt(ctx context.Context, n *vo.Node, opts ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeVariableAssigner,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: c,
|
||||
}
|
||||
|
||||
var pairs = make([]*Pair, 0, len(n.Data.Inputs.InputParameters))
|
||||
for i, param := range n.Data.Inputs.InputParameters {
|
||||
if param.Left == nil || param.Input == nil {
|
||||
return nil, fmt.Errorf("variable assigner node's param left or input is nil")
|
||||
}
|
||||
|
||||
leftSources, err := convert.CanvasBlockInputToFieldInfo(param.Left, compose.FieldPath{fmt.Sprintf("left_%d", i)}, n.Parent())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if leftSources[0].Source.Ref == nil {
|
||||
return nil, fmt.Errorf("variable assigner node's param left source ref is nil")
|
||||
}
|
||||
|
||||
if leftSources[0].Source.Ref.VariableType == nil {
|
||||
return nil, fmt.Errorf("variable assigner node's param left source ref's variable type is nil")
|
||||
}
|
||||
|
||||
if *leftSources[0].Source.Ref.VariableType == vo.GlobalSystem {
|
||||
return nil, fmt.Errorf("variable assigner node's param left's ref's variable type cannot be variable.GlobalSystem")
|
||||
}
|
||||
|
||||
inputSource, err := convert.CanvasBlockInputToFieldInfo(param.Input, leftSources[0].Source.Ref.FromPath, n.Parent())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ns.AddInputSource(inputSource...)
|
||||
pair := &Pair{
|
||||
Left: *leftSources[0].Source.Ref,
|
||||
Right: inputSource[0].Path,
|
||||
}
|
||||
pairs = append(pairs, pair)
|
||||
}
|
||||
|
||||
c.Pairs = pairs
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func NewVariableAssigner(_ context.Context, conf *Config) (*VariableAssigner, error) {
|
||||
for _, pair := range conf.Pairs {
|
||||
func (c *Config) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
for _, pair := range c.Pairs {
|
||||
if pair.Left.VariableType == nil {
|
||||
return nil, fmt.Errorf("cannot assign to output of nodes in VariableAssigner, ref: %v", pair.Left)
|
||||
}
|
||||
|
|
@ -63,12 +109,18 @@ func NewVariableAssigner(_ context.Context, conf *Config) (*VariableAssigner, er
|
|||
}
|
||||
|
||||
return &VariableAssigner{
|
||||
config: conf,
|
||||
pairs: c.Pairs,
|
||||
handler: variable.GetVariableHandler(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (v *VariableAssigner) Assign(ctx context.Context, in map[string]any) (map[string]any, error) {
|
||||
for _, pair := range v.config.Pairs {
|
||||
type Pair struct {
|
||||
Left vo.Reference
|
||||
Right compose.FieldPath
|
||||
}
|
||||
|
||||
func (v *VariableAssigner) Invoke(ctx context.Context, in map[string]any) (map[string]any, error) {
|
||||
for _, pair := range v.pairs {
|
||||
right, ok := nodes.TakeMapValue(in, pair.Right)
|
||||
if !ok {
|
||||
return nil, vo.NewError(errno.ErrInputFieldMissing, errorx.KV("name", strings.Join(pair.Right, ".")))
|
||||
|
|
@ -98,7 +150,7 @@ func (v *VariableAssigner) Assign(ctx context.Context, in map[string]any) (map[s
|
|||
ConnectorUID: exeCfg.ConnectorUID,
|
||||
}))
|
||||
}
|
||||
err := v.config.Handler.Set(ctx, *pair.Left.VariableType, pair.Left.FromPath, right, opts...)
|
||||
err := v.handler.Set(ctx, *pair.Left.VariableType, pair.Left.FromPath, right, opts...)
|
||||
if err != nil {
|
||||
return nil, vo.WrapIfNeeded(errno.ErrVariablesAPIFail, err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,25 +20,93 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
|
||||
einoCompose "github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
)
|
||||
|
||||
type InLoop struct {
|
||||
config *Config
|
||||
intermediateVarStore variable.Store
|
||||
type InLoopConfig struct {
|
||||
Pairs []*Pair
|
||||
}
|
||||
|
||||
func NewVariableAssignerInLoop(_ context.Context, conf *Config) (*InLoop, error) {
|
||||
func (i *InLoopConfig) Adapt(ctx context.Context, n *vo.Node, opts ...nodes.AdaptOption) (*schema.NodeSchema, error) {
|
||||
if n.Parent() == nil {
|
||||
return nil, fmt.Errorf("loop set variable node must have parent: %s", n.ID)
|
||||
}
|
||||
|
||||
ns := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeVariableAssignerWithinLoop,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: i,
|
||||
}
|
||||
|
||||
var pairs []*Pair
|
||||
for i, param := range n.Data.Inputs.InputParameters {
|
||||
if param.Left == nil || param.Right == nil {
|
||||
return nil, fmt.Errorf("loop set variable node's param left or right is nil")
|
||||
}
|
||||
|
||||
leftSources, err := convert.CanvasBlockInputToFieldInfo(param.Left, einoCompose.FieldPath{fmt.Sprintf("left_%d", i)}, n.Parent())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(leftSources) != 1 {
|
||||
return nil, fmt.Errorf("loop set variable node's param left is not a single source")
|
||||
}
|
||||
|
||||
if leftSources[0].Source.Ref == nil {
|
||||
return nil, fmt.Errorf("loop set variable node's param left's ref is nil")
|
||||
}
|
||||
|
||||
if leftSources[0].Source.Ref.VariableType == nil || *leftSources[0].Source.Ref.VariableType != vo.ParentIntermediate {
|
||||
return nil, fmt.Errorf("loop set variable node's param left's ref's variable type is not variable.ParentIntermediate")
|
||||
}
|
||||
|
||||
rightSources, err := convert.CanvasBlockInputToFieldInfo(param.Right, leftSources[0].Source.Ref.FromPath, n.Parent())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ns.AddInputSource(rightSources...)
|
||||
|
||||
if len(rightSources) != 1 {
|
||||
return nil, fmt.Errorf("loop set variable node's param right is not a single source")
|
||||
}
|
||||
|
||||
pair := &Pair{
|
||||
Left: *leftSources[0].Source.Ref,
|
||||
Right: rightSources[0].Path,
|
||||
}
|
||||
|
||||
pairs = append(pairs, pair)
|
||||
}
|
||||
|
||||
i.Pairs = pairs
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func (i *InLoopConfig) Build(_ context.Context, _ *schema.NodeSchema, _ ...schema.BuildOption) (any, error) {
|
||||
return &InLoop{
|
||||
config: conf,
|
||||
pairs: i.Pairs,
|
||||
intermediateVarStore: &nodes.ParentIntermediateStore{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (v *InLoop) Assign(ctx context.Context, in map[string]any) (out map[string]any, err error) {
|
||||
for _, pair := range v.config.Pairs {
|
||||
type InLoop struct {
|
||||
pairs []*Pair
|
||||
intermediateVarStore variable.Store
|
||||
}
|
||||
|
||||
func (v *InLoop) Invoke(ctx context.Context, in map[string]any) (out map[string]any, err error) {
|
||||
for _, pair := range v.pairs {
|
||||
if pair.Left.VariableType == nil || *pair.Left.VariableType != vo.ParentIntermediate {
|
||||
panic(fmt.Errorf("dest is %+v in VariableAssignerInloop, invalid", pair.Left))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -37,8 +37,7 @@ func TestVariableAssigner(t *testing.T) {
|
|||
arrVar := any([]any{1, "2"})
|
||||
|
||||
va := &InLoop{
|
||||
config: &Config{
|
||||
Pairs: []*Pair{
|
||||
pairs: []*Pair{
|
||||
{
|
||||
Left: vo.Reference{
|
||||
FromPath: compose.FieldPath{"int_var_s"},
|
||||
|
|
@ -68,7 +67,6 @@ func TestVariableAssigner(t *testing.T) {
|
|||
Right: compose.FieldPath{"arr_var_t"},
|
||||
},
|
||||
},
|
||||
},
|
||||
intermediateVarStore: &nodes.ParentIntermediateStore{},
|
||||
}
|
||||
|
||||
|
|
@ -79,7 +77,7 @@ func TestVariableAssigner(t *testing.T) {
|
|||
"arr_var_s": &arrVar,
|
||||
}, nil)
|
||||
|
||||
_, err := va.Assign(ctx, map[string]any{
|
||||
_, err := va.Invoke(ctx, map[string]any{
|
||||
"int_var_t": 2,
|
||||
"str_var_t": "str2",
|
||||
"obj_var_t": map[string]any{
|
||||
|
|
|
|||
|
|
@ -0,0 +1,196 @@
|
|||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
)
|
||||
|
||||
// Port type constants
|
||||
const (
|
||||
PortDefault = "default"
|
||||
PortBranchError = "branch_error"
|
||||
PortBranchFormat = "branch_%d"
|
||||
)
|
||||
|
||||
// BranchSchema defines the schema for workflow branches.
|
||||
type BranchSchema struct {
|
||||
From vo.NodeKey `json:"from_node"`
|
||||
DefaultMapping map[string]bool `json:"default_mapping,omitempty"`
|
||||
ExceptionMapping map[string]bool `json:"exception_mapping,omitempty"`
|
||||
Mappings map[int64]map[string]bool `json:"mappings,omitempty"`
|
||||
}
|
||||
|
||||
// BuildBranches builds branch schemas from connections.
|
||||
func BuildBranches(connections []*Connection) (map[vo.NodeKey]*BranchSchema, error) {
|
||||
var branchMap map[vo.NodeKey]*BranchSchema
|
||||
|
||||
for _, conn := range connections {
|
||||
if conn.FromPort == nil || len(*conn.FromPort) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
port := *conn.FromPort
|
||||
sourceNodeKey := conn.FromNode
|
||||
|
||||
if branchMap == nil {
|
||||
branchMap = map[vo.NodeKey]*BranchSchema{}
|
||||
}
|
||||
|
||||
// Get or create branch schema for source node
|
||||
branch, exists := branchMap[sourceNodeKey]
|
||||
if !exists {
|
||||
branch = &BranchSchema{
|
||||
From: sourceNodeKey,
|
||||
}
|
||||
branchMap[sourceNodeKey] = branch
|
||||
}
|
||||
|
||||
// Classify port type and add to appropriate mapping
|
||||
switch {
|
||||
case port == PortDefault:
|
||||
if branch.DefaultMapping == nil {
|
||||
branch.DefaultMapping = map[string]bool{}
|
||||
}
|
||||
branch.DefaultMapping[string(conn.ToNode)] = true
|
||||
case port == PortBranchError:
|
||||
if branch.ExceptionMapping == nil {
|
||||
branch.ExceptionMapping = map[string]bool{}
|
||||
}
|
||||
branch.ExceptionMapping[string(conn.ToNode)] = true
|
||||
default:
|
||||
var branchNum int64
|
||||
_, err := fmt.Sscanf(port, PortBranchFormat, &branchNum)
|
||||
if err != nil || branchNum < 0 {
|
||||
return nil, fmt.Errorf("invalid port format '%s' for connection %+v", port, conn)
|
||||
}
|
||||
if branch.Mappings == nil {
|
||||
branch.Mappings = map[int64]map[string]bool{}
|
||||
}
|
||||
if _, exists := branch.Mappings[branchNum]; !exists {
|
||||
branch.Mappings[branchNum] = make(map[string]bool)
|
||||
}
|
||||
branch.Mappings[branchNum][string(conn.ToNode)] = true
|
||||
}
|
||||
}
|
||||
|
||||
return branchMap, nil
|
||||
}
|
||||
|
||||
func (bs *BranchSchema) OnlyException() bool {
|
||||
return len(bs.Mappings) == 0 && len(bs.ExceptionMapping) > 0 && len(bs.DefaultMapping) > 0
|
||||
}
|
||||
|
||||
func (bs *BranchSchema) GetExceptionBranch() *compose.GraphBranch {
|
||||
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
|
||||
isSuccess, ok := in["isSuccess"]
|
||||
if ok && isSuccess != nil && !isSuccess.(bool) {
|
||||
return bs.ExceptionMapping, nil
|
||||
}
|
||||
|
||||
return bs.DefaultMapping, nil
|
||||
}
|
||||
|
||||
// Combine ExceptionMapping and DefaultMapping into a new map
|
||||
endNodes := make(map[string]bool)
|
||||
for node := range bs.ExceptionMapping {
|
||||
endNodes[node] = true
|
||||
}
|
||||
for node := range bs.DefaultMapping {
|
||||
endNodes[node] = true
|
||||
}
|
||||
|
||||
return compose.NewGraphMultiBranch(condition, endNodes)
|
||||
}
|
||||
|
||||
func (bs *BranchSchema) GetFullBranch(ctx context.Context, bb BranchBuilder) (*compose.GraphBranch, error) {
|
||||
extractor, hasBranch := bb.BuildBranch(ctx)
|
||||
if !hasBranch {
|
||||
return nil, fmt.Errorf("branch expected but BranchBuilder thinks not. BranchSchema: %v", bs)
|
||||
}
|
||||
|
||||
if len(bs.ExceptionMapping) == 0 { // no exception, it's a normal branch
|
||||
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
|
||||
index, isDefault, err := extractor(ctx, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if isDefault {
|
||||
return bs.DefaultMapping, nil
|
||||
}
|
||||
|
||||
if _, ok := bs.Mappings[index]; !ok {
|
||||
return nil, fmt.Errorf("chosen index= %d, out of range", index)
|
||||
}
|
||||
|
||||
return bs.Mappings[index], nil
|
||||
}
|
||||
|
||||
// Combine DefaultMapping and normal mappings into a new map
|
||||
endNodes := make(map[string]bool)
|
||||
for node := range bs.DefaultMapping {
|
||||
endNodes[node] = true
|
||||
}
|
||||
for _, ms := range bs.Mappings {
|
||||
for node := range ms {
|
||||
endNodes[node] = true
|
||||
}
|
||||
}
|
||||
|
||||
return compose.NewGraphMultiBranch(condition, endNodes), nil
|
||||
}
|
||||
|
||||
condition := func(ctx context.Context, in map[string]any) (map[string]bool, error) {
|
||||
isSuccess, ok := in["isSuccess"]
|
||||
if ok && isSuccess != nil && !isSuccess.(bool) {
|
||||
return bs.ExceptionMapping, nil
|
||||
}
|
||||
|
||||
index, isDefault, err := extractor(ctx, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if isDefault {
|
||||
return bs.DefaultMapping, nil
|
||||
}
|
||||
|
||||
return bs.Mappings[index], nil
|
||||
}
|
||||
|
||||
// Combine ALL mappings into a new map
|
||||
endNodes := make(map[string]bool)
|
||||
for node := range bs.ExceptionMapping {
|
||||
endNodes[node] = true
|
||||
}
|
||||
for node := range bs.DefaultMapping {
|
||||
endNodes[node] = true
|
||||
}
|
||||
for _, ms := range bs.Mappings {
|
||||
for node := range ms {
|
||||
endNodes[node] = true
|
||||
}
|
||||
}
|
||||
|
||||
return compose.NewGraphMultiBranch(condition, endNodes), nil
|
||||
}
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
)
|
||||
|
||||
type BuildOptions struct {
|
||||
WS *WorkflowSchema
|
||||
Inner compose.Runnable[map[string]any, map[string]any]
|
||||
}
|
||||
|
||||
func GetBuildOptions(opts ...BuildOption) *BuildOptions {
|
||||
bo := &BuildOptions{}
|
||||
for _, o := range opts {
|
||||
o(bo)
|
||||
}
|
||||
return bo
|
||||
}
|
||||
|
||||
type BuildOption func(options *BuildOptions)
|
||||
|
||||
func WithWorkflowSchema(ws *WorkflowSchema) BuildOption {
|
||||
return func(options *BuildOptions) {
|
||||
options.WS = ws
|
||||
}
|
||||
}
|
||||
|
||||
func WithInnerWorkflow(inner compose.Runnable[map[string]any, map[string]any]) BuildOption {
|
||||
return func(options *BuildOptions) {
|
||||
options.Inner = inner
|
||||
}
|
||||
}
|
||||
|
||||
// NodeBuilder takes a NodeSchema and several BuildOption to build an executable node instance.
|
||||
// The result 'executable' MUST implement at least one of the execute interfaces:
|
||||
// - nodes.InvokableNode
|
||||
// - nodes.StreamableNode
|
||||
// - nodes.CollectableNode
|
||||
// - nodes.TransformableNode
|
||||
// - nodes.InvokableNodeWOpt
|
||||
// - nodes.StreamableNodeWOpt
|
||||
// - nodes.CollectableNodeWOpt
|
||||
// - nodes.TransformableNodeWOpt
|
||||
// NOTE: the 'normal' version does not take NodeOption, while the 'WOpt' versions take NodeOption.
|
||||
// NOTE: a node should either implement the 'normal' versions, or the 'WOpt' versions, not mix them up.
|
||||
type NodeBuilder interface {
|
||||
Build(ctx context.Context, ns *NodeSchema, opts ...BuildOption) (
|
||||
executable any, err error)
|
||||
}
|
||||
|
||||
// BranchBuilder builds the extractor function that maps node output to port index.
|
||||
type BranchBuilder interface {
|
||||
BuildBranch(ctx context.Context) (extractor func(ctx context.Context,
|
||||
nodeOutput map[string]any) (int64, bool /*if is default branch*/, error), hasBranch bool)
|
||||
}
|
||||
|
|
@ -0,0 +1,131 @@
|
|||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package schema
|
||||
|
||||
import (
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
)
|
||||
|
||||
// NodeSchema is the universal description and configuration for a workflow Node.
|
||||
// It should contain EVERYTHING a node needs to instantiate.
|
||||
type NodeSchema struct {
|
||||
// Key is the node key within the Eino graph.
|
||||
// A node may need this information during execution,
|
||||
// e.g.
|
||||
// - using this Key to query workflow State for data belonging to current node.
|
||||
Key vo.NodeKey `json:"key"`
|
||||
|
||||
// Name is the name for this node as specified on Canvas.
|
||||
// A node may show this name on Canvas as part of this node's input/output.
|
||||
Name string `json:"name"`
|
||||
|
||||
// Type is the NodeType for the node.
|
||||
Type entity.NodeType `json:"type"`
|
||||
|
||||
// Configs are node specific configurations, with actual struct type defined by each Node Type.
|
||||
// Will not hold information relating to field mappings, nor as node's static values.
|
||||
// In a word, these Configs are INTERNAL to node's implementation, NOT related to workflow orchestration.
|
||||
// Actual type of these Configs should implement two interfaces:
|
||||
// - NodeAdaptor: to provide conversion from vo.Node to NodeSchema
|
||||
// - NodeBuilder: to provide instantiation from NodeSchema to actual node instance.
|
||||
Configs any `json:"configs,omitempty"`
|
||||
|
||||
// InputTypes are type information about the node's input fields.
|
||||
InputTypes map[string]*vo.TypeInfo `json:"input_types,omitempty"`
|
||||
// InputSources are field mapping information about the node's input fields.
|
||||
InputSources []*vo.FieldInfo `json:"input_sources,omitempty"`
|
||||
|
||||
// OutputTypes are type information about the node's output fields.
|
||||
OutputTypes map[string]*vo.TypeInfo `json:"output_types,omitempty"`
|
||||
// OutputSources are field mapping information about the node's output fields.
|
||||
// NOTE: only applicable to composite nodes such as NodeTypeBatch or NodeTypeLoop.
|
||||
OutputSources []*vo.FieldInfo `json:"output_sources,omitempty"`
|
||||
|
||||
// ExceptionConfigs are about exception handling strategy of the node.
|
||||
ExceptionConfigs *ExceptionConfig `json:"exception_configs,omitempty"`
|
||||
// StreamConfigs are streaming characteristics of the node.
|
||||
StreamConfigs *StreamConfig `json:"stream_configs,omitempty"`
|
||||
|
||||
// SubWorkflowBasic is basic information of the sub workflow if this node is NodeTypeSubWorkflow.
|
||||
SubWorkflowBasic *entity.WorkflowBasic `json:"sub_workflow_basic,omitempty"`
|
||||
// SubWorkflowSchema is WorkflowSchema of the sub workflow if this node is NodeTypeSubWorkflow.
|
||||
SubWorkflowSchema *WorkflowSchema `json:"sub_workflow_schema,omitempty"`
|
||||
|
||||
// FullSources contains more complete information about a node's input fields' mapping sources,
|
||||
// such as whether a field's source is a 'streaming field',
|
||||
// or whether the field is an object that contains sub-fields with real mappings.
|
||||
// Used for those nodes that need to process streaming input.
|
||||
// Set InputSourceAware = true in NodeMeta to enable.
|
||||
FullSources map[string]*SourceInfo
|
||||
|
||||
// Lambda directly sets the node to be an Eino Lambda.
|
||||
// NOTE: not serializable, used ONLY for internal test.
|
||||
Lambda *compose.Lambda
|
||||
}
|
||||
|
||||
type RequireCheckpoint interface {
|
||||
RequireCheckpoint() bool
|
||||
}
|
||||
|
||||
type ExceptionConfig struct {
|
||||
TimeoutMS int64 `json:"timeout_ms,omitempty"` // timeout in milliseconds, 0 means no timeout
|
||||
MaxRetry int64 `json:"max_retry,omitempty"` // max retry times, 0 means no retry
|
||||
ProcessType *vo.ErrorProcessType `json:"process_type,omitempty"` // error process type, 0 means throw error
|
||||
DataOnErr string `json:"data_on_err,omitempty"` // data to return when error, effective when ProcessType==Default occurs
|
||||
}
|
||||
|
||||
type StreamConfig struct {
|
||||
// whether this node has the ability to produce genuine streaming output.
|
||||
// not include nodes that only passes stream down as they receives them
|
||||
CanGeneratesStream bool `json:"can_generates_stream,omitempty"`
|
||||
// whether this node prioritize streaming input over none-streaming input.
|
||||
// not include nodes that can accept both and does not have preference.
|
||||
RequireStreamingInput bool `json:"can_process_stream,omitempty"`
|
||||
}
|
||||
|
||||
func (s *NodeSchema) SetConfigKV(key string, value any) {
|
||||
if s.Configs == nil {
|
||||
s.Configs = make(map[string]any)
|
||||
}
|
||||
|
||||
s.Configs.(map[string]any)[key] = value
|
||||
}
|
||||
|
||||
func (s *NodeSchema) SetInputType(key string, t *vo.TypeInfo) {
|
||||
if s.InputTypes == nil {
|
||||
s.InputTypes = make(map[string]*vo.TypeInfo)
|
||||
}
|
||||
s.InputTypes[key] = t
|
||||
}
|
||||
|
||||
func (s *NodeSchema) AddInputSource(info ...*vo.FieldInfo) {
|
||||
s.InputSources = append(s.InputSources, info...)
|
||||
}
|
||||
|
||||
func (s *NodeSchema) SetOutputType(key string, t *vo.TypeInfo) {
|
||||
if s.OutputTypes == nil {
|
||||
s.OutputTypes = make(map[string]*vo.TypeInfo)
|
||||
}
|
||||
s.OutputTypes[key] = t
|
||||
}
|
||||
|
||||
func (s *NodeSchema) AddOutputSource(info ...*vo.FieldInfo) {
|
||||
s.OutputSources = append(s.OutputSources, info...)
|
||||
}
|
||||
|
|
@ -0,0 +1,77 @@
|
|||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package schema
|
||||
|
||||
import (
|
||||
"github.com/cloudwego/eino/compose"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
)
|
||||
|
||||
type FieldStreamType string
|
||||
|
||||
const (
|
||||
FieldIsStream FieldStreamType = "yes" // absolutely a stream
|
||||
FieldNotStream FieldStreamType = "no" // absolutely not a stream
|
||||
FieldMaybeStream FieldStreamType = "maybe" // maybe a stream, requires request-time resolution
|
||||
FieldSkipped FieldStreamType = "skipped" // the field source's node is skipped
|
||||
)
|
||||
|
||||
type FieldSkipStatus string
|
||||
|
||||
// SourceInfo contains stream type for a input field source of a node.
|
||||
type SourceInfo struct {
|
||||
// IsIntermediate means this field is itself not a field source, but a map containing one or more field sources.
|
||||
IsIntermediate bool
|
||||
// FieldType the stream type of the field. May require request-time resolution in addition to compile-time.
|
||||
FieldType FieldStreamType
|
||||
// FromNodeKey is the node key that produces this field source. empty if the field is a static value or variable.
|
||||
FromNodeKey vo.NodeKey
|
||||
// FromPath is the path of this field source within the source node. empty if the field is a static value or variable.
|
||||
FromPath compose.FieldPath
|
||||
TypeInfo *vo.TypeInfo
|
||||
// SubSources are SourceInfo for keys within this intermediate Map(Object) field.
|
||||
SubSources map[string]*SourceInfo
|
||||
}
|
||||
|
||||
func (s *SourceInfo) Skipped() bool {
|
||||
if !s.IsIntermediate {
|
||||
return s.FieldType == FieldSkipped
|
||||
}
|
||||
|
||||
for _, sub := range s.SubSources {
|
||||
if !sub.Skipped() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *SourceInfo) FromNode(nodeKey vo.NodeKey) bool {
|
||||
if !s.IsIntermediate {
|
||||
return s.FromNodeKey == nodeKey
|
||||
}
|
||||
|
||||
for _, sub := range s.SubSources {
|
||||
if sub.FromNode(nodeKey) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
|
@ -14,7 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package compose
|
||||
package schema
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
|
@ -32,6 +32,7 @@ type WorkflowSchema struct {
|
|||
Nodes []*NodeSchema `json:"nodes"`
|
||||
Connections []*Connection `json:"connections"`
|
||||
Hierarchy map[vo.NodeKey]vo.NodeKey `json:"hierarchy,omitempty"` // child node key-> parent node key
|
||||
Branches map[vo.NodeKey]*BranchSchema `json:"branches,omitempty"`
|
||||
|
||||
GeneratedNodes []vo.NodeKey `json:"generated_nodes,omitempty"` // generated nodes for the nodes in batch mode
|
||||
|
||||
|
|
@ -71,12 +72,22 @@ func (w *WorkflowSchema) Init() {
|
|||
w.doGetCompositeNodes()
|
||||
|
||||
for _, node := range w.Nodes {
|
||||
if node.requireCheckpoint() {
|
||||
if node.Type == entity.NodeTypeSubWorkflow {
|
||||
node.SubWorkflowSchema.Init()
|
||||
if node.SubWorkflowSchema.requireCheckPoint {
|
||||
w.requireCheckPoint = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if rc, ok := node.Configs.(RequireCheckpoint); ok {
|
||||
if rc.RequireCheckpoint() {
|
||||
w.requireCheckPoint = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
w.requireStreaming = w.doRequireStreaming()
|
||||
})
|
||||
}
|
||||
|
|
@ -97,6 +108,22 @@ func (w *WorkflowSchema) GetCompositeNodes() []*CompositeNode {
|
|||
return w.compositeNodes
|
||||
}
|
||||
|
||||
func (w *WorkflowSchema) GetBranch(key vo.NodeKey) *BranchSchema {
|
||||
if w.Branches == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return w.Branches[key]
|
||||
}
|
||||
|
||||
func (w *WorkflowSchema) RequireCheckpoint() bool {
|
||||
return w.requireCheckPoint
|
||||
}
|
||||
|
||||
func (w *WorkflowSchema) RequireStreaming() bool {
|
||||
return w.requireStreaming
|
||||
}
|
||||
|
||||
func (w *WorkflowSchema) doGetCompositeNodes() (cNodes []*CompositeNode) {
|
||||
if w.Hierarchy == nil {
|
||||
return nil
|
||||
|
|
@ -125,7 +152,7 @@ func (w *WorkflowSchema) doGetCompositeNodes() (cNodes []*CompositeNode) {
|
|||
return cNodes
|
||||
}
|
||||
|
||||
func isInSameWorkflow(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.NodeKey) bool {
|
||||
func IsInSameWorkflow(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.NodeKey) bool {
|
||||
if n == nil {
|
||||
return true
|
||||
}
|
||||
|
|
@ -144,7 +171,7 @@ func isInSameWorkflow(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.Node
|
|||
return myParents == theirParents
|
||||
}
|
||||
|
||||
func isBelowOneLevel(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.NodeKey) bool {
|
||||
func IsBelowOneLevel(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.NodeKey) bool {
|
||||
if n == nil {
|
||||
return false
|
||||
}
|
||||
|
|
@ -154,7 +181,7 @@ func isBelowOneLevel(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.NodeK
|
|||
return myParentExists && !theirParentExists
|
||||
}
|
||||
|
||||
func isParentOf(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.NodeKey) bool {
|
||||
func IsParentOf(n map[vo.NodeKey]vo.NodeKey, nodeKey, otherNodeKey vo.NodeKey) bool {
|
||||
if n == nil {
|
||||
return false
|
||||
}
|
||||
|
|
@ -230,21 +257,14 @@ func (w *WorkflowSchema) doRequireStreaming() bool {
|
|||
consumers := make(map[vo.NodeKey]bool)
|
||||
|
||||
for _, node := range w.Nodes {
|
||||
meta := entity.NodeMetaByNodeType(node.Type)
|
||||
if meta != nil {
|
||||
sps := meta.ExecutableMeta.StreamingParadigms
|
||||
if _, ok := sps[entity.Stream]; ok {
|
||||
if node.StreamConfigs != nil && node.StreamConfigs.CanGeneratesStream {
|
||||
producers[node.Key] = true
|
||||
}
|
||||
}
|
||||
|
||||
if sps[entity.Transform] || sps[entity.Collect] {
|
||||
if node.StreamConfigs != nil && node.StreamConfigs.RequireStreamingInput {
|
||||
consumers[node.Key] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if len(producers) == 0 || len(consumers) == 0 {
|
||||
|
|
@ -290,7 +310,7 @@ func (w *WorkflowSchema) doRequireStreaming() bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func (w *WorkflowSchema) fanInMergeConfigs() map[string]compose.FanInMergeConfig {
|
||||
func (w *WorkflowSchema) FanInMergeConfigs() map[string]compose.FanInMergeConfig {
|
||||
// what we need to do is to see if the workflow requires streaming, if not, then no fan-in merge configs needed
|
||||
// then we find those nodes that have 'transform' or 'collect' as streaming paradigm,
|
||||
// and see if each of those nodes has multiple data predecessors, if so, it's a fan-in node.
|
||||
|
|
@ -301,10 +321,6 @@ func (w *WorkflowSchema) fanInMergeConfigs() map[string]compose.FanInMergeConfig
|
|||
|
||||
fanInNodes := make(map[vo.NodeKey]bool)
|
||||
for _, node := range w.Nodes {
|
||||
meta := entity.NodeMetaByNodeType(node.Type)
|
||||
if meta != nil {
|
||||
sps := meta.ExecutableMeta.StreamingParadigms
|
||||
if sps[entity.Transform] || sps[entity.Collect] {
|
||||
if node.StreamConfigs != nil && node.StreamConfigs.RequireStreamingInput {
|
||||
var predecessor *vo.NodeKey
|
||||
for _, source := range node.InputSources {
|
||||
|
|
@ -318,8 +334,6 @@ func (w *WorkflowSchema) fanInMergeConfigs() map[string]compose.FanInMergeConfig
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fanInConfigs := make(map[string]compose.FanInMergeConfig)
|
||||
for nodeKey := range fanInNodes {
|
||||
|
|
@ -37,8 +37,11 @@ import (
|
|||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/adaptor"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/compose"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/intentdetector"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/repo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
|
|
@ -73,7 +76,7 @@ func NewWorkflowRepository(idgen idgen.IDGenerator, db *gorm.DB, redis cache.Cmd
|
|||
return repo.NewRepository(idgen, db, redis, tos, cpStore, chatModel)
|
||||
}
|
||||
|
||||
func (i *impl) ListNodeMeta(ctx context.Context, nodeTypes map[entity.NodeType]bool) (map[string][]*entity.NodeTypeMeta, []entity.Category, error) {
|
||||
func (i *impl) ListNodeMeta(_ context.Context, nodeTypes map[entity.NodeType]bool) (map[string][]*entity.NodeTypeMeta, []entity.Category, error) {
|
||||
// Initialize result maps
|
||||
nodeMetaMap := make(map[string][]*entity.NodeTypeMeta)
|
||||
|
||||
|
|
@ -82,7 +85,7 @@ func (i *impl) ListNodeMeta(ctx context.Context, nodeTypes map[entity.NodeType]b
|
|||
if meta.Disabled {
|
||||
return false
|
||||
}
|
||||
nodeType := meta.Type
|
||||
nodeType := meta.Key
|
||||
if nodeTypes == nil || len(nodeTypes) == 0 {
|
||||
return true // No filter, include all
|
||||
}
|
||||
|
|
@ -192,10 +195,10 @@ func extractInputsAndOutputsNamedInfoList(c *vo.Canvas) (inputs []*vo.NamedTypeI
|
|||
if startNode != nil && endNode != nil {
|
||||
break
|
||||
}
|
||||
if node.Type == vo.BlockTypeBotStart {
|
||||
if node.Type == entity.NodeTypeEntry.IDStr() {
|
||||
startNode = node
|
||||
}
|
||||
if node.Type == vo.BlockTypeBotEnd {
|
||||
if node.Type == entity.NodeTypeExit.IDStr() {
|
||||
endNode = node
|
||||
}
|
||||
}
|
||||
|
|
@ -207,7 +210,7 @@ func extractInputsAndOutputsNamedInfoList(c *vo.Canvas) (inputs []*vo.NamedTypeI
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nInfo, err := adaptor.VariableToNamedTypeInfo(v)
|
||||
nInfo, err := convert.VariableToNamedTypeInfo(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -220,7 +223,7 @@ func extractInputsAndOutputsNamedInfoList(c *vo.Canvas) (inputs []*vo.NamedTypeI
|
|||
|
||||
if endNode != nil {
|
||||
outputs, err = slices.TransformWithErrorCheck(endNode.Data.Inputs.InputParameters, func(a *vo.Param) (*vo.NamedTypeInfo, error) {
|
||||
return adaptor.BlockInputToNamedTypeInfo(a.Name, a.Input)
|
||||
return convert.BlockInputToNamedTypeInfo(a.Name, a.Input)
|
||||
})
|
||||
if err != nil {
|
||||
logs.Warn(fmt.Sprintf("transform end node inputs to named info failed, err=%v", err))
|
||||
|
|
@ -316,6 +319,34 @@ func (i *impl) GetWorkflowReference(ctx context.Context, id int64) (map[int64]*v
|
|||
return ret, nil
|
||||
}
|
||||
|
||||
type workflowIdentity struct {
|
||||
ID string `json:"id"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
func getAllSubWorkflowIdentities(c *vo.Canvas) []*workflowIdentity {
|
||||
workflowEntities := make([]*workflowIdentity, 0)
|
||||
|
||||
var collectSubWorkFlowEntities func(nodes []*vo.Node)
|
||||
collectSubWorkFlowEntities = func(nodes []*vo.Node) {
|
||||
for _, n := range nodes {
|
||||
if n.Type == entity.NodeTypeSubWorkflow.IDStr() {
|
||||
workflowEntities = append(workflowEntities, &workflowIdentity{
|
||||
ID: n.Data.Inputs.WorkflowID,
|
||||
Version: n.Data.Inputs.WorkflowVersion,
|
||||
})
|
||||
}
|
||||
if len(n.Blocks) > 0 {
|
||||
collectSubWorkFlowEntities(n.Blocks)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
collectSubWorkFlowEntities(c.Nodes)
|
||||
|
||||
return workflowEntities
|
||||
}
|
||||
|
||||
func (i *impl) ValidateTree(ctx context.Context, id int64, validateConfig vo.ValidateTreeConfig) ([]*cloudworkflow.ValidateTreeInfo, error) {
|
||||
wfValidateInfos := make([]*cloudworkflow.ValidateTreeInfo, 0)
|
||||
issues, err := validateWorkflowTree(ctx, validateConfig)
|
||||
|
|
@ -337,7 +368,7 @@ func (i *impl) ValidateTree(ctx context.Context, id int64, validateConfig vo.Val
|
|||
fmt.Errorf("failed to unmarshal canvas schema: %w", err))
|
||||
}
|
||||
|
||||
subWorkflowIdentities := c.GetAllSubWorkflowIdentities()
|
||||
subWorkflowIdentities := getAllSubWorkflowIdentities(c)
|
||||
|
||||
if len(subWorkflowIdentities) > 0 {
|
||||
var ids []int64
|
||||
|
|
@ -421,25 +452,21 @@ func (i *impl) collectNodePropertyMap(ctx context.Context, canvas *vo.Canvas) (m
|
|||
}
|
||||
|
||||
for _, n := range canvas.Nodes {
|
||||
if n.Type == vo.BlockTypeBotSubWorkflow {
|
||||
nodeSchema := &compose.NodeSchema{
|
||||
if n.Type == entity.NodeTypeSubWorkflow.IDStr() {
|
||||
nodeSchema := &schema.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeSubWorkflow,
|
||||
Name: n.Data.Meta.Title,
|
||||
}
|
||||
err := adaptor.SetInputsForNodeSchema(n, nodeSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
blockType, err := entityNodeTypeToBlockType(nodeSchema.Type)
|
||||
err := convert.SetInputsForNodeSchema(n, nodeSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
prop := &vo.NodeProperty{
|
||||
Type: string(blockType),
|
||||
IsEnableUserQuery: nodeSchema.IsEnableUserQuery(),
|
||||
IsEnableChatHistory: nodeSchema.IsEnableChatHistory(),
|
||||
IsRefGlobalVariable: nodeSchema.IsRefGlobalVariable(),
|
||||
Type: nodeSchema.Type.IDStr(),
|
||||
IsEnableUserQuery: isEnableUserQuery(nodeSchema),
|
||||
IsEnableChatHistory: isEnableChatHistory(nodeSchema),
|
||||
IsRefGlobalVariable: isRefGlobalVariable(nodeSchema),
|
||||
}
|
||||
nodePropertyMap[string(nodeSchema.Key)] = prop
|
||||
wid, err := strconv.ParseInt(n.Data.Inputs.WorkflowID, 10, 64)
|
||||
|
|
@ -478,20 +505,16 @@ func (i *impl) collectNodePropertyMap(ctx context.Context, canvas *vo.Canvas) (m
|
|||
prop.SubWorkflow = ret
|
||||
|
||||
} else {
|
||||
nodeSchemas, _, err := adaptor.NodeToNodeSchema(ctx, n)
|
||||
nodeSchemas, _, err := adaptor.NodeToNodeSchema(ctx, n, canvas)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, nodeSchema := range nodeSchemas {
|
||||
blockType, err := entityNodeTypeToBlockType(nodeSchema.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nodePropertyMap[string(nodeSchema.Key)] = &vo.NodeProperty{
|
||||
Type: string(blockType),
|
||||
IsEnableUserQuery: nodeSchema.IsEnableUserQuery(),
|
||||
IsEnableChatHistory: nodeSchema.IsEnableChatHistory(),
|
||||
IsRefGlobalVariable: nodeSchema.IsRefGlobalVariable(),
|
||||
Type: nodeSchema.Type.IDStr(),
|
||||
IsEnableUserQuery: isEnableUserQuery(nodeSchema),
|
||||
IsEnableChatHistory: isEnableChatHistory(nodeSchema),
|
||||
IsRefGlobalVariable: isRefGlobalVariable(nodeSchema),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -500,6 +523,60 @@ func (i *impl) collectNodePropertyMap(ctx context.Context, canvas *vo.Canvas) (m
|
|||
return nodePropertyMap, nil
|
||||
}
|
||||
|
||||
func isEnableUserQuery(s *schema.NodeSchema) bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
if s.Type != entity.NodeTypeEntry {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(s.OutputSources) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, source := range s.OutputSources {
|
||||
fieldPath := source.Path
|
||||
if len(fieldPath) == 1 && (fieldPath[0] == "BOT_USER_INPUT" || fieldPath[0] == "USER_INPUT") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func isEnableChatHistory(s *schema.NodeSchema) bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
switch s.Type {
|
||||
|
||||
case entity.NodeTypeLLM:
|
||||
llmParam := s.Configs.(*llm.Config).LLMParams
|
||||
return llmParam.EnableChatHistory
|
||||
case entity.NodeTypeIntentDetector:
|
||||
llmParam := s.Configs.(*intentdetector.Config).LLMParams
|
||||
return llmParam.EnableChatHistory
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isRefGlobalVariable(s *schema.NodeSchema) bool {
|
||||
for _, source := range s.InputSources {
|
||||
if source.IsRefGlobalVariable() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
for _, source := range s.OutputSources {
|
||||
if source.IsRefGlobalVariable() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func canvasToRefs(referringID int64, canvasStr string) (map[entity.WorkflowReferenceKey]struct{}, error) {
|
||||
var canvas vo.Canvas
|
||||
if err := sonic.UnmarshalString(canvasStr, &canvas); err != nil {
|
||||
|
|
@ -510,7 +587,7 @@ func canvasToRefs(referringID int64, canvasStr string) (map[entity.WorkflowRefer
|
|||
var getRefFn func([]*vo.Node) error
|
||||
getRefFn = func(nodes []*vo.Node) error {
|
||||
for _, node := range nodes {
|
||||
if node.Type == vo.BlockTypeBotSubWorkflow {
|
||||
if node.Type == entity.NodeTypeSubWorkflow.IDStr() {
|
||||
referredID, err := strconv.ParseInt(node.Data.Inputs.WorkflowID, 10, 64)
|
||||
if err != nil {
|
||||
return vo.WrapError(errno.ErrSchemaConversionFail, err)
|
||||
|
|
@ -521,7 +598,8 @@ func canvasToRefs(referringID int64, canvasStr string) (map[entity.WorkflowRefer
|
|||
ReferType: vo.ReferTypeSubWorkflow,
|
||||
ReferringBizType: vo.ReferringBizTypeWorkflow,
|
||||
}] = struct{}{}
|
||||
} else if node.Type == vo.BlockTypeBotLLM {
|
||||
} else if node.Type == entity.NodeTypeLLM.IDStr() {
|
||||
if node.Data.Inputs.LLM != nil {
|
||||
if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil {
|
||||
for _, w := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList {
|
||||
referredID, err := strconv.ParseInt(w.WorkflowID, 10, 64)
|
||||
|
|
@ -536,6 +614,7 @@ func canvasToRefs(referringID int64, canvasStr string) (map[entity.WorkflowRefer
|
|||
}] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if len(node.Blocks) > 0 {
|
||||
for _, subNode := range node.Blocks {
|
||||
if err := getRefFn([]*vo.Node{subNode}); err != nil {
|
||||
|
|
@ -832,7 +911,7 @@ func (i *impl) CopyWorkflowFromAppToLibrary(ctx context.Context, workflowID int6
|
|||
|
||||
validateAndBuildWorkflowReference = func(nodes []*vo.Node, wf *copiedWorkflow) error {
|
||||
for _, node := range nodes {
|
||||
if node.Type == vo.BlockTypeBotSubWorkflow {
|
||||
if node.Type == entity.NodeTypeSubWorkflow.IDStr() {
|
||||
var (
|
||||
v *vo.DraftInfo
|
||||
wfID int64
|
||||
|
|
@ -883,7 +962,7 @@ func (i *impl) CopyWorkflowFromAppToLibrary(ctx context.Context, workflowID int6
|
|||
|
||||
}
|
||||
|
||||
if node.Type == vo.BlockTypeBotLLM {
|
||||
if node.Type == entity.NodeTypeLLM.IDStr() {
|
||||
if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil {
|
||||
for _, w := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList {
|
||||
var (
|
||||
|
|
@ -1086,7 +1165,7 @@ func (i *impl) DuplicateWorkflowsByAppID(ctx context.Context, sourceAppID, targe
|
|||
var buildWorkflowReference func(nodes []*vo.Node, wf *copiedWorkflow) error
|
||||
buildWorkflowReference = func(nodes []*vo.Node, wf *copiedWorkflow) error {
|
||||
for _, node := range nodes {
|
||||
if node.Type == vo.BlockTypeBotSubWorkflow {
|
||||
if node.Type == entity.NodeTypeSubWorkflow.IDStr() {
|
||||
var (
|
||||
v *vo.DraftInfo
|
||||
wfID int64
|
||||
|
|
@ -1121,7 +1200,7 @@ func (i *impl) DuplicateWorkflowsByAppID(ctx context.Context, sourceAppID, targe
|
|||
}
|
||||
|
||||
}
|
||||
if node.Type == vo.BlockTypeBotLLM {
|
||||
if node.Type == entity.NodeTypeLLM.IDStr() {
|
||||
if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil {
|
||||
for _, w := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList {
|
||||
var (
|
||||
|
|
@ -1323,8 +1402,8 @@ func (i *impl) GetWorkflowDependenceResource(ctx context.Context, workflowID int
|
|||
var collectDependence func(nodes []*vo.Node) error
|
||||
collectDependence = func(nodes []*vo.Node) error {
|
||||
for _, node := range nodes {
|
||||
switch node.Type {
|
||||
case vo.BlockTypeBotAPI:
|
||||
switch entity.IDStrToNodeType(node.Type) {
|
||||
case entity.NodeTypePlugin:
|
||||
apiParams := slices.ToMap(node.Data.Inputs.APIParams, func(e *vo.Param) (string, *vo.Param) {
|
||||
return e.Name, e
|
||||
})
|
||||
|
|
@ -1347,7 +1426,7 @@ func (i *impl) GetWorkflowDependenceResource(ctx context.Context, workflowID int
|
|||
ds.PluginIDs = append(ds.PluginIDs, pID)
|
||||
}
|
||||
|
||||
case vo.BlockTypeBotDatasetWrite, vo.BlockTypeBotDataset:
|
||||
case entity.NodeTypeKnowledgeIndexer, entity.NodeTypeKnowledgeRetriever:
|
||||
datasetListInfoParam := node.Data.Inputs.DatasetParam[0]
|
||||
datasetIDs := datasetListInfoParam.Input.Value.Content.([]any)
|
||||
for _, id := range datasetIDs {
|
||||
|
|
@ -1357,7 +1436,7 @@ func (i *impl) GetWorkflowDependenceResource(ctx context.Context, workflowID int
|
|||
}
|
||||
ds.KnowledgeIDs = append(ds.KnowledgeIDs, k)
|
||||
}
|
||||
case vo.BlockTypeDatabase, vo.BlockTypeDatabaseSelect, vo.BlockTypeDatabaseInsert, vo.BlockTypeDatabaseDelete, vo.BlockTypeDatabaseUpdate:
|
||||
case entity.NodeTypeDatabaseCustomSQL, entity.NodeTypeDatabaseQuery, entity.NodeTypeDatabaseInsert, entity.NodeTypeDatabaseDelete, entity.NodeTypeDatabaseUpdate:
|
||||
dsList := node.Data.Inputs.DatabaseInfoList
|
||||
if len(dsList) == 0 {
|
||||
return fmt.Errorf("database info is requird")
|
||||
|
|
@ -1369,7 +1448,7 @@ func (i *impl) GetWorkflowDependenceResource(ctx context.Context, workflowID int
|
|||
}
|
||||
ds.DatabaseIDs = append(ds.DatabaseIDs, dsID)
|
||||
}
|
||||
case vo.BlockTypeBotLLM:
|
||||
case entity.NodeTypeLLM:
|
||||
if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.PluginFCParam != nil {
|
||||
for idx := range node.Data.Inputs.FCParam.PluginFCParam.PluginList {
|
||||
pl := node.Data.Inputs.FCParam.PluginFCParam.PluginList[idx]
|
||||
|
|
@ -1396,7 +1475,7 @@ func (i *impl) GetWorkflowDependenceResource(ctx context.Context, workflowID int
|
|||
}
|
||||
}
|
||||
|
||||
case vo.BlockTypeBotSubWorkflow:
|
||||
case entity.NodeTypeSubWorkflow:
|
||||
wfID, err := strconv.ParseInt(node.Data.Inputs.WorkflowID, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -1567,8 +1646,8 @@ func replaceRelatedWorkflowOrExternalResourceInWorkflowNodes(nodes []*vo.Node, r
|
|||
)
|
||||
|
||||
for _, node := range nodes {
|
||||
switch node.Type {
|
||||
case vo.BlockTypeBotSubWorkflow:
|
||||
switch entity.IDStrToNodeType(node.Type) {
|
||||
case entity.NodeTypeSubWorkflow:
|
||||
if !hasWorkflowRelated {
|
||||
continue
|
||||
}
|
||||
|
|
@ -1580,7 +1659,7 @@ func replaceRelatedWorkflowOrExternalResourceInWorkflowNodes(nodes []*vo.Node, r
|
|||
node.Data.Inputs.WorkflowID = strconv.FormatInt(wf.ID, 10)
|
||||
node.Data.Inputs.WorkflowVersion = wf.Version
|
||||
}
|
||||
case vo.BlockTypeBotAPI:
|
||||
case entity.NodeTypePlugin:
|
||||
if !hasPluginRelated {
|
||||
continue
|
||||
}
|
||||
|
|
@ -1623,7 +1702,7 @@ func replaceRelatedWorkflowOrExternalResourceInWorkflowNodes(nodes []*vo.Node, r
|
|||
apiIDParam.Input.Value.Content = strconv.FormatInt(refApiID, 10)
|
||||
}
|
||||
|
||||
case vo.BlockTypeBotLLM:
|
||||
case entity.NodeTypeLLM:
|
||||
if hasWorkflowRelated && node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil {
|
||||
for idx := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList {
|
||||
wf := node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList[idx]
|
||||
|
|
@ -1669,7 +1748,7 @@ func replaceRelatedWorkflowOrExternalResourceInWorkflowNodes(nodes []*vo.Node, r
|
|||
}
|
||||
}
|
||||
|
||||
case vo.BlockTypeBotDataset, vo.BlockTypeBotDatasetWrite:
|
||||
case entity.NodeTypeKnowledgeIndexer, entity.NodeTypeKnowledgeRetriever:
|
||||
if !hasKnowledgeRelated {
|
||||
continue
|
||||
}
|
||||
|
|
@ -1685,7 +1764,7 @@ func replaceRelatedWorkflowOrExternalResourceInWorkflowNodes(nodes []*vo.Node, r
|
|||
}
|
||||
}
|
||||
|
||||
case vo.BlockTypeDatabase, vo.BlockTypeDatabaseSelect, vo.BlockTypeDatabaseInsert, vo.BlockTypeDatabaseDelete, vo.BlockTypeDatabaseUpdate:
|
||||
case entity.NodeTypeDatabaseCustomSQL, entity.NodeTypeDatabaseQuery, entity.NodeTypeDatabaseInsert, entity.NodeTypeDatabaseDelete, entity.NodeTypeDatabaseUpdate:
|
||||
if !hasDatabaseRelated {
|
||||
continue
|
||||
}
|
||||
|
|
@ -1713,3 +1792,7 @@ func replaceRelatedWorkflowOrExternalResourceInWorkflowNodes(nodes []*vo.Node, r
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func RegisterAllNodeAdaptors() {
|
||||
adaptor.RegisterAllNodeAdaptors()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,11 +24,9 @@ import (
|
|||
|
||||
cloudworkflow "github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/variable"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/adaptor"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/validate"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
|
@ -199,158 +197,3 @@ func isIncremental(prev version, next version) bool {
|
|||
|
||||
return next.Patch > prev.Patch
|
||||
}
|
||||
|
||||
func replaceRelatedWorkflowOrPluginInWorkflowNodes(nodes []*vo.Node, relatedWorkflows map[int64]entity.IDVersionPair, relatedPlugins map[int64]vo.PluginEntity) error {
|
||||
for _, node := range nodes {
|
||||
if node.Type == vo.BlockTypeBotSubWorkflow {
|
||||
workflowID, err := strconv.ParseInt(node.Data.Inputs.WorkflowID, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if wf, ok := relatedWorkflows[workflowID]; ok {
|
||||
node.Data.Inputs.WorkflowID = strconv.FormatInt(wf.ID, 10)
|
||||
node.Data.Inputs.WorkflowVersion = wf.Version
|
||||
}
|
||||
}
|
||||
if node.Type == vo.BlockTypeBotAPI {
|
||||
apiParams := slices.ToMap(node.Data.Inputs.APIParams, func(e *vo.Param) (string, *vo.Param) {
|
||||
return e.Name, e
|
||||
})
|
||||
pluginIDParam, ok := apiParams["pluginID"]
|
||||
if !ok {
|
||||
return fmt.Errorf("plugin id param is not found")
|
||||
}
|
||||
|
||||
pID, err := strconv.ParseInt(pluginIDParam.Input.Value.Content.(string), 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pluginVersionParam, ok := apiParams["pluginVersion"]
|
||||
if !ok {
|
||||
return fmt.Errorf("plugin version param is not found")
|
||||
}
|
||||
|
||||
if refPlugin, ok := relatedPlugins[pID]; ok {
|
||||
pluginIDParam.Input.Value.Content = refPlugin.PluginID
|
||||
if refPlugin.PluginVersion != nil {
|
||||
pluginVersionParam.Input.Value.Content = *refPlugin.PluginVersion
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
if node.Type == vo.BlockTypeBotLLM {
|
||||
if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil {
|
||||
for idx := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList {
|
||||
wf := node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList[idx]
|
||||
workflowID, err := strconv.ParseInt(wf.WorkflowID, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if refWf, ok := relatedWorkflows[workflowID]; ok {
|
||||
wf.WorkflowID = strconv.FormatInt(refWf.ID, 10)
|
||||
wf.WorkflowVersion = refWf.Version
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.PluginFCParam != nil {
|
||||
for idx := range node.Data.Inputs.FCParam.PluginFCParam.PluginList {
|
||||
pl := node.Data.Inputs.FCParam.PluginFCParam.PluginList[idx]
|
||||
pluginID, err := strconv.ParseInt(pl.PluginID, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if refPlugin, ok := relatedPlugins[pluginID]; ok {
|
||||
pl.PluginID = strconv.FormatInt(refPlugin.PluginID, 10)
|
||||
if refPlugin.PluginVersion != nil {
|
||||
pl.PluginVersion = *refPlugin.PluginVersion
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
if len(node.Blocks) > 0 {
|
||||
err := replaceRelatedWorkflowOrPluginInWorkflowNodes(node.Blocks, relatedWorkflows, relatedPlugins)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// entityNodeTypeToBlockType converts an entity.NodeType to the corresponding vo.BlockType.
|
||||
func entityNodeTypeToBlockType(nodeType entity.NodeType) (vo.BlockType, error) {
|
||||
switch nodeType {
|
||||
case entity.NodeTypeEntry:
|
||||
return vo.BlockTypeBotStart, nil
|
||||
case entity.NodeTypeExit:
|
||||
return vo.BlockTypeBotEnd, nil
|
||||
case entity.NodeTypeLLM:
|
||||
return vo.BlockTypeBotLLM, nil
|
||||
case entity.NodeTypePlugin:
|
||||
return vo.BlockTypeBotAPI, nil
|
||||
case entity.NodeTypeCodeRunner:
|
||||
return vo.BlockTypeBotCode, nil
|
||||
case entity.NodeTypeKnowledgeRetriever:
|
||||
return vo.BlockTypeBotDataset, nil
|
||||
case entity.NodeTypeSelector:
|
||||
return vo.BlockTypeCondition, nil
|
||||
case entity.NodeTypeSubWorkflow:
|
||||
return vo.BlockTypeBotSubWorkflow, nil
|
||||
case entity.NodeTypeDatabaseCustomSQL:
|
||||
return vo.BlockTypeDatabase, nil
|
||||
case entity.NodeTypeOutputEmitter:
|
||||
return vo.BlockTypeBotMessage, nil
|
||||
case entity.NodeTypeTextProcessor:
|
||||
return vo.BlockTypeBotText, nil
|
||||
case entity.NodeTypeQuestionAnswer:
|
||||
return vo.BlockTypeQuestion, nil
|
||||
case entity.NodeTypeBreak:
|
||||
return vo.BlockTypeBotBreak, nil
|
||||
case entity.NodeTypeVariableAssigner:
|
||||
return vo.BlockTypeBotAssignVariable, nil
|
||||
case entity.NodeTypeVariableAssignerWithinLoop:
|
||||
return vo.BlockTypeBotLoopSetVariable, nil
|
||||
case entity.NodeTypeLoop:
|
||||
return vo.BlockTypeBotLoop, nil
|
||||
case entity.NodeTypeIntentDetector:
|
||||
return vo.BlockTypeBotIntent, nil
|
||||
case entity.NodeTypeKnowledgeIndexer:
|
||||
return vo.BlockTypeBotDatasetWrite, nil
|
||||
case entity.NodeTypeBatch:
|
||||
return vo.BlockTypeBotBatch, nil
|
||||
case entity.NodeTypeContinue:
|
||||
return vo.BlockTypeBotContinue, nil
|
||||
case entity.NodeTypeInputReceiver:
|
||||
return vo.BlockTypeBotInput, nil
|
||||
case entity.NodeTypeDatabaseUpdate:
|
||||
return vo.BlockTypeDatabaseUpdate, nil
|
||||
case entity.NodeTypeDatabaseQuery:
|
||||
return vo.BlockTypeDatabaseSelect, nil
|
||||
case entity.NodeTypeDatabaseDelete:
|
||||
return vo.BlockTypeDatabaseDelete, nil
|
||||
case entity.NodeTypeHTTPRequester:
|
||||
return vo.BlockTypeBotHttp, nil
|
||||
case entity.NodeTypeDatabaseInsert:
|
||||
return vo.BlockTypeDatabaseInsert, nil
|
||||
case entity.NodeTypeVariableAggregator:
|
||||
return vo.BlockTypeBotVariableMerge, nil
|
||||
case entity.NodeTypeJsonSerialization:
|
||||
return vo.BlockTypeJsonSerialization, nil
|
||||
case entity.NodeTypeJsonDeserialization:
|
||||
return vo.BlockTypeJsonDeserialization, nil
|
||||
case entity.NodeTypeKnowledgeDeleter:
|
||||
return vo.BlockTypeBotDatasetDelete, nil
|
||||
|
||||
default:
|
||||
return "", vo.WrapError(errno.ErrSchemaConversionFail,
|
||||
fmt.Errorf("cannot map entity node type '%s' to a workflow.NodeTemplateType", nodeType))
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ func init() {
|
|||
|
||||
code.Register(
|
||||
ErrMissingRequiredParam,
|
||||
"Missing required parameters {param}. Please review the API documentation and ensure all mandatory fields are included in your request.",
|
||||
"Missing required parameters:’{param}‘. Please review the API documentation and ensure all mandatory fields are included in your request.",
|
||||
code.WithAffectStability(false),
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue