feat: Support for Chat Flow & Agent Support for binding a single chat flow (#765)

Co-authored-by: Yu Yang <72337138+tomasyu985@users.noreply.github.com>
Co-authored-by: zengxiaohui <csu.zengxiaohui@gmail.com>
Co-authored-by: lijunwen.gigoo <lijunwen.gigoo@bytedance.com>
Co-authored-by: lvxinyu.1117 <lvxinyu.1117@bytedance.com>
Co-authored-by: liuyunchao.0510 <liuyunchao.0510@bytedance.com>
Co-authored-by: haozhenfei <37089575+haozhenfei@users.noreply.github.com>
Co-authored-by: July <jiangxujin@bytedance.com>
Co-authored-by: tecvan-fe <fanwenjie.fe@bytedance.com>
This commit is contained in:
Zhj
2025-08-28 21:53:32 +08:00
committed by GitHub
parent bbc615a18e
commit d70101c979
503 changed files with 48036 additions and 3427 deletions

View File

@@ -43,6 +43,7 @@ type RunRecordMeta struct {
ChatRequest *string `json:"chat_message"`
CompletedAt int64 `json:"completed_at"`
FailedAt int64 `json:"failed_at"`
CreatorID int64 `json:"creator_id"`
}
type ChunkRunItem = RunRecordMeta
@@ -158,3 +159,18 @@ type ModelAnswerEvent struct {
Message *schema.Message
Err error
}
type ListRunRecordMeta struct {
ConversationID int64 `json:"conversation_id"`
AgentID int64 `json:"agent_id"`
SectionID int64 `json:"section_id"`
Limit int32 `json:"limit"`
OrderBy string `json:"order_by"` //desc asc
BeforeID int64 `json:"before_id"`
AfterID int64 `json:"after_id"`
}
type CancelRunMeta struct {
ConversationID int64 `json:"conversation_id"`
RunID int64 `json:"run_id"`
}

View File

@@ -0,0 +1,45 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package internal
import (
"context"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
crossagent "github.com/coze-dev/coze-studio/backend/crossdomain/contract/agent"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
func getAgentHistoryRounds(agentInfo *singleagent.SingleAgent) int32 {
var conversationTurns int32 = entity.ConversationTurnsDefault
if agentInfo != nil && agentInfo.ModelInfo != nil && agentInfo.ModelInfo.ShortMemoryPolicy != nil && ptr.From(agentInfo.ModelInfo.ShortMemoryPolicy.HistoryRound) > 0 {
conversationTurns = ptr.From(agentInfo.ModelInfo.ShortMemoryPolicy.HistoryRound)
}
return conversationTurns
}
func getAgentInfo(ctx context.Context, agentID int64, isDraft bool) (*singleagent.SingleAgent, error) {
agentInfo, err := crossagent.DefaultSVC().ObtainAgentByIdentity(ctx, &singleagent.AgentIdentity{
AgentID: agentID,
IsDraft: isDraft,
})
if err != nil {
return nil, err
}
return agentInfo, nil
}

View File

@@ -0,0 +1,215 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package internal
import (
"bytes"
"context"
"errors"
"io"
"strconv"
"sync"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
crossworkflow "github.com/coze-dev/coze-studio/backend/crossdomain/contract/workflow"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
msgEntity "github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
func (art *AgentRuntime) ChatflowRun(ctx context.Context, imagex imagex.ImageX) (err error) {
mh := &MesssageEventHanlder{
sw: art.SW,
messageEvent: art.MessageEvent,
}
resumeInfo := parseResumeInfo(ctx, art.GetHistory())
wfID, _ := strconv.ParseInt(art.GetAgentInfo().LayoutInfo.WorkflowId, 10, 64)
if wfID == 0 {
mh.handlerErr(ctx, errorx.New(errno.ErrAgentRunWorkflowNotFound))
return
}
var wfStreamer *schema.StreamReader[*crossworkflow.WorkflowMessage]
executeConfig := crossworkflow.ExecuteConfig{
ID: wfID,
ConnectorID: art.GetRunMeta().ConnectorID,
ConnectorUID: art.GetRunMeta().UserID,
AgentID: ptr.Of(art.GetRunMeta().AgentID),
Mode: crossworkflow.ExecuteModeRelease,
BizType: crossworkflow.BizTypeAgent,
SyncPattern: crossworkflow.SyncPatternStream,
From: crossworkflow.FromLatestVersion,
}
if resumeInfo != nil {
wfStreamer, err = crossworkflow.DefaultSVC().StreamResume(ctx, &crossworkflow.ResumeRequest{
ResumeData: concatWfInput(art),
EventID: resumeInfo.ChatflowInterrupt.InterruptEvent.ID,
ExecuteID: resumeInfo.ChatflowInterrupt.ExecuteID,
}, executeConfig)
} else {
executeConfig.ConversationID = &art.GetRunMeta().ConversationID
executeConfig.SectionID = &art.GetRunMeta().SectionID
executeConfig.InitRoundID = &art.RunRecord.ID
executeConfig.RoundID = &art.RunRecord.ID
executeConfig.UserMessage = transMessageToSchemaMessage(ctx, []*msgEntity.Message{art.GetInput()}, imagex)[0]
executeConfig.MaxHistoryRounds = ptr.Of(getAgentHistoryRounds(art.GetAgentInfo()))
wfStreamer, err = crossworkflow.DefaultSVC().StreamExecute(ctx, executeConfig, map[string]any{
"USER_INPUT": concatWfInput(art),
})
}
if err != nil {
return err
}
var wg sync.WaitGroup
wg.Add(1)
safego.Go(ctx, func() {
defer wg.Done()
art.pullWfStream(ctx, wfStreamer, mh)
})
wg.Wait()
return err
}
func concatWfInput(rtDependence *AgentRuntime) string {
var input string
for _, content := range rtDependence.RunMeta.Content {
if content.Type == message.InputTypeText {
input = content.Text + "," + input
} else {
for _, file := range content.FileData {
input += file.Url + ","
}
}
}
return input
}
func (art *AgentRuntime) pullWfStream(ctx context.Context, events *schema.StreamReader[*crossworkflow.WorkflowMessage], mh *MesssageEventHanlder) {
fullAnswerContent := bytes.NewBuffer([]byte{})
var usage *msgEntity.UsageExt
preAnswerMsg, cErr := preCreateAnswer(ctx, art)
if cErr != nil {
return
}
var preMsgIsFinish = false
var lastAnswerMsg *entity.ChunkMessageItem
for {
st, re := events.Recv()
if re != nil {
if errors.Is(re, io.EOF) {
if lastAnswerMsg != nil && usage != nil {
art.SetUsage(&agentrun.Usage{
LlmPromptTokens: usage.InputTokens,
LlmCompletionTokens: usage.OutputTokens,
LlmTotalTokens: usage.TotalCount,
})
_ = mh.handlerWfUsage(ctx, lastAnswerMsg, usage)
}
finishErr := mh.handlerFinalAnswerFinish(ctx, art)
if finishErr != nil {
logs.CtxErrorf(ctx, "handlerFinalAnswerFinish error: %v", finishErr)
return
}
return
}
logs.CtxErrorf(ctx, "pullWfStream Recv error: %v", re)
mh.handlerErr(ctx, re)
return
}
if st == nil {
continue
}
if st.StateMessage != nil {
if st.StateMessage.Status == crossworkflow.WorkflowFailed {
mh.handlerErr(ctx, st.StateMessage.LastError)
continue
}
if st.StateMessage.Usage != nil {
usage = &msgEntity.UsageExt{
InputTokens: st.StateMessage.Usage.InputTokens,
OutputTokens: st.StateMessage.Usage.OutputTokens,
TotalCount: st.StateMessage.Usage.InputTokens + st.StateMessage.Usage.OutputTokens,
}
}
if st.StateMessage.InterruptEvent != nil { // interrupt
mh.handlerWfInterruptMsg(ctx, st.StateMessage, art)
continue
}
}
if st.DataMessage == nil {
continue
}
switch st.DataMessage.Type {
case crossworkflow.Answer:
// input node & question node skip
if st.DataMessage != nil && (st.DataMessage.NodeType == crossworkflow.NodeTypeInputReceiver || st.DataMessage.NodeType == crossworkflow.NodeTypeQuestion) {
break
}
if preMsgIsFinish {
preAnswerMsg, cErr = preCreateAnswer(ctx, art)
if cErr != nil {
return
}
preMsgIsFinish = false
}
if st.DataMessage.Content != "" {
fullAnswerContent.WriteString(st.DataMessage.Content)
}
sendAnswerMsg := buildSendMsg(ctx, preAnswerMsg, false, art)
sendAnswerMsg.Content = st.DataMessage.Content
mh.messageEvent.SendMsgEvent(entity.RunEventMessageDelta, sendAnswerMsg, mh.sw)
if st.DataMessage.Last {
preMsgIsFinish = true
sendAnswerMsg := buildSendMsg(ctx, preAnswerMsg, false, art)
sendAnswerMsg.Content = fullAnswerContent.String()
fullAnswerContent.Reset()
hfErr := mh.handlerAnswer(ctx, sendAnswerMsg, usage, art, preAnswerMsg)
if hfErr != nil {
return
}
lastAnswerMsg = sendAnswerMsg
}
}
}
}

View File

@@ -19,6 +19,8 @@ package dal
import (
"context"
"encoding/json"
"strconv"
"strings"
"time"
"gorm.io/gorm"
@@ -27,6 +29,7 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/internal/dal/model"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/internal/dal/query"
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
)
@@ -59,8 +62,12 @@ func (dao *RunRecordDAO) Create(ctx context.Context, runMeta *entity.AgentRunMet
return dao.buildPo2Do(createPO), nil
}
func (dao *RunRecordDAO) GetByID(ctx context.Context, id int64) (*model.RunRecord, error) {
return dao.query.RunRecord.WithContext(ctx).Where(dao.query.RunRecord.ID.Eq(id)).First()
func (dao *RunRecordDAO) GetByID(ctx context.Context, id int64) (*entity.RunRecordMeta, error) {
po, err := dao.query.RunRecord.WithContext(ctx).Where(dao.query.RunRecord.ID.Eq(id)).First()
if err != nil {
return nil, err
}
return dao.buildPo2Do(po), nil
}
func (dao *RunRecordDAO) UpdateByID(ctx context.Context, id int64, updateMeta *entity.UpdateMeta) error {
@@ -106,20 +113,40 @@ func (dao *RunRecordDAO) Delete(ctx context.Context, id []int64) error {
return err
}
func (dao *RunRecordDAO) List(ctx context.Context, conversationID int64, sectionID int64, limit int32) ([]*model.RunRecord, error) {
logs.CtxInfof(ctx, "list run record req:%v, sectionID:%v, limit:%v", conversationID, sectionID, limit)
func (dao *RunRecordDAO) List(ctx context.Context, meta *entity.ListRunRecordMeta) ([]*entity.RunRecordMeta, error) {
logs.CtxInfof(ctx, "list run record req:%v, sectionID:%v, limit:%v", meta.ConversationID, meta.SectionID, meta.Limit)
m := dao.query.RunRecord
do := m.WithContext(ctx).Where(m.ConversationID.Eq(conversationID)).Debug().Where(m.Status.NotIn(string(entity.RunStatusDeleted)))
if sectionID > 0 {
do = do.Where(m.SectionID.Eq(sectionID))
do := m.WithContext(ctx).Where(m.ConversationID.Eq(meta.ConversationID)).Debug().Where(m.Status.NotIn(string(entity.RunStatusDeleted)))
if meta.BeforeID > 0 {
runRecord, err := m.Where(m.ID.Eq(meta.BeforeID)).First()
if err != nil {
return nil, err
}
do = do.Where(m.CreatedAt.Lt(runRecord.CreatedAt))
}
if limit > 0 {
do = do.Limit(int(limit))
if meta.AfterID > 0 {
runRecord, err := m.Where(m.ID.Eq(meta.AfterID)).First()
if err != nil {
return nil, err
}
do = do.Where(m.CreatedAt.Gt(runRecord.CreatedAt))
}
if meta.SectionID > 0 {
do = do.Where(m.SectionID.Eq(meta.SectionID))
}
if meta.Limit > 0 {
do = do.Limit(int(meta.Limit))
}
if strings.ToLower(meta.OrderBy) == "asc" {
do = do.Order(m.CreatedAt.Asc())
} else {
do = do.Order(m.CreatedAt.Desc())
}
runRecords, err := do.Order(m.CreatedAt.Desc()).Find()
return runRecords, err
runRecords, err := do.Find()
return slices.Transform(runRecords, func(item *model.RunRecord) *entity.RunRecordMeta {
return dao.buildPo2Do(item)
}), err
}
func (dao *RunRecordDAO) buildCreatePO(ctx context.Context, runMeta *entity.AgentRunMeta) (*model.RunRecord, error) {
@@ -135,7 +162,10 @@ func (dao *RunRecordDAO) buildCreatePO(ctx context.Context, runMeta *entity.Agen
}
timeNow := time.Now().UnixMilli()
creatorID, err := strconv.ParseInt(runMeta.UserID, 10, 64)
if err != nil {
return nil, err
}
return &model.RunRecord{
ID: runID,
ConversationID: runMeta.ConversationID,
@@ -145,6 +175,7 @@ func (dao *RunRecordDAO) buildCreatePO(ctx context.Context, runMeta *entity.Agen
ChatRequest: string(reqOrigin),
UserID: runMeta.UserID,
CreatedAt: timeNow,
CreatorID: creatorID,
}, nil
}
@@ -161,7 +192,21 @@ func (dao *RunRecordDAO) buildPo2Do(po *model.RunRecord) *entity.RunRecordMeta {
CompletedAt: po.CompletedAt,
FailedAt: po.FailedAt,
Usage: po.Usage,
CreatorID: po.CreatorID,
}
return runMeta
}
func (dao *RunRecordDAO) Cancel(ctx context.Context, meta *entity.CancelRunMeta) (*entity.RunRecordMeta, error) {
m := dao.query.RunRecord
_, err := m.WithContext(ctx).Where(m.ID.Eq(meta.RunID)).UpdateColumns(map[string]interface{}{
"updated_at": time.Now().UnixMilli(),
"status": entity.RunEventCancelled,
})
if err != nil {
return nil, err
}
return dao.GetByID(ctx, meta.RunID)
}

View File

@@ -1,78 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package internal
import (
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
)
type Event struct {
}
func NewEvent() *Event {
return &Event{}
}
func (e *Event) buildMessageEvent(runEvent entity.RunEvent, chunkMsgItem *entity.ChunkMessageItem) *entity.AgentRunResponse {
return &entity.AgentRunResponse{
Event: runEvent,
ChunkMessageItem: chunkMsgItem,
}
}
func (e *Event) buildRunEvent(runEvent entity.RunEvent, chunkRunItem *entity.ChunkRunItem) *entity.AgentRunResponse {
return &entity.AgentRunResponse{
Event: runEvent,
ChunkRunItem: chunkRunItem,
}
}
func (e *Event) buildErrEvent(runEvent entity.RunEvent, err *entity.RunError) *entity.AgentRunResponse {
return &entity.AgentRunResponse{
Event: runEvent,
Error: err,
}
}
func (e *Event) buildStreamDoneEvent() *entity.AgentRunResponse {
return &entity.AgentRunResponse{
Event: entity.RunEventStreamDone,
}
}
func (e *Event) SendRunEvent(runEvent entity.RunEvent, runItem *entity.ChunkRunItem, sw *schema.StreamWriter[*entity.AgentRunResponse]) {
resp := e.buildRunEvent(runEvent, runItem)
sw.Send(resp, nil)
}
func (e *Event) SendMsgEvent(runEvent entity.RunEvent, messageItem *entity.ChunkMessageItem, sw *schema.StreamWriter[*entity.AgentRunResponse]) {
resp := e.buildMessageEvent(runEvent, messageItem)
sw.Send(resp, nil)
}
func (e *Event) SendErrEvent(runEvent entity.RunEvent, sw *schema.StreamWriter[*entity.AgentRunResponse], err *entity.RunError) {
resp := e.buildErrEvent(runEvent, err)
sw.Send(resp, nil)
}
func (e *Event) SendStreamDoneEvent(sw *schema.StreamWriter[*entity.AgentRunResponse]) {
resp := e.buildStreamDoneEvent()
sw.Send(resp, nil)
}

View File

@@ -0,0 +1,512 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package internal
import (
"context"
"encoding/json"
"fmt"
"strconv"
"time"
"github.com/cloudwego/eino/schema"
messageModel "github.com/coze-dev/coze-studio/backend/api/model/conversation/message"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
crossagent "github.com/coze-dev/coze-studio/backend/crossdomain/contract/agent"
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
crossmessage "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message"
crossworkflow "github.com/coze-dev/coze-studio/backend/crossdomain/contract/workflow"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
msgEntity "github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
func buildSendMsg(_ context.Context, msg *msgEntity.Message, isFinish bool, rtDependence *AgentRuntime) *entity.ChunkMessageItem {
copyMap := make(map[string]string)
for k, v := range msg.Ext {
copyMap[k] = v
}
return &entity.ChunkMessageItem{
ID: msg.ID,
ConversationID: msg.ConversationID,
SectionID: msg.SectionID,
AgentID: msg.AgentID,
Content: msg.Content,
Role: entity.RoleTypeAssistant,
ContentType: msg.ContentType,
MessageType: msg.MessageType,
ReplyID: rtDependence.GetQuestionMsgID(),
Type: msg.MessageType,
CreatedAt: msg.CreatedAt,
UpdatedAt: msg.UpdatedAt,
RunID: rtDependence.GetRunRecord().ID,
Ext: copyMap,
IsFinish: isFinish,
ReasoningContent: ptr.Of(msg.ReasoningContent),
}
}
func buildKnowledge(_ context.Context, chunk *entity.AgentRespEvent) *msgEntity.VerboseInfo {
var recallDatas []msgEntity.RecallDataInfo
for _, kOne := range chunk.Knowledge {
recallDatas = append(recallDatas, msgEntity.RecallDataInfo{
Slice: kOne.Content,
Meta: msgEntity.MetaInfo{
Dataset: msgEntity.DatasetInfo{
ID: kOne.MetaData["dataset_id"].(string),
Name: kOne.MetaData["dataset_name"].(string),
},
Document: msgEntity.DocumentInfo{
ID: kOne.MetaData["document_id"].(string),
Name: kOne.MetaData["document_name"].(string),
},
},
Score: kOne.Score(),
})
}
verboseData := &msgEntity.VerboseData{
Chunks: recallDatas,
OriReq: "",
StatusCode: 0,
}
data, err := json.Marshal(verboseData)
if err != nil {
return nil
}
knowledgeInfo := &msgEntity.VerboseInfo{
MessageType: string(entity.MessageSubTypeKnowledgeCall),
Data: string(data),
}
return knowledgeInfo
}
func buildBotStateExt(arm *entity.AgentRunMeta) *msgEntity.BotStateExt {
agentID := strconv.FormatInt(arm.AgentID, 10)
botStateExt := &msgEntity.BotStateExt{
AgentID: agentID,
AgentName: arm.Name,
Awaiting: agentID,
BotID: agentID,
}
return botStateExt
}
type irMsg struct {
Type string `json:"type,omitempty"`
ContentType string `json:"content_type"`
Content any `json:"content"` // either optionContent or string
ID string `json:"id,omitempty"`
}
func parseInterruptData(_ context.Context, interruptData *singleagent.InterruptInfo) (string, message.ContentType, error) {
defaultContentType := message.ContentTypeText
switch interruptData.InterruptType {
case singleagent.InterruptEventType_OauthPlugin:
data := interruptData.AllToolInterruptData[interruptData.ToolCallID].ToolNeedOAuth.Message
return data, defaultContentType, nil
case singleagent.InterruptEventType_Question:
data := interruptData.AllWfInterruptData[interruptData.ToolCallID].InterruptData
return processQuestionInterruptData(data)
case singleagent.InterruptEventType_InputNode:
data := interruptData.AllWfInterruptData[interruptData.ToolCallID].InterruptData
return processInputNodeInterruptData(data)
case singleagent.InterruptEventType_WorkflowLLM:
toolInterruptEvent := interruptData.AllWfInterruptData[interruptData.ToolCallID].ToolInterruptEvent
data := toolInterruptEvent.InterruptData
if singleagent.InterruptEventType(toolInterruptEvent.EventType) == singleagent.InterruptEventType_InputNode {
return processInputNodeInterruptData(data)
}
if singleagent.InterruptEventType(toolInterruptEvent.EventType) == singleagent.InterruptEventType_Question {
return processQuestionInterruptData(data)
}
return "", defaultContentType, errorx.New(errno.ErrUnknowInterruptType)
}
return "", defaultContentType, errorx.New(errno.ErrUnknowInterruptType)
}
func processQuestionInterruptData(data string) (string, message.ContentType, error) {
defaultContentType := message.ContentTypeText
var iData map[string][]*irMsg
err := json.Unmarshal([]byte(data), &iData)
if err != nil {
return "", defaultContentType, err
}
if len(iData["messages"]) == 0 {
return "", defaultContentType, errorx.New(errno.ErrInterruptDataEmpty)
}
interruptMsg := iData["messages"][0]
if interruptMsg.ContentType == "text" {
return interruptMsg.Content.(string), defaultContentType, nil
} else if interruptMsg.ContentType == "option" || interruptMsg.ContentType == "form_schema" {
iMarshalData, err := json.Marshal(interruptMsg)
if err != nil {
return "", defaultContentType, err
}
return string(iMarshalData), message.ContentTypeCard, nil
}
return "", defaultContentType, errorx.New(errno.ErrUnknowInterruptType)
}
func processInputNodeInterruptData(data string) (string, message.ContentType, error) {
return data, message.ContentTypeCard, nil
}
func handlerUsage(meta *schema.ResponseMeta) *msgEntity.UsageExt {
if meta == nil || meta.Usage == nil {
return nil
}
return &msgEntity.UsageExt{
TotalCount: int64(meta.Usage.TotalTokens),
InputTokens: int64(meta.Usage.PromptTokens),
OutputTokens: int64(meta.Usage.CompletionTokens),
}
}
func preCreateAnswer(ctx context.Context, rtDependence *AgentRuntime) (*msgEntity.Message, error) {
arm := rtDependence.RunMeta
msgMeta := &msgEntity.Message{
ConversationID: arm.ConversationID,
RunID: rtDependence.RunRecord.ID,
AgentID: arm.AgentID,
SectionID: arm.SectionID,
UserID: arm.UserID,
Role: schema.Assistant,
MessageType: message.MessageTypeAnswer,
ContentType: message.ContentTypeText,
Ext: arm.Ext,
}
if arm.Ext == nil {
msgMeta.Ext = map[string]string{}
}
botStateExt := buildBotStateExt(arm)
bseString, err := json.Marshal(botStateExt)
if err != nil {
return nil, err
}
if _, ok := msgMeta.Ext[string(msgEntity.MessageExtKeyBotState)]; !ok {
msgMeta.Ext[string(msgEntity.MessageExtKeyBotState)] = string(bseString)
}
msgMeta.Ext = arm.Ext
return crossmessage.DefaultSVC().PreCreate(ctx, msgMeta)
}
func buildAgentMessage2Create(ctx context.Context, chunk *entity.AgentRespEvent, messageType message.MessageType, rtDependence *AgentRuntime) *message.Message {
arm := rtDependence.GetRunMeta()
msg := &msgEntity.Message{
ConversationID: arm.ConversationID,
RunID: rtDependence.RunRecord.ID,
AgentID: arm.AgentID,
SectionID: arm.SectionID,
UserID: arm.UserID,
MessageType: messageType,
}
buildExt := map[string]string{}
timeCost := fmt.Sprintf("%.1f", float64(time.Since(rtDependence.GetStartTime()).Milliseconds())/1000.00)
switch messageType {
case message.MessageTypeQuestion:
msg.Role = schema.User
msg.ContentType = arm.ContentType
for _, content := range arm.Content {
if content.Type == message.InputTypeText {
msg.Content = content.Text
break
}
}
msg.MultiContent = arm.Content
buildExt = arm.Ext
msg.DisplayContent = arm.DisplayContent
case message.MessageTypeAnswer, message.MessageTypeToolAsAnswer:
msg.Role = schema.Assistant
msg.ContentType = message.ContentTypeText
case message.MessageTypeToolResponse:
msg.Role = schema.Assistant
msg.ContentType = message.ContentTypeText
msg.Content = chunk.ToolsMessage[0].Content
buildExt[string(msgEntity.MessageExtKeyTimeCost)] = timeCost
modelContent := chunk.ToolsMessage[0]
mc, err := json.Marshal(modelContent)
if err == nil {
msg.ModelContent = string(mc)
}
case message.MessageTypeKnowledge:
msg.Role = schema.Assistant
msg.ContentType = message.ContentTypeText
knowledgeContent := buildKnowledge(ctx, chunk)
if knowledgeContent != nil {
knInfo, err := json.Marshal(knowledgeContent)
if err == nil {
msg.Content = string(knInfo)
}
}
buildExt[string(msgEntity.MessageExtKeyTimeCost)] = timeCost
modelContent := chunk.Knowledge
mc, err := json.Marshal(modelContent)
if err == nil {
msg.ModelContent = string(mc)
}
case message.MessageTypeFunctionCall:
msg.Role = schema.Assistant
msg.ContentType = message.ContentTypeText
if len(chunk.FuncCall.ToolCalls) > 0 {
toolCall := chunk.FuncCall.ToolCalls[0]
toolCalling, err := json.Marshal(toolCall)
if err == nil {
msg.Content = string(toolCalling)
}
buildExt[string(msgEntity.MessageExtKeyPlugin)] = toolCall.Function.Name
buildExt[string(msgEntity.MessageExtKeyToolName)] = toolCall.Function.Name
buildExt[string(msgEntity.MessageExtKeyTimeCost)] = timeCost
modelContent := chunk.FuncCall
mc, err := json.Marshal(modelContent)
if err == nil {
msg.ModelContent = string(mc)
}
}
case message.MessageTypeFlowUp:
msg.Role = schema.Assistant
msg.ContentType = message.ContentTypeText
msg.Content = chunk.Suggest.Content
case message.MessageTypeVerbose:
msg.Role = schema.Assistant
msg.ContentType = message.ContentTypeText
d := &entity.Data{
FinishReason: 0,
FinData: "",
}
dByte, _ := json.Marshal(d)
afc := &entity.AnswerFinshContent{
MsgType: entity.MessageSubTypeGenerateFinish,
Data: string(dByte),
}
afcMarshal, _ := json.Marshal(afc)
msg.Content = string(afcMarshal)
case message.MessageTypeInterrupt:
msg.Role = schema.Assistant
msg.MessageType = message.MessageTypeVerbose
msg.ContentType = message.ContentTypeText
afc := &entity.AnswerFinshContent{
MsgType: entity.MessageSubTypeInterrupt,
Data: "",
}
afcMarshal, _ := json.Marshal(afc)
msg.Content = string(afcMarshal)
// Add ext to save to context_message
interruptByte, err := json.Marshal(chunk.Interrupt)
if err == nil {
buildExt[string(msgEntity.ExtKeyResumeInfo)] = string(interruptByte)
}
buildExt[string(msgEntity.ExtKeyToolCallsIDs)] = chunk.Interrupt.ToolCallID
rc := &messageModel.RequiredAction{
Type: "submit_tool_outputs",
SubmitToolOutputs: &messageModel.SubmitToolOutputs{},
}
msg.RequiredAction = rc
rcExtByte, err := json.Marshal(rc)
if err == nil {
buildExt[string(msgEntity.ExtKeyRequiresAction)] = string(rcExtByte)
}
}
if messageType != message.MessageTypeQuestion {
botStateExt := buildBotStateExt(arm)
bseString, err := json.Marshal(botStateExt)
if err == nil {
buildExt[string(msgEntity.MessageExtKeyBotState)] = string(bseString)
}
}
msg.Ext = buildExt
return msg
}
func handlerWfInterruptEvent(_ context.Context, interruptEventData *crossworkflow.InterruptEvent) (string, message.ContentType, error) {
defaultContentType := message.ContentTypeText
switch singleagent.InterruptEventType(interruptEventData.EventType) {
case singleagent.InterruptEventType_OauthPlugin:
case singleagent.InterruptEventType_Question:
data := interruptEventData.InterruptData
return processQuestionInterruptData(data)
case singleagent.InterruptEventType_InputNode:
data := interruptEventData.InterruptData
return processInputNodeInterruptData(data)
case singleagent.InterruptEventType_WorkflowLLM:
data := interruptEventData.ToolInterruptEvent.InterruptData
if singleagent.InterruptEventType(interruptEventData.EventType) == singleagent.InterruptEventType_InputNode {
return processInputNodeInterruptData(data)
}
if singleagent.InterruptEventType(interruptEventData.EventType) == singleagent.InterruptEventType_Question {
return processQuestionInterruptData(data)
}
return "", defaultContentType, errorx.New(errno.ErrUnknowInterruptType)
}
return "", defaultContentType, errorx.New(errno.ErrUnknowInterruptType)
}
func historyPairs(historyMsg []*message.Message) []*message.Message {
fcMsgPairs := make(map[int64][]*message.Message)
for _, one := range historyMsg {
if one.MessageType != message.MessageTypeFunctionCall && one.MessageType != message.MessageTypeToolResponse {
continue
}
if _, ok := fcMsgPairs[one.RunID]; !ok {
fcMsgPairs[one.RunID] = []*message.Message{one}
} else {
fcMsgPairs[one.RunID] = append(fcMsgPairs[one.RunID], one)
}
}
var historyAfterPairs []*message.Message
for _, value := range historyMsg {
if value.MessageType == message.MessageTypeFunctionCall {
if len(fcMsgPairs[value.RunID])%2 == 0 {
historyAfterPairs = append(historyAfterPairs, value)
}
} else {
historyAfterPairs = append(historyAfterPairs, value)
}
}
return historyAfterPairs
}
func transMessageToSchemaMessage(ctx context.Context, msgs []*message.Message, imagexClient imagex.ImageX) []*schema.Message {
schemaMessage := make([]*schema.Message, 0, len(msgs))
for _, msgOne := range msgs {
if msgOne.ModelContent == "" {
continue
}
if msgOne.MessageType == message.MessageTypeVerbose || msgOne.MessageType == message.MessageTypeFlowUp {
continue
}
var sm *schema.Message
err := json.Unmarshal([]byte(msgOne.ModelContent), &sm)
if err != nil {
continue
}
if len(sm.ReasoningContent) > 0 {
sm.ReasoningContent = ""
}
schemaMessage = append(schemaMessage, parseMessageURI(ctx, sm, imagexClient))
}
return schemaMessage
}
func parseMessageURI(ctx context.Context, mcMsg *schema.Message, imagexClient imagex.ImageX) *schema.Message {
if mcMsg.MultiContent == nil {
return mcMsg
}
for k, one := range mcMsg.MultiContent {
switch one.Type {
case schema.ChatMessagePartTypeImageURL:
if one.ImageURL.URI != "" {
url, err := imagexClient.GetResourceURL(ctx, one.ImageURL.URI)
if err == nil {
mcMsg.MultiContent[k].ImageURL.URL = url.URL
}
}
case schema.ChatMessagePartTypeFileURL:
if one.FileURL.URI != "" {
url, err := imagexClient.GetResourceURL(ctx, one.FileURL.URI)
if err == nil {
mcMsg.MultiContent[k].FileURL.URL = url.URL
}
}
case schema.ChatMessagePartTypeAudioURL:
if one.AudioURL.URI != "" {
url, err := imagexClient.GetResourceURL(ctx, one.AudioURL.URI)
if err == nil {
mcMsg.MultiContent[k].AudioURL.URL = url.URL
}
}
case schema.ChatMessagePartTypeVideoURL:
if one.VideoURL.URI != "" {
url, err := imagexClient.GetResourceURL(ctx, one.VideoURL.URI)
if err == nil {
mcMsg.MultiContent[k].VideoURL.URL = url.URL
}
}
}
}
return mcMsg
}
func parseResumeInfo(_ context.Context, historyMsg []*message.Message) *crossagent.ResumeInfo {
var resumeInfo *crossagent.ResumeInfo
for i := len(historyMsg) - 1; i >= 0; i-- {
if historyMsg[i].MessageType == message.MessageTypeQuestion {
break
}
if historyMsg[i].MessageType == message.MessageTypeVerbose {
if historyMsg[i].Ext[string(msgEntity.ExtKeyResumeInfo)] != "" {
err := json.Unmarshal([]byte(historyMsg[i].Ext[string(msgEntity.ExtKeyResumeInfo)]), &resumeInfo)
if err != nil {
return nil
}
}
}
}
return resumeInfo
}
func buildSendRunRecord(_ context.Context, runRecord *entity.RunRecordMeta, runStatus entity.RunStatus) *entity.ChunkRunItem {
return &entity.ChunkRunItem{
ID: runRecord.ID,
ConversationID: runRecord.ConversationID,
AgentID: runRecord.AgentID,
SectionID: runRecord.SectionID,
Status: runStatus,
CreatedAt: runRecord.CreatedAt,
}
}

View File

@@ -0,0 +1,428 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package internal
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"strconv"
"strings"
"time"
"github.com/cloudwego/eino/schema"
"github.com/mohae/deepcopy"
crossmessage "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
crossworkflow "github.com/coze-dev/coze-studio/backend/crossdomain/contract/workflow"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
msgEntity "github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/types/consts"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
type Event struct {
}
func NewMessageEvent() *Event {
return &Event{}
}
func (e *Event) buildMessageEvent(runEvent entity.RunEvent, chunkMsgItem *entity.ChunkMessageItem) *entity.AgentRunResponse {
return &entity.AgentRunResponse{
Event: runEvent,
ChunkMessageItem: chunkMsgItem,
}
}
func (e *Event) buildRunEvent(runEvent entity.RunEvent, chunkRunItem *entity.ChunkRunItem) *entity.AgentRunResponse {
return &entity.AgentRunResponse{
Event: runEvent,
ChunkRunItem: chunkRunItem,
}
}
func (e *Event) buildErrEvent(runEvent entity.RunEvent, err *entity.RunError) *entity.AgentRunResponse {
return &entity.AgentRunResponse{
Event: runEvent,
Error: err,
}
}
func (e *Event) buildStreamDoneEvent() *entity.AgentRunResponse {
return &entity.AgentRunResponse{
Event: entity.RunEventStreamDone,
}
}
func (e *Event) SendRunEvent(runEvent entity.RunEvent, runItem *entity.ChunkRunItem, sw *schema.StreamWriter[*entity.AgentRunResponse]) {
resp := e.buildRunEvent(runEvent, runItem)
sw.Send(resp, nil)
}
func (e *Event) SendMsgEvent(runEvent entity.RunEvent, messageItem *entity.ChunkMessageItem, sw *schema.StreamWriter[*entity.AgentRunResponse]) {
resp := e.buildMessageEvent(runEvent, messageItem)
sw.Send(resp, nil)
}
func (e *Event) SendErrEvent(runEvent entity.RunEvent, sw *schema.StreamWriter[*entity.AgentRunResponse], err *entity.RunError) {
resp := e.buildErrEvent(runEvent, err)
sw.Send(resp, nil)
}
func (e *Event) SendStreamDoneEvent(sw *schema.StreamWriter[*entity.AgentRunResponse]) {
resp := e.buildStreamDoneEvent()
sw.Send(resp, nil)
}
type MesssageEventHanlder struct {
messageEvent *Event
sw *schema.StreamWriter[*entity.AgentRunResponse]
}
func (mh *MesssageEventHanlder) handlerErr(_ context.Context, err error) {
var errMsg string
var statusErr errorx.StatusError
if errors.As(err, &statusErr) {
errMsg = statusErr.Msg()
} else {
if strings.ToLower(os.Getenv(consts.RunMode)) != "debug" {
errMsg = "Internal Server Error"
} else {
errMsg = errorx.ErrorWithoutStack(err)
}
}
mh.messageEvent.SendErrEvent(entity.RunEventError, mh.sw, &entity.RunError{
Code: errno.ErrAgentRun,
Msg: errMsg,
})
}
func (mh *MesssageEventHanlder) handlerAckMessage(_ context.Context, input *msgEntity.Message) error {
sendMsg := &entity.ChunkMessageItem{
ID: input.ID,
ConversationID: input.ConversationID,
SectionID: input.SectionID,
AgentID: input.AgentID,
Role: entity.RoleType(input.Role),
MessageType: message.MessageTypeAck,
ReplyID: input.ID,
Content: input.Content,
ContentType: message.ContentTypeText,
IsFinish: true,
}
mh.messageEvent.SendMsgEvent(entity.RunEventAck, sendMsg, mh.sw)
return nil
}
func (mh *MesssageEventHanlder) handlerFunctionCall(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime) error {
cm := buildAgentMessage2Create(ctx, chunk, message.MessageTypeFunctionCall, rtDependence)
cmData, err := crossmessage.DefaultSVC().Create(ctx, cm)
if err != nil {
return err
}
sendMsg := buildSendMsg(ctx, cmData, true, rtDependence)
mh.messageEvent.SendMsgEvent(entity.RunEventMessageCompleted, sendMsg, mh.sw)
return nil
}
func (mh *MesssageEventHanlder) handlerTooResponse(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime, preToolResponseMsg *msgEntity.Message, toolResponseMsgContent string) error {
cm := buildAgentMessage2Create(ctx, chunk, message.MessageTypeToolResponse, rtDependence)
var cmData *message.Message
var err error
if preToolResponseMsg != nil {
cm.ID = preToolResponseMsg.ID
cm.CreatedAt = preToolResponseMsg.CreatedAt
cm.UpdatedAt = preToolResponseMsg.UpdatedAt
if len(toolResponseMsgContent) > 0 {
cm.Content = toolResponseMsgContent + "\n" + cm.Content
}
}
cmData, err = crossmessage.DefaultSVC().Create(ctx, cm)
if err != nil {
return err
}
sendMsg := buildSendMsg(ctx, cmData, true, rtDependence)
mh.messageEvent.SendMsgEvent(entity.RunEventMessageCompleted, sendMsg, mh.sw)
return nil
}
func (mh *MesssageEventHanlder) handlerSuggest(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime) error {
cm := buildAgentMessage2Create(ctx, chunk, message.MessageTypeFlowUp, rtDependence)
cmData, err := crossmessage.DefaultSVC().Create(ctx, cm)
if err != nil {
return err
}
sendMsg := buildSendMsg(ctx, cmData, true, rtDependence)
mh.messageEvent.SendMsgEvent(entity.RunEventMessageCompleted, sendMsg, mh.sw)
return nil
}
func (mh *MesssageEventHanlder) handlerKnowledge(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime) error {
cm := buildAgentMessage2Create(ctx, chunk, message.MessageTypeKnowledge, rtDependence)
cmData, err := crossmessage.DefaultSVC().Create(ctx, cm)
if err != nil {
return err
}
sendMsg := buildSendMsg(ctx, cmData, true, rtDependence)
mh.messageEvent.SendMsgEvent(entity.RunEventMessageCompleted, sendMsg, mh.sw)
return nil
}
func (mh *MesssageEventHanlder) handlerAnswer(ctx context.Context, msg *entity.ChunkMessageItem, usage *msgEntity.UsageExt, rtDependence *AgentRuntime, preAnswerMsg *msgEntity.Message) error {
if len(msg.Content) == 0 && len(ptr.From(msg.ReasoningContent)) == 0 {
return nil
}
msg.IsFinish = true
if msg.Ext == nil {
msg.Ext = map[string]string{}
}
if usage != nil {
msg.Ext[string(msgEntity.MessageExtKeyToken)] = strconv.FormatInt(usage.TotalCount, 10)
msg.Ext[string(msgEntity.MessageExtKeyInputTokens)] = strconv.FormatInt(usage.InputTokens, 10)
msg.Ext[string(msgEntity.MessageExtKeyOutputTokens)] = strconv.FormatInt(usage.OutputTokens, 10)
rtDependence.Usage = &agentrun.Usage{
LlmPromptTokens: usage.InputTokens,
LlmCompletionTokens: usage.OutputTokens,
LlmTotalTokens: usage.TotalCount,
}
}
if _, ok := msg.Ext[string(msgEntity.MessageExtKeyTimeCost)]; !ok {
msg.Ext[string(msgEntity.MessageExtKeyTimeCost)] = fmt.Sprintf("%.1f", float64(time.Since(rtDependence.GetStartTime()).Milliseconds())/1000.00)
}
buildModelContent := &schema.Message{
Role: schema.Assistant,
Content: msg.Content,
}
mc, err := json.Marshal(buildModelContent)
if err != nil {
return err
}
preAnswerMsg.Content = msg.Content
preAnswerMsg.ReasoningContent = ptr.From(msg.ReasoningContent)
preAnswerMsg.Ext = msg.Ext
preAnswerMsg.ContentType = msg.ContentType
preAnswerMsg.ModelContent = string(mc)
preAnswerMsg.CreatedAt = 0
preAnswerMsg.UpdatedAt = 0
_, err = crossmessage.DefaultSVC().Create(ctx, preAnswerMsg)
if err != nil {
return err
}
mh.messageEvent.SendMsgEvent(entity.RunEventMessageCompleted, msg, mh.sw)
return nil
}
func (mh *MesssageEventHanlder) handlerFinalAnswerFinish(ctx context.Context, rtDependence *AgentRuntime) error {
cm := buildAgentMessage2Create(ctx, nil, message.MessageTypeVerbose, rtDependence)
cmData, err := crossmessage.DefaultSVC().Create(ctx, cm)
if err != nil {
return err
}
sendMsg := buildSendMsg(ctx, cmData, true, rtDependence)
mh.messageEvent.SendMsgEvent(entity.RunEventMessageCompleted, sendMsg, mh.sw)
return nil
}
func (mh *MesssageEventHanlder) handlerInterruptVerbose(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime) error {
cm := buildAgentMessage2Create(ctx, chunk, message.MessageTypeInterrupt, rtDependence)
cmData, err := crossmessage.DefaultSVC().Create(ctx, cm)
if err != nil {
return err
}
sendMsg := buildSendMsg(ctx, cmData, true, rtDependence)
mh.messageEvent.SendMsgEvent(entity.RunEventMessageCompleted, sendMsg, mh.sw)
return nil
}
func (mh *MesssageEventHanlder) handlerWfUsage(ctx context.Context, msg *entity.ChunkMessageItem, usage *msgEntity.UsageExt) error {
if msg.Ext == nil {
msg.Ext = map[string]string{}
}
if usage != nil {
msg.Ext[string(msgEntity.MessageExtKeyToken)] = strconv.FormatInt(usage.TotalCount, 10)
msg.Ext[string(msgEntity.MessageExtKeyInputTokens)] = strconv.FormatInt(usage.InputTokens, 10)
msg.Ext[string(msgEntity.MessageExtKeyOutputTokens)] = strconv.FormatInt(usage.OutputTokens, 10)
}
_, err := crossmessage.DefaultSVC().Edit(ctx, &msgEntity.Message{
ID: msg.ID,
Ext: msg.Ext,
})
if err != nil {
return err
}
mh.messageEvent.SendMsgEvent(entity.RunEventMessageCompleted, msg, mh.sw)
return nil
}
func (mh *MesssageEventHanlder) handlerInterrupt(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime, firstAnswerMsg *msgEntity.Message, reasoningContent string) error {
interruptData, cType, err := parseInterruptData(ctx, chunk.Interrupt)
if err != nil {
return err
}
preMsg, err := preCreateAnswer(ctx, rtDependence)
if err != nil {
return err
}
deltaAnswer := &entity.ChunkMessageItem{
ID: preMsg.ID,
ConversationID: preMsg.ConversationID,
SectionID: preMsg.SectionID,
RunID: preMsg.RunID,
AgentID: preMsg.AgentID,
Role: entity.RoleType(preMsg.Role),
Content: interruptData,
MessageType: preMsg.MessageType,
ContentType: cType,
ReplyID: preMsg.RunID,
Ext: preMsg.Ext,
IsFinish: false,
}
mh.messageEvent.SendMsgEvent(entity.RunEventMessageDelta, deltaAnswer, mh.sw)
finalAnswer := deepcopy.Copy(deltaAnswer).(*entity.ChunkMessageItem)
if len(reasoningContent) > 0 && firstAnswerMsg == nil {
finalAnswer.ReasoningContent = ptr.Of(reasoningContent)
}
usage := func() *msgEntity.UsageExt {
if rtDependence.GetUsage() != nil {
return &msgEntity.UsageExt{
TotalCount: rtDependence.GetUsage().LlmTotalTokens,
InputTokens: rtDependence.GetUsage().LlmPromptTokens,
OutputTokens: rtDependence.GetUsage().LlmCompletionTokens,
}
}
return nil
}
err = mh.handlerAnswer(ctx, finalAnswer, usage(), rtDependence, preMsg)
if err != nil {
return err
}
err = mh.handlerInterruptVerbose(ctx, chunk, rtDependence)
if err != nil {
return err
}
return nil
}
func (mh *MesssageEventHanlder) handlerWfInterruptMsg(ctx context.Context, stateMsg *crossworkflow.StateMessage, rtDependence *AgentRuntime) {
interruptData, cType, err := handlerWfInterruptEvent(ctx, stateMsg.InterruptEvent)
if err != nil {
return
}
preMsg, err := preCreateAnswer(ctx, rtDependence)
if err != nil {
return
}
deltaAnswer := &entity.ChunkMessageItem{
ID: preMsg.ID,
ConversationID: preMsg.ConversationID,
SectionID: preMsg.SectionID,
RunID: preMsg.RunID,
AgentID: preMsg.AgentID,
Role: entity.RoleType(preMsg.Role),
Content: interruptData,
MessageType: preMsg.MessageType,
ContentType: cType,
ReplyID: preMsg.RunID,
Ext: preMsg.Ext,
IsFinish: false,
}
mh.messageEvent.SendMsgEvent(entity.RunEventMessageDelta, deltaAnswer, mh.sw)
finalAnswer := deepcopy.Copy(deltaAnswer).(*entity.ChunkMessageItem)
err = mh.handlerAnswer(ctx, finalAnswer, nil, rtDependence, preMsg)
if err != nil {
return
}
err = mh.handlerInterruptVerbose(ctx, &entity.AgentRespEvent{
EventType: message.MessageTypeInterrupt,
Interrupt: &singleagent.InterruptInfo{
InterruptType: singleagent.InterruptEventType(stateMsg.InterruptEvent.EventType),
InterruptID: strconv.FormatInt(stateMsg.InterruptEvent.ID, 10),
ChatflowInterrupt: stateMsg,
},
}, rtDependence)
if err != nil {
return
}
}
func (mh *MesssageEventHanlder) HandlerInput(ctx context.Context, rtDependence *AgentRuntime) (*msgEntity.Message, error) {
msgMeta := buildAgentMessage2Create(ctx, nil, message.MessageTypeQuestion, rtDependence)
cm, err := crossmessage.DefaultSVC().Create(ctx, msgMeta)
if err != nil {
return nil, err
}
ackErr := mh.handlerAckMessage(ctx, cm)
if ackErr != nil {
return msgMeta, ackErr
}
return cm, nil
}

View File

@@ -0,0 +1,214 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package internal
import (
"context"
"time"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/api/model/app/bot_common"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
crossmessage "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/repository"
msgEntity "github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
type AgentRuntime struct {
RunRecord *entity.RunRecordMeta
AgentInfo *singleagent.SingleAgent
QuestionMsgID int64
RunMeta *entity.AgentRunMeta
StartTime time.Time
Input *msgEntity.Message
HistoryMsg []*msgEntity.Message
Usage *agentrun.Usage
SW *schema.StreamWriter[*entity.AgentRunResponse]
RunProcess *RunProcess
RunRecordRepo repository.RunRecordRepo
ImagexClient imagex.ImageX
MessageEvent *Event
}
func (rd *AgentRuntime) SetRunRecord(runRecord *entity.RunRecordMeta) {
rd.RunRecord = runRecord
}
func (rd *AgentRuntime) GetRunRecord() *entity.RunRecordMeta {
return rd.RunRecord
}
func (rd *AgentRuntime) SetUsage(usage *agentrun.Usage) {
rd.Usage = usage
}
func (rd *AgentRuntime) GetUsage() *agentrun.Usage {
return rd.Usage
}
func (rd *AgentRuntime) SetRunMeta(arm *entity.AgentRunMeta) {
rd.RunMeta = arm
}
func (rd *AgentRuntime) GetRunMeta() *entity.AgentRunMeta {
return rd.RunMeta
}
func (rd *AgentRuntime) SetAgentInfo(agentInfo *singleagent.SingleAgent) {
rd.AgentInfo = agentInfo
}
func (rd *AgentRuntime) GetAgentInfo() *singleagent.SingleAgent {
return rd.AgentInfo
}
func (rd *AgentRuntime) SetQuestionMsgID(msgID int64) {
rd.QuestionMsgID = msgID
}
func (rd *AgentRuntime) GetQuestionMsgID() int64 {
return rd.QuestionMsgID
}
func (rd *AgentRuntime) SetStartTime(t time.Time) {
rd.StartTime = t
}
func (rd *AgentRuntime) GetStartTime() time.Time {
return rd.StartTime
}
func (rd *AgentRuntime) SetInput(input *msgEntity.Message) {
rd.Input = input
}
func (rd *AgentRuntime) GetInput() *msgEntity.Message {
return rd.Input
}
func (rd *AgentRuntime) SetHistoryMsg(histroyMsg []*msgEntity.Message) {
rd.HistoryMsg = histroyMsg
}
func (rd *AgentRuntime) GetHistory() []*msgEntity.Message {
return rd.HistoryMsg
}
func (art *AgentRuntime) Run(ctx context.Context) (err error) {
agentInfo, err := getAgentInfo(ctx, art.GetRunMeta().AgentID, art.GetRunMeta().IsDraft)
if err != nil {
return
}
art.SetAgentInfo(agentInfo)
history, err := art.getHistory(ctx)
if err != nil {
return
}
runRecord, err := art.createRunRecord(ctx)
if err != nil {
return
}
art.SetRunRecord(runRecord)
art.SetHistoryMsg(history)
defer func() {
srRecord := buildSendRunRecord(ctx, runRecord, entity.RunStatusCompleted)
if err != nil {
srRecord.Error = &entity.RunError{
Code: errno.ErrConversationAgentRunError,
Msg: err.Error(),
}
art.RunProcess.StepToFailed(ctx, srRecord, art.SW)
return
}
art.RunProcess.StepToComplete(ctx, srRecord, art.SW, art.GetUsage())
}()
mh := &MesssageEventHanlder{
messageEvent: art.MessageEvent,
sw: art.SW,
}
input, err := mh.HandlerInput(ctx, art)
if err != nil {
return
}
art.SetInput(input)
art.SetQuestionMsgID(input.ID)
if art.GetAgentInfo().BotMode == bot_common.BotMode_WorkflowMode {
err = art.ChatflowRun(ctx, art.ImagexClient)
} else {
err = art.AgentStreamExecute(ctx, art.ImagexClient)
}
return
}
func (art *AgentRuntime) getHistory(ctx context.Context) ([]*msgEntity.Message, error) {
conversationTurns := getAgentHistoryRounds(art.GetAgentInfo())
runRecordList, err := art.RunRecordRepo.List(ctx, &entity.ListRunRecordMeta{
ConversationID: art.GetRunMeta().ConversationID,
SectionID: art.GetRunMeta().SectionID,
Limit: conversationTurns,
})
if err != nil {
return nil, err
}
if len(runRecordList) == 0 {
return nil, nil
}
runIDS := concactRunID(runRecordList)
history, err := crossmessage.DefaultSVC().GetByRunIDs(ctx, art.GetRunMeta().ConversationID, runIDS)
if err != nil {
return nil, err
}
return history, nil
}
func concactRunID(rr []*entity.RunRecordMeta) []int64 {
ids := make([]int64, 0, len(rr))
for _, c := range rr {
ids = append(ids, c.ID)
}
return ids
}
func (art *AgentRuntime) createRunRecord(ctx context.Context) (*entity.RunRecordMeta, error) {
runPoData, err := art.RunRecordRepo.Create(ctx, art.GetRunMeta())
if err != nil {
logs.CtxErrorf(ctx, "RunRecordRepo.Create error: %v", err)
return nil, err
}
srRecord := buildSendRunRecord(ctx, runPoData, entity.RunStatusCreated)
art.RunProcess.StepToCreate(ctx, srRecord, art.SW)
err = art.RunProcess.StepToInProgress(ctx, srRecord, art.SW)
if err != nil {
logs.CtxErrorf(ctx, "runProcess.StepToInProgress error: %v", err)
return nil, err
}
return runPoData, nil
}

View File

@@ -30,8 +30,8 @@ import (
)
type RunProcess struct {
event *Event
event *Event
SW *schema.StreamWriter[*entity.AgentRunResponse]
RunRecordRepo repository.RunRecordRepo
}
@@ -115,7 +115,6 @@ func (r *RunProcess) StepToFailed(ctx context.Context, srRecord *entity.ChunkRun
Code: srRecord.Error.Code,
Msg: srRecord.Error.Msg,
})
return
}
func (r *RunProcess) StepToDone(sw *schema.StreamWriter[*entity.AgentRunResponse]) {

View File

@@ -0,0 +1,450 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package internal
import (
"bytes"
"context"
"errors"
"io"
"sync"
"github.com/cloudwego/eino/schema"
"github.com/mohae/deepcopy"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
crossagent "github.com/coze-dev/coze-studio/backend/crossdomain/contract/agent"
crossmessage "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
msgEntity "github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
func (art *AgentRuntime) AgentStreamExecute(ctx context.Context, imagex imagex.ImageX) (err error) {
mainChan := make(chan *entity.AgentRespEvent, 100)
ar := &crossagent.AgentRuntime{
AgentVersion: art.GetRunMeta().Version,
SpaceID: art.GetRunMeta().SpaceID,
AgentID: art.GetRunMeta().AgentID,
IsDraft: art.GetRunMeta().IsDraft,
UserID: art.GetRunMeta().UserID,
ConnectorID: art.GetRunMeta().ConnectorID,
PreRetrieveTools: art.GetRunMeta().PreRetrieveTools,
Input: transMessageToSchemaMessage(ctx, []*msgEntity.Message{art.GetInput()}, imagex)[0],
HistoryMsg: transMessageToSchemaMessage(ctx, historyPairs(art.GetHistory()), imagex),
ResumeInfo: parseResumeInfo(ctx, art.GetHistory()),
}
streamer, err := crossagent.DefaultSVC().StreamExecute(ctx, ar)
if err != nil {
return errors.New(errorx.ErrorWithoutStack(err))
}
var wg sync.WaitGroup
wg.Add(2)
safego.Go(ctx, func() {
defer wg.Done()
art.pull(ctx, mainChan, streamer)
})
safego.Go(ctx, func() {
defer wg.Done()
art.push(ctx, mainChan)
})
wg.Wait()
return err
}
func (art *AgentRuntime) push(ctx context.Context, mainChan chan *entity.AgentRespEvent) {
mh := &MesssageEventHanlder{
sw: art.SW,
messageEvent: art.MessageEvent,
}
var err error
defer func() {
if err != nil {
logs.CtxErrorf(ctx, "run.push error: %v", err)
mh.handlerErr(ctx, err)
}
}()
reasoningContent := bytes.NewBuffer([]byte{})
var firstAnswerMsg *msgEntity.Message
var reasoningMsg *msgEntity.Message
isSendFinishAnswer := false
var preToolResponseMsg *msgEntity.Message
toolResponseMsgContent := bytes.NewBuffer([]byte{})
for {
chunk, ok := <-mainChan
if !ok || chunk == nil {
return
}
if chunk.Err != nil {
if errors.Is(chunk.Err, io.EOF) {
if !isSendFinishAnswer {
isSendFinishAnswer = true
if firstAnswerMsg != nil && len(reasoningContent.String()) > 0 {
art.saveReasoningContent(ctx, firstAnswerMsg, reasoningContent.String())
reasoningContent.Reset()
}
finishErr := mh.handlerFinalAnswerFinish(ctx, art)
if finishErr != nil {
err = finishErr
return
}
}
return
}
mh.handlerErr(ctx, chunk.Err)
return
}
switch chunk.EventType {
case message.MessageTypeFunctionCall:
if chunk.FuncCall != nil && chunk.FuncCall.ResponseMeta != nil {
if usage := handlerUsage(chunk.FuncCall.ResponseMeta); usage != nil {
art.SetUsage(&agentrun.Usage{
LlmPromptTokens: usage.InputTokens,
LlmCompletionTokens: usage.OutputTokens,
LlmTotalTokens: usage.TotalCount,
})
}
}
err = mh.handlerFunctionCall(ctx, chunk, art)
if err != nil {
return
}
if preToolResponseMsg == nil {
var cErr error
preToolResponseMsg, cErr = preCreateAnswer(ctx, art)
if cErr != nil {
err = cErr
return
}
}
case message.MessageTypeToolResponse:
err = mh.handlerTooResponse(ctx, chunk, art, preToolResponseMsg, toolResponseMsgContent.String())
if err != nil {
return
}
preToolResponseMsg = nil // reset
case message.MessageTypeKnowledge:
err = mh.handlerKnowledge(ctx, chunk, art)
if err != nil {
return
}
case message.MessageTypeToolMidAnswer:
fullMidAnswerContent := bytes.NewBuffer([]byte{})
var usage *msgEntity.UsageExt
toolMidAnswerMsg, cErr := preCreateAnswer(ctx, art)
if cErr != nil {
err = cErr
return
}
var preMsgIsFinish = false
for {
streamMsg, receErr := chunk.ToolMidAnswer.Recv()
if receErr != nil {
if errors.Is(receErr, io.EOF) {
break
}
err = receErr
return
}
if preMsgIsFinish {
toolMidAnswerMsg, cErr = preCreateAnswer(ctx, art)
if cErr != nil {
err = cErr
return
}
preMsgIsFinish = false
}
if streamMsg == nil {
continue
}
if firstAnswerMsg == nil && len(streamMsg.Content) > 0 {
if reasoningMsg != nil {
toolMidAnswerMsg = deepcopy.Copy(reasoningMsg).(*msgEntity.Message)
}
firstAnswerMsg = deepcopy.Copy(toolMidAnswerMsg).(*msgEntity.Message)
}
if streamMsg.Extra != nil {
if val, ok := streamMsg.Extra["workflow_node_name"]; ok && val != nil {
toolMidAnswerMsg.Ext["message_title"] = val.(string)
}
}
sendMidAnswerMsg := buildSendMsg(ctx, toolMidAnswerMsg, false, art)
sendMidAnswerMsg.Content = streamMsg.Content
toolResponseMsgContent.WriteString(streamMsg.Content)
fullMidAnswerContent.WriteString(streamMsg.Content)
art.MessageEvent.SendMsgEvent(entity.RunEventMessageDelta, sendMidAnswerMsg, art.SW)
if streamMsg != nil && streamMsg.ResponseMeta != nil {
usage = handlerUsage(streamMsg.ResponseMeta)
}
if streamMsg.Extra["is_finish"] == true {
preMsgIsFinish = true
sendMidAnswerMsg := buildSendMsg(ctx, toolMidAnswerMsg, false, art)
sendMidAnswerMsg.Content = fullMidAnswerContent.String()
fullMidAnswerContent.Reset()
hfErr := mh.handlerAnswer(ctx, sendMidAnswerMsg, usage, art, toolMidAnswerMsg)
if hfErr != nil {
err = hfErr
return
}
}
}
case message.MessageTypeToolAsAnswer:
var usage *msgEntity.UsageExt
fullContent := bytes.NewBuffer([]byte{})
toolAsAnswerMsg, cErr := preCreateAnswer(ctx, art)
if cErr != nil {
err = cErr
return
}
if firstAnswerMsg == nil {
firstAnswerMsg = toolAsAnswerMsg
}
for {
streamMsg, receErr := chunk.ToolAsAnswer.Recv()
if receErr != nil {
if errors.Is(receErr, io.EOF) {
answer := buildSendMsg(ctx, toolAsAnswerMsg, false, art)
answer.Content = fullContent.String()
hfErr := mh.handlerAnswer(ctx, answer, usage, art, toolAsAnswerMsg)
if hfErr != nil {
err = hfErr
return
}
break
}
err = receErr
return
}
if streamMsg != nil && streamMsg.ResponseMeta != nil {
usage = handlerUsage(streamMsg.ResponseMeta)
}
sendMsg := buildSendMsg(ctx, toolAsAnswerMsg, false, art)
fullContent.WriteString(streamMsg.Content)
sendMsg.Content = streamMsg.Content
art.MessageEvent.SendMsgEvent(entity.RunEventMessageDelta, sendMsg, art.SW)
}
case message.MessageTypeAnswer:
fullContent := bytes.NewBuffer([]byte{})
var usage *msgEntity.UsageExt
var isToolCalls = false
var modelAnswerMsg *msgEntity.Message
for {
streamMsg, receErr := chunk.ModelAnswer.Recv()
if receErr != nil {
if errors.Is(receErr, io.EOF) {
if isToolCalls {
break
}
if modelAnswerMsg == nil {
break
}
answer := buildSendMsg(ctx, modelAnswerMsg, false, art)
answer.Content = fullContent.String()
hfErr := mh.handlerAnswer(ctx, answer, usage, art, modelAnswerMsg)
if hfErr != nil {
err = hfErr
return
}
break
}
err = receErr
return
}
if streamMsg != nil && len(streamMsg.ToolCalls) > 0 {
isToolCalls = true
}
if streamMsg != nil && streamMsg.ResponseMeta != nil {
usage = handlerUsage(streamMsg.ResponseMeta)
}
if streamMsg != nil && len(streamMsg.ReasoningContent) == 0 && len(streamMsg.Content) == 0 {
continue
}
if len(streamMsg.ReasoningContent) > 0 {
if reasoningMsg == nil {
reasoningMsg, err = preCreateAnswer(ctx, art)
if err != nil {
return
}
}
sendReasoningMsg := buildSendMsg(ctx, reasoningMsg, false, art)
reasoningContent.WriteString(streamMsg.ReasoningContent)
sendReasoningMsg.ReasoningContent = ptr.Of(streamMsg.ReasoningContent)
art.MessageEvent.SendMsgEvent(entity.RunEventMessageDelta, sendReasoningMsg, art.SW)
}
if len(streamMsg.Content) > 0 {
if modelAnswerMsg == nil {
modelAnswerMsg, err = preCreateAnswer(ctx, art)
if err != nil {
return
}
if firstAnswerMsg == nil {
if reasoningMsg != nil {
modelAnswerMsg.ID = reasoningMsg.ID
}
firstAnswerMsg = modelAnswerMsg
}
}
sendAnswerMsg := buildSendMsg(ctx, modelAnswerMsg, false, art)
fullContent.WriteString(streamMsg.Content)
sendAnswerMsg.Content = streamMsg.Content
art.MessageEvent.SendMsgEvent(entity.RunEventMessageDelta, sendAnswerMsg, art.SW)
}
}
case message.MessageTypeFlowUp:
if isSendFinishAnswer {
if firstAnswerMsg != nil && len(reasoningContent.String()) > 0 {
art.saveReasoningContent(ctx, firstAnswerMsg, reasoningContent.String())
}
isSendFinishAnswer = true
finishErr := mh.handlerFinalAnswerFinish(ctx, art)
if finishErr != nil {
err = finishErr
return
}
}
err = mh.handlerSuggest(ctx, chunk, art)
if err != nil {
return
}
case message.MessageTypeInterrupt:
err = mh.handlerInterrupt(ctx, chunk, art, firstAnswerMsg, reasoningContent.String())
if err != nil {
return
}
}
}
}
func (art *AgentRuntime) pull(_ context.Context, mainChan chan *entity.AgentRespEvent, events *schema.StreamReader[*crossagent.AgentEvent]) {
defer func() {
close(mainChan)
}()
for {
rm, re := events.Recv()
if re != nil {
errChunk := &entity.AgentRespEvent{
Err: re,
}
mainChan <- errChunk
return
}
eventType, tErr := transformEventMap(rm.EventType)
if tErr != nil {
errChunk := &entity.AgentRespEvent{
Err: tErr,
}
mainChan <- errChunk
return
}
respChunk := &entity.AgentRespEvent{
EventType: eventType,
ModelAnswer: rm.ChatModelAnswer,
ToolsMessage: rm.ToolsMessage,
FuncCall: rm.FuncCall,
Knowledge: rm.Knowledge,
Suggest: rm.Suggest,
Interrupt: rm.Interrupt,
ToolMidAnswer: rm.ToolMidAnswer,
ToolAsAnswer: rm.ToolAsChatModelAnswer,
}
mainChan <- respChunk
}
}
func transformEventMap(eventType singleagent.EventType) (message.MessageType, error) {
var eType message.MessageType
switch eventType {
case singleagent.EventTypeOfFuncCall:
return message.MessageTypeFunctionCall, nil
case singleagent.EventTypeOfKnowledge:
return message.MessageTypeKnowledge, nil
case singleagent.EventTypeOfToolsMessage:
return message.MessageTypeToolResponse, nil
case singleagent.EventTypeOfChatModelAnswer:
return message.MessageTypeAnswer, nil
case singleagent.EventTypeOfToolsAsChatModelStream:
return message.MessageTypeToolAsAnswer, nil
case singleagent.EventTypeOfToolMidAnswer:
return message.MessageTypeToolMidAnswer, nil
case singleagent.EventTypeOfSuggest:
return message.MessageTypeFlowUp, nil
case singleagent.EventTypeOfInterrupt:
return message.MessageTypeInterrupt, nil
}
return eType, errorx.New(errno.ErrReplyUnknowEventType)
}
func (art *AgentRuntime) saveReasoningContent(ctx context.Context, firstAnswerMsg *msgEntity.Message, reasoningContent string) {
_, err := crossmessage.DefaultSVC().Edit(ctx, &message.Message{
ID: firstAnswerMsg.ID,
ReasoningContent: reasoningContent,
})
if err != nil {
logs.CtxInfof(ctx, "save reasoning content failed, err: %v", err)
}
}

View File

@@ -23,7 +23,6 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/internal/dal"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/internal/dal/model"
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
)
@@ -34,8 +33,9 @@ func NewRunRecordRepo(db *gorm.DB, idGen idgen.IDGenerator) RunRecordRepo {
type RunRecordRepo interface {
Create(ctx context.Context, runMeta *entity.AgentRunMeta) (*entity.RunRecordMeta, error)
GetByID(ctx context.Context, id int64) (*entity.RunRecord, error)
GetByID(ctx context.Context, id int64) (*entity.RunRecordMeta, error)
Cancel(ctx context.Context, req *entity.CancelRunMeta) (*entity.RunRecordMeta, error)
Delete(ctx context.Context, id []int64) error
UpdateByID(ctx context.Context, id int64, update *entity.UpdateMeta) error
List(ctx context.Context, conversationID int64, sectionID int64, limit int32) ([]*model.RunRecord, error)
List(ctx context.Context, meta *entity.ListRunRecordMeta) ([]*entity.RunRecordMeta, error)
}

View File

@@ -26,6 +26,9 @@ import (
type Run interface {
AgentRun(ctx context.Context, req *entity.AgentRunMeta) (*schema.StreamReader[*entity.AgentRunResponse], error)
Delete(ctx context.Context, runID []int64) error
Create(ctx context.Context, runRecord *entity.AgentRunMeta) (*entity.RunRecordMeta, error)
List(ctx context.Context, ListMeta *entity.ListRunRecordMeta) ([]*entity.RunRecordMeta, error)
GetByID(ctx context.Context, runID int64) (*entity.RunRecordMeta, error)
Cancel(ctx context.Context, req *entity.CancelRunMeta) (*entity.RunRecordMeta, error)
}

File diff suppressed because it is too large Load Diff

View File

@@ -17,7 +17,18 @@
package agentrun
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/internal/dal/model"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/repository"
mock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/idgen"
"github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/orm"
)
func TestAgentRun(t *testing.T) {
@@ -97,3 +108,158 @@ func TestAgentRun(t *testing.T) {
// assert.NoError(t, err)
}
func TestRunImpl_List(t *testing.T) {
ctx := context.Background()
mockDBGen := orm.NewMockDB()
mockDBGen.AddTable(&model.RunRecord{}).AddRows(
&model.RunRecord{
ID: 1,
ConversationID: 123,
AgentID: 456,
SectionID: 789,
UserID: "123456",
CreatedAt: time.Now().Unix(),
},
&model.RunRecord{
ID: 2,
ConversationID: 123,
AgentID: 456,
SectionID: 789,
UserID: "123456",
CreatedAt: time.Now().Unix() + 1,
}, &model.RunRecord{
ID: 3,
ConversationID: 123,
AgentID: 456,
SectionID: 789,
UserID: "123456",
CreatedAt: time.Now().Unix() + 2,
}, &model.RunRecord{
ID: 4,
ConversationID: 123,
AgentID: 456,
SectionID: 789,
UserID: "123456",
CreatedAt: time.Now().Unix() + 3,
}, &model.RunRecord{
ID: 5,
ConversationID: 123,
AgentID: 456,
SectionID: 789,
UserID: "123456",
CreatedAt: time.Now().Unix() + 4,
},
&model.RunRecord{
ID: 6,
ConversationID: 123,
AgentID: 456,
SectionID: 789,
UserID: "123456",
CreatedAt: time.Now().Unix() + 5,
}, &model.RunRecord{
ID: 7,
ConversationID: 123,
AgentID: 456,
SectionID: 789,
UserID: "123456",
CreatedAt: time.Now().Unix() + 6,
}, &model.RunRecord{
ID: 8,
ConversationID: 123,
AgentID: 456,
SectionID: 789,
UserID: "123456",
CreatedAt: time.Now().Unix() + 7,
}, &model.RunRecord{
ID: 9,
ConversationID: 123,
AgentID: 456,
SectionID: 789,
UserID: "123456",
CreatedAt: time.Now().Unix() + 8,
},
)
mockDB, err := mockDBGen.DB()
assert.NoError(t, err)
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockIDGen := mock.NewMockIDGenerator(ctrl)
runRecordRepo := repository.NewRunRecordRepo(mockDB, mockIDGen)
service := &runImpl{
Components: Components{
RunRecordRepo: runRecordRepo,
},
}
t.Run("list success", func(t *testing.T) {
meta := &entity.ListRunRecordMeta{
ConversationID: 123,
AgentID: 456,
SectionID: 789,
Limit: 10,
OrderBy: "desc",
}
result, err := service.List(ctx, meta)
// check result
assert.NoError(t, err)
assert.Len(t, result, 9)
assert.Equal(t, int64(123), result[0].ConversationID)
assert.Equal(t, int64(456), result[0].AgentID)
})
t.Run("empty list", func(t *testing.T) {
meta := &entity.ListRunRecordMeta{
ConversationID: 999, //
Limit: 10,
OrderBy: "desc",
}
// check result
result, err := service.List(ctx, meta)
assert.NoError(t, err)
assert.Empty(t, result)
})
t.Run("search with before id", func(t *testing.T) {
meta := &entity.ListRunRecordMeta{
ConversationID: 123,
SectionID: 789,
AgentID: 456,
BeforeID: 5,
Limit: 3,
OrderBy: "desc",
}
result, err := service.List(ctx, meta)
// check result
assert.NoError(t, err)
assert.Len(t, result, 3)
assert.Equal(t, int64(4), result[0].ID)
})
t.Run("search with after id and limit", func(t *testing.T) {
meta := &entity.ListRunRecordMeta{
ConversationID: 123,
SectionID: 789,
AgentID: 456,
AfterID: 5,
Limit: 3,
OrderBy: "desc",
}
result, err := service.List(ctx, meta)
// check result
assert.NoError(t, err)
assert.Len(t, result, 3)
assert.Equal(t, int64(9), result[0].ID)
})
}

View File

@@ -24,6 +24,7 @@ import (
type Conversation = conversation.Conversation
type CreateMeta struct {
Name string `json:"name"`
AgentID int64 `json:"agent_id"`
UserID int64 `json:"user_id"`
ConnectorID int64 `json:"connector_id"`
@@ -50,3 +51,8 @@ type ListMeta struct {
Limit int `json:"limit"`
Page int `json:"page"`
}
type UpdateMeta struct {
ID int64 `json:"id"`
Name string `json:"name"`
}

View File

@@ -107,6 +107,20 @@ func (dao *ConversationDAO) Delete(ctx context.Context, id int64) (int64, error)
return updateRes.RowsAffected, err
}
func (dao *ConversationDAO) Update(ctx context.Context, req *entity.UpdateMeta) (*entity.Conversation, error) {
updateColumn := make(map[string]interface{})
updateColumn[dao.query.Conversation.UpdatedAt.ColumnName().String()] = time.Now().UnixMilli()
if len(req.Name) > 0 {
updateColumn[dao.query.Conversation.Name.ColumnName().String()] = req.Name
}
_, err := dao.query.Conversation.WithContext(ctx).Where(dao.query.Conversation.ID.Eq(req.ID)).UpdateColumns(updateColumn)
if err != nil {
return nil, err
}
return dao.GetByID(ctx, req.ID)
}
func (dao *ConversationDAO) Get(ctx context.Context, userID int64, agentID int64, scene int32, connectorID int64) (*entity.Conversation, error) {
po, err := dao.query.Conversation.WithContext(ctx).Debug().
Where(dao.query.Conversation.CreatorID.Eq(userID)).
@@ -133,13 +147,15 @@ func (dao *ConversationDAO) List(ctx context.Context, userID int64, agentID int6
do = do.Where(dao.query.Conversation.CreatorID.Eq(userID)).
Where(dao.query.Conversation.AgentID.Eq(agentID)).
Where(dao.query.Conversation.Scene.Eq(scene)).
Where(dao.query.Conversation.ConnectorID.Eq(connectorID))
Where(dao.query.Conversation.ConnectorID.Eq(connectorID)).
Where(dao.query.Conversation.Status.Eq(int32(conversation.ConversationStatusNormal)))
do = do.Offset((page - 1) * limit)
if limit > 0 {
do = do.Limit(int(limit) + 1)
}
do = do.Order(dao.query.Conversation.CreatedAt.Desc())
poList, err := do.Find()
@@ -173,6 +189,7 @@ func (dao *ConversationDAO) conversationDO2PO(ctx context.Context, conversation
Ext: conversation.Ext,
CreatedAt: time.Now().UnixMilli(),
UpdatedAt: time.Now().UnixMilli(),
Name: conversation.Name,
}
}
@@ -188,6 +205,7 @@ func (dao *ConversationDAO) conversationPO2DO(ctx context.Context, c *model.Conv
Ext: c.Ext,
CreatedAt: c.CreatedAt,
UpdatedAt: c.UpdatedAt,
Name: c.Name,
}
}
@@ -204,6 +222,7 @@ func (dao *ConversationDAO) conversationBatchPO2DO(ctx context.Context, conversa
Ext: c.Ext,
CreatedAt: c.CreatedAt,
UpdatedAt: c.UpdatedAt,
Name: c.Name,
}
})
}

View File

@@ -1,3 +1,19 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
@@ -9,6 +25,7 @@ const TableNameConversation = "conversation"
// Conversation conversation info record
type Conversation struct {
ID int64 `gorm:"column:id;primaryKey;autoIncrement:true;comment:id" json:"id"` // id
Name string `gorm:"column:name;not null;comment:conversation name" json:"name"` // conversation name
ConnectorID int64 `gorm:"column:connector_id;not null;comment:Publish Connector ID" json:"connector_id"` // Publish Connector ID
AgentID int64 `gorm:"column:agent_id;not null;comment:agent_id" json:"agent_id"` // agent_id
Scene int32 `gorm:"column:scene;not null;comment:conversation scene" json:"scene"` // conversation scene

View File

@@ -1,3 +1,19 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
@@ -28,6 +44,7 @@ func newConversation(db *gorm.DB, opts ...gen.DOOption) conversation {
tableName := _conversation.conversationDo.TableName()
_conversation.ALL = field.NewAsterisk(tableName)
_conversation.ID = field.NewInt64(tableName, "id")
_conversation.Name = field.NewString(tableName, "name")
_conversation.ConnectorID = field.NewInt64(tableName, "connector_id")
_conversation.AgentID = field.NewInt64(tableName, "agent_id")
_conversation.Scene = field.NewInt32(tableName, "scene")
@@ -49,6 +66,7 @@ type conversation struct {
ALL field.Asterisk
ID field.Int64 // id
Name field.String // conversation name
ConnectorID field.Int64 // Publish Connector ID
AgentID field.Int64 // agent_id
Scene field.Int32 // conversation scene
@@ -75,6 +93,7 @@ func (c conversation) As(alias string) *conversation {
func (c *conversation) updateTableName(table string) *conversation {
c.ALL = field.NewAsterisk(table)
c.ID = field.NewInt64(table, "id")
c.Name = field.NewString(table, "name")
c.ConnectorID = field.NewInt64(table, "connector_id")
c.AgentID = field.NewInt64(table, "agent_id")
c.Scene = field.NewInt32(table, "scene")
@@ -100,8 +119,9 @@ func (c *conversation) GetFieldByName(fieldName string) (field.OrderExpr, bool)
}
func (c *conversation) fillFieldMap() {
c.fieldMap = make(map[string]field.Expr, 10)
c.fieldMap = make(map[string]field.Expr, 11)
c.fieldMap["id"] = c.ID
c.fieldMap["name"] = c.Name
c.fieldMap["connector_id"] = c.ConnectorID
c.fieldMap["agent_id"] = c.AgentID
c.fieldMap["scene"] = c.Scene

View File

@@ -35,6 +35,7 @@ type ConversationRepo interface {
GetByID(ctx context.Context, id int64) (*entity.Conversation, error)
UpdateSection(ctx context.Context, id int64) (int64, error)
Get(ctx context.Context, userID int64, agentID int64, scene int32, connectorID int64) (*entity.Conversation, error)
Update(ctx context.Context, req *entity.UpdateMeta) (*entity.Conversation, error)
Delete(ctx context.Context, id int64) (int64, error)
List(ctx context.Context, userID int64, agentID int64, connectorID int64, scene int32, limit int, page int) ([]*entity.Conversation, bool, error)
}

View File

@@ -29,4 +29,5 @@ type Conversation interface {
GetCurrentConversation(ctx context.Context, req *entity.GetCurrent) (*entity.Conversation, error)
Delete(ctx context.Context, id int64) error
List(ctx context.Context, req *entity.ListMeta) ([]*entity.Conversation, bool, error)
Update(ctx context.Context, req *entity.UpdateMeta) (*entity.Conversation, error)
}

View File

@@ -101,6 +101,11 @@ func (c *conversationImpl) Delete(ctx context.Context, id int64) error {
return nil
}
func (c *conversationImpl) Update(ctx context.Context, req *entity.UpdateMeta) (*entity.Conversation, error) {
// get conversation
return c.ConversationRepo.Update(ctx, req)
}
func (c *conversationImpl) List(ctx context.Context, req *entity.ListMeta) ([]*entity.Conversation, bool, error) {
conversationList, hasMore, err := c.ConversationRepo.List(ctx, req.UserID, req.AgentID, req.ConnectorID, int32(req.Scene), req.Limit, req.Page)

View File

@@ -16,19 +16,22 @@
package entity
import "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
import (
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
)
type Message = message.Message
type ListMeta struct {
ConversationID int64 `json:"conversation_id"`
RunID []*int64 `json:"run_id"`
UserID string `json:"user_id"`
AgentID int64 `json:"agent_id"`
OrderBy *string `json:"order_by"`
Limit int `json:"limit"`
Cursor int64 `json:"cursor"` // message id
Direction ScrollPageDirection `json:"direction"` // "prev" "Next"
ConversationID int64 `json:"conversation_id"`
RunID []*int64 `json:"run_id"`
UserID string `json:"user_id"`
AgentID int64 `json:"agent_id"`
OrderBy *string `json:"order_by"`
Limit int `json:"limit"`
Cursor int64 `json:"cursor"` // message id
Direction ScrollPageDirection `json:"direction"` // "prev" "Next"
MessageType []*message.MessageType `json:"message_type"`
}
type ListResult struct {
@@ -45,8 +48,9 @@ type GetByRunIDsRequest struct {
}
type DeleteMeta struct {
MessageIDs []int64 `json:"message_ids"`
RunIDs []int64 `json:"run_ids"`
ConversationID *int64 `json:"conversation_id"`
MessageIDs []int64 `json:"message_ids"`
RunIDs []int64 `json:"run_ids"`
}
type BrokenMeta struct {

View File

@@ -31,6 +31,7 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/internal/dal/query"
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
"github.com/coze-dev/coze-studio/backend/types/errno"
@@ -71,27 +72,41 @@ func (dao *MessageDAO) Create(ctx context.Context, msg *entity.Message) (*entity
return dao.messagePO2DO(poData), nil
}
func (dao *MessageDAO) List(ctx context.Context, conversationID int64, limit int, cursor int64, direction entity.ScrollPageDirection, messageType *message.MessageType) ([]*entity.Message, bool, error) {
func (dao *MessageDAO) List(ctx context.Context, listMeta *entity.ListMeta) ([]*entity.Message, bool, error) {
m := dao.query.Message
do := m.WithContext(ctx).Debug().Where(m.ConversationID.Eq(conversationID)).Where(m.Status.Eq(int32(entity.MessageStatusAvailable)))
do := m.WithContext(ctx).Debug().Where(m.ConversationID.Eq(listMeta.ConversationID)).Where(m.Status.Eq(int32(entity.MessageStatusAvailable)))
if messageType != nil {
do = do.Where(m.MessageType.Eq(string(*messageType)))
if len(listMeta.RunID) > 0 {
do = do.Where(m.RunID.In(slices.Transform(listMeta.RunID, func(t *int64) int64 {
return *t
})...))
}
if len(listMeta.MessageType) > 0 {
do = do.Where(m.MessageType.In(slices.Transform(listMeta.MessageType, func(t *message.MessageType) string {
return string(*t)
})...))
}
if limit > 0 {
do = do.Limit(int(limit) + 1)
if listMeta.Limit > 0 {
do = do.Limit(int(listMeta.Limit) + 1)
}
if cursor > 0 {
if direction == entity.ScrollPageDirectionPrev {
do = do.Where(m.CreatedAt.Lt(cursor))
} else {
do = do.Where(m.CreatedAt.Gt(cursor))
if listMeta.Cursor > 0 {
msg, err := m.Where(m.ID.Eq(listMeta.Cursor)).First()
if err != nil {
return nil, false, err
}
if listMeta.Direction == entity.ScrollPageDirectionPrev {
do = do.Where(m.CreatedAt.Lt(msg.CreatedAt))
do = do.Order(m.CreatedAt.Desc())
} else {
do = do.Where(m.CreatedAt.Gt(msg.CreatedAt))
do = do.Order(m.CreatedAt.Asc())
}
} else {
do = do.Order(m.CreatedAt.Desc())
}
do = do.Order(m.CreatedAt.Desc())
messageList, err := do.Find()
var hasMore bool
@@ -103,9 +118,9 @@ func (dao *MessageDAO) List(ctx context.Context, conversationID int64, limit int
return nil, false, err
}
if len(messageList) > limit {
if len(messageList) > int(listMeta.Limit) {
hasMore = true
messageList = messageList[:limit]
messageList = messageList[:int(listMeta.Limit)]
}
return dao.batchMessagePO2DO(messageList), hasMore, nil
@@ -113,7 +128,8 @@ func (dao *MessageDAO) List(ctx context.Context, conversationID int64, limit int
func (dao *MessageDAO) GetByRunIDs(ctx context.Context, runIDs []int64, orderBy string) ([]*entity.Message, error) {
m := dao.query.Message
do := m.WithContext(ctx).Debug().Where(m.RunID.In(runIDs...))
do := m.WithContext(ctx).Debug().Where(m.RunID.In(runIDs...)).Where(m.Status.Eq(int32(entity.MessageStatusAvailable)))
if orderBy == "DESC" {
do = do.Order(m.CreatedAt.Desc())
} else {
@@ -133,19 +149,37 @@ func (dao *MessageDAO) GetByRunIDs(ctx context.Context, runIDs []int64, orderBy
func (dao *MessageDAO) Edit(ctx context.Context, msgID int64, msg *message.Message) (int64, error) {
m := dao.query.Message
columns := dao.buildEditColumns(msg)
originMsg, err := dao.GetByID(ctx, msgID)
if originMsg == nil {
return 0, errorx.New(errno.ErrRecordNotFound)
}
if err != nil {
return 0, err
}
columns := dao.buildEditColumns(msg, originMsg)
do, err := m.WithContext(ctx).Where(m.ID.Eq(msgID)).UpdateColumns(columns)
if err != nil {
return 0, err
}
if do.RowsAffected == 0 {
return 0, errorx.New(errno.ErrRecordNotFound)
}
return do.RowsAffected, nil
}
func (dao *MessageDAO) buildEditColumns(msg *message.Message) map[string]interface{} {
func (dao *MessageDAO) buildEditColumns(msg *message.Message, originMsg *entity.Message) map[string]interface{} {
columns := make(map[string]interface{})
table := dao.query.Message
if msg.Content != "" {
msg.Role = originMsg.Role
columns[table.Content.ColumnName().String()] = msg.Content
modelContent, err := dao.buildModelContent(msg)
if err == nil {
columns[table.ModelContent.ColumnName().String()] = modelContent
}
}
if msg.MessageType != "" {
columns[table.MessageType.ColumnName().String()] = msg.MessageType
@@ -170,6 +204,11 @@ func (dao *MessageDAO) buildEditColumns(msg *message.Message) map[string]interfa
columns[table.UpdatedAt.ColumnName().String()] = time.Now().UnixMilli()
if msg.Ext != nil {
if originMsg.Ext != nil {
for k, v := range originMsg.Ext {
msg.Ext[k] = v
}
}
ext, err := sonic.MarshalString(msg.Ext)
if err == nil {
columns[table.Ext.ColumnName().String()] = ext
@@ -192,8 +231,8 @@ func (dao *MessageDAO) GetByID(ctx context.Context, msgID int64) (*entity.Messag
return dao.messagePO2DO(po), nil
}
func (dao *MessageDAO) Delete(ctx context.Context, msgIDs []int64, runIDs []int64) error {
if len(msgIDs) == 0 && len(runIDs) == 0 {
func (dao *MessageDAO) Delete(ctx context.Context, delMeta *entity.DeleteMeta) error {
if len(delMeta.MessageIDs) == 0 && len(delMeta.RunIDs) == 0 {
return nil
}
@@ -202,11 +241,14 @@ func (dao *MessageDAO) Delete(ctx context.Context, msgIDs []int64, runIDs []int6
m := dao.query.Message
do := m.WithContext(ctx)
if len(runIDs) > 0 {
do = do.Where(m.RunID.In(runIDs...))
if len(delMeta.RunIDs) > 0 {
do = do.Where(m.RunID.In(delMeta.RunIDs...))
}
if len(msgIDs) > 0 {
do = do.Where(m.ID.In(msgIDs...))
if len(delMeta.MessageIDs) > 0 {
do = do.Where(m.ID.In(delMeta.MessageIDs...))
}
if delMeta.ConversationID != nil && ptr.From(delMeta.ConversationID) > 0 {
do = do.Where(m.ConversationID.Eq(*delMeta.ConversationID))
}
_, err := do.UpdateColumns(&updateColumns)
return err
@@ -284,6 +326,9 @@ func (dao *MessageDAO) buildModelContent(msgDO *entity.Message) (string, error)
var multiContent []schema.ChatMessagePart
for _, contentData := range msgDO.MultiContent {
if contentData.Type == message.InputTypeText {
if len(msgDO.Content) == 0 && len(contentData.Text) > 0 {
msgDO.Content = contentData.Text
}
continue
}
one := schema.ChatMessagePart{}

View File

@@ -34,10 +34,9 @@ func NewMessageRepo(db *gorm.DB, idGen idgen.IDGenerator) MessageRepo {
type MessageRepo interface {
PreCreate(ctx context.Context, msg *entity.Message) (*entity.Message, error)
Create(ctx context.Context, msg *entity.Message) (*entity.Message, error)
List(ctx context.Context, conversationID int64, limit int, cursor int64,
direction entity.ScrollPageDirection, messageType *message.MessageType) ([]*entity.Message, bool, error)
List(ctx context.Context, listMeta *entity.ListMeta) ([]*entity.Message, bool, error)
GetByRunIDs(ctx context.Context, runIDs []int64, orderBy string) ([]*entity.Message, error)
Edit(ctx context.Context, msgID int64, message *message.Message) (int64, error)
GetByID(ctx context.Context, msgID int64) (*entity.Message, error)
Delete(ctx context.Context, msgIDs []int64, runIDs []int64) error
Delete(ctx context.Context, delMeta *entity.DeleteMeta) error
}

View File

@@ -24,6 +24,7 @@ import (
type Message interface {
List(ctx context.Context, req *entity.ListMeta) (*entity.ListResult, error)
ListWithoutPair(ctx context.Context, req *entity.ListMeta) (*entity.ListResult, error)
PreCreate(ctx context.Context, req *entity.Message) (*entity.Message, error)
Create(ctx context.Context, req *entity.Message) (*entity.Message, error)
GetByRunIDs(ctx context.Context, conversationID int64, runIDs []int64) ([]*entity.Message, error)

View File

@@ -18,6 +18,7 @@ package message
import (
"context"
"sort"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
@@ -51,9 +52,9 @@ func (m *messageImpl) Create(ctx context.Context, msg *entity.Message) (*entity.
func (m *messageImpl) List(ctx context.Context, req *entity.ListMeta) (*entity.ListResult, error) {
resp := &entity.ListResult{}
req.MessageType = []*message.MessageType{ptr.Of(message.MessageTypeQuestion)}
// get message with query
messageList, hasMore, err := m.MessageRepo.List(ctx, req.ConversationID, req.Limit, req.Cursor, req.Direction, ptr.Of(message.MessageTypeQuestion))
messageList, hasMore, err := m.MessageRepo.List(ctx, req)
if err != nil {
return resp, err
}
@@ -62,8 +63,11 @@ func (m *messageImpl) List(ctx context.Context, req *entity.ListMeta) (*entity.L
resp.HasMore = hasMore
if len(messageList) > 0 {
resp.PrevCursor = messageList[len(messageList)-1].CreatedAt
resp.NextCursor = messageList[0].CreatedAt
sort.Slice(messageList, func(i, j int) bool {
return messageList[i].CreatedAt > messageList[j].CreatedAt
})
resp.PrevCursor = messageList[len(messageList)-1].ID
resp.NextCursor = messageList[0].ID
var runIDs []int64
for _, m := range messageList {
@@ -82,6 +86,23 @@ func (m *messageImpl) List(ctx context.Context, req *entity.ListMeta) (*entity.L
return resp, nil
}
func (m *messageImpl) ListWithoutPair(ctx context.Context, req *entity.ListMeta) (*entity.ListResult, error) {
resp := &entity.ListResult{}
messageList, hasMore, err := m.MessageRepo.List(ctx, req)
if err != nil {
return resp, err
}
resp.Direction = req.Direction
resp.HasMore = hasMore
resp.Messages = messageList
if len(messageList) > 0 {
resp.PrevCursor = messageList[0].ID
resp.NextCursor = messageList[len(messageList)-1].ID
}
return resp, nil
}
func (m *messageImpl) GetByRunIDs(ctx context.Context, conversationID int64, runIDs []int64) ([]*entity.Message, error) {
return m.MessageRepo.GetByRunIDs(ctx, runIDs, "ASC")
}
@@ -96,7 +117,7 @@ func (m *messageImpl) Edit(ctx context.Context, req *entity.Message) (*entity.Me
}
func (m *messageImpl) Delete(ctx context.Context, req *entity.DeleteMeta) error {
return m.MessageRepo.Delete(ctx, req.MessageIDs, req.RunIDs)
return m.MessageRepo.Delete(ctx, req)
}
func (m *messageImpl) GetByID(ctx context.Context, id int64) (*entity.Message, error) {

View File

@@ -18,6 +18,7 @@ package message
import (
"context"
"encoding/json"
"testing"
"time"
@@ -145,20 +146,26 @@ func TestCreateMessage(t *testing.T) {
func TestEditMessage(t *testing.T) {
ctx := context.Background()
mockDBGen := orm.NewMockDB()
extData := map[string]string{
"test": "test",
}
ext, _ := json.Marshal(extData)
mockDBGen.AddTable(&model.Message{}).
AddRows(
&model.Message{
ID: 1,
ConversationID: 1,
UserID: "1",
Role: string(schema.User),
RunID: 123,
},
&model.Message{
ID: 2,
ConversationID: 1,
UserID: "1",
Role: string(schema.User),
RunID: 124,
Ext: string(ext),
},
)
@@ -177,7 +184,7 @@ func TestEditMessage(t *testing.T) {
Url: "https://xxxxx.xxxx/file",
Name: "test_file",
}
content := []*message.InputMetaData{
_ = []*message.InputMetaData{
{
Type: message.InputTypeText,
Text: "解析图片中的内容",
@@ -197,56 +204,293 @@ func TestEditMessage(t *testing.T) {
}
resp, err := NewService(components).Edit(ctx, &entity.Message{
ID: 2,
Content: "test edit message",
MultiContent: content,
ID: 2,
Content: "test edit message",
Ext: map[string]string{"newext": "true"},
// MultiContent: content,
})
_ = resp
msOne, err := NewService(components).GetByRunIDs(ctx, 1, []int64{124})
msg, err := NewService(components).GetByID(ctx, 2)
assert.NoError(t, err)
assert.Equal(t, int64(124), msOne[0].RunID)
assert.Equal(t, int64(2), msg.ID)
assert.Equal(t, "test edit message", msg.Content)
var modelContent *schema.Message
err = json.Unmarshal([]byte(msg.ModelContent), &modelContent)
assert.NoError(t, err)
assert.Equal(t, "test edit message", modelContent.Content)
assert.Equal(t, "true", msg.Ext["newext"])
}
func TestGetByRunIDs(t *testing.T) {
//func TestGetByRunIDs(t *testing.T) {
// ctx := context.Background()
//
// mockDBGen := orm.NewMockDB()
//
// mockDBGen.AddTable(&model.Message{}).
// AddRows(
// &model.Message{
// ID: 1,
// ConversationID: 1,
// UserID: "1",
// RunID: 123,
// Content: "test content123",
// },
// &model.Message{
// ID: 2,
// ConversationID: 1,
// UserID: "1",
// Content: "test content124",
// RunID: 124,
// },
// &model.Message{
// ID: 3,
// ConversationID: 1,
// UserID: "1",
// Content: "test content124",
// RunID: 124,
// },
// )
// mockDB, err := mockDBGen.DB()
// assert.NoError(t, err)
// components := &Components{
// MessageRepo: repository.NewMessageRepo(mockDB, nil),
// }
//
// resp, err := NewService(components).GetByRunIDs(ctx, 1, []int64{124})
//
// assert.NoError(t, err)
//
// assert.Len(t, resp, 2)
//}
func TestListWithoutPair(t *testing.T) {
ctx := context.Background()
t.Run("success_with_messages", func(t *testing.T) {
mockDBGen := orm.NewMockDB()
mockDBGen := orm.NewMockDB()
mockDBGen.AddTable(&model.Message{}).
AddRows(
&model.Message{
ID: 1,
ConversationID: 100,
UserID: "user123",
RunID: 200,
Content: "Hello",
MessageType: string(message.MessageTypeAnswer),
Status: 1, // MessageStatusAvailable
CreatedAt: time.Now().UnixMilli(),
},
&model.Message{
ID: 2,
ConversationID: 100,
UserID: "user123",
RunID: 201,
Content: "World",
MessageType: string(message.MessageTypeAnswer),
Status: 1, // MessageStatusAvailable
CreatedAt: time.Now().UnixMilli(),
},
)
mockDBGen.AddTable(&model.Message{}).
AddRows(
&model.Message{
ID: 1,
ConversationID: 1,
UserID: "1",
RunID: 123,
Content: "test content123",
},
&model.Message{
ID: 2,
ConversationID: 1,
UserID: "1",
Content: "test content124",
RunID: 124,
},
&model.Message{
ID: 3,
ConversationID: 1,
UserID: "1",
Content: "test content124",
RunID: 124,
},
)
mockDB, err := mockDBGen.DB()
assert.NoError(t, err)
components := &Components{
MessageRepo: repository.NewMessageRepo(mockDB, nil),
}
mockDB, err := mockDBGen.DB()
assert.NoError(t, err)
resp, err := NewService(components).GetByRunIDs(ctx, 1, []int64{124})
components := &Components{
MessageRepo: repository.NewMessageRepo(mockDB, nil),
}
assert.NoError(t, err)
req := &entity.ListMeta{
ConversationID: 100,
UserID: "user123",
Limit: 10,
Direction: entity.ScrollPageDirectionNext,
}
assert.Len(t, resp, 2)
resp, err := NewService(components).ListWithoutPair(ctx, req)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, entity.ScrollPageDirectionNext, resp.Direction)
assert.False(t, resp.HasMore)
assert.Len(t, resp.Messages, 2)
assert.Equal(t, "Hello", resp.Messages[0].Content)
assert.Equal(t, "World", resp.Messages[1].Content)
})
t.Run("empty_result", func(t *testing.T) {
mockDBGen := orm.NewMockDB()
mockDBGen.AddTable(&model.Message{})
mockDB, err := mockDBGen.DB()
assert.NoError(t, err)
components := &Components{
MessageRepo: repository.NewMessageRepo(mockDB, nil),
}
req := &entity.ListMeta{
ConversationID: 999,
UserID: "user123",
Limit: 10,
Direction: entity.ScrollPageDirectionNext,
}
resp, err := NewService(components).ListWithoutPair(ctx, req)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, entity.ScrollPageDirectionNext, resp.Direction)
assert.False(t, resp.HasMore)
assert.Len(t, resp.Messages, 0)
})
t.Run("pagination_has_more", func(t *testing.T) {
mockDBGen := orm.NewMockDB()
mockDBGen.AddTable(&model.Message{}).
AddRows(
&model.Message{
ID: 1,
ConversationID: 100,
UserID: "user123",
RunID: 200,
Content: "Message 1",
MessageType: string(message.MessageTypeAnswer),
Status: 1,
CreatedAt: time.Now().UnixMilli() - 3000,
},
&model.Message{
ID: 2,
ConversationID: 100,
UserID: "user123",
RunID: 201,
Content: "Message 2",
MessageType: string(message.MessageTypeAnswer),
Status: 1,
CreatedAt: time.Now().UnixMilli() - 2000,
},
&model.Message{
ID: 3,
ConversationID: 100,
UserID: "user123",
RunID: 202,
Content: "Message 3",
MessageType: string(message.MessageTypeAnswer),
Status: 1,
CreatedAt: time.Now().UnixMilli() - 1000,
},
)
mockDB, err := mockDBGen.DB()
assert.NoError(t, err)
components := &Components{
MessageRepo: repository.NewMessageRepo(mockDB, nil),
}
req := &entity.ListMeta{
ConversationID: 100,
UserID: "user123",
Limit: 2,
Direction: entity.ScrollPageDirectionNext,
}
resp, err := NewService(components).ListWithoutPair(ctx, req)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, entity.ScrollPageDirectionNext, resp.Direction)
assert.True(t, resp.HasMore)
assert.Len(t, resp.Messages, 2)
})
t.Run("direction_prev", func(t *testing.T) {
mockDBGen := orm.NewMockDB()
mockDBGen.AddTable(&model.Message{}).
AddRows(
&model.Message{
ID: 1,
ConversationID: 100,
UserID: "user123",
RunID: 200,
Content: "Test message",
MessageType: string(message.MessageTypeAnswer),
Status: 1,
CreatedAt: time.Now().UnixMilli(),
},
)
mockDB, err := mockDBGen.DB()
assert.NoError(t, err)
components := &Components{
MessageRepo: repository.NewMessageRepo(mockDB, nil),
}
req := &entity.ListMeta{
ConversationID: 100,
UserID: "user123",
Limit: 10,
Direction: entity.ScrollPageDirectionPrev,
}
resp, err := NewService(components).ListWithoutPair(ctx, req)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, entity.ScrollPageDirectionPrev, resp.Direction)
assert.False(t, resp.HasMore)
assert.Len(t, resp.Messages, 1)
})
t.Run("with_message_type_filter", func(t *testing.T) {
mockDBGen := orm.NewMockDB()
mockDBGen.AddTable(&model.Message{}).
AddRows(
&model.Message{
ID: 1,
ConversationID: 100,
UserID: "user123",
RunID: 200,
Content: "Answer message",
MessageType: string(message.MessageTypeAnswer),
Status: 1,
CreatedAt: time.Now().UnixMilli(),
},
&model.Message{
ID: 2,
ConversationID: 100,
UserID: "user123",
RunID: 201,
Content: "Question message",
MessageType: string(message.MessageTypeQuestion),
Status: 1,
CreatedAt: time.Now().UnixMilli(),
},
)
mockDB, err := mockDBGen.DB()
assert.NoError(t, err)
components := &Components{
MessageRepo: repository.NewMessageRepo(mockDB, nil),
}
req := &entity.ListMeta{
ConversationID: 100,
UserID: "user123",
Limit: 10,
Direction: entity.ScrollPageDirectionNext,
MessageType: []*message.MessageType{&[]message.MessageType{message.MessageTypeAnswer}[0]},
}
resp, err := NewService(components).ListWithoutPair(ctx, req)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Len(t, resp.Messages, 1)
assert.Equal(t, "Answer message", resp.Messages[0].Content)
})
}