feat: Support for Chat Flow & Agent Support for binding a single chat flow (#765)
Co-authored-by: Yu Yang <72337138+tomasyu985@users.noreply.github.com> Co-authored-by: zengxiaohui <csu.zengxiaohui@gmail.com> Co-authored-by: lijunwen.gigoo <lijunwen.gigoo@bytedance.com> Co-authored-by: lvxinyu.1117 <lvxinyu.1117@bytedance.com> Co-authored-by: liuyunchao.0510 <liuyunchao.0510@bytedance.com> Co-authored-by: haozhenfei <37089575+haozhenfei@users.noreply.github.com> Co-authored-by: July <jiangxujin@bytedance.com> Co-authored-by: tecvan-fe <fanwenjie.fe@bytedance.com>
This commit is contained in:
@@ -31,6 +31,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/internal/dal/query"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
@@ -71,27 +72,41 @@ func (dao *MessageDAO) Create(ctx context.Context, msg *entity.Message) (*entity
|
||||
return dao.messagePO2DO(poData), nil
|
||||
}
|
||||
|
||||
func (dao *MessageDAO) List(ctx context.Context, conversationID int64, limit int, cursor int64, direction entity.ScrollPageDirection, messageType *message.MessageType) ([]*entity.Message, bool, error) {
|
||||
func (dao *MessageDAO) List(ctx context.Context, listMeta *entity.ListMeta) ([]*entity.Message, bool, error) {
|
||||
m := dao.query.Message
|
||||
do := m.WithContext(ctx).Debug().Where(m.ConversationID.Eq(conversationID)).Where(m.Status.Eq(int32(entity.MessageStatusAvailable)))
|
||||
do := m.WithContext(ctx).Debug().Where(m.ConversationID.Eq(listMeta.ConversationID)).Where(m.Status.Eq(int32(entity.MessageStatusAvailable)))
|
||||
|
||||
if messageType != nil {
|
||||
do = do.Where(m.MessageType.Eq(string(*messageType)))
|
||||
if len(listMeta.RunID) > 0 {
|
||||
do = do.Where(m.RunID.In(slices.Transform(listMeta.RunID, func(t *int64) int64 {
|
||||
return *t
|
||||
})...))
|
||||
}
|
||||
if len(listMeta.MessageType) > 0 {
|
||||
do = do.Where(m.MessageType.In(slices.Transform(listMeta.MessageType, func(t *message.MessageType) string {
|
||||
return string(*t)
|
||||
})...))
|
||||
}
|
||||
|
||||
if limit > 0 {
|
||||
do = do.Limit(int(limit) + 1)
|
||||
if listMeta.Limit > 0 {
|
||||
do = do.Limit(int(listMeta.Limit) + 1)
|
||||
}
|
||||
|
||||
if cursor > 0 {
|
||||
if direction == entity.ScrollPageDirectionPrev {
|
||||
do = do.Where(m.CreatedAt.Lt(cursor))
|
||||
} else {
|
||||
do = do.Where(m.CreatedAt.Gt(cursor))
|
||||
if listMeta.Cursor > 0 {
|
||||
msg, err := m.Where(m.ID.Eq(listMeta.Cursor)).First()
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if listMeta.Direction == entity.ScrollPageDirectionPrev {
|
||||
do = do.Where(m.CreatedAt.Lt(msg.CreatedAt))
|
||||
do = do.Order(m.CreatedAt.Desc())
|
||||
} else {
|
||||
do = do.Where(m.CreatedAt.Gt(msg.CreatedAt))
|
||||
do = do.Order(m.CreatedAt.Asc())
|
||||
}
|
||||
} else {
|
||||
do = do.Order(m.CreatedAt.Desc())
|
||||
}
|
||||
|
||||
do = do.Order(m.CreatedAt.Desc())
|
||||
messageList, err := do.Find()
|
||||
|
||||
var hasMore bool
|
||||
@@ -103,9 +118,9 @@ func (dao *MessageDAO) List(ctx context.Context, conversationID int64, limit int
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
if len(messageList) > limit {
|
||||
if len(messageList) > int(listMeta.Limit) {
|
||||
hasMore = true
|
||||
messageList = messageList[:limit]
|
||||
messageList = messageList[:int(listMeta.Limit)]
|
||||
}
|
||||
|
||||
return dao.batchMessagePO2DO(messageList), hasMore, nil
|
||||
@@ -113,7 +128,8 @@ func (dao *MessageDAO) List(ctx context.Context, conversationID int64, limit int
|
||||
|
||||
func (dao *MessageDAO) GetByRunIDs(ctx context.Context, runIDs []int64, orderBy string) ([]*entity.Message, error) {
|
||||
m := dao.query.Message
|
||||
do := m.WithContext(ctx).Debug().Where(m.RunID.In(runIDs...))
|
||||
do := m.WithContext(ctx).Debug().Where(m.RunID.In(runIDs...)).Where(m.Status.Eq(int32(entity.MessageStatusAvailable)))
|
||||
|
||||
if orderBy == "DESC" {
|
||||
do = do.Order(m.CreatedAt.Desc())
|
||||
} else {
|
||||
@@ -133,19 +149,37 @@ func (dao *MessageDAO) GetByRunIDs(ctx context.Context, runIDs []int64, orderBy
|
||||
|
||||
func (dao *MessageDAO) Edit(ctx context.Context, msgID int64, msg *message.Message) (int64, error) {
|
||||
m := dao.query.Message
|
||||
columns := dao.buildEditColumns(msg)
|
||||
|
||||
originMsg, err := dao.GetByID(ctx, msgID)
|
||||
if originMsg == nil {
|
||||
return 0, errorx.New(errno.ErrRecordNotFound)
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
columns := dao.buildEditColumns(msg, originMsg)
|
||||
do, err := m.WithContext(ctx).Where(m.ID.Eq(msgID)).UpdateColumns(columns)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if do.RowsAffected == 0 {
|
||||
return 0, errorx.New(errno.ErrRecordNotFound)
|
||||
}
|
||||
|
||||
return do.RowsAffected, nil
|
||||
}
|
||||
|
||||
func (dao *MessageDAO) buildEditColumns(msg *message.Message) map[string]interface{} {
|
||||
func (dao *MessageDAO) buildEditColumns(msg *message.Message, originMsg *entity.Message) map[string]interface{} {
|
||||
columns := make(map[string]interface{})
|
||||
table := dao.query.Message
|
||||
if msg.Content != "" {
|
||||
msg.Role = originMsg.Role
|
||||
columns[table.Content.ColumnName().String()] = msg.Content
|
||||
modelContent, err := dao.buildModelContent(msg)
|
||||
if err == nil {
|
||||
columns[table.ModelContent.ColumnName().String()] = modelContent
|
||||
}
|
||||
}
|
||||
if msg.MessageType != "" {
|
||||
columns[table.MessageType.ColumnName().String()] = msg.MessageType
|
||||
@@ -170,6 +204,11 @@ func (dao *MessageDAO) buildEditColumns(msg *message.Message) map[string]interfa
|
||||
|
||||
columns[table.UpdatedAt.ColumnName().String()] = time.Now().UnixMilli()
|
||||
if msg.Ext != nil {
|
||||
if originMsg.Ext != nil {
|
||||
for k, v := range originMsg.Ext {
|
||||
msg.Ext[k] = v
|
||||
}
|
||||
}
|
||||
ext, err := sonic.MarshalString(msg.Ext)
|
||||
if err == nil {
|
||||
columns[table.Ext.ColumnName().String()] = ext
|
||||
@@ -192,8 +231,8 @@ func (dao *MessageDAO) GetByID(ctx context.Context, msgID int64) (*entity.Messag
|
||||
return dao.messagePO2DO(po), nil
|
||||
}
|
||||
|
||||
func (dao *MessageDAO) Delete(ctx context.Context, msgIDs []int64, runIDs []int64) error {
|
||||
if len(msgIDs) == 0 && len(runIDs) == 0 {
|
||||
func (dao *MessageDAO) Delete(ctx context.Context, delMeta *entity.DeleteMeta) error {
|
||||
if len(delMeta.MessageIDs) == 0 && len(delMeta.RunIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -202,11 +241,14 @@ func (dao *MessageDAO) Delete(ctx context.Context, msgIDs []int64, runIDs []int6
|
||||
m := dao.query.Message
|
||||
do := m.WithContext(ctx)
|
||||
|
||||
if len(runIDs) > 0 {
|
||||
do = do.Where(m.RunID.In(runIDs...))
|
||||
if len(delMeta.RunIDs) > 0 {
|
||||
do = do.Where(m.RunID.In(delMeta.RunIDs...))
|
||||
}
|
||||
if len(msgIDs) > 0 {
|
||||
do = do.Where(m.ID.In(msgIDs...))
|
||||
if len(delMeta.MessageIDs) > 0 {
|
||||
do = do.Where(m.ID.In(delMeta.MessageIDs...))
|
||||
}
|
||||
if delMeta.ConversationID != nil && ptr.From(delMeta.ConversationID) > 0 {
|
||||
do = do.Where(m.ConversationID.Eq(*delMeta.ConversationID))
|
||||
}
|
||||
_, err := do.UpdateColumns(&updateColumns)
|
||||
return err
|
||||
@@ -284,6 +326,9 @@ func (dao *MessageDAO) buildModelContent(msgDO *entity.Message) (string, error)
|
||||
var multiContent []schema.ChatMessagePart
|
||||
for _, contentData := range msgDO.MultiContent {
|
||||
if contentData.Type == message.InputTypeText {
|
||||
if len(msgDO.Content) == 0 && len(contentData.Text) > 0 {
|
||||
msgDO.Content = contentData.Text
|
||||
}
|
||||
continue
|
||||
}
|
||||
one := schema.ChatMessagePart{}
|
||||
|
||||
Reference in New Issue
Block a user