feat: manually mirror opencoze's code from bytedance
Change-Id: I09a73aadda978ad9511264a756b2ce51f5761adf
This commit is contained in:
32
backend/domain/conversation/message/entity/const.go
Normal file
32
backend/domain/conversation/message/entity/const.go
Normal file
@@ -0,0 +1,32 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package entity
|
||||
|
||||
type ScrollPageDirection string
|
||||
|
||||
const (
|
||||
ScrollPageDirectionPrev ScrollPageDirection = "up"
|
||||
ScrollPageDirectionNext ScrollPageDirection = "down"
|
||||
)
|
||||
|
||||
type MessageStatus int32
|
||||
|
||||
const (
|
||||
MessageStatusAvailable MessageStatus = 1
|
||||
MessageStatusDeleted MessageStatus = 2
|
||||
MessageStatusBroken MessageStatus = 4
|
||||
)
|
||||
64
backend/domain/conversation/message/entity/knowledge.go
Normal file
64
backend/domain/conversation/message/entity/knowledge.go
Normal file
@@ -0,0 +1,64 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package entity
|
||||
|
||||
type VerboseInfo struct {
|
||||
MessageType string `json:"msg_type"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
type VerboseData struct {
|
||||
Chunks []RecallDataInfo `json:"chunks"`
|
||||
OriReq string `json:"ori_req"`
|
||||
StatusCode int `json:"status_code"`
|
||||
}
|
||||
|
||||
type RecallDataInfo struct {
|
||||
Slice string `json:"slice"`
|
||||
Score float64 `json:"score"`
|
||||
Meta MetaInfo `json:"meta"`
|
||||
}
|
||||
|
||||
type MetaInfo struct {
|
||||
Dataset DatasetInfo `json:"dataset"`
|
||||
Document DocumentInfo `json:"document"`
|
||||
Link LinkInfo `json:"link"`
|
||||
Card CardInfo `json:"card"`
|
||||
}
|
||||
|
||||
type DatasetInfo struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type DocumentInfo struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
FormatType int32 `json:"format_type"`
|
||||
SourceType int32 `json:"source_type"`
|
||||
}
|
||||
|
||||
type LinkInfo struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
type CardInfo struct {
|
||||
Title string `json:"title"`
|
||||
Con string `json:"con"`
|
||||
Index string `json:"index"`
|
||||
}
|
||||
55
backend/domain/conversation/message/entity/message.go
Normal file
55
backend/domain/conversation/message/entity/message.go
Normal file
@@ -0,0 +1,55 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package entity
|
||||
|
||||
import "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
|
||||
type Message = message.Message
|
||||
|
||||
type ListMeta struct {
|
||||
ConversationID int64 `json:"conversation_id"`
|
||||
RunID []*int64 `json:"run_id"`
|
||||
UserID string `json:"user_id"`
|
||||
AgentID int64 `json:"agent_id"`
|
||||
OrderBy *string `json:"order_by"`
|
||||
Limit int `json:"limit"`
|
||||
Cursor int64 `json:"cursor"` // message id
|
||||
Direction ScrollPageDirection `json:"direction"` // "prev" "Next"
|
||||
}
|
||||
|
||||
type ListResult struct {
|
||||
Messages []*Message `json:"messages"`
|
||||
PrevCursor int64 `json:"prev_cursor"`
|
||||
NextCursor int64 `json:"next_cursor"`
|
||||
HasMore bool `json:"has_more"`
|
||||
Direction ScrollPageDirection `json:"direction"`
|
||||
}
|
||||
|
||||
type GetByRunIDsRequest struct {
|
||||
ConversationID int64 `json:"conversation_id"`
|
||||
RunID []int64 `json:"run_id"`
|
||||
}
|
||||
|
||||
type DeleteMeta struct {
|
||||
MessageIDs []int64 `json:"message_ids"`
|
||||
RunIDs []int64 `json:"run_ids"`
|
||||
}
|
||||
|
||||
type BrokenMeta struct {
|
||||
ID int64 `json:"id"`
|
||||
Position *int32 `json:"position"`
|
||||
}
|
||||
55
backend/domain/conversation/message/entity/message_ext.go
Normal file
55
backend/domain/conversation/message/entity/message_ext.go
Normal file
@@ -0,0 +1,55 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package entity
|
||||
|
||||
type MessageExtKey string
|
||||
|
||||
const (
|
||||
MessageExtKeyInputTokens MessageExtKey = "input_tokens"
|
||||
MessageExtKeyOutputTokens MessageExtKey = "output_tokens"
|
||||
MessageExtKeyToken MessageExtKey = "token"
|
||||
MessageExtKeyPluginStatus MessageExtKey = "plugin_status"
|
||||
MessageExtKeyTimeCost MessageExtKey = "time_cost"
|
||||
MessageExtKeyWorkflowTokens MessageExtKey = "workflow_tokens"
|
||||
MessageExtKeyBotState MessageExtKey = "bot_state"
|
||||
MessageExtKeyPluginRequest MessageExtKey = "plugin_request"
|
||||
MessageExtKeyToolName MessageExtKey = "tool_name"
|
||||
MessageExtKeyPlugin MessageExtKey = "plugin"
|
||||
MessageExtKeyMockHitInfo MessageExtKey = "mock_hit_info"
|
||||
MessageExtKeyMessageTitle MessageExtKey = "message_title"
|
||||
MessageExtKeyStreamPluginRunning MessageExtKey = "stream_plugin_running"
|
||||
MessageExtKeyExecuteDisplayName MessageExtKey = "execute_display_name"
|
||||
MessageExtKeyTaskType MessageExtKey = "task_type"
|
||||
MessageExtKeyCallID MessageExtKey = "call_id"
|
||||
ExtKeyResumeInfo MessageExtKey = "resume_info"
|
||||
ExtKeyBreakPoint MessageExtKey = "break_point"
|
||||
ExtKeyToolCallsIDs MessageExtKey = "tool_calls_ids"
|
||||
ExtKeyRequiresAction MessageExtKey = "requires_action"
|
||||
)
|
||||
|
||||
type BotStateExt struct {
|
||||
BotID string `json:"bot_id"`
|
||||
AgentName string `json:"agent_name"`
|
||||
AgentID string `json:"agent_id"`
|
||||
Awaiting string `json:"awaiting"`
|
||||
}
|
||||
|
||||
type UsageExt struct {
|
||||
TotalCount int64 `json:"total_count"`
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
}
|
||||
365
backend/domain/conversation/message/internal/dal/message.go
Normal file
365
backend/domain/conversation/message/internal/dal/message.go
Normal file
@@ -0,0 +1,365 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package dal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/internal/dal/model"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/internal/dal/query"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
type MessageDAO struct {
|
||||
query *query.Query
|
||||
idgen idgen.IDGenerator
|
||||
}
|
||||
|
||||
func NewMessageDAO(db *gorm.DB, idgen idgen.IDGenerator) *MessageDAO {
|
||||
return &MessageDAO{
|
||||
query: query.Use(db),
|
||||
idgen: idgen,
|
||||
}
|
||||
}
|
||||
|
||||
func (dao *MessageDAO) Create(ctx context.Context, msg *entity.Message) (*entity.Message, error) {
|
||||
poData, err := dao.messageDO2PO(ctx, msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
do := dao.query.Message.WithContext(ctx).Debug()
|
||||
cErr := do.Create(poData)
|
||||
if cErr != nil {
|
||||
return nil, cErr
|
||||
}
|
||||
|
||||
return dao.messagePO2DO(poData), nil
|
||||
}
|
||||
|
||||
func (dao *MessageDAO) List(ctx context.Context, conversationID int64, limit int, cursor int64, direction entity.ScrollPageDirection, messageType *message.MessageType) ([]*entity.Message, bool, error) {
|
||||
m := dao.query.Message
|
||||
do := m.WithContext(ctx).Debug().Where(m.ConversationID.Eq(conversationID)).Where(m.Status.Eq(int32(entity.MessageStatusAvailable)))
|
||||
|
||||
if messageType != nil {
|
||||
do = do.Where(m.MessageType.Eq(string(*messageType)))
|
||||
}
|
||||
|
||||
if limit > 0 {
|
||||
do = do.Limit(int(limit) + 1)
|
||||
}
|
||||
|
||||
if cursor > 0 {
|
||||
if direction == entity.ScrollPageDirectionPrev {
|
||||
do = do.Where(m.CreatedAt.Lt(cursor))
|
||||
} else {
|
||||
do = do.Where(m.CreatedAt.Gt(cursor))
|
||||
}
|
||||
}
|
||||
|
||||
do = do.Order(m.CreatedAt.Desc())
|
||||
messageList, err := do.Find()
|
||||
|
||||
var hasMore bool
|
||||
|
||||
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, hasMore, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
if len(messageList) > limit {
|
||||
hasMore = true
|
||||
messageList = messageList[:limit]
|
||||
}
|
||||
|
||||
return dao.batchMessagePO2DO(messageList), hasMore, nil
|
||||
}
|
||||
|
||||
func (dao *MessageDAO) GetByRunIDs(ctx context.Context, runIDs []int64, orderBy string) ([]*entity.Message, error) {
|
||||
m := dao.query.Message
|
||||
do := m.WithContext(ctx).Debug().Where(m.RunID.In(runIDs...))
|
||||
if orderBy == "DESC" {
|
||||
do = do.Order(m.CreatedAt.Desc())
|
||||
} else {
|
||||
do = do.Order(m.CreatedAt.Asc())
|
||||
}
|
||||
poList, err := do.Find()
|
||||
|
||||
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return dao.batchMessagePO2DO(poList), nil
|
||||
}
|
||||
|
||||
func (dao *MessageDAO) Edit(ctx context.Context, msgID int64, msg *message.Message) (int64, error) {
|
||||
m := dao.query.Message
|
||||
columns := dao.buildEditColumns(msg)
|
||||
do, err := m.WithContext(ctx).Where(m.ID.Eq(msgID)).UpdateColumns(columns)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return do.RowsAffected, nil
|
||||
}
|
||||
|
||||
func (dao *MessageDAO) buildEditColumns(msg *message.Message) map[string]interface{} {
|
||||
columns := make(map[string]interface{})
|
||||
table := dao.query.Message
|
||||
if msg.Content != "" {
|
||||
columns[table.Content.ColumnName().String()] = msg.Content
|
||||
}
|
||||
if msg.MessageType != "" {
|
||||
columns[table.MessageType.ColumnName().String()] = msg.MessageType
|
||||
}
|
||||
if msg.ContentType != "" {
|
||||
columns[table.ContentType.ColumnName().String()] = msg.ContentType
|
||||
}
|
||||
if len(msg.ReasoningContent) > 0 {
|
||||
columns[table.ReasoningContent.ColumnName().String()] = msg.ReasoningContent
|
||||
}
|
||||
|
||||
if msg.Position > 0 {
|
||||
columns[table.BrokenPosition.ColumnName().String()] = msg.Position
|
||||
}
|
||||
if msg.Status > 0 {
|
||||
columns[table.Status.ColumnName().String()] = msg.Status
|
||||
}
|
||||
|
||||
if len(msg.ModelContent) > 0 {
|
||||
columns[table.ModelContent.ColumnName().String()] = msg.ModelContent
|
||||
}
|
||||
|
||||
columns[table.UpdatedAt.ColumnName().String()] = time.Now().UnixMilli()
|
||||
if msg.Ext != nil {
|
||||
ext, err := sonic.MarshalString(msg.Ext)
|
||||
if err == nil {
|
||||
columns[table.Ext.ColumnName().String()] = ext
|
||||
}
|
||||
}
|
||||
return columns
|
||||
}
|
||||
|
||||
func (dao *MessageDAO) GetByID(ctx context.Context, msgID int64) (*entity.Message, error) {
|
||||
m := dao.query.Message
|
||||
do := m.WithContext(ctx).Where(m.ID.Eq(msgID))
|
||||
po, err := do.First()
|
||||
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return dao.messagePO2DO(po), nil
|
||||
}
|
||||
|
||||
func (dao *MessageDAO) Delete(ctx context.Context, msgIDs []int64, runIDs []int64) error {
|
||||
if len(msgIDs) == 0 && len(runIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
updateColumns := make(map[string]interface{})
|
||||
updateColumns["status"] = int32(entity.MessageStatusDeleted)
|
||||
m := dao.query.Message
|
||||
do := m.WithContext(ctx)
|
||||
|
||||
if len(runIDs) > 0 {
|
||||
do = do.Where(m.RunID.In(runIDs...))
|
||||
}
|
||||
if len(msgIDs) > 0 {
|
||||
do = do.Where(m.ID.In(msgIDs...))
|
||||
}
|
||||
_, err := do.UpdateColumns(&updateColumns)
|
||||
return err
|
||||
}
|
||||
|
||||
func (dao *MessageDAO) messageDO2PO(ctx context.Context, msgDo *entity.Message) (*model.Message, error) {
|
||||
id, gErr := dao.idgen.GenID(ctx)
|
||||
if gErr != nil {
|
||||
return nil, gErr
|
||||
}
|
||||
msgPO := &model.Message{
|
||||
ID: id,
|
||||
ConversationID: msgDo.ConversationID,
|
||||
RunID: msgDo.RunID,
|
||||
AgentID: msgDo.AgentID,
|
||||
SectionID: msgDo.SectionID,
|
||||
UserID: msgDo.UserID,
|
||||
Role: string(msgDo.Role),
|
||||
ContentType: string(msgDo.ContentType),
|
||||
MessageType: string(msgDo.MessageType),
|
||||
DisplayContent: msgDo.DisplayContent,
|
||||
Content: msgDo.Content,
|
||||
BrokenPosition: msgDo.Position,
|
||||
Status: int32(entity.MessageStatusAvailable),
|
||||
CreatedAt: time.Now().UnixMilli(),
|
||||
UpdatedAt: time.Now().UnixMilli(),
|
||||
}
|
||||
|
||||
mc, err := dao.buildModelContent(msgDo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msgPO.ModelContent = mc
|
||||
|
||||
ext, err := json.Marshal(msgDo.Ext)
|
||||
if err != nil {
|
||||
return nil, errorx.WrapByCode(err, errno.ErrConversationJsonMarshal)
|
||||
}
|
||||
msgPO.Ext = string(ext)
|
||||
|
||||
return msgPO, nil
|
||||
}
|
||||
|
||||
func (dao *MessageDAO) buildModelContent(msgDO *entity.Message) (string, error) {
|
||||
modelContent := msgDO.ModelContent
|
||||
if modelContent != "" {
|
||||
return modelContent, nil
|
||||
}
|
||||
|
||||
modelContentObj := &schema.Message{
|
||||
Role: msgDO.Role,
|
||||
Name: msgDO.Name,
|
||||
}
|
||||
if msgDO.Content == "" && len(msgDO.MultiContent) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var multiContent []schema.ChatMessagePart
|
||||
for _, contentData := range msgDO.MultiContent {
|
||||
if contentData.Type == message.InputTypeText {
|
||||
continue
|
||||
}
|
||||
one := schema.ChatMessagePart{}
|
||||
switch contentData.Type {
|
||||
case message.InputTypeImage:
|
||||
one.Type = schema.ChatMessagePartTypeImageURL
|
||||
one.ImageURL = &schema.ChatMessageImageURL{
|
||||
URL: contentData.FileData[0].Url,
|
||||
}
|
||||
case message.InputTypeFile:
|
||||
one.Type = schema.ChatMessagePartTypeFileURL
|
||||
one.FileURL = &schema.ChatMessageFileURL{
|
||||
URL: contentData.FileData[0].Url,
|
||||
}
|
||||
case message.InputTypeVideo:
|
||||
one.Type = schema.ChatMessagePartTypeVideoURL
|
||||
one.VideoURL = &schema.ChatMessageVideoURL{
|
||||
URL: contentData.FileData[0].Url,
|
||||
}
|
||||
case message.InputTypeAudio:
|
||||
one.Type = schema.ChatMessagePartTypeFileURL
|
||||
one.AudioURL = &schema.ChatMessageAudioURL{
|
||||
URL: contentData.FileData[0].Url,
|
||||
}
|
||||
}
|
||||
multiContent = append(multiContent, one)
|
||||
}
|
||||
if len(multiContent) > 0 {
|
||||
multiContent = append(multiContent, schema.ChatMessagePart{
|
||||
Type: schema.ChatMessagePartTypeText,
|
||||
Text: msgDO.Content,
|
||||
})
|
||||
} else {
|
||||
modelContentObj.Content = msgDO.Content
|
||||
}
|
||||
|
||||
modelContentObj.MultiContent = multiContent
|
||||
|
||||
mcObjByte, err := json.Marshal(modelContentObj)
|
||||
if err != nil {
|
||||
return "", errorx.WrapByCode(err, errno.ErrConversationJsonMarshal)
|
||||
}
|
||||
|
||||
return string(mcObjByte), nil
|
||||
}
|
||||
|
||||
func (dao *MessageDAO) batchMessagePO2DO(msgPOs []*model.Message) []*entity.Message {
|
||||
return slices.Transform(msgPOs, func(msgPO *model.Message) *entity.Message {
|
||||
msgDO := &entity.Message{
|
||||
ID: msgPO.ID,
|
||||
AgentID: msgPO.AgentID,
|
||||
ConversationID: msgPO.ConversationID,
|
||||
SectionID: msgPO.SectionID,
|
||||
UserID: msgPO.UserID,
|
||||
RunID: msgPO.RunID,
|
||||
Role: schema.RoleType(msgPO.Role),
|
||||
ContentType: message.ContentType(msgPO.ContentType),
|
||||
MessageType: message.MessageType(msgPO.MessageType),
|
||||
Position: msgPO.BrokenPosition,
|
||||
ModelContent: msgPO.ModelContent,
|
||||
Content: msgPO.Content,
|
||||
Status: message.MessageStatus(msgPO.Status),
|
||||
DisplayContent: msgPO.DisplayContent,
|
||||
CreatedAt: msgPO.CreatedAt,
|
||||
UpdatedAt: msgPO.UpdatedAt,
|
||||
ReasoningContent: msgPO.ReasoningContent,
|
||||
}
|
||||
|
||||
var ext map[string]string
|
||||
err := json.Unmarshal([]byte(msgPO.Ext), &ext)
|
||||
if err == nil {
|
||||
msgDO.Ext = ext
|
||||
}
|
||||
|
||||
return msgDO
|
||||
})
|
||||
}
|
||||
|
||||
func (dao *MessageDAO) messagePO2DO(msgPO *model.Message) *entity.Message {
|
||||
msgDO := &entity.Message{
|
||||
ID: msgPO.ID,
|
||||
AgentID: msgPO.AgentID,
|
||||
ConversationID: msgPO.ConversationID,
|
||||
SectionID: msgPO.SectionID,
|
||||
UserID: msgPO.UserID,
|
||||
RunID: msgPO.RunID,
|
||||
Role: schema.RoleType(msgPO.Role),
|
||||
ContentType: message.ContentType(msgPO.ContentType),
|
||||
MessageType: message.MessageType(msgPO.MessageType),
|
||||
ModelContent: msgPO.ModelContent,
|
||||
Content: msgPO.Content,
|
||||
DisplayContent: msgPO.DisplayContent,
|
||||
CreatedAt: msgPO.CreatedAt,
|
||||
UpdatedAt: msgPO.UpdatedAt,
|
||||
}
|
||||
|
||||
var ext map[string]string
|
||||
err := json.Unmarshal([]byte(msgPO.Ext), &ext)
|
||||
if err == nil {
|
||||
msgDO.Ext = ext
|
||||
}
|
||||
|
||||
return msgDO
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
|
||||
package model
|
||||
|
||||
const TableNameMessage = "message"
|
||||
|
||||
// Message 消息表
|
||||
type Message struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;autoIncrement:true;comment:主键ID" json:"id"` // 主键ID
|
||||
RunID int64 `gorm:"column:run_id;not null;comment:对应的run_id" json:"run_id"` // 对应的run_id
|
||||
ConversationID int64 `gorm:"column:conversation_id;not null;comment:conversation id" json:"conversation_id"` // conversation id
|
||||
UserID string `gorm:"column:user_id;not null;comment:user id" json:"user_id"` // user id
|
||||
AgentID int64 `gorm:"column:agent_id;not null;comment:agent_id" json:"agent_id"` // agent_id
|
||||
Role string `gorm:"column:role;not null;comment:角色: user、assistant、system" json:"role"` // 角色: user、assistant、system
|
||||
ContentType string `gorm:"column:content_type;not null;comment:内容类型 1 text" json:"content_type"` // 内容类型 1 text
|
||||
Content string `gorm:"column:content;comment:内容" json:"content"` // 内容
|
||||
MessageType string `gorm:"column:message_type;not null;comment:消息类型:" json:"message_type"` // 消息类型:
|
||||
DisplayContent string `gorm:"column:display_content;comment:展示内容" json:"display_content"` // 展示内容
|
||||
Ext string `gorm:"column:ext;comment:message 扩展字段" json:"ext"` // message 扩展字段
|
||||
SectionID int64 `gorm:"column:section_id;comment:段落id" json:"section_id"` // 段落id
|
||||
BrokenPosition int32 `gorm:"column:broken_position;default:-1;comment:打断位置" json:"broken_position"` // 打断位置
|
||||
Status int32 `gorm:"column:status;not null;comment:消息状态 1 Available 2 Deleted 3 Replaced 4 Broken 5 Failed 6 Streaming 7 Pending" json:"status"` // 消息状态 1 Available 2 Deleted 3 Replaced 4 Broken 5 Failed 6 Streaming 7 Pending
|
||||
ModelContent string `gorm:"column:model_content;comment:模型输入内容" json:"model_content"` // 模型输入内容
|
||||
MetaInfo string `gorm:"column:meta_info;comment:引用、高亮等文本标记信息" json:"meta_info"` // 引用、高亮等文本标记信息
|
||||
ReasoningContent string `gorm:"column:reasoning_content;comment:思考内容" json:"reasoning_content"` // 思考内容
|
||||
CreatedAt int64 `gorm:"column:created_at;not null;autoCreateTime:milli;comment:创建时间" json:"created_at"` // 创建时间
|
||||
UpdatedAt int64 `gorm:"column:updated_at;not null;autoUpdateTime:milli;comment:更新时间" json:"updated_at"` // 更新时间
|
||||
}
|
||||
|
||||
// TableName Message's table name
|
||||
func (*Message) TableName() string {
|
||||
return TableNameMessage
|
||||
}
|
||||
103
backend/domain/conversation/message/internal/dal/query/gen.go
Normal file
103
backend/domain/conversation/message/internal/dal/query/gen.go
Normal file
@@ -0,0 +1,103 @@
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
|
||||
package query
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"gorm.io/gen"
|
||||
|
||||
"gorm.io/plugin/dbresolver"
|
||||
)
|
||||
|
||||
var (
|
||||
Q = new(Query)
|
||||
Message *message
|
||||
)
|
||||
|
||||
func SetDefault(db *gorm.DB, opts ...gen.DOOption) {
|
||||
*Q = *Use(db, opts...)
|
||||
Message = &Q.Message
|
||||
}
|
||||
|
||||
func Use(db *gorm.DB, opts ...gen.DOOption) *Query {
|
||||
return &Query{
|
||||
db: db,
|
||||
Message: newMessage(db, opts...),
|
||||
}
|
||||
}
|
||||
|
||||
type Query struct {
|
||||
db *gorm.DB
|
||||
|
||||
Message message
|
||||
}
|
||||
|
||||
func (q *Query) Available() bool { return q.db != nil }
|
||||
|
||||
func (q *Query) clone(db *gorm.DB) *Query {
|
||||
return &Query{
|
||||
db: db,
|
||||
Message: q.Message.clone(db),
|
||||
}
|
||||
}
|
||||
|
||||
func (q *Query) ReadDB() *Query {
|
||||
return q.ReplaceDB(q.db.Clauses(dbresolver.Read))
|
||||
}
|
||||
|
||||
func (q *Query) WriteDB() *Query {
|
||||
return q.ReplaceDB(q.db.Clauses(dbresolver.Write))
|
||||
}
|
||||
|
||||
func (q *Query) ReplaceDB(db *gorm.DB) *Query {
|
||||
return &Query{
|
||||
db: db,
|
||||
Message: q.Message.replaceDB(db),
|
||||
}
|
||||
}
|
||||
|
||||
type queryCtx struct {
|
||||
Message IMessageDo
|
||||
}
|
||||
|
||||
func (q *Query) WithContext(ctx context.Context) *queryCtx {
|
||||
return &queryCtx{
|
||||
Message: q.Message.WithContext(ctx),
|
||||
}
|
||||
}
|
||||
|
||||
func (q *Query) Transaction(fc func(tx *Query) error, opts ...*sql.TxOptions) error {
|
||||
return q.db.Transaction(func(tx *gorm.DB) error { return fc(q.clone(tx)) }, opts...)
|
||||
}
|
||||
|
||||
func (q *Query) Begin(opts ...*sql.TxOptions) *QueryTx {
|
||||
tx := q.db.Begin(opts...)
|
||||
return &QueryTx{Query: q.clone(tx), Error: tx.Error}
|
||||
}
|
||||
|
||||
type QueryTx struct {
|
||||
*Query
|
||||
Error error
|
||||
}
|
||||
|
||||
func (q *QueryTx) Commit() error {
|
||||
return q.db.Commit().Error
|
||||
}
|
||||
|
||||
func (q *QueryTx) Rollback() error {
|
||||
return q.db.Rollback().Error
|
||||
}
|
||||
|
||||
func (q *QueryTx) SavePoint(name string) error {
|
||||
return q.db.SavePoint(name).Error
|
||||
}
|
||||
|
||||
func (q *QueryTx) RollbackTo(name string) error {
|
||||
return q.db.RollbackTo(name).Error
|
||||
}
|
||||
@@ -0,0 +1,453 @@
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
|
||||
package query
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
|
||||
"gorm.io/gen"
|
||||
"gorm.io/gen/field"
|
||||
|
||||
"gorm.io/plugin/dbresolver"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/internal/dal/model"
|
||||
)
|
||||
|
||||
func newMessage(db *gorm.DB, opts ...gen.DOOption) message {
|
||||
_message := message{}
|
||||
|
||||
_message.messageDo.UseDB(db, opts...)
|
||||
_message.messageDo.UseModel(&model.Message{})
|
||||
|
||||
tableName := _message.messageDo.TableName()
|
||||
_message.ALL = field.NewAsterisk(tableName)
|
||||
_message.ID = field.NewInt64(tableName, "id")
|
||||
_message.RunID = field.NewInt64(tableName, "run_id")
|
||||
_message.ConversationID = field.NewInt64(tableName, "conversation_id")
|
||||
_message.UserID = field.NewString(tableName, "user_id")
|
||||
_message.AgentID = field.NewInt64(tableName, "agent_id")
|
||||
_message.Role = field.NewString(tableName, "role")
|
||||
_message.ContentType = field.NewString(tableName, "content_type")
|
||||
_message.Content = field.NewString(tableName, "content")
|
||||
_message.MessageType = field.NewString(tableName, "message_type")
|
||||
_message.DisplayContent = field.NewString(tableName, "display_content")
|
||||
_message.Ext = field.NewString(tableName, "ext")
|
||||
_message.SectionID = field.NewInt64(tableName, "section_id")
|
||||
_message.BrokenPosition = field.NewInt32(tableName, "broken_position")
|
||||
_message.Status = field.NewInt32(tableName, "status")
|
||||
_message.ModelContent = field.NewString(tableName, "model_content")
|
||||
_message.MetaInfo = field.NewString(tableName, "meta_info")
|
||||
_message.ReasoningContent = field.NewString(tableName, "reasoning_content")
|
||||
_message.CreatedAt = field.NewInt64(tableName, "created_at")
|
||||
_message.UpdatedAt = field.NewInt64(tableName, "updated_at")
|
||||
|
||||
_message.fillFieldMap()
|
||||
|
||||
return _message
|
||||
}
|
||||
|
||||
// message 消息表
|
||||
type message struct {
|
||||
messageDo
|
||||
|
||||
ALL field.Asterisk
|
||||
ID field.Int64 // 主键ID
|
||||
RunID field.Int64 // 对应的run_id
|
||||
ConversationID field.Int64 // conversation id
|
||||
UserID field.String // user id
|
||||
AgentID field.Int64 // agent_id
|
||||
Role field.String // 角色: user、assistant、system
|
||||
ContentType field.String // 内容类型 1 text
|
||||
Content field.String // 内容
|
||||
MessageType field.String // 消息类型:
|
||||
DisplayContent field.String // 展示内容
|
||||
Ext field.String // message 扩展字段
|
||||
SectionID field.Int64 // 段落id
|
||||
BrokenPosition field.Int32 // 打断位置
|
||||
Status field.Int32 // 消息状态 1 Available 2 Deleted 3 Replaced 4 Broken 5 Failed 6 Streaming 7 Pending
|
||||
ModelContent field.String // 模型输入内容
|
||||
MetaInfo field.String // 引用、高亮等文本标记信息
|
||||
ReasoningContent field.String // 思考内容
|
||||
CreatedAt field.Int64 // 创建时间
|
||||
UpdatedAt field.Int64 // 更新时间
|
||||
|
||||
fieldMap map[string]field.Expr
|
||||
}
|
||||
|
||||
func (m message) Table(newTableName string) *message {
|
||||
m.messageDo.UseTable(newTableName)
|
||||
return m.updateTableName(newTableName)
|
||||
}
|
||||
|
||||
func (m message) As(alias string) *message {
|
||||
m.messageDo.DO = *(m.messageDo.As(alias).(*gen.DO))
|
||||
return m.updateTableName(alias)
|
||||
}
|
||||
|
||||
func (m *message) updateTableName(table string) *message {
|
||||
m.ALL = field.NewAsterisk(table)
|
||||
m.ID = field.NewInt64(table, "id")
|
||||
m.RunID = field.NewInt64(table, "run_id")
|
||||
m.ConversationID = field.NewInt64(table, "conversation_id")
|
||||
m.UserID = field.NewString(table, "user_id")
|
||||
m.AgentID = field.NewInt64(table, "agent_id")
|
||||
m.Role = field.NewString(table, "role")
|
||||
m.ContentType = field.NewString(table, "content_type")
|
||||
m.Content = field.NewString(table, "content")
|
||||
m.MessageType = field.NewString(table, "message_type")
|
||||
m.DisplayContent = field.NewString(table, "display_content")
|
||||
m.Ext = field.NewString(table, "ext")
|
||||
m.SectionID = field.NewInt64(table, "section_id")
|
||||
m.BrokenPosition = field.NewInt32(table, "broken_position")
|
||||
m.Status = field.NewInt32(table, "status")
|
||||
m.ModelContent = field.NewString(table, "model_content")
|
||||
m.MetaInfo = field.NewString(table, "meta_info")
|
||||
m.ReasoningContent = field.NewString(table, "reasoning_content")
|
||||
m.CreatedAt = field.NewInt64(table, "created_at")
|
||||
m.UpdatedAt = field.NewInt64(table, "updated_at")
|
||||
|
||||
m.fillFieldMap()
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *message) GetFieldByName(fieldName string) (field.OrderExpr, bool) {
|
||||
_f, ok := m.fieldMap[fieldName]
|
||||
if !ok || _f == nil {
|
||||
return nil, false
|
||||
}
|
||||
_oe, ok := _f.(field.OrderExpr)
|
||||
return _oe, ok
|
||||
}
|
||||
|
||||
func (m *message) fillFieldMap() {
|
||||
m.fieldMap = make(map[string]field.Expr, 19)
|
||||
m.fieldMap["id"] = m.ID
|
||||
m.fieldMap["run_id"] = m.RunID
|
||||
m.fieldMap["conversation_id"] = m.ConversationID
|
||||
m.fieldMap["user_id"] = m.UserID
|
||||
m.fieldMap["agent_id"] = m.AgentID
|
||||
m.fieldMap["role"] = m.Role
|
||||
m.fieldMap["content_type"] = m.ContentType
|
||||
m.fieldMap["content"] = m.Content
|
||||
m.fieldMap["message_type"] = m.MessageType
|
||||
m.fieldMap["display_content"] = m.DisplayContent
|
||||
m.fieldMap["ext"] = m.Ext
|
||||
m.fieldMap["section_id"] = m.SectionID
|
||||
m.fieldMap["broken_position"] = m.BrokenPosition
|
||||
m.fieldMap["status"] = m.Status
|
||||
m.fieldMap["model_content"] = m.ModelContent
|
||||
m.fieldMap["meta_info"] = m.MetaInfo
|
||||
m.fieldMap["reasoning_content"] = m.ReasoningContent
|
||||
m.fieldMap["created_at"] = m.CreatedAt
|
||||
m.fieldMap["updated_at"] = m.UpdatedAt
|
||||
}
|
||||
|
||||
func (m message) clone(db *gorm.DB) message {
|
||||
m.messageDo.ReplaceConnPool(db.Statement.ConnPool)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m message) replaceDB(db *gorm.DB) message {
|
||||
m.messageDo.ReplaceDB(db)
|
||||
return m
|
||||
}
|
||||
|
||||
type messageDo struct{ gen.DO }
|
||||
|
||||
type IMessageDo interface {
|
||||
gen.SubQuery
|
||||
Debug() IMessageDo
|
||||
WithContext(ctx context.Context) IMessageDo
|
||||
WithResult(fc func(tx gen.Dao)) gen.ResultInfo
|
||||
ReplaceDB(db *gorm.DB)
|
||||
ReadDB() IMessageDo
|
||||
WriteDB() IMessageDo
|
||||
As(alias string) gen.Dao
|
||||
Session(config *gorm.Session) IMessageDo
|
||||
Columns(cols ...field.Expr) gen.Columns
|
||||
Clauses(conds ...clause.Expression) IMessageDo
|
||||
Not(conds ...gen.Condition) IMessageDo
|
||||
Or(conds ...gen.Condition) IMessageDo
|
||||
Select(conds ...field.Expr) IMessageDo
|
||||
Where(conds ...gen.Condition) IMessageDo
|
||||
Order(conds ...field.Expr) IMessageDo
|
||||
Distinct(cols ...field.Expr) IMessageDo
|
||||
Omit(cols ...field.Expr) IMessageDo
|
||||
Join(table schema.Tabler, on ...field.Expr) IMessageDo
|
||||
LeftJoin(table schema.Tabler, on ...field.Expr) IMessageDo
|
||||
RightJoin(table schema.Tabler, on ...field.Expr) IMessageDo
|
||||
Group(cols ...field.Expr) IMessageDo
|
||||
Having(conds ...gen.Condition) IMessageDo
|
||||
Limit(limit int) IMessageDo
|
||||
Offset(offset int) IMessageDo
|
||||
Count() (count int64, err error)
|
||||
Scopes(funcs ...func(gen.Dao) gen.Dao) IMessageDo
|
||||
Unscoped() IMessageDo
|
||||
Create(values ...*model.Message) error
|
||||
CreateInBatches(values []*model.Message, batchSize int) error
|
||||
Save(values ...*model.Message) error
|
||||
First() (*model.Message, error)
|
||||
Take() (*model.Message, error)
|
||||
Last() (*model.Message, error)
|
||||
Find() ([]*model.Message, error)
|
||||
FindInBatch(batchSize int, fc func(tx gen.Dao, batch int) error) (results []*model.Message, err error)
|
||||
FindInBatches(result *[]*model.Message, batchSize int, fc func(tx gen.Dao, batch int) error) error
|
||||
Pluck(column field.Expr, dest interface{}) error
|
||||
Delete(...*model.Message) (info gen.ResultInfo, err error)
|
||||
Update(column field.Expr, value interface{}) (info gen.ResultInfo, err error)
|
||||
UpdateSimple(columns ...field.AssignExpr) (info gen.ResultInfo, err error)
|
||||
Updates(value interface{}) (info gen.ResultInfo, err error)
|
||||
UpdateColumn(column field.Expr, value interface{}) (info gen.ResultInfo, err error)
|
||||
UpdateColumnSimple(columns ...field.AssignExpr) (info gen.ResultInfo, err error)
|
||||
UpdateColumns(value interface{}) (info gen.ResultInfo, err error)
|
||||
UpdateFrom(q gen.SubQuery) gen.Dao
|
||||
Attrs(attrs ...field.AssignExpr) IMessageDo
|
||||
Assign(attrs ...field.AssignExpr) IMessageDo
|
||||
Joins(fields ...field.RelationField) IMessageDo
|
||||
Preload(fields ...field.RelationField) IMessageDo
|
||||
FirstOrInit() (*model.Message, error)
|
||||
FirstOrCreate() (*model.Message, error)
|
||||
FindByPage(offset int, limit int) (result []*model.Message, count int64, err error)
|
||||
ScanByPage(result interface{}, offset int, limit int) (count int64, err error)
|
||||
Scan(result interface{}) (err error)
|
||||
Returning(value interface{}, columns ...string) IMessageDo
|
||||
UnderlyingDB() *gorm.DB
|
||||
schema.Tabler
|
||||
}
|
||||
|
||||
func (m messageDo) Debug() IMessageDo {
|
||||
return m.withDO(m.DO.Debug())
|
||||
}
|
||||
|
||||
func (m messageDo) WithContext(ctx context.Context) IMessageDo {
|
||||
return m.withDO(m.DO.WithContext(ctx))
|
||||
}
|
||||
|
||||
func (m messageDo) ReadDB() IMessageDo {
|
||||
return m.Clauses(dbresolver.Read)
|
||||
}
|
||||
|
||||
func (m messageDo) WriteDB() IMessageDo {
|
||||
return m.Clauses(dbresolver.Write)
|
||||
}
|
||||
|
||||
func (m messageDo) Session(config *gorm.Session) IMessageDo {
|
||||
return m.withDO(m.DO.Session(config))
|
||||
}
|
||||
|
||||
func (m messageDo) Clauses(conds ...clause.Expression) IMessageDo {
|
||||
return m.withDO(m.DO.Clauses(conds...))
|
||||
}
|
||||
|
||||
func (m messageDo) Returning(value interface{}, columns ...string) IMessageDo {
|
||||
return m.withDO(m.DO.Returning(value, columns...))
|
||||
}
|
||||
|
||||
func (m messageDo) Not(conds ...gen.Condition) IMessageDo {
|
||||
return m.withDO(m.DO.Not(conds...))
|
||||
}
|
||||
|
||||
func (m messageDo) Or(conds ...gen.Condition) IMessageDo {
|
||||
return m.withDO(m.DO.Or(conds...))
|
||||
}
|
||||
|
||||
func (m messageDo) Select(conds ...field.Expr) IMessageDo {
|
||||
return m.withDO(m.DO.Select(conds...))
|
||||
}
|
||||
|
||||
func (m messageDo) Where(conds ...gen.Condition) IMessageDo {
|
||||
return m.withDO(m.DO.Where(conds...))
|
||||
}
|
||||
|
||||
func (m messageDo) Order(conds ...field.Expr) IMessageDo {
|
||||
return m.withDO(m.DO.Order(conds...))
|
||||
}
|
||||
|
||||
func (m messageDo) Distinct(cols ...field.Expr) IMessageDo {
|
||||
return m.withDO(m.DO.Distinct(cols...))
|
||||
}
|
||||
|
||||
func (m messageDo) Omit(cols ...field.Expr) IMessageDo {
|
||||
return m.withDO(m.DO.Omit(cols...))
|
||||
}
|
||||
|
||||
func (m messageDo) Join(table schema.Tabler, on ...field.Expr) IMessageDo {
|
||||
return m.withDO(m.DO.Join(table, on...))
|
||||
}
|
||||
|
||||
func (m messageDo) LeftJoin(table schema.Tabler, on ...field.Expr) IMessageDo {
|
||||
return m.withDO(m.DO.LeftJoin(table, on...))
|
||||
}
|
||||
|
||||
func (m messageDo) RightJoin(table schema.Tabler, on ...field.Expr) IMessageDo {
|
||||
return m.withDO(m.DO.RightJoin(table, on...))
|
||||
}
|
||||
|
||||
func (m messageDo) Group(cols ...field.Expr) IMessageDo {
|
||||
return m.withDO(m.DO.Group(cols...))
|
||||
}
|
||||
|
||||
func (m messageDo) Having(conds ...gen.Condition) IMessageDo {
|
||||
return m.withDO(m.DO.Having(conds...))
|
||||
}
|
||||
|
||||
func (m messageDo) Limit(limit int) IMessageDo {
|
||||
return m.withDO(m.DO.Limit(limit))
|
||||
}
|
||||
|
||||
func (m messageDo) Offset(offset int) IMessageDo {
|
||||
return m.withDO(m.DO.Offset(offset))
|
||||
}
|
||||
|
||||
func (m messageDo) Scopes(funcs ...func(gen.Dao) gen.Dao) IMessageDo {
|
||||
return m.withDO(m.DO.Scopes(funcs...))
|
||||
}
|
||||
|
||||
func (m messageDo) Unscoped() IMessageDo {
|
||||
return m.withDO(m.DO.Unscoped())
|
||||
}
|
||||
|
||||
func (m messageDo) Create(values ...*model.Message) error {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
return m.DO.Create(values)
|
||||
}
|
||||
|
||||
func (m messageDo) CreateInBatches(values []*model.Message, batchSize int) error {
|
||||
return m.DO.CreateInBatches(values, batchSize)
|
||||
}
|
||||
|
||||
// Save : !!! underlying implementation is different with GORM
|
||||
// The method is equivalent to executing the statement: db.Clauses(clause.OnConflict{UpdateAll: true}).Create(values)
|
||||
func (m messageDo) Save(values ...*model.Message) error {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
return m.DO.Save(values)
|
||||
}
|
||||
|
||||
func (m messageDo) First() (*model.Message, error) {
|
||||
if result, err := m.DO.First(); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return result.(*model.Message), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m messageDo) Take() (*model.Message, error) {
|
||||
if result, err := m.DO.Take(); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return result.(*model.Message), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m messageDo) Last() (*model.Message, error) {
|
||||
if result, err := m.DO.Last(); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return result.(*model.Message), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m messageDo) Find() ([]*model.Message, error) {
|
||||
result, err := m.DO.Find()
|
||||
return result.([]*model.Message), err
|
||||
}
|
||||
|
||||
func (m messageDo) FindInBatch(batchSize int, fc func(tx gen.Dao, batch int) error) (results []*model.Message, err error) {
|
||||
buf := make([]*model.Message, 0, batchSize)
|
||||
err = m.DO.FindInBatches(&buf, batchSize, func(tx gen.Dao, batch int) error {
|
||||
defer func() { results = append(results, buf...) }()
|
||||
return fc(tx, batch)
|
||||
})
|
||||
return results, err
|
||||
}
|
||||
|
||||
func (m messageDo) FindInBatches(result *[]*model.Message, batchSize int, fc func(tx gen.Dao, batch int) error) error {
|
||||
return m.DO.FindInBatches(result, batchSize, fc)
|
||||
}
|
||||
|
||||
func (m messageDo) Attrs(attrs ...field.AssignExpr) IMessageDo {
|
||||
return m.withDO(m.DO.Attrs(attrs...))
|
||||
}
|
||||
|
||||
func (m messageDo) Assign(attrs ...field.AssignExpr) IMessageDo {
|
||||
return m.withDO(m.DO.Assign(attrs...))
|
||||
}
|
||||
|
||||
func (m messageDo) Joins(fields ...field.RelationField) IMessageDo {
|
||||
for _, _f := range fields {
|
||||
m = *m.withDO(m.DO.Joins(_f))
|
||||
}
|
||||
return &m
|
||||
}
|
||||
|
||||
func (m messageDo) Preload(fields ...field.RelationField) IMessageDo {
|
||||
for _, _f := range fields {
|
||||
m = *m.withDO(m.DO.Preload(_f))
|
||||
}
|
||||
return &m
|
||||
}
|
||||
|
||||
func (m messageDo) FirstOrInit() (*model.Message, error) {
|
||||
if result, err := m.DO.FirstOrInit(); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return result.(*model.Message), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m messageDo) FirstOrCreate() (*model.Message, error) {
|
||||
if result, err := m.DO.FirstOrCreate(); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return result.(*model.Message), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m messageDo) FindByPage(offset int, limit int) (result []*model.Message, count int64, err error) {
|
||||
result, err = m.Offset(offset).Limit(limit).Find()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if size := len(result); 0 < limit && 0 < size && size < limit {
|
||||
count = int64(size + offset)
|
||||
return
|
||||
}
|
||||
|
||||
count, err = m.Offset(-1).Limit(-1).Count()
|
||||
return
|
||||
}
|
||||
|
||||
func (m messageDo) ScanByPage(result interface{}, offset int, limit int) (count int64, err error) {
|
||||
count, err = m.Count()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = m.Offset(offset).Limit(limit).Scan(result)
|
||||
return
|
||||
}
|
||||
|
||||
func (m messageDo) Scan(result interface{}) (err error) {
|
||||
return m.DO.Scan(result)
|
||||
}
|
||||
|
||||
func (m messageDo) Delete(models ...*model.Message) (result gen.ResultInfo, err error) {
|
||||
return m.DO.Delete(models)
|
||||
}
|
||||
|
||||
func (m *messageDo) withDO(do gen.Dao) *messageDo {
|
||||
m.DO = *do.(*gen.DO)
|
||||
return m
|
||||
}
|
||||
42
backend/domain/conversation/message/repository/repository.go
Normal file
42
backend/domain/conversation/message/repository/repository.go
Normal file
@@ -0,0 +1,42 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/internal/dal"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
)
|
||||
|
||||
func NewMessageRepo(db *gorm.DB, idGen idgen.IDGenerator) MessageRepo {
|
||||
return dal.NewMessageDAO(db, idGen)
|
||||
}
|
||||
|
||||
type MessageRepo interface {
|
||||
Create(ctx context.Context, msg *entity.Message) (*entity.Message, error)
|
||||
List(ctx context.Context, conversationID int64, limit int, cursor int64,
|
||||
direction entity.ScrollPageDirection, messageType *message.MessageType) ([]*entity.Message, bool, error)
|
||||
GetByRunIDs(ctx context.Context, runIDs []int64, orderBy string) ([]*entity.Message, error)
|
||||
Edit(ctx context.Context, msgID int64, message *message.Message) (int64, error)
|
||||
GetByID(ctx context.Context, msgID int64) (*entity.Message, error)
|
||||
Delete(ctx context.Context, msgIDs []int64, runIDs []int64) error
|
||||
}
|
||||
33
backend/domain/conversation/message/service/message.go
Normal file
33
backend/domain/conversation/message/service/message.go
Normal file
@@ -0,0 +1,33 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package message
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
|
||||
)
|
||||
|
||||
type Message interface {
|
||||
List(ctx context.Context, req *entity.ListMeta) (*entity.ListResult, error)
|
||||
Create(ctx context.Context, req *entity.Message) (*entity.Message, error)
|
||||
GetByRunIDs(ctx context.Context, conversationID int64, runIDs []int64) ([]*entity.Message, error)
|
||||
GetByID(ctx context.Context, id int64) (*entity.Message, error)
|
||||
Edit(ctx context.Context, req *entity.Message) (*entity.Message, error)
|
||||
Delete(ctx context.Context, req *entity.DeleteMeta) error
|
||||
Broken(ctx context.Context, req *entity.BrokenMeta) error
|
||||
}
|
||||
130
backend/domain/conversation/message/service/message_impl.go
Normal file
130
backend/domain/conversation/message/service/message_impl.go
Normal file
@@ -0,0 +1,130 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package message
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/repository"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
type messageImpl struct {
|
||||
Components
|
||||
}
|
||||
|
||||
type Components struct {
|
||||
MessageRepo repository.MessageRepo
|
||||
}
|
||||
|
||||
func NewService(c *Components) Message {
|
||||
return &messageImpl{
|
||||
Components: *c,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *messageImpl) Create(ctx context.Context, msg *entity.Message) (*entity.Message, error) {
|
||||
// create message
|
||||
msg, err := m.MessageRepo.Create(ctx, msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (m *messageImpl) List(ctx context.Context, req *entity.ListMeta) (*entity.ListResult, error) {
|
||||
resp := &entity.ListResult{}
|
||||
|
||||
// get message with query
|
||||
messageList, hasMore, err := m.MessageRepo.List(ctx, req.ConversationID, req.Limit, req.Cursor, req.Direction, ptr.Of(message.MessageTypeQuestion))
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
resp.Direction = req.Direction
|
||||
resp.HasMore = hasMore
|
||||
|
||||
if len(messageList) > 0 {
|
||||
resp.PrevCursor = messageList[len(messageList)-1].CreatedAt
|
||||
resp.NextCursor = messageList[0].CreatedAt
|
||||
|
||||
var runIDs []int64
|
||||
for _, m := range messageList {
|
||||
runIDs = append(runIDs, m.RunID)
|
||||
}
|
||||
orderBy := "DESC"
|
||||
if req.OrderBy != nil {
|
||||
orderBy = *req.OrderBy
|
||||
}
|
||||
allMessageList, err := m.MessageRepo.GetByRunIDs(ctx, runIDs, orderBy)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
resp.Messages = allMessageList
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (m *messageImpl) GetByRunIDs(ctx context.Context, conversationID int64, runIDs []int64) ([]*entity.Message, error) {
|
||||
messageList, err := m.MessageRepo.GetByRunIDs(ctx, runIDs, "ASC")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return messageList, nil
|
||||
}
|
||||
|
||||
func (m *messageImpl) Edit(ctx context.Context, req *entity.Message) (*entity.Message, error) {
|
||||
_, err := m.MessageRepo.Edit(ctx, req.ID, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (m *messageImpl) Delete(ctx context.Context, req *entity.DeleteMeta) error {
|
||||
err := m.MessageRepo.Delete(ctx, req.MessageIDs, req.RunIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *messageImpl) GetByID(ctx context.Context, id int64) (*entity.Message, error) {
|
||||
msg, err := m.MessageRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (m *messageImpl) Broken(ctx context.Context, req *entity.BrokenMeta) error {
|
||||
|
||||
_, err := m.MessageRepo.Edit(ctx, req.ID, &message.Message{
|
||||
Status: message.MessageStatusBroken,
|
||||
Position: ptr.From(req.Position),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
247
backend/domain/conversation/message/service/message_test.go
Normal file
247
backend/domain/conversation/message/service/message_test.go
Normal file
@@ -0,0 +1,247 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package message
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/internal/dal/model"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/repository"
|
||||
mock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/orm"
|
||||
)
|
||||
|
||||
// Test_NewListMessage tests the NewListMessage function
|
||||
func TestListMessage(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
mockDBGen := orm.NewMockDB()
|
||||
|
||||
mockDBGen.AddTable(&model.Message{}).
|
||||
AddRows(
|
||||
&model.Message{
|
||||
ID: 1,
|
||||
ConversationID: 1,
|
||||
UserID: "1",
|
||||
},
|
||||
&model.Message{
|
||||
ID: 2,
|
||||
ConversationID: 1,
|
||||
UserID: "1",
|
||||
},
|
||||
)
|
||||
|
||||
mockDB, err := mockDBGen.DB()
|
||||
assert.NoError(t, err)
|
||||
|
||||
components := &Components{
|
||||
MessageRepo: repository.NewMessageRepo(mockDB, nil),
|
||||
}
|
||||
|
||||
resp, err := NewService(components).List(ctx, &entity.ListMeta{
|
||||
ConversationID: 1,
|
||||
Limit: 1,
|
||||
UserID: "1",
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, resp.Messages, 0)
|
||||
}
|
||||
|
||||
// Test_NewListMessage tests the NewListMessage function
|
||||
func TestCreateMessage(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
idGen := mock.NewMockIDGenerator(ctrl)
|
||||
idGen.EXPECT().GenID(gomock.Any()).Return(int64(10), nil).Times(1)
|
||||
|
||||
mockDBGen := orm.NewMockDB()
|
||||
mockDBGen.AddTable(&model.Message{})
|
||||
mockDB, err := mockDBGen.DB()
|
||||
|
||||
// redisCli := redis.New()
|
||||
// idGen, _ := idgen.New(redisCli)
|
||||
// mockDB, err := mysql.New()
|
||||
|
||||
assert.NoError(t, err)
|
||||
|
||||
components := &Components{
|
||||
MessageRepo: repository.NewMessageRepo(mockDB, idGen),
|
||||
}
|
||||
imageInput := &message.FileData{
|
||||
Url: "https://xxxxx.xxxx/image",
|
||||
Name: "test_img",
|
||||
}
|
||||
fileInput := &message.FileData{
|
||||
Url: "https://xxxxx.xxxx/file",
|
||||
Name: "test_file",
|
||||
}
|
||||
content := []*message.InputMetaData{
|
||||
{
|
||||
Type: message.InputTypeText,
|
||||
Text: "解析图片中的内容",
|
||||
},
|
||||
{
|
||||
Type: message.InputTypeImage,
|
||||
FileData: []*message.FileData{
|
||||
imageInput,
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: message.InputTypeFile,
|
||||
FileData: []*message.FileData{
|
||||
fileInput,
|
||||
},
|
||||
},
|
||||
}
|
||||
service := NewService(components)
|
||||
insert := &entity.Message{
|
||||
ID: 7498710126354759680,
|
||||
ConversationID: 7496795464885338112,
|
||||
AgentID: 7366055842027922437,
|
||||
UserID: "6666666",
|
||||
RunID: 7498710102375923712,
|
||||
Content: "你是谁?",
|
||||
MultiContent: content,
|
||||
Role: schema.Assistant,
|
||||
MessageType: message.MessageTypeFunctionCall,
|
||||
SectionID: 7496795464897921024,
|
||||
ModelContent: "{\"role\":\"tool\",\"content\":\"tool call\"}",
|
||||
ContentType: message.ContentTypeMix,
|
||||
}
|
||||
resp, err := service.Create(ctx, insert)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, int64(7366055842027922437), resp.AgentID)
|
||||
assert.Equal(t, "你是谁?", resp.Content)
|
||||
}
|
||||
|
||||
func TestEditMessage(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mockDBGen := orm.NewMockDB()
|
||||
|
||||
mockDBGen.AddTable(&model.Message{}).
|
||||
AddRows(
|
||||
&model.Message{
|
||||
ID: 1,
|
||||
ConversationID: 1,
|
||||
UserID: "1",
|
||||
RunID: 123,
|
||||
},
|
||||
&model.Message{
|
||||
ID: 2,
|
||||
ConversationID: 1,
|
||||
UserID: "1",
|
||||
RunID: 124,
|
||||
},
|
||||
)
|
||||
|
||||
mockDB, err := mockDBGen.DB()
|
||||
assert.NoError(t, err)
|
||||
|
||||
components := &Components{
|
||||
MessageRepo: repository.NewMessageRepo(mockDB, nil),
|
||||
}
|
||||
|
||||
imageInput := &message.FileData{
|
||||
Url: "https://xxxxx.xxxx/image",
|
||||
Name: "test_img",
|
||||
}
|
||||
fileInput := &message.FileData{
|
||||
Url: "https://xxxxx.xxxx/file",
|
||||
Name: "test_file",
|
||||
}
|
||||
content := []*message.InputMetaData{
|
||||
{
|
||||
Type: message.InputTypeText,
|
||||
Text: "解析图片中的内容",
|
||||
},
|
||||
{
|
||||
Type: message.InputTypeImage,
|
||||
FileData: []*message.FileData{
|
||||
imageInput,
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: message.InputTypeFile,
|
||||
FileData: []*message.FileData{
|
||||
fileInput,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := NewService(components).Edit(ctx, &entity.Message{
|
||||
ID: 2,
|
||||
Content: "test edit message",
|
||||
MultiContent: content,
|
||||
})
|
||||
_ = resp
|
||||
|
||||
msOne, err := NewService(components).GetByRunIDs(ctx, 1, []int64{124})
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, int64(124), msOne[0].RunID)
|
||||
}
|
||||
|
||||
func TestGetByRunIDs(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
mockDBGen := orm.NewMockDB()
|
||||
|
||||
mockDBGen.AddTable(&model.Message{}).
|
||||
AddRows(
|
||||
&model.Message{
|
||||
ID: 1,
|
||||
ConversationID: 1,
|
||||
UserID: "1",
|
||||
RunID: 123,
|
||||
Content: "test content123",
|
||||
},
|
||||
&model.Message{
|
||||
ID: 2,
|
||||
ConversationID: 1,
|
||||
UserID: "1",
|
||||
Content: "test content124",
|
||||
RunID: 124,
|
||||
},
|
||||
&model.Message{
|
||||
ID: 3,
|
||||
ConversationID: 1,
|
||||
UserID: "1",
|
||||
Content: "test content124",
|
||||
RunID: 124,
|
||||
},
|
||||
)
|
||||
mockDB, err := mockDBGen.DB()
|
||||
assert.NoError(t, err)
|
||||
components := &Components{
|
||||
MessageRepo: repository.NewMessageRepo(mockDB, nil),
|
||||
}
|
||||
|
||||
resp, err := NewService(components).GetByRunIDs(ctx, 1, []int64{124})
|
||||
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Len(t, resp, 2)
|
||||
}
|
||||
Reference in New Issue
Block a user