feat: Support for Chat Flow & Agent Support for binding a single chat flow (#765)
Co-authored-by: Yu Yang <72337138+tomasyu985@users.noreply.github.com> Co-authored-by: zengxiaohui <csu.zengxiaohui@gmail.com> Co-authored-by: lijunwen.gigoo <lijunwen.gigoo@bytedance.com> Co-authored-by: lvxinyu.1117 <lvxinyu.1117@bytedance.com> Co-authored-by: liuyunchao.0510 <liuyunchao.0510@bytedance.com> Co-authored-by: haozhenfei <37089575+haozhenfei@users.noreply.github.com> Co-authored-by: July <jiangxujin@bytedance.com> Co-authored-by: tecvan-fe <fanwenjie.fe@bytedance.com>
This commit is contained in:
490
backend/domain/workflow/service/conversation_impl.go
Normal file
490
backend/domain/workflow/service/conversation_impl.go
Normal file
@@ -0,0 +1,490 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/conversation/common"
|
||||
conventity "github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/entity"
|
||||
|
||||
workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/taskgroup"
|
||||
|
||||
workflow2 "github.com/coze-dev/coze-studio/backend/api/model/workflow"
|
||||
crossconversation "github.com/coze-dev/coze-studio/backend/crossdomain/contract/conversation"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
type conversationImpl struct {
|
||||
repo workflow.Repository
|
||||
}
|
||||
|
||||
func (c *conversationImpl) CreateDraftConversationTemplate(ctx context.Context, template *vo.CreateConversationTemplateMeta) (int64, error) {
|
||||
var (
|
||||
spaceID = template.SpaceID
|
||||
appID = template.AppID
|
||||
name = template.Name
|
||||
userID = template.UserID
|
||||
)
|
||||
|
||||
existed, err := c.IsDraftConversationNameExist(ctx, appID, userID, template.Name)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if existed {
|
||||
return 0, vo.WrapError(errno.ErrConversationNameIsDuplicated, fmt.Errorf("conversation name %s exists", name), errorx.KV("name", name))
|
||||
}
|
||||
|
||||
return c.repo.CreateDraftConversationTemplate(ctx, &vo.CreateConversationTemplateMeta{
|
||||
SpaceID: spaceID,
|
||||
AppID: appID,
|
||||
Name: name,
|
||||
UserID: userID,
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func (c *conversationImpl) IsDraftConversationNameExist(ctx context.Context, appID int64, userID int64, name string) (bool, error) {
|
||||
_, existed, err := c.repo.GetDynamicConversationByName(ctx, vo.Draft, appID, consts.CozeConnectorID, userID, name)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if existed {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
_, existed, err = c.repo.GetConversationTemplate(ctx, vo.Draft, vo.GetConversationTemplatePolicy{AppID: ptr.Of(appID), Name: ptr.Of(name)})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if existed {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return false, nil
|
||||
|
||||
}
|
||||
|
||||
func (c *conversationImpl) UpdateDraftConversationTemplateName(ctx context.Context, appID int64, userID int64, templateID int64, modifiedName string) error {
|
||||
template, existed, err := c.repo.GetConversationTemplate(ctx, vo.Draft, vo.GetConversationTemplatePolicy{TemplateID: ptr.Of(templateID)})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if existed && template.Name == modifiedName {
|
||||
return nil
|
||||
}
|
||||
|
||||
existed, err = c.IsDraftConversationNameExist(ctx, appID, userID, modifiedName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if existed {
|
||||
return vo.WrapError(errno.ErrConversationNameIsDuplicated, fmt.Errorf("conversation name %s exists", modifiedName), errorx.KV("name", modifiedName))
|
||||
}
|
||||
|
||||
wfs, err := c.findReplaceWorkflowByConversationName(ctx, appID, template.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = c.replaceWorkflowsConversationName(ctx, wfs, slices.ToMap(wfs, func(e *entity.Workflow) (int64, string) {
|
||||
return e.ID, modifiedName
|
||||
}))
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.repo.UpdateDraftConversationTemplateName(ctx, templateID, modifiedName)
|
||||
|
||||
}
|
||||
|
||||
func (c *conversationImpl) CheckWorkflowsToReplace(ctx context.Context, appID int64, templateID int64) ([]*entity.Workflow, error) {
|
||||
template, existed, err := c.repo.GetConversationTemplate(ctx, vo.Draft, vo.GetConversationTemplatePolicy{TemplateID: ptr.Of(templateID)})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if existed {
|
||||
return c.findReplaceWorkflowByConversationName(ctx, appID, template.Name)
|
||||
}
|
||||
|
||||
return []*entity.Workflow{}, nil
|
||||
}
|
||||
|
||||
func (c *conversationImpl) DeleteDraftConversationTemplate(ctx context.Context, templateID int64, wfID2ConversationName map[int64]string) (int64, error) {
|
||||
|
||||
if len(wfID2ConversationName) == 0 {
|
||||
return c.repo.DeleteDraftConversationTemplate(ctx, templateID)
|
||||
}
|
||||
workflowIDs := make([]int64, 0)
|
||||
for id := range wfID2ConversationName {
|
||||
workflowIDs = append(workflowIDs, id)
|
||||
}
|
||||
|
||||
wfs, _, err := c.repo.MGetDrafts(ctx, &vo.MGetPolicy{
|
||||
MetaQuery: vo.MetaQuery{
|
||||
IDs: workflowIDs,
|
||||
},
|
||||
QType: workflowModel.FromDraft,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
err = c.replaceWorkflowsConversationName(ctx, wfs, wfID2ConversationName)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return c.repo.DeleteDraftConversationTemplate(ctx, templateID)
|
||||
}
|
||||
|
||||
func (c *conversationImpl) ListConversationTemplate(ctx context.Context, env vo.Env, policy *vo.ListConversationTemplatePolicy) ([]*entity.ConversationTemplate, error) {
|
||||
var (
|
||||
err error
|
||||
templates []*entity.ConversationTemplate
|
||||
appID = policy.AppID
|
||||
)
|
||||
templates, err = c.repo.ListConversationTemplate(ctx, env, &vo.ListConversationTemplatePolicy{
|
||||
AppID: appID,
|
||||
Page: policy.Page,
|
||||
NameLike: policy.NameLike,
|
||||
Version: policy.Version,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return templates, nil
|
||||
|
||||
}
|
||||
|
||||
func (c *conversationImpl) MGetStaticConversation(ctx context.Context, env vo.Env, userID, connectorID int64, templateIDs []int64) ([]*entity.StaticConversation, error) {
|
||||
return c.repo.MGetStaticConversation(ctx, env, userID, connectorID, templateIDs)
|
||||
}
|
||||
|
||||
func (c *conversationImpl) ListDynamicConversation(ctx context.Context, env vo.Env, policy *vo.ListConversationPolicy) ([]*entity.DynamicConversation, error) {
|
||||
return c.repo.ListDynamicConversation(ctx, env, policy)
|
||||
}
|
||||
|
||||
func (c *conversationImpl) ReleaseConversationTemplate(ctx context.Context, appID int64, version string) error {
|
||||
templates, err := c.repo.ListConversationTemplate(ctx, vo.Draft, &vo.ListConversationTemplatePolicy{
|
||||
AppID: appID,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(templates) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return c.repo.BatchCreateOnlineConversationTemplate(ctx, templates, version)
|
||||
}
|
||||
|
||||
func (c *conversationImpl) InitApplicationDefaultConversationTemplate(ctx context.Context, spaceID, appID int64, userID int64) error {
|
||||
_, err := c.repo.CreateDraftConversationTemplate(ctx, &vo.CreateConversationTemplateMeta{
|
||||
AppID: appID,
|
||||
SpaceID: spaceID,
|
||||
UserID: userID,
|
||||
Name: "Default",
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conversationImpl) findReplaceWorkflowByConversationName(ctx context.Context, appID int64, name string) ([]*entity.Workflow, error) {
|
||||
|
||||
wfs, _, err := c.repo.MGetDrafts(ctx, &vo.MGetPolicy{
|
||||
QType: workflowModel.FromDraft,
|
||||
MetaQuery: vo.MetaQuery{
|
||||
AppID: ptr.Of(appID),
|
||||
Mode: ptr.Of(workflow2.WorkflowMode_ChatFlow),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
shouldReplacedWorkflow := func(nodes []*vo.Node) (bool, error) {
|
||||
var startNode *vo.Node
|
||||
for _, node := range nodes {
|
||||
if node.Type == entity.NodeTypeEntry.IDStr() {
|
||||
startNode = node
|
||||
}
|
||||
}
|
||||
if startNode == nil {
|
||||
return false, fmt.Errorf("start node not found for block type")
|
||||
}
|
||||
for _, vAny := range startNode.Data.Outputs {
|
||||
v, err := vo.ParseVariable(vAny)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if v.Name == "CONVERSATION_NAME" && v.DefaultValue == name {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
|
||||
}
|
||||
|
||||
shouldReplacedWorkflows := make([]*entity.Workflow, 0)
|
||||
for idx := range wfs {
|
||||
wf := wfs[idx]
|
||||
canvas := &vo.Canvas{}
|
||||
err = sonic.UnmarshalString(wf.Canvas, canvas)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ok, err := shouldReplacedWorkflow(canvas.Nodes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ok {
|
||||
shouldReplacedWorkflows = append(shouldReplacedWorkflows, wf)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return shouldReplacedWorkflows, nil
|
||||
|
||||
}
|
||||
|
||||
func (c *conversationImpl) replaceWorkflowsConversationName(ctx context.Context, wfs []*entity.Workflow, workflowID2ConversionName map[int64]string) error {
|
||||
|
||||
replaceConversionName := func(nodes []*vo.Node, conversionName string) error {
|
||||
var startNode *vo.Node
|
||||
for _, node := range nodes {
|
||||
if node.Type == entity.NodeTypeEntry.IDStr() {
|
||||
startNode = node
|
||||
}
|
||||
}
|
||||
if startNode == nil {
|
||||
return fmt.Errorf("start node not found for block type")
|
||||
}
|
||||
for idx, vAny := range startNode.Data.Outputs {
|
||||
v, err := vo.ParseVariable(vAny)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if v.Name == "CONVERSATION_NAME" {
|
||||
v.DefaultValue = conversionName
|
||||
}
|
||||
startNode.Data.Outputs[idx] = v
|
||||
|
||||
}
|
||||
return nil
|
||||
}
|
||||
tg := taskgroup.NewTaskGroup(ctx, len(wfs))
|
||||
for _, wf := range wfs {
|
||||
wfEntity := wf
|
||||
tg.Go(func() error {
|
||||
canvas := &vo.Canvas{}
|
||||
err := sonic.UnmarshalString(wfEntity.Canvas, canvas)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conversationName := workflowID2ConversionName[wfEntity.ID]
|
||||
err = replaceConversionName(canvas.Nodes, conversationName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
replaceCanvas, err := sonic.MarshalString(canvas)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = c.repo.CreateOrUpdateDraft(ctx, wfEntity.ID, &vo.DraftInfo{
|
||||
DraftMeta: &vo.DraftMeta{
|
||||
TestRunSuccess: false,
|
||||
Modified: true,
|
||||
},
|
||||
Canvas: replaceCanvas,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
||||
})
|
||||
}
|
||||
err := tg.Wait()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conversationImpl) DeleteDynamicConversation(ctx context.Context, env vo.Env, templateID int64) (int64, error) {
|
||||
return c.repo.DeleteDynamicConversation(ctx, env, templateID)
|
||||
}
|
||||
|
||||
func (c *conversationImpl) GetOrCreateConversation(ctx context.Context, env vo.Env, appID, connectorID, userID int64, conversationName string) (int64, int64, error) {
|
||||
t, existed, err := c.repo.GetConversationTemplate(ctx, env, vo.GetConversationTemplatePolicy{
|
||||
AppID: ptr.Of(appID),
|
||||
Name: ptr.Of(conversationName),
|
||||
})
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
conversationIDGenerator := workflow.ConversationIDGenerator(func(ctx context.Context, appID int64, userID, connectorID int64) (*conventity.Conversation, error) {
|
||||
return crossconversation.DefaultSVC().CreateConversation(ctx, &conventity.CreateMeta{
|
||||
AgentID: appID,
|
||||
UserID: userID,
|
||||
ConnectorID: connectorID,
|
||||
Scene: common.Scene_SceneWorkflow,
|
||||
})
|
||||
})
|
||||
|
||||
if existed {
|
||||
conversationID, sectionID, _, err := c.repo.GetOrCreateStaticConversation(ctx, env, conversationIDGenerator, &vo.CreateStaticConversation{
|
||||
AppID: appID,
|
||||
ConnectorID: connectorID,
|
||||
UserID: userID,
|
||||
TemplateID: t.TemplateID,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
return conversationID, sectionID, nil
|
||||
}
|
||||
|
||||
conversationID, sectionID, _, err := c.repo.GetOrCreateDynamicConversation(ctx, env, conversationIDGenerator, &vo.CreateDynamicConversation{
|
||||
AppID: appID,
|
||||
ConnectorID: connectorID,
|
||||
UserID: userID,
|
||||
Name: conversationName,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
return conversationID, sectionID, nil
|
||||
|
||||
}
|
||||
|
||||
func (c *conversationImpl) UpdateConversation(ctx context.Context, env vo.Env, appID, connectorID, userID int64, conversationName string) (int64, error) {
|
||||
t, existed, err := c.repo.GetConversationTemplate(ctx, env, vo.GetConversationTemplatePolicy{
|
||||
AppID: ptr.Of(appID),
|
||||
Name: ptr.Of(conversationName),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if existed {
|
||||
conv, err := crossconversation.DefaultSVC().CreateConversation(ctx, &conventity.CreateMeta{
|
||||
AgentID: appID,
|
||||
UserID: userID,
|
||||
ConnectorID: connectorID,
|
||||
Scene: common.Scene_SceneWorkflow,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if conv == nil {
|
||||
return 0, fmt.Errorf("create conversation failed")
|
||||
}
|
||||
err = c.repo.UpdateStaticConversation(ctx, env, t.TemplateID, connectorID, userID, conv.ID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return conv.ID, nil
|
||||
}
|
||||
|
||||
dy, existed, err := c.repo.GetDynamicConversationByName(ctx, env, appID, connectorID, userID, conversationName)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if !existed {
|
||||
return 0, fmt.Errorf("conversation name %v not found", conversationName)
|
||||
}
|
||||
|
||||
conv, err := crossconversation.DefaultSVC().CreateConversation(ctx, &conventity.CreateMeta{
|
||||
AgentID: appID,
|
||||
UserID: userID,
|
||||
ConnectorID: connectorID,
|
||||
Scene: common.Scene_SceneWorkflow,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if conv == nil {
|
||||
return 0, fmt.Errorf("create conversation failed")
|
||||
}
|
||||
err = c.repo.UpdateDynamicConversation(ctx, env, dy.ConversationID, conv.ID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return conv.ID, nil
|
||||
|
||||
}
|
||||
|
||||
func (c *conversationImpl) GetTemplateByName(ctx context.Context, env vo.Env, appID int64, templateName string) (*entity.ConversationTemplate, bool, error) {
|
||||
return c.repo.GetConversationTemplate(ctx, env, vo.GetConversationTemplatePolicy{
|
||||
AppID: ptr.Of(appID),
|
||||
Name: ptr.Of(templateName),
|
||||
})
|
||||
}
|
||||
|
||||
func (c *conversationImpl) GetDynamicConversationByName(ctx context.Context, env vo.Env, appID, connectorID, userID int64, name string) (*entity.DynamicConversation, bool, error) {
|
||||
return c.repo.GetDynamicConversationByName(ctx, env, appID, connectorID, userID, name)
|
||||
}
|
||||
|
||||
func (c *conversationImpl) GetConversationNameByID(ctx context.Context, env vo.Env, appID, connectorID, conversationID int64) (string, bool, error) {
|
||||
sc, existed, err := c.repo.GetStaticConversationByID(ctx, env, appID, connectorID, conversationID)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
if existed {
|
||||
return sc, true, nil
|
||||
}
|
||||
|
||||
dc, existed, err := c.repo.GetDynamicConversationByID(ctx, env, appID, connectorID, conversationID)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
|
||||
if existed {
|
||||
return dc.Name, true, nil
|
||||
}
|
||||
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
func (c *conversationImpl) Suggest(ctx context.Context, input *vo.SuggestInfo) ([]string, error) {
|
||||
return c.repo.Suggest(ctx, input)
|
||||
}
|
||||
@@ -26,6 +26,8 @@ import (
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
|
||||
workflowapimodel "github.com/coze-dev/coze-studio/backend/api/model/workflow"
|
||||
crossmessage "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
@@ -62,6 +64,8 @@ func (i *impl) SyncExecute(ctx context.Context, config workflowModel.ExecuteConf
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
config.WorkflowMode = wfEntity.Mode
|
||||
|
||||
isApplicationWorkflow := wfEntity.AppID != nil
|
||||
if isApplicationWorkflow && config.Mode == workflowModel.ExecuteModeRelease {
|
||||
err = i.checkApplicationWorkflowReleaseVersion(ctx, *wfEntity.AppID, config.ConnectorID, config.ID, config.Version)
|
||||
@@ -207,6 +211,8 @@ func (i *impl) AsyncExecute(ctx context.Context, config workflowModel.ExecuteCon
|
||||
return 0, err
|
||||
}
|
||||
|
||||
config.WorkflowMode = wfEntity.Mode
|
||||
|
||||
isApplicationWorkflow := wfEntity.AppID != nil
|
||||
if isApplicationWorkflow && config.Mode == workflowModel.ExecuteModeRelease {
|
||||
err = i.checkApplicationWorkflowReleaseVersion(ctx, *wfEntity.AppID, config.ConnectorID, config.ID, config.Version)
|
||||
@@ -292,6 +298,8 @@ func (i *impl) AsyncExecuteNode(ctx context.Context, nodeID string, config workf
|
||||
return 0, err
|
||||
}
|
||||
|
||||
config.WorkflowMode = wfEntity.Mode
|
||||
|
||||
isApplicationWorkflow := wfEntity.AppID != nil
|
||||
if isApplicationWorkflow && config.Mode == workflowModel.ExecuteModeRelease {
|
||||
err = i.checkApplicationWorkflowReleaseVersion(ctx, *wfEntity.AppID, config.ConnectorID, config.ID, config.Version)
|
||||
@@ -300,6 +308,30 @@ func (i *impl) AsyncExecuteNode(ctx context.Context, nodeID string, config workf
|
||||
}
|
||||
}
|
||||
|
||||
historyRounds := int64(0)
|
||||
if config.WorkflowMode == workflowapimodel.WorkflowMode_ChatFlow {
|
||||
|
||||
historyRounds, err = getHistoryRoundsFromNode(ctx, wfEntity, nodeID, i.repo)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
if historyRounds > 0 {
|
||||
messages, scMessages, err := i.prefetchChatHistory(ctx, config, historyRounds)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "failed to prefetch chat history: %v", err)
|
||||
}
|
||||
|
||||
if len(messages) > 0 {
|
||||
config.ConversationHistory = messages
|
||||
}
|
||||
|
||||
if len(scMessages) > 0 {
|
||||
config.ConversationHistorySchemaMessages = scMessages
|
||||
}
|
||||
|
||||
}
|
||||
c := &vo.Canvas{}
|
||||
if err = sonic.UnmarshalString(wfEntity.Canvas, c); err != nil {
|
||||
return 0, fmt.Errorf("failed to unmarshal canvas: %w", err)
|
||||
@@ -375,6 +407,8 @@ func (i *impl) StreamExecute(ctx context.Context, config workflowModel.ExecuteCo
|
||||
return nil, err
|
||||
}
|
||||
|
||||
config.WorkflowMode = wfEntity.Mode
|
||||
|
||||
isApplicationWorkflow := wfEntity.AppID != nil
|
||||
if isApplicationWorkflow && config.Mode == workflowModel.ExecuteModeRelease {
|
||||
err = i.checkApplicationWorkflowReleaseVersion(ctx, *wfEntity.AppID, config.ConnectorID, config.ID, config.Version)
|
||||
@@ -383,6 +417,29 @@ func (i *impl) StreamExecute(ctx context.Context, config workflowModel.ExecuteCo
|
||||
}
|
||||
}
|
||||
|
||||
historyRounds := int64(0)
|
||||
if config.WorkflowMode == workflowapimodel.WorkflowMode_ChatFlow {
|
||||
historyRounds, err = i.calculateMaxChatHistoryRounds(ctx, wfEntity, i.repo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if historyRounds > 0 {
|
||||
messages, scMessages, err := i.prefetchChatHistory(ctx, config, historyRounds)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "failed to prefetch chat history: %v", err)
|
||||
}
|
||||
|
||||
if len(messages) > 0 {
|
||||
config.ConversationHistory = messages
|
||||
}
|
||||
|
||||
if len(scMessages) > 0 {
|
||||
config.ConversationHistorySchemaMessages = scMessages
|
||||
}
|
||||
|
||||
}
|
||||
c := &vo.Canvas{}
|
||||
if err = sonic.UnmarshalString(wfEntity.Canvas, c); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal canvas: %w", err)
|
||||
@@ -718,6 +775,7 @@ func (i *impl) AsyncResume(ctx context.Context, req *entity.ResumeRequest, confi
|
||||
config.AppID = wfExe.AppID
|
||||
config.AgentID = wfExe.AgentID
|
||||
config.CommitID = wfExe.CommitID
|
||||
config.WorkflowMode = wfEntity.Mode
|
||||
|
||||
if config.ConnectorID == 0 {
|
||||
config.ConnectorID = wfExe.ConnectorID
|
||||
@@ -859,6 +917,7 @@ func (i *impl) StreamResume(ctx context.Context, req *entity.ResumeRequest, conf
|
||||
config.AppID = wfExe.AppID
|
||||
config.AgentID = wfExe.AgentID
|
||||
config.CommitID = wfExe.CommitID
|
||||
config.WorkflowMode = wfEntity.Mode
|
||||
|
||||
if config.ConnectorID == 0 {
|
||||
config.ConnectorID = wfExe.ConnectorID
|
||||
@@ -937,3 +996,73 @@ func (i *impl) checkApplicationWorkflowReleaseVersion(ctx context.Context, appID
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
const maxHistoryRounds int64 = 30
|
||||
|
||||
func (i *impl) calculateMaxChatHistoryRounds(ctx context.Context, wfEntity *entity.Workflow, repo workflow.Repository) (int64, error) {
|
||||
if wfEntity == nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
maxRounds, err := getMaxHistoryRoundsRecursively(ctx, wfEntity, repo)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return min(maxRounds, maxHistoryRounds), nil
|
||||
}
|
||||
|
||||
func (i *impl) prefetchChatHistory(ctx context.Context, config workflowModel.ExecuteConfig, historyRounds int64) ([]*crossmessage.WfMessage, []*schema.Message, error) {
|
||||
convID := config.ConversationID
|
||||
agentID := config.AgentID
|
||||
appID := config.AppID
|
||||
userID := config.Operator
|
||||
sectionID := config.SectionID
|
||||
if sectionID == nil {
|
||||
logs.CtxWarnf(ctx, "SectionID is nil, skipping chat history")
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
if convID == nil || *convID == 0 {
|
||||
logs.CtxWarnf(ctx, "ConversationID is 0 or nil, skipping chat history")
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
var resolvedAppID int64
|
||||
if appID != nil {
|
||||
resolvedAppID = *appID
|
||||
} else if agentID != nil {
|
||||
resolvedAppID = *agentID
|
||||
} else {
|
||||
logs.CtxWarnf(ctx, "AppID and AgentID are both nil, skipping chat history")
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
runIdsReq := &crossmessage.GetLatestRunIDsRequest{
|
||||
ConversationID: *convID,
|
||||
AppID: resolvedAppID,
|
||||
UserID: userID,
|
||||
Rounds: historyRounds + 1,
|
||||
SectionID: *sectionID,
|
||||
}
|
||||
|
||||
runIds, err := crossmessage.DefaultSVC().GetLatestRunIDs(ctx, runIdsReq)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "failed to get latest run ids: %v", err)
|
||||
return nil, nil, nil
|
||||
}
|
||||
if len(runIds) <= 1 {
|
||||
return []*crossmessage.WfMessage{}, []*schema.Message{}, nil
|
||||
}
|
||||
runIds = runIds[1:]
|
||||
|
||||
response, err := crossmessage.DefaultSVC().GetMessagesByRunIDs(ctx, &crossmessage.GetMessagesByRunIDsRequest{
|
||||
ConversationID: *convID,
|
||||
RunIDs: runIds,
|
||||
})
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "failed to get messages by run ids: %v", err)
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
return response.Messages, response.SchemaMessages, nil
|
||||
}
|
||||
|
||||
@@ -20,13 +20,14 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/spf13/cast"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"strconv"
|
||||
|
||||
einoCompose "github.com/cloudwego/eino/compose"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
||||
workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
|
||||
@@ -38,6 +39,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/adaptor"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/intentdetector"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/llm"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/repo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
@@ -45,6 +47,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
@@ -56,6 +59,7 @@ type impl struct {
|
||||
repo workflow.Repository
|
||||
*asToolImpl
|
||||
*executableImpl
|
||||
*conversationImpl
|
||||
}
|
||||
|
||||
func NewWorkflowService(repo workflow.Repository) workflow.Service {
|
||||
@@ -67,12 +71,14 @@ func NewWorkflowService(repo workflow.Repository) workflow.Service {
|
||||
executableImpl: &executableImpl{
|
||||
repo: repo,
|
||||
},
|
||||
conversationImpl: &conversationImpl{repo: repo},
|
||||
}
|
||||
}
|
||||
|
||||
func NewWorkflowRepository(idgen idgen.IDGenerator, db *gorm.DB, redis cache.Cmdable, tos storage.Storage,
|
||||
cpStore einoCompose.CheckPointStore, chatModel chatmodel.BaseChatModel, workflowConfig workflow.WorkflowConfig) workflow.Repository {
|
||||
return repo.NewRepository(idgen, db, redis, tos, cpStore, chatModel, workflowConfig)
|
||||
cpStore einoCompose.CheckPointStore, chatModel chatmodel.BaseChatModel, cfg workflow.WorkflowConfig) (workflow.Repository, error) {
|
||||
return repo.NewRepository(idgen, db, redis, tos, cpStore, chatModel, cfg)
|
||||
|
||||
}
|
||||
|
||||
func (i *impl) ListNodeMeta(_ context.Context, nodeTypes map[entity.NodeType]bool) (map[string][]*entity.NodeTypeMeta, []entity.Category, error) {
|
||||
@@ -440,10 +446,13 @@ func (i *impl) collectNodePropertyMap(ctx context.Context, canvas *vo.Canvas) (m
|
||||
|
||||
var canvasSchema string
|
||||
if n.Data.Inputs.WorkflowVersion != "" {
|
||||
versionInfo, err := i.repo.GetVersion(ctx, wid, n.Data.Inputs.WorkflowVersion)
|
||||
versionInfo, existed, err := i.repo.GetVersion(ctx, wid, n.Data.Inputs.WorkflowVersion)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !existed {
|
||||
return nil, vo.WrapError(errno.ErrWorkflowNotFound, fmt.Errorf("workflow version %s not found for ID %d: %w", n.Data.Inputs.WorkflowVersion, wid, err), errorx.KV("id", strconv.FormatInt(wid, 10)))
|
||||
}
|
||||
canvasSchema = versionInfo.Canvas
|
||||
} else {
|
||||
draftInfo, err := i.repo.DraftV2(ctx, wid, "")
|
||||
@@ -522,6 +531,9 @@ func isEnableChatHistory(s *schema.NodeSchema) bool {
|
||||
case entity.NodeTypeIntentDetector:
|
||||
llmParam := s.Configs.(*intentdetector.Config).LLMParams
|
||||
return llmParam.EnableChatHistory
|
||||
case entity.NodeTypeKnowledgeRetriever:
|
||||
chatParam := s.Configs.(*knowledge.RetrieveConfig).ChatHistorySetting
|
||||
return chatParam != nil && chatParam.EnableChatHistory
|
||||
default:
|
||||
return false
|
||||
}
|
||||
@@ -541,6 +553,103 @@ func isRefGlobalVariable(s *schema.NodeSchema) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *impl) CreateChatFlowRole(ctx context.Context, role *vo.ChatFlowRoleCreate) (int64, error) {
|
||||
id, err := i.repo.CreateChatFlowRoleConfig(ctx, &entity.ChatFlowRole{
|
||||
Name: role.Name,
|
||||
Description: role.Description,
|
||||
WorkflowID: role.WorkflowID,
|
||||
CreatorID: role.CreatorID,
|
||||
AudioConfig: role.AudioConfig,
|
||||
UserInputConfig: role.UserInputConfig,
|
||||
AvatarUri: role.AvatarUri,
|
||||
BackgroundImageInfo: role.BackgroundImageInfo,
|
||||
OnboardingInfo: role.OnboardingInfo,
|
||||
SuggestReplyInfo: role.SuggestReplyInfo,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (i *impl) UpdateChatFlowRole(ctx context.Context, workflowID int64, role *vo.ChatFlowRoleUpdate) error {
|
||||
err := i.repo.UpdateChatFlowRoleConfig(ctx, workflowID, role)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *impl) GetChatFlowRole(ctx context.Context, workflowID int64, version string) (*entity.ChatFlowRole, error) {
|
||||
role, err, isExist := i.repo.GetChatFlowRoleConfig(ctx, workflowID, version)
|
||||
if !isExist {
|
||||
logs.CtxWarnf(ctx, "chat flow role not exist, workflow id %v, version %v", workflowID, version)
|
||||
// Return (nil, nil) on 'NotExist' to align with the production behavior,
|
||||
// where the GET API may be called before the CREATE API during chatflow creation.
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return role, nil
|
||||
}
|
||||
|
||||
func (i *impl) GetWorkflowVersionsByConnector(ctx context.Context, connectorID, workflowID int64, limit int) ([]string, error) {
|
||||
return i.repo.GetVersionListByConnectorAndWorkflowID(ctx, connectorID, workflowID, limit)
|
||||
}
|
||||
|
||||
func (i *impl) DeleteChatFlowRole(ctx context.Context, id int64, workflowID int64) error {
|
||||
return i.repo.DeleteChatFlowRoleConfig(ctx, id, workflowID)
|
||||
}
|
||||
|
||||
func (i *impl) PublishChatFlowRole(ctx context.Context, policy *vo.PublishRolePolicy) error {
|
||||
if policy.WorkflowID == 0 || policy.CreatorID == 0 || policy.Version == "" {
|
||||
logs.CtxErrorf(ctx, "invalid publish role policy, workflow id %v, creator id %v should not be zero, version %v should not be empty", policy.WorkflowID, policy.CreatorID, policy.Version)
|
||||
return vo.WrapError(errno.ErrInvalidParameter, fmt.Errorf("invalid publish role policy, workflow id %v, creator id %v should not be zero, version %v should not be empty", policy.WorkflowID, policy.CreatorID, policy.Version))
|
||||
}
|
||||
wf, err := i.repo.GetEntity(ctx, &vo.GetPolicy{
|
||||
ID: policy.WorkflowID,
|
||||
MetaOnly: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if wf.Mode != cloudworkflow.WorkflowMode_ChatFlow {
|
||||
return vo.WrapError(errno.ErrChatFlowRoleOperationFail, fmt.Errorf("workflow id %v, mode %v is not a chatflow", policy.WorkflowID, wf.Mode))
|
||||
}
|
||||
role, err, isExist := i.repo.GetChatFlowRoleConfig(ctx, policy.WorkflowID, "")
|
||||
if !isExist {
|
||||
logs.CtxErrorf(ctx, "get draft chat flow role nil, workflow id %v", policy.WorkflowID)
|
||||
return vo.WrapError(errno.ErrChatFlowRoleOperationFail, fmt.Errorf("get draft chat flow role nil, workflow id %v", policy.WorkflowID))
|
||||
}
|
||||
if err != nil {
|
||||
return vo.WrapIfNeeded(errno.ErrChatFlowRoleOperationFail, err)
|
||||
}
|
||||
|
||||
_, err = i.repo.CreateChatFlowRoleConfig(ctx, &entity.ChatFlowRole{
|
||||
Name: role.Name,
|
||||
Description: role.Description,
|
||||
WorkflowID: policy.WorkflowID,
|
||||
CreatorID: policy.CreatorID,
|
||||
AudioConfig: role.AudioConfig,
|
||||
UserInputConfig: role.UserInputConfig,
|
||||
AvatarUri: role.AvatarUri,
|
||||
BackgroundImageInfo: role.BackgroundImageInfo,
|
||||
OnboardingInfo: role.OnboardingInfo,
|
||||
SuggestReplyInfo: role.SuggestReplyInfo,
|
||||
Version: policy.Version,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func canvasToRefs(referringID int64, canvasStr string) (map[entity.WorkflowReferenceKey]struct{}, error) {
|
||||
var canvas vo.Canvas
|
||||
if err := sonic.UnmarshalString(canvasStr, &canvas); err != nil {
|
||||
@@ -659,6 +768,13 @@ func (i *impl) UpdateMeta(ctx context.Context, id int64, metaUpdate *vo.MetaUpda
|
||||
return err
|
||||
}
|
||||
|
||||
if metaUpdate.WorkflowMode != nil && *metaUpdate.WorkflowMode == cloudworkflow.WorkflowMode_ChatFlow {
|
||||
err = i.adaptToChatFlow(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -667,6 +783,35 @@ func (i *impl) CopyWorkflow(ctx context.Context, workflowID int64, policy vo.Cop
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// chat flow should copy role config
|
||||
if wf.Mode == cloudworkflow.WorkflowMode_ChatFlow {
|
||||
role, err, isExist := i.repo.GetChatFlowRoleConfig(ctx, workflowID, "")
|
||||
if !isExist {
|
||||
logs.CtxErrorf(ctx, "get draft chat flow role nil, workflow id %v", workflowID)
|
||||
return nil, vo.WrapError(errno.ErrChatFlowRoleOperationFail, fmt.Errorf("get draft chat flow role nil, workflow id %v", workflowID))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, vo.WrapIfNeeded(errno.ErrChatFlowRoleOperationFail, err)
|
||||
}
|
||||
_, err = i.repo.CreateChatFlowRoleConfig(ctx, &entity.ChatFlowRole{
|
||||
Name: role.Name,
|
||||
Description: role.Description,
|
||||
WorkflowID: wf.ID,
|
||||
CreatorID: wf.CreatorID,
|
||||
AudioConfig: role.AudioConfig,
|
||||
UserInputConfig: role.UserInputConfig,
|
||||
AvatarUri: role.AvatarUri,
|
||||
BackgroundImageInfo: role.BackgroundImageInfo,
|
||||
OnboardingInfo: role.OnboardingInfo,
|
||||
SuggestReplyInfo: role.SuggestReplyInfo,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return wf, nil
|
||||
|
||||
@@ -677,7 +822,7 @@ func (i *impl) ReleaseApplicationWorkflows(ctx context.Context, appID int64, con
|
||||
return nil, fmt.Errorf("connector ids is required")
|
||||
}
|
||||
|
||||
wfs, _, err := i.MGet(ctx, &vo.MGetPolicy{
|
||||
allWorkflowsInApp, _, err := i.MGet(ctx, &vo.MGetPolicy{
|
||||
MetaQuery: vo.MetaQuery{
|
||||
AppID: &appID,
|
||||
},
|
||||
@@ -688,14 +833,15 @@ func (i *impl) ReleaseApplicationWorkflows(ctx context.Context, appID int64, con
|
||||
}
|
||||
|
||||
relatedPlugins := make(map[int64]*plugin.PluginEntity, len(config.PluginIDs))
|
||||
relatedWorkflow := make(map[int64]entity.IDVersionPair, len(wfs))
|
||||
relatedWorkflow := make(map[int64]entity.IDVersionPair, len(allWorkflowsInApp))
|
||||
|
||||
for _, wf := range wfs {
|
||||
for _, wf := range allWorkflowsInApp {
|
||||
relatedWorkflow[wf.ID] = entity.IDVersionPair{
|
||||
ID: wf.ID,
|
||||
Version: config.Version,
|
||||
}
|
||||
}
|
||||
|
||||
for _, id := range config.PluginIDs {
|
||||
relatedPlugins[id] = &plugin.PluginEntity{
|
||||
PluginID: id,
|
||||
@@ -704,7 +850,22 @@ func (i *impl) ReleaseApplicationWorkflows(ctx context.Context, appID int64, con
|
||||
}
|
||||
|
||||
vIssues := make([]*vo.ValidateIssue, 0)
|
||||
for _, wf := range wfs {
|
||||
|
||||
willPublishWorkflows := make([]*entity.Workflow, 0)
|
||||
|
||||
if len(config.WorkflowIDs) == 0 {
|
||||
willPublishWorkflows = allWorkflowsInApp
|
||||
} else {
|
||||
willPublishWorkflows, _, err = i.MGet(ctx, &vo.MGetPolicy{
|
||||
MetaQuery: vo.MetaQuery{
|
||||
AppID: &appID,
|
||||
IDs: config.WorkflowIDs,
|
||||
},
|
||||
QType: workflowModel.FromDraft,
|
||||
})
|
||||
}
|
||||
|
||||
for _, wf := range willPublishWorkflows {
|
||||
issues, err := validateWorkflowTree(ctx, vo.ValidateTreeConfig{
|
||||
CanvasSchema: wf.Canvas,
|
||||
AppID: ptr.Of(appID),
|
||||
@@ -723,7 +884,7 @@ func (i *impl) ReleaseApplicationWorkflows(ctx context.Context, appID int64, con
|
||||
return vIssues, nil
|
||||
}
|
||||
|
||||
for _, wf := range wfs {
|
||||
for _, wf := range willPublishWorkflows {
|
||||
c := &vo.Canvas{}
|
||||
err := sonic.UnmarshalString(wf.Canvas, c)
|
||||
if err != nil {
|
||||
@@ -747,9 +908,8 @@ func (i *impl) ReleaseApplicationWorkflows(ctx context.Context, appID int64, con
|
||||
}
|
||||
|
||||
userID := ctxutil.MustGetUIDFromCtx(ctx)
|
||||
|
||||
workflowsToPublish := make(map[int64]*vo.VersionInfo)
|
||||
for _, wf := range wfs {
|
||||
for _, wf := range willPublishWorkflows {
|
||||
inputStr, err := sonic.MarshalString(wf.InputParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -774,8 +934,16 @@ func (i *impl) ReleaseApplicationWorkflows(ctx context.Context, appID int64, con
|
||||
}
|
||||
}
|
||||
|
||||
workflowIDs := make([]int64, 0, len(wfs))
|
||||
workflowIDs := make([]int64, 0, len(willPublishWorkflows))
|
||||
for id, vInfo := range workflowsToPublish {
|
||||
// if version existed skip
|
||||
_, existed, err := i.repo.GetVersion(ctx, id, config.Version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if existed {
|
||||
continue
|
||||
}
|
||||
wfRefs, err := canvasToRefs(id, vInfo.Canvas)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -787,6 +955,24 @@ func (i *impl) ReleaseApplicationWorkflows(ctx context.Context, appID int64, con
|
||||
}
|
||||
}
|
||||
|
||||
err = i.ReleaseConversationTemplate(ctx, appID, config.Version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, wf := range willPublishWorkflows {
|
||||
if wf.Mode == cloudworkflow.WorkflowMode_ChatFlow {
|
||||
err = i.PublishChatFlowRole(ctx, &vo.PublishRolePolicy{
|
||||
WorkflowID: wf.ID,
|
||||
CreatorID: wf.CreatorID,
|
||||
Version: config.Version,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, connectorID := range config.ConnectorIDs {
|
||||
err = i.repo.BatchCreateConnectorWorkflowVersion(ctx, appID, connectorID, workflowIDs, config.Version)
|
||||
if err != nil {
|
||||
@@ -889,7 +1075,7 @@ func (i *impl) CopyWorkflowFromAppToLibrary(ctx context.Context, workflowID int6
|
||||
}
|
||||
|
||||
if node.Type == entity.NodeTypeLLM.IDStr() {
|
||||
if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil {
|
||||
if node.Data.Inputs.LLM != nil && node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil {
|
||||
for _, w := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList {
|
||||
var (
|
||||
v *vo.DraftInfo
|
||||
@@ -1012,7 +1198,7 @@ func (i *impl) CopyWorkflowFromAppToLibrary(ctx context.Context, workflowID int6
|
||||
return err
|
||||
}
|
||||
|
||||
cwf, err := i.repo.CopyWorkflow(ctx, wf.id, vo.CopyWorkflowPolicy{
|
||||
cwf, err := i.CopyWorkflow(ctx, wf.id, vo.CopyWorkflowPolicy{
|
||||
TargetAppID: ptr.Of(int64(0)),
|
||||
ModifiedCanvasSchema: ptr.Of(modifiedCanvasString),
|
||||
})
|
||||
@@ -1120,7 +1306,7 @@ func (i *impl) DuplicateWorkflowsByAppID(ctx context.Context, sourceAppID, targe
|
||||
|
||||
}
|
||||
if node.Type == entity.NodeTypeLLM.IDStr() {
|
||||
if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil {
|
||||
if node.Data.Inputs.LLM != nil && node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil {
|
||||
for _, w := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList {
|
||||
var (
|
||||
v *vo.DraftInfo
|
||||
@@ -1246,6 +1432,11 @@ func (i *impl) DuplicateWorkflowsByAppID(ctx context.Context, sourceAppID, targe
|
||||
}
|
||||
}
|
||||
|
||||
err = i.repo.CopyTemplateConversationByAppID(ctx, sourceAppID, targetAppID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return copiedWorkflowArray, nil
|
||||
|
||||
}
|
||||
@@ -1368,7 +1559,7 @@ func (i *impl) GetWorkflowDependenceResource(ctx context.Context, workflowID int
|
||||
ds.DatabaseIDs = append(ds.DatabaseIDs, dsID)
|
||||
}
|
||||
case entity.NodeTypeLLM:
|
||||
if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.PluginFCParam != nil {
|
||||
if node.Data.Inputs.LLM != nil && node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.PluginFCParam != nil {
|
||||
for idx := range node.Data.Inputs.FCParam.PluginFCParam.PluginList {
|
||||
if node.Data.Inputs.FCParam.PluginFCParam.PluginList[idx].IsDraft {
|
||||
pl := node.Data.Inputs.FCParam.PluginFCParam.PluginList[idx]
|
||||
@@ -1382,7 +1573,7 @@ func (i *impl) GetWorkflowDependenceResource(ctx context.Context, workflowID int
|
||||
|
||||
}
|
||||
}
|
||||
if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.KnowledgeFCParam != nil {
|
||||
if node.Data.Inputs.LLM != nil && node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.KnowledgeFCParam != nil {
|
||||
for idx := range node.Data.Inputs.FCParam.KnowledgeFCParam.KnowledgeList {
|
||||
kn := node.Data.Inputs.FCParam.KnowledgeFCParam.KnowledgeList[idx]
|
||||
kid, err := strconv.ParseInt(kn.ID, 10, 64)
|
||||
@@ -1394,7 +1585,7 @@ func (i *impl) GetWorkflowDependenceResource(ctx context.Context, workflowID int
|
||||
}
|
||||
}
|
||||
|
||||
if node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil {
|
||||
if node.Data.Inputs.LLM != nil && node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil {
|
||||
for idx := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList {
|
||||
if node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList[idx].IsDraft {
|
||||
wID, err := strconv.ParseInt(node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList[idx].WorkflowID, 10, 64)
|
||||
@@ -1467,6 +1658,165 @@ func (i *impl) GetWorkflowDependenceResource(ctx context.Context, workflowID int
|
||||
|
||||
}
|
||||
|
||||
func (i *impl) checkBotAgentNode(node *vo.Node) error {
|
||||
if node.Type == entity.NodeTypeCreateConversation.IDStr() || node.Type == entity.NodeTypeConversationDelete.IDStr() || node.Type == entity.NodeTypeConversationUpdate.IDStr() || node.Type == entity.NodeTypeConversationList.IDStr() {
|
||||
return errors.New("conversation-related nodes are not supported in chatflow")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *impl) validateNodesRecursively(ctx context.Context, nodes []*vo.Node, checkType cloudworkflow.CheckType, visited map[string]struct{}, repo workflow.Repository) error {
|
||||
queue := make([]*vo.Node, 0, len(nodes))
|
||||
queue = append(queue, nodes...)
|
||||
|
||||
for len(queue) > 0 {
|
||||
node := queue[0]
|
||||
queue = queue[1:]
|
||||
|
||||
if node == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var checkErr error
|
||||
switch checkType {
|
||||
case cloudworkflow.CheckType_BotAgent:
|
||||
checkErr = i.checkBotAgentNode(node)
|
||||
default:
|
||||
// For now, we only handle BotAgent check, so we can do nothing here.
|
||||
// In the future, if there are other check types that need to be validated on every node, this logic will need to be adjusted.
|
||||
}
|
||||
if checkErr != nil {
|
||||
return checkErr
|
||||
}
|
||||
|
||||
// Enqueue nested nodes for BFS traversal. This handles Loop, Batch, and other nodes with nested blocks.
|
||||
if len(node.Blocks) > 0 {
|
||||
queue = append(queue, node.Blocks...)
|
||||
}
|
||||
|
||||
if node.Type == entity.NodeTypeSubWorkflow.IDStr() && node.Data != nil && node.Data.Inputs != nil {
|
||||
workflowIDStr := node.Data.Inputs.WorkflowID
|
||||
if workflowIDStr == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
workflowID, err := strconv.ParseInt(workflowIDStr, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid workflow ID in sub-workflow node %s: %w", node.ID, err)
|
||||
}
|
||||
|
||||
version := node.Data.Inputs.WorkflowVersion
|
||||
qType := workflowModel.FromDraft
|
||||
if version != "" {
|
||||
qType = workflowModel.FromSpecificVersion
|
||||
}
|
||||
|
||||
visitedKey := fmt.Sprintf("%d:%s", workflowID, version)
|
||||
if _, ok := visited[visitedKey]; ok {
|
||||
continue
|
||||
}
|
||||
visited[visitedKey] = struct{}{}
|
||||
|
||||
subWfEntity, err := repo.GetEntity(ctx, &vo.GetPolicy{
|
||||
ID: workflowID,
|
||||
QType: qType,
|
||||
Version: version,
|
||||
})
|
||||
if err != nil {
|
||||
delete(visited, visitedKey)
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("failed to get sub-workflow entity %d: %w", workflowID, err)
|
||||
}
|
||||
|
||||
var canvas vo.Canvas
|
||||
if err := sonic.UnmarshalString(subWfEntity.Canvas, &canvas); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal canvas for workflow %d: %w", subWfEntity.ID, err)
|
||||
}
|
||||
|
||||
queue = append(queue, canvas.Nodes...)
|
||||
}
|
||||
|
||||
if node.Type == entity.NodeTypeLLM.IDStr() && node.Data != nil && node.Data.Inputs != nil && node.Data.Inputs.LLM != nil && node.Data.Inputs.FCParam != nil && node.Data.Inputs.FCParam.WorkflowFCParam != nil {
|
||||
for _, subWfInfo := range node.Data.Inputs.FCParam.WorkflowFCParam.WorkflowList {
|
||||
if subWfInfo.WorkflowID == "" {
|
||||
continue
|
||||
}
|
||||
workflowID, err := strconv.ParseInt(subWfInfo.WorkflowID, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid workflow ID in large model node %s: %w", node.ID, err)
|
||||
}
|
||||
|
||||
version := subWfInfo.WorkflowVersion
|
||||
qType := workflowModel.FromDraft
|
||||
if version != "" {
|
||||
qType = workflowModel.FromSpecificVersion
|
||||
}
|
||||
|
||||
visitedKey := fmt.Sprintf("%d:%s", workflowID, version)
|
||||
if _, ok := visited[visitedKey]; ok {
|
||||
continue
|
||||
}
|
||||
visited[visitedKey] = struct{}{}
|
||||
|
||||
subWfEntity, err := repo.GetEntity(ctx, &vo.GetPolicy{
|
||||
ID: workflowID,
|
||||
QType: qType,
|
||||
Version: version,
|
||||
})
|
||||
if err != nil {
|
||||
delete(visited, visitedKey)
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("failed to get sub-workflow entity %d from large model node: %w", workflowID, err)
|
||||
}
|
||||
|
||||
var canvas vo.Canvas
|
||||
if err := sonic.UnmarshalString(subWfEntity.Canvas, &canvas); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal canvas for workflow %d from large model node: %w", subWfEntity.ID, err)
|
||||
}
|
||||
|
||||
queue = append(queue, canvas.Nodes...)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *impl) WorkflowSchemaCheck(ctx context.Context, wf *entity.Workflow, checks []cloudworkflow.CheckType) ([]*cloudworkflow.CheckResult, error) {
|
||||
checkResults := make([]*cloudworkflow.CheckResult, 0, len(checks))
|
||||
|
||||
var canvas vo.Canvas
|
||||
if err := sonic.UnmarshalString(wf.Canvas, &canvas); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal canvas for workflow %d: %w", wf.ID, err)
|
||||
}
|
||||
|
||||
for _, checkType := range checks {
|
||||
visited := make(map[string]struct{})
|
||||
visitedKey := fmt.Sprintf("%d:%s", wf.ID, wf.GetVersion())
|
||||
visited[visitedKey] = struct{}{}
|
||||
|
||||
err := i.validateNodesRecursively(ctx, canvas.Nodes, checkType, visited, i.repo)
|
||||
|
||||
if err != nil {
|
||||
checkResults = append(checkResults, &cloudworkflow.CheckResult{
|
||||
IsPass: false,
|
||||
Reason: err.Error(),
|
||||
Type: checkType,
|
||||
})
|
||||
} else {
|
||||
checkResults = append(checkResults, &cloudworkflow.CheckResult{
|
||||
IsPass: true,
|
||||
Type: checkType,
|
||||
Reason: "",
|
||||
})
|
||||
}
|
||||
}
|
||||
return checkResults, nil
|
||||
}
|
||||
|
||||
func (i *impl) MGet(ctx context.Context, policy *vo.MGetPolicy) ([]*entity.Workflow, int64, error) {
|
||||
if policy.MetaOnly {
|
||||
metas, total, err := i.repo.MGetMetas(ctx, &policy.MetaQuery)
|
||||
@@ -1527,11 +1877,13 @@ func (i *impl) MGet(ctx context.Context, policy *vo.MGetPolicy) ([]*entity.Workf
|
||||
index := 0
|
||||
|
||||
for id, version := range policy.Versions {
|
||||
v, err := i.repo.GetVersion(ctx, id, version)
|
||||
v, existed, err := i.repo.GetVersion(ctx, id, version)
|
||||
if err != nil {
|
||||
return nil, total, err
|
||||
}
|
||||
|
||||
if !existed {
|
||||
return nil, total, vo.WrapError(errno.ErrWorkflowNotFound, fmt.Errorf("workflow version %s not found for ID %d: %w", version, id, err), errorx.KV("id", strconv.FormatInt(id, 10)))
|
||||
}
|
||||
inputs, outputs, err := ioF(v.InputParamsStr, v.OutputParamsStr)
|
||||
if err != nil {
|
||||
return nil, total, err
|
||||
@@ -1562,6 +1914,14 @@ func (i *impl) MGet(ctx context.Context, policy *vo.MGetPolicy) ([]*entity.Workf
|
||||
}
|
||||
}
|
||||
|
||||
func (i *impl) BindConvRelatedInfo(ctx context.Context, convID int64, info entity.ConvRelatedInfo) error {
|
||||
return i.repo.BindConvRelatedInfo(ctx, convID, info)
|
||||
}
|
||||
|
||||
func (i *impl) GetConvRelatedInfo(ctx context.Context, convID int64) (*entity.ConvRelatedInfo, bool, func() error, error) {
|
||||
return i.repo.GetConvRelatedInfo(ctx, convID)
|
||||
}
|
||||
|
||||
func (i *impl) calculateTestRunSuccess(ctx context.Context, c *vo.Canvas, wid int64) (bool, error) {
|
||||
sc, err := adaptor.CanvasToWorkflowSchema(ctx, c)
|
||||
if err != nil { // not even legal, test run can't possibly be successful
|
||||
@@ -1766,3 +2126,58 @@ func replaceRelatedWorkflowOrExternalResourceInWorkflowNodes(nodes []*vo.Node, r
|
||||
func RegisterAllNodeAdaptors() {
|
||||
adaptor.RegisterAllNodeAdaptors()
|
||||
}
|
||||
func (i *impl) adaptToChatFlow(ctx context.Context, wID int64) error {
|
||||
wfEntity, err := i.repo.GetEntity(ctx, &vo.GetPolicy{
|
||||
ID: wID,
|
||||
QType: workflowModel.FromDraft,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
canvas := &vo.Canvas{}
|
||||
err = sonic.UnmarshalString(wfEntity.Canvas, canvas)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var startNode *vo.Node
|
||||
for _, node := range canvas.Nodes {
|
||||
if node.Type == entity.NodeTypeEntry.IDStr() {
|
||||
startNode = node
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if startNode == nil {
|
||||
return fmt.Errorf("can not find start node")
|
||||
}
|
||||
|
||||
vMap := make(map[string]bool)
|
||||
for _, o := range startNode.Data.Outputs {
|
||||
v, err := vo.ParseVariable(o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
vMap[v.Name] = true
|
||||
}
|
||||
|
||||
if _, ok := vMap["USER_INPUT"]; !ok {
|
||||
startNode.Data.Outputs = append(startNode.Data.Outputs, &vo.Variable{
|
||||
Name: "USER_INPUT",
|
||||
Type: vo.VariableTypeString,
|
||||
})
|
||||
}
|
||||
if _, ok := vMap["CONVERSATION_NAME"]; !ok {
|
||||
startNode.Data.Outputs = append(startNode.Data.Outputs, &vo.Variable{
|
||||
Name: "CONVERSATION_NAME",
|
||||
Type: vo.VariableTypeString,
|
||||
DefaultValue: "Default",
|
||||
})
|
||||
}
|
||||
canvasStr, err := sonic.MarshalString(canvas)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return i.Save(ctx, wID, canvasStr)
|
||||
}
|
||||
|
||||
@@ -22,11 +22,15 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
cloudworkflow "github.com/coze-dev/coze-studio/backend/api/model/workflow"
|
||||
workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/workflow"
|
||||
wf "github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/adaptor"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/validate"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/variable"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
@@ -102,17 +106,17 @@ func validateWorkflowTree(ctx context.Context, config vo.ValidateTreeConfig) ([]
|
||||
return issues, nil
|
||||
}
|
||||
|
||||
func convertToValidationError(issue *validate.Issue) *cloudworkflow.ValidateErrorData {
|
||||
e := &cloudworkflow.ValidateErrorData{}
|
||||
func convertToValidationError(issue *validate.Issue) *workflow.ValidateErrorData {
|
||||
e := &workflow.ValidateErrorData{}
|
||||
e.Message = issue.Message
|
||||
if issue.NodeErr != nil {
|
||||
e.Type = cloudworkflow.ValidateErrorType_BotValidateNodeErr
|
||||
e.NodeError = &cloudworkflow.NodeError{
|
||||
e.Type = workflow.ValidateErrorType_BotValidateNodeErr
|
||||
e.NodeError = &workflow.NodeError{
|
||||
NodeID: issue.NodeErr.NodeID,
|
||||
}
|
||||
} else if issue.PathErr != nil {
|
||||
e.Type = cloudworkflow.ValidateErrorType_BotValidatePathErr
|
||||
e.PathError = &cloudworkflow.PathError{
|
||||
e.Type = workflow.ValidateErrorType_BotValidatePathErr
|
||||
e.PathError = &workflow.PathError{
|
||||
Start: issue.PathErr.StartNode,
|
||||
End: issue.PathErr.EndNode,
|
||||
}
|
||||
@@ -121,8 +125,8 @@ func convertToValidationError(issue *validate.Issue) *cloudworkflow.ValidateErro
|
||||
return e
|
||||
}
|
||||
|
||||
func toValidateErrorData(issues []*validate.Issue) []*cloudworkflow.ValidateErrorData {
|
||||
validateErrors := make([]*cloudworkflow.ValidateErrorData, 0, len(issues))
|
||||
func toValidateErrorData(issues []*validate.Issue) []*workflow.ValidateErrorData {
|
||||
validateErrors := make([]*workflow.ValidateErrorData, 0, len(issues))
|
||||
for _, issue := range issues {
|
||||
validateErrors = append(validateErrors, convertToValidationError(issue))
|
||||
}
|
||||
@@ -197,3 +201,214 @@ func isIncremental(prev version, next version) bool {
|
||||
|
||||
return next.Patch > prev.Patch
|
||||
}
|
||||
|
||||
func getMaxHistoryRoundsRecursively(ctx context.Context, wfEntity *entity.Workflow, repo wf.Repository) (int64, error) {
|
||||
visited := make(map[string]struct{})
|
||||
maxRounds := int64(0)
|
||||
err := getMaxHistoryRoundsRecursiveHelper(ctx, wfEntity, repo, visited, &maxRounds)
|
||||
return maxRounds, err
|
||||
}
|
||||
|
||||
func getMaxHistoryRoundsRecursiveHelper(ctx context.Context, wfEntity *entity.Workflow, repo wf.Repository, visited map[string]struct{}, maxRounds *int64) error {
|
||||
visitedKey := fmt.Sprintf("%d:%s", wfEntity.ID, wfEntity.GetVersion())
|
||||
if _, ok := visited[visitedKey]; ok {
|
||||
return nil
|
||||
}
|
||||
visited[visitedKey] = struct{}{}
|
||||
|
||||
var canvas vo.Canvas
|
||||
if err := sonic.UnmarshalString(wfEntity.Canvas, &canvas); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal canvas for workflow %d: %w", wfEntity.ID, err)
|
||||
}
|
||||
|
||||
return collectMaxHistoryRounds(ctx, canvas.Nodes, repo, visited, maxRounds)
|
||||
}
|
||||
|
||||
func collectMaxHistoryRounds(ctx context.Context, nodes []*vo.Node, repo wf.Repository, visited map[string]struct{}, maxRounds *int64) error {
|
||||
for _, node := range nodes {
|
||||
if node == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if node.Data != nil && node.Data.Inputs != nil && node.Data.Inputs.ChatHistorySetting != nil && node.Data.Inputs.ChatHistorySetting.EnableChatHistory {
|
||||
if node.Data.Inputs.ChatHistorySetting.ChatHistoryRound > *maxRounds {
|
||||
*maxRounds = node.Data.Inputs.ChatHistorySetting.ChatHistoryRound
|
||||
}
|
||||
} else if node.Type == entity.NodeTypeLLM.IDStr() && node.Data != nil && node.Data.Inputs != nil && node.Data.Inputs.LLMParam != nil {
|
||||
param := node.Data.Inputs.LLMParam
|
||||
bs, _ := sonic.Marshal(param)
|
||||
llmParam := make(vo.LLMParam, 0)
|
||||
if err := sonic.Unmarshal(bs, &llmParam); err != nil {
|
||||
return err
|
||||
}
|
||||
var chatHistoryEnabled bool
|
||||
var chatHistoryRound int64
|
||||
for _, param := range llmParam {
|
||||
switch param.Name {
|
||||
case "enableChatHistory":
|
||||
if val, ok := param.Input.Value.Content.(bool); ok {
|
||||
b := val
|
||||
chatHistoryEnabled = b
|
||||
}
|
||||
case "chatHistoryRound":
|
||||
if strVal, ok := param.Input.Value.Content.(string); ok {
|
||||
int64Val, err := strconv.ParseInt(strVal, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
chatHistoryRound = int64Val
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if chatHistoryEnabled {
|
||||
if chatHistoryRound > *maxRounds {
|
||||
*maxRounds = chatHistoryRound
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
isSubWorkflow := node.Type == entity.NodeTypeSubWorkflow.IDStr() && node.Data != nil && node.Data.Inputs != nil
|
||||
if isSubWorkflow {
|
||||
workflowIDStr := node.Data.Inputs.WorkflowID
|
||||
if workflowIDStr == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
workflowID, err := strconv.ParseInt(workflowIDStr, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid workflow ID in sub-workflow node %s: %w", node.ID, err)
|
||||
}
|
||||
|
||||
subWfEntity, err := repo.GetEntity(ctx, &vo.GetPolicy{
|
||||
ID: workflowID,
|
||||
QType: ternary.IFElse(len(node.Data.Inputs.WorkflowVersion) == 0, workflowModel.FromDraft, workflowModel.FromSpecificVersion),
|
||||
Version: node.Data.Inputs.WorkflowVersion,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get sub-workflow entity %d: %w", workflowID, err)
|
||||
}
|
||||
|
||||
if err := getMaxHistoryRoundsRecursiveHelper(ctx, subWfEntity, repo, visited, maxRounds); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(node.Blocks) > 0 {
|
||||
if err := collectMaxHistoryRounds(ctx, node.Blocks, repo, visited, maxRounds); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getHistoryRoundsFromNode(ctx context.Context, wfEntity *entity.Workflow, nodeID string, repo wf.Repository) (int64, error) {
|
||||
if wfEntity == nil {
|
||||
return 0, nil
|
||||
}
|
||||
visited := make(map[string]struct{})
|
||||
visitedKey := fmt.Sprintf("%d:%s", wfEntity.ID, wfEntity.GetVersion())
|
||||
if _, ok := visited[visitedKey]; ok {
|
||||
return 0, nil
|
||||
}
|
||||
visited[visitedKey] = struct{}{}
|
||||
maxRounds := int64(0)
|
||||
c := &vo.Canvas{}
|
||||
if err := sonic.UnmarshalString(wfEntity.Canvas, c); err != nil {
|
||||
return 0, fmt.Errorf("failed to unmarshal canvas: %w", err)
|
||||
}
|
||||
var (
|
||||
n *vo.Node
|
||||
nodeFinder func(nodes []*vo.Node) *vo.Node
|
||||
)
|
||||
nodeFinder = func(nodes []*vo.Node) *vo.Node {
|
||||
for i := range nodes {
|
||||
if nodes[i].ID == nodeID {
|
||||
return nodes[i]
|
||||
}
|
||||
if len(nodes[i].Blocks) > 0 {
|
||||
if n := nodeFinder(nodes[i].Blocks); n != nil {
|
||||
return n
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
n = nodeFinder(c.Nodes)
|
||||
if n.Type == entity.NodeTypeLLM.IDStr() {
|
||||
if n.Data == nil || n.Data.Inputs == nil {
|
||||
return 0, nil
|
||||
}
|
||||
param := n.Data.Inputs.LLMParam
|
||||
bs, _ := sonic.Marshal(param)
|
||||
llmParam := make(vo.LLMParam, 0)
|
||||
if err := sonic.Unmarshal(bs, &llmParam); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
var chatHistoryEnabled bool
|
||||
var chatHistoryRound int64
|
||||
for _, param := range llmParam {
|
||||
switch param.Name {
|
||||
case "enableChatHistory":
|
||||
if val, ok := param.Input.Value.Content.(bool); ok {
|
||||
b := val
|
||||
chatHistoryEnabled = b
|
||||
}
|
||||
case "chatHistoryRound":
|
||||
if strVal, ok := param.Input.Value.Content.(string); ok {
|
||||
int64Val, err := strconv.ParseInt(strVal, 10, 64)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
chatHistoryRound = int64Val
|
||||
}
|
||||
}
|
||||
}
|
||||
if chatHistoryEnabled {
|
||||
return chatHistoryRound, nil
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
if n.Type == entity.NodeTypeIntentDetector.IDStr() || n.Type == entity.NodeTypeKnowledgeRetriever.IDStr() {
|
||||
if n.Data != nil && n.Data.Inputs != nil && n.Data.Inputs.ChatHistorySetting != nil && n.Data.Inputs.ChatHistorySetting.EnableChatHistory {
|
||||
return n.Data.Inputs.ChatHistorySetting.ChatHistoryRound, nil
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
if n.Type == entity.NodeTypeSubWorkflow.IDStr() {
|
||||
if n.Data != nil && n.Data.Inputs != nil {
|
||||
workflowIDStr := n.Data.Inputs.WorkflowID
|
||||
if workflowIDStr == "" {
|
||||
return 0, nil
|
||||
}
|
||||
workflowID, err := strconv.ParseInt(workflowIDStr, 10, 64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid workflow ID in sub-workflow node %s: %w", n.ID, err)
|
||||
}
|
||||
subWfEntity, err := repo.GetEntity(ctx, &vo.GetPolicy{
|
||||
ID: workflowID,
|
||||
QType: ternary.IFElse(len(n.Data.Inputs.WorkflowVersion) == 0, workflowModel.FromDraft, workflowModel.FromSpecificVersion),
|
||||
Version: n.Data.Inputs.WorkflowVersion,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get sub-workflow entity %d: %w", workflowID, err)
|
||||
}
|
||||
if err := getMaxHistoryRoundsRecursiveHelper(ctx, subWfEntity, repo, visited, &maxRounds); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return maxRounds, nil
|
||||
}
|
||||
}
|
||||
|
||||
if len(n.Blocks) > 0 {
|
||||
if err := collectMaxHistoryRounds(ctx, n.Blocks, repo, visited, &maxRounds); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return maxRounds, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user