feat: Support for Chat Flow & Agent Support for binding a single chat flow (#765)

Co-authored-by: Yu Yang <72337138+tomasyu985@users.noreply.github.com>
Co-authored-by: zengxiaohui <csu.zengxiaohui@gmail.com>
Co-authored-by: lijunwen.gigoo <lijunwen.gigoo@bytedance.com>
Co-authored-by: lvxinyu.1117 <lvxinyu.1117@bytedance.com>
Co-authored-by: liuyunchao.0510 <liuyunchao.0510@bytedance.com>
Co-authored-by: haozhenfei <37089575+haozhenfei@users.noreply.github.com>
Co-authored-by: July <jiangxujin@bytedance.com>
Co-authored-by: tecvan-fe <fanwenjie.fe@bytedance.com>
This commit is contained in:
Zhj
2025-08-28 21:53:32 +08:00
committed by GitHub
parent bbc615a18e
commit d70101c979
503 changed files with 48036 additions and 3427 deletions

View File

@@ -43,6 +43,7 @@ type RunRecordMeta struct {
ChatRequest *string `json:"chat_message"`
CompletedAt int64 `json:"completed_at"`
FailedAt int64 `json:"failed_at"`
CreatorID int64 `json:"creator_id"`
}
type ChunkRunItem = RunRecordMeta
@@ -158,3 +159,18 @@ type ModelAnswerEvent struct {
Message *schema.Message
Err error
}
type ListRunRecordMeta struct {
ConversationID int64 `json:"conversation_id"`
AgentID int64 `json:"agent_id"`
SectionID int64 `json:"section_id"`
Limit int32 `json:"limit"`
OrderBy string `json:"order_by"` //desc asc
BeforeID int64 `json:"before_id"`
AfterID int64 `json:"after_id"`
}
type CancelRunMeta struct {
ConversationID int64 `json:"conversation_id"`
RunID int64 `json:"run_id"`
}

View File

@@ -0,0 +1,45 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package internal
import (
"context"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
crossagent "github.com/coze-dev/coze-studio/backend/crossdomain/contract/agent"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
)
func getAgentHistoryRounds(agentInfo *singleagent.SingleAgent) int32 {
var conversationTurns int32 = entity.ConversationTurnsDefault
if agentInfo != nil && agentInfo.ModelInfo != nil && agentInfo.ModelInfo.ShortMemoryPolicy != nil && ptr.From(agentInfo.ModelInfo.ShortMemoryPolicy.HistoryRound) > 0 {
conversationTurns = ptr.From(agentInfo.ModelInfo.ShortMemoryPolicy.HistoryRound)
}
return conversationTurns
}
func getAgentInfo(ctx context.Context, agentID int64, isDraft bool) (*singleagent.SingleAgent, error) {
agentInfo, err := crossagent.DefaultSVC().ObtainAgentByIdentity(ctx, &singleagent.AgentIdentity{
AgentID: agentID,
IsDraft: isDraft,
})
if err != nil {
return nil, err
}
return agentInfo, nil
}

View File

@@ -0,0 +1,215 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package internal
import (
"bytes"
"context"
"errors"
"io"
"strconv"
"sync"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
crossworkflow "github.com/coze-dev/coze-studio/backend/crossdomain/contract/workflow"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
msgEntity "github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
func (art *AgentRuntime) ChatflowRun(ctx context.Context, imagex imagex.ImageX) (err error) {
mh := &MesssageEventHanlder{
sw: art.SW,
messageEvent: art.MessageEvent,
}
resumeInfo := parseResumeInfo(ctx, art.GetHistory())
wfID, _ := strconv.ParseInt(art.GetAgentInfo().LayoutInfo.WorkflowId, 10, 64)
if wfID == 0 {
mh.handlerErr(ctx, errorx.New(errno.ErrAgentRunWorkflowNotFound))
return
}
var wfStreamer *schema.StreamReader[*crossworkflow.WorkflowMessage]
executeConfig := crossworkflow.ExecuteConfig{
ID: wfID,
ConnectorID: art.GetRunMeta().ConnectorID,
ConnectorUID: art.GetRunMeta().UserID,
AgentID: ptr.Of(art.GetRunMeta().AgentID),
Mode: crossworkflow.ExecuteModeRelease,
BizType: crossworkflow.BizTypeAgent,
SyncPattern: crossworkflow.SyncPatternStream,
From: crossworkflow.FromLatestVersion,
}
if resumeInfo != nil {
wfStreamer, err = crossworkflow.DefaultSVC().StreamResume(ctx, &crossworkflow.ResumeRequest{
ResumeData: concatWfInput(art),
EventID: resumeInfo.ChatflowInterrupt.InterruptEvent.ID,
ExecuteID: resumeInfo.ChatflowInterrupt.ExecuteID,
}, executeConfig)
} else {
executeConfig.ConversationID = &art.GetRunMeta().ConversationID
executeConfig.SectionID = &art.GetRunMeta().SectionID
executeConfig.InitRoundID = &art.RunRecord.ID
executeConfig.RoundID = &art.RunRecord.ID
executeConfig.UserMessage = transMessageToSchemaMessage(ctx, []*msgEntity.Message{art.GetInput()}, imagex)[0]
executeConfig.MaxHistoryRounds = ptr.Of(getAgentHistoryRounds(art.GetAgentInfo()))
wfStreamer, err = crossworkflow.DefaultSVC().StreamExecute(ctx, executeConfig, map[string]any{
"USER_INPUT": concatWfInput(art),
})
}
if err != nil {
return err
}
var wg sync.WaitGroup
wg.Add(1)
safego.Go(ctx, func() {
defer wg.Done()
art.pullWfStream(ctx, wfStreamer, mh)
})
wg.Wait()
return err
}
func concatWfInput(rtDependence *AgentRuntime) string {
var input string
for _, content := range rtDependence.RunMeta.Content {
if content.Type == message.InputTypeText {
input = content.Text + "," + input
} else {
for _, file := range content.FileData {
input += file.Url + ","
}
}
}
return input
}
func (art *AgentRuntime) pullWfStream(ctx context.Context, events *schema.StreamReader[*crossworkflow.WorkflowMessage], mh *MesssageEventHanlder) {
fullAnswerContent := bytes.NewBuffer([]byte{})
var usage *msgEntity.UsageExt
preAnswerMsg, cErr := preCreateAnswer(ctx, art)
if cErr != nil {
return
}
var preMsgIsFinish = false
var lastAnswerMsg *entity.ChunkMessageItem
for {
st, re := events.Recv()
if re != nil {
if errors.Is(re, io.EOF) {
if lastAnswerMsg != nil && usage != nil {
art.SetUsage(&agentrun.Usage{
LlmPromptTokens: usage.InputTokens,
LlmCompletionTokens: usage.OutputTokens,
LlmTotalTokens: usage.TotalCount,
})
_ = mh.handlerWfUsage(ctx, lastAnswerMsg, usage)
}
finishErr := mh.handlerFinalAnswerFinish(ctx, art)
if finishErr != nil {
logs.CtxErrorf(ctx, "handlerFinalAnswerFinish error: %v", finishErr)
return
}
return
}
logs.CtxErrorf(ctx, "pullWfStream Recv error: %v", re)
mh.handlerErr(ctx, re)
return
}
if st == nil {
continue
}
if st.StateMessage != nil {
if st.StateMessage.Status == crossworkflow.WorkflowFailed {
mh.handlerErr(ctx, st.StateMessage.LastError)
continue
}
if st.StateMessage.Usage != nil {
usage = &msgEntity.UsageExt{
InputTokens: st.StateMessage.Usage.InputTokens,
OutputTokens: st.StateMessage.Usage.OutputTokens,
TotalCount: st.StateMessage.Usage.InputTokens + st.StateMessage.Usage.OutputTokens,
}
}
if st.StateMessage.InterruptEvent != nil { // interrupt
mh.handlerWfInterruptMsg(ctx, st.StateMessage, art)
continue
}
}
if st.DataMessage == nil {
continue
}
switch st.DataMessage.Type {
case crossworkflow.Answer:
// input node & question node skip
if st.DataMessage != nil && (st.DataMessage.NodeType == crossworkflow.NodeTypeInputReceiver || st.DataMessage.NodeType == crossworkflow.NodeTypeQuestion) {
break
}
if preMsgIsFinish {
preAnswerMsg, cErr = preCreateAnswer(ctx, art)
if cErr != nil {
return
}
preMsgIsFinish = false
}
if st.DataMessage.Content != "" {
fullAnswerContent.WriteString(st.DataMessage.Content)
}
sendAnswerMsg := buildSendMsg(ctx, preAnswerMsg, false, art)
sendAnswerMsg.Content = st.DataMessage.Content
mh.messageEvent.SendMsgEvent(entity.RunEventMessageDelta, sendAnswerMsg, mh.sw)
if st.DataMessage.Last {
preMsgIsFinish = true
sendAnswerMsg := buildSendMsg(ctx, preAnswerMsg, false, art)
sendAnswerMsg.Content = fullAnswerContent.String()
fullAnswerContent.Reset()
hfErr := mh.handlerAnswer(ctx, sendAnswerMsg, usage, art, preAnswerMsg)
if hfErr != nil {
return
}
lastAnswerMsg = sendAnswerMsg
}
}
}
}

View File

@@ -19,6 +19,8 @@ package dal
import (
"context"
"encoding/json"
"strconv"
"strings"
"time"
"gorm.io/gorm"
@@ -27,6 +29,7 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/internal/dal/model"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/internal/dal/query"
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
)
@@ -59,8 +62,12 @@ func (dao *RunRecordDAO) Create(ctx context.Context, runMeta *entity.AgentRunMet
return dao.buildPo2Do(createPO), nil
}
func (dao *RunRecordDAO) GetByID(ctx context.Context, id int64) (*model.RunRecord, error) {
return dao.query.RunRecord.WithContext(ctx).Where(dao.query.RunRecord.ID.Eq(id)).First()
func (dao *RunRecordDAO) GetByID(ctx context.Context, id int64) (*entity.RunRecordMeta, error) {
po, err := dao.query.RunRecord.WithContext(ctx).Where(dao.query.RunRecord.ID.Eq(id)).First()
if err != nil {
return nil, err
}
return dao.buildPo2Do(po), nil
}
func (dao *RunRecordDAO) UpdateByID(ctx context.Context, id int64, updateMeta *entity.UpdateMeta) error {
@@ -106,20 +113,40 @@ func (dao *RunRecordDAO) Delete(ctx context.Context, id []int64) error {
return err
}
func (dao *RunRecordDAO) List(ctx context.Context, conversationID int64, sectionID int64, limit int32) ([]*model.RunRecord, error) {
logs.CtxInfof(ctx, "list run record req:%v, sectionID:%v, limit:%v", conversationID, sectionID, limit)
func (dao *RunRecordDAO) List(ctx context.Context, meta *entity.ListRunRecordMeta) ([]*entity.RunRecordMeta, error) {
logs.CtxInfof(ctx, "list run record req:%v, sectionID:%v, limit:%v", meta.ConversationID, meta.SectionID, meta.Limit)
m := dao.query.RunRecord
do := m.WithContext(ctx).Where(m.ConversationID.Eq(conversationID)).Debug().Where(m.Status.NotIn(string(entity.RunStatusDeleted)))
if sectionID > 0 {
do = do.Where(m.SectionID.Eq(sectionID))
do := m.WithContext(ctx).Where(m.ConversationID.Eq(meta.ConversationID)).Debug().Where(m.Status.NotIn(string(entity.RunStatusDeleted)))
if meta.BeforeID > 0 {
runRecord, err := m.Where(m.ID.Eq(meta.BeforeID)).First()
if err != nil {
return nil, err
}
do = do.Where(m.CreatedAt.Lt(runRecord.CreatedAt))
}
if limit > 0 {
do = do.Limit(int(limit))
if meta.AfterID > 0 {
runRecord, err := m.Where(m.ID.Eq(meta.AfterID)).First()
if err != nil {
return nil, err
}
do = do.Where(m.CreatedAt.Gt(runRecord.CreatedAt))
}
if meta.SectionID > 0 {
do = do.Where(m.SectionID.Eq(meta.SectionID))
}
if meta.Limit > 0 {
do = do.Limit(int(meta.Limit))
}
if strings.ToLower(meta.OrderBy) == "asc" {
do = do.Order(m.CreatedAt.Asc())
} else {
do = do.Order(m.CreatedAt.Desc())
}
runRecords, err := do.Order(m.CreatedAt.Desc()).Find()
return runRecords, err
runRecords, err := do.Find()
return slices.Transform(runRecords, func(item *model.RunRecord) *entity.RunRecordMeta {
return dao.buildPo2Do(item)
}), err
}
func (dao *RunRecordDAO) buildCreatePO(ctx context.Context, runMeta *entity.AgentRunMeta) (*model.RunRecord, error) {
@@ -135,7 +162,10 @@ func (dao *RunRecordDAO) buildCreatePO(ctx context.Context, runMeta *entity.Agen
}
timeNow := time.Now().UnixMilli()
creatorID, err := strconv.ParseInt(runMeta.UserID, 10, 64)
if err != nil {
return nil, err
}
return &model.RunRecord{
ID: runID,
ConversationID: runMeta.ConversationID,
@@ -145,6 +175,7 @@ func (dao *RunRecordDAO) buildCreatePO(ctx context.Context, runMeta *entity.Agen
ChatRequest: string(reqOrigin),
UserID: runMeta.UserID,
CreatedAt: timeNow,
CreatorID: creatorID,
}, nil
}
@@ -161,7 +192,21 @@ func (dao *RunRecordDAO) buildPo2Do(po *model.RunRecord) *entity.RunRecordMeta {
CompletedAt: po.CompletedAt,
FailedAt: po.FailedAt,
Usage: po.Usage,
CreatorID: po.CreatorID,
}
return runMeta
}
func (dao *RunRecordDAO) Cancel(ctx context.Context, meta *entity.CancelRunMeta) (*entity.RunRecordMeta, error) {
m := dao.query.RunRecord
_, err := m.WithContext(ctx).Where(m.ID.Eq(meta.RunID)).UpdateColumns(map[string]interface{}{
"updated_at": time.Now().UnixMilli(),
"status": entity.RunEventCancelled,
})
if err != nil {
return nil, err
}
return dao.GetByID(ctx, meta.RunID)
}

View File

@@ -1,78 +0,0 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package internal
import (
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
)
type Event struct {
}
func NewEvent() *Event {
return &Event{}
}
func (e *Event) buildMessageEvent(runEvent entity.RunEvent, chunkMsgItem *entity.ChunkMessageItem) *entity.AgentRunResponse {
return &entity.AgentRunResponse{
Event: runEvent,
ChunkMessageItem: chunkMsgItem,
}
}
func (e *Event) buildRunEvent(runEvent entity.RunEvent, chunkRunItem *entity.ChunkRunItem) *entity.AgentRunResponse {
return &entity.AgentRunResponse{
Event: runEvent,
ChunkRunItem: chunkRunItem,
}
}
func (e *Event) buildErrEvent(runEvent entity.RunEvent, err *entity.RunError) *entity.AgentRunResponse {
return &entity.AgentRunResponse{
Event: runEvent,
Error: err,
}
}
func (e *Event) buildStreamDoneEvent() *entity.AgentRunResponse {
return &entity.AgentRunResponse{
Event: entity.RunEventStreamDone,
}
}
func (e *Event) SendRunEvent(runEvent entity.RunEvent, runItem *entity.ChunkRunItem, sw *schema.StreamWriter[*entity.AgentRunResponse]) {
resp := e.buildRunEvent(runEvent, runItem)
sw.Send(resp, nil)
}
func (e *Event) SendMsgEvent(runEvent entity.RunEvent, messageItem *entity.ChunkMessageItem, sw *schema.StreamWriter[*entity.AgentRunResponse]) {
resp := e.buildMessageEvent(runEvent, messageItem)
sw.Send(resp, nil)
}
func (e *Event) SendErrEvent(runEvent entity.RunEvent, sw *schema.StreamWriter[*entity.AgentRunResponse], err *entity.RunError) {
resp := e.buildErrEvent(runEvent, err)
sw.Send(resp, nil)
}
func (e *Event) SendStreamDoneEvent(sw *schema.StreamWriter[*entity.AgentRunResponse]) {
resp := e.buildStreamDoneEvent()
sw.Send(resp, nil)
}

View File

@@ -0,0 +1,512 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package internal
import (
"context"
"encoding/json"
"fmt"
"strconv"
"time"
"github.com/cloudwego/eino/schema"
messageModel "github.com/coze-dev/coze-studio/backend/api/model/conversation/message"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
crossagent "github.com/coze-dev/coze-studio/backend/crossdomain/contract/agent"
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
crossmessage "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message"
crossworkflow "github.com/coze-dev/coze-studio/backend/crossdomain/contract/workflow"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
msgEntity "github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
func buildSendMsg(_ context.Context, msg *msgEntity.Message, isFinish bool, rtDependence *AgentRuntime) *entity.ChunkMessageItem {
copyMap := make(map[string]string)
for k, v := range msg.Ext {
copyMap[k] = v
}
return &entity.ChunkMessageItem{
ID: msg.ID,
ConversationID: msg.ConversationID,
SectionID: msg.SectionID,
AgentID: msg.AgentID,
Content: msg.Content,
Role: entity.RoleTypeAssistant,
ContentType: msg.ContentType,
MessageType: msg.MessageType,
ReplyID: rtDependence.GetQuestionMsgID(),
Type: msg.MessageType,
CreatedAt: msg.CreatedAt,
UpdatedAt: msg.UpdatedAt,
RunID: rtDependence.GetRunRecord().ID,
Ext: copyMap,
IsFinish: isFinish,
ReasoningContent: ptr.Of(msg.ReasoningContent),
}
}
func buildKnowledge(_ context.Context, chunk *entity.AgentRespEvent) *msgEntity.VerboseInfo {
var recallDatas []msgEntity.RecallDataInfo
for _, kOne := range chunk.Knowledge {
recallDatas = append(recallDatas, msgEntity.RecallDataInfo{
Slice: kOne.Content,
Meta: msgEntity.MetaInfo{
Dataset: msgEntity.DatasetInfo{
ID: kOne.MetaData["dataset_id"].(string),
Name: kOne.MetaData["dataset_name"].(string),
},
Document: msgEntity.DocumentInfo{
ID: kOne.MetaData["document_id"].(string),
Name: kOne.MetaData["document_name"].(string),
},
},
Score: kOne.Score(),
})
}
verboseData := &msgEntity.VerboseData{
Chunks: recallDatas,
OriReq: "",
StatusCode: 0,
}
data, err := json.Marshal(verboseData)
if err != nil {
return nil
}
knowledgeInfo := &msgEntity.VerboseInfo{
MessageType: string(entity.MessageSubTypeKnowledgeCall),
Data: string(data),
}
return knowledgeInfo
}
func buildBotStateExt(arm *entity.AgentRunMeta) *msgEntity.BotStateExt {
agentID := strconv.FormatInt(arm.AgentID, 10)
botStateExt := &msgEntity.BotStateExt{
AgentID: agentID,
AgentName: arm.Name,
Awaiting: agentID,
BotID: agentID,
}
return botStateExt
}
type irMsg struct {
Type string `json:"type,omitempty"`
ContentType string `json:"content_type"`
Content any `json:"content"` // either optionContent or string
ID string `json:"id,omitempty"`
}
func parseInterruptData(_ context.Context, interruptData *singleagent.InterruptInfo) (string, message.ContentType, error) {
defaultContentType := message.ContentTypeText
switch interruptData.InterruptType {
case singleagent.InterruptEventType_OauthPlugin:
data := interruptData.AllToolInterruptData[interruptData.ToolCallID].ToolNeedOAuth.Message
return data, defaultContentType, nil
case singleagent.InterruptEventType_Question:
data := interruptData.AllWfInterruptData[interruptData.ToolCallID].InterruptData
return processQuestionInterruptData(data)
case singleagent.InterruptEventType_InputNode:
data := interruptData.AllWfInterruptData[interruptData.ToolCallID].InterruptData
return processInputNodeInterruptData(data)
case singleagent.InterruptEventType_WorkflowLLM:
toolInterruptEvent := interruptData.AllWfInterruptData[interruptData.ToolCallID].ToolInterruptEvent
data := toolInterruptEvent.InterruptData
if singleagent.InterruptEventType(toolInterruptEvent.EventType) == singleagent.InterruptEventType_InputNode {
return processInputNodeInterruptData(data)
}
if singleagent.InterruptEventType(toolInterruptEvent.EventType) == singleagent.InterruptEventType_Question {
return processQuestionInterruptData(data)
}
return "", defaultContentType, errorx.New(errno.ErrUnknowInterruptType)
}
return "", defaultContentType, errorx.New(errno.ErrUnknowInterruptType)
}
func processQuestionInterruptData(data string) (string, message.ContentType, error) {
defaultContentType := message.ContentTypeText
var iData map[string][]*irMsg
err := json.Unmarshal([]byte(data), &iData)
if err != nil {
return "", defaultContentType, err
}
if len(iData["messages"]) == 0 {
return "", defaultContentType, errorx.New(errno.ErrInterruptDataEmpty)
}
interruptMsg := iData["messages"][0]
if interruptMsg.ContentType == "text" {
return interruptMsg.Content.(string), defaultContentType, nil
} else if interruptMsg.ContentType == "option" || interruptMsg.ContentType == "form_schema" {
iMarshalData, err := json.Marshal(interruptMsg)
if err != nil {
return "", defaultContentType, err
}
return string(iMarshalData), message.ContentTypeCard, nil
}
return "", defaultContentType, errorx.New(errno.ErrUnknowInterruptType)
}
func processInputNodeInterruptData(data string) (string, message.ContentType, error) {
return data, message.ContentTypeCard, nil
}
func handlerUsage(meta *schema.ResponseMeta) *msgEntity.UsageExt {
if meta == nil || meta.Usage == nil {
return nil
}
return &msgEntity.UsageExt{
TotalCount: int64(meta.Usage.TotalTokens),
InputTokens: int64(meta.Usage.PromptTokens),
OutputTokens: int64(meta.Usage.CompletionTokens),
}
}
func preCreateAnswer(ctx context.Context, rtDependence *AgentRuntime) (*msgEntity.Message, error) {
arm := rtDependence.RunMeta
msgMeta := &msgEntity.Message{
ConversationID: arm.ConversationID,
RunID: rtDependence.RunRecord.ID,
AgentID: arm.AgentID,
SectionID: arm.SectionID,
UserID: arm.UserID,
Role: schema.Assistant,
MessageType: message.MessageTypeAnswer,
ContentType: message.ContentTypeText,
Ext: arm.Ext,
}
if arm.Ext == nil {
msgMeta.Ext = map[string]string{}
}
botStateExt := buildBotStateExt(arm)
bseString, err := json.Marshal(botStateExt)
if err != nil {
return nil, err
}
if _, ok := msgMeta.Ext[string(msgEntity.MessageExtKeyBotState)]; !ok {
msgMeta.Ext[string(msgEntity.MessageExtKeyBotState)] = string(bseString)
}
msgMeta.Ext = arm.Ext
return crossmessage.DefaultSVC().PreCreate(ctx, msgMeta)
}
func buildAgentMessage2Create(ctx context.Context, chunk *entity.AgentRespEvent, messageType message.MessageType, rtDependence *AgentRuntime) *message.Message {
arm := rtDependence.GetRunMeta()
msg := &msgEntity.Message{
ConversationID: arm.ConversationID,
RunID: rtDependence.RunRecord.ID,
AgentID: arm.AgentID,
SectionID: arm.SectionID,
UserID: arm.UserID,
MessageType: messageType,
}
buildExt := map[string]string{}
timeCost := fmt.Sprintf("%.1f", float64(time.Since(rtDependence.GetStartTime()).Milliseconds())/1000.00)
switch messageType {
case message.MessageTypeQuestion:
msg.Role = schema.User
msg.ContentType = arm.ContentType
for _, content := range arm.Content {
if content.Type == message.InputTypeText {
msg.Content = content.Text
break
}
}
msg.MultiContent = arm.Content
buildExt = arm.Ext
msg.DisplayContent = arm.DisplayContent
case message.MessageTypeAnswer, message.MessageTypeToolAsAnswer:
msg.Role = schema.Assistant
msg.ContentType = message.ContentTypeText
case message.MessageTypeToolResponse:
msg.Role = schema.Assistant
msg.ContentType = message.ContentTypeText
msg.Content = chunk.ToolsMessage[0].Content
buildExt[string(msgEntity.MessageExtKeyTimeCost)] = timeCost
modelContent := chunk.ToolsMessage[0]
mc, err := json.Marshal(modelContent)
if err == nil {
msg.ModelContent = string(mc)
}
case message.MessageTypeKnowledge:
msg.Role = schema.Assistant
msg.ContentType = message.ContentTypeText
knowledgeContent := buildKnowledge(ctx, chunk)
if knowledgeContent != nil {
knInfo, err := json.Marshal(knowledgeContent)
if err == nil {
msg.Content = string(knInfo)
}
}
buildExt[string(msgEntity.MessageExtKeyTimeCost)] = timeCost
modelContent := chunk.Knowledge
mc, err := json.Marshal(modelContent)
if err == nil {
msg.ModelContent = string(mc)
}
case message.MessageTypeFunctionCall:
msg.Role = schema.Assistant
msg.ContentType = message.ContentTypeText
if len(chunk.FuncCall.ToolCalls) > 0 {
toolCall := chunk.FuncCall.ToolCalls[0]
toolCalling, err := json.Marshal(toolCall)
if err == nil {
msg.Content = string(toolCalling)
}
buildExt[string(msgEntity.MessageExtKeyPlugin)] = toolCall.Function.Name
buildExt[string(msgEntity.MessageExtKeyToolName)] = toolCall.Function.Name
buildExt[string(msgEntity.MessageExtKeyTimeCost)] = timeCost
modelContent := chunk.FuncCall
mc, err := json.Marshal(modelContent)
if err == nil {
msg.ModelContent = string(mc)
}
}
case message.MessageTypeFlowUp:
msg.Role = schema.Assistant
msg.ContentType = message.ContentTypeText
msg.Content = chunk.Suggest.Content
case message.MessageTypeVerbose:
msg.Role = schema.Assistant
msg.ContentType = message.ContentTypeText
d := &entity.Data{
FinishReason: 0,
FinData: "",
}
dByte, _ := json.Marshal(d)
afc := &entity.AnswerFinshContent{
MsgType: entity.MessageSubTypeGenerateFinish,
Data: string(dByte),
}
afcMarshal, _ := json.Marshal(afc)
msg.Content = string(afcMarshal)
case message.MessageTypeInterrupt:
msg.Role = schema.Assistant
msg.MessageType = message.MessageTypeVerbose
msg.ContentType = message.ContentTypeText
afc := &entity.AnswerFinshContent{
MsgType: entity.MessageSubTypeInterrupt,
Data: "",
}
afcMarshal, _ := json.Marshal(afc)
msg.Content = string(afcMarshal)
// Add ext to save to context_message
interruptByte, err := json.Marshal(chunk.Interrupt)
if err == nil {
buildExt[string(msgEntity.ExtKeyResumeInfo)] = string(interruptByte)
}
buildExt[string(msgEntity.ExtKeyToolCallsIDs)] = chunk.Interrupt.ToolCallID
rc := &messageModel.RequiredAction{
Type: "submit_tool_outputs",
SubmitToolOutputs: &messageModel.SubmitToolOutputs{},
}
msg.RequiredAction = rc
rcExtByte, err := json.Marshal(rc)
if err == nil {
buildExt[string(msgEntity.ExtKeyRequiresAction)] = string(rcExtByte)
}
}
if messageType != message.MessageTypeQuestion {
botStateExt := buildBotStateExt(arm)
bseString, err := json.Marshal(botStateExt)
if err == nil {
buildExt[string(msgEntity.MessageExtKeyBotState)] = string(bseString)
}
}
msg.Ext = buildExt
return msg
}
func handlerWfInterruptEvent(_ context.Context, interruptEventData *crossworkflow.InterruptEvent) (string, message.ContentType, error) {
defaultContentType := message.ContentTypeText
switch singleagent.InterruptEventType(interruptEventData.EventType) {
case singleagent.InterruptEventType_OauthPlugin:
case singleagent.InterruptEventType_Question:
data := interruptEventData.InterruptData
return processQuestionInterruptData(data)
case singleagent.InterruptEventType_InputNode:
data := interruptEventData.InterruptData
return processInputNodeInterruptData(data)
case singleagent.InterruptEventType_WorkflowLLM:
data := interruptEventData.ToolInterruptEvent.InterruptData
if singleagent.InterruptEventType(interruptEventData.EventType) == singleagent.InterruptEventType_InputNode {
return processInputNodeInterruptData(data)
}
if singleagent.InterruptEventType(interruptEventData.EventType) == singleagent.InterruptEventType_Question {
return processQuestionInterruptData(data)
}
return "", defaultContentType, errorx.New(errno.ErrUnknowInterruptType)
}
return "", defaultContentType, errorx.New(errno.ErrUnknowInterruptType)
}
func historyPairs(historyMsg []*message.Message) []*message.Message {
fcMsgPairs := make(map[int64][]*message.Message)
for _, one := range historyMsg {
if one.MessageType != message.MessageTypeFunctionCall && one.MessageType != message.MessageTypeToolResponse {
continue
}
if _, ok := fcMsgPairs[one.RunID]; !ok {
fcMsgPairs[one.RunID] = []*message.Message{one}
} else {
fcMsgPairs[one.RunID] = append(fcMsgPairs[one.RunID], one)
}
}
var historyAfterPairs []*message.Message
for _, value := range historyMsg {
if value.MessageType == message.MessageTypeFunctionCall {
if len(fcMsgPairs[value.RunID])%2 == 0 {
historyAfterPairs = append(historyAfterPairs, value)
}
} else {
historyAfterPairs = append(historyAfterPairs, value)
}
}
return historyAfterPairs
}
func transMessageToSchemaMessage(ctx context.Context, msgs []*message.Message, imagexClient imagex.ImageX) []*schema.Message {
schemaMessage := make([]*schema.Message, 0, len(msgs))
for _, msgOne := range msgs {
if msgOne.ModelContent == "" {
continue
}
if msgOne.MessageType == message.MessageTypeVerbose || msgOne.MessageType == message.MessageTypeFlowUp {
continue
}
var sm *schema.Message
err := json.Unmarshal([]byte(msgOne.ModelContent), &sm)
if err != nil {
continue
}
if len(sm.ReasoningContent) > 0 {
sm.ReasoningContent = ""
}
schemaMessage = append(schemaMessage, parseMessageURI(ctx, sm, imagexClient))
}
return schemaMessage
}
func parseMessageURI(ctx context.Context, mcMsg *schema.Message, imagexClient imagex.ImageX) *schema.Message {
if mcMsg.MultiContent == nil {
return mcMsg
}
for k, one := range mcMsg.MultiContent {
switch one.Type {
case schema.ChatMessagePartTypeImageURL:
if one.ImageURL.URI != "" {
url, err := imagexClient.GetResourceURL(ctx, one.ImageURL.URI)
if err == nil {
mcMsg.MultiContent[k].ImageURL.URL = url.URL
}
}
case schema.ChatMessagePartTypeFileURL:
if one.FileURL.URI != "" {
url, err := imagexClient.GetResourceURL(ctx, one.FileURL.URI)
if err == nil {
mcMsg.MultiContent[k].FileURL.URL = url.URL
}
}
case schema.ChatMessagePartTypeAudioURL:
if one.AudioURL.URI != "" {
url, err := imagexClient.GetResourceURL(ctx, one.AudioURL.URI)
if err == nil {
mcMsg.MultiContent[k].AudioURL.URL = url.URL
}
}
case schema.ChatMessagePartTypeVideoURL:
if one.VideoURL.URI != "" {
url, err := imagexClient.GetResourceURL(ctx, one.VideoURL.URI)
if err == nil {
mcMsg.MultiContent[k].VideoURL.URL = url.URL
}
}
}
}
return mcMsg
}
func parseResumeInfo(_ context.Context, historyMsg []*message.Message) *crossagent.ResumeInfo {
var resumeInfo *crossagent.ResumeInfo
for i := len(historyMsg) - 1; i >= 0; i-- {
if historyMsg[i].MessageType == message.MessageTypeQuestion {
break
}
if historyMsg[i].MessageType == message.MessageTypeVerbose {
if historyMsg[i].Ext[string(msgEntity.ExtKeyResumeInfo)] != "" {
err := json.Unmarshal([]byte(historyMsg[i].Ext[string(msgEntity.ExtKeyResumeInfo)]), &resumeInfo)
if err != nil {
return nil
}
}
}
}
return resumeInfo
}
func buildSendRunRecord(_ context.Context, runRecord *entity.RunRecordMeta, runStatus entity.RunStatus) *entity.ChunkRunItem {
return &entity.ChunkRunItem{
ID: runRecord.ID,
ConversationID: runRecord.ConversationID,
AgentID: runRecord.AgentID,
SectionID: runRecord.SectionID,
Status: runStatus,
CreatedAt: runRecord.CreatedAt,
}
}

View File

@@ -0,0 +1,428 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package internal
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"strconv"
"strings"
"time"
"github.com/cloudwego/eino/schema"
"github.com/mohae/deepcopy"
crossmessage "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
crossworkflow "github.com/coze-dev/coze-studio/backend/crossdomain/contract/workflow"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
msgEntity "github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/types/consts"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
type Event struct {
}
func NewMessageEvent() *Event {
return &Event{}
}
func (e *Event) buildMessageEvent(runEvent entity.RunEvent, chunkMsgItem *entity.ChunkMessageItem) *entity.AgentRunResponse {
return &entity.AgentRunResponse{
Event: runEvent,
ChunkMessageItem: chunkMsgItem,
}
}
func (e *Event) buildRunEvent(runEvent entity.RunEvent, chunkRunItem *entity.ChunkRunItem) *entity.AgentRunResponse {
return &entity.AgentRunResponse{
Event: runEvent,
ChunkRunItem: chunkRunItem,
}
}
func (e *Event) buildErrEvent(runEvent entity.RunEvent, err *entity.RunError) *entity.AgentRunResponse {
return &entity.AgentRunResponse{
Event: runEvent,
Error: err,
}
}
func (e *Event) buildStreamDoneEvent() *entity.AgentRunResponse {
return &entity.AgentRunResponse{
Event: entity.RunEventStreamDone,
}
}
func (e *Event) SendRunEvent(runEvent entity.RunEvent, runItem *entity.ChunkRunItem, sw *schema.StreamWriter[*entity.AgentRunResponse]) {
resp := e.buildRunEvent(runEvent, runItem)
sw.Send(resp, nil)
}
func (e *Event) SendMsgEvent(runEvent entity.RunEvent, messageItem *entity.ChunkMessageItem, sw *schema.StreamWriter[*entity.AgentRunResponse]) {
resp := e.buildMessageEvent(runEvent, messageItem)
sw.Send(resp, nil)
}
func (e *Event) SendErrEvent(runEvent entity.RunEvent, sw *schema.StreamWriter[*entity.AgentRunResponse], err *entity.RunError) {
resp := e.buildErrEvent(runEvent, err)
sw.Send(resp, nil)
}
func (e *Event) SendStreamDoneEvent(sw *schema.StreamWriter[*entity.AgentRunResponse]) {
resp := e.buildStreamDoneEvent()
sw.Send(resp, nil)
}
type MesssageEventHanlder struct {
messageEvent *Event
sw *schema.StreamWriter[*entity.AgentRunResponse]
}
func (mh *MesssageEventHanlder) handlerErr(_ context.Context, err error) {
var errMsg string
var statusErr errorx.StatusError
if errors.As(err, &statusErr) {
errMsg = statusErr.Msg()
} else {
if strings.ToLower(os.Getenv(consts.RunMode)) != "debug" {
errMsg = "Internal Server Error"
} else {
errMsg = errorx.ErrorWithoutStack(err)
}
}
mh.messageEvent.SendErrEvent(entity.RunEventError, mh.sw, &entity.RunError{
Code: errno.ErrAgentRun,
Msg: errMsg,
})
}
func (mh *MesssageEventHanlder) handlerAckMessage(_ context.Context, input *msgEntity.Message) error {
sendMsg := &entity.ChunkMessageItem{
ID: input.ID,
ConversationID: input.ConversationID,
SectionID: input.SectionID,
AgentID: input.AgentID,
Role: entity.RoleType(input.Role),
MessageType: message.MessageTypeAck,
ReplyID: input.ID,
Content: input.Content,
ContentType: message.ContentTypeText,
IsFinish: true,
}
mh.messageEvent.SendMsgEvent(entity.RunEventAck, sendMsg, mh.sw)
return nil
}
func (mh *MesssageEventHanlder) handlerFunctionCall(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime) error {
cm := buildAgentMessage2Create(ctx, chunk, message.MessageTypeFunctionCall, rtDependence)
cmData, err := crossmessage.DefaultSVC().Create(ctx, cm)
if err != nil {
return err
}
sendMsg := buildSendMsg(ctx, cmData, true, rtDependence)
mh.messageEvent.SendMsgEvent(entity.RunEventMessageCompleted, sendMsg, mh.sw)
return nil
}
func (mh *MesssageEventHanlder) handlerTooResponse(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime, preToolResponseMsg *msgEntity.Message, toolResponseMsgContent string) error {
cm := buildAgentMessage2Create(ctx, chunk, message.MessageTypeToolResponse, rtDependence)
var cmData *message.Message
var err error
if preToolResponseMsg != nil {
cm.ID = preToolResponseMsg.ID
cm.CreatedAt = preToolResponseMsg.CreatedAt
cm.UpdatedAt = preToolResponseMsg.UpdatedAt
if len(toolResponseMsgContent) > 0 {
cm.Content = toolResponseMsgContent + "\n" + cm.Content
}
}
cmData, err = crossmessage.DefaultSVC().Create(ctx, cm)
if err != nil {
return err
}
sendMsg := buildSendMsg(ctx, cmData, true, rtDependence)
mh.messageEvent.SendMsgEvent(entity.RunEventMessageCompleted, sendMsg, mh.sw)
return nil
}
func (mh *MesssageEventHanlder) handlerSuggest(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime) error {
cm := buildAgentMessage2Create(ctx, chunk, message.MessageTypeFlowUp, rtDependence)
cmData, err := crossmessage.DefaultSVC().Create(ctx, cm)
if err != nil {
return err
}
sendMsg := buildSendMsg(ctx, cmData, true, rtDependence)
mh.messageEvent.SendMsgEvent(entity.RunEventMessageCompleted, sendMsg, mh.sw)
return nil
}
func (mh *MesssageEventHanlder) handlerKnowledge(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime) error {
cm := buildAgentMessage2Create(ctx, chunk, message.MessageTypeKnowledge, rtDependence)
cmData, err := crossmessage.DefaultSVC().Create(ctx, cm)
if err != nil {
return err
}
sendMsg := buildSendMsg(ctx, cmData, true, rtDependence)
mh.messageEvent.SendMsgEvent(entity.RunEventMessageCompleted, sendMsg, mh.sw)
return nil
}
func (mh *MesssageEventHanlder) handlerAnswer(ctx context.Context, msg *entity.ChunkMessageItem, usage *msgEntity.UsageExt, rtDependence *AgentRuntime, preAnswerMsg *msgEntity.Message) error {
if len(msg.Content) == 0 && len(ptr.From(msg.ReasoningContent)) == 0 {
return nil
}
msg.IsFinish = true
if msg.Ext == nil {
msg.Ext = map[string]string{}
}
if usage != nil {
msg.Ext[string(msgEntity.MessageExtKeyToken)] = strconv.FormatInt(usage.TotalCount, 10)
msg.Ext[string(msgEntity.MessageExtKeyInputTokens)] = strconv.FormatInt(usage.InputTokens, 10)
msg.Ext[string(msgEntity.MessageExtKeyOutputTokens)] = strconv.FormatInt(usage.OutputTokens, 10)
rtDependence.Usage = &agentrun.Usage{
LlmPromptTokens: usage.InputTokens,
LlmCompletionTokens: usage.OutputTokens,
LlmTotalTokens: usage.TotalCount,
}
}
if _, ok := msg.Ext[string(msgEntity.MessageExtKeyTimeCost)]; !ok {
msg.Ext[string(msgEntity.MessageExtKeyTimeCost)] = fmt.Sprintf("%.1f", float64(time.Since(rtDependence.GetStartTime()).Milliseconds())/1000.00)
}
buildModelContent := &schema.Message{
Role: schema.Assistant,
Content: msg.Content,
}
mc, err := json.Marshal(buildModelContent)
if err != nil {
return err
}
preAnswerMsg.Content = msg.Content
preAnswerMsg.ReasoningContent = ptr.From(msg.ReasoningContent)
preAnswerMsg.Ext = msg.Ext
preAnswerMsg.ContentType = msg.ContentType
preAnswerMsg.ModelContent = string(mc)
preAnswerMsg.CreatedAt = 0
preAnswerMsg.UpdatedAt = 0
_, err = crossmessage.DefaultSVC().Create(ctx, preAnswerMsg)
if err != nil {
return err
}
mh.messageEvent.SendMsgEvent(entity.RunEventMessageCompleted, msg, mh.sw)
return nil
}
func (mh *MesssageEventHanlder) handlerFinalAnswerFinish(ctx context.Context, rtDependence *AgentRuntime) error {
cm := buildAgentMessage2Create(ctx, nil, message.MessageTypeVerbose, rtDependence)
cmData, err := crossmessage.DefaultSVC().Create(ctx, cm)
if err != nil {
return err
}
sendMsg := buildSendMsg(ctx, cmData, true, rtDependence)
mh.messageEvent.SendMsgEvent(entity.RunEventMessageCompleted, sendMsg, mh.sw)
return nil
}
func (mh *MesssageEventHanlder) handlerInterruptVerbose(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime) error {
cm := buildAgentMessage2Create(ctx, chunk, message.MessageTypeInterrupt, rtDependence)
cmData, err := crossmessage.DefaultSVC().Create(ctx, cm)
if err != nil {
return err
}
sendMsg := buildSendMsg(ctx, cmData, true, rtDependence)
mh.messageEvent.SendMsgEvent(entity.RunEventMessageCompleted, sendMsg, mh.sw)
return nil
}
func (mh *MesssageEventHanlder) handlerWfUsage(ctx context.Context, msg *entity.ChunkMessageItem, usage *msgEntity.UsageExt) error {
if msg.Ext == nil {
msg.Ext = map[string]string{}
}
if usage != nil {
msg.Ext[string(msgEntity.MessageExtKeyToken)] = strconv.FormatInt(usage.TotalCount, 10)
msg.Ext[string(msgEntity.MessageExtKeyInputTokens)] = strconv.FormatInt(usage.InputTokens, 10)
msg.Ext[string(msgEntity.MessageExtKeyOutputTokens)] = strconv.FormatInt(usage.OutputTokens, 10)
}
_, err := crossmessage.DefaultSVC().Edit(ctx, &msgEntity.Message{
ID: msg.ID,
Ext: msg.Ext,
})
if err != nil {
return err
}
mh.messageEvent.SendMsgEvent(entity.RunEventMessageCompleted, msg, mh.sw)
return nil
}
func (mh *MesssageEventHanlder) handlerInterrupt(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime, firstAnswerMsg *msgEntity.Message, reasoningContent string) error {
interruptData, cType, err := parseInterruptData(ctx, chunk.Interrupt)
if err != nil {
return err
}
preMsg, err := preCreateAnswer(ctx, rtDependence)
if err != nil {
return err
}
deltaAnswer := &entity.ChunkMessageItem{
ID: preMsg.ID,
ConversationID: preMsg.ConversationID,
SectionID: preMsg.SectionID,
RunID: preMsg.RunID,
AgentID: preMsg.AgentID,
Role: entity.RoleType(preMsg.Role),
Content: interruptData,
MessageType: preMsg.MessageType,
ContentType: cType,
ReplyID: preMsg.RunID,
Ext: preMsg.Ext,
IsFinish: false,
}
mh.messageEvent.SendMsgEvent(entity.RunEventMessageDelta, deltaAnswer, mh.sw)
finalAnswer := deepcopy.Copy(deltaAnswer).(*entity.ChunkMessageItem)
if len(reasoningContent) > 0 && firstAnswerMsg == nil {
finalAnswer.ReasoningContent = ptr.Of(reasoningContent)
}
usage := func() *msgEntity.UsageExt {
if rtDependence.GetUsage() != nil {
return &msgEntity.UsageExt{
TotalCount: rtDependence.GetUsage().LlmTotalTokens,
InputTokens: rtDependence.GetUsage().LlmPromptTokens,
OutputTokens: rtDependence.GetUsage().LlmCompletionTokens,
}
}
return nil
}
err = mh.handlerAnswer(ctx, finalAnswer, usage(), rtDependence, preMsg)
if err != nil {
return err
}
err = mh.handlerInterruptVerbose(ctx, chunk, rtDependence)
if err != nil {
return err
}
return nil
}
func (mh *MesssageEventHanlder) handlerWfInterruptMsg(ctx context.Context, stateMsg *crossworkflow.StateMessage, rtDependence *AgentRuntime) {
interruptData, cType, err := handlerWfInterruptEvent(ctx, stateMsg.InterruptEvent)
if err != nil {
return
}
preMsg, err := preCreateAnswer(ctx, rtDependence)
if err != nil {
return
}
deltaAnswer := &entity.ChunkMessageItem{
ID: preMsg.ID,
ConversationID: preMsg.ConversationID,
SectionID: preMsg.SectionID,
RunID: preMsg.RunID,
AgentID: preMsg.AgentID,
Role: entity.RoleType(preMsg.Role),
Content: interruptData,
MessageType: preMsg.MessageType,
ContentType: cType,
ReplyID: preMsg.RunID,
Ext: preMsg.Ext,
IsFinish: false,
}
mh.messageEvent.SendMsgEvent(entity.RunEventMessageDelta, deltaAnswer, mh.sw)
finalAnswer := deepcopy.Copy(deltaAnswer).(*entity.ChunkMessageItem)
err = mh.handlerAnswer(ctx, finalAnswer, nil, rtDependence, preMsg)
if err != nil {
return
}
err = mh.handlerInterruptVerbose(ctx, &entity.AgentRespEvent{
EventType: message.MessageTypeInterrupt,
Interrupt: &singleagent.InterruptInfo{
InterruptType: singleagent.InterruptEventType(stateMsg.InterruptEvent.EventType),
InterruptID: strconv.FormatInt(stateMsg.InterruptEvent.ID, 10),
ChatflowInterrupt: stateMsg,
},
}, rtDependence)
if err != nil {
return
}
}
func (mh *MesssageEventHanlder) HandlerInput(ctx context.Context, rtDependence *AgentRuntime) (*msgEntity.Message, error) {
msgMeta := buildAgentMessage2Create(ctx, nil, message.MessageTypeQuestion, rtDependence)
cm, err := crossmessage.DefaultSVC().Create(ctx, msgMeta)
if err != nil {
return nil, err
}
ackErr := mh.handlerAckMessage(ctx, cm)
if ackErr != nil {
return msgMeta, ackErr
}
return cm, nil
}

View File

@@ -0,0 +1,214 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package internal
import (
"context"
"time"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/api/model/app/bot_common"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
crossmessage "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/repository"
msgEntity "github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
type AgentRuntime struct {
RunRecord *entity.RunRecordMeta
AgentInfo *singleagent.SingleAgent
QuestionMsgID int64
RunMeta *entity.AgentRunMeta
StartTime time.Time
Input *msgEntity.Message
HistoryMsg []*msgEntity.Message
Usage *agentrun.Usage
SW *schema.StreamWriter[*entity.AgentRunResponse]
RunProcess *RunProcess
RunRecordRepo repository.RunRecordRepo
ImagexClient imagex.ImageX
MessageEvent *Event
}
func (rd *AgentRuntime) SetRunRecord(runRecord *entity.RunRecordMeta) {
rd.RunRecord = runRecord
}
func (rd *AgentRuntime) GetRunRecord() *entity.RunRecordMeta {
return rd.RunRecord
}
func (rd *AgentRuntime) SetUsage(usage *agentrun.Usage) {
rd.Usage = usage
}
func (rd *AgentRuntime) GetUsage() *agentrun.Usage {
return rd.Usage
}
func (rd *AgentRuntime) SetRunMeta(arm *entity.AgentRunMeta) {
rd.RunMeta = arm
}
func (rd *AgentRuntime) GetRunMeta() *entity.AgentRunMeta {
return rd.RunMeta
}
func (rd *AgentRuntime) SetAgentInfo(agentInfo *singleagent.SingleAgent) {
rd.AgentInfo = agentInfo
}
func (rd *AgentRuntime) GetAgentInfo() *singleagent.SingleAgent {
return rd.AgentInfo
}
func (rd *AgentRuntime) SetQuestionMsgID(msgID int64) {
rd.QuestionMsgID = msgID
}
func (rd *AgentRuntime) GetQuestionMsgID() int64 {
return rd.QuestionMsgID
}
func (rd *AgentRuntime) SetStartTime(t time.Time) {
rd.StartTime = t
}
func (rd *AgentRuntime) GetStartTime() time.Time {
return rd.StartTime
}
func (rd *AgentRuntime) SetInput(input *msgEntity.Message) {
rd.Input = input
}
func (rd *AgentRuntime) GetInput() *msgEntity.Message {
return rd.Input
}
func (rd *AgentRuntime) SetHistoryMsg(histroyMsg []*msgEntity.Message) {
rd.HistoryMsg = histroyMsg
}
func (rd *AgentRuntime) GetHistory() []*msgEntity.Message {
return rd.HistoryMsg
}
func (art *AgentRuntime) Run(ctx context.Context) (err error) {
agentInfo, err := getAgentInfo(ctx, art.GetRunMeta().AgentID, art.GetRunMeta().IsDraft)
if err != nil {
return
}
art.SetAgentInfo(agentInfo)
history, err := art.getHistory(ctx)
if err != nil {
return
}
runRecord, err := art.createRunRecord(ctx)
if err != nil {
return
}
art.SetRunRecord(runRecord)
art.SetHistoryMsg(history)
defer func() {
srRecord := buildSendRunRecord(ctx, runRecord, entity.RunStatusCompleted)
if err != nil {
srRecord.Error = &entity.RunError{
Code: errno.ErrConversationAgentRunError,
Msg: err.Error(),
}
art.RunProcess.StepToFailed(ctx, srRecord, art.SW)
return
}
art.RunProcess.StepToComplete(ctx, srRecord, art.SW, art.GetUsage())
}()
mh := &MesssageEventHanlder{
messageEvent: art.MessageEvent,
sw: art.SW,
}
input, err := mh.HandlerInput(ctx, art)
if err != nil {
return
}
art.SetInput(input)
art.SetQuestionMsgID(input.ID)
if art.GetAgentInfo().BotMode == bot_common.BotMode_WorkflowMode {
err = art.ChatflowRun(ctx, art.ImagexClient)
} else {
err = art.AgentStreamExecute(ctx, art.ImagexClient)
}
return
}
func (art *AgentRuntime) getHistory(ctx context.Context) ([]*msgEntity.Message, error) {
conversationTurns := getAgentHistoryRounds(art.GetAgentInfo())
runRecordList, err := art.RunRecordRepo.List(ctx, &entity.ListRunRecordMeta{
ConversationID: art.GetRunMeta().ConversationID,
SectionID: art.GetRunMeta().SectionID,
Limit: conversationTurns,
})
if err != nil {
return nil, err
}
if len(runRecordList) == 0 {
return nil, nil
}
runIDS := concactRunID(runRecordList)
history, err := crossmessage.DefaultSVC().GetByRunIDs(ctx, art.GetRunMeta().ConversationID, runIDS)
if err != nil {
return nil, err
}
return history, nil
}
func concactRunID(rr []*entity.RunRecordMeta) []int64 {
ids := make([]int64, 0, len(rr))
for _, c := range rr {
ids = append(ids, c.ID)
}
return ids
}
func (art *AgentRuntime) createRunRecord(ctx context.Context) (*entity.RunRecordMeta, error) {
runPoData, err := art.RunRecordRepo.Create(ctx, art.GetRunMeta())
if err != nil {
logs.CtxErrorf(ctx, "RunRecordRepo.Create error: %v", err)
return nil, err
}
srRecord := buildSendRunRecord(ctx, runPoData, entity.RunStatusCreated)
art.RunProcess.StepToCreate(ctx, srRecord, art.SW)
err = art.RunProcess.StepToInProgress(ctx, srRecord, art.SW)
if err != nil {
logs.CtxErrorf(ctx, "runProcess.StepToInProgress error: %v", err)
return nil, err
}
return runPoData, nil
}

View File

@@ -30,8 +30,8 @@ import (
)
type RunProcess struct {
event *Event
event *Event
SW *schema.StreamWriter[*entity.AgentRunResponse]
RunRecordRepo repository.RunRecordRepo
}
@@ -115,7 +115,6 @@ func (r *RunProcess) StepToFailed(ctx context.Context, srRecord *entity.ChunkRun
Code: srRecord.Error.Code,
Msg: srRecord.Error.Msg,
})
return
}
func (r *RunProcess) StepToDone(sw *schema.StreamWriter[*entity.AgentRunResponse]) {

View File

@@ -0,0 +1,450 @@
/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package internal
import (
"bytes"
"context"
"errors"
"io"
"sync"
"github.com/cloudwego/eino/schema"
"github.com/mohae/deepcopy"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
crossagent "github.com/coze-dev/coze-studio/backend/crossdomain/contract/agent"
crossmessage "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
msgEntity "github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
func (art *AgentRuntime) AgentStreamExecute(ctx context.Context, imagex imagex.ImageX) (err error) {
mainChan := make(chan *entity.AgentRespEvent, 100)
ar := &crossagent.AgentRuntime{
AgentVersion: art.GetRunMeta().Version,
SpaceID: art.GetRunMeta().SpaceID,
AgentID: art.GetRunMeta().AgentID,
IsDraft: art.GetRunMeta().IsDraft,
UserID: art.GetRunMeta().UserID,
ConnectorID: art.GetRunMeta().ConnectorID,
PreRetrieveTools: art.GetRunMeta().PreRetrieveTools,
Input: transMessageToSchemaMessage(ctx, []*msgEntity.Message{art.GetInput()}, imagex)[0],
HistoryMsg: transMessageToSchemaMessage(ctx, historyPairs(art.GetHistory()), imagex),
ResumeInfo: parseResumeInfo(ctx, art.GetHistory()),
}
streamer, err := crossagent.DefaultSVC().StreamExecute(ctx, ar)
if err != nil {
return errors.New(errorx.ErrorWithoutStack(err))
}
var wg sync.WaitGroup
wg.Add(2)
safego.Go(ctx, func() {
defer wg.Done()
art.pull(ctx, mainChan, streamer)
})
safego.Go(ctx, func() {
defer wg.Done()
art.push(ctx, mainChan)
})
wg.Wait()
return err
}
func (art *AgentRuntime) push(ctx context.Context, mainChan chan *entity.AgentRespEvent) {
mh := &MesssageEventHanlder{
sw: art.SW,
messageEvent: art.MessageEvent,
}
var err error
defer func() {
if err != nil {
logs.CtxErrorf(ctx, "run.push error: %v", err)
mh.handlerErr(ctx, err)
}
}()
reasoningContent := bytes.NewBuffer([]byte{})
var firstAnswerMsg *msgEntity.Message
var reasoningMsg *msgEntity.Message
isSendFinishAnswer := false
var preToolResponseMsg *msgEntity.Message
toolResponseMsgContent := bytes.NewBuffer([]byte{})
for {
chunk, ok := <-mainChan
if !ok || chunk == nil {
return
}
if chunk.Err != nil {
if errors.Is(chunk.Err, io.EOF) {
if !isSendFinishAnswer {
isSendFinishAnswer = true
if firstAnswerMsg != nil && len(reasoningContent.String()) > 0 {
art.saveReasoningContent(ctx, firstAnswerMsg, reasoningContent.String())
reasoningContent.Reset()
}
finishErr := mh.handlerFinalAnswerFinish(ctx, art)
if finishErr != nil {
err = finishErr
return
}
}
return
}
mh.handlerErr(ctx, chunk.Err)
return
}
switch chunk.EventType {
case message.MessageTypeFunctionCall:
if chunk.FuncCall != nil && chunk.FuncCall.ResponseMeta != nil {
if usage := handlerUsage(chunk.FuncCall.ResponseMeta); usage != nil {
art.SetUsage(&agentrun.Usage{
LlmPromptTokens: usage.InputTokens,
LlmCompletionTokens: usage.OutputTokens,
LlmTotalTokens: usage.TotalCount,
})
}
}
err = mh.handlerFunctionCall(ctx, chunk, art)
if err != nil {
return
}
if preToolResponseMsg == nil {
var cErr error
preToolResponseMsg, cErr = preCreateAnswer(ctx, art)
if cErr != nil {
err = cErr
return
}
}
case message.MessageTypeToolResponse:
err = mh.handlerTooResponse(ctx, chunk, art, preToolResponseMsg, toolResponseMsgContent.String())
if err != nil {
return
}
preToolResponseMsg = nil // reset
case message.MessageTypeKnowledge:
err = mh.handlerKnowledge(ctx, chunk, art)
if err != nil {
return
}
case message.MessageTypeToolMidAnswer:
fullMidAnswerContent := bytes.NewBuffer([]byte{})
var usage *msgEntity.UsageExt
toolMidAnswerMsg, cErr := preCreateAnswer(ctx, art)
if cErr != nil {
err = cErr
return
}
var preMsgIsFinish = false
for {
streamMsg, receErr := chunk.ToolMidAnswer.Recv()
if receErr != nil {
if errors.Is(receErr, io.EOF) {
break
}
err = receErr
return
}
if preMsgIsFinish {
toolMidAnswerMsg, cErr = preCreateAnswer(ctx, art)
if cErr != nil {
err = cErr
return
}
preMsgIsFinish = false
}
if streamMsg == nil {
continue
}
if firstAnswerMsg == nil && len(streamMsg.Content) > 0 {
if reasoningMsg != nil {
toolMidAnswerMsg = deepcopy.Copy(reasoningMsg).(*msgEntity.Message)
}
firstAnswerMsg = deepcopy.Copy(toolMidAnswerMsg).(*msgEntity.Message)
}
if streamMsg.Extra != nil {
if val, ok := streamMsg.Extra["workflow_node_name"]; ok && val != nil {
toolMidAnswerMsg.Ext["message_title"] = val.(string)
}
}
sendMidAnswerMsg := buildSendMsg(ctx, toolMidAnswerMsg, false, art)
sendMidAnswerMsg.Content = streamMsg.Content
toolResponseMsgContent.WriteString(streamMsg.Content)
fullMidAnswerContent.WriteString(streamMsg.Content)
art.MessageEvent.SendMsgEvent(entity.RunEventMessageDelta, sendMidAnswerMsg, art.SW)
if streamMsg != nil && streamMsg.ResponseMeta != nil {
usage = handlerUsage(streamMsg.ResponseMeta)
}
if streamMsg.Extra["is_finish"] == true {
preMsgIsFinish = true
sendMidAnswerMsg := buildSendMsg(ctx, toolMidAnswerMsg, false, art)
sendMidAnswerMsg.Content = fullMidAnswerContent.String()
fullMidAnswerContent.Reset()
hfErr := mh.handlerAnswer(ctx, sendMidAnswerMsg, usage, art, toolMidAnswerMsg)
if hfErr != nil {
err = hfErr
return
}
}
}
case message.MessageTypeToolAsAnswer:
var usage *msgEntity.UsageExt
fullContent := bytes.NewBuffer([]byte{})
toolAsAnswerMsg, cErr := preCreateAnswer(ctx, art)
if cErr != nil {
err = cErr
return
}
if firstAnswerMsg == nil {
firstAnswerMsg = toolAsAnswerMsg
}
for {
streamMsg, receErr := chunk.ToolAsAnswer.Recv()
if receErr != nil {
if errors.Is(receErr, io.EOF) {
answer := buildSendMsg(ctx, toolAsAnswerMsg, false, art)
answer.Content = fullContent.String()
hfErr := mh.handlerAnswer(ctx, answer, usage, art, toolAsAnswerMsg)
if hfErr != nil {
err = hfErr
return
}
break
}
err = receErr
return
}
if streamMsg != nil && streamMsg.ResponseMeta != nil {
usage = handlerUsage(streamMsg.ResponseMeta)
}
sendMsg := buildSendMsg(ctx, toolAsAnswerMsg, false, art)
fullContent.WriteString(streamMsg.Content)
sendMsg.Content = streamMsg.Content
art.MessageEvent.SendMsgEvent(entity.RunEventMessageDelta, sendMsg, art.SW)
}
case message.MessageTypeAnswer:
fullContent := bytes.NewBuffer([]byte{})
var usage *msgEntity.UsageExt
var isToolCalls = false
var modelAnswerMsg *msgEntity.Message
for {
streamMsg, receErr := chunk.ModelAnswer.Recv()
if receErr != nil {
if errors.Is(receErr, io.EOF) {
if isToolCalls {
break
}
if modelAnswerMsg == nil {
break
}
answer := buildSendMsg(ctx, modelAnswerMsg, false, art)
answer.Content = fullContent.String()
hfErr := mh.handlerAnswer(ctx, answer, usage, art, modelAnswerMsg)
if hfErr != nil {
err = hfErr
return
}
break
}
err = receErr
return
}
if streamMsg != nil && len(streamMsg.ToolCalls) > 0 {
isToolCalls = true
}
if streamMsg != nil && streamMsg.ResponseMeta != nil {
usage = handlerUsage(streamMsg.ResponseMeta)
}
if streamMsg != nil && len(streamMsg.ReasoningContent) == 0 && len(streamMsg.Content) == 0 {
continue
}
if len(streamMsg.ReasoningContent) > 0 {
if reasoningMsg == nil {
reasoningMsg, err = preCreateAnswer(ctx, art)
if err != nil {
return
}
}
sendReasoningMsg := buildSendMsg(ctx, reasoningMsg, false, art)
reasoningContent.WriteString(streamMsg.ReasoningContent)
sendReasoningMsg.ReasoningContent = ptr.Of(streamMsg.ReasoningContent)
art.MessageEvent.SendMsgEvent(entity.RunEventMessageDelta, sendReasoningMsg, art.SW)
}
if len(streamMsg.Content) > 0 {
if modelAnswerMsg == nil {
modelAnswerMsg, err = preCreateAnswer(ctx, art)
if err != nil {
return
}
if firstAnswerMsg == nil {
if reasoningMsg != nil {
modelAnswerMsg.ID = reasoningMsg.ID
}
firstAnswerMsg = modelAnswerMsg
}
}
sendAnswerMsg := buildSendMsg(ctx, modelAnswerMsg, false, art)
fullContent.WriteString(streamMsg.Content)
sendAnswerMsg.Content = streamMsg.Content
art.MessageEvent.SendMsgEvent(entity.RunEventMessageDelta, sendAnswerMsg, art.SW)
}
}
case message.MessageTypeFlowUp:
if isSendFinishAnswer {
if firstAnswerMsg != nil && len(reasoningContent.String()) > 0 {
art.saveReasoningContent(ctx, firstAnswerMsg, reasoningContent.String())
}
isSendFinishAnswer = true
finishErr := mh.handlerFinalAnswerFinish(ctx, art)
if finishErr != nil {
err = finishErr
return
}
}
err = mh.handlerSuggest(ctx, chunk, art)
if err != nil {
return
}
case message.MessageTypeInterrupt:
err = mh.handlerInterrupt(ctx, chunk, art, firstAnswerMsg, reasoningContent.String())
if err != nil {
return
}
}
}
}
func (art *AgentRuntime) pull(_ context.Context, mainChan chan *entity.AgentRespEvent, events *schema.StreamReader[*crossagent.AgentEvent]) {
defer func() {
close(mainChan)
}()
for {
rm, re := events.Recv()
if re != nil {
errChunk := &entity.AgentRespEvent{
Err: re,
}
mainChan <- errChunk
return
}
eventType, tErr := transformEventMap(rm.EventType)
if tErr != nil {
errChunk := &entity.AgentRespEvent{
Err: tErr,
}
mainChan <- errChunk
return
}
respChunk := &entity.AgentRespEvent{
EventType: eventType,
ModelAnswer: rm.ChatModelAnswer,
ToolsMessage: rm.ToolsMessage,
FuncCall: rm.FuncCall,
Knowledge: rm.Knowledge,
Suggest: rm.Suggest,
Interrupt: rm.Interrupt,
ToolMidAnswer: rm.ToolMidAnswer,
ToolAsAnswer: rm.ToolAsChatModelAnswer,
}
mainChan <- respChunk
}
}
func transformEventMap(eventType singleagent.EventType) (message.MessageType, error) {
var eType message.MessageType
switch eventType {
case singleagent.EventTypeOfFuncCall:
return message.MessageTypeFunctionCall, nil
case singleagent.EventTypeOfKnowledge:
return message.MessageTypeKnowledge, nil
case singleagent.EventTypeOfToolsMessage:
return message.MessageTypeToolResponse, nil
case singleagent.EventTypeOfChatModelAnswer:
return message.MessageTypeAnswer, nil
case singleagent.EventTypeOfToolsAsChatModelStream:
return message.MessageTypeToolAsAnswer, nil
case singleagent.EventTypeOfToolMidAnswer:
return message.MessageTypeToolMidAnswer, nil
case singleagent.EventTypeOfSuggest:
return message.MessageTypeFlowUp, nil
case singleagent.EventTypeOfInterrupt:
return message.MessageTypeInterrupt, nil
}
return eType, errorx.New(errno.ErrReplyUnknowEventType)
}
func (art *AgentRuntime) saveReasoningContent(ctx context.Context, firstAnswerMsg *msgEntity.Message, reasoningContent string) {
_, err := crossmessage.DefaultSVC().Edit(ctx, &message.Message{
ID: firstAnswerMsg.ID,
ReasoningContent: reasoningContent,
})
if err != nil {
logs.CtxInfof(ctx, "save reasoning content failed, err: %v", err)
}
}

View File

@@ -23,7 +23,6 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/internal/dal"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/internal/dal/model"
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
)
@@ -34,8 +33,9 @@ func NewRunRecordRepo(db *gorm.DB, idGen idgen.IDGenerator) RunRecordRepo {
type RunRecordRepo interface {
Create(ctx context.Context, runMeta *entity.AgentRunMeta) (*entity.RunRecordMeta, error)
GetByID(ctx context.Context, id int64) (*entity.RunRecord, error)
GetByID(ctx context.Context, id int64) (*entity.RunRecordMeta, error)
Cancel(ctx context.Context, req *entity.CancelRunMeta) (*entity.RunRecordMeta, error)
Delete(ctx context.Context, id []int64) error
UpdateByID(ctx context.Context, id int64, update *entity.UpdateMeta) error
List(ctx context.Context, conversationID int64, sectionID int64, limit int32) ([]*model.RunRecord, error)
List(ctx context.Context, meta *entity.ListRunRecordMeta) ([]*entity.RunRecordMeta, error)
}

View File

@@ -26,6 +26,9 @@ import (
type Run interface {
AgentRun(ctx context.Context, req *entity.AgentRunMeta) (*schema.StreamReader[*entity.AgentRunResponse], error)
Delete(ctx context.Context, runID []int64) error
Create(ctx context.Context, runRecord *entity.AgentRunMeta) (*entity.RunRecordMeta, error)
List(ctx context.Context, ListMeta *entity.ListRunRecordMeta) ([]*entity.RunRecordMeta, error)
GetByID(ctx context.Context, runID int64) (*entity.RunRecordMeta, error)
Cancel(ctx context.Context, req *entity.CancelRunMeta) (*entity.RunRecordMeta, error)
}

File diff suppressed because it is too large Load Diff

View File

@@ -17,7 +17,18 @@
package agentrun
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/internal/dal/model"
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/repository"
mock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/idgen"
"github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/orm"
)
func TestAgentRun(t *testing.T) {
@@ -97,3 +108,158 @@ func TestAgentRun(t *testing.T) {
// assert.NoError(t, err)
}
func TestRunImpl_List(t *testing.T) {
ctx := context.Background()
mockDBGen := orm.NewMockDB()
mockDBGen.AddTable(&model.RunRecord{}).AddRows(
&model.RunRecord{
ID: 1,
ConversationID: 123,
AgentID: 456,
SectionID: 789,
UserID: "123456",
CreatedAt: time.Now().Unix(),
},
&model.RunRecord{
ID: 2,
ConversationID: 123,
AgentID: 456,
SectionID: 789,
UserID: "123456",
CreatedAt: time.Now().Unix() + 1,
}, &model.RunRecord{
ID: 3,
ConversationID: 123,
AgentID: 456,
SectionID: 789,
UserID: "123456",
CreatedAt: time.Now().Unix() + 2,
}, &model.RunRecord{
ID: 4,
ConversationID: 123,
AgentID: 456,
SectionID: 789,
UserID: "123456",
CreatedAt: time.Now().Unix() + 3,
}, &model.RunRecord{
ID: 5,
ConversationID: 123,
AgentID: 456,
SectionID: 789,
UserID: "123456",
CreatedAt: time.Now().Unix() + 4,
},
&model.RunRecord{
ID: 6,
ConversationID: 123,
AgentID: 456,
SectionID: 789,
UserID: "123456",
CreatedAt: time.Now().Unix() + 5,
}, &model.RunRecord{
ID: 7,
ConversationID: 123,
AgentID: 456,
SectionID: 789,
UserID: "123456",
CreatedAt: time.Now().Unix() + 6,
}, &model.RunRecord{
ID: 8,
ConversationID: 123,
AgentID: 456,
SectionID: 789,
UserID: "123456",
CreatedAt: time.Now().Unix() + 7,
}, &model.RunRecord{
ID: 9,
ConversationID: 123,
AgentID: 456,
SectionID: 789,
UserID: "123456",
CreatedAt: time.Now().Unix() + 8,
},
)
mockDB, err := mockDBGen.DB()
assert.NoError(t, err)
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockIDGen := mock.NewMockIDGenerator(ctrl)
runRecordRepo := repository.NewRunRecordRepo(mockDB, mockIDGen)
service := &runImpl{
Components: Components{
RunRecordRepo: runRecordRepo,
},
}
t.Run("list success", func(t *testing.T) {
meta := &entity.ListRunRecordMeta{
ConversationID: 123,
AgentID: 456,
SectionID: 789,
Limit: 10,
OrderBy: "desc",
}
result, err := service.List(ctx, meta)
// check result
assert.NoError(t, err)
assert.Len(t, result, 9)
assert.Equal(t, int64(123), result[0].ConversationID)
assert.Equal(t, int64(456), result[0].AgentID)
})
t.Run("empty list", func(t *testing.T) {
meta := &entity.ListRunRecordMeta{
ConversationID: 999, //
Limit: 10,
OrderBy: "desc",
}
// check result
result, err := service.List(ctx, meta)
assert.NoError(t, err)
assert.Empty(t, result)
})
t.Run("search with before id", func(t *testing.T) {
meta := &entity.ListRunRecordMeta{
ConversationID: 123,
SectionID: 789,
AgentID: 456,
BeforeID: 5,
Limit: 3,
OrderBy: "desc",
}
result, err := service.List(ctx, meta)
// check result
assert.NoError(t, err)
assert.Len(t, result, 3)
assert.Equal(t, int64(4), result[0].ID)
})
t.Run("search with after id and limit", func(t *testing.T) {
meta := &entity.ListRunRecordMeta{
ConversationID: 123,
SectionID: 789,
AgentID: 456,
AfterID: 5,
Limit: 3,
OrderBy: "desc",
}
result, err := service.List(ctx, meta)
// check result
assert.NoError(t, err)
assert.Len(t, result, 3)
assert.Equal(t, int64(9), result[0].ID)
})
}