feat: manually mirror opencoze's code from bytedance
Change-Id: I09a73aadda978ad9511264a756b2ce51f5761adf
This commit is contained in:
85
backend/domain/conversation/agentrun/entity/const.go
Normal file
85
backend/domain/conversation/agentrun/entity/const.go
Normal file
@@ -0,0 +1,85 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package entity
|
||||
|
||||
const ConversationTurnsDefault int32 = 100
|
||||
|
||||
type RunStatus string
|
||||
|
||||
const (
|
||||
RunStatusCreated RunStatus = "created"
|
||||
RunStatusInProgress RunStatus = "in_progress"
|
||||
RunStatusCompleted RunStatus = "completed"
|
||||
RunStatusFailed RunStatus = "failed"
|
||||
RunStatusExpired RunStatus = "expired"
|
||||
RunStatusCancelled RunStatus = "cancelled"
|
||||
RunStatusRequiredAction RunStatus = "required_action"
|
||||
RunStatusDeleted RunStatus = "deleted"
|
||||
)
|
||||
|
||||
type RunEvent string
|
||||
|
||||
const (
|
||||
RunEventCreated RunEvent = "conversation.run.created"
|
||||
RunEventInProgress RunEvent = "conversation.run.in_progress"
|
||||
RunEventCompleted RunEvent = "conversation.run.completed"
|
||||
RunEventFailed RunEvent = "conversation.run.failed"
|
||||
RunEventExpired RunEvent = "conversation.run.expired"
|
||||
RunEventCancelled RunEvent = "conversation.run.cancelled"
|
||||
RunEventRequiredAction RunEvent = "conversation.run.required_action"
|
||||
|
||||
RunEventMessageDelta RunEvent = "conversation.message.delta"
|
||||
RunEventMessageCompleted RunEvent = "conversation.message.completed"
|
||||
|
||||
RunEventAck = "conversation.ack"
|
||||
RunEventError RunEvent = "conversation.error"
|
||||
RunEventStreamDone RunEvent = "conversation.stream.done"
|
||||
)
|
||||
|
||||
type ReplyType int64
|
||||
|
||||
const (
|
||||
ReplyTypeAnswer ReplyType = 1
|
||||
ReplyTypeSuggest ReplyType = 2
|
||||
ReplyTypeLLMOutput ReplyType = 3
|
||||
ReplyTypeToolOutput ReplyType = 4
|
||||
ReplyTypeVerbose ReplyType = 100
|
||||
ReplyTypePlaceHolder ReplyType = 101
|
||||
)
|
||||
|
||||
type MetaType int64
|
||||
|
||||
const (
|
||||
MetaTypeKnowledgeCard MetaType = 4
|
||||
)
|
||||
|
||||
type RoleType string
|
||||
|
||||
const (
|
||||
RoleTypeSystem RoleType = "system"
|
||||
RoleTypeUser RoleType = "user"
|
||||
RoleTypeAssistant RoleType = "assistant"
|
||||
RoleTypeTool RoleType = "tool"
|
||||
)
|
||||
|
||||
type MessageSubType string
|
||||
|
||||
const (
|
||||
MessageSubTypeKnowledgeCall MessageSubType = "knowledge_recall"
|
||||
MessageSubTypeGenerateFinish MessageSubType = "generate_answer_finish"
|
||||
MessageSubTypeInterrupt MessageSubType = "interrupt"
|
||||
)
|
||||
@@ -0,0 +1,30 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package entity
|
||||
|
||||
type EventType int64
|
||||
|
||||
const (
|
||||
EventType_LocalPlugin EventType = 1
|
||||
EventType_Question EventType = 2
|
||||
EventType_RequireInfos EventType = 3
|
||||
EventType_SceneChat EventType = 4
|
||||
EventType_InputNode EventType = 5
|
||||
EventType_WorkflowLocalPlugin EventType = 6
|
||||
EventType_OauthPlugin EventType = 7
|
||||
EventType_WorkflowLLM EventType = 100
|
||||
)
|
||||
158
backend/domain/conversation/agentrun/entity/run_record.go
Normal file
158
backend/domain/conversation/agentrun/entity/run_record.go
Normal file
@@ -0,0 +1,158 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package entity
|
||||
|
||||
import (
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/conversation/common"
|
||||
message2 "github.com/coze-dev/coze-studio/backend/api/model/conversation/message"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/internal/dal/model"
|
||||
)
|
||||
|
||||
type RunRecord = model.RunRecord
|
||||
|
||||
type RunRecordMeta struct {
|
||||
ID int64 `json:"id"`
|
||||
ConversationID int64 `json:"conversation_id"`
|
||||
SectionID int64 `json:"section_id"`
|
||||
AgentID int64 `json:"agent_id"`
|
||||
Status RunStatus `json:"status"`
|
||||
Error *RunError `json:"error"`
|
||||
Usage *agentrun.Usage `json:"usage"`
|
||||
Ext string `json:"ext"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
ChatRequest *string `json:"chat_message"`
|
||||
CompletedAt int64 `json:"completed_at"`
|
||||
FailedAt int64 `json:"failed_at"`
|
||||
}
|
||||
|
||||
type ChunkRunItem = RunRecordMeta
|
||||
|
||||
type ChunkMessageItem struct {
|
||||
ID int64 `json:"id"`
|
||||
ConversationID int64 `json:"conversation_id"`
|
||||
SectionID int64 `json:"section_id"`
|
||||
RunID int64 `json:"run_id"`
|
||||
AgentID int64 `json:"agent_id"`
|
||||
Role RoleType `json:"role"`
|
||||
Type message.MessageType `json:"type"`
|
||||
Content string `json:"content"`
|
||||
ContentType message.ContentType `json:"content_type"`
|
||||
MessageType message.MessageType `json:"message_type"`
|
||||
ReplyID int64 `json:"reply_id"`
|
||||
Ext map[string]string `json:"ext"`
|
||||
ReasoningContent *string `json:"reasoning_content"`
|
||||
Index int64 `json:"index"`
|
||||
RequiredAction *message2.RequiredAction `json:"required_action"`
|
||||
SeqID int64 `json:"seq_id"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
IsFinish bool `json:"is_finish"`
|
||||
}
|
||||
|
||||
type RunError struct {
|
||||
Code int64 `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
}
|
||||
|
||||
type CustomerConfig struct {
|
||||
ModelConfig *ModelConfig `json:"model_config"`
|
||||
AgentConfig *AgentConfig `json:"agent_config"`
|
||||
}
|
||||
|
||||
type ModelConfig struct {
|
||||
ModelId *int64 `json:"model_id,omitempty"`
|
||||
}
|
||||
|
||||
type AgentConfig struct {
|
||||
Prompt *string `json:"prompt"`
|
||||
}
|
||||
|
||||
type Tool = agentrun.Tool
|
||||
|
||||
type AnswerFinshContent struct {
|
||||
MsgType MessageSubType `json:"msg_type"`
|
||||
Data string `json:"data"`
|
||||
FromUnit string `json:"from_unit"`
|
||||
}
|
||||
type Data struct {
|
||||
FinishReason int32 `json:"finish_reason"`
|
||||
FinData string `json:"fin_data"`
|
||||
}
|
||||
|
||||
type MetaInfo struct {
|
||||
Type MetaType `json:"type"`
|
||||
Info string `json:"info"`
|
||||
}
|
||||
|
||||
type AgentRunMeta struct {
|
||||
ConversationID int64 `json:"conversation_id"`
|
||||
ConnectorID int64 `json:"connector_id"`
|
||||
SpaceID int64 `json:"space_id"`
|
||||
Scene common.Scene `json:"scene"`
|
||||
SectionID int64 `json:"section_id"`
|
||||
Name string `json:"name"`
|
||||
UserID string `json:"user_id"`
|
||||
AgentID int64 `json:"agent_id"`
|
||||
ContentType message.ContentType `json:"content_type"`
|
||||
Content []*message.InputMetaData `json:"content"`
|
||||
PreRetrieveTools []*Tool `json:"tools"`
|
||||
IsDraft bool `json:"is_draft"`
|
||||
CustomerConfig *CustomerConfig `json:"customer_config"`
|
||||
DisplayContent string `json:"display_content"`
|
||||
CustomVariables map[string]string `json:"custom_variables"`
|
||||
Version string `json:"version"`
|
||||
Ext map[string]string `json:"ext"`
|
||||
}
|
||||
|
||||
type UpdateMeta struct {
|
||||
Status RunStatus
|
||||
LastError *RunError
|
||||
Usage *agentrun.Usage
|
||||
UpdatedAt int64
|
||||
CompletedAt int64
|
||||
FailedAt int64
|
||||
}
|
||||
|
||||
type AgentRunResponse struct {
|
||||
Event RunEvent `json:"event"`
|
||||
ChunkRunItem *ChunkRunItem `json:"run_record_item"`
|
||||
ChunkMessageItem *ChunkMessageItem `json:"message_item"`
|
||||
Error *RunError `json:"error"`
|
||||
}
|
||||
|
||||
type AgentRespEvent struct {
|
||||
EventType message.MessageType
|
||||
|
||||
ModelAnswer *schema.StreamReader[*schema.Message]
|
||||
ToolsMessage []*schema.Message
|
||||
FuncCall *schema.Message
|
||||
Suggest *schema.Message
|
||||
Knowledge []*schema.Document
|
||||
Interrupt *singleagent.InterruptInfo
|
||||
Err error
|
||||
}
|
||||
|
||||
type ModelAnswerEvent struct {
|
||||
Message *schema.Message
|
||||
Err error
|
||||
}
|
||||
167
backend/domain/conversation/agentrun/internal/dal/dao.go
Normal file
167
backend/domain/conversation/agentrun/internal/dal/dao.go
Normal file
@@ -0,0 +1,167 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package dal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/internal/dal/model"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/internal/dal/query"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
)
|
||||
|
||||
type RunRecordDAO struct {
|
||||
db *gorm.DB
|
||||
query *query.Query
|
||||
idGen idgen.IDGenerator
|
||||
}
|
||||
|
||||
func NewRunRecordDAO(db *gorm.DB, idGen idgen.IDGenerator) *RunRecordDAO {
|
||||
return &RunRecordDAO{
|
||||
db: db,
|
||||
idGen: idGen,
|
||||
query: query.Use(db),
|
||||
}
|
||||
}
|
||||
|
||||
func (dao *RunRecordDAO) Create(ctx context.Context, runMeta *entity.AgentRunMeta) (*entity.RunRecordMeta, error) {
|
||||
|
||||
createPO, err := dao.buildCreatePO(ctx, runMeta)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
createErr := dao.query.RunRecord.WithContext(ctx).Create(createPO)
|
||||
if createErr != nil {
|
||||
return nil, createErr
|
||||
}
|
||||
|
||||
return dao.buildPo2Do(createPO), nil
|
||||
}
|
||||
|
||||
func (dao *RunRecordDAO) GetByID(ctx context.Context, id int64) (*model.RunRecord, error) {
|
||||
return dao.query.RunRecord.WithContext(ctx).Where(dao.query.RunRecord.ID.Eq(id)).First()
|
||||
}
|
||||
|
||||
func (dao *RunRecordDAO) UpdateByID(ctx context.Context, id int64, updateMeta *entity.UpdateMeta) error {
|
||||
po := &model.RunRecord{
|
||||
ID: id,
|
||||
}
|
||||
if updateMeta.Status != "" {
|
||||
|
||||
po.Status = string(updateMeta.Status)
|
||||
}
|
||||
if updateMeta.LastError != nil {
|
||||
errString, err := json.Marshal(updateMeta.LastError)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
po.LastError = string(errString)
|
||||
}
|
||||
if updateMeta.CompletedAt != 0 {
|
||||
|
||||
po.CompletedAt = updateMeta.CompletedAt
|
||||
}
|
||||
if updateMeta.FailedAt != 0 {
|
||||
|
||||
po.FailedAt = updateMeta.FailedAt
|
||||
}
|
||||
if updateMeta.Usage != nil {
|
||||
|
||||
po.Usage = updateMeta.Usage
|
||||
}
|
||||
po.UpdatedAt = time.Now().UnixMilli()
|
||||
|
||||
_, err := dao.query.RunRecord.WithContext(ctx).Where(dao.query.RunRecord.ID.Eq(id)).Updates(po)
|
||||
return err
|
||||
}
|
||||
|
||||
func (dao *RunRecordDAO) Delete(ctx context.Context, id []int64) error {
|
||||
|
||||
_, err := dao.query.RunRecord.WithContext(ctx).Where(dao.query.RunRecord.ID.In(id...)).UpdateColumns(map[string]interface{}{
|
||||
"updated_at": time.Now().UnixMilli(),
|
||||
"status": entity.RunStatusDeleted,
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (dao *RunRecordDAO) List(ctx context.Context, conversationID int64, sectionID int64, limit int32) ([]*model.RunRecord, error) {
|
||||
logs.CtxInfof(ctx, "list run record req:%v, sectionID:%v, limit:%v", conversationID, sectionID, limit)
|
||||
m := dao.query.RunRecord
|
||||
do := m.WithContext(ctx).Where(m.ConversationID.Eq(conversationID)).Debug().Where(m.Status.NotIn(string(entity.RunStatusDeleted)))
|
||||
|
||||
if sectionID > 0 {
|
||||
do = do.Where(m.SectionID.Eq(sectionID))
|
||||
}
|
||||
if limit > 0 {
|
||||
do = do.Limit(int(limit))
|
||||
}
|
||||
|
||||
runRecords, err := do.Order(m.CreatedAt.Desc()).Find()
|
||||
return runRecords, err
|
||||
}
|
||||
|
||||
func (dao *RunRecordDAO) buildCreatePO(ctx context.Context, runMeta *entity.AgentRunMeta) (*model.RunRecord, error) {
|
||||
|
||||
runID, err := dao.idGen.GenID(ctx)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reqOrigin, err := json.Marshal(runMeta)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
timeNow := time.Now().UnixMilli()
|
||||
|
||||
return &model.RunRecord{
|
||||
ID: runID,
|
||||
ConversationID: runMeta.ConversationID,
|
||||
SectionID: runMeta.SectionID,
|
||||
AgentID: runMeta.AgentID,
|
||||
Status: string(entity.RunStatusCreated),
|
||||
ChatRequest: string(reqOrigin),
|
||||
UserID: runMeta.UserID,
|
||||
CreatedAt: timeNow,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (dao *RunRecordDAO) buildPo2Do(po *model.RunRecord) *entity.RunRecordMeta {
|
||||
runMeta := &entity.RunRecordMeta{
|
||||
ID: po.ID,
|
||||
ConversationID: po.ConversationID,
|
||||
SectionID: po.SectionID,
|
||||
AgentID: po.AgentID,
|
||||
Status: entity.RunStatus(po.Status),
|
||||
Ext: po.Ext,
|
||||
CreatedAt: po.CreatedAt,
|
||||
UpdatedAt: po.UpdatedAt,
|
||||
CompletedAt: po.CompletedAt,
|
||||
FailedAt: po.FailedAt,
|
||||
Usage: po.Usage,
|
||||
}
|
||||
|
||||
return runMeta
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
|
||||
package model
|
||||
|
||||
import "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
|
||||
|
||||
const TableNameRunRecord = "run_record"
|
||||
|
||||
// RunRecord 执行记录表
|
||||
type RunRecord struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;comment:主键ID" json:"id"` // 主键ID
|
||||
ConversationID int64 `gorm:"column:conversation_id;not null;comment:会话 ID" json:"conversation_id"` // 会话 ID
|
||||
SectionID int64 `gorm:"column:section_id;not null;comment:section ID" json:"section_id"` // section ID
|
||||
AgentID int64 `gorm:"column:agent_id;not null;comment:agent_id" json:"agent_id"` // agent_id
|
||||
UserID string `gorm:"column:user_id;not null;comment:user id" json:"user_id"` // user id
|
||||
Source int32 `gorm:"column:source;not null;comment:执行来源 0 API," json:"source"` // 执行来源 0 API,
|
||||
Status string `gorm:"column:status;not null;comment:状态,0 Unknown, 1-Created,2-InProgress,3-Completed,4-Failed,5-Expired,6-Cancelled,7-RequiresAction" json:"status"` // 状态,0 Unknown, 1-Created,2-InProgress,3-Completed,4-Failed,5-Expired,6-Cancelled,7-RequiresAction
|
||||
CreatorID int64 `gorm:"column:creator_id;not null;comment:创建者标识" json:"creator_id"` // 创建者标识
|
||||
CreatedAt int64 `gorm:"column:created_at;not null;autoCreateTime:milli;comment:创建时间" json:"created_at"` // 创建时间
|
||||
UpdatedAt int64 `gorm:"column:updated_at;not null;autoUpdateTime:milli;comment:更新时间" json:"updated_at"` // 更新时间
|
||||
FailedAt int64 `gorm:"column:failed_at;not null;comment:失败时间" json:"failed_at"` // 失败时间
|
||||
LastError string `gorm:"column:last_error;comment:error message" json:"last_error"` // error message
|
||||
CompletedAt int64 `gorm:"column:completed_at;not null;comment:结束时间" json:"completed_at"` // 结束时间
|
||||
ChatRequest string `gorm:"column:chat_request;comment:保存原始请求的部分字段" json:"chat_request"` // 保存原始请求的部分字段
|
||||
Ext string `gorm:"column:ext;comment:扩展字段" json:"ext"` // 扩展字段
|
||||
Usage *agentrun.Usage `gorm:"column:usage;comment:usage;serializer:json" json:"usage"` // usage
|
||||
}
|
||||
|
||||
// TableName RunRecord's table name
|
||||
func (*RunRecord) TableName() string {
|
||||
return TableNameRunRecord
|
||||
}
|
||||
103
backend/domain/conversation/agentrun/internal/dal/query/gen.go
Normal file
103
backend/domain/conversation/agentrun/internal/dal/query/gen.go
Normal file
@@ -0,0 +1,103 @@
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
|
||||
package query
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"gorm.io/gen"
|
||||
|
||||
"gorm.io/plugin/dbresolver"
|
||||
)
|
||||
|
||||
var (
|
||||
Q = new(Query)
|
||||
RunRecord *runRecord
|
||||
)
|
||||
|
||||
func SetDefault(db *gorm.DB, opts ...gen.DOOption) {
|
||||
*Q = *Use(db, opts...)
|
||||
RunRecord = &Q.RunRecord
|
||||
}
|
||||
|
||||
func Use(db *gorm.DB, opts ...gen.DOOption) *Query {
|
||||
return &Query{
|
||||
db: db,
|
||||
RunRecord: newRunRecord(db, opts...),
|
||||
}
|
||||
}
|
||||
|
||||
type Query struct {
|
||||
db *gorm.DB
|
||||
|
||||
RunRecord runRecord
|
||||
}
|
||||
|
||||
func (q *Query) Available() bool { return q.db != nil }
|
||||
|
||||
func (q *Query) clone(db *gorm.DB) *Query {
|
||||
return &Query{
|
||||
db: db,
|
||||
RunRecord: q.RunRecord.clone(db),
|
||||
}
|
||||
}
|
||||
|
||||
func (q *Query) ReadDB() *Query {
|
||||
return q.ReplaceDB(q.db.Clauses(dbresolver.Read))
|
||||
}
|
||||
|
||||
func (q *Query) WriteDB() *Query {
|
||||
return q.ReplaceDB(q.db.Clauses(dbresolver.Write))
|
||||
}
|
||||
|
||||
func (q *Query) ReplaceDB(db *gorm.DB) *Query {
|
||||
return &Query{
|
||||
db: db,
|
||||
RunRecord: q.RunRecord.replaceDB(db),
|
||||
}
|
||||
}
|
||||
|
||||
type queryCtx struct {
|
||||
RunRecord IRunRecordDo
|
||||
}
|
||||
|
||||
func (q *Query) WithContext(ctx context.Context) *queryCtx {
|
||||
return &queryCtx{
|
||||
RunRecord: q.RunRecord.WithContext(ctx),
|
||||
}
|
||||
}
|
||||
|
||||
func (q *Query) Transaction(fc func(tx *Query) error, opts ...*sql.TxOptions) error {
|
||||
return q.db.Transaction(func(tx *gorm.DB) error { return fc(q.clone(tx)) }, opts...)
|
||||
}
|
||||
|
||||
func (q *Query) Begin(opts ...*sql.TxOptions) *QueryTx {
|
||||
tx := q.db.Begin(opts...)
|
||||
return &QueryTx{Query: q.clone(tx), Error: tx.Error}
|
||||
}
|
||||
|
||||
type QueryTx struct {
|
||||
*Query
|
||||
Error error
|
||||
}
|
||||
|
||||
func (q *QueryTx) Commit() error {
|
||||
return q.db.Commit().Error
|
||||
}
|
||||
|
||||
func (q *QueryTx) Rollback() error {
|
||||
return q.db.Rollback().Error
|
||||
}
|
||||
|
||||
func (q *QueryTx) SavePoint(name string) error {
|
||||
return q.db.SavePoint(name).Error
|
||||
}
|
||||
|
||||
func (q *QueryTx) RollbackTo(name string) error {
|
||||
return q.db.RollbackTo(name).Error
|
||||
}
|
||||
@@ -0,0 +1,441 @@
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
|
||||
package query
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
|
||||
"gorm.io/gen"
|
||||
"gorm.io/gen/field"
|
||||
|
||||
"gorm.io/plugin/dbresolver"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/internal/dal/model"
|
||||
)
|
||||
|
||||
func newRunRecord(db *gorm.DB, opts ...gen.DOOption) runRecord {
|
||||
_runRecord := runRecord{}
|
||||
|
||||
_runRecord.runRecordDo.UseDB(db, opts...)
|
||||
_runRecord.runRecordDo.UseModel(&model.RunRecord{})
|
||||
|
||||
tableName := _runRecord.runRecordDo.TableName()
|
||||
_runRecord.ALL = field.NewAsterisk(tableName)
|
||||
_runRecord.ID = field.NewInt64(tableName, "id")
|
||||
_runRecord.ConversationID = field.NewInt64(tableName, "conversation_id")
|
||||
_runRecord.SectionID = field.NewInt64(tableName, "section_id")
|
||||
_runRecord.AgentID = field.NewInt64(tableName, "agent_id")
|
||||
_runRecord.UserID = field.NewString(tableName, "user_id")
|
||||
_runRecord.Source = field.NewInt32(tableName, "source")
|
||||
_runRecord.Status = field.NewString(tableName, "status")
|
||||
_runRecord.CreatorID = field.NewInt64(tableName, "creator_id")
|
||||
_runRecord.CreatedAt = field.NewInt64(tableName, "created_at")
|
||||
_runRecord.UpdatedAt = field.NewInt64(tableName, "updated_at")
|
||||
_runRecord.FailedAt = field.NewInt64(tableName, "failed_at")
|
||||
_runRecord.LastError = field.NewString(tableName, "last_error")
|
||||
_runRecord.CompletedAt = field.NewInt64(tableName, "completed_at")
|
||||
_runRecord.ChatRequest = field.NewString(tableName, "chat_request")
|
||||
_runRecord.Ext = field.NewString(tableName, "ext")
|
||||
_runRecord.Usage = field.NewField(tableName, "usage")
|
||||
|
||||
_runRecord.fillFieldMap()
|
||||
|
||||
return _runRecord
|
||||
}
|
||||
|
||||
// runRecord 执行记录表
|
||||
type runRecord struct {
|
||||
runRecordDo
|
||||
|
||||
ALL field.Asterisk
|
||||
ID field.Int64 // 主键ID
|
||||
ConversationID field.Int64 // 会话 ID
|
||||
SectionID field.Int64 // section ID
|
||||
AgentID field.Int64 // agent_id
|
||||
UserID field.String // user id
|
||||
Source field.Int32 // 执行来源 0 API,
|
||||
Status field.String // 状态,0 Unknown, 1-Created,2-InProgress,3-Completed,4-Failed,5-Expired,6-Cancelled,7-RequiresAction
|
||||
CreatorID field.Int64 // 创建者标识
|
||||
CreatedAt field.Int64 // 创建时间
|
||||
UpdatedAt field.Int64 // 更新时间
|
||||
FailedAt field.Int64 // 失败时间
|
||||
LastError field.String // error message
|
||||
CompletedAt field.Int64 // 结束时间
|
||||
ChatRequest field.String // 保存原始请求的部分字段
|
||||
Ext field.String // 扩展字段
|
||||
Usage field.Field // usage
|
||||
|
||||
fieldMap map[string]field.Expr
|
||||
}
|
||||
|
||||
func (r runRecord) Table(newTableName string) *runRecord {
|
||||
r.runRecordDo.UseTable(newTableName)
|
||||
return r.updateTableName(newTableName)
|
||||
}
|
||||
|
||||
func (r runRecord) As(alias string) *runRecord {
|
||||
r.runRecordDo.DO = *(r.runRecordDo.As(alias).(*gen.DO))
|
||||
return r.updateTableName(alias)
|
||||
}
|
||||
|
||||
func (r *runRecord) updateTableName(table string) *runRecord {
|
||||
r.ALL = field.NewAsterisk(table)
|
||||
r.ID = field.NewInt64(table, "id")
|
||||
r.ConversationID = field.NewInt64(table, "conversation_id")
|
||||
r.SectionID = field.NewInt64(table, "section_id")
|
||||
r.AgentID = field.NewInt64(table, "agent_id")
|
||||
r.UserID = field.NewString(table, "user_id")
|
||||
r.Source = field.NewInt32(table, "source")
|
||||
r.Status = field.NewString(table, "status")
|
||||
r.CreatorID = field.NewInt64(table, "creator_id")
|
||||
r.CreatedAt = field.NewInt64(table, "created_at")
|
||||
r.UpdatedAt = field.NewInt64(table, "updated_at")
|
||||
r.FailedAt = field.NewInt64(table, "failed_at")
|
||||
r.LastError = field.NewString(table, "last_error")
|
||||
r.CompletedAt = field.NewInt64(table, "completed_at")
|
||||
r.ChatRequest = field.NewString(table, "chat_request")
|
||||
r.Ext = field.NewString(table, "ext")
|
||||
r.Usage = field.NewField(table, "usage")
|
||||
|
||||
r.fillFieldMap()
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *runRecord) GetFieldByName(fieldName string) (field.OrderExpr, bool) {
|
||||
_f, ok := r.fieldMap[fieldName]
|
||||
if !ok || _f == nil {
|
||||
return nil, false
|
||||
}
|
||||
_oe, ok := _f.(field.OrderExpr)
|
||||
return _oe, ok
|
||||
}
|
||||
|
||||
func (r *runRecord) fillFieldMap() {
|
||||
r.fieldMap = make(map[string]field.Expr, 16)
|
||||
r.fieldMap["id"] = r.ID
|
||||
r.fieldMap["conversation_id"] = r.ConversationID
|
||||
r.fieldMap["section_id"] = r.SectionID
|
||||
r.fieldMap["agent_id"] = r.AgentID
|
||||
r.fieldMap["user_id"] = r.UserID
|
||||
r.fieldMap["source"] = r.Source
|
||||
r.fieldMap["status"] = r.Status
|
||||
r.fieldMap["creator_id"] = r.CreatorID
|
||||
r.fieldMap["created_at"] = r.CreatedAt
|
||||
r.fieldMap["updated_at"] = r.UpdatedAt
|
||||
r.fieldMap["failed_at"] = r.FailedAt
|
||||
r.fieldMap["last_error"] = r.LastError
|
||||
r.fieldMap["completed_at"] = r.CompletedAt
|
||||
r.fieldMap["chat_request"] = r.ChatRequest
|
||||
r.fieldMap["ext"] = r.Ext
|
||||
r.fieldMap["usage"] = r.Usage
|
||||
}
|
||||
|
||||
func (r runRecord) clone(db *gorm.DB) runRecord {
|
||||
r.runRecordDo.ReplaceConnPool(db.Statement.ConnPool)
|
||||
return r
|
||||
}
|
||||
|
||||
func (r runRecord) replaceDB(db *gorm.DB) runRecord {
|
||||
r.runRecordDo.ReplaceDB(db)
|
||||
return r
|
||||
}
|
||||
|
||||
type runRecordDo struct{ gen.DO }
|
||||
|
||||
type IRunRecordDo interface {
|
||||
gen.SubQuery
|
||||
Debug() IRunRecordDo
|
||||
WithContext(ctx context.Context) IRunRecordDo
|
||||
WithResult(fc func(tx gen.Dao)) gen.ResultInfo
|
||||
ReplaceDB(db *gorm.DB)
|
||||
ReadDB() IRunRecordDo
|
||||
WriteDB() IRunRecordDo
|
||||
As(alias string) gen.Dao
|
||||
Session(config *gorm.Session) IRunRecordDo
|
||||
Columns(cols ...field.Expr) gen.Columns
|
||||
Clauses(conds ...clause.Expression) IRunRecordDo
|
||||
Not(conds ...gen.Condition) IRunRecordDo
|
||||
Or(conds ...gen.Condition) IRunRecordDo
|
||||
Select(conds ...field.Expr) IRunRecordDo
|
||||
Where(conds ...gen.Condition) IRunRecordDo
|
||||
Order(conds ...field.Expr) IRunRecordDo
|
||||
Distinct(cols ...field.Expr) IRunRecordDo
|
||||
Omit(cols ...field.Expr) IRunRecordDo
|
||||
Join(table schema.Tabler, on ...field.Expr) IRunRecordDo
|
||||
LeftJoin(table schema.Tabler, on ...field.Expr) IRunRecordDo
|
||||
RightJoin(table schema.Tabler, on ...field.Expr) IRunRecordDo
|
||||
Group(cols ...field.Expr) IRunRecordDo
|
||||
Having(conds ...gen.Condition) IRunRecordDo
|
||||
Limit(limit int) IRunRecordDo
|
||||
Offset(offset int) IRunRecordDo
|
||||
Count() (count int64, err error)
|
||||
Scopes(funcs ...func(gen.Dao) gen.Dao) IRunRecordDo
|
||||
Unscoped() IRunRecordDo
|
||||
Create(values ...*model.RunRecord) error
|
||||
CreateInBatches(values []*model.RunRecord, batchSize int) error
|
||||
Save(values ...*model.RunRecord) error
|
||||
First() (*model.RunRecord, error)
|
||||
Take() (*model.RunRecord, error)
|
||||
Last() (*model.RunRecord, error)
|
||||
Find() ([]*model.RunRecord, error)
|
||||
FindInBatch(batchSize int, fc func(tx gen.Dao, batch int) error) (results []*model.RunRecord, err error)
|
||||
FindInBatches(result *[]*model.RunRecord, batchSize int, fc func(tx gen.Dao, batch int) error) error
|
||||
Pluck(column field.Expr, dest interface{}) error
|
||||
Delete(...*model.RunRecord) (info gen.ResultInfo, err error)
|
||||
Update(column field.Expr, value interface{}) (info gen.ResultInfo, err error)
|
||||
UpdateSimple(columns ...field.AssignExpr) (info gen.ResultInfo, err error)
|
||||
Updates(value interface{}) (info gen.ResultInfo, err error)
|
||||
UpdateColumn(column field.Expr, value interface{}) (info gen.ResultInfo, err error)
|
||||
UpdateColumnSimple(columns ...field.AssignExpr) (info gen.ResultInfo, err error)
|
||||
UpdateColumns(value interface{}) (info gen.ResultInfo, err error)
|
||||
UpdateFrom(q gen.SubQuery) gen.Dao
|
||||
Attrs(attrs ...field.AssignExpr) IRunRecordDo
|
||||
Assign(attrs ...field.AssignExpr) IRunRecordDo
|
||||
Joins(fields ...field.RelationField) IRunRecordDo
|
||||
Preload(fields ...field.RelationField) IRunRecordDo
|
||||
FirstOrInit() (*model.RunRecord, error)
|
||||
FirstOrCreate() (*model.RunRecord, error)
|
||||
FindByPage(offset int, limit int) (result []*model.RunRecord, count int64, err error)
|
||||
ScanByPage(result interface{}, offset int, limit int) (count int64, err error)
|
||||
Scan(result interface{}) (err error)
|
||||
Returning(value interface{}, columns ...string) IRunRecordDo
|
||||
UnderlyingDB() *gorm.DB
|
||||
schema.Tabler
|
||||
}
|
||||
|
||||
func (r runRecordDo) Debug() IRunRecordDo {
|
||||
return r.withDO(r.DO.Debug())
|
||||
}
|
||||
|
||||
func (r runRecordDo) WithContext(ctx context.Context) IRunRecordDo {
|
||||
return r.withDO(r.DO.WithContext(ctx))
|
||||
}
|
||||
|
||||
func (r runRecordDo) ReadDB() IRunRecordDo {
|
||||
return r.Clauses(dbresolver.Read)
|
||||
}
|
||||
|
||||
func (r runRecordDo) WriteDB() IRunRecordDo {
|
||||
return r.Clauses(dbresolver.Write)
|
||||
}
|
||||
|
||||
func (r runRecordDo) Session(config *gorm.Session) IRunRecordDo {
|
||||
return r.withDO(r.DO.Session(config))
|
||||
}
|
||||
|
||||
func (r runRecordDo) Clauses(conds ...clause.Expression) IRunRecordDo {
|
||||
return r.withDO(r.DO.Clauses(conds...))
|
||||
}
|
||||
|
||||
func (r runRecordDo) Returning(value interface{}, columns ...string) IRunRecordDo {
|
||||
return r.withDO(r.DO.Returning(value, columns...))
|
||||
}
|
||||
|
||||
func (r runRecordDo) Not(conds ...gen.Condition) IRunRecordDo {
|
||||
return r.withDO(r.DO.Not(conds...))
|
||||
}
|
||||
|
||||
func (r runRecordDo) Or(conds ...gen.Condition) IRunRecordDo {
|
||||
return r.withDO(r.DO.Or(conds...))
|
||||
}
|
||||
|
||||
func (r runRecordDo) Select(conds ...field.Expr) IRunRecordDo {
|
||||
return r.withDO(r.DO.Select(conds...))
|
||||
}
|
||||
|
||||
func (r runRecordDo) Where(conds ...gen.Condition) IRunRecordDo {
|
||||
return r.withDO(r.DO.Where(conds...))
|
||||
}
|
||||
|
||||
func (r runRecordDo) Order(conds ...field.Expr) IRunRecordDo {
|
||||
return r.withDO(r.DO.Order(conds...))
|
||||
}
|
||||
|
||||
func (r runRecordDo) Distinct(cols ...field.Expr) IRunRecordDo {
|
||||
return r.withDO(r.DO.Distinct(cols...))
|
||||
}
|
||||
|
||||
func (r runRecordDo) Omit(cols ...field.Expr) IRunRecordDo {
|
||||
return r.withDO(r.DO.Omit(cols...))
|
||||
}
|
||||
|
||||
func (r runRecordDo) Join(table schema.Tabler, on ...field.Expr) IRunRecordDo {
|
||||
return r.withDO(r.DO.Join(table, on...))
|
||||
}
|
||||
|
||||
func (r runRecordDo) LeftJoin(table schema.Tabler, on ...field.Expr) IRunRecordDo {
|
||||
return r.withDO(r.DO.LeftJoin(table, on...))
|
||||
}
|
||||
|
||||
func (r runRecordDo) RightJoin(table schema.Tabler, on ...field.Expr) IRunRecordDo {
|
||||
return r.withDO(r.DO.RightJoin(table, on...))
|
||||
}
|
||||
|
||||
func (r runRecordDo) Group(cols ...field.Expr) IRunRecordDo {
|
||||
return r.withDO(r.DO.Group(cols...))
|
||||
}
|
||||
|
||||
func (r runRecordDo) Having(conds ...gen.Condition) IRunRecordDo {
|
||||
return r.withDO(r.DO.Having(conds...))
|
||||
}
|
||||
|
||||
func (r runRecordDo) Limit(limit int) IRunRecordDo {
|
||||
return r.withDO(r.DO.Limit(limit))
|
||||
}
|
||||
|
||||
func (r runRecordDo) Offset(offset int) IRunRecordDo {
|
||||
return r.withDO(r.DO.Offset(offset))
|
||||
}
|
||||
|
||||
func (r runRecordDo) Scopes(funcs ...func(gen.Dao) gen.Dao) IRunRecordDo {
|
||||
return r.withDO(r.DO.Scopes(funcs...))
|
||||
}
|
||||
|
||||
func (r runRecordDo) Unscoped() IRunRecordDo {
|
||||
return r.withDO(r.DO.Unscoped())
|
||||
}
|
||||
|
||||
func (r runRecordDo) Create(values ...*model.RunRecord) error {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
return r.DO.Create(values)
|
||||
}
|
||||
|
||||
func (r runRecordDo) CreateInBatches(values []*model.RunRecord, batchSize int) error {
|
||||
return r.DO.CreateInBatches(values, batchSize)
|
||||
}
|
||||
|
||||
// Save : !!! underlying implementation is different with GORM
|
||||
// The method is equivalent to executing the statement: db.Clauses(clause.OnConflict{UpdateAll: true}).Create(values)
|
||||
func (r runRecordDo) Save(values ...*model.RunRecord) error {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
return r.DO.Save(values)
|
||||
}
|
||||
|
||||
func (r runRecordDo) First() (*model.RunRecord, error) {
|
||||
if result, err := r.DO.First(); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return result.(*model.RunRecord), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (r runRecordDo) Take() (*model.RunRecord, error) {
|
||||
if result, err := r.DO.Take(); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return result.(*model.RunRecord), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (r runRecordDo) Last() (*model.RunRecord, error) {
|
||||
if result, err := r.DO.Last(); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return result.(*model.RunRecord), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (r runRecordDo) Find() ([]*model.RunRecord, error) {
|
||||
result, err := r.DO.Find()
|
||||
return result.([]*model.RunRecord), err
|
||||
}
|
||||
|
||||
func (r runRecordDo) FindInBatch(batchSize int, fc func(tx gen.Dao, batch int) error) (results []*model.RunRecord, err error) {
|
||||
buf := make([]*model.RunRecord, 0, batchSize)
|
||||
err = r.DO.FindInBatches(&buf, batchSize, func(tx gen.Dao, batch int) error {
|
||||
defer func() { results = append(results, buf...) }()
|
||||
return fc(tx, batch)
|
||||
})
|
||||
return results, err
|
||||
}
|
||||
|
||||
func (r runRecordDo) FindInBatches(result *[]*model.RunRecord, batchSize int, fc func(tx gen.Dao, batch int) error) error {
|
||||
return r.DO.FindInBatches(result, batchSize, fc)
|
||||
}
|
||||
|
||||
func (r runRecordDo) Attrs(attrs ...field.AssignExpr) IRunRecordDo {
|
||||
return r.withDO(r.DO.Attrs(attrs...))
|
||||
}
|
||||
|
||||
func (r runRecordDo) Assign(attrs ...field.AssignExpr) IRunRecordDo {
|
||||
return r.withDO(r.DO.Assign(attrs...))
|
||||
}
|
||||
|
||||
func (r runRecordDo) Joins(fields ...field.RelationField) IRunRecordDo {
|
||||
for _, _f := range fields {
|
||||
r = *r.withDO(r.DO.Joins(_f))
|
||||
}
|
||||
return &r
|
||||
}
|
||||
|
||||
func (r runRecordDo) Preload(fields ...field.RelationField) IRunRecordDo {
|
||||
for _, _f := range fields {
|
||||
r = *r.withDO(r.DO.Preload(_f))
|
||||
}
|
||||
return &r
|
||||
}
|
||||
|
||||
func (r runRecordDo) FirstOrInit() (*model.RunRecord, error) {
|
||||
if result, err := r.DO.FirstOrInit(); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return result.(*model.RunRecord), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (r runRecordDo) FirstOrCreate() (*model.RunRecord, error) {
|
||||
if result, err := r.DO.FirstOrCreate(); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return result.(*model.RunRecord), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (r runRecordDo) FindByPage(offset int, limit int) (result []*model.RunRecord, count int64, err error) {
|
||||
result, err = r.Offset(offset).Limit(limit).Find()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if size := len(result); 0 < limit && 0 < size && size < limit {
|
||||
count = int64(size + offset)
|
||||
return
|
||||
}
|
||||
|
||||
count, err = r.Offset(-1).Limit(-1).Count()
|
||||
return
|
||||
}
|
||||
|
||||
func (r runRecordDo) ScanByPage(result interface{}, offset int, limit int) (count int64, err error) {
|
||||
count, err = r.Count()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = r.Offset(offset).Limit(limit).Scan(result)
|
||||
return
|
||||
}
|
||||
|
||||
func (r runRecordDo) Scan(result interface{}) (err error) {
|
||||
return r.DO.Scan(result)
|
||||
}
|
||||
|
||||
func (r runRecordDo) Delete(models ...*model.RunRecord) (result gen.ResultInfo, err error) {
|
||||
return r.DO.Delete(models)
|
||||
}
|
||||
|
||||
func (r *runRecordDo) withDO(do gen.Dao) *runRecordDo {
|
||||
r.DO = *do.(*gen.DO)
|
||||
return r
|
||||
}
|
||||
78
backend/domain/conversation/agentrun/internal/event.go
Normal file
78
backend/domain/conversation/agentrun/internal/event.go
Normal file
@@ -0,0 +1,78 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
|
||||
)
|
||||
|
||||
type Event struct {
|
||||
}
|
||||
|
||||
func NewEvent() *Event {
|
||||
return &Event{}
|
||||
}
|
||||
|
||||
func (e *Event) buildMessageEvent(runEvent entity.RunEvent, chunkMsgItem *entity.ChunkMessageItem) *entity.AgentRunResponse {
|
||||
return &entity.AgentRunResponse{
|
||||
Event: runEvent,
|
||||
ChunkMessageItem: chunkMsgItem,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Event) buildRunEvent(runEvent entity.RunEvent, chunkRunItem *entity.ChunkRunItem) *entity.AgentRunResponse {
|
||||
return &entity.AgentRunResponse{
|
||||
Event: runEvent,
|
||||
ChunkRunItem: chunkRunItem,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Event) buildErrEvent(runEvent entity.RunEvent, err *entity.RunError) *entity.AgentRunResponse {
|
||||
return &entity.AgentRunResponse{
|
||||
Event: runEvent,
|
||||
Error: err,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Event) buildStreamDoneEvent() *entity.AgentRunResponse {
|
||||
|
||||
return &entity.AgentRunResponse{
|
||||
Event: entity.RunEventStreamDone,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Event) SendRunEvent(runEvent entity.RunEvent, runItem *entity.ChunkRunItem, sw *schema.StreamWriter[*entity.AgentRunResponse]) {
|
||||
resp := e.buildRunEvent(runEvent, runItem)
|
||||
sw.Send(resp, nil)
|
||||
}
|
||||
|
||||
func (e *Event) SendMsgEvent(runEvent entity.RunEvent, messageItem *entity.ChunkMessageItem, sw *schema.StreamWriter[*entity.AgentRunResponse]) {
|
||||
resp := e.buildMessageEvent(runEvent, messageItem)
|
||||
sw.Send(resp, nil)
|
||||
}
|
||||
|
||||
func (e *Event) SendErrEvent(runEvent entity.RunEvent, sw *schema.StreamWriter[*entity.AgentRunResponse], err *entity.RunError) {
|
||||
resp := e.buildErrEvent(runEvent, err)
|
||||
sw.Send(resp, nil)
|
||||
}
|
||||
|
||||
func (e *Event) SendStreamDoneEvent(sw *schema.StreamWriter[*entity.AgentRunResponse]) {
|
||||
resp := e.buildStreamDoneEvent()
|
||||
sw.Send(resp, nil)
|
||||
}
|
||||
123
backend/domain/conversation/agentrun/internal/run_process.go
Normal file
123
backend/domain/conversation/agentrun/internal/run_process.go
Normal file
@@ -0,0 +1,123 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/repository"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
type RunProcess struct {
|
||||
event *Event
|
||||
|
||||
RunRecordRepo repository.RunRecordRepo
|
||||
}
|
||||
|
||||
func NewRunProcess(runRecordRepo repository.RunRecordRepo) *RunProcess {
|
||||
return &RunProcess{
|
||||
RunRecordRepo: runRecordRepo,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RunProcess) StepToCreate(ctx context.Context, srRecord *entity.ChunkRunItem, sw *schema.StreamWriter[*entity.AgentRunResponse]) {
|
||||
srRecord.Status = entity.RunStatusCreated
|
||||
r.event.SendRunEvent(entity.RunEventCreated, srRecord, sw)
|
||||
}
|
||||
func (r *RunProcess) StepToInProgress(ctx context.Context, srRecord *entity.ChunkRunItem, sw *schema.StreamWriter[*entity.AgentRunResponse]) error {
|
||||
srRecord.Status = entity.RunStatusInProgress
|
||||
|
||||
updateMeta := &entity.UpdateMeta{
|
||||
Status: entity.RunStatusInProgress,
|
||||
UpdatedAt: time.Now().UnixMilli(),
|
||||
}
|
||||
err := r.RunRecordRepo.UpdateByID(ctx, srRecord.ID, updateMeta)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.event.SendRunEvent(entity.RunEventInProgress, srRecord, sw)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RunProcess) StepToComplete(ctx context.Context, srRecord *entity.ChunkRunItem, sw *schema.StreamWriter[*entity.AgentRunResponse], usage *agentrun.Usage) {
|
||||
|
||||
completedAt := time.Now().UnixMilli()
|
||||
|
||||
updateMeta := &entity.UpdateMeta{
|
||||
Status: entity.RunStatusCompleted,
|
||||
Usage: usage,
|
||||
CompletedAt: completedAt,
|
||||
UpdatedAt: completedAt,
|
||||
}
|
||||
err := r.RunRecordRepo.UpdateByID(ctx, srRecord.ID, updateMeta)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "RunRecordRepo.UpdateByID error: %v", err)
|
||||
r.event.SendErrEvent(entity.RunEventError, sw, &entity.RunError{
|
||||
Code: errno.ErrConversationAgentRunError,
|
||||
Msg: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
srRecord.CompletedAt = completedAt
|
||||
srRecord.Status = entity.RunStatusCompleted
|
||||
|
||||
r.event.SendRunEvent(entity.RunEventCompleted, srRecord, sw)
|
||||
|
||||
r.event.SendStreamDoneEvent(sw)
|
||||
}
|
||||
func (r *RunProcess) StepToFailed(ctx context.Context, srRecord *entity.ChunkRunItem, sw *schema.StreamWriter[*entity.AgentRunResponse]) {
|
||||
|
||||
nowTime := time.Now().UnixMilli()
|
||||
updateMeta := &entity.UpdateMeta{
|
||||
Status: entity.RunStatusFailed,
|
||||
UpdatedAt: nowTime,
|
||||
FailedAt: nowTime,
|
||||
LastError: srRecord.Error,
|
||||
}
|
||||
|
||||
err := r.RunRecordRepo.UpdateByID(ctx, srRecord.ID, updateMeta)
|
||||
|
||||
if err != nil {
|
||||
r.event.SendErrEvent(entity.RunEventError, sw, &entity.RunError{
|
||||
Code: errno.ErrConversationAgentRunError,
|
||||
Msg: err.Error(),
|
||||
})
|
||||
logs.CtxErrorf(ctx, "update run record failed, err: %v", err)
|
||||
return
|
||||
}
|
||||
srRecord.Status = entity.RunStatusFailed
|
||||
srRecord.FailedAt = time.Now().UnixMilli()
|
||||
r.event.SendErrEvent(entity.RunEventError, sw, &entity.RunError{
|
||||
Code: srRecord.Error.Code,
|
||||
Msg: srRecord.Error.Msg,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (r *RunProcess) StepToDone(sw *schema.StreamWriter[*entity.AgentRunResponse]) {
|
||||
r.event.SendStreamDoneEvent(sw)
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/internal/dal"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/internal/dal/model"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
)
|
||||
|
||||
func NewRunRecordRepo(db *gorm.DB, idGen idgen.IDGenerator) RunRecordRepo {
|
||||
|
||||
return dal.NewRunRecordDAO(db, idGen)
|
||||
}
|
||||
|
||||
type RunRecordRepo interface {
|
||||
Create(ctx context.Context, runMeta *entity.AgentRunMeta) (*entity.RunRecordMeta, error)
|
||||
GetByID(ctx context.Context, id int64) (*entity.RunRecord, error)
|
||||
Delete(ctx context.Context, id []int64) error
|
||||
UpdateByID(ctx context.Context, id int64, update *entity.UpdateMeta) error
|
||||
List(ctx context.Context, conversationID int64, sectionID int64, limit int32) ([]*model.RunRecord, error)
|
||||
}
|
||||
31
backend/domain/conversation/agentrun/service/agent_run.go
Normal file
31
backend/domain/conversation/agentrun/service/agent_run.go
Normal file
@@ -0,0 +1,31 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package agentrun
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
|
||||
)
|
||||
|
||||
type Run interface {
|
||||
AgentRun(ctx context.Context, req *entity.AgentRunMeta) (*schema.StreamReader[*entity.AgentRunResponse], error)
|
||||
|
||||
Delete(ctx context.Context, runID []int64) error
|
||||
}
|
||||
994
backend/domain/conversation/agentrun/service/agent_run_impl.go
Normal file
994
backend/domain/conversation/agentrun/service/agent_run_impl.go
Normal file
@@ -0,0 +1,994 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package agentrun
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/mohae/deepcopy"
|
||||
|
||||
messageModel "github.com/coze-dev/coze-studio/backend/api/model/conversation/message"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossagent"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossmessage"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/internal"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/internal/dal/model"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/repository"
|
||||
msgEntity "github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
type runImpl struct {
|
||||
Components
|
||||
|
||||
runProcess *internal.RunProcess
|
||||
runEvent *internal.Event
|
||||
}
|
||||
|
||||
type runtimeDependence struct {
|
||||
runID int64
|
||||
agentInfo *singleagent.SingleAgent
|
||||
questionMsgID int64
|
||||
runMeta *entity.AgentRunMeta
|
||||
startTime time.Time
|
||||
|
||||
usage *agentrun.Usage
|
||||
}
|
||||
|
||||
type Components struct {
|
||||
RunRecordRepo repository.RunRecordRepo
|
||||
}
|
||||
|
||||
func NewService(c *Components) Run {
|
||||
return &runImpl{
|
||||
Components: *c,
|
||||
runEvent: internal.NewEvent(),
|
||||
runProcess: internal.NewRunProcess(c.RunRecordRepo),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *runImpl) AgentRun(ctx context.Context, arm *entity.AgentRunMeta) (*schema.StreamReader[*entity.AgentRunResponse], error) {
|
||||
sr, sw := schema.Pipe[*entity.AgentRunResponse](20)
|
||||
|
||||
defer func() {
|
||||
if pe := recover(); pe != nil {
|
||||
logs.CtxErrorf(ctx, "panic recover: %v\n, [stack]:%v", pe, string(debug.Stack()))
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
rtDependence := &runtimeDependence{
|
||||
runMeta: arm,
|
||||
startTime: time.Now(),
|
||||
}
|
||||
|
||||
safego.Go(ctx, func() {
|
||||
defer sw.Close()
|
||||
_ = c.run(ctx, sw, rtDependence)
|
||||
})
|
||||
|
||||
return sr, nil
|
||||
}
|
||||
|
||||
func (c *runImpl) run(ctx context.Context, sw *schema.StreamWriter[*entity.AgentRunResponse], rtDependence *runtimeDependence) (err error) {
|
||||
runRecord, err := c.createRunRecord(ctx, sw, rtDependence)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
rtDependence.runID = runRecord.ID
|
||||
defer func() {
|
||||
srRecord := c.buildSendRunRecord(ctx, runRecord, entity.RunStatusCompleted)
|
||||
if err != nil {
|
||||
srRecord.Error = &entity.RunError{
|
||||
Code: errno.ErrConversationAgentRunError,
|
||||
Msg: err.Error(),
|
||||
}
|
||||
c.runProcess.StepToFailed(ctx, srRecord, sw)
|
||||
return
|
||||
}
|
||||
c.runProcess.StepToComplete(ctx, srRecord, sw, rtDependence.usage)
|
||||
}()
|
||||
|
||||
agentInfo, err := c.handlerAgent(ctx, rtDependence)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
rtDependence.agentInfo = agentInfo
|
||||
|
||||
history, err := c.handlerHistory(ctx, rtDependence)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
input, err := c.handlerInput(ctx, sw, rtDependence)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
rtDependence.questionMsgID = input.ID
|
||||
|
||||
err = c.handlerStreamExecute(ctx, sw, history, input, rtDependence)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *runImpl) handlerAgent(ctx context.Context, rtDependence *runtimeDependence) (*singleagent.SingleAgent, error) {
|
||||
agentInfo, err := crossagent.DefaultSVC().ObtainAgentByIdentity(ctx, &singleagent.AgentIdentity{
|
||||
AgentID: rtDependence.runMeta.AgentID,
|
||||
IsDraft: rtDependence.runMeta.IsDraft,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return agentInfo, nil
|
||||
}
|
||||
|
||||
func (c *runImpl) handlerStreamExecute(ctx context.Context, sw *schema.StreamWriter[*entity.AgentRunResponse], historyMsg []*msgEntity.Message, input *msgEntity.Message, rtDependence *runtimeDependence) (err error) {
|
||||
mainChan := make(chan *entity.AgentRespEvent, 100)
|
||||
|
||||
ar := &singleagent.AgentRuntime{
|
||||
AgentVersion: rtDependence.runMeta.Version,
|
||||
SpaceID: rtDependence.runMeta.SpaceID,
|
||||
IsDraft: rtDependence.runMeta.IsDraft,
|
||||
ConnectorID: rtDependence.runMeta.ConnectorID,
|
||||
PreRetrieveTools: rtDependence.runMeta.PreRetrieveTools,
|
||||
}
|
||||
|
||||
streamer, err := crossagent.DefaultSVC().StreamExecute(ctx, historyMsg, input, ar)
|
||||
if err != nil {
|
||||
return errors.New(errorx.ErrorWithoutStack(err))
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
safego.Go(ctx, func() {
|
||||
defer wg.Done()
|
||||
c.pull(ctx, mainChan, streamer)
|
||||
})
|
||||
|
||||
safego.Go(ctx, func() {
|
||||
defer wg.Done()
|
||||
c.push(ctx, mainChan, sw, rtDependence)
|
||||
})
|
||||
|
||||
wg.Wait()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func transformEventMap(eventType singleagent.EventType) (message.MessageType, error) {
|
||||
var eType message.MessageType
|
||||
switch eventType {
|
||||
case singleagent.EventTypeOfFuncCall:
|
||||
return message.MessageTypeFunctionCall, nil
|
||||
case singleagent.EventTypeOfKnowledge:
|
||||
return message.MessageTypeKnowledge, nil
|
||||
case singleagent.EventTypeOfToolsMessage:
|
||||
return message.MessageTypeToolResponse, nil
|
||||
case singleagent.EventTypeOfChatModelAnswer:
|
||||
return message.MessageTypeAnswer, nil
|
||||
case singleagent.EventTypeOfSuggest:
|
||||
return message.MessageTypeFlowUp, nil
|
||||
case singleagent.EventTypeOfInterrupt:
|
||||
return message.MessageTypeInterrupt, nil
|
||||
}
|
||||
return eType, errorx.New(errno.ErrReplyUnknowEventType)
|
||||
}
|
||||
|
||||
func (c *runImpl) buildAgentMessage2Create(ctx context.Context, chunk *entity.AgentRespEvent, messageType message.MessageType, rtDependence *runtimeDependence) *message.Message {
|
||||
arm := rtDependence.runMeta
|
||||
msg := &msgEntity.Message{
|
||||
ConversationID: arm.ConversationID,
|
||||
RunID: rtDependence.runID,
|
||||
AgentID: arm.AgentID,
|
||||
SectionID: arm.SectionID,
|
||||
UserID: arm.UserID,
|
||||
MessageType: messageType,
|
||||
}
|
||||
buildExt := map[string]string{}
|
||||
|
||||
timeCost := fmt.Sprintf("%.1f", float64(time.Since(rtDependence.startTime).Milliseconds())/1000.00)
|
||||
|
||||
switch messageType {
|
||||
case message.MessageTypeQuestion:
|
||||
msg.Role = schema.User
|
||||
msg.ContentType = arm.ContentType
|
||||
for _, content := range arm.Content {
|
||||
if content.Type == message.InputTypeText {
|
||||
msg.Content = content.Text
|
||||
break
|
||||
}
|
||||
}
|
||||
msg.MultiContent = arm.Content
|
||||
buildExt = arm.Ext
|
||||
|
||||
msg.DisplayContent = arm.DisplayContent
|
||||
case message.MessageTypeAnswer:
|
||||
msg.Role = schema.Assistant
|
||||
msg.ContentType = message.ContentTypeText
|
||||
|
||||
case message.MessageTypeToolResponse:
|
||||
msg.Role = schema.Tool
|
||||
msg.ContentType = message.ContentTypeText
|
||||
msg.Content = chunk.ToolsMessage[0].Content
|
||||
|
||||
buildExt[string(msgEntity.MessageExtKeyTimeCost)] = timeCost
|
||||
modelContent := chunk.ToolsMessage[0]
|
||||
mc, err := json.Marshal(modelContent)
|
||||
if err == nil {
|
||||
msg.ModelContent = string(mc)
|
||||
}
|
||||
|
||||
case message.MessageTypeKnowledge:
|
||||
msg.Role = schema.Assistant
|
||||
msg.ContentType = message.ContentTypeText
|
||||
|
||||
knowledgeContent := c.buildKnowledge(ctx, arm, chunk)
|
||||
if knowledgeContent != nil {
|
||||
knInfo, err := json.Marshal(knowledgeContent)
|
||||
if err == nil {
|
||||
msg.Content = string(knInfo)
|
||||
}
|
||||
}
|
||||
|
||||
buildExt[string(msgEntity.MessageExtKeyTimeCost)] = timeCost
|
||||
|
||||
modelContent := chunk.Knowledge
|
||||
mc, err := json.Marshal(modelContent)
|
||||
if err == nil {
|
||||
msg.ModelContent = string(mc)
|
||||
}
|
||||
|
||||
case message.MessageTypeFunctionCall:
|
||||
msg.Role = schema.Assistant
|
||||
msg.ContentType = message.ContentTypeText
|
||||
|
||||
if len(chunk.FuncCall.ToolCalls) > 0 {
|
||||
toolCall := chunk.FuncCall.ToolCalls[0]
|
||||
toolCalling, err := json.Marshal(toolCall)
|
||||
if err == nil {
|
||||
msg.Content = string(toolCalling)
|
||||
}
|
||||
buildExt[string(msgEntity.MessageExtKeyPlugin)] = toolCall.Function.Name
|
||||
buildExt[string(msgEntity.MessageExtKeyToolName)] = toolCall.Function.Name
|
||||
buildExt[string(msgEntity.MessageExtKeyTimeCost)] = timeCost
|
||||
|
||||
modelContent := chunk.FuncCall
|
||||
mc, err := json.Marshal(modelContent)
|
||||
if err == nil {
|
||||
msg.ModelContent = string(mc)
|
||||
}
|
||||
}
|
||||
case message.MessageTypeFlowUp:
|
||||
msg.Role = schema.Assistant
|
||||
msg.ContentType = message.ContentTypeText
|
||||
msg.Content = chunk.Suggest.Content
|
||||
|
||||
case message.MessageTypeVerbose:
|
||||
msg.Role = schema.Assistant
|
||||
msg.ContentType = message.ContentTypeText
|
||||
|
||||
d := &entity.Data{
|
||||
FinishReason: 0,
|
||||
FinData: "",
|
||||
}
|
||||
dByte, _ := json.Marshal(d)
|
||||
afc := &entity.AnswerFinshContent{
|
||||
MsgType: entity.MessageSubTypeGenerateFinish,
|
||||
Data: string(dByte),
|
||||
}
|
||||
afcMarshal, _ := json.Marshal(afc)
|
||||
msg.Content = string(afcMarshal)
|
||||
case message.MessageTypeInterrupt:
|
||||
msg.Role = schema.Assistant
|
||||
msg.MessageType = message.MessageTypeVerbose
|
||||
msg.ContentType = message.ContentTypeText
|
||||
|
||||
afc := &entity.AnswerFinshContent{
|
||||
MsgType: entity.MessageSubTypeInterrupt,
|
||||
Data: "",
|
||||
}
|
||||
afcMarshal, _ := json.Marshal(afc)
|
||||
msg.Content = string(afcMarshal)
|
||||
|
||||
// 添加 ext 用于保存到 context_message
|
||||
interruptByte, err := json.Marshal(chunk.Interrupt)
|
||||
if err == nil {
|
||||
buildExt[string(msgEntity.ExtKeyResumeInfo)] = string(interruptByte)
|
||||
}
|
||||
buildExt[string(msgEntity.ExtKeyToolCallsIDs)] = chunk.Interrupt.ToolCallID
|
||||
rc := &messageModel.RequiredAction{
|
||||
Type: "submit_tool_outputs",
|
||||
SubmitToolOutputs: &messageModel.SubmitToolOutputs{},
|
||||
}
|
||||
msg.RequiredAction = rc
|
||||
rcExtByte, err := json.Marshal(rc)
|
||||
if err == nil {
|
||||
buildExt[string(msgEntity.ExtKeyRequiresAction)] = string(rcExtByte)
|
||||
}
|
||||
}
|
||||
|
||||
if messageType != message.MessageTypeQuestion {
|
||||
botStateExt := c.buildBotStateExt(arm)
|
||||
bseString, err := json.Marshal(botStateExt)
|
||||
if err == nil {
|
||||
buildExt[string(msgEntity.MessageExtKeyBotState)] = string(bseString)
|
||||
}
|
||||
}
|
||||
msg.Ext = buildExt
|
||||
return msg
|
||||
}
|
||||
|
||||
func (c *runImpl) handlerHistory(ctx context.Context, rtDependence *runtimeDependence) ([]*msgEntity.Message, error) {
|
||||
|
||||
conversationTurns := entity.ConversationTurnsDefault
|
||||
|
||||
if rtDependence.agentInfo != nil && rtDependence.agentInfo.ModelInfo != nil && rtDependence.agentInfo.ModelInfo.ShortMemoryPolicy != nil && ptr.From(rtDependence.agentInfo.ModelInfo.ShortMemoryPolicy.HistoryRound) > 0 {
|
||||
conversationTurns = ptr.From(rtDependence.agentInfo.ModelInfo.ShortMemoryPolicy.HistoryRound)
|
||||
}
|
||||
|
||||
runRecordList, err := c.RunRecordRepo.List(ctx, rtDependence.runMeta.ConversationID, rtDependence.runMeta.SectionID, conversationTurns)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(runRecordList) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
runIDS := c.getRunID(runRecordList)
|
||||
|
||||
history, err := crossmessage.DefaultSVC().GetByRunIDs(ctx, rtDependence.runMeta.ConversationID, runIDS)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return history, nil
|
||||
}
|
||||
|
||||
func (c *runImpl) getRunID(rr []*model.RunRecord) []int64 {
|
||||
ids := make([]int64, 0, len(rr))
|
||||
for _, c := range rr {
|
||||
ids = append(ids, c.ID)
|
||||
}
|
||||
|
||||
return ids
|
||||
}
|
||||
|
||||
func (c *runImpl) createRunRecord(ctx context.Context, sw *schema.StreamWriter[*entity.AgentRunResponse], rtDependence *runtimeDependence) (*entity.RunRecordMeta, error) {
|
||||
runPoData, err := c.RunRecordRepo.Create(ctx, rtDependence.runMeta)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "RunRecordRepo.Create error: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
srRecord := c.buildSendRunRecord(ctx, runPoData, entity.RunStatusCreated)
|
||||
|
||||
c.runProcess.StepToCreate(ctx, srRecord, sw)
|
||||
|
||||
err = c.runProcess.StepToInProgress(ctx, srRecord, sw)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "runProcess.StepToInProgress error: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return runPoData, nil
|
||||
}
|
||||
|
||||
func (c *runImpl) handlerInput(ctx context.Context, sw *schema.StreamWriter[*entity.AgentRunResponse], rtDependence *runtimeDependence) (*msgEntity.Message, error) {
|
||||
msgMeta := c.buildAgentMessage2Create(ctx, nil, message.MessageTypeQuestion, rtDependence)
|
||||
|
||||
cm, err := crossmessage.DefaultSVC().Create(ctx, msgMeta)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ackErr := c.handlerAckMessage(ctx, cm, sw)
|
||||
if ackErr != nil {
|
||||
return msgMeta, ackErr
|
||||
}
|
||||
return cm, nil
|
||||
}
|
||||
|
||||
func (c *runImpl) pull(_ context.Context, mainChan chan *entity.AgentRespEvent, events *schema.StreamReader[*crossagent.AgentEvent]) {
|
||||
defer func() {
|
||||
close(mainChan)
|
||||
}()
|
||||
|
||||
for {
|
||||
rm, re := events.Recv()
|
||||
if re != nil {
|
||||
errChunk := &entity.AgentRespEvent{
|
||||
Err: re,
|
||||
}
|
||||
mainChan <- errChunk
|
||||
return
|
||||
}
|
||||
|
||||
eventType, tErr := transformEventMap(rm.EventType)
|
||||
|
||||
if tErr != nil {
|
||||
errChunk := &entity.AgentRespEvent{
|
||||
Err: tErr,
|
||||
}
|
||||
mainChan <- errChunk
|
||||
return
|
||||
}
|
||||
|
||||
respChunk := &entity.AgentRespEvent{
|
||||
EventType: eventType,
|
||||
ModelAnswer: rm.ChatModelAnswer,
|
||||
ToolsMessage: rm.ToolsMessage,
|
||||
FuncCall: rm.FuncCall,
|
||||
Knowledge: rm.Knowledge,
|
||||
Suggest: rm.Suggest,
|
||||
Interrupt: rm.Interrupt,
|
||||
}
|
||||
|
||||
mainChan <- respChunk
|
||||
}
|
||||
}
|
||||
|
||||
func (c *runImpl) push(ctx context.Context, mainChan chan *entity.AgentRespEvent, sw *schema.StreamWriter[*entity.AgentRunResponse], rtDependence *runtimeDependence) {
|
||||
|
||||
var err error
|
||||
defer func() {
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "run.push error: %v", err)
|
||||
c.handlerErr(ctx, err, sw)
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
chunk, ok := <-mainChan
|
||||
if !ok || chunk == nil {
|
||||
return
|
||||
}
|
||||
logs.CtxInfof(ctx, "hanlder event:%v,err:%v", conv.DebugJsonToStr(chunk), chunk.Err)
|
||||
if chunk.Err != nil {
|
||||
if errors.Is(chunk.Err, io.EOF) {
|
||||
return
|
||||
}
|
||||
c.handlerErr(ctx, chunk.Err, sw)
|
||||
return
|
||||
}
|
||||
|
||||
switch chunk.EventType {
|
||||
case message.MessageTypeFunctionCall:
|
||||
err = c.handlerFunctionCall(ctx, chunk, sw, rtDependence)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
case message.MessageTypeToolResponse:
|
||||
err = c.handlerTooResponse(ctx, chunk, sw, rtDependence)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
case message.MessageTypeKnowledge:
|
||||
err = c.handlerKnowledge(ctx, chunk, sw, rtDependence)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
case message.MessageTypeAnswer:
|
||||
fullContent := bytes.NewBuffer([]byte{})
|
||||
reasoningContent := bytes.NewBuffer([]byte{})
|
||||
|
||||
var preMsg *msgEntity.Message
|
||||
var usage *msgEntity.UsageExt
|
||||
var createPreMsg = true
|
||||
var isToolCalls = false
|
||||
|
||||
for {
|
||||
streamMsg, receErr := chunk.ModelAnswer.Recv()
|
||||
|
||||
if receErr != nil {
|
||||
if errors.Is(receErr, io.EOF) {
|
||||
|
||||
if isToolCalls && reasoningContent.String() == "" {
|
||||
break
|
||||
}
|
||||
|
||||
finalAnswer := c.buildSendMsg(ctx, preMsg, false, rtDependence)
|
||||
finalAnswer.Content = fullContent.String()
|
||||
finalAnswer.ReasoningContent = ptr.Of(reasoningContent.String())
|
||||
hfErr := c.handlerFinalAnswer(ctx, finalAnswer, sw, usage, rtDependence)
|
||||
if hfErr != nil {
|
||||
err = hfErr
|
||||
return
|
||||
}
|
||||
finishErr := c.handlerFinalAnswerFinish(ctx, sw, rtDependence)
|
||||
if finishErr != nil {
|
||||
err = finishErr
|
||||
return
|
||||
}
|
||||
break
|
||||
}
|
||||
err = receErr
|
||||
return
|
||||
}
|
||||
|
||||
if streamMsg != nil && len(streamMsg.ToolCalls) > 0 {
|
||||
isToolCalls = true
|
||||
}
|
||||
|
||||
if streamMsg != nil && streamMsg.ResponseMeta != nil {
|
||||
usage = c.handlerUsage(streamMsg.ResponseMeta)
|
||||
}
|
||||
|
||||
if streamMsg != nil && len(streamMsg.ReasoningContent) == 0 && len(streamMsg.Content) == 0 {
|
||||
continue
|
||||
}
|
||||
if createPreMsg && (len(streamMsg.ReasoningContent) > 0 || len(streamMsg.Content) > 0) {
|
||||
preMsg, err = c.handlerPreAnswer(ctx, rtDependence)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
createPreMsg = false
|
||||
}
|
||||
|
||||
sendMsg := c.buildSendMsg(ctx, preMsg, false, rtDependence)
|
||||
reasoningContent.WriteString(streamMsg.ReasoningContent)
|
||||
sendMsg.ReasoningContent = ptr.Of(streamMsg.ReasoningContent)
|
||||
|
||||
fullContent.WriteString(streamMsg.Content)
|
||||
sendMsg.Content = streamMsg.Content
|
||||
|
||||
c.runEvent.SendMsgEvent(entity.RunEventMessageDelta, sendMsg, sw)
|
||||
}
|
||||
|
||||
case message.MessageTypeFlowUp:
|
||||
err = c.handlerSuggest(ctx, chunk, sw, rtDependence)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
case message.MessageTypeInterrupt:
|
||||
err = c.handlerInterrupt(ctx, chunk, sw, rtDependence)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *runImpl) handlerInterrupt(ctx context.Context, chunk *entity.AgentRespEvent, sw *schema.StreamWriter[*entity.AgentRunResponse], rtDependence *runtimeDependence) error {
|
||||
interruptData, cType, err := c.parseInterruptData(ctx, chunk.Interrupt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
preMsg, err := c.handlerPreAnswer(ctx, rtDependence)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
deltaAnswer := &entity.ChunkMessageItem{
|
||||
ID: preMsg.ID,
|
||||
ConversationID: preMsg.ConversationID,
|
||||
SectionID: preMsg.SectionID,
|
||||
RunID: preMsg.RunID,
|
||||
AgentID: preMsg.AgentID,
|
||||
Role: entity.RoleType(preMsg.Role),
|
||||
Content: interruptData,
|
||||
MessageType: preMsg.MessageType,
|
||||
ContentType: cType,
|
||||
ReplyID: preMsg.RunID,
|
||||
Ext: preMsg.Ext,
|
||||
IsFinish: false,
|
||||
}
|
||||
|
||||
c.runEvent.SendMsgEvent(entity.RunEventMessageDelta, deltaAnswer, sw)
|
||||
finalAnswer := deepcopy.Copy(deltaAnswer).(*entity.ChunkMessageItem)
|
||||
|
||||
err = c.handlerFinalAnswer(ctx, finalAnswer, sw, nil, rtDependence)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = c.handlerInterruptVerbose(ctx, chunk, sw, rtDependence)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = c.handlerFinalAnswerFinish(ctx, sw, rtDependence)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *runImpl) parseInterruptData(_ context.Context, interruptData *singleagent.InterruptInfo) (string, message.ContentType, error) {
|
||||
|
||||
type msg struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
ContentType string `json:"content_type"`
|
||||
Content any `json:"content"` // either optionContent or string
|
||||
ID string `json:"id,omitempty"`
|
||||
}
|
||||
|
||||
defaultContentType := message.ContentTypeText
|
||||
switch interruptData.InterruptType {
|
||||
case singleagent.InterruptEventType_OauthPlugin:
|
||||
data := interruptData.AllToolInterruptData[interruptData.ToolCallID].ToolNeedOAuth.Message
|
||||
return data, defaultContentType, nil
|
||||
case singleagent.InterruptEventType_Question:
|
||||
var iData map[string][]*msg
|
||||
err := json.Unmarshal([]byte(interruptData.AllWfInterruptData[interruptData.ToolCallID].InterruptData), &iData)
|
||||
if err != nil {
|
||||
return "", defaultContentType, err
|
||||
}
|
||||
if len(iData["messages"]) == 0 {
|
||||
return "", defaultContentType, errorx.New(errno.ErrInterruptDataEmpty)
|
||||
}
|
||||
interruptMsg := iData["messages"][0]
|
||||
|
||||
if interruptMsg.ContentType == "text" {
|
||||
return interruptMsg.Content.(string), defaultContentType, nil
|
||||
} else if interruptMsg.ContentType == "option" || interruptMsg.ContentType == "form_schema" {
|
||||
iMarshalData, err := json.Marshal(interruptMsg)
|
||||
if err != nil {
|
||||
return "", defaultContentType, err
|
||||
}
|
||||
return string(iMarshalData), message.ContentTypeCard, nil
|
||||
}
|
||||
case singleagent.InterruptEventType_InputNode:
|
||||
data := interruptData.AllWfInterruptData[interruptData.ToolCallID].InterruptData
|
||||
return data, message.ContentTypeCard, nil
|
||||
case singleagent.InterruptEventType_WorkflowLLM:
|
||||
toolInterruptEvent := interruptData.AllWfInterruptData[interruptData.ToolCallID].ToolInterruptEvent
|
||||
data := toolInterruptEvent.InterruptData
|
||||
if singleagent.InterruptEventType(toolInterruptEvent.EventType) == singleagent.InterruptEventType_InputNode {
|
||||
return data, message.ContentTypeCard, nil
|
||||
}
|
||||
if singleagent.InterruptEventType(toolInterruptEvent.EventType) == singleagent.InterruptEventType_Question {
|
||||
var iData map[string][]*msg
|
||||
err := json.Unmarshal([]byte(data), &iData)
|
||||
if err != nil {
|
||||
return "", defaultContentType, err
|
||||
}
|
||||
if len(iData["messages"]) == 0 {
|
||||
return "", defaultContentType, errorx.New(errno.ErrInterruptDataEmpty)
|
||||
}
|
||||
interruptMsg := iData["messages"][0]
|
||||
|
||||
if interruptMsg.ContentType == "text" {
|
||||
return interruptMsg.Content.(string), defaultContentType, nil
|
||||
} else if interruptMsg.ContentType == "option" || interruptMsg.ContentType == "form_schema" {
|
||||
iMarshalData, err := json.Marshal(interruptMsg)
|
||||
if err != nil {
|
||||
return "", defaultContentType, err
|
||||
}
|
||||
return string(iMarshalData), message.ContentTypeCard, nil
|
||||
}
|
||||
}
|
||||
return "", defaultContentType, errorx.New(errno.ErrUnknowInterruptType)
|
||||
|
||||
}
|
||||
return "", defaultContentType, errorx.New(errno.ErrUnknowInterruptType)
|
||||
}
|
||||
|
||||
func (c *runImpl) handlerUsage(meta *schema.ResponseMeta) *msgEntity.UsageExt {
|
||||
if meta == nil || meta.Usage == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &msgEntity.UsageExt{
|
||||
TotalCount: int64(meta.Usage.TotalTokens),
|
||||
InputTokens: int64(meta.Usage.PromptTokens),
|
||||
OutputTokens: int64(meta.Usage.CompletionTokens),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *runImpl) handlerErr(_ context.Context, err error, sw *schema.StreamWriter[*entity.AgentRunResponse]) {
|
||||
c.runEvent.SendErrEvent(entity.RunEventError, sw, &entity.RunError{
|
||||
Code: errno.ErrConversationAgentRunError,
|
||||
Msg: errorx.ErrorWithoutStack(err),
|
||||
})
|
||||
}
|
||||
|
||||
func (c *runImpl) handlerPreAnswer(ctx context.Context, rtDependence *runtimeDependence) (*msgEntity.Message, error) {
|
||||
arm := rtDependence.runMeta
|
||||
msgMeta := &msgEntity.Message{
|
||||
ConversationID: arm.ConversationID,
|
||||
RunID: rtDependence.runID,
|
||||
AgentID: arm.AgentID,
|
||||
SectionID: arm.SectionID,
|
||||
UserID: arm.UserID,
|
||||
Role: schema.Assistant,
|
||||
MessageType: message.MessageTypeAnswer,
|
||||
ContentType: message.ContentTypeText,
|
||||
Ext: arm.Ext,
|
||||
}
|
||||
|
||||
if arm.Ext == nil {
|
||||
msgMeta.Ext = map[string]string{}
|
||||
}
|
||||
|
||||
botStateExt := c.buildBotStateExt(arm)
|
||||
bseString, err := json.Marshal(botStateExt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, ok := msgMeta.Ext[string(msgEntity.MessageExtKeyBotState)]; !ok {
|
||||
msgMeta.Ext[string(msgEntity.MessageExtKeyBotState)] = string(bseString)
|
||||
}
|
||||
|
||||
msgMeta.Ext = arm.Ext
|
||||
return crossmessage.DefaultSVC().Create(ctx, msgMeta)
|
||||
}
|
||||
|
||||
func (c *runImpl) handlerFinalAnswer(ctx context.Context, msg *entity.ChunkMessageItem, sw *schema.StreamWriter[*entity.AgentRunResponse], usage *msgEntity.UsageExt, rtDependence *runtimeDependence) error {
|
||||
|
||||
if len(msg.Content) == 0 && len(ptr.From(msg.ReasoningContent)) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
msg.IsFinish = true
|
||||
|
||||
if msg.Ext == nil {
|
||||
msg.Ext = map[string]string{}
|
||||
}
|
||||
if usage != nil {
|
||||
msg.Ext[string(msgEntity.MessageExtKeyToken)] = strconv.FormatInt(usage.TotalCount, 10)
|
||||
msg.Ext[string(msgEntity.MessageExtKeyInputTokens)] = strconv.FormatInt(usage.InputTokens, 10)
|
||||
msg.Ext[string(msgEntity.MessageExtKeyOutputTokens)] = strconv.FormatInt(usage.OutputTokens, 10)
|
||||
|
||||
rtDependence.usage = &agentrun.Usage{
|
||||
LlmPromptTokens: usage.InputTokens,
|
||||
LlmCompletionTokens: usage.OutputTokens,
|
||||
LlmTotalTokens: usage.TotalCount,
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := msg.Ext[string(msgEntity.MessageExtKeyTimeCost)]; !ok {
|
||||
msg.Ext[string(msgEntity.MessageExtKeyTimeCost)] = fmt.Sprintf("%.1f", float64(time.Since(rtDependence.startTime).Milliseconds())/1000.00)
|
||||
}
|
||||
|
||||
buildModelContent := &schema.Message{
|
||||
Role: schema.Assistant,
|
||||
Content: msg.Content,
|
||||
}
|
||||
|
||||
mc, err := json.Marshal(buildModelContent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
editMsg := &msgEntity.Message{
|
||||
ID: msg.ID,
|
||||
Content: msg.Content,
|
||||
ContentType: msg.ContentType,
|
||||
ModelContent: string(mc),
|
||||
ReasoningContent: ptr.From(msg.ReasoningContent),
|
||||
Ext: msg.Ext,
|
||||
}
|
||||
_, err = crossmessage.DefaultSVC().Edit(ctx, editMsg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.runEvent.SendMsgEvent(entity.RunEventMessageCompleted, msg, sw)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *runImpl) buildBotStateExt(arm *entity.AgentRunMeta) *msgEntity.BotStateExt {
|
||||
agentID := strconv.FormatInt(arm.AgentID, 10)
|
||||
botStateExt := &msgEntity.BotStateExt{
|
||||
AgentID: agentID,
|
||||
AgentName: arm.Name,
|
||||
Awaiting: agentID,
|
||||
BotID: agentID,
|
||||
}
|
||||
|
||||
return botStateExt
|
||||
}
|
||||
|
||||
func (c *runImpl) handlerFunctionCall(ctx context.Context, chunk *entity.AgentRespEvent, sw *schema.StreamWriter[*entity.AgentRunResponse], rtDependence *runtimeDependence) error {
|
||||
cm := c.buildAgentMessage2Create(ctx, chunk, message.MessageTypeFunctionCall, rtDependence)
|
||||
|
||||
cmData, err := crossmessage.DefaultSVC().Create(ctx, cm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sendMsg := c.buildSendMsg(ctx, cmData, true, rtDependence)
|
||||
|
||||
c.runEvent.SendMsgEvent(entity.RunEventMessageCompleted, sendMsg, sw)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *runImpl) handlerAckMessage(_ context.Context, input *msgEntity.Message, sw *schema.StreamWriter[*entity.AgentRunResponse]) error {
|
||||
sendMsg := &entity.ChunkMessageItem{
|
||||
ID: input.ID,
|
||||
ConversationID: input.ConversationID,
|
||||
SectionID: input.SectionID,
|
||||
AgentID: input.AgentID,
|
||||
Role: entity.RoleType(input.Role),
|
||||
MessageType: message.MessageTypeAck,
|
||||
ReplyID: input.ID,
|
||||
Content: input.Content,
|
||||
ContentType: message.ContentTypeText,
|
||||
IsFinish: true,
|
||||
}
|
||||
|
||||
c.runEvent.SendMsgEvent(entity.RunEventAck, sendMsg, sw)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *runImpl) handlerTooResponse(ctx context.Context, chunk *entity.AgentRespEvent, sw *schema.StreamWriter[*entity.AgentRunResponse], rtDependence *runtimeDependence) error {
|
||||
cm := c.buildAgentMessage2Create(ctx, chunk, message.MessageTypeToolResponse, rtDependence)
|
||||
cmData, err := crossmessage.DefaultSVC().Create(ctx, cm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sendMsg := c.buildSendMsg(ctx, cmData, true, rtDependence)
|
||||
|
||||
c.runEvent.SendMsgEvent(entity.RunEventMessageCompleted, sendMsg, sw)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *runImpl) handlerSuggest(ctx context.Context, chunk *entity.AgentRespEvent, sw *schema.StreamWriter[*entity.AgentRunResponse], rtDependence *runtimeDependence) error {
|
||||
cm := c.buildAgentMessage2Create(ctx, chunk, message.MessageTypeFlowUp, rtDependence)
|
||||
|
||||
cmData, err := crossmessage.DefaultSVC().Create(ctx, cm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sendMsg := c.buildSendMsg(ctx, cmData, true, rtDependence)
|
||||
|
||||
c.runEvent.SendMsgEvent(entity.RunEventMessageCompleted, sendMsg, sw)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *runImpl) handlerKnowledge(ctx context.Context, chunk *entity.AgentRespEvent, sw *schema.StreamWriter[*entity.AgentRunResponse], rtDependence *runtimeDependence) error {
|
||||
cm := c.buildAgentMessage2Create(ctx, chunk, message.MessageTypeKnowledge, rtDependence)
|
||||
cmData, err := crossmessage.DefaultSVC().Create(ctx, cm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sendMsg := c.buildSendMsg(ctx, cmData, true, rtDependence)
|
||||
|
||||
c.runEvent.SendMsgEvent(entity.RunEventMessageCompleted, sendMsg, sw)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *runImpl) buildKnowledge(_ context.Context, arm *entity.AgentRunMeta, chunk *entity.AgentRespEvent) *msgEntity.VerboseInfo {
|
||||
var recallDatas []msgEntity.RecallDataInfo
|
||||
for _, kOne := range chunk.Knowledge {
|
||||
recallDatas = append(recallDatas, msgEntity.RecallDataInfo{
|
||||
Slice: kOne.Content,
|
||||
Meta: msgEntity.MetaInfo{
|
||||
Dataset: msgEntity.DatasetInfo{
|
||||
ID: kOne.MetaData["dataset_id"].(string),
|
||||
Name: kOne.MetaData["dataset_name"].(string),
|
||||
},
|
||||
Document: msgEntity.DocumentInfo{
|
||||
ID: kOne.MetaData["document_id"].(string),
|
||||
Name: kOne.MetaData["document_name"].(string),
|
||||
},
|
||||
},
|
||||
Score: kOne.Score(),
|
||||
})
|
||||
}
|
||||
|
||||
verboseData := &msgEntity.VerboseData{
|
||||
Chunks: recallDatas,
|
||||
OriReq: "",
|
||||
StatusCode: 0,
|
||||
}
|
||||
data, err := json.Marshal(verboseData)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
knowledgeInfo := &msgEntity.VerboseInfo{
|
||||
MessageType: string(entity.MessageSubTypeKnowledgeCall),
|
||||
Data: string(data),
|
||||
}
|
||||
return knowledgeInfo
|
||||
}
|
||||
|
||||
func (c *runImpl) handlerFinalAnswerFinish(ctx context.Context, sw *schema.StreamWriter[*entity.AgentRunResponse], rtDependence *runtimeDependence) error {
|
||||
cm := c.buildAgentMessage2Create(ctx, nil, message.MessageTypeVerbose, rtDependence)
|
||||
cmData, err := crossmessage.DefaultSVC().Create(ctx, cm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sendMsg := c.buildSendMsg(ctx, cmData, true, rtDependence)
|
||||
|
||||
c.runEvent.SendMsgEvent(entity.RunEventMessageCompleted, sendMsg, sw)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *runImpl) handlerInterruptVerbose(ctx context.Context, chunk *entity.AgentRespEvent, sw *schema.StreamWriter[*entity.AgentRunResponse], rtDependence *runtimeDependence) error {
|
||||
cm := c.buildAgentMessage2Create(ctx, chunk, message.MessageTypeInterrupt, rtDependence)
|
||||
cmData, err := crossmessage.DefaultSVC().Create(ctx, cm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sendMsg := c.buildSendMsg(ctx, cmData, true, rtDependence)
|
||||
|
||||
c.runEvent.SendMsgEvent(entity.RunEventMessageCompleted, sendMsg, sw)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *runImpl) buildSendMsg(_ context.Context, msg *msgEntity.Message, isFinish bool, rtDependence *runtimeDependence) *entity.ChunkMessageItem {
|
||||
|
||||
copyMap := make(map[string]string)
|
||||
for k, v := range msg.Ext {
|
||||
copyMap[k] = v
|
||||
}
|
||||
|
||||
return &entity.ChunkMessageItem{
|
||||
ID: msg.ID,
|
||||
ConversationID: msg.ConversationID,
|
||||
SectionID: msg.SectionID,
|
||||
AgentID: msg.AgentID,
|
||||
Content: msg.Content,
|
||||
Role: entity.RoleTypeAssistant,
|
||||
ContentType: msg.ContentType,
|
||||
MessageType: msg.MessageType,
|
||||
ReplyID: rtDependence.questionMsgID,
|
||||
Type: msg.MessageType,
|
||||
CreatedAt: msg.CreatedAt,
|
||||
UpdatedAt: msg.UpdatedAt,
|
||||
RunID: rtDependence.runID,
|
||||
Ext: copyMap,
|
||||
IsFinish: isFinish,
|
||||
ReasoningContent: ptr.Of(msg.ReasoningContent),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *runImpl) buildSendRunRecord(_ context.Context, runRecord *entity.RunRecordMeta, runStatus entity.RunStatus) *entity.ChunkRunItem {
|
||||
return &entity.ChunkRunItem{
|
||||
ID: runRecord.ID,
|
||||
ConversationID: runRecord.ConversationID,
|
||||
AgentID: runRecord.AgentID,
|
||||
SectionID: runRecord.SectionID,
|
||||
Status: runStatus,
|
||||
CreatedAt: runRecord.CreatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *runImpl) Delete(ctx context.Context, runID []int64) error {
|
||||
return c.RunRecordRepo.Delete(ctx, runID)
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package agentrun
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAgentRun(t *testing.T) {
|
||||
// ctx := context.Background()
|
||||
//
|
||||
// mockDB, err := mysql.New()
|
||||
// assert.Nil(t, err)
|
||||
// cacheCli := redis.New()
|
||||
//
|
||||
// idGen, err := idgen.New(cacheCli)
|
||||
// ctrl := gomock.NewController(t)
|
||||
// idGen := mock.NewMockIDGenerator(ctrl)
|
||||
// // idGen.EXPECT().GenMultiIDs(gomock.Any(), 2).Return([]int64{time.Now().UnixMilli(), time.Now().Add(time.Second).UnixMilli()}, nil).AnyTimes()
|
||||
// idGen.EXPECT().GenID(gomock.Any()).Return(int64(time.Now().UnixMilli()), nil).AnyTimes()
|
||||
//
|
||||
// mockDBGen := orm.NewMockDB()
|
||||
// mockDBGen.AddTable(&model.RunRecord{})
|
||||
// mockDB, err := mockDBGen.DB()
|
||||
//
|
||||
// assert.NoError(t, err)
|
||||
// components := &Components{
|
||||
// DB: mockDB,
|
||||
// IDGen: idGen,
|
||||
// }
|
||||
//
|
||||
// imageInput := &entity.FileData{
|
||||
// Url: "https://xxxxx.xxxx/image",
|
||||
// Name: "test_img",
|
||||
// }
|
||||
// fileInput := &entity.FileData{
|
||||
// Url: "https://xxxxx.xxxx/file",
|
||||
// Name: "test_file",
|
||||
// }
|
||||
// content := []*entity.InputMetaData{
|
||||
// {
|
||||
// Type: entity.InputTypeText,
|
||||
// Text: "你是谁",
|
||||
// },
|
||||
// {
|
||||
// Type: entity.InputTypeImage,
|
||||
// FileData: []*entity.FileData{
|
||||
// imageInput,
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// Type: entity.InputTypeFile,
|
||||
// FileData: []*entity.FileData{
|
||||
// fileInput,
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
// stream, err := NewService(components, nil).AgentRun(ctx, &entity.AgentRunMeta{
|
||||
// ConversationID: 7503546991712960512,
|
||||
// SpaceID: 666,
|
||||
// SectionID: 7503546991712976896,
|
||||
// UserID: 888,
|
||||
// AgentID: 7501996002144944128,
|
||||
// Content: content,
|
||||
// ContentType: entity.ContentTypeMulti,
|
||||
// })
|
||||
// assert.NoError(t, err)
|
||||
// t.Logf("------------stream: %+v; err:%v", stream, err)
|
||||
//
|
||||
// for {
|
||||
// chunk, errRecv := stream.Recv()
|
||||
// jsonStr, _ := json.Marshal(chunk)
|
||||
// fmt.Println(string(jsonStr))
|
||||
// if errRecv == io.EOF || chunk == nil || chunk.Event == entity.RunEventStreamDone {
|
||||
// break
|
||||
// }
|
||||
// if errRecv != nil {
|
||||
// assert.NoError(t, errRecv)
|
||||
// break
|
||||
// }
|
||||
// }
|
||||
|
||||
// assert.NoError(t, err)
|
||||
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package entity
|
||||
|
||||
import (
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/conversation/common"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/conversation"
|
||||
)
|
||||
|
||||
type Conversation = conversation.Conversation
|
||||
|
||||
type CreateMeta struct {
|
||||
AgentID int64 `json:"agent_id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
ConnectorID int64 `json:"connector_id"`
|
||||
Scene common.Scene `json:"scene"`
|
||||
Ext string `json:"ext"`
|
||||
}
|
||||
|
||||
type NewConversationCtxRequest struct {
|
||||
ID int64 `json:"id"`
|
||||
}
|
||||
|
||||
type NewConversationCtxResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
SectionID int64 `json:"section_id"`
|
||||
}
|
||||
|
||||
type GetCurrent = conversation.GetCurrent
|
||||
|
||||
type ListMeta struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
ConnectorID int64 `json:"connector_id"`
|
||||
Scene common.Scene `json:"scene"`
|
||||
AgentID int64 `json:"agent_id"`
|
||||
Limit int `json:"limit"`
|
||||
Page int `json:"page"`
|
||||
}
|
||||
209
backend/domain/conversation/conversation/internal/dal/dao.go
Normal file
209
backend/domain/conversation/conversation/internal/dal/dao.go
Normal file
@@ -0,0 +1,209 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package dal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/conversation/common"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/conversation"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/internal/dal/model"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/internal/dal/query"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
)
|
||||
|
||||
type ConversationDAO struct {
|
||||
idgen idgen.IDGenerator
|
||||
db *gorm.DB
|
||||
query *query.Query
|
||||
}
|
||||
|
||||
func NewConversationDAO(db *gorm.DB, generator idgen.IDGenerator) *ConversationDAO {
|
||||
return &ConversationDAO{
|
||||
idgen: generator,
|
||||
db: db,
|
||||
query: query.Use(db),
|
||||
}
|
||||
}
|
||||
|
||||
func (dao *ConversationDAO) Create(ctx context.Context, msg *entity.Conversation) (*entity.Conversation, error) {
|
||||
poData := dao.conversationDO2PO(ctx, msg)
|
||||
|
||||
ids, err := dao.idgen.GenMultiIDs(ctx, 2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
poData.ID = ids[0]
|
||||
poData.SectionID = ids[1]
|
||||
|
||||
err = dao.query.Conversation.WithContext(ctx).Create(poData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dao.conversationPO2DO(ctx, poData), nil
|
||||
}
|
||||
|
||||
func (dao *ConversationDAO) GetByID(ctx context.Context, id int64) (*entity.Conversation, error) {
|
||||
poData, err := dao.query.Conversation.WithContext(ctx).Debug().Where(dao.query.Conversation.ID.Eq(id)).First()
|
||||
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dao.conversationPO2DO(ctx, poData), nil
|
||||
}
|
||||
|
||||
func (dao *ConversationDAO) UpdateSection(ctx context.Context, id int64) (int64, error) {
|
||||
updateColumn := make(map[string]interface{})
|
||||
table := dao.query.Conversation
|
||||
newSectionID, err := dao.idgen.GenID(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
updateColumn[table.SectionID.ColumnName().String()] = newSectionID
|
||||
updateColumn[table.UpdatedAt.ColumnName().String()] = time.Now().UnixMilli()
|
||||
|
||||
_, err = dao.query.Conversation.WithContext(ctx).Where(dao.query.Conversation.ID.Eq(id)).UpdateColumns(updateColumn)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return newSectionID, nil
|
||||
}
|
||||
|
||||
func (dao *ConversationDAO) Delete(ctx context.Context, id int64) (int64, error) {
|
||||
table := dao.query.Conversation
|
||||
|
||||
updateColumn := make(map[string]interface{})
|
||||
updateColumn[table.UpdatedAt.ColumnName().String()] = time.Now().UnixMilli()
|
||||
updateColumn[table.Status.ColumnName().String()] = conversation.ConversationStatusDeleted
|
||||
|
||||
updateRes, err := dao.query.Conversation.WithContext(ctx).Where(dao.query.Conversation.ID.Eq(id)).UpdateColumns(updateColumn)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return updateRes.RowsAffected, err
|
||||
}
|
||||
|
||||
func (dao *ConversationDAO) Get(ctx context.Context, userID int64, agentID int64, scene int32, connectorID int64) (*entity.Conversation, error) {
|
||||
po, err := dao.query.Conversation.WithContext(ctx).Debug().
|
||||
Where(dao.query.Conversation.CreatorID.Eq(userID)).
|
||||
Where(dao.query.Conversation.AgentID.Eq(agentID)).
|
||||
Where(dao.query.Conversation.Scene.Eq(scene)).
|
||||
Where(dao.query.Conversation.ConnectorID.Eq(connectorID)).
|
||||
Where(dao.query.Conversation.Status.Eq(int32(conversation.ConversationStatusNormal))).
|
||||
Order(dao.query.Conversation.CreatedAt.Desc()).
|
||||
First()
|
||||
|
||||
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dao.conversationPO2DO(ctx, po), nil
|
||||
}
|
||||
|
||||
func (dao *ConversationDAO) List(ctx context.Context, userID int64, agentID int64, connectorID int64, scene int32, limit int, page int) ([]*entity.Conversation, bool, error) {
|
||||
var hasMore bool
|
||||
|
||||
do := dao.query.Conversation.WithContext(ctx).Debug()
|
||||
do = do.Where(dao.query.Conversation.CreatorID.Eq(userID)).
|
||||
Where(dao.query.Conversation.AgentID.Eq(agentID)).
|
||||
Where(dao.query.Conversation.Scene.Eq(scene)).
|
||||
Where(dao.query.Conversation.ConnectorID.Eq(connectorID))
|
||||
|
||||
do = do.Offset((page - 1) * limit)
|
||||
|
||||
if limit > 0 {
|
||||
do = do.Limit(int(limit) + 1)
|
||||
}
|
||||
|
||||
poList, err := do.Find()
|
||||
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, hasMore, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, hasMore, err
|
||||
}
|
||||
|
||||
if len(poList) == 0 {
|
||||
return nil, hasMore, nil
|
||||
}
|
||||
if len(poList) > limit {
|
||||
hasMore = true
|
||||
return dao.conversationBatchPO2DO(ctx, poList[:(len(poList)-1)]), hasMore, nil
|
||||
|
||||
}
|
||||
return dao.conversationBatchPO2DO(ctx, poList), hasMore, nil
|
||||
}
|
||||
|
||||
func (dao *ConversationDAO) conversationDO2PO(ctx context.Context, conversation *entity.Conversation) *model.Conversation {
|
||||
return &model.Conversation{
|
||||
ID: conversation.ID,
|
||||
SectionID: conversation.SectionID,
|
||||
ConnectorID: conversation.ConnectorID,
|
||||
AgentID: conversation.AgentID,
|
||||
CreatorID: conversation.CreatorID,
|
||||
Scene: int32(conversation.Scene),
|
||||
Status: int32(conversation.Status),
|
||||
Ext: conversation.Ext,
|
||||
CreatedAt: time.Now().UnixMilli(),
|
||||
UpdatedAt: time.Now().UnixMilli(),
|
||||
}
|
||||
}
|
||||
|
||||
func (dao *ConversationDAO) conversationPO2DO(ctx context.Context, c *model.Conversation) *entity.Conversation {
|
||||
return &entity.Conversation{
|
||||
ID: c.ID,
|
||||
SectionID: c.SectionID,
|
||||
ConnectorID: c.ConnectorID,
|
||||
AgentID: c.AgentID,
|
||||
CreatorID: c.CreatorID,
|
||||
Scene: common.Scene(c.Scene),
|
||||
Status: conversation.ConversationStatus(c.Status),
|
||||
Ext: c.Ext,
|
||||
CreatedAt: c.CreatedAt,
|
||||
UpdatedAt: c.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func (dao *ConversationDAO) conversationBatchPO2DO(ctx context.Context, conversations []*model.Conversation) []*entity.Conversation {
|
||||
return slices.Transform(conversations, func(c *model.Conversation) *entity.Conversation {
|
||||
return &entity.Conversation{
|
||||
ID: c.ID,
|
||||
SectionID: c.SectionID,
|
||||
ConnectorID: c.ConnectorID,
|
||||
AgentID: c.AgentID,
|
||||
CreatorID: c.CreatorID,
|
||||
Scene: common.Scene(c.Scene),
|
||||
Status: conversation.ConversationStatus(c.Status),
|
||||
Ext: c.Ext,
|
||||
CreatedAt: c.CreatedAt,
|
||||
UpdatedAt: c.UpdatedAt,
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
|
||||
package model
|
||||
|
||||
const TableNameConversation = "conversation"
|
||||
|
||||
// Conversation 会话信息表
|
||||
type Conversation struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;autoIncrement:true;comment:主键ID" json:"id"` // 主键ID
|
||||
ConnectorID int64 `gorm:"column:connector_id;not null;comment:业务线 ID" json:"connector_id"` // 业务线 ID
|
||||
AgentID int64 `gorm:"column:agent_id;not null;comment:agent_id" json:"agent_id"` // agent_id
|
||||
Scene int32 `gorm:"column:scene;not null;comment:会话场景" json:"scene"` // 会话场景
|
||||
SectionID int64 `gorm:"column:section_id;not null;comment:最新section_id" json:"section_id"` // 最新section_id
|
||||
CreatorID int64 `gorm:"column:creator_id;comment:创建者id" json:"creator_id"` // 创建者id
|
||||
Ext string `gorm:"column:ext;comment:扩展字段" json:"ext"` // 扩展字段
|
||||
Status int32 `gorm:"column:status;not null;default:1;comment:status: 1-normal 2-deleted" json:"status"` // status: 1-normal 2-deleted
|
||||
CreatedAt int64 `gorm:"column:created_at;not null;autoCreateTime:milli;comment:创建时间" json:"created_at"` // 创建时间
|
||||
UpdatedAt int64 `gorm:"column:updated_at;not null;autoUpdateTime:milli;comment:更新时间" json:"updated_at"` // 更新时间
|
||||
}
|
||||
|
||||
// TableName Conversation's table name
|
||||
func (*Conversation) TableName() string {
|
||||
return TableNameConversation
|
||||
}
|
||||
@@ -0,0 +1,417 @@
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
|
||||
package query
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
|
||||
"gorm.io/gen"
|
||||
"gorm.io/gen/field"
|
||||
|
||||
"gorm.io/plugin/dbresolver"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/internal/dal/model"
|
||||
)
|
||||
|
||||
func newConversation(db *gorm.DB, opts ...gen.DOOption) conversation {
|
||||
_conversation := conversation{}
|
||||
|
||||
_conversation.conversationDo.UseDB(db, opts...)
|
||||
_conversation.conversationDo.UseModel(&model.Conversation{})
|
||||
|
||||
tableName := _conversation.conversationDo.TableName()
|
||||
_conversation.ALL = field.NewAsterisk(tableName)
|
||||
_conversation.ID = field.NewInt64(tableName, "id")
|
||||
_conversation.ConnectorID = field.NewInt64(tableName, "connector_id")
|
||||
_conversation.AgentID = field.NewInt64(tableName, "agent_id")
|
||||
_conversation.Scene = field.NewInt32(tableName, "scene")
|
||||
_conversation.SectionID = field.NewInt64(tableName, "section_id")
|
||||
_conversation.CreatorID = field.NewInt64(tableName, "creator_id")
|
||||
_conversation.Ext = field.NewString(tableName, "ext")
|
||||
_conversation.Status = field.NewInt32(tableName, "status")
|
||||
_conversation.CreatedAt = field.NewInt64(tableName, "created_at")
|
||||
_conversation.UpdatedAt = field.NewInt64(tableName, "updated_at")
|
||||
|
||||
_conversation.fillFieldMap()
|
||||
|
||||
return _conversation
|
||||
}
|
||||
|
||||
// conversation 会话信息表
|
||||
type conversation struct {
|
||||
conversationDo
|
||||
|
||||
ALL field.Asterisk
|
||||
ID field.Int64 // 主键ID
|
||||
ConnectorID field.Int64 // 业务线 ID
|
||||
AgentID field.Int64 // agent_id
|
||||
Scene field.Int32 // 会话场景
|
||||
SectionID field.Int64 // 最新section_id
|
||||
CreatorID field.Int64 // 创建者id
|
||||
Ext field.String // 扩展字段
|
||||
Status field.Int32 // status: 1-normal 2-deleted
|
||||
CreatedAt field.Int64 // 创建时间
|
||||
UpdatedAt field.Int64 // 更新时间
|
||||
|
||||
fieldMap map[string]field.Expr
|
||||
}
|
||||
|
||||
func (c conversation) Table(newTableName string) *conversation {
|
||||
c.conversationDo.UseTable(newTableName)
|
||||
return c.updateTableName(newTableName)
|
||||
}
|
||||
|
||||
func (c conversation) As(alias string) *conversation {
|
||||
c.conversationDo.DO = *(c.conversationDo.As(alias).(*gen.DO))
|
||||
return c.updateTableName(alias)
|
||||
}
|
||||
|
||||
func (c *conversation) updateTableName(table string) *conversation {
|
||||
c.ALL = field.NewAsterisk(table)
|
||||
c.ID = field.NewInt64(table, "id")
|
||||
c.ConnectorID = field.NewInt64(table, "connector_id")
|
||||
c.AgentID = field.NewInt64(table, "agent_id")
|
||||
c.Scene = field.NewInt32(table, "scene")
|
||||
c.SectionID = field.NewInt64(table, "section_id")
|
||||
c.CreatorID = field.NewInt64(table, "creator_id")
|
||||
c.Ext = field.NewString(table, "ext")
|
||||
c.Status = field.NewInt32(table, "status")
|
||||
c.CreatedAt = field.NewInt64(table, "created_at")
|
||||
c.UpdatedAt = field.NewInt64(table, "updated_at")
|
||||
|
||||
c.fillFieldMap()
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *conversation) GetFieldByName(fieldName string) (field.OrderExpr, bool) {
|
||||
_f, ok := c.fieldMap[fieldName]
|
||||
if !ok || _f == nil {
|
||||
return nil, false
|
||||
}
|
||||
_oe, ok := _f.(field.OrderExpr)
|
||||
return _oe, ok
|
||||
}
|
||||
|
||||
func (c *conversation) fillFieldMap() {
|
||||
c.fieldMap = make(map[string]field.Expr, 10)
|
||||
c.fieldMap["id"] = c.ID
|
||||
c.fieldMap["connector_id"] = c.ConnectorID
|
||||
c.fieldMap["agent_id"] = c.AgentID
|
||||
c.fieldMap["scene"] = c.Scene
|
||||
c.fieldMap["section_id"] = c.SectionID
|
||||
c.fieldMap["creator_id"] = c.CreatorID
|
||||
c.fieldMap["ext"] = c.Ext
|
||||
c.fieldMap["status"] = c.Status
|
||||
c.fieldMap["created_at"] = c.CreatedAt
|
||||
c.fieldMap["updated_at"] = c.UpdatedAt
|
||||
}
|
||||
|
||||
func (c conversation) clone(db *gorm.DB) conversation {
|
||||
c.conversationDo.ReplaceConnPool(db.Statement.ConnPool)
|
||||
return c
|
||||
}
|
||||
|
||||
func (c conversation) replaceDB(db *gorm.DB) conversation {
|
||||
c.conversationDo.ReplaceDB(db)
|
||||
return c
|
||||
}
|
||||
|
||||
type conversationDo struct{ gen.DO }
|
||||
|
||||
type IConversationDo interface {
|
||||
gen.SubQuery
|
||||
Debug() IConversationDo
|
||||
WithContext(ctx context.Context) IConversationDo
|
||||
WithResult(fc func(tx gen.Dao)) gen.ResultInfo
|
||||
ReplaceDB(db *gorm.DB)
|
||||
ReadDB() IConversationDo
|
||||
WriteDB() IConversationDo
|
||||
As(alias string) gen.Dao
|
||||
Session(config *gorm.Session) IConversationDo
|
||||
Columns(cols ...field.Expr) gen.Columns
|
||||
Clauses(conds ...clause.Expression) IConversationDo
|
||||
Not(conds ...gen.Condition) IConversationDo
|
||||
Or(conds ...gen.Condition) IConversationDo
|
||||
Select(conds ...field.Expr) IConversationDo
|
||||
Where(conds ...gen.Condition) IConversationDo
|
||||
Order(conds ...field.Expr) IConversationDo
|
||||
Distinct(cols ...field.Expr) IConversationDo
|
||||
Omit(cols ...field.Expr) IConversationDo
|
||||
Join(table schema.Tabler, on ...field.Expr) IConversationDo
|
||||
LeftJoin(table schema.Tabler, on ...field.Expr) IConversationDo
|
||||
RightJoin(table schema.Tabler, on ...field.Expr) IConversationDo
|
||||
Group(cols ...field.Expr) IConversationDo
|
||||
Having(conds ...gen.Condition) IConversationDo
|
||||
Limit(limit int) IConversationDo
|
||||
Offset(offset int) IConversationDo
|
||||
Count() (count int64, err error)
|
||||
Scopes(funcs ...func(gen.Dao) gen.Dao) IConversationDo
|
||||
Unscoped() IConversationDo
|
||||
Create(values ...*model.Conversation) error
|
||||
CreateInBatches(values []*model.Conversation, batchSize int) error
|
||||
Save(values ...*model.Conversation) error
|
||||
First() (*model.Conversation, error)
|
||||
Take() (*model.Conversation, error)
|
||||
Last() (*model.Conversation, error)
|
||||
Find() ([]*model.Conversation, error)
|
||||
FindInBatch(batchSize int, fc func(tx gen.Dao, batch int) error) (results []*model.Conversation, err error)
|
||||
FindInBatches(result *[]*model.Conversation, batchSize int, fc func(tx gen.Dao, batch int) error) error
|
||||
Pluck(column field.Expr, dest interface{}) error
|
||||
Delete(...*model.Conversation) (info gen.ResultInfo, err error)
|
||||
Update(column field.Expr, value interface{}) (info gen.ResultInfo, err error)
|
||||
UpdateSimple(columns ...field.AssignExpr) (info gen.ResultInfo, err error)
|
||||
Updates(value interface{}) (info gen.ResultInfo, err error)
|
||||
UpdateColumn(column field.Expr, value interface{}) (info gen.ResultInfo, err error)
|
||||
UpdateColumnSimple(columns ...field.AssignExpr) (info gen.ResultInfo, err error)
|
||||
UpdateColumns(value interface{}) (info gen.ResultInfo, err error)
|
||||
UpdateFrom(q gen.SubQuery) gen.Dao
|
||||
Attrs(attrs ...field.AssignExpr) IConversationDo
|
||||
Assign(attrs ...field.AssignExpr) IConversationDo
|
||||
Joins(fields ...field.RelationField) IConversationDo
|
||||
Preload(fields ...field.RelationField) IConversationDo
|
||||
FirstOrInit() (*model.Conversation, error)
|
||||
FirstOrCreate() (*model.Conversation, error)
|
||||
FindByPage(offset int, limit int) (result []*model.Conversation, count int64, err error)
|
||||
ScanByPage(result interface{}, offset int, limit int) (count int64, err error)
|
||||
Scan(result interface{}) (err error)
|
||||
Returning(value interface{}, columns ...string) IConversationDo
|
||||
UnderlyingDB() *gorm.DB
|
||||
schema.Tabler
|
||||
}
|
||||
|
||||
func (c conversationDo) Debug() IConversationDo {
|
||||
return c.withDO(c.DO.Debug())
|
||||
}
|
||||
|
||||
func (c conversationDo) WithContext(ctx context.Context) IConversationDo {
|
||||
return c.withDO(c.DO.WithContext(ctx))
|
||||
}
|
||||
|
||||
func (c conversationDo) ReadDB() IConversationDo {
|
||||
return c.Clauses(dbresolver.Read)
|
||||
}
|
||||
|
||||
func (c conversationDo) WriteDB() IConversationDo {
|
||||
return c.Clauses(dbresolver.Write)
|
||||
}
|
||||
|
||||
func (c conversationDo) Session(config *gorm.Session) IConversationDo {
|
||||
return c.withDO(c.DO.Session(config))
|
||||
}
|
||||
|
||||
func (c conversationDo) Clauses(conds ...clause.Expression) IConversationDo {
|
||||
return c.withDO(c.DO.Clauses(conds...))
|
||||
}
|
||||
|
||||
func (c conversationDo) Returning(value interface{}, columns ...string) IConversationDo {
|
||||
return c.withDO(c.DO.Returning(value, columns...))
|
||||
}
|
||||
|
||||
func (c conversationDo) Not(conds ...gen.Condition) IConversationDo {
|
||||
return c.withDO(c.DO.Not(conds...))
|
||||
}
|
||||
|
||||
func (c conversationDo) Or(conds ...gen.Condition) IConversationDo {
|
||||
return c.withDO(c.DO.Or(conds...))
|
||||
}
|
||||
|
||||
func (c conversationDo) Select(conds ...field.Expr) IConversationDo {
|
||||
return c.withDO(c.DO.Select(conds...))
|
||||
}
|
||||
|
||||
func (c conversationDo) Where(conds ...gen.Condition) IConversationDo {
|
||||
return c.withDO(c.DO.Where(conds...))
|
||||
}
|
||||
|
||||
func (c conversationDo) Order(conds ...field.Expr) IConversationDo {
|
||||
return c.withDO(c.DO.Order(conds...))
|
||||
}
|
||||
|
||||
func (c conversationDo) Distinct(cols ...field.Expr) IConversationDo {
|
||||
return c.withDO(c.DO.Distinct(cols...))
|
||||
}
|
||||
|
||||
func (c conversationDo) Omit(cols ...field.Expr) IConversationDo {
|
||||
return c.withDO(c.DO.Omit(cols...))
|
||||
}
|
||||
|
||||
func (c conversationDo) Join(table schema.Tabler, on ...field.Expr) IConversationDo {
|
||||
return c.withDO(c.DO.Join(table, on...))
|
||||
}
|
||||
|
||||
func (c conversationDo) LeftJoin(table schema.Tabler, on ...field.Expr) IConversationDo {
|
||||
return c.withDO(c.DO.LeftJoin(table, on...))
|
||||
}
|
||||
|
||||
func (c conversationDo) RightJoin(table schema.Tabler, on ...field.Expr) IConversationDo {
|
||||
return c.withDO(c.DO.RightJoin(table, on...))
|
||||
}
|
||||
|
||||
func (c conversationDo) Group(cols ...field.Expr) IConversationDo {
|
||||
return c.withDO(c.DO.Group(cols...))
|
||||
}
|
||||
|
||||
func (c conversationDo) Having(conds ...gen.Condition) IConversationDo {
|
||||
return c.withDO(c.DO.Having(conds...))
|
||||
}
|
||||
|
||||
func (c conversationDo) Limit(limit int) IConversationDo {
|
||||
return c.withDO(c.DO.Limit(limit))
|
||||
}
|
||||
|
||||
func (c conversationDo) Offset(offset int) IConversationDo {
|
||||
return c.withDO(c.DO.Offset(offset))
|
||||
}
|
||||
|
||||
func (c conversationDo) Scopes(funcs ...func(gen.Dao) gen.Dao) IConversationDo {
|
||||
return c.withDO(c.DO.Scopes(funcs...))
|
||||
}
|
||||
|
||||
func (c conversationDo) Unscoped() IConversationDo {
|
||||
return c.withDO(c.DO.Unscoped())
|
||||
}
|
||||
|
||||
func (c conversationDo) Create(values ...*model.Conversation) error {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
return c.DO.Create(values)
|
||||
}
|
||||
|
||||
func (c conversationDo) CreateInBatches(values []*model.Conversation, batchSize int) error {
|
||||
return c.DO.CreateInBatches(values, batchSize)
|
||||
}
|
||||
|
||||
// Save : !!! underlying implementation is different with GORM
|
||||
// The method is equivalent to executing the statement: db.Clauses(clause.OnConflict{UpdateAll: true}).Create(values)
|
||||
func (c conversationDo) Save(values ...*model.Conversation) error {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
return c.DO.Save(values)
|
||||
}
|
||||
|
||||
func (c conversationDo) First() (*model.Conversation, error) {
|
||||
if result, err := c.DO.First(); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return result.(*model.Conversation), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c conversationDo) Take() (*model.Conversation, error) {
|
||||
if result, err := c.DO.Take(); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return result.(*model.Conversation), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c conversationDo) Last() (*model.Conversation, error) {
|
||||
if result, err := c.DO.Last(); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return result.(*model.Conversation), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c conversationDo) Find() ([]*model.Conversation, error) {
|
||||
result, err := c.DO.Find()
|
||||
return result.([]*model.Conversation), err
|
||||
}
|
||||
|
||||
func (c conversationDo) FindInBatch(batchSize int, fc func(tx gen.Dao, batch int) error) (results []*model.Conversation, err error) {
|
||||
buf := make([]*model.Conversation, 0, batchSize)
|
||||
err = c.DO.FindInBatches(&buf, batchSize, func(tx gen.Dao, batch int) error {
|
||||
defer func() { results = append(results, buf...) }()
|
||||
return fc(tx, batch)
|
||||
})
|
||||
return results, err
|
||||
}
|
||||
|
||||
func (c conversationDo) FindInBatches(result *[]*model.Conversation, batchSize int, fc func(tx gen.Dao, batch int) error) error {
|
||||
return c.DO.FindInBatches(result, batchSize, fc)
|
||||
}
|
||||
|
||||
func (c conversationDo) Attrs(attrs ...field.AssignExpr) IConversationDo {
|
||||
return c.withDO(c.DO.Attrs(attrs...))
|
||||
}
|
||||
|
||||
func (c conversationDo) Assign(attrs ...field.AssignExpr) IConversationDo {
|
||||
return c.withDO(c.DO.Assign(attrs...))
|
||||
}
|
||||
|
||||
func (c conversationDo) Joins(fields ...field.RelationField) IConversationDo {
|
||||
for _, _f := range fields {
|
||||
c = *c.withDO(c.DO.Joins(_f))
|
||||
}
|
||||
return &c
|
||||
}
|
||||
|
||||
func (c conversationDo) Preload(fields ...field.RelationField) IConversationDo {
|
||||
for _, _f := range fields {
|
||||
c = *c.withDO(c.DO.Preload(_f))
|
||||
}
|
||||
return &c
|
||||
}
|
||||
|
||||
func (c conversationDo) FirstOrInit() (*model.Conversation, error) {
|
||||
if result, err := c.DO.FirstOrInit(); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return result.(*model.Conversation), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c conversationDo) FirstOrCreate() (*model.Conversation, error) {
|
||||
if result, err := c.DO.FirstOrCreate(); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return result.(*model.Conversation), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c conversationDo) FindByPage(offset int, limit int) (result []*model.Conversation, count int64, err error) {
|
||||
result, err = c.Offset(offset).Limit(limit).Find()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if size := len(result); 0 < limit && 0 < size && size < limit {
|
||||
count = int64(size + offset)
|
||||
return
|
||||
}
|
||||
|
||||
count, err = c.Offset(-1).Limit(-1).Count()
|
||||
return
|
||||
}
|
||||
|
||||
func (c conversationDo) ScanByPage(result interface{}, offset int, limit int) (count int64, err error) {
|
||||
count, err = c.Count()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = c.Offset(offset).Limit(limit).Scan(result)
|
||||
return
|
||||
}
|
||||
|
||||
func (c conversationDo) Scan(result interface{}) (err error) {
|
||||
return c.DO.Scan(result)
|
||||
}
|
||||
|
||||
func (c conversationDo) Delete(models ...*model.Conversation) (result gen.ResultInfo, err error) {
|
||||
return c.DO.Delete(models)
|
||||
}
|
||||
|
||||
func (c *conversationDo) withDO(do gen.Dao) *conversationDo {
|
||||
c.DO = *do.(*gen.DO)
|
||||
return c
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
|
||||
package query
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"gorm.io/gen"
|
||||
|
||||
"gorm.io/plugin/dbresolver"
|
||||
)
|
||||
|
||||
var (
|
||||
Q = new(Query)
|
||||
Conversation *conversation
|
||||
)
|
||||
|
||||
func SetDefault(db *gorm.DB, opts ...gen.DOOption) {
|
||||
*Q = *Use(db, opts...)
|
||||
Conversation = &Q.Conversation
|
||||
}
|
||||
|
||||
func Use(db *gorm.DB, opts ...gen.DOOption) *Query {
|
||||
return &Query{
|
||||
db: db,
|
||||
Conversation: newConversation(db, opts...),
|
||||
}
|
||||
}
|
||||
|
||||
type Query struct {
|
||||
db *gorm.DB
|
||||
|
||||
Conversation conversation
|
||||
}
|
||||
|
||||
func (q *Query) Available() bool { return q.db != nil }
|
||||
|
||||
func (q *Query) clone(db *gorm.DB) *Query {
|
||||
return &Query{
|
||||
db: db,
|
||||
Conversation: q.Conversation.clone(db),
|
||||
}
|
||||
}
|
||||
|
||||
func (q *Query) ReadDB() *Query {
|
||||
return q.ReplaceDB(q.db.Clauses(dbresolver.Read))
|
||||
}
|
||||
|
||||
func (q *Query) WriteDB() *Query {
|
||||
return q.ReplaceDB(q.db.Clauses(dbresolver.Write))
|
||||
}
|
||||
|
||||
func (q *Query) ReplaceDB(db *gorm.DB) *Query {
|
||||
return &Query{
|
||||
db: db,
|
||||
Conversation: q.Conversation.replaceDB(db),
|
||||
}
|
||||
}
|
||||
|
||||
type queryCtx struct {
|
||||
Conversation IConversationDo
|
||||
}
|
||||
|
||||
func (q *Query) WithContext(ctx context.Context) *queryCtx {
|
||||
return &queryCtx{
|
||||
Conversation: q.Conversation.WithContext(ctx),
|
||||
}
|
||||
}
|
||||
|
||||
func (q *Query) Transaction(fc func(tx *Query) error, opts ...*sql.TxOptions) error {
|
||||
return q.db.Transaction(func(tx *gorm.DB) error { return fc(q.clone(tx)) }, opts...)
|
||||
}
|
||||
|
||||
func (q *Query) Begin(opts ...*sql.TxOptions) *QueryTx {
|
||||
tx := q.db.Begin(opts...)
|
||||
return &QueryTx{Query: q.clone(tx), Error: tx.Error}
|
||||
}
|
||||
|
||||
type QueryTx struct {
|
||||
*Query
|
||||
Error error
|
||||
}
|
||||
|
||||
func (q *QueryTx) Commit() error {
|
||||
return q.db.Commit().Error
|
||||
}
|
||||
|
||||
func (q *QueryTx) Rollback() error {
|
||||
return q.db.Rollback().Error
|
||||
}
|
||||
|
||||
func (q *QueryTx) SavePoint(name string) error {
|
||||
return q.db.SavePoint(name).Error
|
||||
}
|
||||
|
||||
func (q *QueryTx) RollbackTo(name string) error {
|
||||
return q.db.RollbackTo(name).Error
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/internal/dal"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
)
|
||||
|
||||
func NewConversationRepo(db *gorm.DB, idGen idgen.IDGenerator) ConversationRepo {
|
||||
return dal.NewConversationDAO(db, idGen)
|
||||
}
|
||||
|
||||
type ConversationRepo interface {
|
||||
Create(ctx context.Context, msg *entity.Conversation) (*entity.Conversation, error)
|
||||
GetByID(ctx context.Context, id int64) (*entity.Conversation, error)
|
||||
UpdateSection(ctx context.Context, id int64) (int64, error)
|
||||
Get(ctx context.Context, userID int64, agentID int64, scene int32, connectorID int64) (*entity.Conversation, error)
|
||||
Delete(ctx context.Context, id int64) (int64, error)
|
||||
List(ctx context.Context, userID int64, agentID int64, connectorID int64, scene int32, limit int, page int) ([]*entity.Conversation, bool, error)
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/entity"
|
||||
)
|
||||
|
||||
type Conversation interface {
|
||||
Create(ctx context.Context, req *entity.CreateMeta) (*entity.Conversation, error)
|
||||
GetByID(ctx context.Context, id int64) (*entity.Conversation, error)
|
||||
NewConversationCtx(ctx context.Context, req *entity.NewConversationCtxRequest) (*entity.NewConversationCtxResponse, error)
|
||||
GetCurrentConversation(ctx context.Context, req *entity.GetCurrent) (*entity.Conversation, error)
|
||||
Delete(ctx context.Context, id int64) error
|
||||
List(ctx context.Context, req *entity.ListMeta) ([]*entity.Conversation, bool, error)
|
||||
}
|
||||
@@ -0,0 +1,112 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/repository"
|
||||
)
|
||||
|
||||
type conversationImpl struct {
|
||||
Components
|
||||
}
|
||||
type Components struct {
|
||||
ConversationRepo repository.ConversationRepo
|
||||
}
|
||||
|
||||
func NewService(c *Components) Conversation {
|
||||
return &conversationImpl{
|
||||
Components: *c,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *conversationImpl) Create(ctx context.Context, req *entity.CreateMeta) (*entity.Conversation, error) {
|
||||
var resp *entity.Conversation
|
||||
|
||||
doData := &entity.Conversation{
|
||||
CreatorID: req.UserID,
|
||||
AgentID: req.AgentID,
|
||||
Scene: req.Scene,
|
||||
ConnectorID: req.ConnectorID,
|
||||
Ext: req.Ext,
|
||||
}
|
||||
|
||||
resp, err := c.ConversationRepo.Create(ctx, doData)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (c *conversationImpl) GetByID(ctx context.Context, id int64) (*entity.Conversation, error) {
|
||||
resp := &entity.Conversation{}
|
||||
// get conversation
|
||||
resp, err := c.ConversationRepo.GetByID(ctx, id)
|
||||
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (c *conversationImpl) NewConversationCtx(ctx context.Context, req *entity.NewConversationCtxRequest) (*entity.NewConversationCtxResponse, error) {
|
||||
resp := &entity.NewConversationCtxResponse{}
|
||||
|
||||
newSectionID, err := c.ConversationRepo.UpdateSection(ctx, req.ID)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
if newSectionID != 0 {
|
||||
resp.ID = req.ID
|
||||
resp.SectionID = newSectionID
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (c *conversationImpl) GetCurrentConversation(ctx context.Context, req *entity.GetCurrent) (*entity.Conversation, error) {
|
||||
// get conversation
|
||||
conversation, err := c.ConversationRepo.Get(ctx, req.UserID, req.AgentID, int32(req.Scene), req.ConnectorID)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// build data
|
||||
return conversation, nil
|
||||
}
|
||||
|
||||
func (c *conversationImpl) Delete(ctx context.Context, id int64) error {
|
||||
|
||||
_, err := c.ConversationRepo.Delete(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conversationImpl) List(ctx context.Context, req *entity.ListMeta) ([]*entity.Conversation, bool, error) {
|
||||
conversationList, hasMore, err := c.ConversationRepo.List(ctx, req.UserID, req.AgentID, req.ConnectorID, int32(req.Scene), req.Limit, req.Page)
|
||||
|
||||
if err != nil {
|
||||
return nil, hasMore, err
|
||||
}
|
||||
|
||||
return conversationList, hasMore, nil
|
||||
}
|
||||
@@ -0,0 +1,171 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/conversation/common"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/conversation"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/internal/dal/model"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/conversation/repository"
|
||||
mock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/orm"
|
||||
)
|
||||
|
||||
// Test_NewListMessage tests the NewListMessage function
|
||||
func TestCreateConversation(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// mockDB, _ := mysql.New()
|
||||
// redisCli := redis.New()
|
||||
// idGen, err := idgen.New(redisCli)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
idGen := mock.NewMockIDGenerator(ctrl)
|
||||
idGen.EXPECT().GenMultiIDs(gomock.Any(), 2).Return([]int64{
|
||||
1, 2,
|
||||
}, nil).AnyTimes()
|
||||
|
||||
mockDBGen := orm.NewMockDB()
|
||||
mockDBGen.AddTable(&model.Conversation{})
|
||||
mockDB, err := mockDBGen.DB()
|
||||
|
||||
components := &Components{
|
||||
ConversationRepo: repository.NewConversationRepo(mockDB, idGen),
|
||||
}
|
||||
|
||||
createData, err := NewService(components).Create(ctx, &entity.CreateMeta{
|
||||
AgentID: 100000,
|
||||
UserID: 222222,
|
||||
ConnectorID: 100001,
|
||||
Scene: common.Scene_Playground,
|
||||
Ext: "debug ext9999",
|
||||
})
|
||||
assert.NotNil(t, createData)
|
||||
|
||||
t.Logf("create conversation result: %v; err:%v", createData, err)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "debug ext9999", createData.Ext)
|
||||
}
|
||||
|
||||
func TestGetById(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
mockDBGen := orm.NewMockDB()
|
||||
mockDBGen.AddTable(&model.Conversation{})
|
||||
|
||||
mockDBGen.AddTable(&model.Conversation{}).
|
||||
AddRows(
|
||||
&model.Conversation{
|
||||
ID: 7494574457319587840,
|
||||
AgentID: 8888,
|
||||
SectionID: 100001,
|
||||
ConnectorID: 100001,
|
||||
CreatorID: 1111,
|
||||
Ext: "debug ext1111",
|
||||
},
|
||||
)
|
||||
|
||||
mockDB, err := mockDBGen.DB()
|
||||
ctrl := gomock.NewController(t)
|
||||
idGen := mock.NewMockIDGenerator(ctrl)
|
||||
idGen.EXPECT().GenID(gomock.Any()).Return(time.Now().UnixMilli(), nil).AnyTimes()
|
||||
|
||||
components := &Components{
|
||||
ConversationRepo: repository.NewConversationRepo(mockDB, idGen),
|
||||
}
|
||||
|
||||
cd, err := NewService(components).GetByID(ctx, 7494574457319587840)
|
||||
assert.NoError(t, err)
|
||||
|
||||
t.Logf("conversation result: %v; err:%v", cd, err)
|
||||
|
||||
assert.Equal(t, "debug ext1111", cd.Ext)
|
||||
}
|
||||
|
||||
func TestNewConversationCtx(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
idGen := mock.NewMockIDGenerator(ctrl)
|
||||
idGen.EXPECT().GenID(gomock.Any()).Return(int64(123456), nil).Times(1)
|
||||
mockDBGen := orm.NewMockDB()
|
||||
mockDBGen.AddTable(&model.Conversation{})
|
||||
mockDBGen.AddTable(&model.Conversation{}).
|
||||
AddRows(
|
||||
&model.Conversation{
|
||||
ID: 7494574457319587840,
|
||||
AgentID: 8888,
|
||||
SectionID: 100001,
|
||||
ConnectorID: 100001,
|
||||
CreatorID: 1111,
|
||||
},
|
||||
)
|
||||
mockDB, err := mockDBGen.DB()
|
||||
|
||||
assert.Nil(t, err)
|
||||
components := &Components{
|
||||
ConversationRepo: repository.NewConversationRepo(mockDB, idGen),
|
||||
}
|
||||
res, err := NewService(components).NewConversationCtx(ctx, &entity.NewConversationCtxRequest{
|
||||
ID: 7494574457319587840,
|
||||
})
|
||||
|
||||
t.Logf("conversation result: %v; err:%v", res, err)
|
||||
assert.Equal(t, int64(123456), res.SectionID)
|
||||
}
|
||||
|
||||
func TestConversationImpl_Delete(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mockDBGen := orm.NewMockDB()
|
||||
mockDBGen.AddTable(&model.Conversation{})
|
||||
mockDBGen.AddTable(&model.Conversation{}).
|
||||
AddRows(
|
||||
&model.Conversation{
|
||||
ID: 7494574457319587840,
|
||||
AgentID: 9999,
|
||||
SectionID: 100001,
|
||||
ConnectorID: 100001,
|
||||
CreatorID: 1111,
|
||||
Status: int32(conversation.ConversationStatusNormal),
|
||||
},
|
||||
)
|
||||
|
||||
mockDB, err := mockDBGen.DB()
|
||||
assert.Nil(t, err)
|
||||
components := &Components{
|
||||
ConversationRepo: repository.NewConversationRepo(mockDB, nil),
|
||||
}
|
||||
err = NewService(components).Delete(ctx, 7494574457319587840)
|
||||
t.Logf("delete err:%v", err)
|
||||
assert.Nil(t, err)
|
||||
|
||||
currentConversation, err := NewService(components).GetByID(ctx, 7494574457319587840)
|
||||
|
||||
assert.NotNil(t, currentConversation)
|
||||
|
||||
t.Logf("conversation result: %v; err:%v", currentConversation, err)
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, conversation.ConversationStatusDeleted, currentConversation.Status)
|
||||
}
|
||||
32
backend/domain/conversation/message/entity/const.go
Normal file
32
backend/domain/conversation/message/entity/const.go
Normal file
@@ -0,0 +1,32 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package entity
|
||||
|
||||
type ScrollPageDirection string
|
||||
|
||||
const (
|
||||
ScrollPageDirectionPrev ScrollPageDirection = "up"
|
||||
ScrollPageDirectionNext ScrollPageDirection = "down"
|
||||
)
|
||||
|
||||
type MessageStatus int32
|
||||
|
||||
const (
|
||||
MessageStatusAvailable MessageStatus = 1
|
||||
MessageStatusDeleted MessageStatus = 2
|
||||
MessageStatusBroken MessageStatus = 4
|
||||
)
|
||||
64
backend/domain/conversation/message/entity/knowledge.go
Normal file
64
backend/domain/conversation/message/entity/knowledge.go
Normal file
@@ -0,0 +1,64 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package entity
|
||||
|
||||
type VerboseInfo struct {
|
||||
MessageType string `json:"msg_type"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
type VerboseData struct {
|
||||
Chunks []RecallDataInfo `json:"chunks"`
|
||||
OriReq string `json:"ori_req"`
|
||||
StatusCode int `json:"status_code"`
|
||||
}
|
||||
|
||||
type RecallDataInfo struct {
|
||||
Slice string `json:"slice"`
|
||||
Score float64 `json:"score"`
|
||||
Meta MetaInfo `json:"meta"`
|
||||
}
|
||||
|
||||
type MetaInfo struct {
|
||||
Dataset DatasetInfo `json:"dataset"`
|
||||
Document DocumentInfo `json:"document"`
|
||||
Link LinkInfo `json:"link"`
|
||||
Card CardInfo `json:"card"`
|
||||
}
|
||||
|
||||
type DatasetInfo struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type DocumentInfo struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
FormatType int32 `json:"format_type"`
|
||||
SourceType int32 `json:"source_type"`
|
||||
}
|
||||
|
||||
type LinkInfo struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
type CardInfo struct {
|
||||
Title string `json:"title"`
|
||||
Con string `json:"con"`
|
||||
Index string `json:"index"`
|
||||
}
|
||||
55
backend/domain/conversation/message/entity/message.go
Normal file
55
backend/domain/conversation/message/entity/message.go
Normal file
@@ -0,0 +1,55 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package entity
|
||||
|
||||
import "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
|
||||
type Message = message.Message
|
||||
|
||||
type ListMeta struct {
|
||||
ConversationID int64 `json:"conversation_id"`
|
||||
RunID []*int64 `json:"run_id"`
|
||||
UserID string `json:"user_id"`
|
||||
AgentID int64 `json:"agent_id"`
|
||||
OrderBy *string `json:"order_by"`
|
||||
Limit int `json:"limit"`
|
||||
Cursor int64 `json:"cursor"` // message id
|
||||
Direction ScrollPageDirection `json:"direction"` // "prev" "Next"
|
||||
}
|
||||
|
||||
type ListResult struct {
|
||||
Messages []*Message `json:"messages"`
|
||||
PrevCursor int64 `json:"prev_cursor"`
|
||||
NextCursor int64 `json:"next_cursor"`
|
||||
HasMore bool `json:"has_more"`
|
||||
Direction ScrollPageDirection `json:"direction"`
|
||||
}
|
||||
|
||||
type GetByRunIDsRequest struct {
|
||||
ConversationID int64 `json:"conversation_id"`
|
||||
RunID []int64 `json:"run_id"`
|
||||
}
|
||||
|
||||
type DeleteMeta struct {
|
||||
MessageIDs []int64 `json:"message_ids"`
|
||||
RunIDs []int64 `json:"run_ids"`
|
||||
}
|
||||
|
||||
type BrokenMeta struct {
|
||||
ID int64 `json:"id"`
|
||||
Position *int32 `json:"position"`
|
||||
}
|
||||
55
backend/domain/conversation/message/entity/message_ext.go
Normal file
55
backend/domain/conversation/message/entity/message_ext.go
Normal file
@@ -0,0 +1,55 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package entity
|
||||
|
||||
type MessageExtKey string
|
||||
|
||||
const (
|
||||
MessageExtKeyInputTokens MessageExtKey = "input_tokens"
|
||||
MessageExtKeyOutputTokens MessageExtKey = "output_tokens"
|
||||
MessageExtKeyToken MessageExtKey = "token"
|
||||
MessageExtKeyPluginStatus MessageExtKey = "plugin_status"
|
||||
MessageExtKeyTimeCost MessageExtKey = "time_cost"
|
||||
MessageExtKeyWorkflowTokens MessageExtKey = "workflow_tokens"
|
||||
MessageExtKeyBotState MessageExtKey = "bot_state"
|
||||
MessageExtKeyPluginRequest MessageExtKey = "plugin_request"
|
||||
MessageExtKeyToolName MessageExtKey = "tool_name"
|
||||
MessageExtKeyPlugin MessageExtKey = "plugin"
|
||||
MessageExtKeyMockHitInfo MessageExtKey = "mock_hit_info"
|
||||
MessageExtKeyMessageTitle MessageExtKey = "message_title"
|
||||
MessageExtKeyStreamPluginRunning MessageExtKey = "stream_plugin_running"
|
||||
MessageExtKeyExecuteDisplayName MessageExtKey = "execute_display_name"
|
||||
MessageExtKeyTaskType MessageExtKey = "task_type"
|
||||
MessageExtKeyCallID MessageExtKey = "call_id"
|
||||
ExtKeyResumeInfo MessageExtKey = "resume_info"
|
||||
ExtKeyBreakPoint MessageExtKey = "break_point"
|
||||
ExtKeyToolCallsIDs MessageExtKey = "tool_calls_ids"
|
||||
ExtKeyRequiresAction MessageExtKey = "requires_action"
|
||||
)
|
||||
|
||||
type BotStateExt struct {
|
||||
BotID string `json:"bot_id"`
|
||||
AgentName string `json:"agent_name"`
|
||||
AgentID string `json:"agent_id"`
|
||||
Awaiting string `json:"awaiting"`
|
||||
}
|
||||
|
||||
type UsageExt struct {
|
||||
TotalCount int64 `json:"total_count"`
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
}
|
||||
365
backend/domain/conversation/message/internal/dal/message.go
Normal file
365
backend/domain/conversation/message/internal/dal/message.go
Normal file
@@ -0,0 +1,365 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package dal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/internal/dal/model"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/internal/dal/query"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
type MessageDAO struct {
|
||||
query *query.Query
|
||||
idgen idgen.IDGenerator
|
||||
}
|
||||
|
||||
func NewMessageDAO(db *gorm.DB, idgen idgen.IDGenerator) *MessageDAO {
|
||||
return &MessageDAO{
|
||||
query: query.Use(db),
|
||||
idgen: idgen,
|
||||
}
|
||||
}
|
||||
|
||||
func (dao *MessageDAO) Create(ctx context.Context, msg *entity.Message) (*entity.Message, error) {
|
||||
poData, err := dao.messageDO2PO(ctx, msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
do := dao.query.Message.WithContext(ctx).Debug()
|
||||
cErr := do.Create(poData)
|
||||
if cErr != nil {
|
||||
return nil, cErr
|
||||
}
|
||||
|
||||
return dao.messagePO2DO(poData), nil
|
||||
}
|
||||
|
||||
func (dao *MessageDAO) List(ctx context.Context, conversationID int64, limit int, cursor int64, direction entity.ScrollPageDirection, messageType *message.MessageType) ([]*entity.Message, bool, error) {
|
||||
m := dao.query.Message
|
||||
do := m.WithContext(ctx).Debug().Where(m.ConversationID.Eq(conversationID)).Where(m.Status.Eq(int32(entity.MessageStatusAvailable)))
|
||||
|
||||
if messageType != nil {
|
||||
do = do.Where(m.MessageType.Eq(string(*messageType)))
|
||||
}
|
||||
|
||||
if limit > 0 {
|
||||
do = do.Limit(int(limit) + 1)
|
||||
}
|
||||
|
||||
if cursor > 0 {
|
||||
if direction == entity.ScrollPageDirectionPrev {
|
||||
do = do.Where(m.CreatedAt.Lt(cursor))
|
||||
} else {
|
||||
do = do.Where(m.CreatedAt.Gt(cursor))
|
||||
}
|
||||
}
|
||||
|
||||
do = do.Order(m.CreatedAt.Desc())
|
||||
messageList, err := do.Find()
|
||||
|
||||
var hasMore bool
|
||||
|
||||
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, hasMore, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
if len(messageList) > limit {
|
||||
hasMore = true
|
||||
messageList = messageList[:limit]
|
||||
}
|
||||
|
||||
return dao.batchMessagePO2DO(messageList), hasMore, nil
|
||||
}
|
||||
|
||||
func (dao *MessageDAO) GetByRunIDs(ctx context.Context, runIDs []int64, orderBy string) ([]*entity.Message, error) {
|
||||
m := dao.query.Message
|
||||
do := m.WithContext(ctx).Debug().Where(m.RunID.In(runIDs...))
|
||||
if orderBy == "DESC" {
|
||||
do = do.Order(m.CreatedAt.Desc())
|
||||
} else {
|
||||
do = do.Order(m.CreatedAt.Asc())
|
||||
}
|
||||
poList, err := do.Find()
|
||||
|
||||
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return dao.batchMessagePO2DO(poList), nil
|
||||
}
|
||||
|
||||
func (dao *MessageDAO) Edit(ctx context.Context, msgID int64, msg *message.Message) (int64, error) {
|
||||
m := dao.query.Message
|
||||
columns := dao.buildEditColumns(msg)
|
||||
do, err := m.WithContext(ctx).Where(m.ID.Eq(msgID)).UpdateColumns(columns)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return do.RowsAffected, nil
|
||||
}
|
||||
|
||||
func (dao *MessageDAO) buildEditColumns(msg *message.Message) map[string]interface{} {
|
||||
columns := make(map[string]interface{})
|
||||
table := dao.query.Message
|
||||
if msg.Content != "" {
|
||||
columns[table.Content.ColumnName().String()] = msg.Content
|
||||
}
|
||||
if msg.MessageType != "" {
|
||||
columns[table.MessageType.ColumnName().String()] = msg.MessageType
|
||||
}
|
||||
if msg.ContentType != "" {
|
||||
columns[table.ContentType.ColumnName().String()] = msg.ContentType
|
||||
}
|
||||
if len(msg.ReasoningContent) > 0 {
|
||||
columns[table.ReasoningContent.ColumnName().String()] = msg.ReasoningContent
|
||||
}
|
||||
|
||||
if msg.Position > 0 {
|
||||
columns[table.BrokenPosition.ColumnName().String()] = msg.Position
|
||||
}
|
||||
if msg.Status > 0 {
|
||||
columns[table.Status.ColumnName().String()] = msg.Status
|
||||
}
|
||||
|
||||
if len(msg.ModelContent) > 0 {
|
||||
columns[table.ModelContent.ColumnName().String()] = msg.ModelContent
|
||||
}
|
||||
|
||||
columns[table.UpdatedAt.ColumnName().String()] = time.Now().UnixMilli()
|
||||
if msg.Ext != nil {
|
||||
ext, err := sonic.MarshalString(msg.Ext)
|
||||
if err == nil {
|
||||
columns[table.Ext.ColumnName().String()] = ext
|
||||
}
|
||||
}
|
||||
return columns
|
||||
}
|
||||
|
||||
func (dao *MessageDAO) GetByID(ctx context.Context, msgID int64) (*entity.Message, error) {
|
||||
m := dao.query.Message
|
||||
do := m.WithContext(ctx).Where(m.ID.Eq(msgID))
|
||||
po, err := do.First()
|
||||
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return dao.messagePO2DO(po), nil
|
||||
}
|
||||
|
||||
func (dao *MessageDAO) Delete(ctx context.Context, msgIDs []int64, runIDs []int64) error {
|
||||
if len(msgIDs) == 0 && len(runIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
updateColumns := make(map[string]interface{})
|
||||
updateColumns["status"] = int32(entity.MessageStatusDeleted)
|
||||
m := dao.query.Message
|
||||
do := m.WithContext(ctx)
|
||||
|
||||
if len(runIDs) > 0 {
|
||||
do = do.Where(m.RunID.In(runIDs...))
|
||||
}
|
||||
if len(msgIDs) > 0 {
|
||||
do = do.Where(m.ID.In(msgIDs...))
|
||||
}
|
||||
_, err := do.UpdateColumns(&updateColumns)
|
||||
return err
|
||||
}
|
||||
|
||||
func (dao *MessageDAO) messageDO2PO(ctx context.Context, msgDo *entity.Message) (*model.Message, error) {
|
||||
id, gErr := dao.idgen.GenID(ctx)
|
||||
if gErr != nil {
|
||||
return nil, gErr
|
||||
}
|
||||
msgPO := &model.Message{
|
||||
ID: id,
|
||||
ConversationID: msgDo.ConversationID,
|
||||
RunID: msgDo.RunID,
|
||||
AgentID: msgDo.AgentID,
|
||||
SectionID: msgDo.SectionID,
|
||||
UserID: msgDo.UserID,
|
||||
Role: string(msgDo.Role),
|
||||
ContentType: string(msgDo.ContentType),
|
||||
MessageType: string(msgDo.MessageType),
|
||||
DisplayContent: msgDo.DisplayContent,
|
||||
Content: msgDo.Content,
|
||||
BrokenPosition: msgDo.Position,
|
||||
Status: int32(entity.MessageStatusAvailable),
|
||||
CreatedAt: time.Now().UnixMilli(),
|
||||
UpdatedAt: time.Now().UnixMilli(),
|
||||
}
|
||||
|
||||
mc, err := dao.buildModelContent(msgDo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msgPO.ModelContent = mc
|
||||
|
||||
ext, err := json.Marshal(msgDo.Ext)
|
||||
if err != nil {
|
||||
return nil, errorx.WrapByCode(err, errno.ErrConversationJsonMarshal)
|
||||
}
|
||||
msgPO.Ext = string(ext)
|
||||
|
||||
return msgPO, nil
|
||||
}
|
||||
|
||||
func (dao *MessageDAO) buildModelContent(msgDO *entity.Message) (string, error) {
|
||||
modelContent := msgDO.ModelContent
|
||||
if modelContent != "" {
|
||||
return modelContent, nil
|
||||
}
|
||||
|
||||
modelContentObj := &schema.Message{
|
||||
Role: msgDO.Role,
|
||||
Name: msgDO.Name,
|
||||
}
|
||||
if msgDO.Content == "" && len(msgDO.MultiContent) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var multiContent []schema.ChatMessagePart
|
||||
for _, contentData := range msgDO.MultiContent {
|
||||
if contentData.Type == message.InputTypeText {
|
||||
continue
|
||||
}
|
||||
one := schema.ChatMessagePart{}
|
||||
switch contentData.Type {
|
||||
case message.InputTypeImage:
|
||||
one.Type = schema.ChatMessagePartTypeImageURL
|
||||
one.ImageURL = &schema.ChatMessageImageURL{
|
||||
URL: contentData.FileData[0].Url,
|
||||
}
|
||||
case message.InputTypeFile:
|
||||
one.Type = schema.ChatMessagePartTypeFileURL
|
||||
one.FileURL = &schema.ChatMessageFileURL{
|
||||
URL: contentData.FileData[0].Url,
|
||||
}
|
||||
case message.InputTypeVideo:
|
||||
one.Type = schema.ChatMessagePartTypeVideoURL
|
||||
one.VideoURL = &schema.ChatMessageVideoURL{
|
||||
URL: contentData.FileData[0].Url,
|
||||
}
|
||||
case message.InputTypeAudio:
|
||||
one.Type = schema.ChatMessagePartTypeFileURL
|
||||
one.AudioURL = &schema.ChatMessageAudioURL{
|
||||
URL: contentData.FileData[0].Url,
|
||||
}
|
||||
}
|
||||
multiContent = append(multiContent, one)
|
||||
}
|
||||
if len(multiContent) > 0 {
|
||||
multiContent = append(multiContent, schema.ChatMessagePart{
|
||||
Type: schema.ChatMessagePartTypeText,
|
||||
Text: msgDO.Content,
|
||||
})
|
||||
} else {
|
||||
modelContentObj.Content = msgDO.Content
|
||||
}
|
||||
|
||||
modelContentObj.MultiContent = multiContent
|
||||
|
||||
mcObjByte, err := json.Marshal(modelContentObj)
|
||||
if err != nil {
|
||||
return "", errorx.WrapByCode(err, errno.ErrConversationJsonMarshal)
|
||||
}
|
||||
|
||||
return string(mcObjByte), nil
|
||||
}
|
||||
|
||||
func (dao *MessageDAO) batchMessagePO2DO(msgPOs []*model.Message) []*entity.Message {
|
||||
return slices.Transform(msgPOs, func(msgPO *model.Message) *entity.Message {
|
||||
msgDO := &entity.Message{
|
||||
ID: msgPO.ID,
|
||||
AgentID: msgPO.AgentID,
|
||||
ConversationID: msgPO.ConversationID,
|
||||
SectionID: msgPO.SectionID,
|
||||
UserID: msgPO.UserID,
|
||||
RunID: msgPO.RunID,
|
||||
Role: schema.RoleType(msgPO.Role),
|
||||
ContentType: message.ContentType(msgPO.ContentType),
|
||||
MessageType: message.MessageType(msgPO.MessageType),
|
||||
Position: msgPO.BrokenPosition,
|
||||
ModelContent: msgPO.ModelContent,
|
||||
Content: msgPO.Content,
|
||||
Status: message.MessageStatus(msgPO.Status),
|
||||
DisplayContent: msgPO.DisplayContent,
|
||||
CreatedAt: msgPO.CreatedAt,
|
||||
UpdatedAt: msgPO.UpdatedAt,
|
||||
ReasoningContent: msgPO.ReasoningContent,
|
||||
}
|
||||
|
||||
var ext map[string]string
|
||||
err := json.Unmarshal([]byte(msgPO.Ext), &ext)
|
||||
if err == nil {
|
||||
msgDO.Ext = ext
|
||||
}
|
||||
|
||||
return msgDO
|
||||
})
|
||||
}
|
||||
|
||||
func (dao *MessageDAO) messagePO2DO(msgPO *model.Message) *entity.Message {
|
||||
msgDO := &entity.Message{
|
||||
ID: msgPO.ID,
|
||||
AgentID: msgPO.AgentID,
|
||||
ConversationID: msgPO.ConversationID,
|
||||
SectionID: msgPO.SectionID,
|
||||
UserID: msgPO.UserID,
|
||||
RunID: msgPO.RunID,
|
||||
Role: schema.RoleType(msgPO.Role),
|
||||
ContentType: message.ContentType(msgPO.ContentType),
|
||||
MessageType: message.MessageType(msgPO.MessageType),
|
||||
ModelContent: msgPO.ModelContent,
|
||||
Content: msgPO.Content,
|
||||
DisplayContent: msgPO.DisplayContent,
|
||||
CreatedAt: msgPO.CreatedAt,
|
||||
UpdatedAt: msgPO.UpdatedAt,
|
||||
}
|
||||
|
||||
var ext map[string]string
|
||||
err := json.Unmarshal([]byte(msgPO.Ext), &ext)
|
||||
if err == nil {
|
||||
msgDO.Ext = ext
|
||||
}
|
||||
|
||||
return msgDO
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
|
||||
package model
|
||||
|
||||
const TableNameMessage = "message"
|
||||
|
||||
// Message 消息表
|
||||
type Message struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;autoIncrement:true;comment:主键ID" json:"id"` // 主键ID
|
||||
RunID int64 `gorm:"column:run_id;not null;comment:对应的run_id" json:"run_id"` // 对应的run_id
|
||||
ConversationID int64 `gorm:"column:conversation_id;not null;comment:conversation id" json:"conversation_id"` // conversation id
|
||||
UserID string `gorm:"column:user_id;not null;comment:user id" json:"user_id"` // user id
|
||||
AgentID int64 `gorm:"column:agent_id;not null;comment:agent_id" json:"agent_id"` // agent_id
|
||||
Role string `gorm:"column:role;not null;comment:角色: user、assistant、system" json:"role"` // 角色: user、assistant、system
|
||||
ContentType string `gorm:"column:content_type;not null;comment:内容类型 1 text" json:"content_type"` // 内容类型 1 text
|
||||
Content string `gorm:"column:content;comment:内容" json:"content"` // 内容
|
||||
MessageType string `gorm:"column:message_type;not null;comment:消息类型:" json:"message_type"` // 消息类型:
|
||||
DisplayContent string `gorm:"column:display_content;comment:展示内容" json:"display_content"` // 展示内容
|
||||
Ext string `gorm:"column:ext;comment:message 扩展字段" json:"ext"` // message 扩展字段
|
||||
SectionID int64 `gorm:"column:section_id;comment:段落id" json:"section_id"` // 段落id
|
||||
BrokenPosition int32 `gorm:"column:broken_position;default:-1;comment:打断位置" json:"broken_position"` // 打断位置
|
||||
Status int32 `gorm:"column:status;not null;comment:消息状态 1 Available 2 Deleted 3 Replaced 4 Broken 5 Failed 6 Streaming 7 Pending" json:"status"` // 消息状态 1 Available 2 Deleted 3 Replaced 4 Broken 5 Failed 6 Streaming 7 Pending
|
||||
ModelContent string `gorm:"column:model_content;comment:模型输入内容" json:"model_content"` // 模型输入内容
|
||||
MetaInfo string `gorm:"column:meta_info;comment:引用、高亮等文本标记信息" json:"meta_info"` // 引用、高亮等文本标记信息
|
||||
ReasoningContent string `gorm:"column:reasoning_content;comment:思考内容" json:"reasoning_content"` // 思考内容
|
||||
CreatedAt int64 `gorm:"column:created_at;not null;autoCreateTime:milli;comment:创建时间" json:"created_at"` // 创建时间
|
||||
UpdatedAt int64 `gorm:"column:updated_at;not null;autoUpdateTime:milli;comment:更新时间" json:"updated_at"` // 更新时间
|
||||
}
|
||||
|
||||
// TableName Message's table name
|
||||
func (*Message) TableName() string {
|
||||
return TableNameMessage
|
||||
}
|
||||
103
backend/domain/conversation/message/internal/dal/query/gen.go
Normal file
103
backend/domain/conversation/message/internal/dal/query/gen.go
Normal file
@@ -0,0 +1,103 @@
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
|
||||
package query
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"gorm.io/gen"
|
||||
|
||||
"gorm.io/plugin/dbresolver"
|
||||
)
|
||||
|
||||
var (
|
||||
Q = new(Query)
|
||||
Message *message
|
||||
)
|
||||
|
||||
func SetDefault(db *gorm.DB, opts ...gen.DOOption) {
|
||||
*Q = *Use(db, opts...)
|
||||
Message = &Q.Message
|
||||
}
|
||||
|
||||
func Use(db *gorm.DB, opts ...gen.DOOption) *Query {
|
||||
return &Query{
|
||||
db: db,
|
||||
Message: newMessage(db, opts...),
|
||||
}
|
||||
}
|
||||
|
||||
type Query struct {
|
||||
db *gorm.DB
|
||||
|
||||
Message message
|
||||
}
|
||||
|
||||
func (q *Query) Available() bool { return q.db != nil }
|
||||
|
||||
func (q *Query) clone(db *gorm.DB) *Query {
|
||||
return &Query{
|
||||
db: db,
|
||||
Message: q.Message.clone(db),
|
||||
}
|
||||
}
|
||||
|
||||
func (q *Query) ReadDB() *Query {
|
||||
return q.ReplaceDB(q.db.Clauses(dbresolver.Read))
|
||||
}
|
||||
|
||||
func (q *Query) WriteDB() *Query {
|
||||
return q.ReplaceDB(q.db.Clauses(dbresolver.Write))
|
||||
}
|
||||
|
||||
func (q *Query) ReplaceDB(db *gorm.DB) *Query {
|
||||
return &Query{
|
||||
db: db,
|
||||
Message: q.Message.replaceDB(db),
|
||||
}
|
||||
}
|
||||
|
||||
type queryCtx struct {
|
||||
Message IMessageDo
|
||||
}
|
||||
|
||||
func (q *Query) WithContext(ctx context.Context) *queryCtx {
|
||||
return &queryCtx{
|
||||
Message: q.Message.WithContext(ctx),
|
||||
}
|
||||
}
|
||||
|
||||
func (q *Query) Transaction(fc func(tx *Query) error, opts ...*sql.TxOptions) error {
|
||||
return q.db.Transaction(func(tx *gorm.DB) error { return fc(q.clone(tx)) }, opts...)
|
||||
}
|
||||
|
||||
func (q *Query) Begin(opts ...*sql.TxOptions) *QueryTx {
|
||||
tx := q.db.Begin(opts...)
|
||||
return &QueryTx{Query: q.clone(tx), Error: tx.Error}
|
||||
}
|
||||
|
||||
type QueryTx struct {
|
||||
*Query
|
||||
Error error
|
||||
}
|
||||
|
||||
func (q *QueryTx) Commit() error {
|
||||
return q.db.Commit().Error
|
||||
}
|
||||
|
||||
func (q *QueryTx) Rollback() error {
|
||||
return q.db.Rollback().Error
|
||||
}
|
||||
|
||||
func (q *QueryTx) SavePoint(name string) error {
|
||||
return q.db.SavePoint(name).Error
|
||||
}
|
||||
|
||||
func (q *QueryTx) RollbackTo(name string) error {
|
||||
return q.db.RollbackTo(name).Error
|
||||
}
|
||||
@@ -0,0 +1,453 @@
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
|
||||
package query
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
|
||||
"gorm.io/gen"
|
||||
"gorm.io/gen/field"
|
||||
|
||||
"gorm.io/plugin/dbresolver"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/internal/dal/model"
|
||||
)
|
||||
|
||||
func newMessage(db *gorm.DB, opts ...gen.DOOption) message {
|
||||
_message := message{}
|
||||
|
||||
_message.messageDo.UseDB(db, opts...)
|
||||
_message.messageDo.UseModel(&model.Message{})
|
||||
|
||||
tableName := _message.messageDo.TableName()
|
||||
_message.ALL = field.NewAsterisk(tableName)
|
||||
_message.ID = field.NewInt64(tableName, "id")
|
||||
_message.RunID = field.NewInt64(tableName, "run_id")
|
||||
_message.ConversationID = field.NewInt64(tableName, "conversation_id")
|
||||
_message.UserID = field.NewString(tableName, "user_id")
|
||||
_message.AgentID = field.NewInt64(tableName, "agent_id")
|
||||
_message.Role = field.NewString(tableName, "role")
|
||||
_message.ContentType = field.NewString(tableName, "content_type")
|
||||
_message.Content = field.NewString(tableName, "content")
|
||||
_message.MessageType = field.NewString(tableName, "message_type")
|
||||
_message.DisplayContent = field.NewString(tableName, "display_content")
|
||||
_message.Ext = field.NewString(tableName, "ext")
|
||||
_message.SectionID = field.NewInt64(tableName, "section_id")
|
||||
_message.BrokenPosition = field.NewInt32(tableName, "broken_position")
|
||||
_message.Status = field.NewInt32(tableName, "status")
|
||||
_message.ModelContent = field.NewString(tableName, "model_content")
|
||||
_message.MetaInfo = field.NewString(tableName, "meta_info")
|
||||
_message.ReasoningContent = field.NewString(tableName, "reasoning_content")
|
||||
_message.CreatedAt = field.NewInt64(tableName, "created_at")
|
||||
_message.UpdatedAt = field.NewInt64(tableName, "updated_at")
|
||||
|
||||
_message.fillFieldMap()
|
||||
|
||||
return _message
|
||||
}
|
||||
|
||||
// message 消息表
|
||||
type message struct {
|
||||
messageDo
|
||||
|
||||
ALL field.Asterisk
|
||||
ID field.Int64 // 主键ID
|
||||
RunID field.Int64 // 对应的run_id
|
||||
ConversationID field.Int64 // conversation id
|
||||
UserID field.String // user id
|
||||
AgentID field.Int64 // agent_id
|
||||
Role field.String // 角色: user、assistant、system
|
||||
ContentType field.String // 内容类型 1 text
|
||||
Content field.String // 内容
|
||||
MessageType field.String // 消息类型:
|
||||
DisplayContent field.String // 展示内容
|
||||
Ext field.String // message 扩展字段
|
||||
SectionID field.Int64 // 段落id
|
||||
BrokenPosition field.Int32 // 打断位置
|
||||
Status field.Int32 // 消息状态 1 Available 2 Deleted 3 Replaced 4 Broken 5 Failed 6 Streaming 7 Pending
|
||||
ModelContent field.String // 模型输入内容
|
||||
MetaInfo field.String // 引用、高亮等文本标记信息
|
||||
ReasoningContent field.String // 思考内容
|
||||
CreatedAt field.Int64 // 创建时间
|
||||
UpdatedAt field.Int64 // 更新时间
|
||||
|
||||
fieldMap map[string]field.Expr
|
||||
}
|
||||
|
||||
func (m message) Table(newTableName string) *message {
|
||||
m.messageDo.UseTable(newTableName)
|
||||
return m.updateTableName(newTableName)
|
||||
}
|
||||
|
||||
func (m message) As(alias string) *message {
|
||||
m.messageDo.DO = *(m.messageDo.As(alias).(*gen.DO))
|
||||
return m.updateTableName(alias)
|
||||
}
|
||||
|
||||
func (m *message) updateTableName(table string) *message {
|
||||
m.ALL = field.NewAsterisk(table)
|
||||
m.ID = field.NewInt64(table, "id")
|
||||
m.RunID = field.NewInt64(table, "run_id")
|
||||
m.ConversationID = field.NewInt64(table, "conversation_id")
|
||||
m.UserID = field.NewString(table, "user_id")
|
||||
m.AgentID = field.NewInt64(table, "agent_id")
|
||||
m.Role = field.NewString(table, "role")
|
||||
m.ContentType = field.NewString(table, "content_type")
|
||||
m.Content = field.NewString(table, "content")
|
||||
m.MessageType = field.NewString(table, "message_type")
|
||||
m.DisplayContent = field.NewString(table, "display_content")
|
||||
m.Ext = field.NewString(table, "ext")
|
||||
m.SectionID = field.NewInt64(table, "section_id")
|
||||
m.BrokenPosition = field.NewInt32(table, "broken_position")
|
||||
m.Status = field.NewInt32(table, "status")
|
||||
m.ModelContent = field.NewString(table, "model_content")
|
||||
m.MetaInfo = field.NewString(table, "meta_info")
|
||||
m.ReasoningContent = field.NewString(table, "reasoning_content")
|
||||
m.CreatedAt = field.NewInt64(table, "created_at")
|
||||
m.UpdatedAt = field.NewInt64(table, "updated_at")
|
||||
|
||||
m.fillFieldMap()
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *message) GetFieldByName(fieldName string) (field.OrderExpr, bool) {
|
||||
_f, ok := m.fieldMap[fieldName]
|
||||
if !ok || _f == nil {
|
||||
return nil, false
|
||||
}
|
||||
_oe, ok := _f.(field.OrderExpr)
|
||||
return _oe, ok
|
||||
}
|
||||
|
||||
func (m *message) fillFieldMap() {
|
||||
m.fieldMap = make(map[string]field.Expr, 19)
|
||||
m.fieldMap["id"] = m.ID
|
||||
m.fieldMap["run_id"] = m.RunID
|
||||
m.fieldMap["conversation_id"] = m.ConversationID
|
||||
m.fieldMap["user_id"] = m.UserID
|
||||
m.fieldMap["agent_id"] = m.AgentID
|
||||
m.fieldMap["role"] = m.Role
|
||||
m.fieldMap["content_type"] = m.ContentType
|
||||
m.fieldMap["content"] = m.Content
|
||||
m.fieldMap["message_type"] = m.MessageType
|
||||
m.fieldMap["display_content"] = m.DisplayContent
|
||||
m.fieldMap["ext"] = m.Ext
|
||||
m.fieldMap["section_id"] = m.SectionID
|
||||
m.fieldMap["broken_position"] = m.BrokenPosition
|
||||
m.fieldMap["status"] = m.Status
|
||||
m.fieldMap["model_content"] = m.ModelContent
|
||||
m.fieldMap["meta_info"] = m.MetaInfo
|
||||
m.fieldMap["reasoning_content"] = m.ReasoningContent
|
||||
m.fieldMap["created_at"] = m.CreatedAt
|
||||
m.fieldMap["updated_at"] = m.UpdatedAt
|
||||
}
|
||||
|
||||
func (m message) clone(db *gorm.DB) message {
|
||||
m.messageDo.ReplaceConnPool(db.Statement.ConnPool)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m message) replaceDB(db *gorm.DB) message {
|
||||
m.messageDo.ReplaceDB(db)
|
||||
return m
|
||||
}
|
||||
|
||||
type messageDo struct{ gen.DO }
|
||||
|
||||
type IMessageDo interface {
|
||||
gen.SubQuery
|
||||
Debug() IMessageDo
|
||||
WithContext(ctx context.Context) IMessageDo
|
||||
WithResult(fc func(tx gen.Dao)) gen.ResultInfo
|
||||
ReplaceDB(db *gorm.DB)
|
||||
ReadDB() IMessageDo
|
||||
WriteDB() IMessageDo
|
||||
As(alias string) gen.Dao
|
||||
Session(config *gorm.Session) IMessageDo
|
||||
Columns(cols ...field.Expr) gen.Columns
|
||||
Clauses(conds ...clause.Expression) IMessageDo
|
||||
Not(conds ...gen.Condition) IMessageDo
|
||||
Or(conds ...gen.Condition) IMessageDo
|
||||
Select(conds ...field.Expr) IMessageDo
|
||||
Where(conds ...gen.Condition) IMessageDo
|
||||
Order(conds ...field.Expr) IMessageDo
|
||||
Distinct(cols ...field.Expr) IMessageDo
|
||||
Omit(cols ...field.Expr) IMessageDo
|
||||
Join(table schema.Tabler, on ...field.Expr) IMessageDo
|
||||
LeftJoin(table schema.Tabler, on ...field.Expr) IMessageDo
|
||||
RightJoin(table schema.Tabler, on ...field.Expr) IMessageDo
|
||||
Group(cols ...field.Expr) IMessageDo
|
||||
Having(conds ...gen.Condition) IMessageDo
|
||||
Limit(limit int) IMessageDo
|
||||
Offset(offset int) IMessageDo
|
||||
Count() (count int64, err error)
|
||||
Scopes(funcs ...func(gen.Dao) gen.Dao) IMessageDo
|
||||
Unscoped() IMessageDo
|
||||
Create(values ...*model.Message) error
|
||||
CreateInBatches(values []*model.Message, batchSize int) error
|
||||
Save(values ...*model.Message) error
|
||||
First() (*model.Message, error)
|
||||
Take() (*model.Message, error)
|
||||
Last() (*model.Message, error)
|
||||
Find() ([]*model.Message, error)
|
||||
FindInBatch(batchSize int, fc func(tx gen.Dao, batch int) error) (results []*model.Message, err error)
|
||||
FindInBatches(result *[]*model.Message, batchSize int, fc func(tx gen.Dao, batch int) error) error
|
||||
Pluck(column field.Expr, dest interface{}) error
|
||||
Delete(...*model.Message) (info gen.ResultInfo, err error)
|
||||
Update(column field.Expr, value interface{}) (info gen.ResultInfo, err error)
|
||||
UpdateSimple(columns ...field.AssignExpr) (info gen.ResultInfo, err error)
|
||||
Updates(value interface{}) (info gen.ResultInfo, err error)
|
||||
UpdateColumn(column field.Expr, value interface{}) (info gen.ResultInfo, err error)
|
||||
UpdateColumnSimple(columns ...field.AssignExpr) (info gen.ResultInfo, err error)
|
||||
UpdateColumns(value interface{}) (info gen.ResultInfo, err error)
|
||||
UpdateFrom(q gen.SubQuery) gen.Dao
|
||||
Attrs(attrs ...field.AssignExpr) IMessageDo
|
||||
Assign(attrs ...field.AssignExpr) IMessageDo
|
||||
Joins(fields ...field.RelationField) IMessageDo
|
||||
Preload(fields ...field.RelationField) IMessageDo
|
||||
FirstOrInit() (*model.Message, error)
|
||||
FirstOrCreate() (*model.Message, error)
|
||||
FindByPage(offset int, limit int) (result []*model.Message, count int64, err error)
|
||||
ScanByPage(result interface{}, offset int, limit int) (count int64, err error)
|
||||
Scan(result interface{}) (err error)
|
||||
Returning(value interface{}, columns ...string) IMessageDo
|
||||
UnderlyingDB() *gorm.DB
|
||||
schema.Tabler
|
||||
}
|
||||
|
||||
func (m messageDo) Debug() IMessageDo {
|
||||
return m.withDO(m.DO.Debug())
|
||||
}
|
||||
|
||||
func (m messageDo) WithContext(ctx context.Context) IMessageDo {
|
||||
return m.withDO(m.DO.WithContext(ctx))
|
||||
}
|
||||
|
||||
func (m messageDo) ReadDB() IMessageDo {
|
||||
return m.Clauses(dbresolver.Read)
|
||||
}
|
||||
|
||||
func (m messageDo) WriteDB() IMessageDo {
|
||||
return m.Clauses(dbresolver.Write)
|
||||
}
|
||||
|
||||
func (m messageDo) Session(config *gorm.Session) IMessageDo {
|
||||
return m.withDO(m.DO.Session(config))
|
||||
}
|
||||
|
||||
func (m messageDo) Clauses(conds ...clause.Expression) IMessageDo {
|
||||
return m.withDO(m.DO.Clauses(conds...))
|
||||
}
|
||||
|
||||
func (m messageDo) Returning(value interface{}, columns ...string) IMessageDo {
|
||||
return m.withDO(m.DO.Returning(value, columns...))
|
||||
}
|
||||
|
||||
func (m messageDo) Not(conds ...gen.Condition) IMessageDo {
|
||||
return m.withDO(m.DO.Not(conds...))
|
||||
}
|
||||
|
||||
func (m messageDo) Or(conds ...gen.Condition) IMessageDo {
|
||||
return m.withDO(m.DO.Or(conds...))
|
||||
}
|
||||
|
||||
func (m messageDo) Select(conds ...field.Expr) IMessageDo {
|
||||
return m.withDO(m.DO.Select(conds...))
|
||||
}
|
||||
|
||||
func (m messageDo) Where(conds ...gen.Condition) IMessageDo {
|
||||
return m.withDO(m.DO.Where(conds...))
|
||||
}
|
||||
|
||||
func (m messageDo) Order(conds ...field.Expr) IMessageDo {
|
||||
return m.withDO(m.DO.Order(conds...))
|
||||
}
|
||||
|
||||
func (m messageDo) Distinct(cols ...field.Expr) IMessageDo {
|
||||
return m.withDO(m.DO.Distinct(cols...))
|
||||
}
|
||||
|
||||
func (m messageDo) Omit(cols ...field.Expr) IMessageDo {
|
||||
return m.withDO(m.DO.Omit(cols...))
|
||||
}
|
||||
|
||||
func (m messageDo) Join(table schema.Tabler, on ...field.Expr) IMessageDo {
|
||||
return m.withDO(m.DO.Join(table, on...))
|
||||
}
|
||||
|
||||
func (m messageDo) LeftJoin(table schema.Tabler, on ...field.Expr) IMessageDo {
|
||||
return m.withDO(m.DO.LeftJoin(table, on...))
|
||||
}
|
||||
|
||||
func (m messageDo) RightJoin(table schema.Tabler, on ...field.Expr) IMessageDo {
|
||||
return m.withDO(m.DO.RightJoin(table, on...))
|
||||
}
|
||||
|
||||
func (m messageDo) Group(cols ...field.Expr) IMessageDo {
|
||||
return m.withDO(m.DO.Group(cols...))
|
||||
}
|
||||
|
||||
func (m messageDo) Having(conds ...gen.Condition) IMessageDo {
|
||||
return m.withDO(m.DO.Having(conds...))
|
||||
}
|
||||
|
||||
func (m messageDo) Limit(limit int) IMessageDo {
|
||||
return m.withDO(m.DO.Limit(limit))
|
||||
}
|
||||
|
||||
func (m messageDo) Offset(offset int) IMessageDo {
|
||||
return m.withDO(m.DO.Offset(offset))
|
||||
}
|
||||
|
||||
func (m messageDo) Scopes(funcs ...func(gen.Dao) gen.Dao) IMessageDo {
|
||||
return m.withDO(m.DO.Scopes(funcs...))
|
||||
}
|
||||
|
||||
func (m messageDo) Unscoped() IMessageDo {
|
||||
return m.withDO(m.DO.Unscoped())
|
||||
}
|
||||
|
||||
func (m messageDo) Create(values ...*model.Message) error {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
return m.DO.Create(values)
|
||||
}
|
||||
|
||||
func (m messageDo) CreateInBatches(values []*model.Message, batchSize int) error {
|
||||
return m.DO.CreateInBatches(values, batchSize)
|
||||
}
|
||||
|
||||
// Save : !!! underlying implementation is different with GORM
|
||||
// The method is equivalent to executing the statement: db.Clauses(clause.OnConflict{UpdateAll: true}).Create(values)
|
||||
func (m messageDo) Save(values ...*model.Message) error {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
return m.DO.Save(values)
|
||||
}
|
||||
|
||||
func (m messageDo) First() (*model.Message, error) {
|
||||
if result, err := m.DO.First(); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return result.(*model.Message), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m messageDo) Take() (*model.Message, error) {
|
||||
if result, err := m.DO.Take(); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return result.(*model.Message), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m messageDo) Last() (*model.Message, error) {
|
||||
if result, err := m.DO.Last(); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return result.(*model.Message), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m messageDo) Find() ([]*model.Message, error) {
|
||||
result, err := m.DO.Find()
|
||||
return result.([]*model.Message), err
|
||||
}
|
||||
|
||||
func (m messageDo) FindInBatch(batchSize int, fc func(tx gen.Dao, batch int) error) (results []*model.Message, err error) {
|
||||
buf := make([]*model.Message, 0, batchSize)
|
||||
err = m.DO.FindInBatches(&buf, batchSize, func(tx gen.Dao, batch int) error {
|
||||
defer func() { results = append(results, buf...) }()
|
||||
return fc(tx, batch)
|
||||
})
|
||||
return results, err
|
||||
}
|
||||
|
||||
func (m messageDo) FindInBatches(result *[]*model.Message, batchSize int, fc func(tx gen.Dao, batch int) error) error {
|
||||
return m.DO.FindInBatches(result, batchSize, fc)
|
||||
}
|
||||
|
||||
func (m messageDo) Attrs(attrs ...field.AssignExpr) IMessageDo {
|
||||
return m.withDO(m.DO.Attrs(attrs...))
|
||||
}
|
||||
|
||||
func (m messageDo) Assign(attrs ...field.AssignExpr) IMessageDo {
|
||||
return m.withDO(m.DO.Assign(attrs...))
|
||||
}
|
||||
|
||||
func (m messageDo) Joins(fields ...field.RelationField) IMessageDo {
|
||||
for _, _f := range fields {
|
||||
m = *m.withDO(m.DO.Joins(_f))
|
||||
}
|
||||
return &m
|
||||
}
|
||||
|
||||
func (m messageDo) Preload(fields ...field.RelationField) IMessageDo {
|
||||
for _, _f := range fields {
|
||||
m = *m.withDO(m.DO.Preload(_f))
|
||||
}
|
||||
return &m
|
||||
}
|
||||
|
||||
func (m messageDo) FirstOrInit() (*model.Message, error) {
|
||||
if result, err := m.DO.FirstOrInit(); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return result.(*model.Message), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m messageDo) FirstOrCreate() (*model.Message, error) {
|
||||
if result, err := m.DO.FirstOrCreate(); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return result.(*model.Message), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m messageDo) FindByPage(offset int, limit int) (result []*model.Message, count int64, err error) {
|
||||
result, err = m.Offset(offset).Limit(limit).Find()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if size := len(result); 0 < limit && 0 < size && size < limit {
|
||||
count = int64(size + offset)
|
||||
return
|
||||
}
|
||||
|
||||
count, err = m.Offset(-1).Limit(-1).Count()
|
||||
return
|
||||
}
|
||||
|
||||
func (m messageDo) ScanByPage(result interface{}, offset int, limit int) (count int64, err error) {
|
||||
count, err = m.Count()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = m.Offset(offset).Limit(limit).Scan(result)
|
||||
return
|
||||
}
|
||||
|
||||
func (m messageDo) Scan(result interface{}) (err error) {
|
||||
return m.DO.Scan(result)
|
||||
}
|
||||
|
||||
func (m messageDo) Delete(models ...*model.Message) (result gen.ResultInfo, err error) {
|
||||
return m.DO.Delete(models)
|
||||
}
|
||||
|
||||
func (m *messageDo) withDO(do gen.Dao) *messageDo {
|
||||
m.DO = *do.(*gen.DO)
|
||||
return m
|
||||
}
|
||||
42
backend/domain/conversation/message/repository/repository.go
Normal file
42
backend/domain/conversation/message/repository/repository.go
Normal file
@@ -0,0 +1,42 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/internal/dal"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
)
|
||||
|
||||
func NewMessageRepo(db *gorm.DB, idGen idgen.IDGenerator) MessageRepo {
|
||||
return dal.NewMessageDAO(db, idGen)
|
||||
}
|
||||
|
||||
type MessageRepo interface {
|
||||
Create(ctx context.Context, msg *entity.Message) (*entity.Message, error)
|
||||
List(ctx context.Context, conversationID int64, limit int, cursor int64,
|
||||
direction entity.ScrollPageDirection, messageType *message.MessageType) ([]*entity.Message, bool, error)
|
||||
GetByRunIDs(ctx context.Context, runIDs []int64, orderBy string) ([]*entity.Message, error)
|
||||
Edit(ctx context.Context, msgID int64, message *message.Message) (int64, error)
|
||||
GetByID(ctx context.Context, msgID int64) (*entity.Message, error)
|
||||
Delete(ctx context.Context, msgIDs []int64, runIDs []int64) error
|
||||
}
|
||||
33
backend/domain/conversation/message/service/message.go
Normal file
33
backend/domain/conversation/message/service/message.go
Normal file
@@ -0,0 +1,33 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package message
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
|
||||
)
|
||||
|
||||
type Message interface {
|
||||
List(ctx context.Context, req *entity.ListMeta) (*entity.ListResult, error)
|
||||
Create(ctx context.Context, req *entity.Message) (*entity.Message, error)
|
||||
GetByRunIDs(ctx context.Context, conversationID int64, runIDs []int64) ([]*entity.Message, error)
|
||||
GetByID(ctx context.Context, id int64) (*entity.Message, error)
|
||||
Edit(ctx context.Context, req *entity.Message) (*entity.Message, error)
|
||||
Delete(ctx context.Context, req *entity.DeleteMeta) error
|
||||
Broken(ctx context.Context, req *entity.BrokenMeta) error
|
||||
}
|
||||
130
backend/domain/conversation/message/service/message_impl.go
Normal file
130
backend/domain/conversation/message/service/message_impl.go
Normal file
@@ -0,0 +1,130 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package message
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/repository"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
type messageImpl struct {
|
||||
Components
|
||||
}
|
||||
|
||||
type Components struct {
|
||||
MessageRepo repository.MessageRepo
|
||||
}
|
||||
|
||||
func NewService(c *Components) Message {
|
||||
return &messageImpl{
|
||||
Components: *c,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *messageImpl) Create(ctx context.Context, msg *entity.Message) (*entity.Message, error) {
|
||||
// create message
|
||||
msg, err := m.MessageRepo.Create(ctx, msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (m *messageImpl) List(ctx context.Context, req *entity.ListMeta) (*entity.ListResult, error) {
|
||||
resp := &entity.ListResult{}
|
||||
|
||||
// get message with query
|
||||
messageList, hasMore, err := m.MessageRepo.List(ctx, req.ConversationID, req.Limit, req.Cursor, req.Direction, ptr.Of(message.MessageTypeQuestion))
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
resp.Direction = req.Direction
|
||||
resp.HasMore = hasMore
|
||||
|
||||
if len(messageList) > 0 {
|
||||
resp.PrevCursor = messageList[len(messageList)-1].CreatedAt
|
||||
resp.NextCursor = messageList[0].CreatedAt
|
||||
|
||||
var runIDs []int64
|
||||
for _, m := range messageList {
|
||||
runIDs = append(runIDs, m.RunID)
|
||||
}
|
||||
orderBy := "DESC"
|
||||
if req.OrderBy != nil {
|
||||
orderBy = *req.OrderBy
|
||||
}
|
||||
allMessageList, err := m.MessageRepo.GetByRunIDs(ctx, runIDs, orderBy)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
resp.Messages = allMessageList
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (m *messageImpl) GetByRunIDs(ctx context.Context, conversationID int64, runIDs []int64) ([]*entity.Message, error) {
|
||||
messageList, err := m.MessageRepo.GetByRunIDs(ctx, runIDs, "ASC")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return messageList, nil
|
||||
}
|
||||
|
||||
func (m *messageImpl) Edit(ctx context.Context, req *entity.Message) (*entity.Message, error) {
|
||||
_, err := m.MessageRepo.Edit(ctx, req.ID, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (m *messageImpl) Delete(ctx context.Context, req *entity.DeleteMeta) error {
|
||||
err := m.MessageRepo.Delete(ctx, req.MessageIDs, req.RunIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *messageImpl) GetByID(ctx context.Context, id int64) (*entity.Message, error) {
|
||||
msg, err := m.MessageRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (m *messageImpl) Broken(ctx context.Context, req *entity.BrokenMeta) error {
|
||||
|
||||
_, err := m.MessageRepo.Edit(ctx, req.ID, &message.Message{
|
||||
Status: message.MessageStatusBroken,
|
||||
Position: ptr.From(req.Position),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
247
backend/domain/conversation/message/service/message_test.go
Normal file
247
backend/domain/conversation/message/service/message_test.go
Normal file
@@ -0,0 +1,247 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package message
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/internal/dal/model"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/repository"
|
||||
mock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/orm"
|
||||
)
|
||||
|
||||
// Test_NewListMessage tests the NewListMessage function
|
||||
func TestListMessage(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
mockDBGen := orm.NewMockDB()
|
||||
|
||||
mockDBGen.AddTable(&model.Message{}).
|
||||
AddRows(
|
||||
&model.Message{
|
||||
ID: 1,
|
||||
ConversationID: 1,
|
||||
UserID: "1",
|
||||
},
|
||||
&model.Message{
|
||||
ID: 2,
|
||||
ConversationID: 1,
|
||||
UserID: "1",
|
||||
},
|
||||
)
|
||||
|
||||
mockDB, err := mockDBGen.DB()
|
||||
assert.NoError(t, err)
|
||||
|
||||
components := &Components{
|
||||
MessageRepo: repository.NewMessageRepo(mockDB, nil),
|
||||
}
|
||||
|
||||
resp, err := NewService(components).List(ctx, &entity.ListMeta{
|
||||
ConversationID: 1,
|
||||
Limit: 1,
|
||||
UserID: "1",
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, resp.Messages, 0)
|
||||
}
|
||||
|
||||
// Test_NewListMessage tests the NewListMessage function
|
||||
func TestCreateMessage(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
idGen := mock.NewMockIDGenerator(ctrl)
|
||||
idGen.EXPECT().GenID(gomock.Any()).Return(int64(10), nil).Times(1)
|
||||
|
||||
mockDBGen := orm.NewMockDB()
|
||||
mockDBGen.AddTable(&model.Message{})
|
||||
mockDB, err := mockDBGen.DB()
|
||||
|
||||
// redisCli := redis.New()
|
||||
// idGen, _ := idgen.New(redisCli)
|
||||
// mockDB, err := mysql.New()
|
||||
|
||||
assert.NoError(t, err)
|
||||
|
||||
components := &Components{
|
||||
MessageRepo: repository.NewMessageRepo(mockDB, idGen),
|
||||
}
|
||||
imageInput := &message.FileData{
|
||||
Url: "https://xxxxx.xxxx/image",
|
||||
Name: "test_img",
|
||||
}
|
||||
fileInput := &message.FileData{
|
||||
Url: "https://xxxxx.xxxx/file",
|
||||
Name: "test_file",
|
||||
}
|
||||
content := []*message.InputMetaData{
|
||||
{
|
||||
Type: message.InputTypeText,
|
||||
Text: "解析图片中的内容",
|
||||
},
|
||||
{
|
||||
Type: message.InputTypeImage,
|
||||
FileData: []*message.FileData{
|
||||
imageInput,
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: message.InputTypeFile,
|
||||
FileData: []*message.FileData{
|
||||
fileInput,
|
||||
},
|
||||
},
|
||||
}
|
||||
service := NewService(components)
|
||||
insert := &entity.Message{
|
||||
ID: 7498710126354759680,
|
||||
ConversationID: 7496795464885338112,
|
||||
AgentID: 7366055842027922437,
|
||||
UserID: "6666666",
|
||||
RunID: 7498710102375923712,
|
||||
Content: "你是谁?",
|
||||
MultiContent: content,
|
||||
Role: schema.Assistant,
|
||||
MessageType: message.MessageTypeFunctionCall,
|
||||
SectionID: 7496795464897921024,
|
||||
ModelContent: "{\"role\":\"tool\",\"content\":\"tool call\"}",
|
||||
ContentType: message.ContentTypeMix,
|
||||
}
|
||||
resp, err := service.Create(ctx, insert)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, int64(7366055842027922437), resp.AgentID)
|
||||
assert.Equal(t, "你是谁?", resp.Content)
|
||||
}
|
||||
|
||||
func TestEditMessage(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mockDBGen := orm.NewMockDB()
|
||||
|
||||
mockDBGen.AddTable(&model.Message{}).
|
||||
AddRows(
|
||||
&model.Message{
|
||||
ID: 1,
|
||||
ConversationID: 1,
|
||||
UserID: "1",
|
||||
RunID: 123,
|
||||
},
|
||||
&model.Message{
|
||||
ID: 2,
|
||||
ConversationID: 1,
|
||||
UserID: "1",
|
||||
RunID: 124,
|
||||
},
|
||||
)
|
||||
|
||||
mockDB, err := mockDBGen.DB()
|
||||
assert.NoError(t, err)
|
||||
|
||||
components := &Components{
|
||||
MessageRepo: repository.NewMessageRepo(mockDB, nil),
|
||||
}
|
||||
|
||||
imageInput := &message.FileData{
|
||||
Url: "https://xxxxx.xxxx/image",
|
||||
Name: "test_img",
|
||||
}
|
||||
fileInput := &message.FileData{
|
||||
Url: "https://xxxxx.xxxx/file",
|
||||
Name: "test_file",
|
||||
}
|
||||
content := []*message.InputMetaData{
|
||||
{
|
||||
Type: message.InputTypeText,
|
||||
Text: "解析图片中的内容",
|
||||
},
|
||||
{
|
||||
Type: message.InputTypeImage,
|
||||
FileData: []*message.FileData{
|
||||
imageInput,
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: message.InputTypeFile,
|
||||
FileData: []*message.FileData{
|
||||
fileInput,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := NewService(components).Edit(ctx, &entity.Message{
|
||||
ID: 2,
|
||||
Content: "test edit message",
|
||||
MultiContent: content,
|
||||
})
|
||||
_ = resp
|
||||
|
||||
msOne, err := NewService(components).GetByRunIDs(ctx, 1, []int64{124})
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, int64(124), msOne[0].RunID)
|
||||
}
|
||||
|
||||
func TestGetByRunIDs(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
mockDBGen := orm.NewMockDB()
|
||||
|
||||
mockDBGen.AddTable(&model.Message{}).
|
||||
AddRows(
|
||||
&model.Message{
|
||||
ID: 1,
|
||||
ConversationID: 1,
|
||||
UserID: "1",
|
||||
RunID: 123,
|
||||
Content: "test content123",
|
||||
},
|
||||
&model.Message{
|
||||
ID: 2,
|
||||
ConversationID: 1,
|
||||
UserID: "1",
|
||||
Content: "test content124",
|
||||
RunID: 124,
|
||||
},
|
||||
&model.Message{
|
||||
ID: 3,
|
||||
ConversationID: 1,
|
||||
UserID: "1",
|
||||
Content: "test content124",
|
||||
RunID: 124,
|
||||
},
|
||||
)
|
||||
mockDB, err := mockDBGen.DB()
|
||||
assert.NoError(t, err)
|
||||
components := &Components{
|
||||
MessageRepo: repository.NewMessageRepo(mockDB, nil),
|
||||
}
|
||||
|
||||
resp, err := NewService(components).GetByRunIDs(ctx, 1, []int64{124})
|
||||
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Len(t, resp, 2)
|
||||
}
|
||||
Reference in New Issue
Block a user