feat: Support for Chat Flow & Agent Support for binding a single chat flow (#765)
Co-authored-by: Yu Yang <72337138+tomasyu985@users.noreply.github.com> Co-authored-by: zengxiaohui <csu.zengxiaohui@gmail.com> Co-authored-by: lijunwen.gigoo <lijunwen.gigoo@bytedance.com> Co-authored-by: lvxinyu.1117 <lvxinyu.1117@bytedance.com> Co-authored-by: liuyunchao.0510 <liuyunchao.0510@bytedance.com> Co-authored-by: haozhenfei <37089575+haozhenfei@users.noreply.github.com> Co-authored-by: July <jiangxujin@bytedance.com> Co-authored-by: tecvan-fe <fanwenjie.fe@bytedance.com>
This commit is contained in:
@@ -70,10 +70,15 @@ type RepositoryImpl struct {
|
||||
workflow.ExecuteHistoryStore
|
||||
builtinModel cm.BaseChatModel
|
||||
workflow.WorkflowConfig
|
||||
workflow.Suggester
|
||||
}
|
||||
|
||||
func NewRepository(idgen idgen.IDGenerator, db *gorm.DB, redis cache.Cmdable, tos storage.Storage,
|
||||
cpStore einoCompose.CheckPointStore, chatModel cm.BaseChatModel, workflowConfig workflow.WorkflowConfig) workflow.Repository {
|
||||
cpStore einoCompose.CheckPointStore, chatModel cm.BaseChatModel, workflowConfig workflow.WorkflowConfig) (workflow.Repository, error) {
|
||||
sg, err := NewSuggester(chatModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &RepositoryImpl{
|
||||
IDGenerator: idgen,
|
||||
query: query.Use(db),
|
||||
@@ -90,9 +95,12 @@ func NewRepository(idgen idgen.IDGenerator, db *gorm.DB, redis cache.Cmdable, to
|
||||
query: query.Use(db),
|
||||
redis: redis,
|
||||
},
|
||||
|
||||
builtinModel: chatModel,
|
||||
Suggester: sg,
|
||||
WorkflowConfig: workflowConfig,
|
||||
}
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
func (r *RepositoryImpl) CreateMeta(ctx context.Context, meta *vo.Meta) (int64, error) {
|
||||
@@ -320,13 +328,16 @@ func (r *RepositoryImpl) CreateVersion(ctx context.Context, id int64, info *vo.V
|
||||
|
||||
func (r *RepositoryImpl) CreateOrUpdateDraft(ctx context.Context, id int64, draft *vo.DraftInfo) error {
|
||||
d := &model.WorkflowDraft{
|
||||
ID: id,
|
||||
Canvas: draft.Canvas,
|
||||
InputParams: draft.InputParamsStr,
|
||||
OutputParams: draft.OutputParamsStr,
|
||||
Modified: draft.Modified,
|
||||
TestRunSuccess: draft.TestRunSuccess,
|
||||
CommitID: draft.CommitID,
|
||||
ID: id,
|
||||
Canvas: draft.Canvas,
|
||||
InputParams: draft.InputParamsStr,
|
||||
OutputParams: draft.OutputParamsStr,
|
||||
CommitID: draft.CommitID,
|
||||
}
|
||||
|
||||
if draft.DraftMeta != nil {
|
||||
d.Modified = draft.DraftMeta.Modified
|
||||
d.TestRunSuccess = draft.DraftMeta.TestRunSuccess
|
||||
}
|
||||
|
||||
if err := r.query.WorkflowDraft.WithContext(ctx).Save(d); err != nil {
|
||||
@@ -500,6 +511,10 @@ func (r *RepositoryImpl) UpdateMeta(ctx context.Context, id int64, metaUpdate *v
|
||||
expressions = append(expressions, r.query.WorkflowMeta.LatestVersion.Value(*metaUpdate.LatestPublishedVersion))
|
||||
}
|
||||
|
||||
if metaUpdate.WorkflowMode != nil {
|
||||
expressions = append(expressions, r.query.WorkflowMeta.Mode.Value(int32(*metaUpdate.WorkflowMode)))
|
||||
}
|
||||
|
||||
if len(expressions) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -551,10 +566,13 @@ func (r *RepositoryImpl) GetEntity(ctx context.Context, policy *vo.GetPolicy) (_
|
||||
draftMeta = draft.DraftMeta
|
||||
commitID = draft.CommitID
|
||||
case workflowModel.FromSpecificVersion:
|
||||
v, err := r.GetVersion(ctx, policy.ID, policy.Version)
|
||||
v, existed, err := r.GetVersion(ctx, policy.ID, policy.Version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !existed {
|
||||
return nil, vo.WrapError(errno.ErrWorkflowNotFound, fmt.Errorf("workflow version %s not found for ID %d: %w", policy.Version, policy.ID, err), errorx.KV("id", strconv.FormatInt(policy.ID, 10)))
|
||||
}
|
||||
canvas = v.Canvas
|
||||
inputParams = v.InputParamsStr
|
||||
outputParams = v.OutputParamsStr
|
||||
@@ -604,7 +622,117 @@ func (r *RepositoryImpl) GetEntity(ctx context.Context, policy *vo.GetPolicy) (_
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *RepositoryImpl) GetVersion(ctx context.Context, id int64, version string) (_ *vo.VersionInfo, err error) {
|
||||
func (r *RepositoryImpl) CreateChatFlowRoleConfig(ctx context.Context, chatFlowRole *entity.ChatFlowRole) (int64, error) {
|
||||
id, err := r.GenID(ctx)
|
||||
if err != nil {
|
||||
return 0, vo.WrapError(errno.ErrIDGenError, err)
|
||||
}
|
||||
chatFlowRoleConfig := &model.ChatFlowRoleConfig{
|
||||
ID: id,
|
||||
WorkflowID: chatFlowRole.WorkflowID,
|
||||
Name: chatFlowRole.Name,
|
||||
Description: chatFlowRole.Description,
|
||||
Avatar: chatFlowRole.AvatarUri,
|
||||
AudioConfig: chatFlowRole.AudioConfig,
|
||||
BackgroundImageInfo: chatFlowRole.BackgroundImageInfo,
|
||||
OnboardingInfo: chatFlowRole.OnboardingInfo,
|
||||
SuggestReplyInfo: chatFlowRole.SuggestReplyInfo,
|
||||
UserInputConfig: chatFlowRole.UserInputConfig,
|
||||
CreatorID: chatFlowRole.CreatorID,
|
||||
Version: chatFlowRole.Version,
|
||||
}
|
||||
|
||||
if err := r.query.ChatFlowRoleConfig.WithContext(ctx).Create(chatFlowRoleConfig); err != nil {
|
||||
return 0, vo.WrapError(errno.ErrDatabaseError, fmt.Errorf("create chat flow role: %w", err))
|
||||
}
|
||||
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (r *RepositoryImpl) UpdateChatFlowRoleConfig(ctx context.Context, workflowID int64, chatFlowRole *vo.ChatFlowRoleUpdate) error {
|
||||
var expressions []field.AssignExpr
|
||||
if chatFlowRole.Name != nil {
|
||||
expressions = append(expressions, r.query.ChatFlowRoleConfig.Name.Value(*chatFlowRole.Name))
|
||||
}
|
||||
if chatFlowRole.Description != nil {
|
||||
expressions = append(expressions, r.query.ChatFlowRoleConfig.Description.Value(*chatFlowRole.Description))
|
||||
}
|
||||
if chatFlowRole.AvatarUri != nil {
|
||||
expressions = append(expressions, r.query.ChatFlowRoleConfig.Avatar.Value(*chatFlowRole.AvatarUri))
|
||||
}
|
||||
if chatFlowRole.AudioConfig != nil {
|
||||
expressions = append(expressions, r.query.ChatFlowRoleConfig.AudioConfig.Value(*chatFlowRole.AudioConfig))
|
||||
}
|
||||
if chatFlowRole.BackgroundImageInfo != nil {
|
||||
expressions = append(expressions, r.query.ChatFlowRoleConfig.BackgroundImageInfo.Value(*chatFlowRole.BackgroundImageInfo))
|
||||
}
|
||||
if chatFlowRole.OnboardingInfo != nil {
|
||||
expressions = append(expressions, r.query.ChatFlowRoleConfig.OnboardingInfo.Value(*chatFlowRole.OnboardingInfo))
|
||||
}
|
||||
if chatFlowRole.SuggestReplyInfo != nil {
|
||||
expressions = append(expressions, r.query.ChatFlowRoleConfig.SuggestReplyInfo.Value(*chatFlowRole.SuggestReplyInfo))
|
||||
}
|
||||
if chatFlowRole.UserInputConfig != nil {
|
||||
expressions = append(expressions, r.query.ChatFlowRoleConfig.UserInputConfig.Value(*chatFlowRole.UserInputConfig))
|
||||
}
|
||||
|
||||
if len(expressions) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := r.query.ChatFlowRoleConfig.WithContext(ctx).Where(r.query.ChatFlowRoleConfig.WorkflowID.Eq(workflowID)).
|
||||
UpdateColumnSimple(expressions...)
|
||||
if err != nil {
|
||||
return vo.WrapError(errno.ErrDatabaseError, fmt.Errorf("update chat flow role: %w", err))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RepositoryImpl) GetChatFlowRoleConfig(ctx context.Context, workflowID int64, version string) (_ *entity.ChatFlowRole, err error, isExist bool) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = vo.WrapIfNeeded(errno.ErrDatabaseError, err)
|
||||
}
|
||||
}()
|
||||
role := &model.ChatFlowRoleConfig{}
|
||||
if version != "" {
|
||||
role, err = r.query.ChatFlowRoleConfig.WithContext(ctx).Where(r.query.ChatFlowRoleConfig.WorkflowID.Eq(workflowID), r.query.ChatFlowRoleConfig.Version.Eq(version)).First()
|
||||
} else {
|
||||
role, err = r.query.ChatFlowRoleConfig.WithContext(ctx).Where(r.query.ChatFlowRoleConfig.WorkflowID.Eq(workflowID)).First()
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, err, false
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get chat flow role for chatflowID %d: %w", workflowID, err), true
|
||||
}
|
||||
res := &entity.ChatFlowRole{
|
||||
ID: role.ID,
|
||||
WorkflowID: role.WorkflowID,
|
||||
Name: role.Name,
|
||||
Description: role.Description,
|
||||
AvatarUri: role.Avatar,
|
||||
AudioConfig: role.AudioConfig,
|
||||
BackgroundImageInfo: role.BackgroundImageInfo,
|
||||
OnboardingInfo: role.OnboardingInfo,
|
||||
SuggestReplyInfo: role.SuggestReplyInfo,
|
||||
UserInputConfig: role.UserInputConfig,
|
||||
CreatorID: role.CreatorID,
|
||||
CreatedAt: time.UnixMilli(role.CreatedAt),
|
||||
}
|
||||
if role.UpdatedAt > 0 {
|
||||
res.UpdatedAt = time.UnixMilli(role.UpdatedAt)
|
||||
}
|
||||
return res, err, true
|
||||
}
|
||||
|
||||
func (r *RepositoryImpl) DeleteChatFlowRoleConfig(ctx context.Context, id int64, workflowID int64) error {
|
||||
_, err := r.query.ChatFlowRoleConfig.WithContext(ctx).Where(r.query.ChatFlowRoleConfig.ID.Eq(id), r.query.ChatFlowRoleConfig.WorkflowID.Eq(workflowID)).Delete()
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *RepositoryImpl) GetVersion(ctx context.Context, id int64, version string) (_ *vo.VersionInfo, existed bool, err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = vo.WrapIfNeeded(errno.ErrDatabaseError, err)
|
||||
@@ -616,9 +744,9 @@ func (r *RepositoryImpl) GetVersion(ctx context.Context, id int64, version strin
|
||||
First()
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, vo.WrapError(errno.ErrWorkflowNotFound, fmt.Errorf("workflow version %s not found for ID %d: %w", version, id, err), errorx.KV("id", strconv.FormatInt(id, 10)))
|
||||
return nil, false, nil
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get workflow version %s for ID %d: %w", version, id, err)
|
||||
return nil, false, fmt.Errorf("failed to get workflow version %s for ID %d: %w", version, id, err)
|
||||
}
|
||||
|
||||
return &vo.VersionInfo{
|
||||
@@ -634,7 +762,29 @@ func (r *RepositoryImpl) GetVersion(ctx context.Context, id int64, version strin
|
||||
OutputParamsStr: wfVersion.OutputParams,
|
||||
},
|
||||
CommitID: wfVersion.CommitID,
|
||||
}, nil
|
||||
}, true, nil
|
||||
}
|
||||
|
||||
func (r *RepositoryImpl) GetVersionListByConnectorAndWorkflowID(ctx context.Context, connectorID, workflowID int64, limit int) (_ []string, err error) {
|
||||
if limit <= 0 {
|
||||
return nil, vo.WrapError(errno.ErrInvalidParameter, errors.New("limit must be greater than 0"))
|
||||
}
|
||||
|
||||
connectorWorkflowVersion := r.query.ConnectorWorkflowVersion
|
||||
vl, err := connectorWorkflowVersion.WithContext(ctx).
|
||||
Where(connectorWorkflowVersion.ConnectorID.Eq(connectorID),
|
||||
connectorWorkflowVersion.WorkflowID.Eq(workflowID)).
|
||||
Order(connectorWorkflowVersion.CreatedAt.Desc()).
|
||||
Limit(limit).
|
||||
Find()
|
||||
if err != nil {
|
||||
return nil, vo.WrapError(errno.ErrDatabaseError, err)
|
||||
}
|
||||
var versionList []string
|
||||
for _, v := range vl {
|
||||
versionList = append(versionList, v.Version)
|
||||
}
|
||||
return versionList, nil
|
||||
}
|
||||
|
||||
func (r *RepositoryImpl) IsApplicationConnectorWorkflowVersion(ctx context.Context, connectorID, workflowID int64, version string) (b bool, err error) {
|
||||
@@ -767,6 +917,10 @@ func (r *RepositoryImpl) MGetDrafts(ctx context.Context, policy *vo.MGetPolicy)
|
||||
conditions = append(conditions, r.query.WorkflowMeta.AppID.Eq(0))
|
||||
}
|
||||
|
||||
if q.Mode != nil {
|
||||
conditions = append(conditions, r.query.WorkflowMeta.Mode.Eq(int32(*q.Mode)))
|
||||
}
|
||||
|
||||
type combinedDraft struct {
|
||||
model.WorkflowDraft
|
||||
Name string `gorm:"column:name"`
|
||||
@@ -933,6 +1087,10 @@ func (r *RepositoryImpl) MGetLatestVersion(ctx context.Context, policy *vo.MGetP
|
||||
conditions = append(conditions, r.query.WorkflowMeta.AppID.Eq(0))
|
||||
}
|
||||
|
||||
if q.Mode != nil {
|
||||
conditions = append(conditions, r.query.WorkflowMeta.Mode.Eq(int32(*q.Mode)))
|
||||
}
|
||||
|
||||
type combinedVersion struct {
|
||||
model.WorkflowMeta
|
||||
Version string `gorm:"column:version"` // release version
|
||||
@@ -1157,6 +1315,10 @@ func (r *RepositoryImpl) MGetMetas(ctx context.Context, query *vo.MetaQuery) (
|
||||
conditions = append(conditions, r.query.WorkflowMeta.AppID.Eq(0))
|
||||
}
|
||||
|
||||
if query.Mode != nil {
|
||||
conditions = append(conditions, r.query.WorkflowMeta.Mode.Eq(int32(*query.Mode)))
|
||||
}
|
||||
|
||||
var result []*model.WorkflowMeta
|
||||
|
||||
workflowMetaDo := r.query.WorkflowMeta.WithContext(ctx).Debug().Where(conditions...)
|
||||
@@ -1520,6 +1682,7 @@ func (r *RepositoryImpl) CopyWorkflow(ctx context.Context, workflowID int64, pol
|
||||
IconURI: wfMeta.IconURI,
|
||||
Desc: wfMeta.Description,
|
||||
AppID: ternary.IFElse(wfMeta.AppID == 0, (*int64)(nil), ptr.Of(wfMeta.AppID)),
|
||||
Mode: workflowModel.WorkflowMode(wfMeta.Mode),
|
||||
},
|
||||
CanvasInfo: &vo.CanvasInfo{
|
||||
Canvas: wfDraft.Canvas,
|
||||
@@ -1594,6 +1757,10 @@ func (r *RepositoryImpl) GetKnowledgeRecallChatModel() cm.BaseChatModel {
|
||||
return r.builtinModel
|
||||
}
|
||||
|
||||
func (r *RepositoryImpl) GetObjectUrl(ctx context.Context, objectKey string, opts ...storage.GetOptFn) (string, error) {
|
||||
return r.tos.GetObjectUrl(ctx, objectKey, opts...)
|
||||
}
|
||||
|
||||
func filterDisabledAPIParameters(parametersCfg []*workflow3.APIParameter, m map[string]any) map[string]any {
|
||||
result := make(map[string]any, len(m))
|
||||
responseParameterMap := slices.ToMap(parametersCfg, func(p *workflow3.APIParameter) (string, *workflow3.APIParameter) {
|
||||
|
||||
Reference in New Issue
Block a user