feat: Support for Chat Flow & Agent Support for binding a single chat flow (#765)

Co-authored-by: Yu Yang <72337138+tomasyu985@users.noreply.github.com>
Co-authored-by: zengxiaohui <csu.zengxiaohui@gmail.com>
Co-authored-by: lijunwen.gigoo <lijunwen.gigoo@bytedance.com>
Co-authored-by: lvxinyu.1117 <lvxinyu.1117@bytedance.com>
Co-authored-by: liuyunchao.0510 <liuyunchao.0510@bytedance.com>
Co-authored-by: haozhenfei <37089575+haozhenfei@users.noreply.github.com>
Co-authored-by: July <jiangxujin@bytedance.com>
Co-authored-by: tecvan-fe <fanwenjie.fe@bytedance.com>
This commit is contained in:
Zhj
2025-08-28 21:53:32 +08:00
committed by GitHub
parent bbc615a18e
commit d70101c979
503 changed files with 48036 additions and 3427 deletions

View File

@@ -24,6 +24,7 @@ import (
type Message interface {
List(ctx context.Context, req *entity.ListMeta) (*entity.ListResult, error)
ListWithoutPair(ctx context.Context, req *entity.ListMeta) (*entity.ListResult, error)
PreCreate(ctx context.Context, req *entity.Message) (*entity.Message, error)
Create(ctx context.Context, req *entity.Message) (*entity.Message, error)
GetByRunIDs(ctx context.Context, conversationID int64, runIDs []int64) ([]*entity.Message, error)

View File

@@ -18,6 +18,7 @@ package message
import (
"context"
"sort"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
@@ -51,9 +52,9 @@ func (m *messageImpl) Create(ctx context.Context, msg *entity.Message) (*entity.
func (m *messageImpl) List(ctx context.Context, req *entity.ListMeta) (*entity.ListResult, error) {
resp := &entity.ListResult{}
req.MessageType = []*message.MessageType{ptr.Of(message.MessageTypeQuestion)}
// get message with query
messageList, hasMore, err := m.MessageRepo.List(ctx, req.ConversationID, req.Limit, req.Cursor, req.Direction, ptr.Of(message.MessageTypeQuestion))
messageList, hasMore, err := m.MessageRepo.List(ctx, req)
if err != nil {
return resp, err
}
@@ -62,8 +63,11 @@ func (m *messageImpl) List(ctx context.Context, req *entity.ListMeta) (*entity.L
resp.HasMore = hasMore
if len(messageList) > 0 {
resp.PrevCursor = messageList[len(messageList)-1].CreatedAt
resp.NextCursor = messageList[0].CreatedAt
sort.Slice(messageList, func(i, j int) bool {
return messageList[i].CreatedAt > messageList[j].CreatedAt
})
resp.PrevCursor = messageList[len(messageList)-1].ID
resp.NextCursor = messageList[0].ID
var runIDs []int64
for _, m := range messageList {
@@ -82,6 +86,23 @@ func (m *messageImpl) List(ctx context.Context, req *entity.ListMeta) (*entity.L
return resp, nil
}
func (m *messageImpl) ListWithoutPair(ctx context.Context, req *entity.ListMeta) (*entity.ListResult, error) {
resp := &entity.ListResult{}
messageList, hasMore, err := m.MessageRepo.List(ctx, req)
if err != nil {
return resp, err
}
resp.Direction = req.Direction
resp.HasMore = hasMore
resp.Messages = messageList
if len(messageList) > 0 {
resp.PrevCursor = messageList[0].ID
resp.NextCursor = messageList[len(messageList)-1].ID
}
return resp, nil
}
func (m *messageImpl) GetByRunIDs(ctx context.Context, conversationID int64, runIDs []int64) ([]*entity.Message, error) {
return m.MessageRepo.GetByRunIDs(ctx, runIDs, "ASC")
}
@@ -96,7 +117,7 @@ func (m *messageImpl) Edit(ctx context.Context, req *entity.Message) (*entity.Me
}
func (m *messageImpl) Delete(ctx context.Context, req *entity.DeleteMeta) error {
return m.MessageRepo.Delete(ctx, req.MessageIDs, req.RunIDs)
return m.MessageRepo.Delete(ctx, req)
}
func (m *messageImpl) GetByID(ctx context.Context, id int64) (*entity.Message, error) {

View File

@@ -18,6 +18,7 @@ package message
import (
"context"
"encoding/json"
"testing"
"time"
@@ -145,20 +146,26 @@ func TestCreateMessage(t *testing.T) {
func TestEditMessage(t *testing.T) {
ctx := context.Background()
mockDBGen := orm.NewMockDB()
extData := map[string]string{
"test": "test",
}
ext, _ := json.Marshal(extData)
mockDBGen.AddTable(&model.Message{}).
AddRows(
&model.Message{
ID: 1,
ConversationID: 1,
UserID: "1",
Role: string(schema.User),
RunID: 123,
},
&model.Message{
ID: 2,
ConversationID: 1,
UserID: "1",
Role: string(schema.User),
RunID: 124,
Ext: string(ext),
},
)
@@ -177,7 +184,7 @@ func TestEditMessage(t *testing.T) {
Url: "https://xxxxx.xxxx/file",
Name: "test_file",
}
content := []*message.InputMetaData{
_ = []*message.InputMetaData{
{
Type: message.InputTypeText,
Text: "解析图片中的内容",
@@ -197,56 +204,293 @@ func TestEditMessage(t *testing.T) {
}
resp, err := NewService(components).Edit(ctx, &entity.Message{
ID: 2,
Content: "test edit message",
MultiContent: content,
ID: 2,
Content: "test edit message",
Ext: map[string]string{"newext": "true"},
// MultiContent: content,
})
_ = resp
msOne, err := NewService(components).GetByRunIDs(ctx, 1, []int64{124})
msg, err := NewService(components).GetByID(ctx, 2)
assert.NoError(t, err)
assert.Equal(t, int64(124), msOne[0].RunID)
assert.Equal(t, int64(2), msg.ID)
assert.Equal(t, "test edit message", msg.Content)
var modelContent *schema.Message
err = json.Unmarshal([]byte(msg.ModelContent), &modelContent)
assert.NoError(t, err)
assert.Equal(t, "test edit message", modelContent.Content)
assert.Equal(t, "true", msg.Ext["newext"])
}
func TestGetByRunIDs(t *testing.T) {
//func TestGetByRunIDs(t *testing.T) {
// ctx := context.Background()
//
// mockDBGen := orm.NewMockDB()
//
// mockDBGen.AddTable(&model.Message{}).
// AddRows(
// &model.Message{
// ID: 1,
// ConversationID: 1,
// UserID: "1",
// RunID: 123,
// Content: "test content123",
// },
// &model.Message{
// ID: 2,
// ConversationID: 1,
// UserID: "1",
// Content: "test content124",
// RunID: 124,
// },
// &model.Message{
// ID: 3,
// ConversationID: 1,
// UserID: "1",
// Content: "test content124",
// RunID: 124,
// },
// )
// mockDB, err := mockDBGen.DB()
// assert.NoError(t, err)
// components := &Components{
// MessageRepo: repository.NewMessageRepo(mockDB, nil),
// }
//
// resp, err := NewService(components).GetByRunIDs(ctx, 1, []int64{124})
//
// assert.NoError(t, err)
//
// assert.Len(t, resp, 2)
//}
func TestListWithoutPair(t *testing.T) {
ctx := context.Background()
t.Run("success_with_messages", func(t *testing.T) {
mockDBGen := orm.NewMockDB()
mockDBGen := orm.NewMockDB()
mockDBGen.AddTable(&model.Message{}).
AddRows(
&model.Message{
ID: 1,
ConversationID: 100,
UserID: "user123",
RunID: 200,
Content: "Hello",
MessageType: string(message.MessageTypeAnswer),
Status: 1, // MessageStatusAvailable
CreatedAt: time.Now().UnixMilli(),
},
&model.Message{
ID: 2,
ConversationID: 100,
UserID: "user123",
RunID: 201,
Content: "World",
MessageType: string(message.MessageTypeAnswer),
Status: 1, // MessageStatusAvailable
CreatedAt: time.Now().UnixMilli(),
},
)
mockDBGen.AddTable(&model.Message{}).
AddRows(
&model.Message{
ID: 1,
ConversationID: 1,
UserID: "1",
RunID: 123,
Content: "test content123",
},
&model.Message{
ID: 2,
ConversationID: 1,
UserID: "1",
Content: "test content124",
RunID: 124,
},
&model.Message{
ID: 3,
ConversationID: 1,
UserID: "1",
Content: "test content124",
RunID: 124,
},
)
mockDB, err := mockDBGen.DB()
assert.NoError(t, err)
components := &Components{
MessageRepo: repository.NewMessageRepo(mockDB, nil),
}
mockDB, err := mockDBGen.DB()
assert.NoError(t, err)
resp, err := NewService(components).GetByRunIDs(ctx, 1, []int64{124})
components := &Components{
MessageRepo: repository.NewMessageRepo(mockDB, nil),
}
assert.NoError(t, err)
req := &entity.ListMeta{
ConversationID: 100,
UserID: "user123",
Limit: 10,
Direction: entity.ScrollPageDirectionNext,
}
assert.Len(t, resp, 2)
resp, err := NewService(components).ListWithoutPair(ctx, req)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, entity.ScrollPageDirectionNext, resp.Direction)
assert.False(t, resp.HasMore)
assert.Len(t, resp.Messages, 2)
assert.Equal(t, "Hello", resp.Messages[0].Content)
assert.Equal(t, "World", resp.Messages[1].Content)
})
t.Run("empty_result", func(t *testing.T) {
mockDBGen := orm.NewMockDB()
mockDBGen.AddTable(&model.Message{})
mockDB, err := mockDBGen.DB()
assert.NoError(t, err)
components := &Components{
MessageRepo: repository.NewMessageRepo(mockDB, nil),
}
req := &entity.ListMeta{
ConversationID: 999,
UserID: "user123",
Limit: 10,
Direction: entity.ScrollPageDirectionNext,
}
resp, err := NewService(components).ListWithoutPair(ctx, req)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, entity.ScrollPageDirectionNext, resp.Direction)
assert.False(t, resp.HasMore)
assert.Len(t, resp.Messages, 0)
})
t.Run("pagination_has_more", func(t *testing.T) {
mockDBGen := orm.NewMockDB()
mockDBGen.AddTable(&model.Message{}).
AddRows(
&model.Message{
ID: 1,
ConversationID: 100,
UserID: "user123",
RunID: 200,
Content: "Message 1",
MessageType: string(message.MessageTypeAnswer),
Status: 1,
CreatedAt: time.Now().UnixMilli() - 3000,
},
&model.Message{
ID: 2,
ConversationID: 100,
UserID: "user123",
RunID: 201,
Content: "Message 2",
MessageType: string(message.MessageTypeAnswer),
Status: 1,
CreatedAt: time.Now().UnixMilli() - 2000,
},
&model.Message{
ID: 3,
ConversationID: 100,
UserID: "user123",
RunID: 202,
Content: "Message 3",
MessageType: string(message.MessageTypeAnswer),
Status: 1,
CreatedAt: time.Now().UnixMilli() - 1000,
},
)
mockDB, err := mockDBGen.DB()
assert.NoError(t, err)
components := &Components{
MessageRepo: repository.NewMessageRepo(mockDB, nil),
}
req := &entity.ListMeta{
ConversationID: 100,
UserID: "user123",
Limit: 2,
Direction: entity.ScrollPageDirectionNext,
}
resp, err := NewService(components).ListWithoutPair(ctx, req)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, entity.ScrollPageDirectionNext, resp.Direction)
assert.True(t, resp.HasMore)
assert.Len(t, resp.Messages, 2)
})
t.Run("direction_prev", func(t *testing.T) {
mockDBGen := orm.NewMockDB()
mockDBGen.AddTable(&model.Message{}).
AddRows(
&model.Message{
ID: 1,
ConversationID: 100,
UserID: "user123",
RunID: 200,
Content: "Test message",
MessageType: string(message.MessageTypeAnswer),
Status: 1,
CreatedAt: time.Now().UnixMilli(),
},
)
mockDB, err := mockDBGen.DB()
assert.NoError(t, err)
components := &Components{
MessageRepo: repository.NewMessageRepo(mockDB, nil),
}
req := &entity.ListMeta{
ConversationID: 100,
UserID: "user123",
Limit: 10,
Direction: entity.ScrollPageDirectionPrev,
}
resp, err := NewService(components).ListWithoutPair(ctx, req)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, entity.ScrollPageDirectionPrev, resp.Direction)
assert.False(t, resp.HasMore)
assert.Len(t, resp.Messages, 1)
})
t.Run("with_message_type_filter", func(t *testing.T) {
mockDBGen := orm.NewMockDB()
mockDBGen.AddTable(&model.Message{}).
AddRows(
&model.Message{
ID: 1,
ConversationID: 100,
UserID: "user123",
RunID: 200,
Content: "Answer message",
MessageType: string(message.MessageTypeAnswer),
Status: 1,
CreatedAt: time.Now().UnixMilli(),
},
&model.Message{
ID: 2,
ConversationID: 100,
UserID: "user123",
RunID: 201,
Content: "Question message",
MessageType: string(message.MessageTypeQuestion),
Status: 1,
CreatedAt: time.Now().UnixMilli(),
},
)
mockDB, err := mockDBGen.DB()
assert.NoError(t, err)
components := &Components{
MessageRepo: repository.NewMessageRepo(mockDB, nil),
}
req := &entity.ListMeta{
ConversationID: 100,
UserID: "user123",
Limit: 10,
Direction: entity.ScrollPageDirectionNext,
MessageType: []*message.MessageType{&[]message.MessageType{message.MessageTypeAnswer}[0]},
}
resp, err := NewService(components).ListWithoutPair(ctx, req)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Len(t, resp.Messages, 1)
assert.Equal(t, "Answer message", resp.Messages[0].Content)
})
}