feat: manually mirror opencoze's code from bytedance

Change-Id: I09a73aadda978ad9511264a756b2ce51f5761adf
This commit is contained in:
fanlv
2025-07-20 17:36:12 +08:00
commit 890153324f
14811 changed files with 1923430 additions and 0 deletions

View File

@@ -0,0 +1,85 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package entity
const ConversationTurnsDefault int32 = 100
type RunStatus string
const (
RunStatusCreated RunStatus = "created"
RunStatusInProgress RunStatus = "in_progress"
RunStatusCompleted RunStatus = "completed"
RunStatusFailed RunStatus = "failed"
RunStatusExpired RunStatus = "expired"
RunStatusCancelled RunStatus = "cancelled"
RunStatusRequiredAction RunStatus = "required_action"
RunStatusDeleted RunStatus = "deleted"
)
type RunEvent string
const (
RunEventCreated RunEvent = "conversation.run.created"
RunEventInProgress RunEvent = "conversation.run.in_progress"
RunEventCompleted RunEvent = "conversation.run.completed"
RunEventFailed RunEvent = "conversation.run.failed"
RunEventExpired RunEvent = "conversation.run.expired"
RunEventCancelled RunEvent = "conversation.run.cancelled"
RunEventRequiredAction RunEvent = "conversation.run.required_action"
RunEventMessageDelta RunEvent = "conversation.message.delta"
RunEventMessageCompleted RunEvent = "conversation.message.completed"
RunEventAck = "conversation.ack"
RunEventError RunEvent = "conversation.error"
RunEventStreamDone RunEvent = "conversation.stream.done"
)
type ReplyType int64
const (
ReplyTypeAnswer ReplyType = 1
ReplyTypeSuggest ReplyType = 2
ReplyTypeLLMOutput ReplyType = 3
ReplyTypeToolOutput ReplyType = 4
ReplyTypeVerbose ReplyType = 100
ReplyTypePlaceHolder ReplyType = 101
)
type MetaType int64
const (
MetaTypeKnowledgeCard MetaType = 4
)
type RoleType string
const (
RoleTypeSystem RoleType = "system"
RoleTypeUser RoleType = "user"
RoleTypeAssistant RoleType = "assistant"
RoleTypeTool RoleType = "tool"
)
type MessageSubType string
const (
MessageSubTypeKnowledgeCall MessageSubType = "knowledge_recall"
MessageSubTypeGenerateFinish MessageSubType = "generate_answer_finish"
MessageSubTypeInterrupt MessageSubType = "interrupt"
)

View File

@@ -0,0 +1,30 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package entity
type EventType int64
const (
EventType_LocalPlugin EventType = 1
EventType_Question EventType = 2
EventType_RequireInfos EventType = 3
EventType_SceneChat EventType = 4
EventType_InputNode EventType = 5
EventType_WorkflowLocalPlugin EventType = 6
EventType_OauthPlugin EventType = 7
EventType_WorkflowLLM EventType = 100
)

View File

@@ -0,0 +1,158 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package entity
import (
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/api/model/conversation/common"
message2 "github.com/coze-dev/coze-studio/backend/api/model/conversation/message"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/internal/dal/model"
)
type RunRecord = model.RunRecord
type RunRecordMeta struct {
ID int64 `json:"id"`
ConversationID int64 `json:"conversation_id"`
SectionID int64 `json:"section_id"`
AgentID int64 `json:"agent_id"`
Status RunStatus `json:"status"`
Error *RunError `json:"error"`
Usage *agentrun.Usage `json:"usage"`
Ext string `json:"ext"`
CreatedAt int64 `json:"created_at"`
UpdatedAt int64 `json:"updated_at"`
ChatRequest *string `json:"chat_message"`
CompletedAt int64 `json:"completed_at"`
FailedAt int64 `json:"failed_at"`
}
type ChunkRunItem = RunRecordMeta
type ChunkMessageItem struct {
ID int64 `json:"id"`
ConversationID int64 `json:"conversation_id"`
SectionID int64 `json:"section_id"`
RunID int64 `json:"run_id"`
AgentID int64 `json:"agent_id"`
Role RoleType `json:"role"`
Type message.MessageType `json:"type"`
Content string `json:"content"`
ContentType message.ContentType `json:"content_type"`
MessageType message.MessageType `json:"message_type"`
ReplyID int64 `json:"reply_id"`
Ext map[string]string `json:"ext"`
ReasoningContent *string `json:"reasoning_content"`
Index int64 `json:"index"`
RequiredAction *message2.RequiredAction `json:"required_action"`
SeqID int64 `json:"seq_id"`
CreatedAt int64 `json:"created_at"`
UpdatedAt int64 `json:"updated_at"`
IsFinish bool `json:"is_finish"`
}
type RunError struct {
Code int64 `json:"code"`
Msg string `json:"msg"`
}
type CustomerConfig struct {
ModelConfig *ModelConfig `json:"model_config"`
AgentConfig *AgentConfig `json:"agent_config"`
}
type ModelConfig struct {
ModelId *int64 `json:"model_id,omitempty"`
}
type AgentConfig struct {
Prompt *string `json:"prompt"`
}
type Tool = agentrun.Tool
type AnswerFinshContent struct {
MsgType MessageSubType `json:"msg_type"`
Data string `json:"data"`
FromUnit string `json:"from_unit"`
}
type Data struct {
FinishReason int32 `json:"finish_reason"`
FinData string `json:"fin_data"`
}
type MetaInfo struct {
Type MetaType `json:"type"`
Info string `json:"info"`
}
type AgentRunMeta struct {
ConversationID int64 `json:"conversation_id"`
ConnectorID int64 `json:"connector_id"`
SpaceID int64 `json:"space_id"`
Scene common.Scene `json:"scene"`
SectionID int64 `json:"section_id"`
Name string `json:"name"`
UserID string `json:"user_id"`
AgentID int64 `json:"agent_id"`
ContentType message.ContentType `json:"content_type"`
Content []*message.InputMetaData `json:"content"`
PreRetrieveTools []*Tool `json:"tools"`
IsDraft bool `json:"is_draft"`
CustomerConfig *CustomerConfig `json:"customer_config"`
DisplayContent string `json:"display_content"`
CustomVariables map[string]string `json:"custom_variables"`
Version string `json:"version"`
Ext map[string]string `json:"ext"`
}
type UpdateMeta struct {
Status RunStatus
LastError *RunError
Usage *agentrun.Usage
UpdatedAt int64
CompletedAt int64
FailedAt int64
}
type AgentRunResponse struct {
Event RunEvent `json:"event"`
ChunkRunItem *ChunkRunItem `json:"run_record_item"`
ChunkMessageItem *ChunkMessageItem `json:"message_item"`
Error *RunError `json:"error"`
}
type AgentRespEvent struct {
EventType message.MessageType
ModelAnswer *schema.StreamReader[*schema.Message]
ToolsMessage []*schema.Message
FuncCall *schema.Message
Suggest *schema.Message
Knowledge []*schema.Document
Interrupt *singleagent.InterruptInfo
Err error
}
type ModelAnswerEvent struct {
Message *schema.Message
Err error
}

View File

@@ -0,0 +1,167 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dal
import (
"context"
"encoding/json"
"time"
"gorm.io/gorm"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/internal/dal/model"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/internal/dal/query"
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
)
type RunRecordDAO struct {
db *gorm.DB
query *query.Query
idGen idgen.IDGenerator
}
func NewRunRecordDAO(db *gorm.DB, idGen idgen.IDGenerator) *RunRecordDAO {
return &RunRecordDAO{
db: db,
idGen: idGen,
query: query.Use(db),
}
}
func (dao *RunRecordDAO) Create(ctx context.Context, runMeta *entity.AgentRunMeta) (*entity.RunRecordMeta, error) {
createPO, err := dao.buildCreatePO(ctx, runMeta)
if err != nil {
return nil, err
}
createErr := dao.query.RunRecord.WithContext(ctx).Create(createPO)
if createErr != nil {
return nil, createErr
}
return dao.buildPo2Do(createPO), nil
}
func (dao *RunRecordDAO) GetByID(ctx context.Context, id int64) (*model.RunRecord, error) {
return dao.query.RunRecord.WithContext(ctx).Where(dao.query.RunRecord.ID.Eq(id)).First()
}
func (dao *RunRecordDAO) UpdateByID(ctx context.Context, id int64, updateMeta *entity.UpdateMeta) error {
po := &model.RunRecord{
ID: id,
}
if updateMeta.Status != "" {
po.Status = string(updateMeta.Status)
}
if updateMeta.LastError != nil {
errString, err := json.Marshal(updateMeta.LastError)
if err != nil {
return err
}
po.LastError = string(errString)
}
if updateMeta.CompletedAt != 0 {
po.CompletedAt = updateMeta.CompletedAt
}
if updateMeta.FailedAt != 0 {
po.FailedAt = updateMeta.FailedAt
}
if updateMeta.Usage != nil {
po.Usage = updateMeta.Usage
}
po.UpdatedAt = time.Now().UnixMilli()
_, err := dao.query.RunRecord.WithContext(ctx).Where(dao.query.RunRecord.ID.Eq(id)).Updates(po)
return err
}
func (dao *RunRecordDAO) Delete(ctx context.Context, id []int64) error {
_, err := dao.query.RunRecord.WithContext(ctx).Where(dao.query.RunRecord.ID.In(id...)).UpdateColumns(map[string]interface{}{
"updated_at": time.Now().UnixMilli(),
"status": entity.RunStatusDeleted,
})
return err
}
func (dao *RunRecordDAO) List(ctx context.Context, conversationID int64, sectionID int64, limit int32) ([]*model.RunRecord, error) {
logs.CtxInfof(ctx, "list run record req:%v, sectionID:%v, limit:%v", conversationID, sectionID, limit)
m := dao.query.RunRecord
do := m.WithContext(ctx).Where(m.ConversationID.Eq(conversationID)).Debug().Where(m.Status.NotIn(string(entity.RunStatusDeleted)))
if sectionID > 0 {
do = do.Where(m.SectionID.Eq(sectionID))
}
if limit > 0 {
do = do.Limit(int(limit))
}
runRecords, err := do.Order(m.CreatedAt.Desc()).Find()
return runRecords, err
}
func (dao *RunRecordDAO) buildCreatePO(ctx context.Context, runMeta *entity.AgentRunMeta) (*model.RunRecord, error) {
runID, err := dao.idGen.GenID(ctx)
if err != nil {
return nil, err
}
reqOrigin, err := json.Marshal(runMeta)
if err != nil {
return nil, err
}
timeNow := time.Now().UnixMilli()
return &model.RunRecord{
ID: runID,
ConversationID: runMeta.ConversationID,
SectionID: runMeta.SectionID,
AgentID: runMeta.AgentID,
Status: string(entity.RunStatusCreated),
ChatRequest: string(reqOrigin),
UserID: runMeta.UserID,
CreatedAt: timeNow,
}, nil
}
func (dao *RunRecordDAO) buildPo2Do(po *model.RunRecord) *entity.RunRecordMeta {
runMeta := &entity.RunRecordMeta{
ID: po.ID,
ConversationID: po.ConversationID,
SectionID: po.SectionID,
AgentID: po.AgentID,
Status: entity.RunStatus(po.Status),
Ext: po.Ext,
CreatedAt: po.CreatedAt,
UpdatedAt: po.UpdatedAt,
CompletedAt: po.CompletedAt,
FailedAt: po.FailedAt,
Usage: po.Usage,
}
return runMeta
}

View File

@@ -0,0 +1,34 @@
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
package model
import "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
const TableNameRunRecord = "run_record"
// RunRecord 执行记录表
type RunRecord struct {
ID int64 `gorm:"column:id;primaryKey;comment:主键ID" json:"id"` // 主键ID
ConversationID int64 `gorm:"column:conversation_id;not null;comment:会话 ID" json:"conversation_id"` // 会话 ID
SectionID int64 `gorm:"column:section_id;not null;comment:section ID" json:"section_id"` // section ID
AgentID int64 `gorm:"column:agent_id;not null;comment:agent_id" json:"agent_id"` // agent_id
UserID string `gorm:"column:user_id;not null;comment:user id" json:"user_id"` // user id
Source int32 `gorm:"column:source;not null;comment:执行来源 0 API," json:"source"` // 执行来源 0 API,
Status string `gorm:"column:status;not null;comment:状态,0 Unknown, 1-Created,2-InProgress,3-Completed,4-Failed,5-Expired,6-Cancelled,7-RequiresAction" json:"status"` // 状态,0 Unknown, 1-Created,2-InProgress,3-Completed,4-Failed,5-Expired,6-Cancelled,7-RequiresAction
CreatorID int64 `gorm:"column:creator_id;not null;comment:创建者标识" json:"creator_id"` // 创建者标识
CreatedAt int64 `gorm:"column:created_at;not null;autoCreateTime:milli;comment:创建时间" json:"created_at"` // 创建时间
UpdatedAt int64 `gorm:"column:updated_at;not null;autoUpdateTime:milli;comment:更新时间" json:"updated_at"` // 更新时间
FailedAt int64 `gorm:"column:failed_at;not null;comment:失败时间" json:"failed_at"` // 失败时间
LastError string `gorm:"column:last_error;comment:error message" json:"last_error"` // error message
CompletedAt int64 `gorm:"column:completed_at;not null;comment:结束时间" json:"completed_at"` // 结束时间
ChatRequest string `gorm:"column:chat_request;comment:保存原始请求的部分字段" json:"chat_request"` // 保存原始请求的部分字段
Ext string `gorm:"column:ext;comment:扩展字段" json:"ext"` // 扩展字段
Usage *agentrun.Usage `gorm:"column:usage;comment:usage;serializer:json" json:"usage"` // usage
}
// TableName RunRecord's table name
func (*RunRecord) TableName() string {
return TableNameRunRecord
}

View File

@@ -0,0 +1,103 @@
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
package query
import (
"context"
"database/sql"
"gorm.io/gorm"
"gorm.io/gen"
"gorm.io/plugin/dbresolver"
)
var (
Q = new(Query)
RunRecord *runRecord
)
func SetDefault(db *gorm.DB, opts ...gen.DOOption) {
*Q = *Use(db, opts...)
RunRecord = &Q.RunRecord
}
func Use(db *gorm.DB, opts ...gen.DOOption) *Query {
return &Query{
db: db,
RunRecord: newRunRecord(db, opts...),
}
}
type Query struct {
db *gorm.DB
RunRecord runRecord
}
func (q *Query) Available() bool { return q.db != nil }
func (q *Query) clone(db *gorm.DB) *Query {
return &Query{
db: db,
RunRecord: q.RunRecord.clone(db),
}
}
func (q *Query) ReadDB() *Query {
return q.ReplaceDB(q.db.Clauses(dbresolver.Read))
}
func (q *Query) WriteDB() *Query {
return q.ReplaceDB(q.db.Clauses(dbresolver.Write))
}
func (q *Query) ReplaceDB(db *gorm.DB) *Query {
return &Query{
db: db,
RunRecord: q.RunRecord.replaceDB(db),
}
}
type queryCtx struct {
RunRecord IRunRecordDo
}
func (q *Query) WithContext(ctx context.Context) *queryCtx {
return &queryCtx{
RunRecord: q.RunRecord.WithContext(ctx),
}
}
func (q *Query) Transaction(fc func(tx *Query) error, opts ...*sql.TxOptions) error {
return q.db.Transaction(func(tx *gorm.DB) error { return fc(q.clone(tx)) }, opts...)
}
func (q *Query) Begin(opts ...*sql.TxOptions) *QueryTx {
tx := q.db.Begin(opts...)
return &QueryTx{Query: q.clone(tx), Error: tx.Error}
}
type QueryTx struct {
*Query
Error error
}
func (q *QueryTx) Commit() error {
return q.db.Commit().Error
}
func (q *QueryTx) Rollback() error {
return q.db.Rollback().Error
}
func (q *QueryTx) SavePoint(name string) error {
return q.db.SavePoint(name).Error
}
func (q *QueryTx) RollbackTo(name string) error {
return q.db.RollbackTo(name).Error
}

View File

@@ -0,0 +1,441 @@
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
package query
import (
"context"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gen"
"gorm.io/gen/field"
"gorm.io/plugin/dbresolver"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/internal/dal/model"
)
func newRunRecord(db *gorm.DB, opts ...gen.DOOption) runRecord {
_runRecord := runRecord{}
_runRecord.runRecordDo.UseDB(db, opts...)
_runRecord.runRecordDo.UseModel(&model.RunRecord{})
tableName := _runRecord.runRecordDo.TableName()
_runRecord.ALL = field.NewAsterisk(tableName)
_runRecord.ID = field.NewInt64(tableName, "id")
_runRecord.ConversationID = field.NewInt64(tableName, "conversation_id")
_runRecord.SectionID = field.NewInt64(tableName, "section_id")
_runRecord.AgentID = field.NewInt64(tableName, "agent_id")
_runRecord.UserID = field.NewString(tableName, "user_id")
_runRecord.Source = field.NewInt32(tableName, "source")
_runRecord.Status = field.NewString(tableName, "status")
_runRecord.CreatorID = field.NewInt64(tableName, "creator_id")
_runRecord.CreatedAt = field.NewInt64(tableName, "created_at")
_runRecord.UpdatedAt = field.NewInt64(tableName, "updated_at")
_runRecord.FailedAt = field.NewInt64(tableName, "failed_at")
_runRecord.LastError = field.NewString(tableName, "last_error")
_runRecord.CompletedAt = field.NewInt64(tableName, "completed_at")
_runRecord.ChatRequest = field.NewString(tableName, "chat_request")
_runRecord.Ext = field.NewString(tableName, "ext")
_runRecord.Usage = field.NewField(tableName, "usage")
_runRecord.fillFieldMap()
return _runRecord
}
// runRecord 执行记录表
type runRecord struct {
runRecordDo
ALL field.Asterisk
ID field.Int64 // 主键ID
ConversationID field.Int64 // 会话 ID
SectionID field.Int64 // section ID
AgentID field.Int64 // agent_id
UserID field.String // user id
Source field.Int32 // 执行来源 0 API,
Status field.String // 状态,0 Unknown, 1-Created,2-InProgress,3-Completed,4-Failed,5-Expired,6-Cancelled,7-RequiresAction
CreatorID field.Int64 // 创建者标识
CreatedAt field.Int64 // 创建时间
UpdatedAt field.Int64 // 更新时间
FailedAt field.Int64 // 失败时间
LastError field.String // error message
CompletedAt field.Int64 // 结束时间
ChatRequest field.String // 保存原始请求的部分字段
Ext field.String // 扩展字段
Usage field.Field // usage
fieldMap map[string]field.Expr
}
func (r runRecord) Table(newTableName string) *runRecord {
r.runRecordDo.UseTable(newTableName)
return r.updateTableName(newTableName)
}
func (r runRecord) As(alias string) *runRecord {
r.runRecordDo.DO = *(r.runRecordDo.As(alias).(*gen.DO))
return r.updateTableName(alias)
}
func (r *runRecord) updateTableName(table string) *runRecord {
r.ALL = field.NewAsterisk(table)
r.ID = field.NewInt64(table, "id")
r.ConversationID = field.NewInt64(table, "conversation_id")
r.SectionID = field.NewInt64(table, "section_id")
r.AgentID = field.NewInt64(table, "agent_id")
r.UserID = field.NewString(table, "user_id")
r.Source = field.NewInt32(table, "source")
r.Status = field.NewString(table, "status")
r.CreatorID = field.NewInt64(table, "creator_id")
r.CreatedAt = field.NewInt64(table, "created_at")
r.UpdatedAt = field.NewInt64(table, "updated_at")
r.FailedAt = field.NewInt64(table, "failed_at")
r.LastError = field.NewString(table, "last_error")
r.CompletedAt = field.NewInt64(table, "completed_at")
r.ChatRequest = field.NewString(table, "chat_request")
r.Ext = field.NewString(table, "ext")
r.Usage = field.NewField(table, "usage")
r.fillFieldMap()
return r
}
func (r *runRecord) GetFieldByName(fieldName string) (field.OrderExpr, bool) {
_f, ok := r.fieldMap[fieldName]
if !ok || _f == nil {
return nil, false
}
_oe, ok := _f.(field.OrderExpr)
return _oe, ok
}
func (r *runRecord) fillFieldMap() {
r.fieldMap = make(map[string]field.Expr, 16)
r.fieldMap["id"] = r.ID
r.fieldMap["conversation_id"] = r.ConversationID
r.fieldMap["section_id"] = r.SectionID
r.fieldMap["agent_id"] = r.AgentID
r.fieldMap["user_id"] = r.UserID
r.fieldMap["source"] = r.Source
r.fieldMap["status"] = r.Status
r.fieldMap["creator_id"] = r.CreatorID
r.fieldMap["created_at"] = r.CreatedAt
r.fieldMap["updated_at"] = r.UpdatedAt
r.fieldMap["failed_at"] = r.FailedAt
r.fieldMap["last_error"] = r.LastError
r.fieldMap["completed_at"] = r.CompletedAt
r.fieldMap["chat_request"] = r.ChatRequest
r.fieldMap["ext"] = r.Ext
r.fieldMap["usage"] = r.Usage
}
func (r runRecord) clone(db *gorm.DB) runRecord {
r.runRecordDo.ReplaceConnPool(db.Statement.ConnPool)
return r
}
func (r runRecord) replaceDB(db *gorm.DB) runRecord {
r.runRecordDo.ReplaceDB(db)
return r
}
type runRecordDo struct{ gen.DO }
type IRunRecordDo interface {
gen.SubQuery
Debug() IRunRecordDo
WithContext(ctx context.Context) IRunRecordDo
WithResult(fc func(tx gen.Dao)) gen.ResultInfo
ReplaceDB(db *gorm.DB)
ReadDB() IRunRecordDo
WriteDB() IRunRecordDo
As(alias string) gen.Dao
Session(config *gorm.Session) IRunRecordDo
Columns(cols ...field.Expr) gen.Columns
Clauses(conds ...clause.Expression) IRunRecordDo
Not(conds ...gen.Condition) IRunRecordDo
Or(conds ...gen.Condition) IRunRecordDo
Select(conds ...field.Expr) IRunRecordDo
Where(conds ...gen.Condition) IRunRecordDo
Order(conds ...field.Expr) IRunRecordDo
Distinct(cols ...field.Expr) IRunRecordDo
Omit(cols ...field.Expr) IRunRecordDo
Join(table schema.Tabler, on ...field.Expr) IRunRecordDo
LeftJoin(table schema.Tabler, on ...field.Expr) IRunRecordDo
RightJoin(table schema.Tabler, on ...field.Expr) IRunRecordDo
Group(cols ...field.Expr) IRunRecordDo
Having(conds ...gen.Condition) IRunRecordDo
Limit(limit int) IRunRecordDo
Offset(offset int) IRunRecordDo
Count() (count int64, err error)
Scopes(funcs ...func(gen.Dao) gen.Dao) IRunRecordDo
Unscoped() IRunRecordDo
Create(values ...*model.RunRecord) error
CreateInBatches(values []*model.RunRecord, batchSize int) error
Save(values ...*model.RunRecord) error
First() (*model.RunRecord, error)
Take() (*model.RunRecord, error)
Last() (*model.RunRecord, error)
Find() ([]*model.RunRecord, error)
FindInBatch(batchSize int, fc func(tx gen.Dao, batch int) error) (results []*model.RunRecord, err error)
FindInBatches(result *[]*model.RunRecord, batchSize int, fc func(tx gen.Dao, batch int) error) error
Pluck(column field.Expr, dest interface{}) error
Delete(...*model.RunRecord) (info gen.ResultInfo, err error)
Update(column field.Expr, value interface{}) (info gen.ResultInfo, err error)
UpdateSimple(columns ...field.AssignExpr) (info gen.ResultInfo, err error)
Updates(value interface{}) (info gen.ResultInfo, err error)
UpdateColumn(column field.Expr, value interface{}) (info gen.ResultInfo, err error)
UpdateColumnSimple(columns ...field.AssignExpr) (info gen.ResultInfo, err error)
UpdateColumns(value interface{}) (info gen.ResultInfo, err error)
UpdateFrom(q gen.SubQuery) gen.Dao
Attrs(attrs ...field.AssignExpr) IRunRecordDo
Assign(attrs ...field.AssignExpr) IRunRecordDo
Joins(fields ...field.RelationField) IRunRecordDo
Preload(fields ...field.RelationField) IRunRecordDo
FirstOrInit() (*model.RunRecord, error)
FirstOrCreate() (*model.RunRecord, error)
FindByPage(offset int, limit int) (result []*model.RunRecord, count int64, err error)
ScanByPage(result interface{}, offset int, limit int) (count int64, err error)
Scan(result interface{}) (err error)
Returning(value interface{}, columns ...string) IRunRecordDo
UnderlyingDB() *gorm.DB
schema.Tabler
}
func (r runRecordDo) Debug() IRunRecordDo {
return r.withDO(r.DO.Debug())
}
func (r runRecordDo) WithContext(ctx context.Context) IRunRecordDo {
return r.withDO(r.DO.WithContext(ctx))
}
func (r runRecordDo) ReadDB() IRunRecordDo {
return r.Clauses(dbresolver.Read)
}
func (r runRecordDo) WriteDB() IRunRecordDo {
return r.Clauses(dbresolver.Write)
}
func (r runRecordDo) Session(config *gorm.Session) IRunRecordDo {
return r.withDO(r.DO.Session(config))
}
func (r runRecordDo) Clauses(conds ...clause.Expression) IRunRecordDo {
return r.withDO(r.DO.Clauses(conds...))
}
func (r runRecordDo) Returning(value interface{}, columns ...string) IRunRecordDo {
return r.withDO(r.DO.Returning(value, columns...))
}
func (r runRecordDo) Not(conds ...gen.Condition) IRunRecordDo {
return r.withDO(r.DO.Not(conds...))
}
func (r runRecordDo) Or(conds ...gen.Condition) IRunRecordDo {
return r.withDO(r.DO.Or(conds...))
}
func (r runRecordDo) Select(conds ...field.Expr) IRunRecordDo {
return r.withDO(r.DO.Select(conds...))
}
func (r runRecordDo) Where(conds ...gen.Condition) IRunRecordDo {
return r.withDO(r.DO.Where(conds...))
}
func (r runRecordDo) Order(conds ...field.Expr) IRunRecordDo {
return r.withDO(r.DO.Order(conds...))
}
func (r runRecordDo) Distinct(cols ...field.Expr) IRunRecordDo {
return r.withDO(r.DO.Distinct(cols...))
}
func (r runRecordDo) Omit(cols ...field.Expr) IRunRecordDo {
return r.withDO(r.DO.Omit(cols...))
}
func (r runRecordDo) Join(table schema.Tabler, on ...field.Expr) IRunRecordDo {
return r.withDO(r.DO.Join(table, on...))
}
func (r runRecordDo) LeftJoin(table schema.Tabler, on ...field.Expr) IRunRecordDo {
return r.withDO(r.DO.LeftJoin(table, on...))
}
func (r runRecordDo) RightJoin(table schema.Tabler, on ...field.Expr) IRunRecordDo {
return r.withDO(r.DO.RightJoin(table, on...))
}
func (r runRecordDo) Group(cols ...field.Expr) IRunRecordDo {
return r.withDO(r.DO.Group(cols...))
}
func (r runRecordDo) Having(conds ...gen.Condition) IRunRecordDo {
return r.withDO(r.DO.Having(conds...))
}
func (r runRecordDo) Limit(limit int) IRunRecordDo {
return r.withDO(r.DO.Limit(limit))
}
func (r runRecordDo) Offset(offset int) IRunRecordDo {
return r.withDO(r.DO.Offset(offset))
}
func (r runRecordDo) Scopes(funcs ...func(gen.Dao) gen.Dao) IRunRecordDo {
return r.withDO(r.DO.Scopes(funcs...))
}
func (r runRecordDo) Unscoped() IRunRecordDo {
return r.withDO(r.DO.Unscoped())
}
func (r runRecordDo) Create(values ...*model.RunRecord) error {
if len(values) == 0 {
return nil
}
return r.DO.Create(values)
}
func (r runRecordDo) CreateInBatches(values []*model.RunRecord, batchSize int) error {
return r.DO.CreateInBatches(values, batchSize)
}
// Save : !!! underlying implementation is different with GORM
// The method is equivalent to executing the statement: db.Clauses(clause.OnConflict{UpdateAll: true}).Create(values)
func (r runRecordDo) Save(values ...*model.RunRecord) error {
if len(values) == 0 {
return nil
}
return r.DO.Save(values)
}
func (r runRecordDo) First() (*model.RunRecord, error) {
if result, err := r.DO.First(); err != nil {
return nil, err
} else {
return result.(*model.RunRecord), nil
}
}
func (r runRecordDo) Take() (*model.RunRecord, error) {
if result, err := r.DO.Take(); err != nil {
return nil, err
} else {
return result.(*model.RunRecord), nil
}
}
func (r runRecordDo) Last() (*model.RunRecord, error) {
if result, err := r.DO.Last(); err != nil {
return nil, err
} else {
return result.(*model.RunRecord), nil
}
}
func (r runRecordDo) Find() ([]*model.RunRecord, error) {
result, err := r.DO.Find()
return result.([]*model.RunRecord), err
}
func (r runRecordDo) FindInBatch(batchSize int, fc func(tx gen.Dao, batch int) error) (results []*model.RunRecord, err error) {
buf := make([]*model.RunRecord, 0, batchSize)
err = r.DO.FindInBatches(&buf, batchSize, func(tx gen.Dao, batch int) error {
defer func() { results = append(results, buf...) }()
return fc(tx, batch)
})
return results, err
}
func (r runRecordDo) FindInBatches(result *[]*model.RunRecord, batchSize int, fc func(tx gen.Dao, batch int) error) error {
return r.DO.FindInBatches(result, batchSize, fc)
}
func (r runRecordDo) Attrs(attrs ...field.AssignExpr) IRunRecordDo {
return r.withDO(r.DO.Attrs(attrs...))
}
func (r runRecordDo) Assign(attrs ...field.AssignExpr) IRunRecordDo {
return r.withDO(r.DO.Assign(attrs...))
}
func (r runRecordDo) Joins(fields ...field.RelationField) IRunRecordDo {
for _, _f := range fields {
r = *r.withDO(r.DO.Joins(_f))
}
return &r
}
func (r runRecordDo) Preload(fields ...field.RelationField) IRunRecordDo {
for _, _f := range fields {
r = *r.withDO(r.DO.Preload(_f))
}
return &r
}
func (r runRecordDo) FirstOrInit() (*model.RunRecord, error) {
if result, err := r.DO.FirstOrInit(); err != nil {
return nil, err
} else {
return result.(*model.RunRecord), nil
}
}
func (r runRecordDo) FirstOrCreate() (*model.RunRecord, error) {
if result, err := r.DO.FirstOrCreate(); err != nil {
return nil, err
} else {
return result.(*model.RunRecord), nil
}
}
func (r runRecordDo) FindByPage(offset int, limit int) (result []*model.RunRecord, count int64, err error) {
result, err = r.Offset(offset).Limit(limit).Find()
if err != nil {
return
}
if size := len(result); 0 < limit && 0 < size && size < limit {
count = int64(size + offset)
return
}
count, err = r.Offset(-1).Limit(-1).Count()
return
}
func (r runRecordDo) ScanByPage(result interface{}, offset int, limit int) (count int64, err error) {
count, err = r.Count()
if err != nil {
return
}
err = r.Offset(offset).Limit(limit).Scan(result)
return
}
func (r runRecordDo) Scan(result interface{}) (err error) {
return r.DO.Scan(result)
}
func (r runRecordDo) Delete(models ...*model.RunRecord) (result gen.ResultInfo, err error) {
return r.DO.Delete(models)
}
func (r *runRecordDo) withDO(do gen.Dao) *runRecordDo {
r.DO = *do.(*gen.DO)
return r
}

View File

@@ -0,0 +1,78 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package internal
import (
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
)
type Event struct {
}
func NewEvent() *Event {
return &Event{}
}
func (e *Event) buildMessageEvent(runEvent entity.RunEvent, chunkMsgItem *entity.ChunkMessageItem) *entity.AgentRunResponse {
return &entity.AgentRunResponse{
Event: runEvent,
ChunkMessageItem: chunkMsgItem,
}
}
func (e *Event) buildRunEvent(runEvent entity.RunEvent, chunkRunItem *entity.ChunkRunItem) *entity.AgentRunResponse {
return &entity.AgentRunResponse{
Event: runEvent,
ChunkRunItem: chunkRunItem,
}
}
func (e *Event) buildErrEvent(runEvent entity.RunEvent, err *entity.RunError) *entity.AgentRunResponse {
return &entity.AgentRunResponse{
Event: runEvent,
Error: err,
}
}
func (e *Event) buildStreamDoneEvent() *entity.AgentRunResponse {
return &entity.AgentRunResponse{
Event: entity.RunEventStreamDone,
}
}
func (e *Event) SendRunEvent(runEvent entity.RunEvent, runItem *entity.ChunkRunItem, sw *schema.StreamWriter[*entity.AgentRunResponse]) {
resp := e.buildRunEvent(runEvent, runItem)
sw.Send(resp, nil)
}
func (e *Event) SendMsgEvent(runEvent entity.RunEvent, messageItem *entity.ChunkMessageItem, sw *schema.StreamWriter[*entity.AgentRunResponse]) {
resp := e.buildMessageEvent(runEvent, messageItem)
sw.Send(resp, nil)
}
func (e *Event) SendErrEvent(runEvent entity.RunEvent, sw *schema.StreamWriter[*entity.AgentRunResponse], err *entity.RunError) {
resp := e.buildErrEvent(runEvent, err)
sw.Send(resp, nil)
}
func (e *Event) SendStreamDoneEvent(sw *schema.StreamWriter[*entity.AgentRunResponse]) {
resp := e.buildStreamDoneEvent()
sw.Send(resp, nil)
}

View File

@@ -0,0 +1,123 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package internal
import (
"context"
"time"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/repository"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
type RunProcess struct {
event *Event
RunRecordRepo repository.RunRecordRepo
}
func NewRunProcess(runRecordRepo repository.RunRecordRepo) *RunProcess {
return &RunProcess{
RunRecordRepo: runRecordRepo,
}
}
func (r *RunProcess) StepToCreate(ctx context.Context, srRecord *entity.ChunkRunItem, sw *schema.StreamWriter[*entity.AgentRunResponse]) {
srRecord.Status = entity.RunStatusCreated
r.event.SendRunEvent(entity.RunEventCreated, srRecord, sw)
}
func (r *RunProcess) StepToInProgress(ctx context.Context, srRecord *entity.ChunkRunItem, sw *schema.StreamWriter[*entity.AgentRunResponse]) error {
srRecord.Status = entity.RunStatusInProgress
updateMeta := &entity.UpdateMeta{
Status: entity.RunStatusInProgress,
UpdatedAt: time.Now().UnixMilli(),
}
err := r.RunRecordRepo.UpdateByID(ctx, srRecord.ID, updateMeta)
if err != nil {
return err
}
r.event.SendRunEvent(entity.RunEventInProgress, srRecord, sw)
return nil
}
func (r *RunProcess) StepToComplete(ctx context.Context, srRecord *entity.ChunkRunItem, sw *schema.StreamWriter[*entity.AgentRunResponse], usage *agentrun.Usage) {
completedAt := time.Now().UnixMilli()
updateMeta := &entity.UpdateMeta{
Status: entity.RunStatusCompleted,
Usage: usage,
CompletedAt: completedAt,
UpdatedAt: completedAt,
}
err := r.RunRecordRepo.UpdateByID(ctx, srRecord.ID, updateMeta)
if err != nil {
logs.CtxErrorf(ctx, "RunRecordRepo.UpdateByID error: %v", err)
r.event.SendErrEvent(entity.RunEventError, sw, &entity.RunError{
Code: errno.ErrConversationAgentRunError,
Msg: err.Error(),
})
return
}
srRecord.CompletedAt = completedAt
srRecord.Status = entity.RunStatusCompleted
r.event.SendRunEvent(entity.RunEventCompleted, srRecord, sw)
r.event.SendStreamDoneEvent(sw)
}
func (r *RunProcess) StepToFailed(ctx context.Context, srRecord *entity.ChunkRunItem, sw *schema.StreamWriter[*entity.AgentRunResponse]) {
nowTime := time.Now().UnixMilli()
updateMeta := &entity.UpdateMeta{
Status: entity.RunStatusFailed,
UpdatedAt: nowTime,
FailedAt: nowTime,
LastError: srRecord.Error,
}
err := r.RunRecordRepo.UpdateByID(ctx, srRecord.ID, updateMeta)
if err != nil {
r.event.SendErrEvent(entity.RunEventError, sw, &entity.RunError{
Code: errno.ErrConversationAgentRunError,
Msg: err.Error(),
})
logs.CtxErrorf(ctx, "update run record failed, err: %v", err)
return
}
srRecord.Status = entity.RunStatusFailed
srRecord.FailedAt = time.Now().UnixMilli()
r.event.SendErrEvent(entity.RunEventError, sw, &entity.RunError{
Code: srRecord.Error.Code,
Msg: srRecord.Error.Msg,
})
return
}
func (r *RunProcess) StepToDone(sw *schema.StreamWriter[*entity.AgentRunResponse]) {
r.event.SendStreamDoneEvent(sw)
}

View File

@@ -0,0 +1,41 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package repository
import (
"context"
"gorm.io/gorm"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/internal/dal"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/internal/dal/model"
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
)
func NewRunRecordRepo(db *gorm.DB, idGen idgen.IDGenerator) RunRecordRepo {
return dal.NewRunRecordDAO(db, idGen)
}
type RunRecordRepo interface {
Create(ctx context.Context, runMeta *entity.AgentRunMeta) (*entity.RunRecordMeta, error)
GetByID(ctx context.Context, id int64) (*entity.RunRecord, error)
Delete(ctx context.Context, id []int64) error
UpdateByID(ctx context.Context, id int64, update *entity.UpdateMeta) error
List(ctx context.Context, conversationID int64, sectionID int64, limit int32) ([]*model.RunRecord, error)
}

View File

@@ -0,0 +1,31 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package agentrun
import (
"context"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
)
type Run interface {
AgentRun(ctx context.Context, req *entity.AgentRunMeta) (*schema.StreamReader[*entity.AgentRunResponse], error)
Delete(ctx context.Context, runID []int64) error
}

View File

@@ -0,0 +1,994 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package agentrun
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"runtime/debug"
"strconv"
"sync"
"time"
"github.com/cloudwego/eino/schema"
"github.com/mohae/deepcopy"
messageModel "github.com/coze-dev/coze-studio/backend/api/model/conversation/message"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossagent"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossmessage"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/internal"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/internal/dal/model"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/repository"
msgEntity "github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
type runImpl struct {
Components
runProcess *internal.RunProcess
runEvent *internal.Event
}
type runtimeDependence struct {
runID int64
agentInfo *singleagent.SingleAgent
questionMsgID int64
runMeta *entity.AgentRunMeta
startTime time.Time
usage *agentrun.Usage
}
type Components struct {
RunRecordRepo repository.RunRecordRepo
}
func NewService(c *Components) Run {
return &runImpl{
Components: *c,
runEvent: internal.NewEvent(),
runProcess: internal.NewRunProcess(c.RunRecordRepo),
}
}
func (c *runImpl) AgentRun(ctx context.Context, arm *entity.AgentRunMeta) (*schema.StreamReader[*entity.AgentRunResponse], error) {
sr, sw := schema.Pipe[*entity.AgentRunResponse](20)
defer func() {
if pe := recover(); pe != nil {
logs.CtxErrorf(ctx, "panic recover: %v\n, [stack]:%v", pe, string(debug.Stack()))
return
}
}()
rtDependence := &runtimeDependence{
runMeta: arm,
startTime: time.Now(),
}
safego.Go(ctx, func() {
defer sw.Close()
_ = c.run(ctx, sw, rtDependence)
})
return sr, nil
}
func (c *runImpl) run(ctx context.Context, sw *schema.StreamWriter[*entity.AgentRunResponse], rtDependence *runtimeDependence) (err error) {
runRecord, err := c.createRunRecord(ctx, sw, rtDependence)
if err != nil {
return
}
rtDependence.runID = runRecord.ID
defer func() {
srRecord := c.buildSendRunRecord(ctx, runRecord, entity.RunStatusCompleted)
if err != nil {
srRecord.Error = &entity.RunError{
Code: errno.ErrConversationAgentRunError,
Msg: err.Error(),
}
c.runProcess.StepToFailed(ctx, srRecord, sw)
return
}
c.runProcess.StepToComplete(ctx, srRecord, sw, rtDependence.usage)
}()
agentInfo, err := c.handlerAgent(ctx, rtDependence)
if err != nil {
return
}
rtDependence.agentInfo = agentInfo
history, err := c.handlerHistory(ctx, rtDependence)
if err != nil {
return
}
input, err := c.handlerInput(ctx, sw, rtDependence)
if err != nil {
return
}
rtDependence.questionMsgID = input.ID
err = c.handlerStreamExecute(ctx, sw, history, input, rtDependence)
return
}
func (c *runImpl) handlerAgent(ctx context.Context, rtDependence *runtimeDependence) (*singleagent.SingleAgent, error) {
agentInfo, err := crossagent.DefaultSVC().ObtainAgentByIdentity(ctx, &singleagent.AgentIdentity{
AgentID: rtDependence.runMeta.AgentID,
IsDraft: rtDependence.runMeta.IsDraft,
})
if err != nil {
return nil, err
}
return agentInfo, nil
}
func (c *runImpl) handlerStreamExecute(ctx context.Context, sw *schema.StreamWriter[*entity.AgentRunResponse], historyMsg []*msgEntity.Message, input *msgEntity.Message, rtDependence *runtimeDependence) (err error) {
mainChan := make(chan *entity.AgentRespEvent, 100)
ar := &singleagent.AgentRuntime{
AgentVersion: rtDependence.runMeta.Version,
SpaceID: rtDependence.runMeta.SpaceID,
IsDraft: rtDependence.runMeta.IsDraft,
ConnectorID: rtDependence.runMeta.ConnectorID,
PreRetrieveTools: rtDependence.runMeta.PreRetrieveTools,
}
streamer, err := crossagent.DefaultSVC().StreamExecute(ctx, historyMsg, input, ar)
if err != nil {
return errors.New(errorx.ErrorWithoutStack(err))
}
var wg sync.WaitGroup
wg.Add(2)
safego.Go(ctx, func() {
defer wg.Done()
c.pull(ctx, mainChan, streamer)
})
safego.Go(ctx, func() {
defer wg.Done()
c.push(ctx, mainChan, sw, rtDependence)
})
wg.Wait()
return err
}
func transformEventMap(eventType singleagent.EventType) (message.MessageType, error) {
var eType message.MessageType
switch eventType {
case singleagent.EventTypeOfFuncCall:
return message.MessageTypeFunctionCall, nil
case singleagent.EventTypeOfKnowledge:
return message.MessageTypeKnowledge, nil
case singleagent.EventTypeOfToolsMessage:
return message.MessageTypeToolResponse, nil
case singleagent.EventTypeOfChatModelAnswer:
return message.MessageTypeAnswer, nil
case singleagent.EventTypeOfSuggest:
return message.MessageTypeFlowUp, nil
case singleagent.EventTypeOfInterrupt:
return message.MessageTypeInterrupt, nil
}
return eType, errorx.New(errno.ErrReplyUnknowEventType)
}
func (c *runImpl) buildAgentMessage2Create(ctx context.Context, chunk *entity.AgentRespEvent, messageType message.MessageType, rtDependence *runtimeDependence) *message.Message {
arm := rtDependence.runMeta
msg := &msgEntity.Message{
ConversationID: arm.ConversationID,
RunID: rtDependence.runID,
AgentID: arm.AgentID,
SectionID: arm.SectionID,
UserID: arm.UserID,
MessageType: messageType,
}
buildExt := map[string]string{}
timeCost := fmt.Sprintf("%.1f", float64(time.Since(rtDependence.startTime).Milliseconds())/1000.00)
switch messageType {
case message.MessageTypeQuestion:
msg.Role = schema.User
msg.ContentType = arm.ContentType
for _, content := range arm.Content {
if content.Type == message.InputTypeText {
msg.Content = content.Text
break
}
}
msg.MultiContent = arm.Content
buildExt = arm.Ext
msg.DisplayContent = arm.DisplayContent
case message.MessageTypeAnswer:
msg.Role = schema.Assistant
msg.ContentType = message.ContentTypeText
case message.MessageTypeToolResponse:
msg.Role = schema.Tool
msg.ContentType = message.ContentTypeText
msg.Content = chunk.ToolsMessage[0].Content
buildExt[string(msgEntity.MessageExtKeyTimeCost)] = timeCost
modelContent := chunk.ToolsMessage[0]
mc, err := json.Marshal(modelContent)
if err == nil {
msg.ModelContent = string(mc)
}
case message.MessageTypeKnowledge:
msg.Role = schema.Assistant
msg.ContentType = message.ContentTypeText
knowledgeContent := c.buildKnowledge(ctx, arm, chunk)
if knowledgeContent != nil {
knInfo, err := json.Marshal(knowledgeContent)
if err == nil {
msg.Content = string(knInfo)
}
}
buildExt[string(msgEntity.MessageExtKeyTimeCost)] = timeCost
modelContent := chunk.Knowledge
mc, err := json.Marshal(modelContent)
if err == nil {
msg.ModelContent = string(mc)
}
case message.MessageTypeFunctionCall:
msg.Role = schema.Assistant
msg.ContentType = message.ContentTypeText
if len(chunk.FuncCall.ToolCalls) > 0 {
toolCall := chunk.FuncCall.ToolCalls[0]
toolCalling, err := json.Marshal(toolCall)
if err == nil {
msg.Content = string(toolCalling)
}
buildExt[string(msgEntity.MessageExtKeyPlugin)] = toolCall.Function.Name
buildExt[string(msgEntity.MessageExtKeyToolName)] = toolCall.Function.Name
buildExt[string(msgEntity.MessageExtKeyTimeCost)] = timeCost
modelContent := chunk.FuncCall
mc, err := json.Marshal(modelContent)
if err == nil {
msg.ModelContent = string(mc)
}
}
case message.MessageTypeFlowUp:
msg.Role = schema.Assistant
msg.ContentType = message.ContentTypeText
msg.Content = chunk.Suggest.Content
case message.MessageTypeVerbose:
msg.Role = schema.Assistant
msg.ContentType = message.ContentTypeText
d := &entity.Data{
FinishReason: 0,
FinData: "",
}
dByte, _ := json.Marshal(d)
afc := &entity.AnswerFinshContent{
MsgType: entity.MessageSubTypeGenerateFinish,
Data: string(dByte),
}
afcMarshal, _ := json.Marshal(afc)
msg.Content = string(afcMarshal)
case message.MessageTypeInterrupt:
msg.Role = schema.Assistant
msg.MessageType = message.MessageTypeVerbose
msg.ContentType = message.ContentTypeText
afc := &entity.AnswerFinshContent{
MsgType: entity.MessageSubTypeInterrupt,
Data: "",
}
afcMarshal, _ := json.Marshal(afc)
msg.Content = string(afcMarshal)
// 添加 ext 用于保存到 context_message
interruptByte, err := json.Marshal(chunk.Interrupt)
if err == nil {
buildExt[string(msgEntity.ExtKeyResumeInfo)] = string(interruptByte)
}
buildExt[string(msgEntity.ExtKeyToolCallsIDs)] = chunk.Interrupt.ToolCallID
rc := &messageModel.RequiredAction{
Type: "submit_tool_outputs",
SubmitToolOutputs: &messageModel.SubmitToolOutputs{},
}
msg.RequiredAction = rc
rcExtByte, err := json.Marshal(rc)
if err == nil {
buildExt[string(msgEntity.ExtKeyRequiresAction)] = string(rcExtByte)
}
}
if messageType != message.MessageTypeQuestion {
botStateExt := c.buildBotStateExt(arm)
bseString, err := json.Marshal(botStateExt)
if err == nil {
buildExt[string(msgEntity.MessageExtKeyBotState)] = string(bseString)
}
}
msg.Ext = buildExt
return msg
}
func (c *runImpl) handlerHistory(ctx context.Context, rtDependence *runtimeDependence) ([]*msgEntity.Message, error) {
conversationTurns := entity.ConversationTurnsDefault
if rtDependence.agentInfo != nil && rtDependence.agentInfo.ModelInfo != nil && rtDependence.agentInfo.ModelInfo.ShortMemoryPolicy != nil && ptr.From(rtDependence.agentInfo.ModelInfo.ShortMemoryPolicy.HistoryRound) > 0 {
conversationTurns = ptr.From(rtDependence.agentInfo.ModelInfo.ShortMemoryPolicy.HistoryRound)
}
runRecordList, err := c.RunRecordRepo.List(ctx, rtDependence.runMeta.ConversationID, rtDependence.runMeta.SectionID, conversationTurns)
if err != nil {
return nil, err
}
if len(runRecordList) == 0 {
return nil, nil
}
runIDS := c.getRunID(runRecordList)
history, err := crossmessage.DefaultSVC().GetByRunIDs(ctx, rtDependence.runMeta.ConversationID, runIDS)
if err != nil {
return nil, err
}
return history, nil
}
func (c *runImpl) getRunID(rr []*model.RunRecord) []int64 {
ids := make([]int64, 0, len(rr))
for _, c := range rr {
ids = append(ids, c.ID)
}
return ids
}
func (c *runImpl) createRunRecord(ctx context.Context, sw *schema.StreamWriter[*entity.AgentRunResponse], rtDependence *runtimeDependence) (*entity.RunRecordMeta, error) {
runPoData, err := c.RunRecordRepo.Create(ctx, rtDependence.runMeta)
if err != nil {
logs.CtxErrorf(ctx, "RunRecordRepo.Create error: %v", err)
return nil, err
}
srRecord := c.buildSendRunRecord(ctx, runPoData, entity.RunStatusCreated)
c.runProcess.StepToCreate(ctx, srRecord, sw)
err = c.runProcess.StepToInProgress(ctx, srRecord, sw)
if err != nil {
logs.CtxErrorf(ctx, "runProcess.StepToInProgress error: %v", err)
return nil, err
}
return runPoData, nil
}
func (c *runImpl) handlerInput(ctx context.Context, sw *schema.StreamWriter[*entity.AgentRunResponse], rtDependence *runtimeDependence) (*msgEntity.Message, error) {
msgMeta := c.buildAgentMessage2Create(ctx, nil, message.MessageTypeQuestion, rtDependence)
cm, err := crossmessage.DefaultSVC().Create(ctx, msgMeta)
if err != nil {
return nil, err
}
ackErr := c.handlerAckMessage(ctx, cm, sw)
if ackErr != nil {
return msgMeta, ackErr
}
return cm, nil
}
func (c *runImpl) pull(_ context.Context, mainChan chan *entity.AgentRespEvent, events *schema.StreamReader[*crossagent.AgentEvent]) {
defer func() {
close(mainChan)
}()
for {
rm, re := events.Recv()
if re != nil {
errChunk := &entity.AgentRespEvent{
Err: re,
}
mainChan <- errChunk
return
}
eventType, tErr := transformEventMap(rm.EventType)
if tErr != nil {
errChunk := &entity.AgentRespEvent{
Err: tErr,
}
mainChan <- errChunk
return
}
respChunk := &entity.AgentRespEvent{
EventType: eventType,
ModelAnswer: rm.ChatModelAnswer,
ToolsMessage: rm.ToolsMessage,
FuncCall: rm.FuncCall,
Knowledge: rm.Knowledge,
Suggest: rm.Suggest,
Interrupt: rm.Interrupt,
}
mainChan <- respChunk
}
}
func (c *runImpl) push(ctx context.Context, mainChan chan *entity.AgentRespEvent, sw *schema.StreamWriter[*entity.AgentRunResponse], rtDependence *runtimeDependence) {
var err error
defer func() {
if err != nil {
logs.CtxErrorf(ctx, "run.push error: %v", err)
c.handlerErr(ctx, err, sw)
}
}()
for {
chunk, ok := <-mainChan
if !ok || chunk == nil {
return
}
logs.CtxInfof(ctx, "hanlder event:%v,err:%v", conv.DebugJsonToStr(chunk), chunk.Err)
if chunk.Err != nil {
if errors.Is(chunk.Err, io.EOF) {
return
}
c.handlerErr(ctx, chunk.Err, sw)
return
}
switch chunk.EventType {
case message.MessageTypeFunctionCall:
err = c.handlerFunctionCall(ctx, chunk, sw, rtDependence)
if err != nil {
return
}
case message.MessageTypeToolResponse:
err = c.handlerTooResponse(ctx, chunk, sw, rtDependence)
if err != nil {
return
}
case message.MessageTypeKnowledge:
err = c.handlerKnowledge(ctx, chunk, sw, rtDependence)
if err != nil {
return
}
case message.MessageTypeAnswer:
fullContent := bytes.NewBuffer([]byte{})
reasoningContent := bytes.NewBuffer([]byte{})
var preMsg *msgEntity.Message
var usage *msgEntity.UsageExt
var createPreMsg = true
var isToolCalls = false
for {
streamMsg, receErr := chunk.ModelAnswer.Recv()
if receErr != nil {
if errors.Is(receErr, io.EOF) {
if isToolCalls && reasoningContent.String() == "" {
break
}
finalAnswer := c.buildSendMsg(ctx, preMsg, false, rtDependence)
finalAnswer.Content = fullContent.String()
finalAnswer.ReasoningContent = ptr.Of(reasoningContent.String())
hfErr := c.handlerFinalAnswer(ctx, finalAnswer, sw, usage, rtDependence)
if hfErr != nil {
err = hfErr
return
}
finishErr := c.handlerFinalAnswerFinish(ctx, sw, rtDependence)
if finishErr != nil {
err = finishErr
return
}
break
}
err = receErr
return
}
if streamMsg != nil && len(streamMsg.ToolCalls) > 0 {
isToolCalls = true
}
if streamMsg != nil && streamMsg.ResponseMeta != nil {
usage = c.handlerUsage(streamMsg.ResponseMeta)
}
if streamMsg != nil && len(streamMsg.ReasoningContent) == 0 && len(streamMsg.Content) == 0 {
continue
}
if createPreMsg && (len(streamMsg.ReasoningContent) > 0 || len(streamMsg.Content) > 0) {
preMsg, err = c.handlerPreAnswer(ctx, rtDependence)
if err != nil {
return
}
createPreMsg = false
}
sendMsg := c.buildSendMsg(ctx, preMsg, false, rtDependence)
reasoningContent.WriteString(streamMsg.ReasoningContent)
sendMsg.ReasoningContent = ptr.Of(streamMsg.ReasoningContent)
fullContent.WriteString(streamMsg.Content)
sendMsg.Content = streamMsg.Content
c.runEvent.SendMsgEvent(entity.RunEventMessageDelta, sendMsg, sw)
}
case message.MessageTypeFlowUp:
err = c.handlerSuggest(ctx, chunk, sw, rtDependence)
if err != nil {
return
}
case message.MessageTypeInterrupt:
err = c.handlerInterrupt(ctx, chunk, sw, rtDependence)
if err != nil {
return
}
}
}
}
func (c *runImpl) handlerInterrupt(ctx context.Context, chunk *entity.AgentRespEvent, sw *schema.StreamWriter[*entity.AgentRunResponse], rtDependence *runtimeDependence) error {
interruptData, cType, err := c.parseInterruptData(ctx, chunk.Interrupt)
if err != nil {
return err
}
preMsg, err := c.handlerPreAnswer(ctx, rtDependence)
if err != nil {
return err
}
deltaAnswer := &entity.ChunkMessageItem{
ID: preMsg.ID,
ConversationID: preMsg.ConversationID,
SectionID: preMsg.SectionID,
RunID: preMsg.RunID,
AgentID: preMsg.AgentID,
Role: entity.RoleType(preMsg.Role),
Content: interruptData,
MessageType: preMsg.MessageType,
ContentType: cType,
ReplyID: preMsg.RunID,
Ext: preMsg.Ext,
IsFinish: false,
}
c.runEvent.SendMsgEvent(entity.RunEventMessageDelta, deltaAnswer, sw)
finalAnswer := deepcopy.Copy(deltaAnswer).(*entity.ChunkMessageItem)
err = c.handlerFinalAnswer(ctx, finalAnswer, sw, nil, rtDependence)
if err != nil {
return err
}
err = c.handlerInterruptVerbose(ctx, chunk, sw, rtDependence)
if err != nil {
return err
}
err = c.handlerFinalAnswerFinish(ctx, sw, rtDependence)
if err != nil {
return err
}
return nil
}
func (c *runImpl) parseInterruptData(_ context.Context, interruptData *singleagent.InterruptInfo) (string, message.ContentType, error) {
type msg struct {
Type string `json:"type,omitempty"`
ContentType string `json:"content_type"`
Content any `json:"content"` // either optionContent or string
ID string `json:"id,omitempty"`
}
defaultContentType := message.ContentTypeText
switch interruptData.InterruptType {
case singleagent.InterruptEventType_OauthPlugin:
data := interruptData.AllToolInterruptData[interruptData.ToolCallID].ToolNeedOAuth.Message
return data, defaultContentType, nil
case singleagent.InterruptEventType_Question:
var iData map[string][]*msg
err := json.Unmarshal([]byte(interruptData.AllWfInterruptData[interruptData.ToolCallID].InterruptData), &iData)
if err != nil {
return "", defaultContentType, err
}
if len(iData["messages"]) == 0 {
return "", defaultContentType, errorx.New(errno.ErrInterruptDataEmpty)
}
interruptMsg := iData["messages"][0]
if interruptMsg.ContentType == "text" {
return interruptMsg.Content.(string), defaultContentType, nil
} else if interruptMsg.ContentType == "option" || interruptMsg.ContentType == "form_schema" {
iMarshalData, err := json.Marshal(interruptMsg)
if err != nil {
return "", defaultContentType, err
}
return string(iMarshalData), message.ContentTypeCard, nil
}
case singleagent.InterruptEventType_InputNode:
data := interruptData.AllWfInterruptData[interruptData.ToolCallID].InterruptData
return data, message.ContentTypeCard, nil
case singleagent.InterruptEventType_WorkflowLLM:
toolInterruptEvent := interruptData.AllWfInterruptData[interruptData.ToolCallID].ToolInterruptEvent
data := toolInterruptEvent.InterruptData
if singleagent.InterruptEventType(toolInterruptEvent.EventType) == singleagent.InterruptEventType_InputNode {
return data, message.ContentTypeCard, nil
}
if singleagent.InterruptEventType(toolInterruptEvent.EventType) == singleagent.InterruptEventType_Question {
var iData map[string][]*msg
err := json.Unmarshal([]byte(data), &iData)
if err != nil {
return "", defaultContentType, err
}
if len(iData["messages"]) == 0 {
return "", defaultContentType, errorx.New(errno.ErrInterruptDataEmpty)
}
interruptMsg := iData["messages"][0]
if interruptMsg.ContentType == "text" {
return interruptMsg.Content.(string), defaultContentType, nil
} else if interruptMsg.ContentType == "option" || interruptMsg.ContentType == "form_schema" {
iMarshalData, err := json.Marshal(interruptMsg)
if err != nil {
return "", defaultContentType, err
}
return string(iMarshalData), message.ContentTypeCard, nil
}
}
return "", defaultContentType, errorx.New(errno.ErrUnknowInterruptType)
}
return "", defaultContentType, errorx.New(errno.ErrUnknowInterruptType)
}
func (c *runImpl) handlerUsage(meta *schema.ResponseMeta) *msgEntity.UsageExt {
if meta == nil || meta.Usage == nil {
return nil
}
return &msgEntity.UsageExt{
TotalCount: int64(meta.Usage.TotalTokens),
InputTokens: int64(meta.Usage.PromptTokens),
OutputTokens: int64(meta.Usage.CompletionTokens),
}
}
func (c *runImpl) handlerErr(_ context.Context, err error, sw *schema.StreamWriter[*entity.AgentRunResponse]) {
c.runEvent.SendErrEvent(entity.RunEventError, sw, &entity.RunError{
Code: errno.ErrConversationAgentRunError,
Msg: errorx.ErrorWithoutStack(err),
})
}
func (c *runImpl) handlerPreAnswer(ctx context.Context, rtDependence *runtimeDependence) (*msgEntity.Message, error) {
arm := rtDependence.runMeta
msgMeta := &msgEntity.Message{
ConversationID: arm.ConversationID,
RunID: rtDependence.runID,
AgentID: arm.AgentID,
SectionID: arm.SectionID,
UserID: arm.UserID,
Role: schema.Assistant,
MessageType: message.MessageTypeAnswer,
ContentType: message.ContentTypeText,
Ext: arm.Ext,
}
if arm.Ext == nil {
msgMeta.Ext = map[string]string{}
}
botStateExt := c.buildBotStateExt(arm)
bseString, err := json.Marshal(botStateExt)
if err != nil {
return nil, err
}
if _, ok := msgMeta.Ext[string(msgEntity.MessageExtKeyBotState)]; !ok {
msgMeta.Ext[string(msgEntity.MessageExtKeyBotState)] = string(bseString)
}
msgMeta.Ext = arm.Ext
return crossmessage.DefaultSVC().Create(ctx, msgMeta)
}
func (c *runImpl) handlerFinalAnswer(ctx context.Context, msg *entity.ChunkMessageItem, sw *schema.StreamWriter[*entity.AgentRunResponse], usage *msgEntity.UsageExt, rtDependence *runtimeDependence) error {
if len(msg.Content) == 0 && len(ptr.From(msg.ReasoningContent)) == 0 {
return nil
}
msg.IsFinish = true
if msg.Ext == nil {
msg.Ext = map[string]string{}
}
if usage != nil {
msg.Ext[string(msgEntity.MessageExtKeyToken)] = strconv.FormatInt(usage.TotalCount, 10)
msg.Ext[string(msgEntity.MessageExtKeyInputTokens)] = strconv.FormatInt(usage.InputTokens, 10)
msg.Ext[string(msgEntity.MessageExtKeyOutputTokens)] = strconv.FormatInt(usage.OutputTokens, 10)
rtDependence.usage = &agentrun.Usage{
LlmPromptTokens: usage.InputTokens,
LlmCompletionTokens: usage.OutputTokens,
LlmTotalTokens: usage.TotalCount,
}
}
if _, ok := msg.Ext[string(msgEntity.MessageExtKeyTimeCost)]; !ok {
msg.Ext[string(msgEntity.MessageExtKeyTimeCost)] = fmt.Sprintf("%.1f", float64(time.Since(rtDependence.startTime).Milliseconds())/1000.00)
}
buildModelContent := &schema.Message{
Role: schema.Assistant,
Content: msg.Content,
}
mc, err := json.Marshal(buildModelContent)
if err != nil {
return err
}
editMsg := &msgEntity.Message{
ID: msg.ID,
Content: msg.Content,
ContentType: msg.ContentType,
ModelContent: string(mc),
ReasoningContent: ptr.From(msg.ReasoningContent),
Ext: msg.Ext,
}
_, err = crossmessage.DefaultSVC().Edit(ctx, editMsg)
if err != nil {
return err
}
c.runEvent.SendMsgEvent(entity.RunEventMessageCompleted, msg, sw)
return nil
}
func (c *runImpl) buildBotStateExt(arm *entity.AgentRunMeta) *msgEntity.BotStateExt {
agentID := strconv.FormatInt(arm.AgentID, 10)
botStateExt := &msgEntity.BotStateExt{
AgentID: agentID,
AgentName: arm.Name,
Awaiting: agentID,
BotID: agentID,
}
return botStateExt
}
func (c *runImpl) handlerFunctionCall(ctx context.Context, chunk *entity.AgentRespEvent, sw *schema.StreamWriter[*entity.AgentRunResponse], rtDependence *runtimeDependence) error {
cm := c.buildAgentMessage2Create(ctx, chunk, message.MessageTypeFunctionCall, rtDependence)
cmData, err := crossmessage.DefaultSVC().Create(ctx, cm)
if err != nil {
return err
}
sendMsg := c.buildSendMsg(ctx, cmData, true, rtDependence)
c.runEvent.SendMsgEvent(entity.RunEventMessageCompleted, sendMsg, sw)
return nil
}
func (c *runImpl) handlerAckMessage(_ context.Context, input *msgEntity.Message, sw *schema.StreamWriter[*entity.AgentRunResponse]) error {
sendMsg := &entity.ChunkMessageItem{
ID: input.ID,
ConversationID: input.ConversationID,
SectionID: input.SectionID,
AgentID: input.AgentID,
Role: entity.RoleType(input.Role),
MessageType: message.MessageTypeAck,
ReplyID: input.ID,
Content: input.Content,
ContentType: message.ContentTypeText,
IsFinish: true,
}
c.runEvent.SendMsgEvent(entity.RunEventAck, sendMsg, sw)
return nil
}
func (c *runImpl) handlerTooResponse(ctx context.Context, chunk *entity.AgentRespEvent, sw *schema.StreamWriter[*entity.AgentRunResponse], rtDependence *runtimeDependence) error {
cm := c.buildAgentMessage2Create(ctx, chunk, message.MessageTypeToolResponse, rtDependence)
cmData, err := crossmessage.DefaultSVC().Create(ctx, cm)
if err != nil {
return err
}
sendMsg := c.buildSendMsg(ctx, cmData, true, rtDependence)
c.runEvent.SendMsgEvent(entity.RunEventMessageCompleted, sendMsg, sw)
return nil
}
func (c *runImpl) handlerSuggest(ctx context.Context, chunk *entity.AgentRespEvent, sw *schema.StreamWriter[*entity.AgentRunResponse], rtDependence *runtimeDependence) error {
cm := c.buildAgentMessage2Create(ctx, chunk, message.MessageTypeFlowUp, rtDependence)
cmData, err := crossmessage.DefaultSVC().Create(ctx, cm)
if err != nil {
return err
}
sendMsg := c.buildSendMsg(ctx, cmData, true, rtDependence)
c.runEvent.SendMsgEvent(entity.RunEventMessageCompleted, sendMsg, sw)
return nil
}
func (c *runImpl) handlerKnowledge(ctx context.Context, chunk *entity.AgentRespEvent, sw *schema.StreamWriter[*entity.AgentRunResponse], rtDependence *runtimeDependence) error {
cm := c.buildAgentMessage2Create(ctx, chunk, message.MessageTypeKnowledge, rtDependence)
cmData, err := crossmessage.DefaultSVC().Create(ctx, cm)
if err != nil {
return err
}
sendMsg := c.buildSendMsg(ctx, cmData, true, rtDependence)
c.runEvent.SendMsgEvent(entity.RunEventMessageCompleted, sendMsg, sw)
return nil
}
func (c *runImpl) buildKnowledge(_ context.Context, arm *entity.AgentRunMeta, chunk *entity.AgentRespEvent) *msgEntity.VerboseInfo {
var recallDatas []msgEntity.RecallDataInfo
for _, kOne := range chunk.Knowledge {
recallDatas = append(recallDatas, msgEntity.RecallDataInfo{
Slice: kOne.Content,
Meta: msgEntity.MetaInfo{
Dataset: msgEntity.DatasetInfo{
ID: kOne.MetaData["dataset_id"].(string),
Name: kOne.MetaData["dataset_name"].(string),
},
Document: msgEntity.DocumentInfo{
ID: kOne.MetaData["document_id"].(string),
Name: kOne.MetaData["document_name"].(string),
},
},
Score: kOne.Score(),
})
}
verboseData := &msgEntity.VerboseData{
Chunks: recallDatas,
OriReq: "",
StatusCode: 0,
}
data, err := json.Marshal(verboseData)
if err != nil {
return nil
}
knowledgeInfo := &msgEntity.VerboseInfo{
MessageType: string(entity.MessageSubTypeKnowledgeCall),
Data: string(data),
}
return knowledgeInfo
}
func (c *runImpl) handlerFinalAnswerFinish(ctx context.Context, sw *schema.StreamWriter[*entity.AgentRunResponse], rtDependence *runtimeDependence) error {
cm := c.buildAgentMessage2Create(ctx, nil, message.MessageTypeVerbose, rtDependence)
cmData, err := crossmessage.DefaultSVC().Create(ctx, cm)
if err != nil {
return err
}
sendMsg := c.buildSendMsg(ctx, cmData, true, rtDependence)
c.runEvent.SendMsgEvent(entity.RunEventMessageCompleted, sendMsg, sw)
return nil
}
func (c *runImpl) handlerInterruptVerbose(ctx context.Context, chunk *entity.AgentRespEvent, sw *schema.StreamWriter[*entity.AgentRunResponse], rtDependence *runtimeDependence) error {
cm := c.buildAgentMessage2Create(ctx, chunk, message.MessageTypeInterrupt, rtDependence)
cmData, err := crossmessage.DefaultSVC().Create(ctx, cm)
if err != nil {
return err
}
sendMsg := c.buildSendMsg(ctx, cmData, true, rtDependence)
c.runEvent.SendMsgEvent(entity.RunEventMessageCompleted, sendMsg, sw)
return nil
}
func (c *runImpl) buildSendMsg(_ context.Context, msg *msgEntity.Message, isFinish bool, rtDependence *runtimeDependence) *entity.ChunkMessageItem {
copyMap := make(map[string]string)
for k, v := range msg.Ext {
copyMap[k] = v
}
return &entity.ChunkMessageItem{
ID: msg.ID,
ConversationID: msg.ConversationID,
SectionID: msg.SectionID,
AgentID: msg.AgentID,
Content: msg.Content,
Role: entity.RoleTypeAssistant,
ContentType: msg.ContentType,
MessageType: msg.MessageType,
ReplyID: rtDependence.questionMsgID,
Type: msg.MessageType,
CreatedAt: msg.CreatedAt,
UpdatedAt: msg.UpdatedAt,
RunID: rtDependence.runID,
Ext: copyMap,
IsFinish: isFinish,
ReasoningContent: ptr.Of(msg.ReasoningContent),
}
}
func (c *runImpl) buildSendRunRecord(_ context.Context, runRecord *entity.RunRecordMeta, runStatus entity.RunStatus) *entity.ChunkRunItem {
return &entity.ChunkRunItem{
ID: runRecord.ID,
ConversationID: runRecord.ConversationID,
AgentID: runRecord.AgentID,
SectionID: runRecord.SectionID,
Status: runStatus,
CreatedAt: runRecord.CreatedAt,
}
}
func (c *runImpl) Delete(ctx context.Context, runID []int64) error {
return c.RunRecordRepo.Delete(ctx, runID)
}

View File

@@ -0,0 +1,99 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package agentrun
import (
"testing"
)
func TestAgentRun(t *testing.T) {
// ctx := context.Background()
//
// mockDB, err := mysql.New()
// assert.Nil(t, err)
// cacheCli := redis.New()
//
// idGen, err := idgen.New(cacheCli)
// ctrl := gomock.NewController(t)
// idGen := mock.NewMockIDGenerator(ctrl)
// // idGen.EXPECT().GenMultiIDs(gomock.Any(), 2).Return([]int64{time.Now().UnixMilli(), time.Now().Add(time.Second).UnixMilli()}, nil).AnyTimes()
// idGen.EXPECT().GenID(gomock.Any()).Return(int64(time.Now().UnixMilli()), nil).AnyTimes()
//
// mockDBGen := orm.NewMockDB()
// mockDBGen.AddTable(&model.RunRecord{})
// mockDB, err := mockDBGen.DB()
//
// assert.NoError(t, err)
// components := &Components{
// DB: mockDB,
// IDGen: idGen,
// }
//
// imageInput := &entity.FileData{
// Url: "https://xxxxx.xxxx/image",
// Name: "test_img",
// }
// fileInput := &entity.FileData{
// Url: "https://xxxxx.xxxx/file",
// Name: "test_file",
// }
// content := []*entity.InputMetaData{
// {
// Type: entity.InputTypeText,
// Text: "你是谁",
// },
// {
// Type: entity.InputTypeImage,
// FileData: []*entity.FileData{
// imageInput,
// },
// },
// {
// Type: entity.InputTypeFile,
// FileData: []*entity.FileData{
// fileInput,
// },
// },
// }
// stream, err := NewService(components, nil).AgentRun(ctx, &entity.AgentRunMeta{
// ConversationID: 7503546991712960512,
// SpaceID: 666,
// SectionID: 7503546991712976896,
// UserID: 888,
// AgentID: 7501996002144944128,
// Content: content,
// ContentType: entity.ContentTypeMulti,
// })
// assert.NoError(t, err)
// t.Logf("------------stream: %+v; err:%v", stream, err)
//
// for {
// chunk, errRecv := stream.Recv()
// jsonStr, _ := json.Marshal(chunk)
// fmt.Println(string(jsonStr))
// if errRecv == io.EOF || chunk == nil || chunk.Event == entity.RunEventStreamDone {
// break
// }
// if errRecv != nil {
// assert.NoError(t, errRecv)
// break
// }
// }
// assert.NoError(t, err)
}

View File

@@ -0,0 +1,52 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package entity
import (
"github.com/coze-dev/coze-studio/backend/api/model/conversation/common"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/conversation"
)
type Conversation = conversation.Conversation
type CreateMeta struct {
AgentID int64 `json:"agent_id"`
UserID int64 `json:"user_id"`
ConnectorID int64 `json:"connector_id"`
Scene common.Scene `json:"scene"`
Ext string `json:"ext"`
}
type NewConversationCtxRequest struct {
ID int64 `json:"id"`
}
type NewConversationCtxResponse struct {
ID int64 `json:"id"`
SectionID int64 `json:"section_id"`
}
type GetCurrent = conversation.GetCurrent
type ListMeta struct {
UserID int64 `json:"user_id"`
ConnectorID int64 `json:"connector_id"`
Scene common.Scene `json:"scene"`
AgentID int64 `json:"agent_id"`
Limit int `json:"limit"`
Page int `json:"page"`
}

View File

@@ -0,0 +1,209 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dal
import (
"context"
"errors"
"time"
"gorm.io/gorm"
"github.com/coze-dev/coze-studio/backend/api/model/conversation/common"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/conversation"
"github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/entity"
"github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/internal/dal/model"
"github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/internal/dal/query"
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
)
type ConversationDAO struct {
idgen idgen.IDGenerator
db *gorm.DB
query *query.Query
}
func NewConversationDAO(db *gorm.DB, generator idgen.IDGenerator) *ConversationDAO {
return &ConversationDAO{
idgen: generator,
db: db,
query: query.Use(db),
}
}
func (dao *ConversationDAO) Create(ctx context.Context, msg *entity.Conversation) (*entity.Conversation, error) {
poData := dao.conversationDO2PO(ctx, msg)
ids, err := dao.idgen.GenMultiIDs(ctx, 2)
if err != nil {
return nil, err
}
poData.ID = ids[0]
poData.SectionID = ids[1]
err = dao.query.Conversation.WithContext(ctx).Create(poData)
if err != nil {
return nil, err
}
return dao.conversationPO2DO(ctx, poData), nil
}
func (dao *ConversationDAO) GetByID(ctx context.Context, id int64) (*entity.Conversation, error) {
poData, err := dao.query.Conversation.WithContext(ctx).Debug().Where(dao.query.Conversation.ID.Eq(id)).First()
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
if err != nil {
return nil, err
}
return dao.conversationPO2DO(ctx, poData), nil
}
func (dao *ConversationDAO) UpdateSection(ctx context.Context, id int64) (int64, error) {
updateColumn := make(map[string]interface{})
table := dao.query.Conversation
newSectionID, err := dao.idgen.GenID(ctx)
if err != nil {
return 0, err
}
updateColumn[table.SectionID.ColumnName().String()] = newSectionID
updateColumn[table.UpdatedAt.ColumnName().String()] = time.Now().UnixMilli()
_, err = dao.query.Conversation.WithContext(ctx).Where(dao.query.Conversation.ID.Eq(id)).UpdateColumns(updateColumn)
if err != nil {
return 0, err
}
return newSectionID, nil
}
func (dao *ConversationDAO) Delete(ctx context.Context, id int64) (int64, error) {
table := dao.query.Conversation
updateColumn := make(map[string]interface{})
updateColumn[table.UpdatedAt.ColumnName().String()] = time.Now().UnixMilli()
updateColumn[table.Status.ColumnName().String()] = conversation.ConversationStatusDeleted
updateRes, err := dao.query.Conversation.WithContext(ctx).Where(dao.query.Conversation.ID.Eq(id)).UpdateColumns(updateColumn)
if err != nil {
return 0, err
}
return updateRes.RowsAffected, err
}
func (dao *ConversationDAO) Get(ctx context.Context, userID int64, agentID int64, scene int32, connectorID int64) (*entity.Conversation, error) {
po, err := dao.query.Conversation.WithContext(ctx).Debug().
Where(dao.query.Conversation.CreatorID.Eq(userID)).
Where(dao.query.Conversation.AgentID.Eq(agentID)).
Where(dao.query.Conversation.Scene.Eq(scene)).
Where(dao.query.Conversation.ConnectorID.Eq(connectorID)).
Where(dao.query.Conversation.Status.Eq(int32(conversation.ConversationStatusNormal))).
Order(dao.query.Conversation.CreatedAt.Desc()).
First()
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
if err != nil {
return nil, err
}
return dao.conversationPO2DO(ctx, po), nil
}
func (dao *ConversationDAO) List(ctx context.Context, userID int64, agentID int64, connectorID int64, scene int32, limit int, page int) ([]*entity.Conversation, bool, error) {
var hasMore bool
do := dao.query.Conversation.WithContext(ctx).Debug()
do = do.Where(dao.query.Conversation.CreatorID.Eq(userID)).
Where(dao.query.Conversation.AgentID.Eq(agentID)).
Where(dao.query.Conversation.Scene.Eq(scene)).
Where(dao.query.Conversation.ConnectorID.Eq(connectorID))
do = do.Offset((page - 1) * limit)
if limit > 0 {
do = do.Limit(int(limit) + 1)
}
poList, err := do.Find()
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, hasMore, nil
}
if err != nil {
return nil, hasMore, err
}
if len(poList) == 0 {
return nil, hasMore, nil
}
if len(poList) > limit {
hasMore = true
return dao.conversationBatchPO2DO(ctx, poList[:(len(poList)-1)]), hasMore, nil
}
return dao.conversationBatchPO2DO(ctx, poList), hasMore, nil
}
func (dao *ConversationDAO) conversationDO2PO(ctx context.Context, conversation *entity.Conversation) *model.Conversation {
return &model.Conversation{
ID: conversation.ID,
SectionID: conversation.SectionID,
ConnectorID: conversation.ConnectorID,
AgentID: conversation.AgentID,
CreatorID: conversation.CreatorID,
Scene: int32(conversation.Scene),
Status: int32(conversation.Status),
Ext: conversation.Ext,
CreatedAt: time.Now().UnixMilli(),
UpdatedAt: time.Now().UnixMilli(),
}
}
func (dao *ConversationDAO) conversationPO2DO(ctx context.Context, c *model.Conversation) *entity.Conversation {
return &entity.Conversation{
ID: c.ID,
SectionID: c.SectionID,
ConnectorID: c.ConnectorID,
AgentID: c.AgentID,
CreatorID: c.CreatorID,
Scene: common.Scene(c.Scene),
Status: conversation.ConversationStatus(c.Status),
Ext: c.Ext,
CreatedAt: c.CreatedAt,
UpdatedAt: c.UpdatedAt,
}
}
func (dao *ConversationDAO) conversationBatchPO2DO(ctx context.Context, conversations []*model.Conversation) []*entity.Conversation {
return slices.Transform(conversations, func(c *model.Conversation) *entity.Conversation {
return &entity.Conversation{
ID: c.ID,
SectionID: c.SectionID,
ConnectorID: c.ConnectorID,
AgentID: c.AgentID,
CreatorID: c.CreatorID,
Scene: common.Scene(c.Scene),
Status: conversation.ConversationStatus(c.Status),
Ext: c.Ext,
CreatedAt: c.CreatedAt,
UpdatedAt: c.UpdatedAt,
}
})
}

View File

@@ -0,0 +1,26 @@
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
package model
const TableNameConversation = "conversation"
// Conversation 会话信息表
type Conversation struct {
ID int64 `gorm:"column:id;primaryKey;autoIncrement:true;comment:主键ID" json:"id"` // 主键ID
ConnectorID int64 `gorm:"column:connector_id;not null;comment:业务线 ID" json:"connector_id"` // 业务线 ID
AgentID int64 `gorm:"column:agent_id;not null;comment:agent_id" json:"agent_id"` // agent_id
Scene int32 `gorm:"column:scene;not null;comment:会话场景" json:"scene"` // 会话场景
SectionID int64 `gorm:"column:section_id;not null;comment:最新section_id" json:"section_id"` // 最新section_id
CreatorID int64 `gorm:"column:creator_id;comment:创建者id" json:"creator_id"` // 创建者id
Ext string `gorm:"column:ext;comment:扩展字段" json:"ext"` // 扩展字段
Status int32 `gorm:"column:status;not null;default:1;comment:status: 1-normal 2-deleted" json:"status"` // status: 1-normal 2-deleted
CreatedAt int64 `gorm:"column:created_at;not null;autoCreateTime:milli;comment:创建时间" json:"created_at"` // 创建时间
UpdatedAt int64 `gorm:"column:updated_at;not null;autoUpdateTime:milli;comment:更新时间" json:"updated_at"` // 更新时间
}
// TableName Conversation's table name
func (*Conversation) TableName() string {
return TableNameConversation
}

View File

@@ -0,0 +1,417 @@
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
package query
import (
"context"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gen"
"gorm.io/gen/field"
"gorm.io/plugin/dbresolver"
"github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/internal/dal/model"
)
func newConversation(db *gorm.DB, opts ...gen.DOOption) conversation {
_conversation := conversation{}
_conversation.conversationDo.UseDB(db, opts...)
_conversation.conversationDo.UseModel(&model.Conversation{})
tableName := _conversation.conversationDo.TableName()
_conversation.ALL = field.NewAsterisk(tableName)
_conversation.ID = field.NewInt64(tableName, "id")
_conversation.ConnectorID = field.NewInt64(tableName, "connector_id")
_conversation.AgentID = field.NewInt64(tableName, "agent_id")
_conversation.Scene = field.NewInt32(tableName, "scene")
_conversation.SectionID = field.NewInt64(tableName, "section_id")
_conversation.CreatorID = field.NewInt64(tableName, "creator_id")
_conversation.Ext = field.NewString(tableName, "ext")
_conversation.Status = field.NewInt32(tableName, "status")
_conversation.CreatedAt = field.NewInt64(tableName, "created_at")
_conversation.UpdatedAt = field.NewInt64(tableName, "updated_at")
_conversation.fillFieldMap()
return _conversation
}
// conversation 会话信息表
type conversation struct {
conversationDo
ALL field.Asterisk
ID field.Int64 // 主键ID
ConnectorID field.Int64 // 业务线 ID
AgentID field.Int64 // agent_id
Scene field.Int32 // 会话场景
SectionID field.Int64 // 最新section_id
CreatorID field.Int64 // 创建者id
Ext field.String // 扩展字段
Status field.Int32 // status: 1-normal 2-deleted
CreatedAt field.Int64 // 创建时间
UpdatedAt field.Int64 // 更新时间
fieldMap map[string]field.Expr
}
func (c conversation) Table(newTableName string) *conversation {
c.conversationDo.UseTable(newTableName)
return c.updateTableName(newTableName)
}
func (c conversation) As(alias string) *conversation {
c.conversationDo.DO = *(c.conversationDo.As(alias).(*gen.DO))
return c.updateTableName(alias)
}
func (c *conversation) updateTableName(table string) *conversation {
c.ALL = field.NewAsterisk(table)
c.ID = field.NewInt64(table, "id")
c.ConnectorID = field.NewInt64(table, "connector_id")
c.AgentID = field.NewInt64(table, "agent_id")
c.Scene = field.NewInt32(table, "scene")
c.SectionID = field.NewInt64(table, "section_id")
c.CreatorID = field.NewInt64(table, "creator_id")
c.Ext = field.NewString(table, "ext")
c.Status = field.NewInt32(table, "status")
c.CreatedAt = field.NewInt64(table, "created_at")
c.UpdatedAt = field.NewInt64(table, "updated_at")
c.fillFieldMap()
return c
}
func (c *conversation) GetFieldByName(fieldName string) (field.OrderExpr, bool) {
_f, ok := c.fieldMap[fieldName]
if !ok || _f == nil {
return nil, false
}
_oe, ok := _f.(field.OrderExpr)
return _oe, ok
}
func (c *conversation) fillFieldMap() {
c.fieldMap = make(map[string]field.Expr, 10)
c.fieldMap["id"] = c.ID
c.fieldMap["connector_id"] = c.ConnectorID
c.fieldMap["agent_id"] = c.AgentID
c.fieldMap["scene"] = c.Scene
c.fieldMap["section_id"] = c.SectionID
c.fieldMap["creator_id"] = c.CreatorID
c.fieldMap["ext"] = c.Ext
c.fieldMap["status"] = c.Status
c.fieldMap["created_at"] = c.CreatedAt
c.fieldMap["updated_at"] = c.UpdatedAt
}
func (c conversation) clone(db *gorm.DB) conversation {
c.conversationDo.ReplaceConnPool(db.Statement.ConnPool)
return c
}
func (c conversation) replaceDB(db *gorm.DB) conversation {
c.conversationDo.ReplaceDB(db)
return c
}
type conversationDo struct{ gen.DO }
type IConversationDo interface {
gen.SubQuery
Debug() IConversationDo
WithContext(ctx context.Context) IConversationDo
WithResult(fc func(tx gen.Dao)) gen.ResultInfo
ReplaceDB(db *gorm.DB)
ReadDB() IConversationDo
WriteDB() IConversationDo
As(alias string) gen.Dao
Session(config *gorm.Session) IConversationDo
Columns(cols ...field.Expr) gen.Columns
Clauses(conds ...clause.Expression) IConversationDo
Not(conds ...gen.Condition) IConversationDo
Or(conds ...gen.Condition) IConversationDo
Select(conds ...field.Expr) IConversationDo
Where(conds ...gen.Condition) IConversationDo
Order(conds ...field.Expr) IConversationDo
Distinct(cols ...field.Expr) IConversationDo
Omit(cols ...field.Expr) IConversationDo
Join(table schema.Tabler, on ...field.Expr) IConversationDo
LeftJoin(table schema.Tabler, on ...field.Expr) IConversationDo
RightJoin(table schema.Tabler, on ...field.Expr) IConversationDo
Group(cols ...field.Expr) IConversationDo
Having(conds ...gen.Condition) IConversationDo
Limit(limit int) IConversationDo
Offset(offset int) IConversationDo
Count() (count int64, err error)
Scopes(funcs ...func(gen.Dao) gen.Dao) IConversationDo
Unscoped() IConversationDo
Create(values ...*model.Conversation) error
CreateInBatches(values []*model.Conversation, batchSize int) error
Save(values ...*model.Conversation) error
First() (*model.Conversation, error)
Take() (*model.Conversation, error)
Last() (*model.Conversation, error)
Find() ([]*model.Conversation, error)
FindInBatch(batchSize int, fc func(tx gen.Dao, batch int) error) (results []*model.Conversation, err error)
FindInBatches(result *[]*model.Conversation, batchSize int, fc func(tx gen.Dao, batch int) error) error
Pluck(column field.Expr, dest interface{}) error
Delete(...*model.Conversation) (info gen.ResultInfo, err error)
Update(column field.Expr, value interface{}) (info gen.ResultInfo, err error)
UpdateSimple(columns ...field.AssignExpr) (info gen.ResultInfo, err error)
Updates(value interface{}) (info gen.ResultInfo, err error)
UpdateColumn(column field.Expr, value interface{}) (info gen.ResultInfo, err error)
UpdateColumnSimple(columns ...field.AssignExpr) (info gen.ResultInfo, err error)
UpdateColumns(value interface{}) (info gen.ResultInfo, err error)
UpdateFrom(q gen.SubQuery) gen.Dao
Attrs(attrs ...field.AssignExpr) IConversationDo
Assign(attrs ...field.AssignExpr) IConversationDo
Joins(fields ...field.RelationField) IConversationDo
Preload(fields ...field.RelationField) IConversationDo
FirstOrInit() (*model.Conversation, error)
FirstOrCreate() (*model.Conversation, error)
FindByPage(offset int, limit int) (result []*model.Conversation, count int64, err error)
ScanByPage(result interface{}, offset int, limit int) (count int64, err error)
Scan(result interface{}) (err error)
Returning(value interface{}, columns ...string) IConversationDo
UnderlyingDB() *gorm.DB
schema.Tabler
}
func (c conversationDo) Debug() IConversationDo {
return c.withDO(c.DO.Debug())
}
func (c conversationDo) WithContext(ctx context.Context) IConversationDo {
return c.withDO(c.DO.WithContext(ctx))
}
func (c conversationDo) ReadDB() IConversationDo {
return c.Clauses(dbresolver.Read)
}
func (c conversationDo) WriteDB() IConversationDo {
return c.Clauses(dbresolver.Write)
}
func (c conversationDo) Session(config *gorm.Session) IConversationDo {
return c.withDO(c.DO.Session(config))
}
func (c conversationDo) Clauses(conds ...clause.Expression) IConversationDo {
return c.withDO(c.DO.Clauses(conds...))
}
func (c conversationDo) Returning(value interface{}, columns ...string) IConversationDo {
return c.withDO(c.DO.Returning(value, columns...))
}
func (c conversationDo) Not(conds ...gen.Condition) IConversationDo {
return c.withDO(c.DO.Not(conds...))
}
func (c conversationDo) Or(conds ...gen.Condition) IConversationDo {
return c.withDO(c.DO.Or(conds...))
}
func (c conversationDo) Select(conds ...field.Expr) IConversationDo {
return c.withDO(c.DO.Select(conds...))
}
func (c conversationDo) Where(conds ...gen.Condition) IConversationDo {
return c.withDO(c.DO.Where(conds...))
}
func (c conversationDo) Order(conds ...field.Expr) IConversationDo {
return c.withDO(c.DO.Order(conds...))
}
func (c conversationDo) Distinct(cols ...field.Expr) IConversationDo {
return c.withDO(c.DO.Distinct(cols...))
}
func (c conversationDo) Omit(cols ...field.Expr) IConversationDo {
return c.withDO(c.DO.Omit(cols...))
}
func (c conversationDo) Join(table schema.Tabler, on ...field.Expr) IConversationDo {
return c.withDO(c.DO.Join(table, on...))
}
func (c conversationDo) LeftJoin(table schema.Tabler, on ...field.Expr) IConversationDo {
return c.withDO(c.DO.LeftJoin(table, on...))
}
func (c conversationDo) RightJoin(table schema.Tabler, on ...field.Expr) IConversationDo {
return c.withDO(c.DO.RightJoin(table, on...))
}
func (c conversationDo) Group(cols ...field.Expr) IConversationDo {
return c.withDO(c.DO.Group(cols...))
}
func (c conversationDo) Having(conds ...gen.Condition) IConversationDo {
return c.withDO(c.DO.Having(conds...))
}
func (c conversationDo) Limit(limit int) IConversationDo {
return c.withDO(c.DO.Limit(limit))
}
func (c conversationDo) Offset(offset int) IConversationDo {
return c.withDO(c.DO.Offset(offset))
}
func (c conversationDo) Scopes(funcs ...func(gen.Dao) gen.Dao) IConversationDo {
return c.withDO(c.DO.Scopes(funcs...))
}
func (c conversationDo) Unscoped() IConversationDo {
return c.withDO(c.DO.Unscoped())
}
func (c conversationDo) Create(values ...*model.Conversation) error {
if len(values) == 0 {
return nil
}
return c.DO.Create(values)
}
func (c conversationDo) CreateInBatches(values []*model.Conversation, batchSize int) error {
return c.DO.CreateInBatches(values, batchSize)
}
// Save : !!! underlying implementation is different with GORM
// The method is equivalent to executing the statement: db.Clauses(clause.OnConflict{UpdateAll: true}).Create(values)
func (c conversationDo) Save(values ...*model.Conversation) error {
if len(values) == 0 {
return nil
}
return c.DO.Save(values)
}
func (c conversationDo) First() (*model.Conversation, error) {
if result, err := c.DO.First(); err != nil {
return nil, err
} else {
return result.(*model.Conversation), nil
}
}
func (c conversationDo) Take() (*model.Conversation, error) {
if result, err := c.DO.Take(); err != nil {
return nil, err
} else {
return result.(*model.Conversation), nil
}
}
func (c conversationDo) Last() (*model.Conversation, error) {
if result, err := c.DO.Last(); err != nil {
return nil, err
} else {
return result.(*model.Conversation), nil
}
}
func (c conversationDo) Find() ([]*model.Conversation, error) {
result, err := c.DO.Find()
return result.([]*model.Conversation), err
}
func (c conversationDo) FindInBatch(batchSize int, fc func(tx gen.Dao, batch int) error) (results []*model.Conversation, err error) {
buf := make([]*model.Conversation, 0, batchSize)
err = c.DO.FindInBatches(&buf, batchSize, func(tx gen.Dao, batch int) error {
defer func() { results = append(results, buf...) }()
return fc(tx, batch)
})
return results, err
}
func (c conversationDo) FindInBatches(result *[]*model.Conversation, batchSize int, fc func(tx gen.Dao, batch int) error) error {
return c.DO.FindInBatches(result, batchSize, fc)
}
func (c conversationDo) Attrs(attrs ...field.AssignExpr) IConversationDo {
return c.withDO(c.DO.Attrs(attrs...))
}
func (c conversationDo) Assign(attrs ...field.AssignExpr) IConversationDo {
return c.withDO(c.DO.Assign(attrs...))
}
func (c conversationDo) Joins(fields ...field.RelationField) IConversationDo {
for _, _f := range fields {
c = *c.withDO(c.DO.Joins(_f))
}
return &c
}
func (c conversationDo) Preload(fields ...field.RelationField) IConversationDo {
for _, _f := range fields {
c = *c.withDO(c.DO.Preload(_f))
}
return &c
}
func (c conversationDo) FirstOrInit() (*model.Conversation, error) {
if result, err := c.DO.FirstOrInit(); err != nil {
return nil, err
} else {
return result.(*model.Conversation), nil
}
}
func (c conversationDo) FirstOrCreate() (*model.Conversation, error) {
if result, err := c.DO.FirstOrCreate(); err != nil {
return nil, err
} else {
return result.(*model.Conversation), nil
}
}
func (c conversationDo) FindByPage(offset int, limit int) (result []*model.Conversation, count int64, err error) {
result, err = c.Offset(offset).Limit(limit).Find()
if err != nil {
return
}
if size := len(result); 0 < limit && 0 < size && size < limit {
count = int64(size + offset)
return
}
count, err = c.Offset(-1).Limit(-1).Count()
return
}
func (c conversationDo) ScanByPage(result interface{}, offset int, limit int) (count int64, err error) {
count, err = c.Count()
if err != nil {
return
}
err = c.Offset(offset).Limit(limit).Scan(result)
return
}
func (c conversationDo) Scan(result interface{}) (err error) {
return c.DO.Scan(result)
}
func (c conversationDo) Delete(models ...*model.Conversation) (result gen.ResultInfo, err error) {
return c.DO.Delete(models)
}
func (c *conversationDo) withDO(do gen.Dao) *conversationDo {
c.DO = *do.(*gen.DO)
return c
}

View File

@@ -0,0 +1,103 @@
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
package query
import (
"context"
"database/sql"
"gorm.io/gorm"
"gorm.io/gen"
"gorm.io/plugin/dbresolver"
)
var (
Q = new(Query)
Conversation *conversation
)
func SetDefault(db *gorm.DB, opts ...gen.DOOption) {
*Q = *Use(db, opts...)
Conversation = &Q.Conversation
}
func Use(db *gorm.DB, opts ...gen.DOOption) *Query {
return &Query{
db: db,
Conversation: newConversation(db, opts...),
}
}
type Query struct {
db *gorm.DB
Conversation conversation
}
func (q *Query) Available() bool { return q.db != nil }
func (q *Query) clone(db *gorm.DB) *Query {
return &Query{
db: db,
Conversation: q.Conversation.clone(db),
}
}
func (q *Query) ReadDB() *Query {
return q.ReplaceDB(q.db.Clauses(dbresolver.Read))
}
func (q *Query) WriteDB() *Query {
return q.ReplaceDB(q.db.Clauses(dbresolver.Write))
}
func (q *Query) ReplaceDB(db *gorm.DB) *Query {
return &Query{
db: db,
Conversation: q.Conversation.replaceDB(db),
}
}
type queryCtx struct {
Conversation IConversationDo
}
func (q *Query) WithContext(ctx context.Context) *queryCtx {
return &queryCtx{
Conversation: q.Conversation.WithContext(ctx),
}
}
func (q *Query) Transaction(fc func(tx *Query) error, opts ...*sql.TxOptions) error {
return q.db.Transaction(func(tx *gorm.DB) error { return fc(q.clone(tx)) }, opts...)
}
func (q *Query) Begin(opts ...*sql.TxOptions) *QueryTx {
tx := q.db.Begin(opts...)
return &QueryTx{Query: q.clone(tx), Error: tx.Error}
}
type QueryTx struct {
*Query
Error error
}
func (q *QueryTx) Commit() error {
return q.db.Commit().Error
}
func (q *QueryTx) Rollback() error {
return q.db.Rollback().Error
}
func (q *QueryTx) SavePoint(name string) error {
return q.db.SavePoint(name).Error
}
func (q *QueryTx) RollbackTo(name string) error {
return q.db.RollbackTo(name).Error
}

View File

@@ -0,0 +1,40 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package repository
import (
"context"
"gorm.io/gorm"
"github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/entity"
"github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/internal/dal"
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
)
func NewConversationRepo(db *gorm.DB, idGen idgen.IDGenerator) ConversationRepo {
return dal.NewConversationDAO(db, idGen)
}
type ConversationRepo interface {
Create(ctx context.Context, msg *entity.Conversation) (*entity.Conversation, error)
GetByID(ctx context.Context, id int64) (*entity.Conversation, error)
UpdateSection(ctx context.Context, id int64) (int64, error)
Get(ctx context.Context, userID int64, agentID int64, scene int32, connectorID int64) (*entity.Conversation, error)
Delete(ctx context.Context, id int64) (int64, error)
List(ctx context.Context, userID int64, agentID int64, connectorID int64, scene int32, limit int, page int) ([]*entity.Conversation, bool, error)
}

View File

@@ -0,0 +1,32 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package conversation
import (
"context"
"github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/entity"
)
type Conversation interface {
Create(ctx context.Context, req *entity.CreateMeta) (*entity.Conversation, error)
GetByID(ctx context.Context, id int64) (*entity.Conversation, error)
NewConversationCtx(ctx context.Context, req *entity.NewConversationCtxRequest) (*entity.NewConversationCtxResponse, error)
GetCurrentConversation(ctx context.Context, req *entity.GetCurrent) (*entity.Conversation, error)
Delete(ctx context.Context, id int64) error
List(ctx context.Context, req *entity.ListMeta) ([]*entity.Conversation, bool, error)
}

View File

@@ -0,0 +1,112 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package conversation
import (
"context"
"github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/entity"
"github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/repository"
)
type conversationImpl struct {
Components
}
type Components struct {
ConversationRepo repository.ConversationRepo
}
func NewService(c *Components) Conversation {
return &conversationImpl{
Components: *c,
}
}
func (c *conversationImpl) Create(ctx context.Context, req *entity.CreateMeta) (*entity.Conversation, error) {
var resp *entity.Conversation
doData := &entity.Conversation{
CreatorID: req.UserID,
AgentID: req.AgentID,
Scene: req.Scene,
ConnectorID: req.ConnectorID,
Ext: req.Ext,
}
resp, err := c.ConversationRepo.Create(ctx, doData)
if err != nil {
return resp, err
}
return resp, nil
}
func (c *conversationImpl) GetByID(ctx context.Context, id int64) (*entity.Conversation, error) {
resp := &entity.Conversation{}
// get conversation
resp, err := c.ConversationRepo.GetByID(ctx, id)
if err != nil {
return resp, err
}
return resp, nil
}
func (c *conversationImpl) NewConversationCtx(ctx context.Context, req *entity.NewConversationCtxRequest) (*entity.NewConversationCtxResponse, error) {
resp := &entity.NewConversationCtxResponse{}
newSectionID, err := c.ConversationRepo.UpdateSection(ctx, req.ID)
if err != nil {
return resp, err
}
if newSectionID != 0 {
resp.ID = req.ID
resp.SectionID = newSectionID
}
return resp, nil
}
func (c *conversationImpl) GetCurrentConversation(ctx context.Context, req *entity.GetCurrent) (*entity.Conversation, error) {
// get conversation
conversation, err := c.ConversationRepo.Get(ctx, req.UserID, req.AgentID, int32(req.Scene), req.ConnectorID)
if err != nil {
return nil, err
}
// build data
return conversation, nil
}
func (c *conversationImpl) Delete(ctx context.Context, id int64) error {
_, err := c.ConversationRepo.Delete(ctx, id)
if err != nil {
return err
}
return nil
}
func (c *conversationImpl) List(ctx context.Context, req *entity.ListMeta) ([]*entity.Conversation, bool, error) {
conversationList, hasMore, err := c.ConversationRepo.List(ctx, req.UserID, req.AgentID, req.ConnectorID, int32(req.Scene), req.Limit, req.Page)
if err != nil {
return nil, hasMore, err
}
return conversationList, hasMore, nil
}

View File

@@ -0,0 +1,171 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package conversation
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
"github.com/coze-dev/coze-studio/backend/api/model/conversation/common"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/conversation"
"github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/entity"
"github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/internal/dal/model"
"github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/repository"
mock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/idgen"
"github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/orm"
)
// Test_NewListMessage tests the NewListMessage function
func TestCreateConversation(t *testing.T) {
ctx := context.Background()
// mockDB, _ := mysql.New()
// redisCli := redis.New()
// idGen, err := idgen.New(redisCli)
ctrl := gomock.NewController(t)
idGen := mock.NewMockIDGenerator(ctrl)
idGen.EXPECT().GenMultiIDs(gomock.Any(), 2).Return([]int64{
1, 2,
}, nil).AnyTimes()
mockDBGen := orm.NewMockDB()
mockDBGen.AddTable(&model.Conversation{})
mockDB, err := mockDBGen.DB()
components := &Components{
ConversationRepo: repository.NewConversationRepo(mockDB, idGen),
}
createData, err := NewService(components).Create(ctx, &entity.CreateMeta{
AgentID: 100000,
UserID: 222222,
ConnectorID: 100001,
Scene: common.Scene_Playground,
Ext: "debug ext9999",
})
assert.NotNil(t, createData)
t.Logf("create conversation result: %v; err:%v", createData, err)
assert.Nil(t, err)
assert.Equal(t, "debug ext9999", createData.Ext)
}
func TestGetById(t *testing.T) {
ctx := context.Background()
mockDBGen := orm.NewMockDB()
mockDBGen.AddTable(&model.Conversation{})
mockDBGen.AddTable(&model.Conversation{}).
AddRows(
&model.Conversation{
ID: 7494574457319587840,
AgentID: 8888,
SectionID: 100001,
ConnectorID: 100001,
CreatorID: 1111,
Ext: "debug ext1111",
},
)
mockDB, err := mockDBGen.DB()
ctrl := gomock.NewController(t)
idGen := mock.NewMockIDGenerator(ctrl)
idGen.EXPECT().GenID(gomock.Any()).Return(time.Now().UnixMilli(), nil).AnyTimes()
components := &Components{
ConversationRepo: repository.NewConversationRepo(mockDB, idGen),
}
cd, err := NewService(components).GetByID(ctx, 7494574457319587840)
assert.NoError(t, err)
t.Logf("conversation result: %v; err:%v", cd, err)
assert.Equal(t, "debug ext1111", cd.Ext)
}
func TestNewConversationCtx(t *testing.T) {
ctx := context.Background()
ctrl := gomock.NewController(t)
idGen := mock.NewMockIDGenerator(ctrl)
idGen.EXPECT().GenID(gomock.Any()).Return(int64(123456), nil).Times(1)
mockDBGen := orm.NewMockDB()
mockDBGen.AddTable(&model.Conversation{})
mockDBGen.AddTable(&model.Conversation{}).
AddRows(
&model.Conversation{
ID: 7494574457319587840,
AgentID: 8888,
SectionID: 100001,
ConnectorID: 100001,
CreatorID: 1111,
},
)
mockDB, err := mockDBGen.DB()
assert.Nil(t, err)
components := &Components{
ConversationRepo: repository.NewConversationRepo(mockDB, idGen),
}
res, err := NewService(components).NewConversationCtx(ctx, &entity.NewConversationCtxRequest{
ID: 7494574457319587840,
})
t.Logf("conversation result: %v; err:%v", res, err)
assert.Equal(t, int64(123456), res.SectionID)
}
func TestConversationImpl_Delete(t *testing.T) {
ctx := context.Background()
mockDBGen := orm.NewMockDB()
mockDBGen.AddTable(&model.Conversation{})
mockDBGen.AddTable(&model.Conversation{}).
AddRows(
&model.Conversation{
ID: 7494574457319587840,
AgentID: 9999,
SectionID: 100001,
ConnectorID: 100001,
CreatorID: 1111,
Status: int32(conversation.ConversationStatusNormal),
},
)
mockDB, err := mockDBGen.DB()
assert.Nil(t, err)
components := &Components{
ConversationRepo: repository.NewConversationRepo(mockDB, nil),
}
err = NewService(components).Delete(ctx, 7494574457319587840)
t.Logf("delete err:%v", err)
assert.Nil(t, err)
currentConversation, err := NewService(components).GetByID(ctx, 7494574457319587840)
assert.NotNil(t, currentConversation)
t.Logf("conversation result: %v; err:%v", currentConversation, err)
assert.Nil(t, err)
assert.Equal(t, conversation.ConversationStatusDeleted, currentConversation.Status)
}

View File

@@ -0,0 +1,32 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package entity
type ScrollPageDirection string
const (
ScrollPageDirectionPrev ScrollPageDirection = "up"
ScrollPageDirectionNext ScrollPageDirection = "down"
)
type MessageStatus int32
const (
MessageStatusAvailable MessageStatus = 1
MessageStatusDeleted MessageStatus = 2
MessageStatusBroken MessageStatus = 4
)

View File

@@ -0,0 +1,64 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package entity
type VerboseInfo struct {
MessageType string `json:"msg_type"`
Data string `json:"data"`
}
type VerboseData struct {
Chunks []RecallDataInfo `json:"chunks"`
OriReq string `json:"ori_req"`
StatusCode int `json:"status_code"`
}
type RecallDataInfo struct {
Slice string `json:"slice"`
Score float64 `json:"score"`
Meta MetaInfo `json:"meta"`
}
type MetaInfo struct {
Dataset DatasetInfo `json:"dataset"`
Document DocumentInfo `json:"document"`
Link LinkInfo `json:"link"`
Card CardInfo `json:"card"`
}
type DatasetInfo struct {
ID string `json:"id"`
Name string `json:"name"`
}
type DocumentInfo struct {
ID string `json:"id"`
Name string `json:"name"`
FormatType int32 `json:"format_type"`
SourceType int32 `json:"source_type"`
}
type LinkInfo struct {
Title string `json:"title"`
URL string `json:"url"`
}
type CardInfo struct {
Title string `json:"title"`
Con string `json:"con"`
Index string `json:"index"`
}

View File

@@ -0,0 +1,55 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package entity
import "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
type Message = message.Message
type ListMeta struct {
ConversationID int64 `json:"conversation_id"`
RunID []*int64 `json:"run_id"`
UserID string `json:"user_id"`
AgentID int64 `json:"agent_id"`
OrderBy *string `json:"order_by"`
Limit int `json:"limit"`
Cursor int64 `json:"cursor"` // message id
Direction ScrollPageDirection `json:"direction"` // "prev" "Next"
}
type ListResult struct {
Messages []*Message `json:"messages"`
PrevCursor int64 `json:"prev_cursor"`
NextCursor int64 `json:"next_cursor"`
HasMore bool `json:"has_more"`
Direction ScrollPageDirection `json:"direction"`
}
type GetByRunIDsRequest struct {
ConversationID int64 `json:"conversation_id"`
RunID []int64 `json:"run_id"`
}
type DeleteMeta struct {
MessageIDs []int64 `json:"message_ids"`
RunIDs []int64 `json:"run_ids"`
}
type BrokenMeta struct {
ID int64 `json:"id"`
Position *int32 `json:"position"`
}

View File

@@ -0,0 +1,55 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package entity
type MessageExtKey string
const (
MessageExtKeyInputTokens MessageExtKey = "input_tokens"
MessageExtKeyOutputTokens MessageExtKey = "output_tokens"
MessageExtKeyToken MessageExtKey = "token"
MessageExtKeyPluginStatus MessageExtKey = "plugin_status"
MessageExtKeyTimeCost MessageExtKey = "time_cost"
MessageExtKeyWorkflowTokens MessageExtKey = "workflow_tokens"
MessageExtKeyBotState MessageExtKey = "bot_state"
MessageExtKeyPluginRequest MessageExtKey = "plugin_request"
MessageExtKeyToolName MessageExtKey = "tool_name"
MessageExtKeyPlugin MessageExtKey = "plugin"
MessageExtKeyMockHitInfo MessageExtKey = "mock_hit_info"
MessageExtKeyMessageTitle MessageExtKey = "message_title"
MessageExtKeyStreamPluginRunning MessageExtKey = "stream_plugin_running"
MessageExtKeyExecuteDisplayName MessageExtKey = "execute_display_name"
MessageExtKeyTaskType MessageExtKey = "task_type"
MessageExtKeyCallID MessageExtKey = "call_id"
ExtKeyResumeInfo MessageExtKey = "resume_info"
ExtKeyBreakPoint MessageExtKey = "break_point"
ExtKeyToolCallsIDs MessageExtKey = "tool_calls_ids"
ExtKeyRequiresAction MessageExtKey = "requires_action"
)
type BotStateExt struct {
BotID string `json:"bot_id"`
AgentName string `json:"agent_name"`
AgentID string `json:"agent_id"`
Awaiting string `json:"awaiting"`
}
type UsageExt struct {
TotalCount int64 `json:"total_count"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
}

View File

@@ -0,0 +1,365 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dal
import (
"context"
"encoding/json"
"errors"
"time"
"github.com/cloudwego/eino/schema"
"gorm.io/gorm"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/internal/dal/model"
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/internal/dal/query"
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
type MessageDAO struct {
query *query.Query
idgen idgen.IDGenerator
}
func NewMessageDAO(db *gorm.DB, idgen idgen.IDGenerator) *MessageDAO {
return &MessageDAO{
query: query.Use(db),
idgen: idgen,
}
}
func (dao *MessageDAO) Create(ctx context.Context, msg *entity.Message) (*entity.Message, error) {
poData, err := dao.messageDO2PO(ctx, msg)
if err != nil {
return nil, err
}
do := dao.query.Message.WithContext(ctx).Debug()
cErr := do.Create(poData)
if cErr != nil {
return nil, cErr
}
return dao.messagePO2DO(poData), nil
}
func (dao *MessageDAO) List(ctx context.Context, conversationID int64, limit int, cursor int64, direction entity.ScrollPageDirection, messageType *message.MessageType) ([]*entity.Message, bool, error) {
m := dao.query.Message
do := m.WithContext(ctx).Debug().Where(m.ConversationID.Eq(conversationID)).Where(m.Status.Eq(int32(entity.MessageStatusAvailable)))
if messageType != nil {
do = do.Where(m.MessageType.Eq(string(*messageType)))
}
if limit > 0 {
do = do.Limit(int(limit) + 1)
}
if cursor > 0 {
if direction == entity.ScrollPageDirectionPrev {
do = do.Where(m.CreatedAt.Lt(cursor))
} else {
do = do.Where(m.CreatedAt.Gt(cursor))
}
}
do = do.Order(m.CreatedAt.Desc())
messageList, err := do.Find()
var hasMore bool
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
return nil, hasMore, nil
}
if err != nil {
return nil, false, err
}
if len(messageList) > limit {
hasMore = true
messageList = messageList[:limit]
}
return dao.batchMessagePO2DO(messageList), hasMore, nil
}
func (dao *MessageDAO) GetByRunIDs(ctx context.Context, runIDs []int64, orderBy string) ([]*entity.Message, error) {
m := dao.query.Message
do := m.WithContext(ctx).Debug().Where(m.RunID.In(runIDs...))
if orderBy == "DESC" {
do = do.Order(m.CreatedAt.Desc())
} else {
do = do.Order(m.CreatedAt.Asc())
}
poList, err := do.Find()
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
if err != nil {
return nil, err
}
return dao.batchMessagePO2DO(poList), nil
}
func (dao *MessageDAO) Edit(ctx context.Context, msgID int64, msg *message.Message) (int64, error) {
m := dao.query.Message
columns := dao.buildEditColumns(msg)
do, err := m.WithContext(ctx).Where(m.ID.Eq(msgID)).UpdateColumns(columns)
if err != nil {
return 0, err
}
return do.RowsAffected, nil
}
func (dao *MessageDAO) buildEditColumns(msg *message.Message) map[string]interface{} {
columns := make(map[string]interface{})
table := dao.query.Message
if msg.Content != "" {
columns[table.Content.ColumnName().String()] = msg.Content
}
if msg.MessageType != "" {
columns[table.MessageType.ColumnName().String()] = msg.MessageType
}
if msg.ContentType != "" {
columns[table.ContentType.ColumnName().String()] = msg.ContentType
}
if len(msg.ReasoningContent) > 0 {
columns[table.ReasoningContent.ColumnName().String()] = msg.ReasoningContent
}
if msg.Position > 0 {
columns[table.BrokenPosition.ColumnName().String()] = msg.Position
}
if msg.Status > 0 {
columns[table.Status.ColumnName().String()] = msg.Status
}
if len(msg.ModelContent) > 0 {
columns[table.ModelContent.ColumnName().String()] = msg.ModelContent
}
columns[table.UpdatedAt.ColumnName().String()] = time.Now().UnixMilli()
if msg.Ext != nil {
ext, err := sonic.MarshalString(msg.Ext)
if err == nil {
columns[table.Ext.ColumnName().String()] = ext
}
}
return columns
}
func (dao *MessageDAO) GetByID(ctx context.Context, msgID int64) (*entity.Message, error) {
m := dao.query.Message
do := m.WithContext(ctx).Where(m.ID.Eq(msgID))
po, err := do.First()
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
if err != nil {
return nil, err
}
return dao.messagePO2DO(po), nil
}
func (dao *MessageDAO) Delete(ctx context.Context, msgIDs []int64, runIDs []int64) error {
if len(msgIDs) == 0 && len(runIDs) == 0 {
return nil
}
updateColumns := make(map[string]interface{})
updateColumns["status"] = int32(entity.MessageStatusDeleted)
m := dao.query.Message
do := m.WithContext(ctx)
if len(runIDs) > 0 {
do = do.Where(m.RunID.In(runIDs...))
}
if len(msgIDs) > 0 {
do = do.Where(m.ID.In(msgIDs...))
}
_, err := do.UpdateColumns(&updateColumns)
return err
}
func (dao *MessageDAO) messageDO2PO(ctx context.Context, msgDo *entity.Message) (*model.Message, error) {
id, gErr := dao.idgen.GenID(ctx)
if gErr != nil {
return nil, gErr
}
msgPO := &model.Message{
ID: id,
ConversationID: msgDo.ConversationID,
RunID: msgDo.RunID,
AgentID: msgDo.AgentID,
SectionID: msgDo.SectionID,
UserID: msgDo.UserID,
Role: string(msgDo.Role),
ContentType: string(msgDo.ContentType),
MessageType: string(msgDo.MessageType),
DisplayContent: msgDo.DisplayContent,
Content: msgDo.Content,
BrokenPosition: msgDo.Position,
Status: int32(entity.MessageStatusAvailable),
CreatedAt: time.Now().UnixMilli(),
UpdatedAt: time.Now().UnixMilli(),
}
mc, err := dao.buildModelContent(msgDo)
if err != nil {
return nil, err
}
msgPO.ModelContent = mc
ext, err := json.Marshal(msgDo.Ext)
if err != nil {
return nil, errorx.WrapByCode(err, errno.ErrConversationJsonMarshal)
}
msgPO.Ext = string(ext)
return msgPO, nil
}
func (dao *MessageDAO) buildModelContent(msgDO *entity.Message) (string, error) {
modelContent := msgDO.ModelContent
if modelContent != "" {
return modelContent, nil
}
modelContentObj := &schema.Message{
Role: msgDO.Role,
Name: msgDO.Name,
}
if msgDO.Content == "" && len(msgDO.MultiContent) == 0 {
return "", nil
}
var multiContent []schema.ChatMessagePart
for _, contentData := range msgDO.MultiContent {
if contentData.Type == message.InputTypeText {
continue
}
one := schema.ChatMessagePart{}
switch contentData.Type {
case message.InputTypeImage:
one.Type = schema.ChatMessagePartTypeImageURL
one.ImageURL = &schema.ChatMessageImageURL{
URL: contentData.FileData[0].Url,
}
case message.InputTypeFile:
one.Type = schema.ChatMessagePartTypeFileURL
one.FileURL = &schema.ChatMessageFileURL{
URL: contentData.FileData[0].Url,
}
case message.InputTypeVideo:
one.Type = schema.ChatMessagePartTypeVideoURL
one.VideoURL = &schema.ChatMessageVideoURL{
URL: contentData.FileData[0].Url,
}
case message.InputTypeAudio:
one.Type = schema.ChatMessagePartTypeFileURL
one.AudioURL = &schema.ChatMessageAudioURL{
URL: contentData.FileData[0].Url,
}
}
multiContent = append(multiContent, one)
}
if len(multiContent) > 0 {
multiContent = append(multiContent, schema.ChatMessagePart{
Type: schema.ChatMessagePartTypeText,
Text: msgDO.Content,
})
} else {
modelContentObj.Content = msgDO.Content
}
modelContentObj.MultiContent = multiContent
mcObjByte, err := json.Marshal(modelContentObj)
if err != nil {
return "", errorx.WrapByCode(err, errno.ErrConversationJsonMarshal)
}
return string(mcObjByte), nil
}
func (dao *MessageDAO) batchMessagePO2DO(msgPOs []*model.Message) []*entity.Message {
return slices.Transform(msgPOs, func(msgPO *model.Message) *entity.Message {
msgDO := &entity.Message{
ID: msgPO.ID,
AgentID: msgPO.AgentID,
ConversationID: msgPO.ConversationID,
SectionID: msgPO.SectionID,
UserID: msgPO.UserID,
RunID: msgPO.RunID,
Role: schema.RoleType(msgPO.Role),
ContentType: message.ContentType(msgPO.ContentType),
MessageType: message.MessageType(msgPO.MessageType),
Position: msgPO.BrokenPosition,
ModelContent: msgPO.ModelContent,
Content: msgPO.Content,
Status: message.MessageStatus(msgPO.Status),
DisplayContent: msgPO.DisplayContent,
CreatedAt: msgPO.CreatedAt,
UpdatedAt: msgPO.UpdatedAt,
ReasoningContent: msgPO.ReasoningContent,
}
var ext map[string]string
err := json.Unmarshal([]byte(msgPO.Ext), &ext)
if err == nil {
msgDO.Ext = ext
}
return msgDO
})
}
func (dao *MessageDAO) messagePO2DO(msgPO *model.Message) *entity.Message {
msgDO := &entity.Message{
ID: msgPO.ID,
AgentID: msgPO.AgentID,
ConversationID: msgPO.ConversationID,
SectionID: msgPO.SectionID,
UserID: msgPO.UserID,
RunID: msgPO.RunID,
Role: schema.RoleType(msgPO.Role),
ContentType: message.ContentType(msgPO.ContentType),
MessageType: message.MessageType(msgPO.MessageType),
ModelContent: msgPO.ModelContent,
Content: msgPO.Content,
DisplayContent: msgPO.DisplayContent,
CreatedAt: msgPO.CreatedAt,
UpdatedAt: msgPO.UpdatedAt,
}
var ext map[string]string
err := json.Unmarshal([]byte(msgPO.Ext), &ext)
if err == nil {
msgDO.Ext = ext
}
return msgDO
}

View File

@@ -0,0 +1,35 @@
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
package model
const TableNameMessage = "message"
// Message 消息表
type Message struct {
ID int64 `gorm:"column:id;primaryKey;autoIncrement:true;comment:主键ID" json:"id"` // 主键ID
RunID int64 `gorm:"column:run_id;not null;comment:对应的run_id" json:"run_id"` // 对应的run_id
ConversationID int64 `gorm:"column:conversation_id;not null;comment:conversation id" json:"conversation_id"` // conversation id
UserID string `gorm:"column:user_id;not null;comment:user id" json:"user_id"` // user id
AgentID int64 `gorm:"column:agent_id;not null;comment:agent_id" json:"agent_id"` // agent_id
Role string `gorm:"column:role;not null;comment:角色: user、assistant、system" json:"role"` // 角色: user、assistant、system
ContentType string `gorm:"column:content_type;not null;comment:内容类型 1 text" json:"content_type"` // 内容类型 1 text
Content string `gorm:"column:content;comment:内容" json:"content"` // 内容
MessageType string `gorm:"column:message_type;not null;comment:消息类型:" json:"message_type"` // 消息类型:
DisplayContent string `gorm:"column:display_content;comment:展示内容" json:"display_content"` // 展示内容
Ext string `gorm:"column:ext;comment:message 扩展字段" json:"ext"` // message 扩展字段
SectionID int64 `gorm:"column:section_id;comment:段落id" json:"section_id"` // 段落id
BrokenPosition int32 `gorm:"column:broken_position;default:-1;comment:打断位置" json:"broken_position"` // 打断位置
Status int32 `gorm:"column:status;not null;comment:消息状态 1 Available 2 Deleted 3 Replaced 4 Broken 5 Failed 6 Streaming 7 Pending" json:"status"` // 消息状态 1 Available 2 Deleted 3 Replaced 4 Broken 5 Failed 6 Streaming 7 Pending
ModelContent string `gorm:"column:model_content;comment:模型输入内容" json:"model_content"` // 模型输入内容
MetaInfo string `gorm:"column:meta_info;comment:引用、高亮等文本标记信息" json:"meta_info"` // 引用、高亮等文本标记信息
ReasoningContent string `gorm:"column:reasoning_content;comment:思考内容" json:"reasoning_content"` // 思考内容
CreatedAt int64 `gorm:"column:created_at;not null;autoCreateTime:milli;comment:创建时间" json:"created_at"` // 创建时间
UpdatedAt int64 `gorm:"column:updated_at;not null;autoUpdateTime:milli;comment:更新时间" json:"updated_at"` // 更新时间
}
// TableName Message's table name
func (*Message) TableName() string {
return TableNameMessage
}

View File

@@ -0,0 +1,103 @@
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
package query
import (
"context"
"database/sql"
"gorm.io/gorm"
"gorm.io/gen"
"gorm.io/plugin/dbresolver"
)
var (
Q = new(Query)
Message *message
)
func SetDefault(db *gorm.DB, opts ...gen.DOOption) {
*Q = *Use(db, opts...)
Message = &Q.Message
}
func Use(db *gorm.DB, opts ...gen.DOOption) *Query {
return &Query{
db: db,
Message: newMessage(db, opts...),
}
}
type Query struct {
db *gorm.DB
Message message
}
func (q *Query) Available() bool { return q.db != nil }
func (q *Query) clone(db *gorm.DB) *Query {
return &Query{
db: db,
Message: q.Message.clone(db),
}
}
func (q *Query) ReadDB() *Query {
return q.ReplaceDB(q.db.Clauses(dbresolver.Read))
}
func (q *Query) WriteDB() *Query {
return q.ReplaceDB(q.db.Clauses(dbresolver.Write))
}
func (q *Query) ReplaceDB(db *gorm.DB) *Query {
return &Query{
db: db,
Message: q.Message.replaceDB(db),
}
}
type queryCtx struct {
Message IMessageDo
}
func (q *Query) WithContext(ctx context.Context) *queryCtx {
return &queryCtx{
Message: q.Message.WithContext(ctx),
}
}
func (q *Query) Transaction(fc func(tx *Query) error, opts ...*sql.TxOptions) error {
return q.db.Transaction(func(tx *gorm.DB) error { return fc(q.clone(tx)) }, opts...)
}
func (q *Query) Begin(opts ...*sql.TxOptions) *QueryTx {
tx := q.db.Begin(opts...)
return &QueryTx{Query: q.clone(tx), Error: tx.Error}
}
type QueryTx struct {
*Query
Error error
}
func (q *QueryTx) Commit() error {
return q.db.Commit().Error
}
func (q *QueryTx) Rollback() error {
return q.db.Rollback().Error
}
func (q *QueryTx) SavePoint(name string) error {
return q.db.SavePoint(name).Error
}
func (q *QueryTx) RollbackTo(name string) error {
return q.db.RollbackTo(name).Error
}

View File

@@ -0,0 +1,453 @@
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
package query
import (
"context"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gen"
"gorm.io/gen/field"
"gorm.io/plugin/dbresolver"
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/internal/dal/model"
)
func newMessage(db *gorm.DB, opts ...gen.DOOption) message {
_message := message{}
_message.messageDo.UseDB(db, opts...)
_message.messageDo.UseModel(&model.Message{})
tableName := _message.messageDo.TableName()
_message.ALL = field.NewAsterisk(tableName)
_message.ID = field.NewInt64(tableName, "id")
_message.RunID = field.NewInt64(tableName, "run_id")
_message.ConversationID = field.NewInt64(tableName, "conversation_id")
_message.UserID = field.NewString(tableName, "user_id")
_message.AgentID = field.NewInt64(tableName, "agent_id")
_message.Role = field.NewString(tableName, "role")
_message.ContentType = field.NewString(tableName, "content_type")
_message.Content = field.NewString(tableName, "content")
_message.MessageType = field.NewString(tableName, "message_type")
_message.DisplayContent = field.NewString(tableName, "display_content")
_message.Ext = field.NewString(tableName, "ext")
_message.SectionID = field.NewInt64(tableName, "section_id")
_message.BrokenPosition = field.NewInt32(tableName, "broken_position")
_message.Status = field.NewInt32(tableName, "status")
_message.ModelContent = field.NewString(tableName, "model_content")
_message.MetaInfo = field.NewString(tableName, "meta_info")
_message.ReasoningContent = field.NewString(tableName, "reasoning_content")
_message.CreatedAt = field.NewInt64(tableName, "created_at")
_message.UpdatedAt = field.NewInt64(tableName, "updated_at")
_message.fillFieldMap()
return _message
}
// message 消息表
type message struct {
messageDo
ALL field.Asterisk
ID field.Int64 // 主键ID
RunID field.Int64 // 对应的run_id
ConversationID field.Int64 // conversation id
UserID field.String // user id
AgentID field.Int64 // agent_id
Role field.String // 角色: user、assistant、system
ContentType field.String // 内容类型 1 text
Content field.String // 内容
MessageType field.String // 消息类型:
DisplayContent field.String // 展示内容
Ext field.String // message 扩展字段
SectionID field.Int64 // 段落id
BrokenPosition field.Int32 // 打断位置
Status field.Int32 // 消息状态 1 Available 2 Deleted 3 Replaced 4 Broken 5 Failed 6 Streaming 7 Pending
ModelContent field.String // 模型输入内容
MetaInfo field.String // 引用、高亮等文本标记信息
ReasoningContent field.String // 思考内容
CreatedAt field.Int64 // 创建时间
UpdatedAt field.Int64 // 更新时间
fieldMap map[string]field.Expr
}
func (m message) Table(newTableName string) *message {
m.messageDo.UseTable(newTableName)
return m.updateTableName(newTableName)
}
func (m message) As(alias string) *message {
m.messageDo.DO = *(m.messageDo.As(alias).(*gen.DO))
return m.updateTableName(alias)
}
func (m *message) updateTableName(table string) *message {
m.ALL = field.NewAsterisk(table)
m.ID = field.NewInt64(table, "id")
m.RunID = field.NewInt64(table, "run_id")
m.ConversationID = field.NewInt64(table, "conversation_id")
m.UserID = field.NewString(table, "user_id")
m.AgentID = field.NewInt64(table, "agent_id")
m.Role = field.NewString(table, "role")
m.ContentType = field.NewString(table, "content_type")
m.Content = field.NewString(table, "content")
m.MessageType = field.NewString(table, "message_type")
m.DisplayContent = field.NewString(table, "display_content")
m.Ext = field.NewString(table, "ext")
m.SectionID = field.NewInt64(table, "section_id")
m.BrokenPosition = field.NewInt32(table, "broken_position")
m.Status = field.NewInt32(table, "status")
m.ModelContent = field.NewString(table, "model_content")
m.MetaInfo = field.NewString(table, "meta_info")
m.ReasoningContent = field.NewString(table, "reasoning_content")
m.CreatedAt = field.NewInt64(table, "created_at")
m.UpdatedAt = field.NewInt64(table, "updated_at")
m.fillFieldMap()
return m
}
func (m *message) GetFieldByName(fieldName string) (field.OrderExpr, bool) {
_f, ok := m.fieldMap[fieldName]
if !ok || _f == nil {
return nil, false
}
_oe, ok := _f.(field.OrderExpr)
return _oe, ok
}
func (m *message) fillFieldMap() {
m.fieldMap = make(map[string]field.Expr, 19)
m.fieldMap["id"] = m.ID
m.fieldMap["run_id"] = m.RunID
m.fieldMap["conversation_id"] = m.ConversationID
m.fieldMap["user_id"] = m.UserID
m.fieldMap["agent_id"] = m.AgentID
m.fieldMap["role"] = m.Role
m.fieldMap["content_type"] = m.ContentType
m.fieldMap["content"] = m.Content
m.fieldMap["message_type"] = m.MessageType
m.fieldMap["display_content"] = m.DisplayContent
m.fieldMap["ext"] = m.Ext
m.fieldMap["section_id"] = m.SectionID
m.fieldMap["broken_position"] = m.BrokenPosition
m.fieldMap["status"] = m.Status
m.fieldMap["model_content"] = m.ModelContent
m.fieldMap["meta_info"] = m.MetaInfo
m.fieldMap["reasoning_content"] = m.ReasoningContent
m.fieldMap["created_at"] = m.CreatedAt
m.fieldMap["updated_at"] = m.UpdatedAt
}
func (m message) clone(db *gorm.DB) message {
m.messageDo.ReplaceConnPool(db.Statement.ConnPool)
return m
}
func (m message) replaceDB(db *gorm.DB) message {
m.messageDo.ReplaceDB(db)
return m
}
type messageDo struct{ gen.DO }
type IMessageDo interface {
gen.SubQuery
Debug() IMessageDo
WithContext(ctx context.Context) IMessageDo
WithResult(fc func(tx gen.Dao)) gen.ResultInfo
ReplaceDB(db *gorm.DB)
ReadDB() IMessageDo
WriteDB() IMessageDo
As(alias string) gen.Dao
Session(config *gorm.Session) IMessageDo
Columns(cols ...field.Expr) gen.Columns
Clauses(conds ...clause.Expression) IMessageDo
Not(conds ...gen.Condition) IMessageDo
Or(conds ...gen.Condition) IMessageDo
Select(conds ...field.Expr) IMessageDo
Where(conds ...gen.Condition) IMessageDo
Order(conds ...field.Expr) IMessageDo
Distinct(cols ...field.Expr) IMessageDo
Omit(cols ...field.Expr) IMessageDo
Join(table schema.Tabler, on ...field.Expr) IMessageDo
LeftJoin(table schema.Tabler, on ...field.Expr) IMessageDo
RightJoin(table schema.Tabler, on ...field.Expr) IMessageDo
Group(cols ...field.Expr) IMessageDo
Having(conds ...gen.Condition) IMessageDo
Limit(limit int) IMessageDo
Offset(offset int) IMessageDo
Count() (count int64, err error)
Scopes(funcs ...func(gen.Dao) gen.Dao) IMessageDo
Unscoped() IMessageDo
Create(values ...*model.Message) error
CreateInBatches(values []*model.Message, batchSize int) error
Save(values ...*model.Message) error
First() (*model.Message, error)
Take() (*model.Message, error)
Last() (*model.Message, error)
Find() ([]*model.Message, error)
FindInBatch(batchSize int, fc func(tx gen.Dao, batch int) error) (results []*model.Message, err error)
FindInBatches(result *[]*model.Message, batchSize int, fc func(tx gen.Dao, batch int) error) error
Pluck(column field.Expr, dest interface{}) error
Delete(...*model.Message) (info gen.ResultInfo, err error)
Update(column field.Expr, value interface{}) (info gen.ResultInfo, err error)
UpdateSimple(columns ...field.AssignExpr) (info gen.ResultInfo, err error)
Updates(value interface{}) (info gen.ResultInfo, err error)
UpdateColumn(column field.Expr, value interface{}) (info gen.ResultInfo, err error)
UpdateColumnSimple(columns ...field.AssignExpr) (info gen.ResultInfo, err error)
UpdateColumns(value interface{}) (info gen.ResultInfo, err error)
UpdateFrom(q gen.SubQuery) gen.Dao
Attrs(attrs ...field.AssignExpr) IMessageDo
Assign(attrs ...field.AssignExpr) IMessageDo
Joins(fields ...field.RelationField) IMessageDo
Preload(fields ...field.RelationField) IMessageDo
FirstOrInit() (*model.Message, error)
FirstOrCreate() (*model.Message, error)
FindByPage(offset int, limit int) (result []*model.Message, count int64, err error)
ScanByPage(result interface{}, offset int, limit int) (count int64, err error)
Scan(result interface{}) (err error)
Returning(value interface{}, columns ...string) IMessageDo
UnderlyingDB() *gorm.DB
schema.Tabler
}
func (m messageDo) Debug() IMessageDo {
return m.withDO(m.DO.Debug())
}
func (m messageDo) WithContext(ctx context.Context) IMessageDo {
return m.withDO(m.DO.WithContext(ctx))
}
func (m messageDo) ReadDB() IMessageDo {
return m.Clauses(dbresolver.Read)
}
func (m messageDo) WriteDB() IMessageDo {
return m.Clauses(dbresolver.Write)
}
func (m messageDo) Session(config *gorm.Session) IMessageDo {
return m.withDO(m.DO.Session(config))
}
func (m messageDo) Clauses(conds ...clause.Expression) IMessageDo {
return m.withDO(m.DO.Clauses(conds...))
}
func (m messageDo) Returning(value interface{}, columns ...string) IMessageDo {
return m.withDO(m.DO.Returning(value, columns...))
}
func (m messageDo) Not(conds ...gen.Condition) IMessageDo {
return m.withDO(m.DO.Not(conds...))
}
func (m messageDo) Or(conds ...gen.Condition) IMessageDo {
return m.withDO(m.DO.Or(conds...))
}
func (m messageDo) Select(conds ...field.Expr) IMessageDo {
return m.withDO(m.DO.Select(conds...))
}
func (m messageDo) Where(conds ...gen.Condition) IMessageDo {
return m.withDO(m.DO.Where(conds...))
}
func (m messageDo) Order(conds ...field.Expr) IMessageDo {
return m.withDO(m.DO.Order(conds...))
}
func (m messageDo) Distinct(cols ...field.Expr) IMessageDo {
return m.withDO(m.DO.Distinct(cols...))
}
func (m messageDo) Omit(cols ...field.Expr) IMessageDo {
return m.withDO(m.DO.Omit(cols...))
}
func (m messageDo) Join(table schema.Tabler, on ...field.Expr) IMessageDo {
return m.withDO(m.DO.Join(table, on...))
}
func (m messageDo) LeftJoin(table schema.Tabler, on ...field.Expr) IMessageDo {
return m.withDO(m.DO.LeftJoin(table, on...))
}
func (m messageDo) RightJoin(table schema.Tabler, on ...field.Expr) IMessageDo {
return m.withDO(m.DO.RightJoin(table, on...))
}
func (m messageDo) Group(cols ...field.Expr) IMessageDo {
return m.withDO(m.DO.Group(cols...))
}
func (m messageDo) Having(conds ...gen.Condition) IMessageDo {
return m.withDO(m.DO.Having(conds...))
}
func (m messageDo) Limit(limit int) IMessageDo {
return m.withDO(m.DO.Limit(limit))
}
func (m messageDo) Offset(offset int) IMessageDo {
return m.withDO(m.DO.Offset(offset))
}
func (m messageDo) Scopes(funcs ...func(gen.Dao) gen.Dao) IMessageDo {
return m.withDO(m.DO.Scopes(funcs...))
}
func (m messageDo) Unscoped() IMessageDo {
return m.withDO(m.DO.Unscoped())
}
func (m messageDo) Create(values ...*model.Message) error {
if len(values) == 0 {
return nil
}
return m.DO.Create(values)
}
func (m messageDo) CreateInBatches(values []*model.Message, batchSize int) error {
return m.DO.CreateInBatches(values, batchSize)
}
// Save : !!! underlying implementation is different with GORM
// The method is equivalent to executing the statement: db.Clauses(clause.OnConflict{UpdateAll: true}).Create(values)
func (m messageDo) Save(values ...*model.Message) error {
if len(values) == 0 {
return nil
}
return m.DO.Save(values)
}
func (m messageDo) First() (*model.Message, error) {
if result, err := m.DO.First(); err != nil {
return nil, err
} else {
return result.(*model.Message), nil
}
}
func (m messageDo) Take() (*model.Message, error) {
if result, err := m.DO.Take(); err != nil {
return nil, err
} else {
return result.(*model.Message), nil
}
}
func (m messageDo) Last() (*model.Message, error) {
if result, err := m.DO.Last(); err != nil {
return nil, err
} else {
return result.(*model.Message), nil
}
}
func (m messageDo) Find() ([]*model.Message, error) {
result, err := m.DO.Find()
return result.([]*model.Message), err
}
func (m messageDo) FindInBatch(batchSize int, fc func(tx gen.Dao, batch int) error) (results []*model.Message, err error) {
buf := make([]*model.Message, 0, batchSize)
err = m.DO.FindInBatches(&buf, batchSize, func(tx gen.Dao, batch int) error {
defer func() { results = append(results, buf...) }()
return fc(tx, batch)
})
return results, err
}
func (m messageDo) FindInBatches(result *[]*model.Message, batchSize int, fc func(tx gen.Dao, batch int) error) error {
return m.DO.FindInBatches(result, batchSize, fc)
}
func (m messageDo) Attrs(attrs ...field.AssignExpr) IMessageDo {
return m.withDO(m.DO.Attrs(attrs...))
}
func (m messageDo) Assign(attrs ...field.AssignExpr) IMessageDo {
return m.withDO(m.DO.Assign(attrs...))
}
func (m messageDo) Joins(fields ...field.RelationField) IMessageDo {
for _, _f := range fields {
m = *m.withDO(m.DO.Joins(_f))
}
return &m
}
func (m messageDo) Preload(fields ...field.RelationField) IMessageDo {
for _, _f := range fields {
m = *m.withDO(m.DO.Preload(_f))
}
return &m
}
func (m messageDo) FirstOrInit() (*model.Message, error) {
if result, err := m.DO.FirstOrInit(); err != nil {
return nil, err
} else {
return result.(*model.Message), nil
}
}
func (m messageDo) FirstOrCreate() (*model.Message, error) {
if result, err := m.DO.FirstOrCreate(); err != nil {
return nil, err
} else {
return result.(*model.Message), nil
}
}
func (m messageDo) FindByPage(offset int, limit int) (result []*model.Message, count int64, err error) {
result, err = m.Offset(offset).Limit(limit).Find()
if err != nil {
return
}
if size := len(result); 0 < limit && 0 < size && size < limit {
count = int64(size + offset)
return
}
count, err = m.Offset(-1).Limit(-1).Count()
return
}
func (m messageDo) ScanByPage(result interface{}, offset int, limit int) (count int64, err error) {
count, err = m.Count()
if err != nil {
return
}
err = m.Offset(offset).Limit(limit).Scan(result)
return
}
func (m messageDo) Scan(result interface{}) (err error) {
return m.DO.Scan(result)
}
func (m messageDo) Delete(models ...*model.Message) (result gen.ResultInfo, err error) {
return m.DO.Delete(models)
}
func (m *messageDo) withDO(do gen.Dao) *messageDo {
m.DO = *do.(*gen.DO)
return m
}

View File

@@ -0,0 +1,42 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package repository
import (
"context"
"gorm.io/gorm"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/internal/dal"
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
)
func NewMessageRepo(db *gorm.DB, idGen idgen.IDGenerator) MessageRepo {
return dal.NewMessageDAO(db, idGen)
}
type MessageRepo interface {
Create(ctx context.Context, msg *entity.Message) (*entity.Message, error)
List(ctx context.Context, conversationID int64, limit int, cursor int64,
direction entity.ScrollPageDirection, messageType *message.MessageType) ([]*entity.Message, bool, error)
GetByRunIDs(ctx context.Context, runIDs []int64, orderBy string) ([]*entity.Message, error)
Edit(ctx context.Context, msgID int64, message *message.Message) (int64, error)
GetByID(ctx context.Context, msgID int64) (*entity.Message, error)
Delete(ctx context.Context, msgIDs []int64, runIDs []int64) error
}

View File

@@ -0,0 +1,33 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package message
import (
"context"
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
)
type Message interface {
List(ctx context.Context, req *entity.ListMeta) (*entity.ListResult, error)
Create(ctx context.Context, req *entity.Message) (*entity.Message, error)
GetByRunIDs(ctx context.Context, conversationID int64, runIDs []int64) ([]*entity.Message, error)
GetByID(ctx context.Context, id int64) (*entity.Message, error)
Edit(ctx context.Context, req *entity.Message) (*entity.Message, error)
Delete(ctx context.Context, req *entity.DeleteMeta) error
Broken(ctx context.Context, req *entity.BrokenMeta) error
}

View File

@@ -0,0 +1,130 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package message
import (
"context"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/repository"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
type messageImpl struct {
Components
}
type Components struct {
MessageRepo repository.MessageRepo
}
func NewService(c *Components) Message {
return &messageImpl{
Components: *c,
}
}
func (m *messageImpl) Create(ctx context.Context, msg *entity.Message) (*entity.Message, error) {
// create message
msg, err := m.MessageRepo.Create(ctx, msg)
if err != nil {
return nil, err
}
return msg, nil
}
func (m *messageImpl) List(ctx context.Context, req *entity.ListMeta) (*entity.ListResult, error) {
resp := &entity.ListResult{}
// get message with query
messageList, hasMore, err := m.MessageRepo.List(ctx, req.ConversationID, req.Limit, req.Cursor, req.Direction, ptr.Of(message.MessageTypeQuestion))
if err != nil {
return resp, err
}
resp.Direction = req.Direction
resp.HasMore = hasMore
if len(messageList) > 0 {
resp.PrevCursor = messageList[len(messageList)-1].CreatedAt
resp.NextCursor = messageList[0].CreatedAt
var runIDs []int64
for _, m := range messageList {
runIDs = append(runIDs, m.RunID)
}
orderBy := "DESC"
if req.OrderBy != nil {
orderBy = *req.OrderBy
}
allMessageList, err := m.MessageRepo.GetByRunIDs(ctx, runIDs, orderBy)
if err != nil {
return resp, err
}
resp.Messages = allMessageList
}
return resp, nil
}
func (m *messageImpl) GetByRunIDs(ctx context.Context, conversationID int64, runIDs []int64) ([]*entity.Message, error) {
messageList, err := m.MessageRepo.GetByRunIDs(ctx, runIDs, "ASC")
if err != nil {
return nil, err
}
return messageList, nil
}
func (m *messageImpl) Edit(ctx context.Context, req *entity.Message) (*entity.Message, error) {
_, err := m.MessageRepo.Edit(ctx, req.ID, req)
if err != nil {
return nil, err
}
return req, nil
}
func (m *messageImpl) Delete(ctx context.Context, req *entity.DeleteMeta) error {
err := m.MessageRepo.Delete(ctx, req.MessageIDs, req.RunIDs)
if err != nil {
return err
}
return nil
}
func (m *messageImpl) GetByID(ctx context.Context, id int64) (*entity.Message, error) {
msg, err := m.MessageRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
return msg, nil
}
func (m *messageImpl) Broken(ctx context.Context, req *entity.BrokenMeta) error {
_, err := m.MessageRepo.Edit(ctx, req.ID, &message.Message{
Status: message.MessageStatusBroken,
Position: ptr.From(req.Position),
})
if err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,247 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package message
import (
"context"
"testing"
"github.com/cloudwego/eino/schema"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/internal/dal/model"
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/repository"
mock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/idgen"
"github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/orm"
)
// Test_NewListMessage tests the NewListMessage function
func TestListMessage(t *testing.T) {
ctx := context.Background()
mockDBGen := orm.NewMockDB()
mockDBGen.AddTable(&model.Message{}).
AddRows(
&model.Message{
ID: 1,
ConversationID: 1,
UserID: "1",
},
&model.Message{
ID: 2,
ConversationID: 1,
UserID: "1",
},
)
mockDB, err := mockDBGen.DB()
assert.NoError(t, err)
components := &Components{
MessageRepo: repository.NewMessageRepo(mockDB, nil),
}
resp, err := NewService(components).List(ctx, &entity.ListMeta{
ConversationID: 1,
Limit: 1,
UserID: "1",
})
assert.NoError(t, err)
assert.Len(t, resp.Messages, 0)
}
// Test_NewListMessage tests the NewListMessage function
func TestCreateMessage(t *testing.T) {
ctx := context.Background()
ctrl := gomock.NewController(t)
idGen := mock.NewMockIDGenerator(ctrl)
idGen.EXPECT().GenID(gomock.Any()).Return(int64(10), nil).Times(1)
mockDBGen := orm.NewMockDB()
mockDBGen.AddTable(&model.Message{})
mockDB, err := mockDBGen.DB()
// redisCli := redis.New()
// idGen, _ := idgen.New(redisCli)
// mockDB, err := mysql.New()
assert.NoError(t, err)
components := &Components{
MessageRepo: repository.NewMessageRepo(mockDB, idGen),
}
imageInput := &message.FileData{
Url: "https://xxxxx.xxxx/image",
Name: "test_img",
}
fileInput := &message.FileData{
Url: "https://xxxxx.xxxx/file",
Name: "test_file",
}
content := []*message.InputMetaData{
{
Type: message.InputTypeText,
Text: "解析图片中的内容",
},
{
Type: message.InputTypeImage,
FileData: []*message.FileData{
imageInput,
},
},
{
Type: message.InputTypeFile,
FileData: []*message.FileData{
fileInput,
},
},
}
service := NewService(components)
insert := &entity.Message{
ID: 7498710126354759680,
ConversationID: 7496795464885338112,
AgentID: 7366055842027922437,
UserID: "6666666",
RunID: 7498710102375923712,
Content: "你是谁?",
MultiContent: content,
Role: schema.Assistant,
MessageType: message.MessageTypeFunctionCall,
SectionID: 7496795464897921024,
ModelContent: "{\"role\":\"tool\",\"content\":\"tool call\"}",
ContentType: message.ContentTypeMix,
}
resp, err := service.Create(ctx, insert)
assert.NoError(t, err)
assert.Equal(t, int64(7366055842027922437), resp.AgentID)
assert.Equal(t, "你是谁?", resp.Content)
}
func TestEditMessage(t *testing.T) {
ctx := context.Background()
mockDBGen := orm.NewMockDB()
mockDBGen.AddTable(&model.Message{}).
AddRows(
&model.Message{
ID: 1,
ConversationID: 1,
UserID: "1",
RunID: 123,
},
&model.Message{
ID: 2,
ConversationID: 1,
UserID: "1",
RunID: 124,
},
)
mockDB, err := mockDBGen.DB()
assert.NoError(t, err)
components := &Components{
MessageRepo: repository.NewMessageRepo(mockDB, nil),
}
imageInput := &message.FileData{
Url: "https://xxxxx.xxxx/image",
Name: "test_img",
}
fileInput := &message.FileData{
Url: "https://xxxxx.xxxx/file",
Name: "test_file",
}
content := []*message.InputMetaData{
{
Type: message.InputTypeText,
Text: "解析图片中的内容",
},
{
Type: message.InputTypeImage,
FileData: []*message.FileData{
imageInput,
},
},
{
Type: message.InputTypeFile,
FileData: []*message.FileData{
fileInput,
},
},
}
resp, err := NewService(components).Edit(ctx, &entity.Message{
ID: 2,
Content: "test edit message",
MultiContent: content,
})
_ = resp
msOne, err := NewService(components).GetByRunIDs(ctx, 1, []int64{124})
assert.NoError(t, err)
assert.Equal(t, int64(124), msOne[0].RunID)
}
func TestGetByRunIDs(t *testing.T) {
ctx := context.Background()
mockDBGen := orm.NewMockDB()
mockDBGen.AddTable(&model.Message{}).
AddRows(
&model.Message{
ID: 1,
ConversationID: 1,
UserID: "1",
RunID: 123,
Content: "test content123",
},
&model.Message{
ID: 2,
ConversationID: 1,
UserID: "1",
Content: "test content124",
RunID: 124,
},
&model.Message{
ID: 3,
ConversationID: 1,
UserID: "1",
Content: "test content124",
RunID: 124,
},
)
mockDB, err := mockDBGen.DB()
assert.NoError(t, err)
components := &Components{
MessageRepo: repository.NewMessageRepo(mockDB, nil),
}
resp, err := NewService(components).GetByRunIDs(ctx, 1, []int64{124})
assert.NoError(t, err)
assert.Len(t, resp, 2)
}