feat: Support for Chat Flow & Agent Support for binding a single chat flow (#765)

Co-authored-by: Yu Yang <72337138+tomasyu985@users.noreply.github.com>
Co-authored-by: zengxiaohui <csu.zengxiaohui@gmail.com>
Co-authored-by: lijunwen.gigoo <lijunwen.gigoo@bytedance.com>
Co-authored-by: lvxinyu.1117 <lvxinyu.1117@bytedance.com>
Co-authored-by: liuyunchao.0510 <liuyunchao.0510@bytedance.com>
Co-authored-by: haozhenfei <37089575+haozhenfei@users.noreply.github.com>
Co-authored-by: July <jiangxujin@bytedance.com>
Co-authored-by: tecvan-fe <fanwenjie.fe@bytedance.com>
This commit is contained in:
Zhj
2025-08-28 21:53:32 +08:00
committed by GitHub
parent bbc615a18e
commit d70101c979
503 changed files with 48036 additions and 3427 deletions

View File

@@ -0,0 +1,490 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package service
import (
"context"
"fmt"
"github.com/coze-dev/coze-studio/backend/api/model/conversation/common"
conventity "github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/entity"
workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
"github.com/coze-dev/coze-studio/backend/pkg/taskgroup"
workflow2 "github.com/coze-dev/coze-studio/backend/api/model/workflow"
crossconversation "github.com/coze-dev/coze-studio/backend/crossdomain/contract/conversation"
"github.com/coze-dev/coze-studio/backend/domain/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
"github.com/coze-dev/coze-studio/backend/types/consts"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
type conversationImpl struct {
repo workflow.Repository
}
func (c *conversationImpl) CreateDraftConversationTemplate(ctx context.Context, template *vo.CreateConversationTemplateMeta) (int64, error) {
var (
spaceID = template.SpaceID
appID = template.AppID
name = template.Name
userID = template.UserID
)
existed, err := c.IsDraftConversationNameExist(ctx, appID, userID, template.Name)
if err != nil {
return 0, err
}
if existed {
return 0, vo.WrapError(errno.ErrConversationNameIsDuplicated, fmt.Errorf("conversation name %s exists", name), errorx.KV("name", name))
}
return c.repo.CreateDraftConversationTemplate(ctx, &vo.CreateConversationTemplateMeta{
SpaceID: spaceID,
AppID: appID,
Name: name,
UserID: userID,
})
}
func (c *conversationImpl) IsDraftConversationNameExist(ctx context.Context, appID int64, userID int64, name string) (bool, error) {
_, existed, err := c.repo.GetDynamicConversationByName(ctx, vo.Draft, appID, consts.CozeConnectorID, userID, name)
if err != nil {
return false, err
}
if existed {
return true, nil
}
_, existed, err = c.repo.GetConversationTemplate(ctx, vo.Draft, vo.GetConversationTemplatePolicy{AppID: ptr.Of(appID), Name: ptr.Of(name)})
if err != nil {
return false, err
}
if existed {
return true, nil
}
return false, nil
}
func (c *conversationImpl) UpdateDraftConversationTemplateName(ctx context.Context, appID int64, userID int64, templateID int64, modifiedName string) error {
template, existed, err := c.repo.GetConversationTemplate(ctx, vo.Draft, vo.GetConversationTemplatePolicy{TemplateID: ptr.Of(templateID)})
if err != nil {
return err
}
if existed && template.Name == modifiedName {
return nil
}
existed, err = c.IsDraftConversationNameExist(ctx, appID, userID, modifiedName)
if err != nil {
return err
}
if existed {
return vo.WrapError(errno.ErrConversationNameIsDuplicated, fmt.Errorf("conversation name %s exists", modifiedName), errorx.KV("name", modifiedName))
}
wfs, err := c.findReplaceWorkflowByConversationName(ctx, appID, template.Name)
if err != nil {
return err
}
err = c.replaceWorkflowsConversationName(ctx, wfs, slices.ToMap(wfs, func(e *entity.Workflow) (int64, string) {
return e.ID, modifiedName
}))
if err != nil {
return err
}
return c.repo.UpdateDraftConversationTemplateName(ctx, templateID, modifiedName)
}
func (c *conversationImpl) CheckWorkflowsToReplace(ctx context.Context, appID int64, templateID int64) ([]*entity.Workflow, error) {
template, existed, err := c.repo.GetConversationTemplate(ctx, vo.Draft, vo.GetConversationTemplatePolicy{TemplateID: ptr.Of(templateID)})
if err != nil {
return nil, err
}
if existed {
return c.findReplaceWorkflowByConversationName(ctx, appID, template.Name)
}
return []*entity.Workflow{}, nil
}
func (c *conversationImpl) DeleteDraftConversationTemplate(ctx context.Context, templateID int64, wfID2ConversationName map[int64]string) (int64, error) {
if len(wfID2ConversationName) == 0 {
return c.repo.DeleteDraftConversationTemplate(ctx, templateID)
}
workflowIDs := make([]int64, 0)
for id := range wfID2ConversationName {
workflowIDs = append(workflowIDs, id)
}
wfs, _, err := c.repo.MGetDrafts(ctx, &vo.MGetPolicy{
MetaQuery: vo.MetaQuery{
IDs: workflowIDs,
},
QType: workflowModel.FromDraft,
})
if err != nil {
return 0, err
}
err = c.replaceWorkflowsConversationName(ctx, wfs, wfID2ConversationName)
if err != nil {
return 0, err
}
return c.repo.DeleteDraftConversationTemplate(ctx, templateID)
}
func (c *conversationImpl) ListConversationTemplate(ctx context.Context, env vo.Env, policy *vo.ListConversationTemplatePolicy) ([]*entity.ConversationTemplate, error) {
var (
err error
templates []*entity.ConversationTemplate
appID = policy.AppID
)
templates, err = c.repo.ListConversationTemplate(ctx, env, &vo.ListConversationTemplatePolicy{
AppID: appID,
Page: policy.Page,
NameLike: policy.NameLike,
Version: policy.Version,
})
if err != nil {
return nil, err
}
return templates, nil
}
func (c *conversationImpl) MGetStaticConversation(ctx context.Context, env vo.Env, userID, connectorID int64, templateIDs []int64) ([]*entity.StaticConversation, error) {
return c.repo.MGetStaticConversation(ctx, env, userID, connectorID, templateIDs)
}
func (c *conversationImpl) ListDynamicConversation(ctx context.Context, env vo.Env, policy *vo.ListConversationPolicy) ([]*entity.DynamicConversation, error) {
return c.repo.ListDynamicConversation(ctx, env, policy)
}
func (c *conversationImpl) ReleaseConversationTemplate(ctx context.Context, appID int64, version string) error {
templates, err := c.repo.ListConversationTemplate(ctx, vo.Draft, &vo.ListConversationTemplatePolicy{
AppID: appID,
})
if err != nil {
return err
}
if len(templates) == 0 {
return nil
}
return c.repo.BatchCreateOnlineConversationTemplate(ctx, templates, version)
}
func (c *conversationImpl) InitApplicationDefaultConversationTemplate(ctx context.Context, spaceID, appID int64, userID int64) error {
_, err := c.repo.CreateDraftConversationTemplate(ctx, &vo.CreateConversationTemplateMeta{
AppID: appID,
SpaceID: spaceID,
UserID: userID,
Name: "Default",
})
if err != nil {
return err
}
return nil
}
func (c *conversationImpl) findReplaceWorkflowByConversationName(ctx context.Context, appID int64, name string) ([]*entity.Workflow, error) {
wfs, _, err := c.repo.MGetDrafts(ctx, &vo.MGetPolicy{
QType: workflowModel.FromDraft,
MetaQuery: vo.MetaQuery{
AppID: ptr.Of(appID),
Mode: ptr.Of(workflow2.WorkflowMode_ChatFlow),
},
})
if err != nil {
return nil, err
}
shouldReplacedWorkflow := func(nodes []*vo.Node) (bool, error) {
var startNode *vo.Node
for _, node := range nodes {
if node.Type == entity.NodeTypeEntry.IDStr() {
startNode = node
}
}
if startNode == nil {
return false, fmt.Errorf("start node not found for block type")
}
for _, vAny := range startNode.Data.Outputs {
v, err := vo.ParseVariable(vAny)
if err != nil {
return false, err
}
if v.Name == "CONVERSATION_NAME" && v.DefaultValue == name {
return true, nil
}
}
return false, nil
}
shouldReplacedWorkflows := make([]*entity.Workflow, 0)
for idx := range wfs {
wf := wfs[idx]
canvas := &vo.Canvas{}
err = sonic.UnmarshalString(wf.Canvas, canvas)
if err != nil {
return nil, err
}
ok, err := shouldReplacedWorkflow(canvas.Nodes)
if err != nil {
return nil, err
}
if ok {
shouldReplacedWorkflows = append(shouldReplacedWorkflows, wf)
}
}
return shouldReplacedWorkflows, nil
}
func (c *conversationImpl) replaceWorkflowsConversationName(ctx context.Context, wfs []*entity.Workflow, workflowID2ConversionName map[int64]string) error {
replaceConversionName := func(nodes []*vo.Node, conversionName string) error {
var startNode *vo.Node
for _, node := range nodes {
if node.Type == entity.NodeTypeEntry.IDStr() {
startNode = node
}
}
if startNode == nil {
return fmt.Errorf("start node not found for block type")
}
for idx, vAny := range startNode.Data.Outputs {
v, err := vo.ParseVariable(vAny)
if err != nil {
return err
}
if v.Name == "CONVERSATION_NAME" {
v.DefaultValue = conversionName
}
startNode.Data.Outputs[idx] = v
}
return nil
}
tg := taskgroup.NewTaskGroup(ctx, len(wfs))
for _, wf := range wfs {
wfEntity := wf
tg.Go(func() error {
canvas := &vo.Canvas{}
err := sonic.UnmarshalString(wfEntity.Canvas, canvas)
if err != nil {
return err
}
conversationName := workflowID2ConversionName[wfEntity.ID]
err = replaceConversionName(canvas.Nodes, conversationName)
if err != nil {
return err
}
replaceCanvas, err := sonic.MarshalString(canvas)
if err != nil {
return err
}
err = c.repo.CreateOrUpdateDraft(ctx, wfEntity.ID, &vo.DraftInfo{
DraftMeta: &vo.DraftMeta{
TestRunSuccess: false,
Modified: true,
},
Canvas: replaceCanvas,
})
if err != nil {
return err
}
return nil
})
}
err := tg.Wait()
if err != nil {
return err
}
return nil
}
func (c *conversationImpl) DeleteDynamicConversation(ctx context.Context, env vo.Env, templateID int64) (int64, error) {
return c.repo.DeleteDynamicConversation(ctx, env, templateID)
}
func (c *conversationImpl) GetOrCreateConversation(ctx context.Context, env vo.Env, appID, connectorID, userID int64, conversationName string) (int64, int64, error) {
t, existed, err := c.repo.GetConversationTemplate(ctx, env, vo.GetConversationTemplatePolicy{
AppID: ptr.Of(appID),
Name: ptr.Of(conversationName),
})
if err != nil {
return 0, 0, err
}
conversationIDGenerator := workflow.ConversationIDGenerator(func(ctx context.Context, appID int64, userID, connectorID int64) (*conventity.Conversation, error) {
return crossconversation.DefaultSVC().CreateConversation(ctx, &conventity.CreateMeta{
AgentID: appID,
UserID: userID,
ConnectorID: connectorID,
Scene: common.Scene_SceneWorkflow,
})
})
if existed {
conversationID, sectionID, _, err := c.repo.GetOrCreateStaticConversation(ctx, env, conversationIDGenerator, &vo.CreateStaticConversation{
AppID: appID,
ConnectorID: connectorID,
UserID: userID,
TemplateID: t.TemplateID,
})
if err != nil {
return 0, 0, err
}
return conversationID, sectionID, nil
}
conversationID, sectionID, _, err := c.repo.GetOrCreateDynamicConversation(ctx, env, conversationIDGenerator, &vo.CreateDynamicConversation{
AppID: appID,
ConnectorID: connectorID,
UserID: userID,
Name: conversationName,
})
if err != nil {
return 0, 0, err
}
return conversationID, sectionID, nil
}
func (c *conversationImpl) UpdateConversation(ctx context.Context, env vo.Env, appID, connectorID, userID int64, conversationName string) (int64, error) {
t, existed, err := c.repo.GetConversationTemplate(ctx, env, vo.GetConversationTemplatePolicy{
AppID: ptr.Of(appID),
Name: ptr.Of(conversationName),
})
if err != nil {
return 0, err
}
if existed {
conv, err := crossconversation.DefaultSVC().CreateConversation(ctx, &conventity.CreateMeta{
AgentID: appID,
UserID: userID,
ConnectorID: connectorID,
Scene: common.Scene_SceneWorkflow,
})
if err != nil {
return 0, err
}
if conv == nil {
return 0, fmt.Errorf("create conversation failed")
}
err = c.repo.UpdateStaticConversation(ctx, env, t.TemplateID, connectorID, userID, conv.ID)
if err != nil {
return 0, err
}
return conv.ID, nil
}
dy, existed, err := c.repo.GetDynamicConversationByName(ctx, env, appID, connectorID, userID, conversationName)
if err != nil {
return 0, err
}
if !existed {
return 0, fmt.Errorf("conversation name %v not found", conversationName)
}
conv, err := crossconversation.DefaultSVC().CreateConversation(ctx, &conventity.CreateMeta{
AgentID: appID,
UserID: userID,
ConnectorID: connectorID,
Scene: common.Scene_SceneWorkflow,
})
if err != nil {
return 0, err
}
if conv == nil {
return 0, fmt.Errorf("create conversation failed")
}
err = c.repo.UpdateDynamicConversation(ctx, env, dy.ConversationID, conv.ID)
if err != nil {
return 0, err
}
return conv.ID, nil
}
func (c *conversationImpl) GetTemplateByName(ctx context.Context, env vo.Env, appID int64, templateName string) (*entity.ConversationTemplate, bool, error) {
return c.repo.GetConversationTemplate(ctx, env, vo.GetConversationTemplatePolicy{
AppID: ptr.Of(appID),
Name: ptr.Of(templateName),
})
}
func (c *conversationImpl) GetDynamicConversationByName(ctx context.Context, env vo.Env, appID, connectorID, userID int64, name string) (*entity.DynamicConversation, bool, error) {
return c.repo.GetDynamicConversationByName(ctx, env, appID, connectorID, userID, name)
}
func (c *conversationImpl) GetConversationNameByID(ctx context.Context, env vo.Env, appID, connectorID, conversationID int64) (string, bool, error) {
sc, existed, err := c.repo.GetStaticConversationByID(ctx, env, appID, connectorID, conversationID)
if err != nil {
return "", false, err
}
if existed {
return sc, true, nil
}
dc, existed, err := c.repo.GetDynamicConversationByID(ctx, env, appID, connectorID, conversationID)
if err != nil {
return "", false, err
}
if existed {
return dc.Name, true, nil
}
return "", false, nil
}
func (c *conversationImpl) Suggest(ctx context.Context, input *vo.SuggestInfo) ([]string, error) {
return c.repo.Suggest(ctx, input)
}

View File

@@ -26,6 +26,8 @@ import (
"github.com/cloudwego/eino/schema"
workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
workflowapimodel "github.com/coze-dev/coze-studio/backend/api/model/workflow"
crossmessage "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message"
"github.com/coze-dev/coze-studio/backend/domain/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
@@ -62,6 +64,8 @@ func (i *impl) SyncExecute(ctx context.Context, config workflowModel.ExecuteConf
return nil, "", err
}
config.WorkflowMode = wfEntity.Mode
isApplicationWorkflow := wfEntity.AppID != nil
if isApplicationWorkflow && config.Mode == workflowModel.ExecuteModeRelease {
err = i.checkApplicationWorkflowReleaseVersion(ctx, *wfEntity.AppID, config.ConnectorID, config.ID, config.Version)
@@ -207,6 +211,8 @@ func (i *impl) AsyncExecute(ctx context.Context, config workflowModel.ExecuteCon
return 0, err
}
config.WorkflowMode = wfEntity.Mode
isApplicationWorkflow := wfEntity.AppID != nil
if isApplicationWorkflow && config.Mode == workflowModel.ExecuteModeRelease {
err = i.checkApplicationWorkflowReleaseVersion(ctx, *wfEntity.AppID, config.ConnectorID, config.ID, config.Version)
@@ -292,6 +298,8 @@ func (i *impl) AsyncExecuteNode(ctx context.Context, nodeID string, config workf
return 0, err
}
config.WorkflowMode = wfEntity.Mode
isApplicationWorkflow := wfEntity.AppID != nil
if isApplicationWorkflow && config.Mode == workflowModel.ExecuteModeRelease {
err = i.checkApplicationWorkflowReleaseVersion(ctx, *wfEntity.AppID, config.ConnectorID, config.ID, config.Version)
@@ -300,6 +308,30 @@ func (i *impl) AsyncExecuteNode(ctx context.Context, nodeID string, config workf
}
}
historyRounds := int64(0)
if config.WorkflowMode == workflowapimodel.WorkflowMode_ChatFlow {
historyRounds, err = getHistoryRoundsFromNode(ctx, wfEntity, nodeID, i.repo)
if err != nil {
return 0, err
}
}
if historyRounds > 0 {
messages, scMessages, err := i.prefetchChatHistory(ctx, config, historyRounds)
if err != nil {
logs.CtxErrorf(ctx, "failed to prefetch chat history: %v", err)
}
if len(messages) > 0 {
config.ConversationHistory = messages
}
if len(scMessages) > 0 {
config.ConversationHistorySchemaMessages = scMessages
}
}
c := &vo.Canvas{}
if err = sonic.UnmarshalString(wfEntity.Canvas, c); err != nil {
return 0, fmt.Errorf("failed to unmarshal canvas: %w", err)
@@ -375,6 +407,8 @@ func (i *impl) StreamExecute(ctx context.Context, config workflowModel.ExecuteCo
return nil, err
}
config.WorkflowMode = wfEntity.Mode
isApplicationWorkflow := wfEntity.AppID != nil
if isApplicationWorkflow && config.Mode == workflowModel.ExecuteModeRelease {
err = i.checkApplicationWorkflowReleaseVersion(ctx, *wfEntity.AppID, config.ConnectorID, config.ID, config.Version)
@@ -383,6 +417,29 @@ func (i *impl) StreamExecute(ctx context.Context, config workflowModel.ExecuteCo
}
}
historyRounds := int64(0)
if config.WorkflowMode == workflowapimodel.WorkflowMode_ChatFlow {
historyRounds, err = i.calculateMaxChatHistoryRounds(ctx, wfEntity, i.repo)
if err != nil {
return nil, err
}
}
if historyRounds > 0 {
messages, scMessages, err := i.prefetchChatHistory(ctx, config, historyRounds)
if err != nil {
logs.CtxErrorf(ctx, "failed to prefetch chat history: %v", err)
}
if len(messages) > 0 {
config.ConversationHistory = messages
}
if len(scMessages) > 0 {
config.ConversationHistorySchemaMessages = scMessages
}
}
c := &vo.Canvas{}
if err = sonic.UnmarshalString(wfEntity.Canvas, c); err != nil {
return nil, fmt.Errorf("failed to unmarshal canvas: %w", err)
@@ -718,6 +775,7 @@ func (i *impl) AsyncResume(ctx context.Context, req *entity.ResumeRequest, confi
config.AppID = wfExe.AppID
config.AgentID = wfExe.AgentID
config.CommitID = wfExe.CommitID
config.WorkflowMode = wfEntity.Mode
if config.ConnectorID == 0 {
config.ConnectorID = wfExe.ConnectorID
@@ -859,6 +917,7 @@ func (i *impl) StreamResume(ctx context.Context, req *entity.ResumeRequest, conf
config.AppID = wfExe.AppID
config.AgentID = wfExe.AgentID
config.CommitID = wfExe.CommitID
config.WorkflowMode = wfEntity.Mode
if config.ConnectorID == 0 {
config.ConnectorID = wfExe.ConnectorID
@@ -937,3 +996,73 @@ func (i *impl) checkApplicationWorkflowReleaseVersion(ctx context.Context, appID
return nil
}
const maxHistoryRounds int64 = 30
func (i *impl) calculateMaxChatHistoryRounds(ctx context.Context, wfEntity *entity.Workflow, repo workflow.Repository) (int64, error) {
if wfEntity == nil {
return 0, nil
}
maxRounds, err := getMaxHistoryRoundsRecursively(ctx, wfEntity, repo)
if err != nil {
return 0, err
}
return min(maxRounds, maxHistoryRounds), nil
}
func (i *impl) prefetchChatHistory(ctx context.Context, config workflowModel.ExecuteConfig, historyRounds int64) ([]*crossmessage.WfMessage, []*schema.Message, error) {
convID := config.ConversationID
agentID := config.AgentID
appID := config.AppID
userID := config.Operator
sectionID := config.SectionID
if sectionID == nil {
logs.CtxWarnf(ctx, "SectionID is nil, skipping chat history")
return nil, nil, nil
}
if convID == nil || *convID == 0 {
logs.CtxWarnf(ctx, "ConversationID is 0 or nil, skipping chat history")
return nil, nil, nil
}
var resolvedAppID int64
if appID != nil {
resolvedAppID = *appID
} else if agentID != nil {
resolvedAppID = *agentID
} else {
logs.CtxWarnf(ctx, "AppID and AgentID are both nil, skipping chat history")
return nil, nil, nil
}
runIdsReq := &crossmessage.GetLatestRunIDsRequest{
ConversationID: *convID,
AppID: resolvedAppID,
UserID: userID,
Rounds: historyRounds + 1,
SectionID: *sectionID,
}
runIds, err := crossmessage.DefaultSVC().GetLatestRunIDs(ctx, runIdsReq)
if err != nil {
logs.CtxErrorf(ctx, "failed to get latest run ids: %v", err)
return nil, nil, nil
}
if len(runIds) <= 1 {
return []*crossmessage.WfMessage{}, []*schema.Message{}, nil
}
runIds = runIds[1:]
response, err := crossmessage.DefaultSVC().GetMessagesByRunIDs(ctx, &crossmessage.GetMessagesByRunIDsRequest{
ConversationID: *convID,
RunIDs: runIds,
})
if err != nil {
logs.CtxErrorf(ctx, "failed to get messages by run ids: %v", err)
return nil, nil, nil
}
return response.Messages, response.SchemaMessages, nil
}

View File

@@ -20,13 +20,14 @@ import (
"context"
"errors"
"fmt"
"strconv"
"github.com/spf13/cast"
"golang.org/x/exp/maps"
"golang.org/x/sync/errgroup"
"gorm.io/gorm"
"strconv"
einoCompose "github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
@@ -38,6 +39,7 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/adaptor"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/intentdetector"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/knowledge"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/repo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
@@ -45,6 +47,7 @@ import (
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
@@ -56,6 +59,7 @@ type impl struct {
repo workflow.Repository
*asToolImpl
*executableImpl
*conversationImpl
}
func NewWorkflowService(repo workflow.Repository) workflow.Service {
@@ -67,12 +71,14 @@ func NewWorkflowService(repo workflow.Repository) workflow.Service {
executableImpl: &executableImpl{
repo: repo,
},
conversationImpl: &conversationImpl{repo: repo},
}
}
func NewWorkflowRepository(idgen idgen.IDGenerator, db *gorm.DB, redis cache.Cmdable, tos storage.Storage,
cpStore einoCompose.CheckPointStore, chatModel chatmodel.BaseChatModel, workflowConfig workflow.WorkflowConfig) workflow.Repository {
return repo.NewRepository(idgen, db, redis, tos, cpStore, chatModel, workflowConfig)
cpStore einoCompose.CheckPointStore, chatModel chatmodel.BaseChatModel, cfg workflow.WorkflowConfig) (workflow.Repository, error) {
return repo.NewRepository(idgen, db, redis, tos, cpStore, chatModel, cfg)
}
func (i *impl) ListNodeMeta(_ context.Context, nodeTypes map[entity.NodeType]bool) (map[string][]*entity.NodeTypeMeta, []entity.Category, error) {
@@ -440,10 +446,13 @@ func (i *impl) collectNodePropertyMap(ctx context.Context, canvas *vo.Canvas) (m
var canvasSchema string
if n.Data.Inputs.WorkflowVersion != "" {
versionInfo, err := i.repo.GetVersion(ctx, wid, n.Data.Inputs.WorkflowVersion)
versionInfo, existed, err := i.repo.GetVersion(ctx, wid, n.Data.Inputs.WorkflowVersion)
if err != nil {
return nil, err
}
if !existed {
return nil, vo.WrapError(errno.ErrWorkflowNotFound, fmt.Errorf("workflow version %s not found for ID %d: %w", n.Data.Inputs.WorkflowVersion, wid, err), errorx.KV("id", strconv.FormatInt(wid, 10)))
}
canvasSchema = versionInfo.Canvas
} else {
draftInfo, err := i.repo.DraftV2(ctx, wid, "")
@@ -522,6 +531,9 @@ func isEnableChatHistory(s *schema.NodeSchema) bool {
case entity.NodeTypeIntentDetector:
llmParam := s.Configs.(*intentdetector.Config).LLMParams
return llmParam.EnableChatHistory
case entity.NodeTypeKnowledgeRetriever:
chatParam := s.Configs.(*knowledge.RetrieveConfig).ChatHistorySetting
return chatParam != nil && chatParam.EnableChatHistory
default:
return false
}
@@ -541,6 +553,103 @@ func isRefGlobalVariable(s *schema.NodeSchema) bool {
return false
}
func (i *impl) CreateChatFlowRole(ctx context.Context, role *vo.ChatFlowRoleCreate) (int64, error) {
id, err := i.repo.CreateChatFlowRoleConfig(ctx, &entity.ChatFlowRole{
Name: role.Name,
Description: role.Description,
WorkflowID: role.WorkflowID,
CreatorID: role.CreatorID,
AudioConfig: role.AudioConfig,
UserInputConfig: role.UserInputConfig,
AvatarUri: role.AvatarUri,
BackgroundImageInfo: role.BackgroundImageInfo,
OnboardingInfo: role.OnboardingInfo,
SuggestReplyInfo: role.SuggestReplyInfo,
})
if err != nil {
return 0, err
}
return id, nil
}
func (i *impl) UpdateChatFlowRole(ctx context.Context, workflowID int64, role *vo.ChatFlowRoleUpdate) error {
err := i.repo.UpdateChatFlowRoleConfig(ctx, workflowID, role)
if err != nil {
return err
}
return nil
}
func (i *impl) GetChatFlowRole(ctx context.Context, workflowID int64, version string) (*entity.ChatFlowRole, error) {
role, err, isExist := i.repo.GetChatFlowRoleConfig(ctx, workflowID, version)
if !isExist {
logs.CtxWarnf(ctx, "chat flow role not exist, workflow id %v, version %v", workflowID, version)
// Return (nil, nil) on 'NotExist' to align with the production behavior,
// where the GET API may be called before the CREATE API during chatflow creation.
return nil, nil
}
if err != nil {
return nil, err
}
return role, nil
}
func (i *impl) GetWorkflowVersionsByConnector(ctx context.Context, connectorID, workflowID int64, limit int) ([]string, error) {
return i.repo.GetVersionListByConnectorAndWorkflowID(ctx, connectorID, workflowID, limit)
}
func (i *impl) DeleteChatFlowRole(ctx context.Context, id int64, workflowID int64) error {
return i.repo.DeleteChatFlowRoleConfig(ctx, id, workflowID)
}
func (i *impl) PublishChatFlowRole(ctx context.Context, policy *vo.PublishRolePolicy) error {
if policy.WorkflowID == 0 || policy.CreatorID == 0 || policy.Version == "" {
logs.CtxErrorf(ctx, "invalid publish role policy, workflow id %v, creator id %v should not be zero, version %v should not be empty", policy.WorkflowID, policy.CreatorID, policy.Version)
return vo.WrapError(errno.ErrInvalidParameter, fmt.Errorf("invalid publish role policy, workflow id %v, creator id %v should not be zero, version %v should not be empty", policy.WorkflowID, policy.CreatorID, policy.Version))
}
wf, err := i.repo.GetEntity(ctx, &vo.GetPolicy{
ID: policy.WorkflowID,
MetaOnly: true,
})
if err != nil {
return err
}
if wf.Mode != cloudworkflow.WorkflowMode_ChatFlow {
return vo.WrapError(errno.ErrChatFlowRoleOperationFail, fmt.Errorf("workflow id %v, mode %v is not a chatflow", policy.WorkflowID, wf.Mode))
}
role, err, isExist := i.repo.GetChatFlowRoleConfig(ctx, policy.WorkflowID, "")
if !isExist {
logs.CtxErrorf(ctx, "get draft chat flow role nil, workflow id %v", policy.WorkflowID)
return vo.WrapError(errno.ErrChatFlowRoleOperationFail, fmt.Errorf("get draft chat flow role nil, workflow id %v", policy.WorkflowID))
}
if err != nil {
return vo.WrapIfNeeded(errno.ErrChatFlowRoleOperationFail, err)
}
_, err = i.repo.CreateChatFlowRoleConfig(ctx, &entity.ChatFlowRole{
Name: role.Name,
Description: role.Description,
WorkflowID: policy.WorkflowID,
CreatorID: policy.CreatorID,
AudioConfig: role.AudioConfig,
UserInputConfig: role.UserInputConfig,
AvatarUri: role.AvatarUri,
BackgroundImageInfo: role.BackgroundImageInfo,
OnboardingInfo: role.OnboardingInfo,
SuggestReplyInfo: role.SuggestReplyInfo,
Version: policy.Version,
})
if err != nil {
return err
}
return nil
}
func canvasToRefs(referringID int64, canvasStr string) (map[entity.WorkflowReferenceKey]struct{}, error) {
var canvas vo.Canvas
if err := sonic.UnmarshalString(canvasStr, &canvas); err != nil {
@@ -659,6 +768,13 @@ func (i *impl) UpdateMeta(ctx context.Context, id int64, metaUpdate *vo.MetaUpda
return err
}
if metaUpdate.WorkflowMode != nil && *metaUpdate.WorkflowMode == cloudworkflow.WorkflowMode_ChatFlow {
err = i.adaptToChatFlow(ctx, id)
if err != nil {
return err
}
}
return nil
}
@@ -667,6 +783,35 @@ func (i *impl) CopyWorkflow(ctx context.Context, workflowID int64, policy vo.Cop
if err != nil {
return nil, err
}
// chat flow should copy role config
if wf.Mode == cloudworkflow.WorkflowMode_ChatFlow {
role, err, isExist := i.repo.GetChatFlowRoleConfig(ctx, workflowID, "")
if !isExist {
logs.CtxErrorf(ctx, "get draft chat flow role nil, workflow id %v", workflowID)
return nil, vo.WrapError(errno.ErrChatFlowRoleOperationFail, fmt.Errorf("get draft chat flow role nil, workflow id %v", workflowID))
}
if err != nil {
return nil, vo.WrapIfNeeded(errno.ErrChatFlowRoleOperationFail, err)
}
_, err = i.repo.CreateChatFlowRoleConfig(ctx, &entity.ChatFlowRole{
Name: role.Name,
Description: role.Description,
WorkflowID: wf.ID,
CreatorID: wf.CreatorID,
AudioConfig: role.AudioConfig,
UserInputConfig: role.UserInputConfig,
AvatarUri: role.AvatarUri,
BackgroundImageInfo: role.BackgroundImageInfo,
OnboardingInfo: role.OnboardingInfo,
SuggestReplyInfo: role.SuggestReplyInfo,
})
if err != nil {
return nil, err
}
}
return wf, nil
@@ -677,7 +822,7 @@ func (i *impl) ReleaseApplicationWorkflows(ctx context.Context, appID int64, con
return nil, fmt.Errorf("connector ids is required")
}
wfs, _, err := i.MGet(ctx, &vo.MGetPolicy{
allWorkflowsInApp, _, err := i.MGet(ctx, &vo.MGetPolicy{
MetaQuery: vo.MetaQuery{
AppID: &appID,
},
@@ -688,14 +833,15 @@ func (i *impl) ReleaseApplicationWorkflows(ctx context.Context, appID int64, con
}
relatedPlugins := make(map[int64]*plugin.PluginEntity, len(config.PluginIDs))
relatedWorkflow := make(map[int64]entity.IDVersionPair, len(wfs))
relatedWorkflow := make(map[int64]entity.IDVersionPair, len(allWorkflowsInApp))
for _, wf := range wfs {
for _, wf := range allWorkflowsInApp {
relatedWorkflow[wf.ID] = entity.IDVersionPair{
ID: wf.ID,
Version: config.Version,
}
}
for _, id := range config.PluginIDs {
relatedPlugins[id] = &plugin.PluginEntity{
PluginID: id,
@@ -704,7 +850,22 @@ func (i *impl) ReleaseApplicationWorkflows(ctx context.Context, appID int64, con
}
vIssues := make([]*vo.ValidateIssue, 0)
for _, wf := range wfs {
willPublishWorkflows := make([]*entity.Workflow, 0)
if len(config.WorkflowIDs) == 0 {
willPublishWorkflows = allWorkflowsInApp
} else {
willPublishWorkflows, _, err = i.MGet(ctx, &vo.MGetPolicy{
MetaQuery: vo.MetaQuery{
AppID: &appID,
IDs: config.WorkflowIDs,
},
QType: workflowModel.FromDraft,
})
}
for _, wf := range willPublishWorkflows {
issues, err := validateWorkflowTree(ctx, vo.ValidateTreeConfig{
CanvasSchema: wf.Canvas,
AppID: ptr.Of(appID),
@@ -723,7 +884,7 @@ func (i *impl) ReleaseApplicationWorkflows(ctx context.Context, appID int64, con
return vIssues, nil
}
for _, wf := range wfs {
for _, wf := range willPublishWorkflows {
c := &vo.Canvas{}
err := sonic.UnmarshalString(wf.Canvas, c)
if err != nil {
@@ -747,9 +908,8 @@ func (i *impl) ReleaseApplicationWorkflows(ctx context.Context, appID int64, con
}
userID := ctxutil.MustGetUIDFromCtx(ctx)
workflowsToPublish := make(map[int64]*vo.VersionInfo)
for _, wf := range wfs {
for _, wf := range willPublishWorkflows {
inputStr, err := sonic.MarshalString(wf.InputParams)
if err != nil {
return nil, err
@@ -774,8 +934,16 @@ func (i *impl) ReleaseApplicationWorkflows(ctx context.Context, appID int64, con
}
}
workflowIDs := make([]int64, 0, len(wfs))
workflowIDs := make([]int64, 0, len(willPublishWorkflows))
for id, vInfo := range workflowsToPublish {
// if version existed skip
_, existed, err := i.repo.GetVersion(ctx, id, config.Version)
if err != nil {
return nil, err
}
if existed {
continue
}
wfRefs, err := canvasToRefs(id, vInfo.Canvas)
if err != nil {
return nil, err
@@ -787,6 +955,24 @@ func (i *impl) ReleaseApplicationWorkflows(ctx context.Context, appID int64, con
}
}
err = i.ReleaseConversationTemplate(ctx, appID, config.Version)
if err != nil {
return nil, err
}
for _, wf := range willPublishWorkflows {
if wf.Mode == cloudworkflow.WorkflowMode_ChatFlow {
err = i.PublishChatFlowRole(ctx, &vo.PublishRolePolicy{
WorkflowID: wf.ID,
CreatorID: wf.CreatorID,
Version: config.Version,
})
if err != nil {
return nil, err
}
}
}
for _, connectorID := range config.ConnectorIDs {
err = i.repo.BatchCreateConnectorWorkflowVersion(ctx, appID, connectorID, workflowIDs, config.Version)
if err != nil {
@@ -889,7 +1075,7 @@ func (i *impl) CopyWorkflowFromAppToLibrary(ctx context.Context, workflowID int6
}
if node.Type == entity.NodeTypeLLM.IDStr() {
if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil {
if node.Data.Inputs.LLM != nil && node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil {
for _, w := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList {
var (
v *vo.DraftInfo
@@ -1012,7 +1198,7 @@ func (i *impl) CopyWorkflowFromAppToLibrary(ctx context.Context, workflowID int6
return err
}
cwf, err := i.repo.CopyWorkflow(ctx, wf.id, vo.CopyWorkflowPolicy{
cwf, err := i.CopyWorkflow(ctx, wf.id, vo.CopyWorkflowPolicy{
TargetAppID: ptr.Of(int64(0)),
ModifiedCanvasSchema: ptr.Of(modifiedCanvasString),
})
@@ -1120,7 +1306,7 @@ func (i *impl) DuplicateWorkflowsByAppID(ctx context.Context, sourceAppID, targe
}
if node.Type == entity.NodeTypeLLM.IDStr() {
if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil {
if node.Data.Inputs.LLM != nil && node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil {
for _, w := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList {
var (
v *vo.DraftInfo
@@ -1246,6 +1432,11 @@ func (i *impl) DuplicateWorkflowsByAppID(ctx context.Context, sourceAppID, targe
}
}
err = i.repo.CopyTemplateConversationByAppID(ctx, sourceAppID, targetAppID)
if err != nil {
return nil, err
}
return copiedWorkflowArray, nil
}
@@ -1368,7 +1559,7 @@ func (i *impl) GetWorkflowDependenceResource(ctx context.Context, workflowID int
ds.DatabaseIDs = append(ds.DatabaseIDs, dsID)
}
case entity.NodeTypeLLM:
if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.PluginFCParam != nil {
if node.Data.Inputs.LLM != nil && node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.PluginFCParam != nil {
for idx := range node.Data.Inputs.FCParam.PluginFCParam.PluginList {
if node.Data.Inputs.FCParam.PluginFCParam.PluginList[idx].IsDraft {
pl := node.Data.Inputs.FCParam.PluginFCParam.PluginList[idx]
@@ -1382,7 +1573,7 @@ func (i *impl) GetWorkflowDependenceResource(ctx context.Context, workflowID int
}
}
if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.KnowledgeFCParam != nil {
if node.Data.Inputs.LLM != nil && node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.KnowledgeFCParam != nil {
for idx := range node.Data.Inputs.FCParam.KnowledgeFCParam.KnowledgeList {
kn := node.Data.Inputs.FCParam.KnowledgeFCParam.KnowledgeList[idx]
kid, err := strconv.ParseInt(kn.ID, 10, 64)
@@ -1394,7 +1585,7 @@ func (i *impl) GetWorkflowDependenceResource(ctx context.Context, workflowID int
}
}
if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil {
if node.Data.Inputs.LLM != nil && node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil {
for idx := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList {
if node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList[idx].IsDraft {
wID, err := strconv.ParseInt(node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList[idx].WorkflowID, 10, 64)
@@ -1467,6 +1658,165 @@ func (i *impl) GetWorkflowDependenceResource(ctx context.Context, workflowID int
}
func (i *impl) checkBotAgentNode(node *vo.Node) error {
if node.Type == entity.NodeTypeCreateConversation.IDStr() || node.Type == entity.NodeTypeConversationDelete.IDStr() || node.Type == entity.NodeTypeConversationUpdate.IDStr() || node.Type == entity.NodeTypeConversationList.IDStr() {
return errors.New("conversation-related nodes are not supported in chatflow")
}
return nil
}
func (i *impl) validateNodesRecursively(ctx context.Context, nodes []*vo.Node, checkType cloudworkflow.CheckType, visited map[string]struct{}, repo workflow.Repository) error {
queue := make([]*vo.Node, 0, len(nodes))
queue = append(queue, nodes...)
for len(queue) > 0 {
node := queue[0]
queue = queue[1:]
if node == nil {
continue
}
var checkErr error
switch checkType {
case cloudworkflow.CheckType_BotAgent:
checkErr = i.checkBotAgentNode(node)
default:
// For now, we only handle BotAgent check, so we can do nothing here.
// In the future, if there are other check types that need to be validated on every node, this logic will need to be adjusted.
}
if checkErr != nil {
return checkErr
}
// Enqueue nested nodes for BFS traversal. This handles Loop, Batch, and other nodes with nested blocks.
if len(node.Blocks) > 0 {
queue = append(queue, node.Blocks...)
}
if node.Type == entity.NodeTypeSubWorkflow.IDStr() && node.Data != nil && node.Data.Inputs != nil {
workflowIDStr := node.Data.Inputs.WorkflowID
if workflowIDStr == "" {
continue
}
workflowID, err := strconv.ParseInt(workflowIDStr, 10, 64)
if err != nil {
return fmt.Errorf("invalid workflow ID in sub-workflow node %s: %w", node.ID, err)
}
version := node.Data.Inputs.WorkflowVersion
qType := workflowModel.FromDraft
if version != "" {
qType = workflowModel.FromSpecificVersion
}
visitedKey := fmt.Sprintf("%d:%s", workflowID, version)
if _, ok := visited[visitedKey]; ok {
continue
}
visited[visitedKey] = struct{}{}
subWfEntity, err := repo.GetEntity(ctx, &vo.GetPolicy{
ID: workflowID,
QType: qType,
Version: version,
})
if err != nil {
delete(visited, visitedKey)
if errors.Is(err, gorm.ErrRecordNotFound) {
continue
}
return fmt.Errorf("failed to get sub-workflow entity %d: %w", workflowID, err)
}
var canvas vo.Canvas
if err := sonic.UnmarshalString(subWfEntity.Canvas, &canvas); err != nil {
return fmt.Errorf("failed to unmarshal canvas for workflow %d: %w", subWfEntity.ID, err)
}
queue = append(queue, canvas.Nodes...)
}
if node.Type == entity.NodeTypeLLM.IDStr() && node.Data != nil && node.Data.Inputs != nil && node.Data.Inputs.LLM != nil && node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil {
for _, subWfInfo := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList {
if subWfInfo.WorkflowID == "" {
continue
}
workflowID, err := strconv.ParseInt(subWfInfo.WorkflowID, 10, 64)
if err != nil {
return fmt.Errorf("invalid workflow ID in large model node %s: %w", node.ID, err)
}
version := subWfInfo.WorkflowVersion
qType := workflowModel.FromDraft
if version != "" {
qType = workflowModel.FromSpecificVersion
}
visitedKey := fmt.Sprintf("%d:%s", workflowID, version)
if _, ok := visited[visitedKey]; ok {
continue
}
visited[visitedKey] = struct{}{}
subWfEntity, err := repo.GetEntity(ctx, &vo.GetPolicy{
ID: workflowID,
QType: qType,
Version: version,
})
if err != nil {
delete(visited, visitedKey)
if errors.Is(err, gorm.ErrRecordNotFound) {
continue
}
return fmt.Errorf("failed to get sub-workflow entity %d from large model node: %w", workflowID, err)
}
var canvas vo.Canvas
if err := sonic.UnmarshalString(subWfEntity.Canvas, &canvas); err != nil {
return fmt.Errorf("failed to unmarshal canvas for workflow %d from large model node: %w", subWfEntity.ID, err)
}
queue = append(queue, canvas.Nodes...)
}
}
}
return nil
}
func (i *impl) WorkflowSchemaCheck(ctx context.Context, wf *entity.Workflow, checks []cloudworkflow.CheckType) ([]*cloudworkflow.CheckResult, error) {
checkResults := make([]*cloudworkflow.CheckResult, 0, len(checks))
var canvas vo.Canvas
if err := sonic.UnmarshalString(wf.Canvas, &canvas); err != nil {
return nil, fmt.Errorf("failed to unmarshal canvas for workflow %d: %w", wf.ID, err)
}
for _, checkType := range checks {
visited := make(map[string]struct{})
visitedKey := fmt.Sprintf("%d:%s", wf.ID, wf.GetVersion())
visited[visitedKey] = struct{}{}
err := i.validateNodesRecursively(ctx, canvas.Nodes, checkType, visited, i.repo)
if err != nil {
checkResults = append(checkResults, &cloudworkflow.CheckResult{
IsPass: false,
Reason: err.Error(),
Type: checkType,
})
} else {
checkResults = append(checkResults, &cloudworkflow.CheckResult{
IsPass: true,
Type: checkType,
Reason: "",
})
}
}
return checkResults, nil
}
func (i *impl) MGet(ctx context.Context, policy *vo.MGetPolicy) ([]*entity.Workflow, int64, error) {
if policy.MetaOnly {
metas, total, err := i.repo.MGetMetas(ctx, &policy.MetaQuery)
@@ -1527,11 +1877,13 @@ func (i *impl) MGet(ctx context.Context, policy *vo.MGetPolicy) ([]*entity.Workf
index := 0
for id, version := range policy.Versions {
v, err := i.repo.GetVersion(ctx, id, version)
v, existed, err := i.repo.GetVersion(ctx, id, version)
if err != nil {
return nil, total, err
}
if !existed {
return nil, total, vo.WrapError(errno.ErrWorkflowNotFound, fmt.Errorf("workflow version %s not found for ID %d: %w", version, id, err), errorx.KV("id", strconv.FormatInt(id, 10)))
}
inputs, outputs, err := ioF(v.InputParamsStr, v.OutputParamsStr)
if err != nil {
return nil, total, err
@@ -1562,6 +1914,14 @@ func (i *impl) MGet(ctx context.Context, policy *vo.MGetPolicy) ([]*entity.Workf
}
}
func (i *impl) BindConvRelatedInfo(ctx context.Context, convID int64, info entity.ConvRelatedInfo) error {
return i.repo.BindConvRelatedInfo(ctx, convID, info)
}
func (i *impl) GetConvRelatedInfo(ctx context.Context, convID int64) (*entity.ConvRelatedInfo, bool, func() error, error) {
return i.repo.GetConvRelatedInfo(ctx, convID)
}
func (i *impl) calculateTestRunSuccess(ctx context.Context, c *vo.Canvas, wid int64) (bool, error) {
sc, err := adaptor.CanvasToWorkflowSchema(ctx, c)
if err != nil { // not even legal, test run can't possibly be successful
@@ -1766,3 +2126,58 @@ func replaceRelatedWorkflowOrExternalResourceInWorkflowNodes(nodes []*vo.Node, r
func RegisterAllNodeAdaptors() {
adaptor.RegisterAllNodeAdaptors()
}
func (i *impl) adaptToChatFlow(ctx context.Context, wID int64) error {
wfEntity, err := i.repo.GetEntity(ctx, &vo.GetPolicy{
ID: wID,
QType: workflowModel.FromDraft,
})
if err != nil {
return err
}
canvas := &vo.Canvas{}
err = sonic.UnmarshalString(wfEntity.Canvas, canvas)
if err != nil {
return err
}
var startNode *vo.Node
for _, node := range canvas.Nodes {
if node.Type == entity.NodeTypeEntry.IDStr() {
startNode = node
break
}
}
if startNode == nil {
return fmt.Errorf("can not find start node")
}
vMap := make(map[string]bool)
for _, o := range startNode.Data.Outputs {
v, err := vo.ParseVariable(o)
if err != nil {
return err
}
vMap[v.Name] = true
}
if _, ok := vMap["USER_INPUT"]; !ok {
startNode.Data.Outputs = append(startNode.Data.Outputs, &vo.Variable{
Name: "USER_INPUT",
Type: vo.VariableTypeString,
})
}
if _, ok := vMap["CONVERSATION_NAME"]; !ok {
startNode.Data.Outputs = append(startNode.Data.Outputs, &vo.Variable{
Name: "CONVERSATION_NAME",
Type: vo.VariableTypeString,
DefaultValue: "Default",
})
}
canvasStr, err := sonic.MarshalString(canvas)
if err != nil {
return err
}
return i.Save(ctx, wID, canvasStr)
}

View File

@@ -22,11 +22,15 @@ import (
"strconv"
"strings"
cloudworkflow "github.com/coze-dev/coze-studio/backend/api/model/workflow"
workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
"github.com/coze-dev/coze-studio/backend/api/model/workflow"
wf "github.com/coze-dev/coze-studio/backend/domain/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/adaptor"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/validate"
"github.com/coze-dev/coze-studio/backend/domain/workflow/variable"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
@@ -102,17 +106,17 @@ func validateWorkflowTree(ctx context.Context, config vo.ValidateTreeConfig) ([]
return issues, nil
}
func convertToValidationError(issue *validate.Issue) *cloudworkflow.ValidateErrorData {
e := &cloudworkflow.ValidateErrorData{}
func convertToValidationError(issue *validate.Issue) *workflow.ValidateErrorData {
e := &workflow.ValidateErrorData{}
e.Message = issue.Message
if issue.NodeErr != nil {
e.Type = cloudworkflow.ValidateErrorType_BotValidateNodeErr
e.NodeError = &cloudworkflow.NodeError{
e.Type = workflow.ValidateErrorType_BotValidateNodeErr
e.NodeError = &workflow.NodeError{
NodeID: issue.NodeErr.NodeID,
}
} else if issue.PathErr != nil {
e.Type = cloudworkflow.ValidateErrorType_BotValidatePathErr
e.PathError = &cloudworkflow.PathError{
e.Type = workflow.ValidateErrorType_BotValidatePathErr
e.PathError = &workflow.PathError{
Start: issue.PathErr.StartNode,
End: issue.PathErr.EndNode,
}
@@ -121,8 +125,8 @@ func convertToValidationError(issue *validate.Issue) *cloudworkflow.ValidateErro
return e
}
func toValidateErrorData(issues []*validate.Issue) []*cloudworkflow.ValidateErrorData {
validateErrors := make([]*cloudworkflow.ValidateErrorData, 0, len(issues))
func toValidateErrorData(issues []*validate.Issue) []*workflow.ValidateErrorData {
validateErrors := make([]*workflow.ValidateErrorData, 0, len(issues))
for _, issue := range issues {
validateErrors = append(validateErrors, convertToValidationError(issue))
}
@@ -197,3 +201,214 @@ func isIncremental(prev version, next version) bool {
return next.Patch > prev.Patch
}
func getMaxHistoryRoundsRecursively(ctx context.Context, wfEntity *entity.Workflow, repo wf.Repository) (int64, error) {
visited := make(map[string]struct{})
maxRounds := int64(0)
err := getMaxHistoryRoundsRecursiveHelper(ctx, wfEntity, repo, visited, &maxRounds)
return maxRounds, err
}
func getMaxHistoryRoundsRecursiveHelper(ctx context.Context, wfEntity *entity.Workflow, repo wf.Repository, visited map[string]struct{}, maxRounds *int64) error {
visitedKey := fmt.Sprintf("%d:%s", wfEntity.ID, wfEntity.GetVersion())
if _, ok := visited[visitedKey]; ok {
return nil
}
visited[visitedKey] = struct{}{}
var canvas vo.Canvas
if err := sonic.UnmarshalString(wfEntity.Canvas, &canvas); err != nil {
return fmt.Errorf("failed to unmarshal canvas for workflow %d: %w", wfEntity.ID, err)
}
return collectMaxHistoryRounds(ctx, canvas.Nodes, repo, visited, maxRounds)
}
func collectMaxHistoryRounds(ctx context.Context, nodes []*vo.Node, repo wf.Repository, visited map[string]struct{}, maxRounds *int64) error {
for _, node := range nodes {
if node == nil {
continue
}
if node.Data != nil && node.Data.Inputs != nil && node.Data.Inputs.ChatHistorySetting != nil && node.Data.Inputs.ChatHistorySetting.EnableChatHistory {
if node.Data.Inputs.ChatHistorySetting.ChatHistoryRound > *maxRounds {
*maxRounds = node.Data.Inputs.ChatHistorySetting.ChatHistoryRound
}
} else if node.Type == entity.NodeTypeLLM.IDStr() && node.Data != nil && node.Data.Inputs != nil && node.Data.Inputs.LLMParam != nil {
param := node.Data.Inputs.LLMParam
bs, _ := sonic.Marshal(param)
llmParam := make(vo.LLMParam, 0)
if err := sonic.Unmarshal(bs, &llmParam); err != nil {
return err
}
var chatHistoryEnabled bool
var chatHistoryRound int64
for _, param := range llmParam {
switch param.Name {
case "enableChatHistory":
if val, ok := param.Input.Value.Content.(bool); ok {
b := val
chatHistoryEnabled = b
}
case "chatHistoryRound":
if strVal, ok := param.Input.Value.Content.(string); ok {
int64Val, err := strconv.ParseInt(strVal, 10, 64)
if err != nil {
return err
}
chatHistoryRound = int64Val
}
}
}
if chatHistoryEnabled {
if chatHistoryRound > *maxRounds {
*maxRounds = chatHistoryRound
}
}
}
isSubWorkflow := node.Type == entity.NodeTypeSubWorkflow.IDStr() && node.Data != nil && node.Data.Inputs != nil
if isSubWorkflow {
workflowIDStr := node.Data.Inputs.WorkflowID
if workflowIDStr == "" {
continue
}
workflowID, err := strconv.ParseInt(workflowIDStr, 10, 64)
if err != nil {
return fmt.Errorf("invalid workflow ID in sub-workflow node %s: %w", node.ID, err)
}
subWfEntity, err := repo.GetEntity(ctx, &vo.GetPolicy{
ID: workflowID,
QType: ternary.IFElse(len(node.Data.Inputs.WorkflowVersion) == 0, workflowModel.FromDraft, workflowModel.FromSpecificVersion),
Version: node.Data.Inputs.WorkflowVersion,
})
if err != nil {
return fmt.Errorf("failed to get sub-workflow entity %d: %w", workflowID, err)
}
if err := getMaxHistoryRoundsRecursiveHelper(ctx, subWfEntity, repo, visited, maxRounds); err != nil {
return err
}
}
if len(node.Blocks) > 0 {
if err := collectMaxHistoryRounds(ctx, node.Blocks, repo, visited, maxRounds); err != nil {
return err
}
}
}
return nil
}
func getHistoryRoundsFromNode(ctx context.Context, wfEntity *entity.Workflow, nodeID string, repo wf.Repository) (int64, error) {
if wfEntity == nil {
return 0, nil
}
visited := make(map[string]struct{})
visitedKey := fmt.Sprintf("%d:%s", wfEntity.ID, wfEntity.GetVersion())
if _, ok := visited[visitedKey]; ok {
return 0, nil
}
visited[visitedKey] = struct{}{}
maxRounds := int64(0)
c := &vo.Canvas{}
if err := sonic.UnmarshalString(wfEntity.Canvas, c); err != nil {
return 0, fmt.Errorf("failed to unmarshal canvas: %w", err)
}
var (
n *vo.Node
nodeFinder func(nodes []*vo.Node) *vo.Node
)
nodeFinder = func(nodes []*vo.Node) *vo.Node {
for i := range nodes {
if nodes[i].ID == nodeID {
return nodes[i]
}
if len(nodes[i].Blocks) > 0 {
if n := nodeFinder(nodes[i].Blocks); n != nil {
return n
}
}
}
return nil
}
n = nodeFinder(c.Nodes)
if n.Type == entity.NodeTypeLLM.IDStr() {
if n.Data == nil || n.Data.Inputs == nil {
return 0, nil
}
param := n.Data.Inputs.LLMParam
bs, _ := sonic.Marshal(param)
llmParam := make(vo.LLMParam, 0)
if err := sonic.Unmarshal(bs, &llmParam); err != nil {
return 0, err
}
var chatHistoryEnabled bool
var chatHistoryRound int64
for _, param := range llmParam {
switch param.Name {
case "enableChatHistory":
if val, ok := param.Input.Value.Content.(bool); ok {
b := val
chatHistoryEnabled = b
}
case "chatHistoryRound":
if strVal, ok := param.Input.Value.Content.(string); ok {
int64Val, err := strconv.ParseInt(strVal, 10, 64)
if err != nil {
return 0, err
}
chatHistoryRound = int64Val
}
}
}
if chatHistoryEnabled {
return chatHistoryRound, nil
}
return 0, nil
}
if n.Type == entity.NodeTypeIntentDetector.IDStr() || n.Type == entity.NodeTypeKnowledgeRetriever.IDStr() {
if n.Data != nil && n.Data.Inputs != nil && n.Data.Inputs.ChatHistorySetting != nil && n.Data.Inputs.ChatHistorySetting.EnableChatHistory {
return n.Data.Inputs.ChatHistorySetting.ChatHistoryRound, nil
}
return 0, nil
}
if n.Type == entity.NodeTypeSubWorkflow.IDStr() {
if n.Data != nil && n.Data.Inputs != nil {
workflowIDStr := n.Data.Inputs.WorkflowID
if workflowIDStr == "" {
return 0, nil
}
workflowID, err := strconv.ParseInt(workflowIDStr, 10, 64)
if err != nil {
return 0, fmt.Errorf("invalid workflow ID in sub-workflow node %s: %w", n.ID, err)
}
subWfEntity, err := repo.GetEntity(ctx, &vo.GetPolicy{
ID: workflowID,
QType: ternary.IFElse(len(n.Data.Inputs.WorkflowVersion) == 0, workflowModel.FromDraft, workflowModel.FromSpecificVersion),
Version: n.Data.Inputs.WorkflowVersion,
})
if err != nil {
return 0, fmt.Errorf("failed to get sub-workflow entity %d: %w", workflowID, err)
}
if err := getMaxHistoryRoundsRecursiveHelper(ctx, subWfEntity, repo, visited, &maxRounds); err != nil {
return 0, err
}
return maxRounds, nil
}
}
if len(n.Blocks) > 0 {
if err := collectMaxHistoryRounds(ctx, n.Blocks, repo, visited, &maxRounds); err != nil {
return 0, err
}
}
return maxRounds, nil
}