feat: Support for Chat Flow & Agent Support for binding a single chat flow (#765)

Co-authored-by: Yu Yang <72337138+tomasyu985@users.noreply.github.com>
Co-authored-by: zengxiaohui <csu.zengxiaohui@gmail.com>
Co-authored-by: lijunwen.gigoo <lijunwen.gigoo@bytedance.com>
Co-authored-by: lvxinyu.1117 <lvxinyu.1117@bytedance.com>
Co-authored-by: liuyunchao.0510 <liuyunchao.0510@bytedance.com>
Co-authored-by: haozhenfei <37089575+haozhenfei@users.noreply.github.com>
Co-authored-by: July <jiangxujin@bytedance.com>
Co-authored-by: tecvan-fe <fanwenjie.fe@bytedance.com>
This commit is contained in:
Zhj
2025-08-28 21:53:32 +08:00
committed by GitHub
parent bbc615a18e
commit d70101c979
503 changed files with 48036 additions and 3427 deletions

View File

@@ -16,19 +16,22 @@
package entity
import "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
import (
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
)
type Message = message.Message
type ListMeta struct {
ConversationID int64 `json:"conversation_id"`
RunID []*int64 `json:"run_id"`
UserID string `json:"user_id"`
AgentID int64 `json:"agent_id"`
OrderBy *string `json:"order_by"`
Limit int `json:"limit"`
Cursor int64 `json:"cursor"` // message id
Direction ScrollPageDirection `json:"direction"` // "prev" "Next"
ConversationID int64 `json:"conversation_id"`
RunID []*int64 `json:"run_id"`
UserID string `json:"user_id"`
AgentID int64 `json:"agent_id"`
OrderBy *string `json:"order_by"`
Limit int `json:"limit"`
Cursor int64 `json:"cursor"` // message id
Direction ScrollPageDirection `json:"direction"` // "prev" "Next"
MessageType []*message.MessageType `json:"message_type"`
}
type ListResult struct {
@@ -45,8 +48,9 @@ type GetByRunIDsRequest struct {
}
type DeleteMeta struct {
MessageIDs []int64 `json:"message_ids"`
RunIDs []int64 `json:"run_ids"`
ConversationID *int64 `json:"conversation_id"`
MessageIDs []int64 `json:"message_ids"`
RunIDs []int64 `json:"run_ids"`
}
type BrokenMeta struct {

View File

@@ -31,6 +31,7 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/internal/dal/query"
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
"github.com/coze-dev/coze-studio/backend/types/errno"
@@ -71,27 +72,41 @@ func (dao *MessageDAO) Create(ctx context.Context, msg *entity.Message) (*entity
return dao.messagePO2DO(poData), nil
}
func (dao *MessageDAO) List(ctx context.Context, conversationID int64, limit int, cursor int64, direction entity.ScrollPageDirection, messageType *message.MessageType) ([]*entity.Message, bool, error) {
func (dao *MessageDAO) List(ctx context.Context, listMeta *entity.ListMeta) ([]*entity.Message, bool, error) {
m := dao.query.Message
do := m.WithContext(ctx).Debug().Where(m.ConversationID.Eq(conversationID)).Where(m.Status.Eq(int32(entity.MessageStatusAvailable)))
do := m.WithContext(ctx).Debug().Where(m.ConversationID.Eq(listMeta.ConversationID)).Where(m.Status.Eq(int32(entity.MessageStatusAvailable)))
if messageType != nil {
do = do.Where(m.MessageType.Eq(string(*messageType)))
if len(listMeta.RunID) > 0 {
do = do.Where(m.RunID.In(slices.Transform(listMeta.RunID, func(t *int64) int64 {
return *t
})...))
}
if len(listMeta.MessageType) > 0 {
do = do.Where(m.MessageType.In(slices.Transform(listMeta.MessageType, func(t *message.MessageType) string {
return string(*t)
})...))
}
if limit > 0 {
do = do.Limit(int(limit) + 1)
if listMeta.Limit > 0 {
do = do.Limit(int(listMeta.Limit) + 1)
}
if cursor > 0 {
if direction == entity.ScrollPageDirectionPrev {
do = do.Where(m.CreatedAt.Lt(cursor))
} else {
do = do.Where(m.CreatedAt.Gt(cursor))
if listMeta.Cursor > 0 {
msg, err := m.Where(m.ID.Eq(listMeta.Cursor)).First()
if err != nil {
return nil, false, err
}
if listMeta.Direction == entity.ScrollPageDirectionPrev {
do = do.Where(m.CreatedAt.Lt(msg.CreatedAt))
do = do.Order(m.CreatedAt.Desc())
} else {
do = do.Where(m.CreatedAt.Gt(msg.CreatedAt))
do = do.Order(m.CreatedAt.Asc())
}
} else {
do = do.Order(m.CreatedAt.Desc())
}
do = do.Order(m.CreatedAt.Desc())
messageList, err := do.Find()
var hasMore bool
@@ -103,9 +118,9 @@ func (dao *MessageDAO) List(ctx context.Context, conversationID int64, limit int
return nil, false, err
}
if len(messageList) > limit {
if len(messageList) > int(listMeta.Limit) {
hasMore = true
messageList = messageList[:limit]
messageList = messageList[:int(listMeta.Limit)]
}
return dao.batchMessagePO2DO(messageList), hasMore, nil
@@ -113,7 +128,8 @@ func (dao *MessageDAO) List(ctx context.Context, conversationID int64, limit int
func (dao *MessageDAO) GetByRunIDs(ctx context.Context, runIDs []int64, orderBy string) ([]*entity.Message, error) {
m := dao.query.Message
do := m.WithContext(ctx).Debug().Where(m.RunID.In(runIDs...))
do := m.WithContext(ctx).Debug().Where(m.RunID.In(runIDs...)).Where(m.Status.Eq(int32(entity.MessageStatusAvailable)))
if orderBy == "DESC" {
do = do.Order(m.CreatedAt.Desc())
} else {
@@ -133,19 +149,37 @@ func (dao *MessageDAO) GetByRunIDs(ctx context.Context, runIDs []int64, orderBy
func (dao *MessageDAO) Edit(ctx context.Context, msgID int64, msg *message.Message) (int64, error) {
m := dao.query.Message
columns := dao.buildEditColumns(msg)
originMsg, err := dao.GetByID(ctx, msgID)
if originMsg == nil {
return 0, errorx.New(errno.ErrRecordNotFound)
}
if err != nil {
return 0, err
}
columns := dao.buildEditColumns(msg, originMsg)
do, err := m.WithContext(ctx).Where(m.ID.Eq(msgID)).UpdateColumns(columns)
if err != nil {
return 0, err
}
if do.RowsAffected == 0 {
return 0, errorx.New(errno.ErrRecordNotFound)
}
return do.RowsAffected, nil
}
func (dao *MessageDAO) buildEditColumns(msg *message.Message) map[string]interface{} {
func (dao *MessageDAO) buildEditColumns(msg *message.Message, originMsg *entity.Message) map[string]interface{} {
columns := make(map[string]interface{})
table := dao.query.Message
if msg.Content != "" {
msg.Role = originMsg.Role
columns[table.Content.ColumnName().String()] = msg.Content
modelContent, err := dao.buildModelContent(msg)
if err == nil {
columns[table.ModelContent.ColumnName().String()] = modelContent
}
}
if msg.MessageType != "" {
columns[table.MessageType.ColumnName().String()] = msg.MessageType
@@ -170,6 +204,11 @@ func (dao *MessageDAO) buildEditColumns(msg *message.Message) map[string]interfa
columns[table.UpdatedAt.ColumnName().String()] = time.Now().UnixMilli()
if msg.Ext != nil {
if originMsg.Ext != nil {
for k, v := range originMsg.Ext {
msg.Ext[k] = v
}
}
ext, err := sonic.MarshalString(msg.Ext)
if err == nil {
columns[table.Ext.ColumnName().String()] = ext
@@ -192,8 +231,8 @@ func (dao *MessageDAO) GetByID(ctx context.Context, msgID int64) (*entity.Messag
return dao.messagePO2DO(po), nil
}
func (dao *MessageDAO) Delete(ctx context.Context, msgIDs []int64, runIDs []int64) error {
if len(msgIDs) == 0 && len(runIDs) == 0 {
func (dao *MessageDAO) Delete(ctx context.Context, delMeta *entity.DeleteMeta) error {
if len(delMeta.MessageIDs) == 0 && len(delMeta.RunIDs) == 0 {
return nil
}
@@ -202,11 +241,14 @@ func (dao *MessageDAO) Delete(ctx context.Context, msgIDs []int64, runIDs []int6
m := dao.query.Message
do := m.WithContext(ctx)
if len(runIDs) > 0 {
do = do.Where(m.RunID.In(runIDs...))
if len(delMeta.RunIDs) > 0 {
do = do.Where(m.RunID.In(delMeta.RunIDs...))
}
if len(msgIDs) > 0 {
do = do.Where(m.ID.In(msgIDs...))
if len(delMeta.MessageIDs) > 0 {
do = do.Where(m.ID.In(delMeta.MessageIDs...))
}
if delMeta.ConversationID != nil && ptr.From(delMeta.ConversationID) > 0 {
do = do.Where(m.ConversationID.Eq(*delMeta.ConversationID))
}
_, err := do.UpdateColumns(&updateColumns)
return err
@@ -284,6 +326,9 @@ func (dao *MessageDAO) buildModelContent(msgDO *entity.Message) (string, error)
var multiContent []schema.ChatMessagePart
for _, contentData := range msgDO.MultiContent {
if contentData.Type == message.InputTypeText {
if len(msgDO.Content) == 0 && len(contentData.Text) > 0 {
msgDO.Content = contentData.Text
}
continue
}
one := schema.ChatMessagePart{}

View File

@@ -34,10 +34,9 @@ func NewMessageRepo(db *gorm.DB, idGen idgen.IDGenerator) MessageRepo {
type MessageRepo interface {
PreCreate(ctx context.Context, msg *entity.Message) (*entity.Message, error)
Create(ctx context.Context, msg *entity.Message) (*entity.Message, error)
List(ctx context.Context, conversationID int64, limit int, cursor int64,
direction entity.ScrollPageDirection, messageType *message.MessageType) ([]*entity.Message, bool, error)
List(ctx context.Context, listMeta *entity.ListMeta) ([]*entity.Message, bool, error)
GetByRunIDs(ctx context.Context, runIDs []int64, orderBy string) ([]*entity.Message, error)
Edit(ctx context.Context, msgID int64, message *message.Message) (int64, error)
GetByID(ctx context.Context, msgID int64) (*entity.Message, error)
Delete(ctx context.Context, msgIDs []int64, runIDs []int64) error
Delete(ctx context.Context, delMeta *entity.DeleteMeta) error
}

View File

@@ -24,6 +24,7 @@ import (
type Message interface {
List(ctx context.Context, req *entity.ListMeta) (*entity.ListResult, error)
ListWithoutPair(ctx context.Context, req *entity.ListMeta) (*entity.ListResult, error)
PreCreate(ctx context.Context, req *entity.Message) (*entity.Message, error)
Create(ctx context.Context, req *entity.Message) (*entity.Message, error)
GetByRunIDs(ctx context.Context, conversationID int64, runIDs []int64) ([]*entity.Message, error)

View File

@@ -18,6 +18,7 @@ package message
import (
"context"
"sort"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
@@ -51,9 +52,9 @@ func (m *messageImpl) Create(ctx context.Context, msg *entity.Message) (*entity.
func (m *messageImpl) List(ctx context.Context, req *entity.ListMeta) (*entity.ListResult, error) {
resp := &entity.ListResult{}
req.MessageType = []*message.MessageType{ptr.Of(message.MessageTypeQuestion)}
// get message with query
messageList, hasMore, err := m.MessageRepo.List(ctx, req.ConversationID, req.Limit, req.Cursor, req.Direction, ptr.Of(message.MessageTypeQuestion))
messageList, hasMore, err := m.MessageRepo.List(ctx, req)
if err != nil {
return resp, err
}
@@ -62,8 +63,11 @@ func (m *messageImpl) List(ctx context.Context, req *entity.ListMeta) (*entity.L
resp.HasMore = hasMore
if len(messageList) > 0 {
resp.PrevCursor = messageList[len(messageList)-1].CreatedAt
resp.NextCursor = messageList[0].CreatedAt
sort.Slice(messageList, func(i, j int) bool {
return messageList[i].CreatedAt > messageList[j].CreatedAt
})
resp.PrevCursor = messageList[len(messageList)-1].ID
resp.NextCursor = messageList[0].ID
var runIDs []int64
for _, m := range messageList {
@@ -82,6 +86,23 @@ func (m *messageImpl) List(ctx context.Context, req *entity.ListMeta) (*entity.L
return resp, nil
}
func (m *messageImpl) ListWithoutPair(ctx context.Context, req *entity.ListMeta) (*entity.ListResult, error) {
resp := &entity.ListResult{}
messageList, hasMore, err := m.MessageRepo.List(ctx, req)
if err != nil {
return resp, err
}
resp.Direction = req.Direction
resp.HasMore = hasMore
resp.Messages = messageList
if len(messageList) > 0 {
resp.PrevCursor = messageList[0].ID
resp.NextCursor = messageList[len(messageList)-1].ID
}
return resp, nil
}
func (m *messageImpl) GetByRunIDs(ctx context.Context, conversationID int64, runIDs []int64) ([]*entity.Message, error) {
return m.MessageRepo.GetByRunIDs(ctx, runIDs, "ASC")
}
@@ -96,7 +117,7 @@ func (m *messageImpl) Edit(ctx context.Context, req *entity.Message) (*entity.Me
}
func (m *messageImpl) Delete(ctx context.Context, req *entity.DeleteMeta) error {
return m.MessageRepo.Delete(ctx, req.MessageIDs, req.RunIDs)
return m.MessageRepo.Delete(ctx, req)
}
func (m *messageImpl) GetByID(ctx context.Context, id int64) (*entity.Message, error) {

View File

@@ -18,6 +18,7 @@ package message
import (
"context"
"encoding/json"
"testing"
"time"
@@ -145,20 +146,26 @@ func TestCreateMessage(t *testing.T) {
func TestEditMessage(t *testing.T) {
ctx := context.Background()
mockDBGen := orm.NewMockDB()
extData := map[string]string{
"test": "test",
}
ext, _ := json.Marshal(extData)
mockDBGen.AddTable(&model.Message{}).
AddRows(
&model.Message{
ID: 1,
ConversationID: 1,
UserID: "1",
Role: string(schema.User),
RunID: 123,
},
&model.Message{
ID: 2,
ConversationID: 1,
UserID: "1",
Role: string(schema.User),
RunID: 124,
Ext: string(ext),
},
)
@@ -177,7 +184,7 @@ func TestEditMessage(t *testing.T) {
Url: "https://xxxxx.xxxx/file",
Name: "test_file",
}
content := []*message.InputMetaData{
_ = []*message.InputMetaData{
{
Type: message.InputTypeText,
Text: "解析图片中的内容",
@@ -197,56 +204,293 @@ func TestEditMessage(t *testing.T) {
}
resp, err := NewService(components).Edit(ctx, &entity.Message{
ID: 2,
Content: "test edit message",
MultiContent: content,
ID: 2,
Content: "test edit message",
Ext: map[string]string{"newext": "true"},
// MultiContent: content,
})
_ = resp
msOne, err := NewService(components).GetByRunIDs(ctx, 1, []int64{124})
msg, err := NewService(components).GetByID(ctx, 2)
assert.NoError(t, err)
assert.Equal(t, int64(124), msOne[0].RunID)
assert.Equal(t, int64(2), msg.ID)
assert.Equal(t, "test edit message", msg.Content)
var modelContent *schema.Message
err = json.Unmarshal([]byte(msg.ModelContent), &modelContent)
assert.NoError(t, err)
assert.Equal(t, "test edit message", modelContent.Content)
assert.Equal(t, "true", msg.Ext["newext"])
}
func TestGetByRunIDs(t *testing.T) {
//func TestGetByRunIDs(t *testing.T) {
// ctx := context.Background()
//
// mockDBGen := orm.NewMockDB()
//
// mockDBGen.AddTable(&model.Message{}).
// AddRows(
// &model.Message{
// ID: 1,
// ConversationID: 1,
// UserID: "1",
// RunID: 123,
// Content: "test content123",
// },
// &model.Message{
// ID: 2,
// ConversationID: 1,
// UserID: "1",
// Content: "test content124",
// RunID: 124,
// },
// &model.Message{
// ID: 3,
// ConversationID: 1,
// UserID: "1",
// Content: "test content124",
// RunID: 124,
// },
// )
// mockDB, err := mockDBGen.DB()
// assert.NoError(t, err)
// components := &Components{
// MessageRepo: repository.NewMessageRepo(mockDB, nil),
// }
//
// resp, err := NewService(components).GetByRunIDs(ctx, 1, []int64{124})
//
// assert.NoError(t, err)
//
// assert.Len(t, resp, 2)
//}
func TestListWithoutPair(t *testing.T) {
ctx := context.Background()
t.Run("success_with_messages", func(t *testing.T) {
mockDBGen := orm.NewMockDB()
mockDBGen := orm.NewMockDB()
mockDBGen.AddTable(&model.Message{}).
AddRows(
&model.Message{
ID: 1,
ConversationID: 100,
UserID: "user123",
RunID: 200,
Content: "Hello",
MessageType: string(message.MessageTypeAnswer),
Status: 1, // MessageStatusAvailable
CreatedAt: time.Now().UnixMilli(),
},
&model.Message{
ID: 2,
ConversationID: 100,
UserID: "user123",
RunID: 201,
Content: "World",
MessageType: string(message.MessageTypeAnswer),
Status: 1, // MessageStatusAvailable
CreatedAt: time.Now().UnixMilli(),
},
)
mockDBGen.AddTable(&model.Message{}).
AddRows(
&model.Message{
ID: 1,
ConversationID: 1,
UserID: "1",
RunID: 123,
Content: "test content123",
},
&model.Message{
ID: 2,
ConversationID: 1,
UserID: "1",
Content: "test content124",
RunID: 124,
},
&model.Message{
ID: 3,
ConversationID: 1,
UserID: "1",
Content: "test content124",
RunID: 124,
},
)
mockDB, err := mockDBGen.DB()
assert.NoError(t, err)
components := &Components{
MessageRepo: repository.NewMessageRepo(mockDB, nil),
}
mockDB, err := mockDBGen.DB()
assert.NoError(t, err)
resp, err := NewService(components).GetByRunIDs(ctx, 1, []int64{124})
components := &Components{
MessageRepo: repository.NewMessageRepo(mockDB, nil),
}
assert.NoError(t, err)
req := &entity.ListMeta{
ConversationID: 100,
UserID: "user123",
Limit: 10,
Direction: entity.ScrollPageDirectionNext,
}
assert.Len(t, resp, 2)
resp, err := NewService(components).ListWithoutPair(ctx, req)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, entity.ScrollPageDirectionNext, resp.Direction)
assert.False(t, resp.HasMore)
assert.Len(t, resp.Messages, 2)
assert.Equal(t, "Hello", resp.Messages[0].Content)
assert.Equal(t, "World", resp.Messages[1].Content)
})
t.Run("empty_result", func(t *testing.T) {
mockDBGen := orm.NewMockDB()
mockDBGen.AddTable(&model.Message{})
mockDB, err := mockDBGen.DB()
assert.NoError(t, err)
components := &Components{
MessageRepo: repository.NewMessageRepo(mockDB, nil),
}
req := &entity.ListMeta{
ConversationID: 999,
UserID: "user123",
Limit: 10,
Direction: entity.ScrollPageDirectionNext,
}
resp, err := NewService(components).ListWithoutPair(ctx, req)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, entity.ScrollPageDirectionNext, resp.Direction)
assert.False(t, resp.HasMore)
assert.Len(t, resp.Messages, 0)
})
t.Run("pagination_has_more", func(t *testing.T) {
mockDBGen := orm.NewMockDB()
mockDBGen.AddTable(&model.Message{}).
AddRows(
&model.Message{
ID: 1,
ConversationID: 100,
UserID: "user123",
RunID: 200,
Content: "Message 1",
MessageType: string(message.MessageTypeAnswer),
Status: 1,
CreatedAt: time.Now().UnixMilli() - 3000,
},
&model.Message{
ID: 2,
ConversationID: 100,
UserID: "user123",
RunID: 201,
Content: "Message 2",
MessageType: string(message.MessageTypeAnswer),
Status: 1,
CreatedAt: time.Now().UnixMilli() - 2000,
},
&model.Message{
ID: 3,
ConversationID: 100,
UserID: "user123",
RunID: 202,
Content: "Message 3",
MessageType: string(message.MessageTypeAnswer),
Status: 1,
CreatedAt: time.Now().UnixMilli() - 1000,
},
)
mockDB, err := mockDBGen.DB()
assert.NoError(t, err)
components := &Components{
MessageRepo: repository.NewMessageRepo(mockDB, nil),
}
req := &entity.ListMeta{
ConversationID: 100,
UserID: "user123",
Limit: 2,
Direction: entity.ScrollPageDirectionNext,
}
resp, err := NewService(components).ListWithoutPair(ctx, req)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, entity.ScrollPageDirectionNext, resp.Direction)
assert.True(t, resp.HasMore)
assert.Len(t, resp.Messages, 2)
})
t.Run("direction_prev", func(t *testing.T) {
mockDBGen := orm.NewMockDB()
mockDBGen.AddTable(&model.Message{}).
AddRows(
&model.Message{
ID: 1,
ConversationID: 100,
UserID: "user123",
RunID: 200,
Content: "Test message",
MessageType: string(message.MessageTypeAnswer),
Status: 1,
CreatedAt: time.Now().UnixMilli(),
},
)
mockDB, err := mockDBGen.DB()
assert.NoError(t, err)
components := &Components{
MessageRepo: repository.NewMessageRepo(mockDB, nil),
}
req := &entity.ListMeta{
ConversationID: 100,
UserID: "user123",
Limit: 10,
Direction: entity.ScrollPageDirectionPrev,
}
resp, err := NewService(components).ListWithoutPair(ctx, req)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, entity.ScrollPageDirectionPrev, resp.Direction)
assert.False(t, resp.HasMore)
assert.Len(t, resp.Messages, 1)
})
t.Run("with_message_type_filter", func(t *testing.T) {
mockDBGen := orm.NewMockDB()
mockDBGen.AddTable(&model.Message{}).
AddRows(
&model.Message{
ID: 1,
ConversationID: 100,
UserID: "user123",
RunID: 200,
Content: "Answer message",
MessageType: string(message.MessageTypeAnswer),
Status: 1,
CreatedAt: time.Now().UnixMilli(),
},
&model.Message{
ID: 2,
ConversationID: 100,
UserID: "user123",
RunID: 201,
Content: "Question message",
MessageType: string(message.MessageTypeQuestion),
Status: 1,
CreatedAt: time.Now().UnixMilli(),
},
)
mockDB, err := mockDBGen.DB()
assert.NoError(t, err)
components := &Components{
MessageRepo: repository.NewMessageRepo(mockDB, nil),
}
req := &entity.ListMeta{
ConversationID: 100,
UserID: "user123",
Limit: 10,
Direction: entity.ScrollPageDirectionNext,
MessageType: []*message.MessageType{&[]message.MessageType{message.MessageTypeAnswer}[0]},
}
resp, err := NewService(components).ListWithoutPair(ctx, req)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Len(t, resp.Messages, 1)
assert.Equal(t, "Answer message", resp.Messages[0].Content)
})
}