feat: Support for Chat Flow & Agent Support for binding a single chat flow (#765)

Co-authored-by: Yu Yang <72337138+tomasyu985@users.noreply.github.com>
Co-authored-by: zengxiaohui <csu.zengxiaohui@gmail.com>
Co-authored-by: lijunwen.gigoo <lijunwen.gigoo@bytedance.com>
Co-authored-by: lvxinyu.1117 <lvxinyu.1117@bytedance.com>
Co-authored-by: liuyunchao.0510 <liuyunchao.0510@bytedance.com>
Co-authored-by: haozhenfei <37089575+haozhenfei@users.noreply.github.com>
Co-authored-by: July <jiangxujin@bytedance.com>
Co-authored-by: tecvan-fe <fanwenjie.fe@bytedance.com>
This commit is contained in:
Zhj
2025-08-28 21:53:32 +08:00
committed by GitHub
parent bbc615a18e
commit d70101c979
503 changed files with 48036 additions and 3427 deletions

View File

@@ -70,10 +70,15 @@ type RepositoryImpl struct {
workflow.ExecuteHistoryStore
builtinModel cm.BaseChatModel
workflow.WorkflowConfig
workflow.Suggester
}
func NewRepository(idgen idgen.IDGenerator, db *gorm.DB, redis cache.Cmdable, tos storage.Storage,
cpStore einoCompose.CheckPointStore, chatModel cm.BaseChatModel, workflowConfig workflow.WorkflowConfig) workflow.Repository {
cpStore einoCompose.CheckPointStore, chatModel cm.BaseChatModel, workflowConfig workflow.WorkflowConfig) (workflow.Repository, error) {
sg, err := NewSuggester(chatModel)
if err != nil {
return nil, err
}
return &RepositoryImpl{
IDGenerator: idgen,
query: query.Use(db),
@@ -90,9 +95,12 @@ func NewRepository(idgen idgen.IDGenerator, db *gorm.DB, redis cache.Cmdable, to
query: query.Use(db),
redis: redis,
},
builtinModel: chatModel,
Suggester: sg,
WorkflowConfig: workflowConfig,
}
}, nil
}
func (r *RepositoryImpl) CreateMeta(ctx context.Context, meta *vo.Meta) (int64, error) {
@@ -320,13 +328,16 @@ func (r *RepositoryImpl) CreateVersion(ctx context.Context, id int64, info *vo.V
func (r *RepositoryImpl) CreateOrUpdateDraft(ctx context.Context, id int64, draft *vo.DraftInfo) error {
d := &model.WorkflowDraft{
ID: id,
Canvas: draft.Canvas,
InputParams: draft.InputParamsStr,
OutputParams: draft.OutputParamsStr,
Modified: draft.Modified,
TestRunSuccess: draft.TestRunSuccess,
CommitID: draft.CommitID,
ID: id,
Canvas: draft.Canvas,
InputParams: draft.InputParamsStr,
OutputParams: draft.OutputParamsStr,
CommitID: draft.CommitID,
}
if draft.DraftMeta != nil {
d.Modified = draft.DraftMeta.Modified
d.TestRunSuccess = draft.DraftMeta.TestRunSuccess
}
if err := r.query.WorkflowDraft.WithContext(ctx).Save(d); err != nil {
@@ -500,6 +511,10 @@ func (r *RepositoryImpl) UpdateMeta(ctx context.Context, id int64, metaUpdate *v
expressions = append(expressions, r.query.WorkflowMeta.LatestVersion.Value(*metaUpdate.LatestPublishedVersion))
}
if metaUpdate.WorkflowMode != nil {
expressions = append(expressions, r.query.WorkflowMeta.Mode.Value(int32(*metaUpdate.WorkflowMode)))
}
if len(expressions) == 0 {
return nil
}
@@ -551,10 +566,13 @@ func (r *RepositoryImpl) GetEntity(ctx context.Context, policy *vo.GetPolicy) (_
draftMeta = draft.DraftMeta
commitID = draft.CommitID
case workflowModel.FromSpecificVersion:
v, err := r.GetVersion(ctx, policy.ID, policy.Version)
v, existed, err := r.GetVersion(ctx, policy.ID, policy.Version)
if err != nil {
return nil, err
}
if !existed {
return nil, vo.WrapError(errno.ErrWorkflowNotFound, fmt.Errorf("workflow version %s not found for ID %d: %w", policy.Version, policy.ID, err), errorx.KV("id", strconv.FormatInt(policy.ID, 10)))
}
canvas = v.Canvas
inputParams = v.InputParamsStr
outputParams = v.OutputParamsStr
@@ -604,7 +622,117 @@ func (r *RepositoryImpl) GetEntity(ctx context.Context, policy *vo.GetPolicy) (_
}, nil
}
func (r *RepositoryImpl) GetVersion(ctx context.Context, id int64, version string) (_ *vo.VersionInfo, err error) {
func (r *RepositoryImpl) CreateChatFlowRoleConfig(ctx context.Context, chatFlowRole *entity.ChatFlowRole) (int64, error) {
id, err := r.GenID(ctx)
if err != nil {
return 0, vo.WrapError(errno.ErrIDGenError, err)
}
chatFlowRoleConfig := &model.ChatFlowRoleConfig{
ID: id,
WorkflowID: chatFlowRole.WorkflowID,
Name: chatFlowRole.Name,
Description: chatFlowRole.Description,
Avatar: chatFlowRole.AvatarUri,
AudioConfig: chatFlowRole.AudioConfig,
BackgroundImageInfo: chatFlowRole.BackgroundImageInfo,
OnboardingInfo: chatFlowRole.OnboardingInfo,
SuggestReplyInfo: chatFlowRole.SuggestReplyInfo,
UserInputConfig: chatFlowRole.UserInputConfig,
CreatorID: chatFlowRole.CreatorID,
Version: chatFlowRole.Version,
}
if err := r.query.ChatFlowRoleConfig.WithContext(ctx).Create(chatFlowRoleConfig); err != nil {
return 0, vo.WrapError(errno.ErrDatabaseError, fmt.Errorf("create chat flow role: %w", err))
}
return id, nil
}
func (r *RepositoryImpl) UpdateChatFlowRoleConfig(ctx context.Context, workflowID int64, chatFlowRole *vo.ChatFlowRoleUpdate) error {
var expressions []field.AssignExpr
if chatFlowRole.Name != nil {
expressions = append(expressions, r.query.ChatFlowRoleConfig.Name.Value(*chatFlowRole.Name))
}
if chatFlowRole.Description != nil {
expressions = append(expressions, r.query.ChatFlowRoleConfig.Description.Value(*chatFlowRole.Description))
}
if chatFlowRole.AvatarUri != nil {
expressions = append(expressions, r.query.ChatFlowRoleConfig.Avatar.Value(*chatFlowRole.AvatarUri))
}
if chatFlowRole.AudioConfig != nil {
expressions = append(expressions, r.query.ChatFlowRoleConfig.AudioConfig.Value(*chatFlowRole.AudioConfig))
}
if chatFlowRole.BackgroundImageInfo != nil {
expressions = append(expressions, r.query.ChatFlowRoleConfig.BackgroundImageInfo.Value(*chatFlowRole.BackgroundImageInfo))
}
if chatFlowRole.OnboardingInfo != nil {
expressions = append(expressions, r.query.ChatFlowRoleConfig.OnboardingInfo.Value(*chatFlowRole.OnboardingInfo))
}
if chatFlowRole.SuggestReplyInfo != nil {
expressions = append(expressions, r.query.ChatFlowRoleConfig.SuggestReplyInfo.Value(*chatFlowRole.SuggestReplyInfo))
}
if chatFlowRole.UserInputConfig != nil {
expressions = append(expressions, r.query.ChatFlowRoleConfig.UserInputConfig.Value(*chatFlowRole.UserInputConfig))
}
if len(expressions) == 0 {
return nil
}
_, err := r.query.ChatFlowRoleConfig.WithContext(ctx).Where(r.query.ChatFlowRoleConfig.WorkflowID.Eq(workflowID)).
UpdateColumnSimple(expressions...)
if err != nil {
return vo.WrapError(errno.ErrDatabaseError, fmt.Errorf("update chat flow role: %w", err))
}
return nil
}
func (r *RepositoryImpl) GetChatFlowRoleConfig(ctx context.Context, workflowID int64, version string) (_ *entity.ChatFlowRole, err error, isExist bool) {
defer func() {
if err != nil {
err = vo.WrapIfNeeded(errno.ErrDatabaseError, err)
}
}()
role := &model.ChatFlowRoleConfig{}
if version != "" {
role, err = r.query.ChatFlowRoleConfig.WithContext(ctx).Where(r.query.ChatFlowRoleConfig.WorkflowID.Eq(workflowID), r.query.ChatFlowRoleConfig.Version.Eq(version)).First()
} else {
role, err = r.query.ChatFlowRoleConfig.WithContext(ctx).Where(r.query.ChatFlowRoleConfig.WorkflowID.Eq(workflowID)).First()
}
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err, false
}
return nil, fmt.Errorf("failed to get chat flow role for chatflowID %d: %w", workflowID, err), true
}
res := &entity.ChatFlowRole{
ID: role.ID,
WorkflowID: role.WorkflowID,
Name: role.Name,
Description: role.Description,
AvatarUri: role.Avatar,
AudioConfig: role.AudioConfig,
BackgroundImageInfo: role.BackgroundImageInfo,
OnboardingInfo: role.OnboardingInfo,
SuggestReplyInfo: role.SuggestReplyInfo,
UserInputConfig: role.UserInputConfig,
CreatorID: role.CreatorID,
CreatedAt: time.UnixMilli(role.CreatedAt),
}
if role.UpdatedAt > 0 {
res.UpdatedAt = time.UnixMilli(role.UpdatedAt)
}
return res, err, true
}
func (r *RepositoryImpl) DeleteChatFlowRoleConfig(ctx context.Context, id int64, workflowID int64) error {
_, err := r.query.ChatFlowRoleConfig.WithContext(ctx).Where(r.query.ChatFlowRoleConfig.ID.Eq(id), r.query.ChatFlowRoleConfig.WorkflowID.Eq(workflowID)).Delete()
return err
}
func (r *RepositoryImpl) GetVersion(ctx context.Context, id int64, version string) (_ *vo.VersionInfo, existed bool, err error) {
defer func() {
if err != nil {
err = vo.WrapIfNeeded(errno.ErrDatabaseError, err)
@@ -616,9 +744,9 @@ func (r *RepositoryImpl) GetVersion(ctx context.Context, id int64, version strin
First()
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, vo.WrapError(errno.ErrWorkflowNotFound, fmt.Errorf("workflow version %s not found for ID %d: %w", version, id, err), errorx.KV("id", strconv.FormatInt(id, 10)))
return nil, false, nil
}
return nil, fmt.Errorf("failed to get workflow version %s for ID %d: %w", version, id, err)
return nil, false, fmt.Errorf("failed to get workflow version %s for ID %d: %w", version, id, err)
}
return &vo.VersionInfo{
@@ -634,7 +762,29 @@ func (r *RepositoryImpl) GetVersion(ctx context.Context, id int64, version strin
OutputParamsStr: wfVersion.OutputParams,
},
CommitID: wfVersion.CommitID,
}, nil
}, true, nil
}
func (r *RepositoryImpl) GetVersionListByConnectorAndWorkflowID(ctx context.Context, connectorID, workflowID int64, limit int) (_ []string, err error) {
if limit <= 0 {
return nil, vo.WrapError(errno.ErrInvalidParameter, errors.New("limit must be greater than 0"))
}
connectorWorkflowVersion := r.query.ConnectorWorkflowVersion
vl, err := connectorWorkflowVersion.WithContext(ctx).
Where(connectorWorkflowVersion.ConnectorID.Eq(connectorID),
connectorWorkflowVersion.WorkflowID.Eq(workflowID)).
Order(connectorWorkflowVersion.CreatedAt.Desc()).
Limit(limit).
Find()
if err != nil {
return nil, vo.WrapError(errno.ErrDatabaseError, err)
}
var versionList []string
for _, v := range vl {
versionList = append(versionList, v.Version)
}
return versionList, nil
}
func (r *RepositoryImpl) IsApplicationConnectorWorkflowVersion(ctx context.Context, connectorID, workflowID int64, version string) (b bool, err error) {
@@ -767,6 +917,10 @@ func (r *RepositoryImpl) MGetDrafts(ctx context.Context, policy *vo.MGetPolicy)
conditions = append(conditions, r.query.WorkflowMeta.AppID.Eq(0))
}
if q.Mode != nil {
conditions = append(conditions, r.query.WorkflowMeta.Mode.Eq(int32(*q.Mode)))
}
type combinedDraft struct {
model.WorkflowDraft
Name string `gorm:"column:name"`
@@ -933,6 +1087,10 @@ func (r *RepositoryImpl) MGetLatestVersion(ctx context.Context, policy *vo.MGetP
conditions = append(conditions, r.query.WorkflowMeta.AppID.Eq(0))
}
if q.Mode != nil {
conditions = append(conditions, r.query.WorkflowMeta.Mode.Eq(int32(*q.Mode)))
}
type combinedVersion struct {
model.WorkflowMeta
Version string `gorm:"column:version"` // release version
@@ -1157,6 +1315,10 @@ func (r *RepositoryImpl) MGetMetas(ctx context.Context, query *vo.MetaQuery) (
conditions = append(conditions, r.query.WorkflowMeta.AppID.Eq(0))
}
if query.Mode != nil {
conditions = append(conditions, r.query.WorkflowMeta.Mode.Eq(int32(*query.Mode)))
}
var result []*model.WorkflowMeta
workflowMetaDo := r.query.WorkflowMeta.WithContext(ctx).Debug().Where(conditions...)
@@ -1520,6 +1682,7 @@ func (r *RepositoryImpl) CopyWorkflow(ctx context.Context, workflowID int64, pol
IconURI: wfMeta.IconURI,
Desc: wfMeta.Description,
AppID: ternary.IFElse(wfMeta.AppID == 0, (*int64)(nil), ptr.Of(wfMeta.AppID)),
Mode: workflowModel.WorkflowMode(wfMeta.Mode),
},
CanvasInfo: &vo.CanvasInfo{
Canvas: wfDraft.Canvas,
@@ -1594,6 +1757,10 @@ func (r *RepositoryImpl) GetKnowledgeRecallChatModel() cm.BaseChatModel {
return r.builtinModel
}
func (r *RepositoryImpl) GetObjectUrl(ctx context.Context, objectKey string, opts ...storage.GetOptFn) (string, error) {
return r.tos.GetObjectUrl(ctx, objectKey, opts...)
}
func filterDisabledAPIParameters(parametersCfg []*workflow3.APIParameter, m map[string]any) map[string]any {
result := make(map[string]any, len(m))
responseParameterMap := slices.ToMap(parametersCfg, func(p *workflow3.APIParameter) (string, *workflow3.APIParameter) {