feat: Support for Chat Flow & Agent Support for binding a single chat flow (#765)
Co-authored-by: Yu Yang <72337138+tomasyu985@users.noreply.github.com> Co-authored-by: zengxiaohui <csu.zengxiaohui@gmail.com> Co-authored-by: lijunwen.gigoo <lijunwen.gigoo@bytedance.com> Co-authored-by: lvxinyu.1117 <lvxinyu.1117@bytedance.com> Co-authored-by: liuyunchao.0510 <liuyunchao.0510@bytedance.com> Co-authored-by: haozhenfei <37089575+haozhenfei@users.noreply.github.com> Co-authored-by: July <jiangxujin@bytedance.com> Co-authored-by: tecvan-fe <fanwenjie.fe@bytedance.com>
This commit is contained in:
@@ -43,6 +43,7 @@ type RunRecordMeta struct {
|
||||
ChatRequest *string `json:"chat_message"`
|
||||
CompletedAt int64 `json:"completed_at"`
|
||||
FailedAt int64 `json:"failed_at"`
|
||||
CreatorID int64 `json:"creator_id"`
|
||||
}
|
||||
|
||||
type ChunkRunItem = RunRecordMeta
|
||||
@@ -158,3 +159,18 @@ type ModelAnswerEvent struct {
|
||||
Message *schema.Message
|
||||
Err error
|
||||
}
|
||||
|
||||
type ListRunRecordMeta struct {
|
||||
ConversationID int64 `json:"conversation_id"`
|
||||
AgentID int64 `json:"agent_id"`
|
||||
SectionID int64 `json:"section_id"`
|
||||
Limit int32 `json:"limit"`
|
||||
OrderBy string `json:"order_by"` //desc asc
|
||||
BeforeID int64 `json:"before_id"`
|
||||
AfterID int64 `json:"after_id"`
|
||||
}
|
||||
|
||||
type CancelRunMeta struct {
|
||||
ConversationID int64 `json:"conversation_id"`
|
||||
RunID int64 `json:"run_id"`
|
||||
}
|
||||
|
||||
45
backend/domain/conversation/agentrun/internal/agent_info.go
Normal file
45
backend/domain/conversation/agentrun/internal/agent_info.go
Normal file
@@ -0,0 +1,45 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
|
||||
crossagent "github.com/coze-dev/coze-studio/backend/crossdomain/contract/agent"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
)
|
||||
|
||||
func getAgentHistoryRounds(agentInfo *singleagent.SingleAgent) int32 {
|
||||
var conversationTurns int32 = entity.ConversationTurnsDefault
|
||||
if agentInfo != nil && agentInfo.ModelInfo != nil && agentInfo.ModelInfo.ShortMemoryPolicy != nil && ptr.From(agentInfo.ModelInfo.ShortMemoryPolicy.HistoryRound) > 0 {
|
||||
conversationTurns = ptr.From(agentInfo.ModelInfo.ShortMemoryPolicy.HistoryRound)
|
||||
}
|
||||
return conversationTurns
|
||||
}
|
||||
|
||||
func getAgentInfo(ctx context.Context, agentID int64, isDraft bool) (*singleagent.SingleAgent, error) {
|
||||
agentInfo, err := crossagent.DefaultSVC().ObtainAgentByIdentity(ctx, &singleagent.AgentIdentity{
|
||||
AgentID: agentID,
|
||||
IsDraft: isDraft,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return agentInfo, nil
|
||||
}
|
||||
215
backend/domain/conversation/agentrun/internal/chatflow_run.go
Normal file
215
backend/domain/conversation/agentrun/internal/chatflow_run.go
Normal file
@@ -0,0 +1,215 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package internal
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
crossworkflow "github.com/coze-dev/coze-studio/backend/crossdomain/contract/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
|
||||
msgEntity "github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
func (art *AgentRuntime) ChatflowRun(ctx context.Context, imagex imagex.ImageX) (err error) {
|
||||
|
||||
mh := &MesssageEventHanlder{
|
||||
sw: art.SW,
|
||||
messageEvent: art.MessageEvent,
|
||||
}
|
||||
resumeInfo := parseResumeInfo(ctx, art.GetHistory())
|
||||
wfID, _ := strconv.ParseInt(art.GetAgentInfo().LayoutInfo.WorkflowId, 10, 64)
|
||||
|
||||
if wfID == 0 {
|
||||
mh.handlerErr(ctx, errorx.New(errno.ErrAgentRunWorkflowNotFound))
|
||||
return
|
||||
}
|
||||
var wfStreamer *schema.StreamReader[*crossworkflow.WorkflowMessage]
|
||||
|
||||
executeConfig := crossworkflow.ExecuteConfig{
|
||||
ID: wfID,
|
||||
ConnectorID: art.GetRunMeta().ConnectorID,
|
||||
ConnectorUID: art.GetRunMeta().UserID,
|
||||
AgentID: ptr.Of(art.GetRunMeta().AgentID),
|
||||
Mode: crossworkflow.ExecuteModeRelease,
|
||||
BizType: crossworkflow.BizTypeAgent,
|
||||
SyncPattern: crossworkflow.SyncPatternStream,
|
||||
From: crossworkflow.FromLatestVersion,
|
||||
}
|
||||
|
||||
if resumeInfo != nil {
|
||||
wfStreamer, err = crossworkflow.DefaultSVC().StreamResume(ctx, &crossworkflow.ResumeRequest{
|
||||
ResumeData: concatWfInput(art),
|
||||
EventID: resumeInfo.ChatflowInterrupt.InterruptEvent.ID,
|
||||
ExecuteID: resumeInfo.ChatflowInterrupt.ExecuteID,
|
||||
}, executeConfig)
|
||||
} else {
|
||||
executeConfig.ConversationID = &art.GetRunMeta().ConversationID
|
||||
executeConfig.SectionID = &art.GetRunMeta().SectionID
|
||||
executeConfig.InitRoundID = &art.RunRecord.ID
|
||||
executeConfig.RoundID = &art.RunRecord.ID
|
||||
executeConfig.UserMessage = transMessageToSchemaMessage(ctx, []*msgEntity.Message{art.GetInput()}, imagex)[0]
|
||||
executeConfig.MaxHistoryRounds = ptr.Of(getAgentHistoryRounds(art.GetAgentInfo()))
|
||||
wfStreamer, err = crossworkflow.DefaultSVC().StreamExecute(ctx, executeConfig, map[string]any{
|
||||
"USER_INPUT": concatWfInput(art),
|
||||
})
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
safego.Go(ctx, func() {
|
||||
defer wg.Done()
|
||||
art.pullWfStream(ctx, wfStreamer, mh)
|
||||
})
|
||||
wg.Wait()
|
||||
return err
|
||||
}
|
||||
|
||||
func concatWfInput(rtDependence *AgentRuntime) string {
|
||||
var input string
|
||||
for _, content := range rtDependence.RunMeta.Content {
|
||||
if content.Type == message.InputTypeText {
|
||||
input = content.Text + "," + input
|
||||
} else {
|
||||
for _, file := range content.FileData {
|
||||
input += file.Url + ","
|
||||
}
|
||||
}
|
||||
}
|
||||
return input
|
||||
}
|
||||
|
||||
func (art *AgentRuntime) pullWfStream(ctx context.Context, events *schema.StreamReader[*crossworkflow.WorkflowMessage], mh *MesssageEventHanlder) {
|
||||
|
||||
fullAnswerContent := bytes.NewBuffer([]byte{})
|
||||
var usage *msgEntity.UsageExt
|
||||
|
||||
preAnswerMsg, cErr := preCreateAnswer(ctx, art)
|
||||
|
||||
if cErr != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var preMsgIsFinish = false
|
||||
var lastAnswerMsg *entity.ChunkMessageItem
|
||||
|
||||
for {
|
||||
st, re := events.Recv()
|
||||
if re != nil {
|
||||
if errors.Is(re, io.EOF) {
|
||||
|
||||
if lastAnswerMsg != nil && usage != nil {
|
||||
art.SetUsage(&agentrun.Usage{
|
||||
LlmPromptTokens: usage.InputTokens,
|
||||
LlmCompletionTokens: usage.OutputTokens,
|
||||
LlmTotalTokens: usage.TotalCount,
|
||||
})
|
||||
_ = mh.handlerWfUsage(ctx, lastAnswerMsg, usage)
|
||||
}
|
||||
|
||||
finishErr := mh.handlerFinalAnswerFinish(ctx, art)
|
||||
if finishErr != nil {
|
||||
logs.CtxErrorf(ctx, "handlerFinalAnswerFinish error: %v", finishErr)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
logs.CtxErrorf(ctx, "pullWfStream Recv error: %v", re)
|
||||
mh.handlerErr(ctx, re)
|
||||
return
|
||||
}
|
||||
if st == nil {
|
||||
continue
|
||||
}
|
||||
if st.StateMessage != nil {
|
||||
if st.StateMessage.Status == crossworkflow.WorkflowFailed {
|
||||
mh.handlerErr(ctx, st.StateMessage.LastError)
|
||||
continue
|
||||
}
|
||||
if st.StateMessage.Usage != nil {
|
||||
usage = &msgEntity.UsageExt{
|
||||
InputTokens: st.StateMessage.Usage.InputTokens,
|
||||
OutputTokens: st.StateMessage.Usage.OutputTokens,
|
||||
TotalCount: st.StateMessage.Usage.InputTokens + st.StateMessage.Usage.OutputTokens,
|
||||
}
|
||||
}
|
||||
|
||||
if st.StateMessage.InterruptEvent != nil { // interrupt
|
||||
mh.handlerWfInterruptMsg(ctx, st.StateMessage, art)
|
||||
continue
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if st.DataMessage == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
switch st.DataMessage.Type {
|
||||
case crossworkflow.Answer:
|
||||
|
||||
// input node & question node skip
|
||||
if st.DataMessage != nil && (st.DataMessage.NodeType == crossworkflow.NodeTypeInputReceiver || st.DataMessage.NodeType == crossworkflow.NodeTypeQuestion) {
|
||||
break
|
||||
}
|
||||
|
||||
if preMsgIsFinish {
|
||||
preAnswerMsg, cErr = preCreateAnswer(ctx, art)
|
||||
if cErr != nil {
|
||||
return
|
||||
}
|
||||
preMsgIsFinish = false
|
||||
}
|
||||
if st.DataMessage.Content != "" {
|
||||
fullAnswerContent.WriteString(st.DataMessage.Content)
|
||||
}
|
||||
|
||||
sendAnswerMsg := buildSendMsg(ctx, preAnswerMsg, false, art)
|
||||
sendAnswerMsg.Content = st.DataMessage.Content
|
||||
|
||||
mh.messageEvent.SendMsgEvent(entity.RunEventMessageDelta, sendAnswerMsg, mh.sw)
|
||||
|
||||
if st.DataMessage.Last {
|
||||
preMsgIsFinish = true
|
||||
sendAnswerMsg := buildSendMsg(ctx, preAnswerMsg, false, art)
|
||||
sendAnswerMsg.Content = fullAnswerContent.String()
|
||||
fullAnswerContent.Reset()
|
||||
hfErr := mh.handlerAnswer(ctx, sendAnswerMsg, usage, art, preAnswerMsg)
|
||||
if hfErr != nil {
|
||||
return
|
||||
}
|
||||
lastAnswerMsg = sendAnswerMsg
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -19,6 +19,8 @@ package dal
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@@ -27,6 +29,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/internal/dal/model"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/internal/dal/query"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
)
|
||||
|
||||
@@ -59,8 +62,12 @@ func (dao *RunRecordDAO) Create(ctx context.Context, runMeta *entity.AgentRunMet
|
||||
return dao.buildPo2Do(createPO), nil
|
||||
}
|
||||
|
||||
func (dao *RunRecordDAO) GetByID(ctx context.Context, id int64) (*model.RunRecord, error) {
|
||||
return dao.query.RunRecord.WithContext(ctx).Where(dao.query.RunRecord.ID.Eq(id)).First()
|
||||
func (dao *RunRecordDAO) GetByID(ctx context.Context, id int64) (*entity.RunRecordMeta, error) {
|
||||
po, err := dao.query.RunRecord.WithContext(ctx).Where(dao.query.RunRecord.ID.Eq(id)).First()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dao.buildPo2Do(po), nil
|
||||
}
|
||||
|
||||
func (dao *RunRecordDAO) UpdateByID(ctx context.Context, id int64, updateMeta *entity.UpdateMeta) error {
|
||||
@@ -106,20 +113,40 @@ func (dao *RunRecordDAO) Delete(ctx context.Context, id []int64) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (dao *RunRecordDAO) List(ctx context.Context, conversationID int64, sectionID int64, limit int32) ([]*model.RunRecord, error) {
|
||||
logs.CtxInfof(ctx, "list run record req:%v, sectionID:%v, limit:%v", conversationID, sectionID, limit)
|
||||
func (dao *RunRecordDAO) List(ctx context.Context, meta *entity.ListRunRecordMeta) ([]*entity.RunRecordMeta, error) {
|
||||
logs.CtxInfof(ctx, "list run record req:%v, sectionID:%v, limit:%v", meta.ConversationID, meta.SectionID, meta.Limit)
|
||||
m := dao.query.RunRecord
|
||||
do := m.WithContext(ctx).Where(m.ConversationID.Eq(conversationID)).Debug().Where(m.Status.NotIn(string(entity.RunStatusDeleted)))
|
||||
|
||||
if sectionID > 0 {
|
||||
do = do.Where(m.SectionID.Eq(sectionID))
|
||||
do := m.WithContext(ctx).Where(m.ConversationID.Eq(meta.ConversationID)).Debug().Where(m.Status.NotIn(string(entity.RunStatusDeleted)))
|
||||
if meta.BeforeID > 0 {
|
||||
runRecord, err := m.Where(m.ID.Eq(meta.BeforeID)).First()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
do = do.Where(m.CreatedAt.Lt(runRecord.CreatedAt))
|
||||
}
|
||||
if limit > 0 {
|
||||
do = do.Limit(int(limit))
|
||||
if meta.AfterID > 0 {
|
||||
runRecord, err := m.Where(m.ID.Eq(meta.AfterID)).First()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
do = do.Where(m.CreatedAt.Gt(runRecord.CreatedAt))
|
||||
}
|
||||
if meta.SectionID > 0 {
|
||||
do = do.Where(m.SectionID.Eq(meta.SectionID))
|
||||
}
|
||||
if meta.Limit > 0 {
|
||||
do = do.Limit(int(meta.Limit))
|
||||
}
|
||||
if strings.ToLower(meta.OrderBy) == "asc" {
|
||||
do = do.Order(m.CreatedAt.Asc())
|
||||
} else {
|
||||
do = do.Order(m.CreatedAt.Desc())
|
||||
}
|
||||
|
||||
runRecords, err := do.Order(m.CreatedAt.Desc()).Find()
|
||||
return runRecords, err
|
||||
runRecords, err := do.Find()
|
||||
return slices.Transform(runRecords, func(item *model.RunRecord) *entity.RunRecordMeta {
|
||||
return dao.buildPo2Do(item)
|
||||
}), err
|
||||
}
|
||||
|
||||
func (dao *RunRecordDAO) buildCreatePO(ctx context.Context, runMeta *entity.AgentRunMeta) (*model.RunRecord, error) {
|
||||
@@ -135,7 +162,10 @@ func (dao *RunRecordDAO) buildCreatePO(ctx context.Context, runMeta *entity.Agen
|
||||
}
|
||||
|
||||
timeNow := time.Now().UnixMilli()
|
||||
|
||||
creatorID, err := strconv.ParseInt(runMeta.UserID, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &model.RunRecord{
|
||||
ID: runID,
|
||||
ConversationID: runMeta.ConversationID,
|
||||
@@ -145,6 +175,7 @@ func (dao *RunRecordDAO) buildCreatePO(ctx context.Context, runMeta *entity.Agen
|
||||
ChatRequest: string(reqOrigin),
|
||||
UserID: runMeta.UserID,
|
||||
CreatedAt: timeNow,
|
||||
CreatorID: creatorID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -161,7 +192,21 @@ func (dao *RunRecordDAO) buildPo2Do(po *model.RunRecord) *entity.RunRecordMeta {
|
||||
CompletedAt: po.CompletedAt,
|
||||
FailedAt: po.FailedAt,
|
||||
Usage: po.Usage,
|
||||
CreatorID: po.CreatorID,
|
||||
}
|
||||
|
||||
return runMeta
|
||||
}
|
||||
|
||||
func (dao *RunRecordDAO) Cancel(ctx context.Context, meta *entity.CancelRunMeta) (*entity.RunRecordMeta, error) {
|
||||
|
||||
m := dao.query.RunRecord
|
||||
_, err := m.WithContext(ctx).Where(m.ID.Eq(meta.RunID)).UpdateColumns(map[string]interface{}{
|
||||
"updated_at": time.Now().UnixMilli(),
|
||||
"status": entity.RunEventCancelled,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dao.GetByID(ctx, meta.RunID)
|
||||
}
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
|
||||
)
|
||||
|
||||
type Event struct {
|
||||
}
|
||||
|
||||
func NewEvent() *Event {
|
||||
return &Event{}
|
||||
}
|
||||
|
||||
func (e *Event) buildMessageEvent(runEvent entity.RunEvent, chunkMsgItem *entity.ChunkMessageItem) *entity.AgentRunResponse {
|
||||
return &entity.AgentRunResponse{
|
||||
Event: runEvent,
|
||||
ChunkMessageItem: chunkMsgItem,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Event) buildRunEvent(runEvent entity.RunEvent, chunkRunItem *entity.ChunkRunItem) *entity.AgentRunResponse {
|
||||
return &entity.AgentRunResponse{
|
||||
Event: runEvent,
|
||||
ChunkRunItem: chunkRunItem,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Event) buildErrEvent(runEvent entity.RunEvent, err *entity.RunError) *entity.AgentRunResponse {
|
||||
return &entity.AgentRunResponse{
|
||||
Event: runEvent,
|
||||
Error: err,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Event) buildStreamDoneEvent() *entity.AgentRunResponse {
|
||||
|
||||
return &entity.AgentRunResponse{
|
||||
Event: entity.RunEventStreamDone,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Event) SendRunEvent(runEvent entity.RunEvent, runItem *entity.ChunkRunItem, sw *schema.StreamWriter[*entity.AgentRunResponse]) {
|
||||
resp := e.buildRunEvent(runEvent, runItem)
|
||||
sw.Send(resp, nil)
|
||||
}
|
||||
|
||||
func (e *Event) SendMsgEvent(runEvent entity.RunEvent, messageItem *entity.ChunkMessageItem, sw *schema.StreamWriter[*entity.AgentRunResponse]) {
|
||||
resp := e.buildMessageEvent(runEvent, messageItem)
|
||||
sw.Send(resp, nil)
|
||||
}
|
||||
|
||||
func (e *Event) SendErrEvent(runEvent entity.RunEvent, sw *schema.StreamWriter[*entity.AgentRunResponse], err *entity.RunError) {
|
||||
resp := e.buildErrEvent(runEvent, err)
|
||||
sw.Send(resp, nil)
|
||||
}
|
||||
|
||||
func (e *Event) SendStreamDoneEvent(sw *schema.StreamWriter[*entity.AgentRunResponse]) {
|
||||
resp := e.buildStreamDoneEvent()
|
||||
sw.Send(resp, nil)
|
||||
}
|
||||
512
backend/domain/conversation/agentrun/internal/message_builder.go
Normal file
512
backend/domain/conversation/agentrun/internal/message_builder.go
Normal file
@@ -0,0 +1,512 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
messageModel "github.com/coze-dev/coze-studio/backend/api/model/conversation/message"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
|
||||
crossagent "github.com/coze-dev/coze-studio/backend/crossdomain/contract/agent"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
|
||||
|
||||
crossmessage "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message"
|
||||
crossworkflow "github.com/coze-dev/coze-studio/backend/crossdomain/contract/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
|
||||
msgEntity "github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
func buildSendMsg(_ context.Context, msg *msgEntity.Message, isFinish bool, rtDependence *AgentRuntime) *entity.ChunkMessageItem {
|
||||
|
||||
copyMap := make(map[string]string)
|
||||
for k, v := range msg.Ext {
|
||||
copyMap[k] = v
|
||||
}
|
||||
|
||||
return &entity.ChunkMessageItem{
|
||||
ID: msg.ID,
|
||||
ConversationID: msg.ConversationID,
|
||||
SectionID: msg.SectionID,
|
||||
AgentID: msg.AgentID,
|
||||
Content: msg.Content,
|
||||
Role: entity.RoleTypeAssistant,
|
||||
ContentType: msg.ContentType,
|
||||
MessageType: msg.MessageType,
|
||||
ReplyID: rtDependence.GetQuestionMsgID(),
|
||||
Type: msg.MessageType,
|
||||
CreatedAt: msg.CreatedAt,
|
||||
UpdatedAt: msg.UpdatedAt,
|
||||
RunID: rtDependence.GetRunRecord().ID,
|
||||
Ext: copyMap,
|
||||
IsFinish: isFinish,
|
||||
ReasoningContent: ptr.Of(msg.ReasoningContent),
|
||||
}
|
||||
}
|
||||
|
||||
func buildKnowledge(_ context.Context, chunk *entity.AgentRespEvent) *msgEntity.VerboseInfo {
|
||||
var recallDatas []msgEntity.RecallDataInfo
|
||||
for _, kOne := range chunk.Knowledge {
|
||||
recallDatas = append(recallDatas, msgEntity.RecallDataInfo{
|
||||
Slice: kOne.Content,
|
||||
Meta: msgEntity.MetaInfo{
|
||||
Dataset: msgEntity.DatasetInfo{
|
||||
ID: kOne.MetaData["dataset_id"].(string),
|
||||
Name: kOne.MetaData["dataset_name"].(string),
|
||||
},
|
||||
Document: msgEntity.DocumentInfo{
|
||||
ID: kOne.MetaData["document_id"].(string),
|
||||
Name: kOne.MetaData["document_name"].(string),
|
||||
},
|
||||
},
|
||||
Score: kOne.Score(),
|
||||
})
|
||||
}
|
||||
|
||||
verboseData := &msgEntity.VerboseData{
|
||||
Chunks: recallDatas,
|
||||
OriReq: "",
|
||||
StatusCode: 0,
|
||||
}
|
||||
data, err := json.Marshal(verboseData)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
knowledgeInfo := &msgEntity.VerboseInfo{
|
||||
MessageType: string(entity.MessageSubTypeKnowledgeCall),
|
||||
Data: string(data),
|
||||
}
|
||||
return knowledgeInfo
|
||||
}
|
||||
|
||||
func buildBotStateExt(arm *entity.AgentRunMeta) *msgEntity.BotStateExt {
|
||||
agentID := strconv.FormatInt(arm.AgentID, 10)
|
||||
botStateExt := &msgEntity.BotStateExt{
|
||||
AgentID: agentID,
|
||||
AgentName: arm.Name,
|
||||
Awaiting: agentID,
|
||||
BotID: agentID,
|
||||
}
|
||||
|
||||
return botStateExt
|
||||
}
|
||||
|
||||
type irMsg struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
ContentType string `json:"content_type"`
|
||||
Content any `json:"content"` // either optionContent or string
|
||||
ID string `json:"id,omitempty"`
|
||||
}
|
||||
|
||||
func parseInterruptData(_ context.Context, interruptData *singleagent.InterruptInfo) (string, message.ContentType, error) {
|
||||
|
||||
defaultContentType := message.ContentTypeText
|
||||
switch interruptData.InterruptType {
|
||||
case singleagent.InterruptEventType_OauthPlugin:
|
||||
data := interruptData.AllToolInterruptData[interruptData.ToolCallID].ToolNeedOAuth.Message
|
||||
return data, defaultContentType, nil
|
||||
case singleagent.InterruptEventType_Question:
|
||||
data := interruptData.AllWfInterruptData[interruptData.ToolCallID].InterruptData
|
||||
return processQuestionInterruptData(data)
|
||||
case singleagent.InterruptEventType_InputNode:
|
||||
data := interruptData.AllWfInterruptData[interruptData.ToolCallID].InterruptData
|
||||
return processInputNodeInterruptData(data)
|
||||
case singleagent.InterruptEventType_WorkflowLLM:
|
||||
toolInterruptEvent := interruptData.AllWfInterruptData[interruptData.ToolCallID].ToolInterruptEvent
|
||||
data := toolInterruptEvent.InterruptData
|
||||
if singleagent.InterruptEventType(toolInterruptEvent.EventType) == singleagent.InterruptEventType_InputNode {
|
||||
return processInputNodeInterruptData(data)
|
||||
}
|
||||
if singleagent.InterruptEventType(toolInterruptEvent.EventType) == singleagent.InterruptEventType_Question {
|
||||
return processQuestionInterruptData(data)
|
||||
}
|
||||
return "", defaultContentType, errorx.New(errno.ErrUnknowInterruptType)
|
||||
|
||||
}
|
||||
return "", defaultContentType, errorx.New(errno.ErrUnknowInterruptType)
|
||||
}
|
||||
|
||||
func processQuestionInterruptData(data string) (string, message.ContentType, error) {
|
||||
defaultContentType := message.ContentTypeText
|
||||
var iData map[string][]*irMsg
|
||||
err := json.Unmarshal([]byte(data), &iData)
|
||||
if err != nil {
|
||||
return "", defaultContentType, err
|
||||
}
|
||||
if len(iData["messages"]) == 0 {
|
||||
return "", defaultContentType, errorx.New(errno.ErrInterruptDataEmpty)
|
||||
}
|
||||
interruptMsg := iData["messages"][0]
|
||||
|
||||
if interruptMsg.ContentType == "text" {
|
||||
return interruptMsg.Content.(string), defaultContentType, nil
|
||||
} else if interruptMsg.ContentType == "option" || interruptMsg.ContentType == "form_schema" {
|
||||
iMarshalData, err := json.Marshal(interruptMsg)
|
||||
if err != nil {
|
||||
return "", defaultContentType, err
|
||||
}
|
||||
return string(iMarshalData), message.ContentTypeCard, nil
|
||||
}
|
||||
return "", defaultContentType, errorx.New(errno.ErrUnknowInterruptType)
|
||||
}
|
||||
|
||||
func processInputNodeInterruptData(data string) (string, message.ContentType, error) {
|
||||
return data, message.ContentTypeCard, nil
|
||||
}
|
||||
|
||||
func handlerUsage(meta *schema.ResponseMeta) *msgEntity.UsageExt {
|
||||
if meta == nil || meta.Usage == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &msgEntity.UsageExt{
|
||||
TotalCount: int64(meta.Usage.TotalTokens),
|
||||
InputTokens: int64(meta.Usage.PromptTokens),
|
||||
OutputTokens: int64(meta.Usage.CompletionTokens),
|
||||
}
|
||||
}
|
||||
|
||||
func preCreateAnswer(ctx context.Context, rtDependence *AgentRuntime) (*msgEntity.Message, error) {
|
||||
arm := rtDependence.RunMeta
|
||||
msgMeta := &msgEntity.Message{
|
||||
ConversationID: arm.ConversationID,
|
||||
RunID: rtDependence.RunRecord.ID,
|
||||
AgentID: arm.AgentID,
|
||||
SectionID: arm.SectionID,
|
||||
UserID: arm.UserID,
|
||||
Role: schema.Assistant,
|
||||
MessageType: message.MessageTypeAnswer,
|
||||
ContentType: message.ContentTypeText,
|
||||
Ext: arm.Ext,
|
||||
}
|
||||
|
||||
if arm.Ext == nil {
|
||||
msgMeta.Ext = map[string]string{}
|
||||
}
|
||||
|
||||
botStateExt := buildBotStateExt(arm)
|
||||
bseString, err := json.Marshal(botStateExt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, ok := msgMeta.Ext[string(msgEntity.MessageExtKeyBotState)]; !ok {
|
||||
msgMeta.Ext[string(msgEntity.MessageExtKeyBotState)] = string(bseString)
|
||||
}
|
||||
|
||||
msgMeta.Ext = arm.Ext
|
||||
return crossmessage.DefaultSVC().PreCreate(ctx, msgMeta)
|
||||
}
|
||||
|
||||
func buildAgentMessage2Create(ctx context.Context, chunk *entity.AgentRespEvent, messageType message.MessageType, rtDependence *AgentRuntime) *message.Message {
|
||||
arm := rtDependence.GetRunMeta()
|
||||
msg := &msgEntity.Message{
|
||||
ConversationID: arm.ConversationID,
|
||||
RunID: rtDependence.RunRecord.ID,
|
||||
AgentID: arm.AgentID,
|
||||
SectionID: arm.SectionID,
|
||||
UserID: arm.UserID,
|
||||
MessageType: messageType,
|
||||
}
|
||||
buildExt := map[string]string{}
|
||||
|
||||
timeCost := fmt.Sprintf("%.1f", float64(time.Since(rtDependence.GetStartTime()).Milliseconds())/1000.00)
|
||||
|
||||
switch messageType {
|
||||
case message.MessageTypeQuestion:
|
||||
msg.Role = schema.User
|
||||
msg.ContentType = arm.ContentType
|
||||
for _, content := range arm.Content {
|
||||
if content.Type == message.InputTypeText {
|
||||
msg.Content = content.Text
|
||||
break
|
||||
}
|
||||
}
|
||||
msg.MultiContent = arm.Content
|
||||
buildExt = arm.Ext
|
||||
|
||||
msg.DisplayContent = arm.DisplayContent
|
||||
case message.MessageTypeAnswer, message.MessageTypeToolAsAnswer:
|
||||
msg.Role = schema.Assistant
|
||||
msg.ContentType = message.ContentTypeText
|
||||
|
||||
case message.MessageTypeToolResponse:
|
||||
msg.Role = schema.Assistant
|
||||
msg.ContentType = message.ContentTypeText
|
||||
msg.Content = chunk.ToolsMessage[0].Content
|
||||
|
||||
buildExt[string(msgEntity.MessageExtKeyTimeCost)] = timeCost
|
||||
modelContent := chunk.ToolsMessage[0]
|
||||
mc, err := json.Marshal(modelContent)
|
||||
if err == nil {
|
||||
msg.ModelContent = string(mc)
|
||||
}
|
||||
|
||||
case message.MessageTypeKnowledge:
|
||||
msg.Role = schema.Assistant
|
||||
msg.ContentType = message.ContentTypeText
|
||||
|
||||
knowledgeContent := buildKnowledge(ctx, chunk)
|
||||
if knowledgeContent != nil {
|
||||
knInfo, err := json.Marshal(knowledgeContent)
|
||||
if err == nil {
|
||||
msg.Content = string(knInfo)
|
||||
}
|
||||
}
|
||||
|
||||
buildExt[string(msgEntity.MessageExtKeyTimeCost)] = timeCost
|
||||
|
||||
modelContent := chunk.Knowledge
|
||||
mc, err := json.Marshal(modelContent)
|
||||
if err == nil {
|
||||
msg.ModelContent = string(mc)
|
||||
}
|
||||
|
||||
case message.MessageTypeFunctionCall:
|
||||
msg.Role = schema.Assistant
|
||||
msg.ContentType = message.ContentTypeText
|
||||
|
||||
if len(chunk.FuncCall.ToolCalls) > 0 {
|
||||
toolCall := chunk.FuncCall.ToolCalls[0]
|
||||
toolCalling, err := json.Marshal(toolCall)
|
||||
if err == nil {
|
||||
msg.Content = string(toolCalling)
|
||||
}
|
||||
buildExt[string(msgEntity.MessageExtKeyPlugin)] = toolCall.Function.Name
|
||||
buildExt[string(msgEntity.MessageExtKeyToolName)] = toolCall.Function.Name
|
||||
buildExt[string(msgEntity.MessageExtKeyTimeCost)] = timeCost
|
||||
|
||||
modelContent := chunk.FuncCall
|
||||
mc, err := json.Marshal(modelContent)
|
||||
if err == nil {
|
||||
msg.ModelContent = string(mc)
|
||||
}
|
||||
}
|
||||
case message.MessageTypeFlowUp:
|
||||
msg.Role = schema.Assistant
|
||||
msg.ContentType = message.ContentTypeText
|
||||
msg.Content = chunk.Suggest.Content
|
||||
|
||||
case message.MessageTypeVerbose:
|
||||
msg.Role = schema.Assistant
|
||||
msg.ContentType = message.ContentTypeText
|
||||
|
||||
d := &entity.Data{
|
||||
FinishReason: 0,
|
||||
FinData: "",
|
||||
}
|
||||
dByte, _ := json.Marshal(d)
|
||||
afc := &entity.AnswerFinshContent{
|
||||
MsgType: entity.MessageSubTypeGenerateFinish,
|
||||
Data: string(dByte),
|
||||
}
|
||||
afcMarshal, _ := json.Marshal(afc)
|
||||
msg.Content = string(afcMarshal)
|
||||
case message.MessageTypeInterrupt:
|
||||
msg.Role = schema.Assistant
|
||||
msg.MessageType = message.MessageTypeVerbose
|
||||
msg.ContentType = message.ContentTypeText
|
||||
|
||||
afc := &entity.AnswerFinshContent{
|
||||
MsgType: entity.MessageSubTypeInterrupt,
|
||||
Data: "",
|
||||
}
|
||||
afcMarshal, _ := json.Marshal(afc)
|
||||
msg.Content = string(afcMarshal)
|
||||
|
||||
// Add ext to save to context_message
|
||||
interruptByte, err := json.Marshal(chunk.Interrupt)
|
||||
if err == nil {
|
||||
buildExt[string(msgEntity.ExtKeyResumeInfo)] = string(interruptByte)
|
||||
}
|
||||
buildExt[string(msgEntity.ExtKeyToolCallsIDs)] = chunk.Interrupt.ToolCallID
|
||||
rc := &messageModel.RequiredAction{
|
||||
Type: "submit_tool_outputs",
|
||||
SubmitToolOutputs: &messageModel.SubmitToolOutputs{},
|
||||
}
|
||||
msg.RequiredAction = rc
|
||||
rcExtByte, err := json.Marshal(rc)
|
||||
if err == nil {
|
||||
buildExt[string(msgEntity.ExtKeyRequiresAction)] = string(rcExtByte)
|
||||
}
|
||||
}
|
||||
|
||||
if messageType != message.MessageTypeQuestion {
|
||||
botStateExt := buildBotStateExt(arm)
|
||||
bseString, err := json.Marshal(botStateExt)
|
||||
if err == nil {
|
||||
buildExt[string(msgEntity.MessageExtKeyBotState)] = string(bseString)
|
||||
}
|
||||
}
|
||||
msg.Ext = buildExt
|
||||
return msg
|
||||
}
|
||||
|
||||
func handlerWfInterruptEvent(_ context.Context, interruptEventData *crossworkflow.InterruptEvent) (string, message.ContentType, error) {
|
||||
defaultContentType := message.ContentTypeText
|
||||
switch singleagent.InterruptEventType(interruptEventData.EventType) {
|
||||
case singleagent.InterruptEventType_OauthPlugin:
|
||||
|
||||
case singleagent.InterruptEventType_Question:
|
||||
data := interruptEventData.InterruptData
|
||||
return processQuestionInterruptData(data)
|
||||
case singleagent.InterruptEventType_InputNode:
|
||||
data := interruptEventData.InterruptData
|
||||
return processInputNodeInterruptData(data)
|
||||
case singleagent.InterruptEventType_WorkflowLLM:
|
||||
data := interruptEventData.ToolInterruptEvent.InterruptData
|
||||
if singleagent.InterruptEventType(interruptEventData.EventType) == singleagent.InterruptEventType_InputNode {
|
||||
return processInputNodeInterruptData(data)
|
||||
}
|
||||
if singleagent.InterruptEventType(interruptEventData.EventType) == singleagent.InterruptEventType_Question {
|
||||
return processQuestionInterruptData(data)
|
||||
}
|
||||
return "", defaultContentType, errorx.New(errno.ErrUnknowInterruptType)
|
||||
}
|
||||
return "", defaultContentType, errorx.New(errno.ErrUnknowInterruptType)
|
||||
}
|
||||
|
||||
func historyPairs(historyMsg []*message.Message) []*message.Message {
|
||||
|
||||
fcMsgPairs := make(map[int64][]*message.Message)
|
||||
for _, one := range historyMsg {
|
||||
if one.MessageType != message.MessageTypeFunctionCall && one.MessageType != message.MessageTypeToolResponse {
|
||||
continue
|
||||
}
|
||||
if _, ok := fcMsgPairs[one.RunID]; !ok {
|
||||
fcMsgPairs[one.RunID] = []*message.Message{one}
|
||||
} else {
|
||||
fcMsgPairs[one.RunID] = append(fcMsgPairs[one.RunID], one)
|
||||
}
|
||||
}
|
||||
|
||||
var historyAfterPairs []*message.Message
|
||||
for _, value := range historyMsg {
|
||||
if value.MessageType == message.MessageTypeFunctionCall {
|
||||
if len(fcMsgPairs[value.RunID])%2 == 0 {
|
||||
historyAfterPairs = append(historyAfterPairs, value)
|
||||
}
|
||||
} else {
|
||||
historyAfterPairs = append(historyAfterPairs, value)
|
||||
}
|
||||
}
|
||||
return historyAfterPairs
|
||||
|
||||
}
|
||||
|
||||
func transMessageToSchemaMessage(ctx context.Context, msgs []*message.Message, imagexClient imagex.ImageX) []*schema.Message {
|
||||
schemaMessage := make([]*schema.Message, 0, len(msgs))
|
||||
|
||||
for _, msgOne := range msgs {
|
||||
if msgOne.ModelContent == "" {
|
||||
continue
|
||||
}
|
||||
if msgOne.MessageType == message.MessageTypeVerbose || msgOne.MessageType == message.MessageTypeFlowUp {
|
||||
continue
|
||||
}
|
||||
var sm *schema.Message
|
||||
err := json.Unmarshal([]byte(msgOne.ModelContent), &sm)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if len(sm.ReasoningContent) > 0 {
|
||||
sm.ReasoningContent = ""
|
||||
}
|
||||
schemaMessage = append(schemaMessage, parseMessageURI(ctx, sm, imagexClient))
|
||||
}
|
||||
|
||||
return schemaMessage
|
||||
}
|
||||
|
||||
func parseMessageURI(ctx context.Context, mcMsg *schema.Message, imagexClient imagex.ImageX) *schema.Message {
|
||||
if mcMsg.MultiContent == nil {
|
||||
return mcMsg
|
||||
}
|
||||
for k, one := range mcMsg.MultiContent {
|
||||
switch one.Type {
|
||||
case schema.ChatMessagePartTypeImageURL:
|
||||
|
||||
if one.ImageURL.URI != "" {
|
||||
url, err := imagexClient.GetResourceURL(ctx, one.ImageURL.URI)
|
||||
if err == nil {
|
||||
mcMsg.MultiContent[k].ImageURL.URL = url.URL
|
||||
}
|
||||
}
|
||||
case schema.ChatMessagePartTypeFileURL:
|
||||
if one.FileURL.URI != "" {
|
||||
url, err := imagexClient.GetResourceURL(ctx, one.FileURL.URI)
|
||||
if err == nil {
|
||||
mcMsg.MultiContent[k].FileURL.URL = url.URL
|
||||
}
|
||||
}
|
||||
case schema.ChatMessagePartTypeAudioURL:
|
||||
if one.AudioURL.URI != "" {
|
||||
url, err := imagexClient.GetResourceURL(ctx, one.AudioURL.URI)
|
||||
if err == nil {
|
||||
mcMsg.MultiContent[k].AudioURL.URL = url.URL
|
||||
}
|
||||
}
|
||||
case schema.ChatMessagePartTypeVideoURL:
|
||||
if one.VideoURL.URI != "" {
|
||||
url, err := imagexClient.GetResourceURL(ctx, one.VideoURL.URI)
|
||||
if err == nil {
|
||||
mcMsg.MultiContent[k].VideoURL.URL = url.URL
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return mcMsg
|
||||
}
|
||||
|
||||
func parseResumeInfo(_ context.Context, historyMsg []*message.Message) *crossagent.ResumeInfo {
|
||||
|
||||
var resumeInfo *crossagent.ResumeInfo
|
||||
for i := len(historyMsg) - 1; i >= 0; i-- {
|
||||
if historyMsg[i].MessageType == message.MessageTypeQuestion {
|
||||
break
|
||||
}
|
||||
if historyMsg[i].MessageType == message.MessageTypeVerbose {
|
||||
if historyMsg[i].Ext[string(msgEntity.ExtKeyResumeInfo)] != "" {
|
||||
err := json.Unmarshal([]byte(historyMsg[i].Ext[string(msgEntity.ExtKeyResumeInfo)]), &resumeInfo)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return resumeInfo
|
||||
}
|
||||
|
||||
func buildSendRunRecord(_ context.Context, runRecord *entity.RunRecordMeta, runStatus entity.RunStatus) *entity.ChunkRunItem {
|
||||
return &entity.ChunkRunItem{
|
||||
ID: runRecord.ID,
|
||||
ConversationID: runRecord.ConversationID,
|
||||
AgentID: runRecord.AgentID,
|
||||
SectionID: runRecord.SectionID,
|
||||
Status: runStatus,
|
||||
CreatedAt: runRecord.CreatedAt,
|
||||
}
|
||||
}
|
||||
428
backend/domain/conversation/agentrun/internal/message_event.go
Normal file
428
backend/domain/conversation/agentrun/internal/message_event.go
Normal file
@@ -0,0 +1,428 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/mohae/deepcopy"
|
||||
|
||||
crossmessage "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
|
||||
crossworkflow "github.com/coze-dev/coze-studio/backend/crossdomain/contract/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
|
||||
msgEntity "github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/types/consts"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
type Event struct {
|
||||
}
|
||||
|
||||
func NewMessageEvent() *Event {
|
||||
return &Event{}
|
||||
}
|
||||
|
||||
func (e *Event) buildMessageEvent(runEvent entity.RunEvent, chunkMsgItem *entity.ChunkMessageItem) *entity.AgentRunResponse {
|
||||
return &entity.AgentRunResponse{
|
||||
Event: runEvent,
|
||||
ChunkMessageItem: chunkMsgItem,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Event) buildRunEvent(runEvent entity.RunEvent, chunkRunItem *entity.ChunkRunItem) *entity.AgentRunResponse {
|
||||
return &entity.AgentRunResponse{
|
||||
Event: runEvent,
|
||||
ChunkRunItem: chunkRunItem,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Event) buildErrEvent(runEvent entity.RunEvent, err *entity.RunError) *entity.AgentRunResponse {
|
||||
return &entity.AgentRunResponse{
|
||||
Event: runEvent,
|
||||
Error: err,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Event) buildStreamDoneEvent() *entity.AgentRunResponse {
|
||||
|
||||
return &entity.AgentRunResponse{
|
||||
Event: entity.RunEventStreamDone,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Event) SendRunEvent(runEvent entity.RunEvent, runItem *entity.ChunkRunItem, sw *schema.StreamWriter[*entity.AgentRunResponse]) {
|
||||
resp := e.buildRunEvent(runEvent, runItem)
|
||||
sw.Send(resp, nil)
|
||||
}
|
||||
|
||||
func (e *Event) SendMsgEvent(runEvent entity.RunEvent, messageItem *entity.ChunkMessageItem, sw *schema.StreamWriter[*entity.AgentRunResponse]) {
|
||||
resp := e.buildMessageEvent(runEvent, messageItem)
|
||||
sw.Send(resp, nil)
|
||||
}
|
||||
|
||||
func (e *Event) SendErrEvent(runEvent entity.RunEvent, sw *schema.StreamWriter[*entity.AgentRunResponse], err *entity.RunError) {
|
||||
resp := e.buildErrEvent(runEvent, err)
|
||||
sw.Send(resp, nil)
|
||||
}
|
||||
|
||||
func (e *Event) SendStreamDoneEvent(sw *schema.StreamWriter[*entity.AgentRunResponse]) {
|
||||
resp := e.buildStreamDoneEvent()
|
||||
sw.Send(resp, nil)
|
||||
}
|
||||
|
||||
type MesssageEventHanlder struct {
|
||||
messageEvent *Event
|
||||
sw *schema.StreamWriter[*entity.AgentRunResponse]
|
||||
}
|
||||
|
||||
func (mh *MesssageEventHanlder) handlerErr(_ context.Context, err error) {
|
||||
|
||||
var errMsg string
|
||||
var statusErr errorx.StatusError
|
||||
if errors.As(err, &statusErr) {
|
||||
errMsg = statusErr.Msg()
|
||||
} else {
|
||||
if strings.ToLower(os.Getenv(consts.RunMode)) != "debug" {
|
||||
errMsg = "Internal Server Error"
|
||||
} else {
|
||||
errMsg = errorx.ErrorWithoutStack(err)
|
||||
}
|
||||
}
|
||||
|
||||
mh.messageEvent.SendErrEvent(entity.RunEventError, mh.sw, &entity.RunError{
|
||||
Code: errno.ErrAgentRun,
|
||||
Msg: errMsg,
|
||||
})
|
||||
}
|
||||
|
||||
func (mh *MesssageEventHanlder) handlerAckMessage(_ context.Context, input *msgEntity.Message) error {
|
||||
sendMsg := &entity.ChunkMessageItem{
|
||||
ID: input.ID,
|
||||
ConversationID: input.ConversationID,
|
||||
SectionID: input.SectionID,
|
||||
AgentID: input.AgentID,
|
||||
Role: entity.RoleType(input.Role),
|
||||
MessageType: message.MessageTypeAck,
|
||||
ReplyID: input.ID,
|
||||
Content: input.Content,
|
||||
ContentType: message.ContentTypeText,
|
||||
IsFinish: true,
|
||||
}
|
||||
|
||||
mh.messageEvent.SendMsgEvent(entity.RunEventAck, sendMsg, mh.sw)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mh *MesssageEventHanlder) handlerFunctionCall(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime) error {
|
||||
cm := buildAgentMessage2Create(ctx, chunk, message.MessageTypeFunctionCall, rtDependence)
|
||||
|
||||
cmData, err := crossmessage.DefaultSVC().Create(ctx, cm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sendMsg := buildSendMsg(ctx, cmData, true, rtDependence)
|
||||
|
||||
mh.messageEvent.SendMsgEvent(entity.RunEventMessageCompleted, sendMsg, mh.sw)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mh *MesssageEventHanlder) handlerTooResponse(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime, preToolResponseMsg *msgEntity.Message, toolResponseMsgContent string) error {
|
||||
|
||||
cm := buildAgentMessage2Create(ctx, chunk, message.MessageTypeToolResponse, rtDependence)
|
||||
|
||||
var cmData *message.Message
|
||||
var err error
|
||||
|
||||
if preToolResponseMsg != nil {
|
||||
cm.ID = preToolResponseMsg.ID
|
||||
cm.CreatedAt = preToolResponseMsg.CreatedAt
|
||||
cm.UpdatedAt = preToolResponseMsg.UpdatedAt
|
||||
if len(toolResponseMsgContent) > 0 {
|
||||
cm.Content = toolResponseMsgContent + "\n" + cm.Content
|
||||
}
|
||||
}
|
||||
|
||||
cmData, err = crossmessage.DefaultSVC().Create(ctx, cm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sendMsg := buildSendMsg(ctx, cmData, true, rtDependence)
|
||||
|
||||
mh.messageEvent.SendMsgEvent(entity.RunEventMessageCompleted, sendMsg, mh.sw)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mh *MesssageEventHanlder) handlerSuggest(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime) error {
|
||||
cm := buildAgentMessage2Create(ctx, chunk, message.MessageTypeFlowUp, rtDependence)
|
||||
|
||||
cmData, err := crossmessage.DefaultSVC().Create(ctx, cm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sendMsg := buildSendMsg(ctx, cmData, true, rtDependence)
|
||||
|
||||
mh.messageEvent.SendMsgEvent(entity.RunEventMessageCompleted, sendMsg, mh.sw)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mh *MesssageEventHanlder) handlerKnowledge(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime) error {
|
||||
cm := buildAgentMessage2Create(ctx, chunk, message.MessageTypeKnowledge, rtDependence)
|
||||
cmData, err := crossmessage.DefaultSVC().Create(ctx, cm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sendMsg := buildSendMsg(ctx, cmData, true, rtDependence)
|
||||
|
||||
mh.messageEvent.SendMsgEvent(entity.RunEventMessageCompleted, sendMsg, mh.sw)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mh *MesssageEventHanlder) handlerAnswer(ctx context.Context, msg *entity.ChunkMessageItem, usage *msgEntity.UsageExt, rtDependence *AgentRuntime, preAnswerMsg *msgEntity.Message) error {
|
||||
|
||||
if len(msg.Content) == 0 && len(ptr.From(msg.ReasoningContent)) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
msg.IsFinish = true
|
||||
|
||||
if msg.Ext == nil {
|
||||
msg.Ext = map[string]string{}
|
||||
}
|
||||
if usage != nil {
|
||||
msg.Ext[string(msgEntity.MessageExtKeyToken)] = strconv.FormatInt(usage.TotalCount, 10)
|
||||
msg.Ext[string(msgEntity.MessageExtKeyInputTokens)] = strconv.FormatInt(usage.InputTokens, 10)
|
||||
msg.Ext[string(msgEntity.MessageExtKeyOutputTokens)] = strconv.FormatInt(usage.OutputTokens, 10)
|
||||
|
||||
rtDependence.Usage = &agentrun.Usage{
|
||||
LlmPromptTokens: usage.InputTokens,
|
||||
LlmCompletionTokens: usage.OutputTokens,
|
||||
LlmTotalTokens: usage.TotalCount,
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := msg.Ext[string(msgEntity.MessageExtKeyTimeCost)]; !ok {
|
||||
msg.Ext[string(msgEntity.MessageExtKeyTimeCost)] = fmt.Sprintf("%.1f", float64(time.Since(rtDependence.GetStartTime()).Milliseconds())/1000.00)
|
||||
}
|
||||
|
||||
buildModelContent := &schema.Message{
|
||||
Role: schema.Assistant,
|
||||
Content: msg.Content,
|
||||
}
|
||||
|
||||
mc, err := json.Marshal(buildModelContent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
preAnswerMsg.Content = msg.Content
|
||||
preAnswerMsg.ReasoningContent = ptr.From(msg.ReasoningContent)
|
||||
preAnswerMsg.Ext = msg.Ext
|
||||
preAnswerMsg.ContentType = msg.ContentType
|
||||
preAnswerMsg.ModelContent = string(mc)
|
||||
preAnswerMsg.CreatedAt = 0
|
||||
preAnswerMsg.UpdatedAt = 0
|
||||
|
||||
_, err = crossmessage.DefaultSVC().Create(ctx, preAnswerMsg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mh.messageEvent.SendMsgEvent(entity.RunEventMessageCompleted, msg, mh.sw)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mh *MesssageEventHanlder) handlerFinalAnswerFinish(ctx context.Context, rtDependence *AgentRuntime) error {
|
||||
cm := buildAgentMessage2Create(ctx, nil, message.MessageTypeVerbose, rtDependence)
|
||||
cmData, err := crossmessage.DefaultSVC().Create(ctx, cm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sendMsg := buildSendMsg(ctx, cmData, true, rtDependence)
|
||||
|
||||
mh.messageEvent.SendMsgEvent(entity.RunEventMessageCompleted, sendMsg, mh.sw)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mh *MesssageEventHanlder) handlerInterruptVerbose(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime) error {
|
||||
cm := buildAgentMessage2Create(ctx, chunk, message.MessageTypeInterrupt, rtDependence)
|
||||
cmData, err := crossmessage.DefaultSVC().Create(ctx, cm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sendMsg := buildSendMsg(ctx, cmData, true, rtDependence)
|
||||
|
||||
mh.messageEvent.SendMsgEvent(entity.RunEventMessageCompleted, sendMsg, mh.sw)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mh *MesssageEventHanlder) handlerWfUsage(ctx context.Context, msg *entity.ChunkMessageItem, usage *msgEntity.UsageExt) error {
|
||||
|
||||
if msg.Ext == nil {
|
||||
msg.Ext = map[string]string{}
|
||||
}
|
||||
if usage != nil {
|
||||
msg.Ext[string(msgEntity.MessageExtKeyToken)] = strconv.FormatInt(usage.TotalCount, 10)
|
||||
msg.Ext[string(msgEntity.MessageExtKeyInputTokens)] = strconv.FormatInt(usage.InputTokens, 10)
|
||||
msg.Ext[string(msgEntity.MessageExtKeyOutputTokens)] = strconv.FormatInt(usage.OutputTokens, 10)
|
||||
}
|
||||
|
||||
_, err := crossmessage.DefaultSVC().Edit(ctx, &msgEntity.Message{
|
||||
ID: msg.ID,
|
||||
Ext: msg.Ext,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mh.messageEvent.SendMsgEvent(entity.RunEventMessageCompleted, msg, mh.sw)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mh *MesssageEventHanlder) handlerInterrupt(ctx context.Context, chunk *entity.AgentRespEvent, rtDependence *AgentRuntime, firstAnswerMsg *msgEntity.Message, reasoningContent string) error {
|
||||
interruptData, cType, err := parseInterruptData(ctx, chunk.Interrupt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
preMsg, err := preCreateAnswer(ctx, rtDependence)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
deltaAnswer := &entity.ChunkMessageItem{
|
||||
ID: preMsg.ID,
|
||||
ConversationID: preMsg.ConversationID,
|
||||
SectionID: preMsg.SectionID,
|
||||
RunID: preMsg.RunID,
|
||||
AgentID: preMsg.AgentID,
|
||||
Role: entity.RoleType(preMsg.Role),
|
||||
Content: interruptData,
|
||||
MessageType: preMsg.MessageType,
|
||||
ContentType: cType,
|
||||
ReplyID: preMsg.RunID,
|
||||
Ext: preMsg.Ext,
|
||||
IsFinish: false,
|
||||
}
|
||||
|
||||
mh.messageEvent.SendMsgEvent(entity.RunEventMessageDelta, deltaAnswer, mh.sw)
|
||||
finalAnswer := deepcopy.Copy(deltaAnswer).(*entity.ChunkMessageItem)
|
||||
if len(reasoningContent) > 0 && firstAnswerMsg == nil {
|
||||
finalAnswer.ReasoningContent = ptr.Of(reasoningContent)
|
||||
}
|
||||
usage := func() *msgEntity.UsageExt {
|
||||
if rtDependence.GetUsage() != nil {
|
||||
return &msgEntity.UsageExt{
|
||||
TotalCount: rtDependence.GetUsage().LlmTotalTokens,
|
||||
InputTokens: rtDependence.GetUsage().LlmPromptTokens,
|
||||
OutputTokens: rtDependence.GetUsage().LlmCompletionTokens,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
err = mh.handlerAnswer(ctx, finalAnswer, usage(), rtDependence, preMsg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = mh.handlerInterruptVerbose(ctx, chunk, rtDependence)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mh *MesssageEventHanlder) handlerWfInterruptMsg(ctx context.Context, stateMsg *crossworkflow.StateMessage, rtDependence *AgentRuntime) {
|
||||
interruptData, cType, err := handlerWfInterruptEvent(ctx, stateMsg.InterruptEvent)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
preMsg, err := preCreateAnswer(ctx, rtDependence)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
deltaAnswer := &entity.ChunkMessageItem{
|
||||
ID: preMsg.ID,
|
||||
ConversationID: preMsg.ConversationID,
|
||||
SectionID: preMsg.SectionID,
|
||||
RunID: preMsg.RunID,
|
||||
AgentID: preMsg.AgentID,
|
||||
Role: entity.RoleType(preMsg.Role),
|
||||
Content: interruptData,
|
||||
MessageType: preMsg.MessageType,
|
||||
ContentType: cType,
|
||||
ReplyID: preMsg.RunID,
|
||||
Ext: preMsg.Ext,
|
||||
IsFinish: false,
|
||||
}
|
||||
|
||||
mh.messageEvent.SendMsgEvent(entity.RunEventMessageDelta, deltaAnswer, mh.sw)
|
||||
finalAnswer := deepcopy.Copy(deltaAnswer).(*entity.ChunkMessageItem)
|
||||
|
||||
err = mh.handlerAnswer(ctx, finalAnswer, nil, rtDependence, preMsg)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = mh.handlerInterruptVerbose(ctx, &entity.AgentRespEvent{
|
||||
EventType: message.MessageTypeInterrupt,
|
||||
Interrupt: &singleagent.InterruptInfo{
|
||||
|
||||
InterruptType: singleagent.InterruptEventType(stateMsg.InterruptEvent.EventType),
|
||||
InterruptID: strconv.FormatInt(stateMsg.InterruptEvent.ID, 10),
|
||||
ChatflowInterrupt: stateMsg,
|
||||
},
|
||||
}, rtDependence)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (mh *MesssageEventHanlder) HandlerInput(ctx context.Context, rtDependence *AgentRuntime) (*msgEntity.Message, error) {
|
||||
msgMeta := buildAgentMessage2Create(ctx, nil, message.MessageTypeQuestion, rtDependence)
|
||||
|
||||
cm, err := crossmessage.DefaultSVC().Create(ctx, msgMeta)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ackErr := mh.handlerAckMessage(ctx, cm)
|
||||
if ackErr != nil {
|
||||
return msgMeta, ackErr
|
||||
}
|
||||
return cm, nil
|
||||
}
|
||||
214
backend/domain/conversation/agentrun/internal/run.go
Normal file
214
backend/domain/conversation/agentrun/internal/run.go
Normal file
@@ -0,0 +1,214 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/app/bot_common"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
|
||||
crossmessage "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/repository"
|
||||
msgEntity "github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
type AgentRuntime struct {
|
||||
RunRecord *entity.RunRecordMeta
|
||||
AgentInfo *singleagent.SingleAgent
|
||||
QuestionMsgID int64
|
||||
RunMeta *entity.AgentRunMeta
|
||||
StartTime time.Time
|
||||
Input *msgEntity.Message
|
||||
HistoryMsg []*msgEntity.Message
|
||||
Usage *agentrun.Usage
|
||||
SW *schema.StreamWriter[*entity.AgentRunResponse]
|
||||
|
||||
RunProcess *RunProcess
|
||||
RunRecordRepo repository.RunRecordRepo
|
||||
ImagexClient imagex.ImageX
|
||||
MessageEvent *Event
|
||||
}
|
||||
|
||||
func (rd *AgentRuntime) SetRunRecord(runRecord *entity.RunRecordMeta) {
|
||||
rd.RunRecord = runRecord
|
||||
}
|
||||
|
||||
func (rd *AgentRuntime) GetRunRecord() *entity.RunRecordMeta {
|
||||
return rd.RunRecord
|
||||
}
|
||||
|
||||
func (rd *AgentRuntime) SetUsage(usage *agentrun.Usage) {
|
||||
rd.Usage = usage
|
||||
}
|
||||
func (rd *AgentRuntime) GetUsage() *agentrun.Usage {
|
||||
return rd.Usage
|
||||
}
|
||||
|
||||
func (rd *AgentRuntime) SetRunMeta(arm *entity.AgentRunMeta) {
|
||||
rd.RunMeta = arm
|
||||
}
|
||||
func (rd *AgentRuntime) GetRunMeta() *entity.AgentRunMeta {
|
||||
return rd.RunMeta
|
||||
}
|
||||
func (rd *AgentRuntime) SetAgentInfo(agentInfo *singleagent.SingleAgent) {
|
||||
rd.AgentInfo = agentInfo
|
||||
}
|
||||
func (rd *AgentRuntime) GetAgentInfo() *singleagent.SingleAgent {
|
||||
return rd.AgentInfo
|
||||
}
|
||||
func (rd *AgentRuntime) SetQuestionMsgID(msgID int64) {
|
||||
rd.QuestionMsgID = msgID
|
||||
}
|
||||
func (rd *AgentRuntime) GetQuestionMsgID() int64 {
|
||||
return rd.QuestionMsgID
|
||||
}
|
||||
func (rd *AgentRuntime) SetStartTime(t time.Time) {
|
||||
rd.StartTime = t
|
||||
}
|
||||
func (rd *AgentRuntime) GetStartTime() time.Time {
|
||||
return rd.StartTime
|
||||
}
|
||||
func (rd *AgentRuntime) SetInput(input *msgEntity.Message) {
|
||||
rd.Input = input
|
||||
}
|
||||
func (rd *AgentRuntime) GetInput() *msgEntity.Message {
|
||||
return rd.Input
|
||||
}
|
||||
|
||||
func (rd *AgentRuntime) SetHistoryMsg(histroyMsg []*msgEntity.Message) {
|
||||
rd.HistoryMsg = histroyMsg
|
||||
}
|
||||
|
||||
func (rd *AgentRuntime) GetHistory() []*msgEntity.Message {
|
||||
return rd.HistoryMsg
|
||||
}
|
||||
|
||||
func (art *AgentRuntime) Run(ctx context.Context) (err error) {
|
||||
|
||||
agentInfo, err := getAgentInfo(ctx, art.GetRunMeta().AgentID, art.GetRunMeta().IsDraft)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
art.SetAgentInfo(agentInfo)
|
||||
|
||||
history, err := art.getHistory(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
runRecord, err := art.createRunRecord(ctx)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
art.SetRunRecord(runRecord)
|
||||
art.SetHistoryMsg(history)
|
||||
|
||||
defer func() {
|
||||
srRecord := buildSendRunRecord(ctx, runRecord, entity.RunStatusCompleted)
|
||||
if err != nil {
|
||||
srRecord.Error = &entity.RunError{
|
||||
Code: errno.ErrConversationAgentRunError,
|
||||
Msg: err.Error(),
|
||||
}
|
||||
art.RunProcess.StepToFailed(ctx, srRecord, art.SW)
|
||||
return
|
||||
}
|
||||
art.RunProcess.StepToComplete(ctx, srRecord, art.SW, art.GetUsage())
|
||||
}()
|
||||
mh := &MesssageEventHanlder{
|
||||
messageEvent: art.MessageEvent,
|
||||
sw: art.SW,
|
||||
}
|
||||
input, err := mh.HandlerInput(ctx, art)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
art.SetInput(input)
|
||||
|
||||
art.SetQuestionMsgID(input.ID)
|
||||
|
||||
if art.GetAgentInfo().BotMode == bot_common.BotMode_WorkflowMode {
|
||||
err = art.ChatflowRun(ctx, art.ImagexClient)
|
||||
} else {
|
||||
err = art.AgentStreamExecute(ctx, art.ImagexClient)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (art *AgentRuntime) getHistory(ctx context.Context) ([]*msgEntity.Message, error) {
|
||||
|
||||
conversationTurns := getAgentHistoryRounds(art.GetAgentInfo())
|
||||
|
||||
runRecordList, err := art.RunRecordRepo.List(ctx, &entity.ListRunRecordMeta{
|
||||
ConversationID: art.GetRunMeta().ConversationID,
|
||||
SectionID: art.GetRunMeta().SectionID,
|
||||
Limit: conversationTurns,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(runRecordList) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
runIDS := concactRunID(runRecordList)
|
||||
history, err := crossmessage.DefaultSVC().GetByRunIDs(ctx, art.GetRunMeta().ConversationID, runIDS)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return history, nil
|
||||
}
|
||||
|
||||
func concactRunID(rr []*entity.RunRecordMeta) []int64 {
|
||||
ids := make([]int64, 0, len(rr))
|
||||
for _, c := range rr {
|
||||
ids = append(ids, c.ID)
|
||||
}
|
||||
|
||||
return ids
|
||||
}
|
||||
|
||||
func (art *AgentRuntime) createRunRecord(ctx context.Context) (*entity.RunRecordMeta, error) {
|
||||
runPoData, err := art.RunRecordRepo.Create(ctx, art.GetRunMeta())
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "RunRecordRepo.Create error: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
srRecord := buildSendRunRecord(ctx, runPoData, entity.RunStatusCreated)
|
||||
|
||||
art.RunProcess.StepToCreate(ctx, srRecord, art.SW)
|
||||
|
||||
err = art.RunProcess.StepToInProgress(ctx, srRecord, art.SW)
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "runProcess.StepToInProgress error: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
return runPoData, nil
|
||||
}
|
||||
@@ -30,8 +30,8 @@ import (
|
||||
)
|
||||
|
||||
type RunProcess struct {
|
||||
event *Event
|
||||
|
||||
event *Event
|
||||
SW *schema.StreamWriter[*entity.AgentRunResponse]
|
||||
RunRecordRepo repository.RunRecordRepo
|
||||
}
|
||||
|
||||
@@ -115,7 +115,6 @@ func (r *RunProcess) StepToFailed(ctx context.Context, srRecord *entity.ChunkRun
|
||||
Code: srRecord.Error.Code,
|
||||
Msg: srRecord.Error.Msg,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (r *RunProcess) StepToDone(sw *schema.StreamWriter[*entity.AgentRunResponse]) {
|
||||
450
backend/domain/conversation/agentrun/internal/singleagent_run.go
Normal file
450
backend/domain/conversation/agentrun/internal/singleagent_run.go
Normal file
@@ -0,0 +1,450 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package internal
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/mohae/deepcopy"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
|
||||
crossagent "github.com/coze-dev/coze-studio/backend/crossdomain/contract/agent"
|
||||
crossmessage "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
|
||||
msgEntity "github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
func (art *AgentRuntime) AgentStreamExecute(ctx context.Context, imagex imagex.ImageX) (err error) {
|
||||
mainChan := make(chan *entity.AgentRespEvent, 100)
|
||||
|
||||
ar := &crossagent.AgentRuntime{
|
||||
AgentVersion: art.GetRunMeta().Version,
|
||||
SpaceID: art.GetRunMeta().SpaceID,
|
||||
AgentID: art.GetRunMeta().AgentID,
|
||||
IsDraft: art.GetRunMeta().IsDraft,
|
||||
UserID: art.GetRunMeta().UserID,
|
||||
ConnectorID: art.GetRunMeta().ConnectorID,
|
||||
PreRetrieveTools: art.GetRunMeta().PreRetrieveTools,
|
||||
Input: transMessageToSchemaMessage(ctx, []*msgEntity.Message{art.GetInput()}, imagex)[0],
|
||||
HistoryMsg: transMessageToSchemaMessage(ctx, historyPairs(art.GetHistory()), imagex),
|
||||
ResumeInfo: parseResumeInfo(ctx, art.GetHistory()),
|
||||
}
|
||||
|
||||
streamer, err := crossagent.DefaultSVC().StreamExecute(ctx, ar)
|
||||
if err != nil {
|
||||
return errors.New(errorx.ErrorWithoutStack(err))
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
safego.Go(ctx, func() {
|
||||
defer wg.Done()
|
||||
art.pull(ctx, mainChan, streamer)
|
||||
})
|
||||
|
||||
safego.Go(ctx, func() {
|
||||
defer wg.Done()
|
||||
art.push(ctx, mainChan)
|
||||
})
|
||||
|
||||
wg.Wait()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (art *AgentRuntime) push(ctx context.Context, mainChan chan *entity.AgentRespEvent) {
|
||||
|
||||
mh := &MesssageEventHanlder{
|
||||
sw: art.SW,
|
||||
messageEvent: art.MessageEvent,
|
||||
}
|
||||
|
||||
var err error
|
||||
defer func() {
|
||||
if err != nil {
|
||||
logs.CtxErrorf(ctx, "run.push error: %v", err)
|
||||
mh.handlerErr(ctx, err)
|
||||
}
|
||||
}()
|
||||
|
||||
reasoningContent := bytes.NewBuffer([]byte{})
|
||||
|
||||
var firstAnswerMsg *msgEntity.Message
|
||||
var reasoningMsg *msgEntity.Message
|
||||
isSendFinishAnswer := false
|
||||
var preToolResponseMsg *msgEntity.Message
|
||||
toolResponseMsgContent := bytes.NewBuffer([]byte{})
|
||||
for {
|
||||
chunk, ok := <-mainChan
|
||||
if !ok || chunk == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if chunk.Err != nil {
|
||||
if errors.Is(chunk.Err, io.EOF) {
|
||||
if !isSendFinishAnswer {
|
||||
isSendFinishAnswer = true
|
||||
if firstAnswerMsg != nil && len(reasoningContent.String()) > 0 {
|
||||
art.saveReasoningContent(ctx, firstAnswerMsg, reasoningContent.String())
|
||||
reasoningContent.Reset()
|
||||
}
|
||||
|
||||
finishErr := mh.handlerFinalAnswerFinish(ctx, art)
|
||||
if finishErr != nil {
|
||||
err = finishErr
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
mh.handlerErr(ctx, chunk.Err)
|
||||
return
|
||||
}
|
||||
|
||||
switch chunk.EventType {
|
||||
case message.MessageTypeFunctionCall:
|
||||
if chunk.FuncCall != nil && chunk.FuncCall.ResponseMeta != nil {
|
||||
if usage := handlerUsage(chunk.FuncCall.ResponseMeta); usage != nil {
|
||||
art.SetUsage(&agentrun.Usage{
|
||||
LlmPromptTokens: usage.InputTokens,
|
||||
LlmCompletionTokens: usage.OutputTokens,
|
||||
LlmTotalTokens: usage.TotalCount,
|
||||
})
|
||||
}
|
||||
}
|
||||
err = mh.handlerFunctionCall(ctx, chunk, art)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if preToolResponseMsg == nil {
|
||||
var cErr error
|
||||
preToolResponseMsg, cErr = preCreateAnswer(ctx, art)
|
||||
if cErr != nil {
|
||||
err = cErr
|
||||
return
|
||||
}
|
||||
}
|
||||
case message.MessageTypeToolResponse:
|
||||
err = mh.handlerTooResponse(ctx, chunk, art, preToolResponseMsg, toolResponseMsgContent.String())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
preToolResponseMsg = nil // reset
|
||||
case message.MessageTypeKnowledge:
|
||||
err = mh.handlerKnowledge(ctx, chunk, art)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
case message.MessageTypeToolMidAnswer:
|
||||
fullMidAnswerContent := bytes.NewBuffer([]byte{})
|
||||
var usage *msgEntity.UsageExt
|
||||
toolMidAnswerMsg, cErr := preCreateAnswer(ctx, art)
|
||||
|
||||
if cErr != nil {
|
||||
err = cErr
|
||||
return
|
||||
}
|
||||
|
||||
var preMsgIsFinish = false
|
||||
for {
|
||||
streamMsg, receErr := chunk.ToolMidAnswer.Recv()
|
||||
if receErr != nil {
|
||||
if errors.Is(receErr, io.EOF) {
|
||||
break
|
||||
}
|
||||
err = receErr
|
||||
return
|
||||
}
|
||||
if preMsgIsFinish {
|
||||
toolMidAnswerMsg, cErr = preCreateAnswer(ctx, art)
|
||||
if cErr != nil {
|
||||
err = cErr
|
||||
return
|
||||
}
|
||||
preMsgIsFinish = false
|
||||
}
|
||||
if streamMsg == nil {
|
||||
continue
|
||||
}
|
||||
if firstAnswerMsg == nil && len(streamMsg.Content) > 0 {
|
||||
if reasoningMsg != nil {
|
||||
toolMidAnswerMsg = deepcopy.Copy(reasoningMsg).(*msgEntity.Message)
|
||||
}
|
||||
firstAnswerMsg = deepcopy.Copy(toolMidAnswerMsg).(*msgEntity.Message)
|
||||
}
|
||||
|
||||
if streamMsg.Extra != nil {
|
||||
if val, ok := streamMsg.Extra["workflow_node_name"]; ok && val != nil {
|
||||
toolMidAnswerMsg.Ext["message_title"] = val.(string)
|
||||
}
|
||||
}
|
||||
|
||||
sendMidAnswerMsg := buildSendMsg(ctx, toolMidAnswerMsg, false, art)
|
||||
sendMidAnswerMsg.Content = streamMsg.Content
|
||||
toolResponseMsgContent.WriteString(streamMsg.Content)
|
||||
fullMidAnswerContent.WriteString(streamMsg.Content)
|
||||
|
||||
art.MessageEvent.SendMsgEvent(entity.RunEventMessageDelta, sendMidAnswerMsg, art.SW)
|
||||
|
||||
if streamMsg != nil && streamMsg.ResponseMeta != nil {
|
||||
usage = handlerUsage(streamMsg.ResponseMeta)
|
||||
}
|
||||
|
||||
if streamMsg.Extra["is_finish"] == true {
|
||||
preMsgIsFinish = true
|
||||
sendMidAnswerMsg := buildSendMsg(ctx, toolMidAnswerMsg, false, art)
|
||||
sendMidAnswerMsg.Content = fullMidAnswerContent.String()
|
||||
fullMidAnswerContent.Reset()
|
||||
hfErr := mh.handlerAnswer(ctx, sendMidAnswerMsg, usage, art, toolMidAnswerMsg)
|
||||
if hfErr != nil {
|
||||
err = hfErr
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case message.MessageTypeToolAsAnswer:
|
||||
var usage *msgEntity.UsageExt
|
||||
fullContent := bytes.NewBuffer([]byte{})
|
||||
toolAsAnswerMsg, cErr := preCreateAnswer(ctx, art)
|
||||
if cErr != nil {
|
||||
err = cErr
|
||||
return
|
||||
}
|
||||
if firstAnswerMsg == nil {
|
||||
firstAnswerMsg = toolAsAnswerMsg
|
||||
}
|
||||
|
||||
for {
|
||||
streamMsg, receErr := chunk.ToolAsAnswer.Recv()
|
||||
if receErr != nil {
|
||||
if errors.Is(receErr, io.EOF) {
|
||||
|
||||
answer := buildSendMsg(ctx, toolAsAnswerMsg, false, art)
|
||||
answer.Content = fullContent.String()
|
||||
hfErr := mh.handlerAnswer(ctx, answer, usage, art, toolAsAnswerMsg)
|
||||
if hfErr != nil {
|
||||
err = hfErr
|
||||
return
|
||||
}
|
||||
break
|
||||
}
|
||||
err = receErr
|
||||
return
|
||||
}
|
||||
|
||||
if streamMsg != nil && streamMsg.ResponseMeta != nil {
|
||||
usage = handlerUsage(streamMsg.ResponseMeta)
|
||||
}
|
||||
sendMsg := buildSendMsg(ctx, toolAsAnswerMsg, false, art)
|
||||
fullContent.WriteString(streamMsg.Content)
|
||||
sendMsg.Content = streamMsg.Content
|
||||
art.MessageEvent.SendMsgEvent(entity.RunEventMessageDelta, sendMsg, art.SW)
|
||||
}
|
||||
|
||||
case message.MessageTypeAnswer:
|
||||
fullContent := bytes.NewBuffer([]byte{})
|
||||
var usage *msgEntity.UsageExt
|
||||
var isToolCalls = false
|
||||
var modelAnswerMsg *msgEntity.Message
|
||||
for {
|
||||
streamMsg, receErr := chunk.ModelAnswer.Recv()
|
||||
if receErr != nil {
|
||||
if errors.Is(receErr, io.EOF) {
|
||||
|
||||
if isToolCalls {
|
||||
break
|
||||
}
|
||||
if modelAnswerMsg == nil {
|
||||
break
|
||||
}
|
||||
answer := buildSendMsg(ctx, modelAnswerMsg, false, art)
|
||||
answer.Content = fullContent.String()
|
||||
hfErr := mh.handlerAnswer(ctx, answer, usage, art, modelAnswerMsg)
|
||||
if hfErr != nil {
|
||||
err = hfErr
|
||||
return
|
||||
}
|
||||
break
|
||||
}
|
||||
err = receErr
|
||||
return
|
||||
}
|
||||
|
||||
if streamMsg != nil && len(streamMsg.ToolCalls) > 0 {
|
||||
isToolCalls = true
|
||||
}
|
||||
|
||||
if streamMsg != nil && streamMsg.ResponseMeta != nil {
|
||||
usage = handlerUsage(streamMsg.ResponseMeta)
|
||||
}
|
||||
|
||||
if streamMsg != nil && len(streamMsg.ReasoningContent) == 0 && len(streamMsg.Content) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(streamMsg.ReasoningContent) > 0 {
|
||||
if reasoningMsg == nil {
|
||||
reasoningMsg, err = preCreateAnswer(ctx, art)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
sendReasoningMsg := buildSendMsg(ctx, reasoningMsg, false, art)
|
||||
reasoningContent.WriteString(streamMsg.ReasoningContent)
|
||||
sendReasoningMsg.ReasoningContent = ptr.Of(streamMsg.ReasoningContent)
|
||||
art.MessageEvent.SendMsgEvent(entity.RunEventMessageDelta, sendReasoningMsg, art.SW)
|
||||
}
|
||||
if len(streamMsg.Content) > 0 {
|
||||
|
||||
if modelAnswerMsg == nil {
|
||||
modelAnswerMsg, err = preCreateAnswer(ctx, art)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if firstAnswerMsg == nil {
|
||||
if reasoningMsg != nil {
|
||||
modelAnswerMsg.ID = reasoningMsg.ID
|
||||
}
|
||||
firstAnswerMsg = modelAnswerMsg
|
||||
}
|
||||
}
|
||||
|
||||
sendAnswerMsg := buildSendMsg(ctx, modelAnswerMsg, false, art)
|
||||
fullContent.WriteString(streamMsg.Content)
|
||||
sendAnswerMsg.Content = streamMsg.Content
|
||||
art.MessageEvent.SendMsgEvent(entity.RunEventMessageDelta, sendAnswerMsg, art.SW)
|
||||
}
|
||||
}
|
||||
|
||||
case message.MessageTypeFlowUp:
|
||||
if isSendFinishAnswer {
|
||||
|
||||
if firstAnswerMsg != nil && len(reasoningContent.String()) > 0 {
|
||||
art.saveReasoningContent(ctx, firstAnswerMsg, reasoningContent.String())
|
||||
}
|
||||
|
||||
isSendFinishAnswer = true
|
||||
finishErr := mh.handlerFinalAnswerFinish(ctx, art)
|
||||
if finishErr != nil {
|
||||
err = finishErr
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err = mh.handlerSuggest(ctx, chunk, art)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
case message.MessageTypeInterrupt:
|
||||
err = mh.handlerInterrupt(ctx, chunk, art, firstAnswerMsg, reasoningContent.String())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (art *AgentRuntime) pull(_ context.Context, mainChan chan *entity.AgentRespEvent, events *schema.StreamReader[*crossagent.AgentEvent]) {
|
||||
defer func() {
|
||||
close(mainChan)
|
||||
}()
|
||||
|
||||
for {
|
||||
rm, re := events.Recv()
|
||||
if re != nil {
|
||||
errChunk := &entity.AgentRespEvent{
|
||||
Err: re,
|
||||
}
|
||||
mainChan <- errChunk
|
||||
return
|
||||
}
|
||||
|
||||
eventType, tErr := transformEventMap(rm.EventType)
|
||||
|
||||
if tErr != nil {
|
||||
errChunk := &entity.AgentRespEvent{
|
||||
Err: tErr,
|
||||
}
|
||||
mainChan <- errChunk
|
||||
return
|
||||
}
|
||||
|
||||
respChunk := &entity.AgentRespEvent{
|
||||
EventType: eventType,
|
||||
ModelAnswer: rm.ChatModelAnswer,
|
||||
ToolsMessage: rm.ToolsMessage,
|
||||
FuncCall: rm.FuncCall,
|
||||
Knowledge: rm.Knowledge,
|
||||
Suggest: rm.Suggest,
|
||||
Interrupt: rm.Interrupt,
|
||||
|
||||
ToolMidAnswer: rm.ToolMidAnswer,
|
||||
ToolAsAnswer: rm.ToolAsChatModelAnswer,
|
||||
}
|
||||
|
||||
mainChan <- respChunk
|
||||
}
|
||||
}
|
||||
|
||||
func transformEventMap(eventType singleagent.EventType) (message.MessageType, error) {
|
||||
var eType message.MessageType
|
||||
switch eventType {
|
||||
case singleagent.EventTypeOfFuncCall:
|
||||
return message.MessageTypeFunctionCall, nil
|
||||
case singleagent.EventTypeOfKnowledge:
|
||||
return message.MessageTypeKnowledge, nil
|
||||
case singleagent.EventTypeOfToolsMessage:
|
||||
return message.MessageTypeToolResponse, nil
|
||||
case singleagent.EventTypeOfChatModelAnswer:
|
||||
return message.MessageTypeAnswer, nil
|
||||
case singleagent.EventTypeOfToolsAsChatModelStream:
|
||||
return message.MessageTypeToolAsAnswer, nil
|
||||
case singleagent.EventTypeOfToolMidAnswer:
|
||||
return message.MessageTypeToolMidAnswer, nil
|
||||
case singleagent.EventTypeOfSuggest:
|
||||
return message.MessageTypeFlowUp, nil
|
||||
case singleagent.EventTypeOfInterrupt:
|
||||
return message.MessageTypeInterrupt, nil
|
||||
}
|
||||
return eType, errorx.New(errno.ErrReplyUnknowEventType)
|
||||
}
|
||||
|
||||
func (art *AgentRuntime) saveReasoningContent(ctx context.Context, firstAnswerMsg *msgEntity.Message, reasoningContent string) {
|
||||
_, err := crossmessage.DefaultSVC().Edit(ctx, &message.Message{
|
||||
ID: firstAnswerMsg.ID,
|
||||
ReasoningContent: reasoningContent,
|
||||
})
|
||||
if err != nil {
|
||||
logs.CtxInfof(ctx, "save reasoning content failed, err: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -23,7 +23,6 @@ import (
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/internal/dal"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/internal/dal/model"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
)
|
||||
|
||||
@@ -34,8 +33,9 @@ func NewRunRecordRepo(db *gorm.DB, idGen idgen.IDGenerator) RunRecordRepo {
|
||||
|
||||
type RunRecordRepo interface {
|
||||
Create(ctx context.Context, runMeta *entity.AgentRunMeta) (*entity.RunRecordMeta, error)
|
||||
GetByID(ctx context.Context, id int64) (*entity.RunRecord, error)
|
||||
GetByID(ctx context.Context, id int64) (*entity.RunRecordMeta, error)
|
||||
Cancel(ctx context.Context, req *entity.CancelRunMeta) (*entity.RunRecordMeta, error)
|
||||
Delete(ctx context.Context, id []int64) error
|
||||
UpdateByID(ctx context.Context, id int64, update *entity.UpdateMeta) error
|
||||
List(ctx context.Context, conversationID int64, sectionID int64, limit int32) ([]*model.RunRecord, error)
|
||||
List(ctx context.Context, meta *entity.ListRunRecordMeta) ([]*entity.RunRecordMeta, error)
|
||||
}
|
||||
|
||||
@@ -26,6 +26,9 @@ import (
|
||||
|
||||
type Run interface {
|
||||
AgentRun(ctx context.Context, req *entity.AgentRunMeta) (*schema.StreamReader[*entity.AgentRunResponse], error)
|
||||
|
||||
Delete(ctx context.Context, runID []int64) error
|
||||
Create(ctx context.Context, runRecord *entity.AgentRunMeta) (*entity.RunRecordMeta, error)
|
||||
List(ctx context.Context, ListMeta *entity.ListRunRecordMeta) ([]*entity.RunRecordMeta, error)
|
||||
GetByID(ctx context.Context, runID int64) (*entity.RunRecordMeta, error)
|
||||
Cancel(ctx context.Context, req *entity.CancelRunMeta) (*entity.RunRecordMeta, error)
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -17,7 +17,18 @@
|
||||
package agentrun
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/internal/dal/model"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/agentrun/repository"
|
||||
mock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/orm"
|
||||
)
|
||||
|
||||
func TestAgentRun(t *testing.T) {
|
||||
@@ -97,3 +108,158 @@ func TestAgentRun(t *testing.T) {
|
||||
// assert.NoError(t, err)
|
||||
|
||||
}
|
||||
|
||||
func TestRunImpl_List(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mockDBGen := orm.NewMockDB()
|
||||
mockDBGen.AddTable(&model.RunRecord{}).AddRows(
|
||||
&model.RunRecord{
|
||||
ID: 1,
|
||||
ConversationID: 123,
|
||||
AgentID: 456,
|
||||
SectionID: 789,
|
||||
UserID: "123456",
|
||||
CreatedAt: time.Now().Unix(),
|
||||
},
|
||||
&model.RunRecord{
|
||||
ID: 2,
|
||||
ConversationID: 123,
|
||||
AgentID: 456,
|
||||
SectionID: 789,
|
||||
UserID: "123456",
|
||||
CreatedAt: time.Now().Unix() + 1,
|
||||
}, &model.RunRecord{
|
||||
ID: 3,
|
||||
ConversationID: 123,
|
||||
AgentID: 456,
|
||||
SectionID: 789,
|
||||
UserID: "123456",
|
||||
CreatedAt: time.Now().Unix() + 2,
|
||||
}, &model.RunRecord{
|
||||
ID: 4,
|
||||
ConversationID: 123,
|
||||
AgentID: 456,
|
||||
SectionID: 789,
|
||||
UserID: "123456",
|
||||
CreatedAt: time.Now().Unix() + 3,
|
||||
}, &model.RunRecord{
|
||||
ID: 5,
|
||||
ConversationID: 123,
|
||||
AgentID: 456,
|
||||
SectionID: 789,
|
||||
UserID: "123456",
|
||||
CreatedAt: time.Now().Unix() + 4,
|
||||
},
|
||||
&model.RunRecord{
|
||||
ID: 6,
|
||||
ConversationID: 123,
|
||||
AgentID: 456,
|
||||
SectionID: 789,
|
||||
UserID: "123456",
|
||||
CreatedAt: time.Now().Unix() + 5,
|
||||
}, &model.RunRecord{
|
||||
ID: 7,
|
||||
ConversationID: 123,
|
||||
AgentID: 456,
|
||||
SectionID: 789,
|
||||
UserID: "123456",
|
||||
CreatedAt: time.Now().Unix() + 6,
|
||||
}, &model.RunRecord{
|
||||
ID: 8,
|
||||
ConversationID: 123,
|
||||
AgentID: 456,
|
||||
SectionID: 789,
|
||||
UserID: "123456",
|
||||
CreatedAt: time.Now().Unix() + 7,
|
||||
}, &model.RunRecord{
|
||||
ID: 9,
|
||||
ConversationID: 123,
|
||||
AgentID: 456,
|
||||
SectionID: 789,
|
||||
UserID: "123456",
|
||||
CreatedAt: time.Now().Unix() + 8,
|
||||
},
|
||||
)
|
||||
mockDB, err := mockDBGen.DB()
|
||||
assert.NoError(t, err)
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockIDGen := mock.NewMockIDGenerator(ctrl)
|
||||
|
||||
runRecordRepo := repository.NewRunRecordRepo(mockDB, mockIDGen)
|
||||
|
||||
service := &runImpl{
|
||||
Components: Components{
|
||||
RunRecordRepo: runRecordRepo,
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("list success", func(t *testing.T) {
|
||||
|
||||
meta := &entity.ListRunRecordMeta{
|
||||
ConversationID: 123,
|
||||
AgentID: 456,
|
||||
SectionID: 789,
|
||||
Limit: 10,
|
||||
OrderBy: "desc",
|
||||
}
|
||||
|
||||
result, err := service.List(ctx, meta)
|
||||
// check result
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, result, 9)
|
||||
assert.Equal(t, int64(123), result[0].ConversationID)
|
||||
assert.Equal(t, int64(456), result[0].AgentID)
|
||||
})
|
||||
|
||||
t.Run("empty list", func(t *testing.T) {
|
||||
meta := &entity.ListRunRecordMeta{
|
||||
ConversationID: 999, //
|
||||
Limit: 10,
|
||||
OrderBy: "desc",
|
||||
}
|
||||
|
||||
// check result
|
||||
result, err := service.List(ctx, meta)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, result)
|
||||
})
|
||||
|
||||
t.Run("search with before id", func(t *testing.T) {
|
||||
|
||||
meta := &entity.ListRunRecordMeta{
|
||||
ConversationID: 123,
|
||||
SectionID: 789,
|
||||
AgentID: 456,
|
||||
BeforeID: 5,
|
||||
Limit: 3,
|
||||
OrderBy: "desc",
|
||||
}
|
||||
|
||||
result, err := service.List(ctx, meta)
|
||||
|
||||
// check result
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, result, 3)
|
||||
assert.Equal(t, int64(4), result[0].ID)
|
||||
})
|
||||
t.Run("search with after id and limit", func(t *testing.T) {
|
||||
|
||||
meta := &entity.ListRunRecordMeta{
|
||||
ConversationID: 123,
|
||||
SectionID: 789,
|
||||
AgentID: 456,
|
||||
AfterID: 5,
|
||||
Limit: 3,
|
||||
OrderBy: "desc",
|
||||
}
|
||||
|
||||
result, err := service.List(ctx, meta)
|
||||
|
||||
// check result
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, result, 3)
|
||||
assert.Equal(t, int64(9), result[0].ID)
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
type Conversation = conversation.Conversation
|
||||
|
||||
type CreateMeta struct {
|
||||
Name string `json:"name"`
|
||||
AgentID int64 `json:"agent_id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
ConnectorID int64 `json:"connector_id"`
|
||||
@@ -50,3 +51,8 @@ type ListMeta struct {
|
||||
Limit int `json:"limit"`
|
||||
Page int `json:"page"`
|
||||
}
|
||||
|
||||
type UpdateMeta struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
@@ -107,6 +107,20 @@ func (dao *ConversationDAO) Delete(ctx context.Context, id int64) (int64, error)
|
||||
return updateRes.RowsAffected, err
|
||||
}
|
||||
|
||||
func (dao *ConversationDAO) Update(ctx context.Context, req *entity.UpdateMeta) (*entity.Conversation, error) {
|
||||
updateColumn := make(map[string]interface{})
|
||||
updateColumn[dao.query.Conversation.UpdatedAt.ColumnName().String()] = time.Now().UnixMilli()
|
||||
if len(req.Name) > 0 {
|
||||
updateColumn[dao.query.Conversation.Name.ColumnName().String()] = req.Name
|
||||
}
|
||||
|
||||
_, err := dao.query.Conversation.WithContext(ctx).Where(dao.query.Conversation.ID.Eq(req.ID)).UpdateColumns(updateColumn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dao.GetByID(ctx, req.ID)
|
||||
}
|
||||
|
||||
func (dao *ConversationDAO) Get(ctx context.Context, userID int64, agentID int64, scene int32, connectorID int64) (*entity.Conversation, error) {
|
||||
po, err := dao.query.Conversation.WithContext(ctx).Debug().
|
||||
Where(dao.query.Conversation.CreatorID.Eq(userID)).
|
||||
@@ -133,13 +147,15 @@ func (dao *ConversationDAO) List(ctx context.Context, userID int64, agentID int6
|
||||
do = do.Where(dao.query.Conversation.CreatorID.Eq(userID)).
|
||||
Where(dao.query.Conversation.AgentID.Eq(agentID)).
|
||||
Where(dao.query.Conversation.Scene.Eq(scene)).
|
||||
Where(dao.query.Conversation.ConnectorID.Eq(connectorID))
|
||||
Where(dao.query.Conversation.ConnectorID.Eq(connectorID)).
|
||||
Where(dao.query.Conversation.Status.Eq(int32(conversation.ConversationStatusNormal)))
|
||||
|
||||
do = do.Offset((page - 1) * limit)
|
||||
|
||||
if limit > 0 {
|
||||
do = do.Limit(int(limit) + 1)
|
||||
}
|
||||
do = do.Order(dao.query.Conversation.CreatedAt.Desc())
|
||||
|
||||
poList, err := do.Find()
|
||||
|
||||
@@ -173,6 +189,7 @@ func (dao *ConversationDAO) conversationDO2PO(ctx context.Context, conversation
|
||||
Ext: conversation.Ext,
|
||||
CreatedAt: time.Now().UnixMilli(),
|
||||
UpdatedAt: time.Now().UnixMilli(),
|
||||
Name: conversation.Name,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -188,6 +205,7 @@ func (dao *ConversationDAO) conversationPO2DO(ctx context.Context, c *model.Conv
|
||||
Ext: c.Ext,
|
||||
CreatedAt: c.CreatedAt,
|
||||
UpdatedAt: c.UpdatedAt,
|
||||
Name: c.Name,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -204,6 +222,7 @@ func (dao *ConversationDAO) conversationBatchPO2DO(ctx context.Context, conversa
|
||||
Ext: c.Ext,
|
||||
CreatedAt: c.CreatedAt,
|
||||
UpdatedAt: c.UpdatedAt,
|
||||
Name: c.Name,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,3 +1,19 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
@@ -9,6 +25,7 @@ const TableNameConversation = "conversation"
|
||||
// Conversation conversation info record
|
||||
type Conversation struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;autoIncrement:true;comment:id" json:"id"` // id
|
||||
Name string `gorm:"column:name;not null;comment:conversation name" json:"name"` // conversation name
|
||||
ConnectorID int64 `gorm:"column:connector_id;not null;comment:Publish Connector ID" json:"connector_id"` // Publish Connector ID
|
||||
AgentID int64 `gorm:"column:agent_id;not null;comment:agent_id" json:"agent_id"` // agent_id
|
||||
Scene int32 `gorm:"column:scene;not null;comment:conversation scene" json:"scene"` // conversation scene
|
||||
|
||||
@@ -1,3 +1,19 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
@@ -28,6 +44,7 @@ func newConversation(db *gorm.DB, opts ...gen.DOOption) conversation {
|
||||
tableName := _conversation.conversationDo.TableName()
|
||||
_conversation.ALL = field.NewAsterisk(tableName)
|
||||
_conversation.ID = field.NewInt64(tableName, "id")
|
||||
_conversation.Name = field.NewString(tableName, "name")
|
||||
_conversation.ConnectorID = field.NewInt64(tableName, "connector_id")
|
||||
_conversation.AgentID = field.NewInt64(tableName, "agent_id")
|
||||
_conversation.Scene = field.NewInt32(tableName, "scene")
|
||||
@@ -49,6 +66,7 @@ type conversation struct {
|
||||
|
||||
ALL field.Asterisk
|
||||
ID field.Int64 // id
|
||||
Name field.String // conversation name
|
||||
ConnectorID field.Int64 // Publish Connector ID
|
||||
AgentID field.Int64 // agent_id
|
||||
Scene field.Int32 // conversation scene
|
||||
@@ -75,6 +93,7 @@ func (c conversation) As(alias string) *conversation {
|
||||
func (c *conversation) updateTableName(table string) *conversation {
|
||||
c.ALL = field.NewAsterisk(table)
|
||||
c.ID = field.NewInt64(table, "id")
|
||||
c.Name = field.NewString(table, "name")
|
||||
c.ConnectorID = field.NewInt64(table, "connector_id")
|
||||
c.AgentID = field.NewInt64(table, "agent_id")
|
||||
c.Scene = field.NewInt32(table, "scene")
|
||||
@@ -100,8 +119,9 @@ func (c *conversation) GetFieldByName(fieldName string) (field.OrderExpr, bool)
|
||||
}
|
||||
|
||||
func (c *conversation) fillFieldMap() {
|
||||
c.fieldMap = make(map[string]field.Expr, 10)
|
||||
c.fieldMap = make(map[string]field.Expr, 11)
|
||||
c.fieldMap["id"] = c.ID
|
||||
c.fieldMap["name"] = c.Name
|
||||
c.fieldMap["connector_id"] = c.ConnectorID
|
||||
c.fieldMap["agent_id"] = c.AgentID
|
||||
c.fieldMap["scene"] = c.Scene
|
||||
|
||||
@@ -35,6 +35,7 @@ type ConversationRepo interface {
|
||||
GetByID(ctx context.Context, id int64) (*entity.Conversation, error)
|
||||
UpdateSection(ctx context.Context, id int64) (int64, error)
|
||||
Get(ctx context.Context, userID int64, agentID int64, scene int32, connectorID int64) (*entity.Conversation, error)
|
||||
Update(ctx context.Context, req *entity.UpdateMeta) (*entity.Conversation, error)
|
||||
Delete(ctx context.Context, id int64) (int64, error)
|
||||
List(ctx context.Context, userID int64, agentID int64, connectorID int64, scene int32, limit int, page int) ([]*entity.Conversation, bool, error)
|
||||
}
|
||||
|
||||
@@ -29,4 +29,5 @@ type Conversation interface {
|
||||
GetCurrentConversation(ctx context.Context, req *entity.GetCurrent) (*entity.Conversation, error)
|
||||
Delete(ctx context.Context, id int64) error
|
||||
List(ctx context.Context, req *entity.ListMeta) ([]*entity.Conversation, bool, error)
|
||||
Update(ctx context.Context, req *entity.UpdateMeta) (*entity.Conversation, error)
|
||||
}
|
||||
|
||||
@@ -101,6 +101,11 @@ func (c *conversationImpl) Delete(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conversationImpl) Update(ctx context.Context, req *entity.UpdateMeta) (*entity.Conversation, error) {
|
||||
// get conversation
|
||||
return c.ConversationRepo.Update(ctx, req)
|
||||
}
|
||||
|
||||
func (c *conversationImpl) List(ctx context.Context, req *entity.ListMeta) ([]*entity.Conversation, bool, error) {
|
||||
conversationList, hasMore, err := c.ConversationRepo.List(ctx, req.UserID, req.AgentID, req.ConnectorID, int32(req.Scene), req.Limit, req.Page)
|
||||
|
||||
|
||||
@@ -16,19 +16,22 @@
|
||||
|
||||
package entity
|
||||
|
||||
import "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
import (
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
)
|
||||
|
||||
type Message = message.Message
|
||||
|
||||
type ListMeta struct {
|
||||
ConversationID int64 `json:"conversation_id"`
|
||||
RunID []*int64 `json:"run_id"`
|
||||
UserID string `json:"user_id"`
|
||||
AgentID int64 `json:"agent_id"`
|
||||
OrderBy *string `json:"order_by"`
|
||||
Limit int `json:"limit"`
|
||||
Cursor int64 `json:"cursor"` // message id
|
||||
Direction ScrollPageDirection `json:"direction"` // "prev" "Next"
|
||||
ConversationID int64 `json:"conversation_id"`
|
||||
RunID []*int64 `json:"run_id"`
|
||||
UserID string `json:"user_id"`
|
||||
AgentID int64 `json:"agent_id"`
|
||||
OrderBy *string `json:"order_by"`
|
||||
Limit int `json:"limit"`
|
||||
Cursor int64 `json:"cursor"` // message id
|
||||
Direction ScrollPageDirection `json:"direction"` // "prev" "Next"
|
||||
MessageType []*message.MessageType `json:"message_type"`
|
||||
}
|
||||
|
||||
type ListResult struct {
|
||||
@@ -45,8 +48,9 @@ type GetByRunIDsRequest struct {
|
||||
}
|
||||
|
||||
type DeleteMeta struct {
|
||||
MessageIDs []int64 `json:"message_ids"`
|
||||
RunIDs []int64 `json:"run_ids"`
|
||||
ConversationID *int64 `json:"conversation_id"`
|
||||
MessageIDs []int64 `json:"message_ids"`
|
||||
RunIDs []int64 `json:"run_ids"`
|
||||
}
|
||||
|
||||
type BrokenMeta struct {
|
||||
|
||||
@@ -31,6 +31,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/internal/dal/query"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/idgen"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
@@ -71,27 +72,41 @@ func (dao *MessageDAO) Create(ctx context.Context, msg *entity.Message) (*entity
|
||||
return dao.messagePO2DO(poData), nil
|
||||
}
|
||||
|
||||
func (dao *MessageDAO) List(ctx context.Context, conversationID int64, limit int, cursor int64, direction entity.ScrollPageDirection, messageType *message.MessageType) ([]*entity.Message, bool, error) {
|
||||
func (dao *MessageDAO) List(ctx context.Context, listMeta *entity.ListMeta) ([]*entity.Message, bool, error) {
|
||||
m := dao.query.Message
|
||||
do := m.WithContext(ctx).Debug().Where(m.ConversationID.Eq(conversationID)).Where(m.Status.Eq(int32(entity.MessageStatusAvailable)))
|
||||
do := m.WithContext(ctx).Debug().Where(m.ConversationID.Eq(listMeta.ConversationID)).Where(m.Status.Eq(int32(entity.MessageStatusAvailable)))
|
||||
|
||||
if messageType != nil {
|
||||
do = do.Where(m.MessageType.Eq(string(*messageType)))
|
||||
if len(listMeta.RunID) > 0 {
|
||||
do = do.Where(m.RunID.In(slices.Transform(listMeta.RunID, func(t *int64) int64 {
|
||||
return *t
|
||||
})...))
|
||||
}
|
||||
if len(listMeta.MessageType) > 0 {
|
||||
do = do.Where(m.MessageType.In(slices.Transform(listMeta.MessageType, func(t *message.MessageType) string {
|
||||
return string(*t)
|
||||
})...))
|
||||
}
|
||||
|
||||
if limit > 0 {
|
||||
do = do.Limit(int(limit) + 1)
|
||||
if listMeta.Limit > 0 {
|
||||
do = do.Limit(int(listMeta.Limit) + 1)
|
||||
}
|
||||
|
||||
if cursor > 0 {
|
||||
if direction == entity.ScrollPageDirectionPrev {
|
||||
do = do.Where(m.CreatedAt.Lt(cursor))
|
||||
} else {
|
||||
do = do.Where(m.CreatedAt.Gt(cursor))
|
||||
if listMeta.Cursor > 0 {
|
||||
msg, err := m.Where(m.ID.Eq(listMeta.Cursor)).First()
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if listMeta.Direction == entity.ScrollPageDirectionPrev {
|
||||
do = do.Where(m.CreatedAt.Lt(msg.CreatedAt))
|
||||
do = do.Order(m.CreatedAt.Desc())
|
||||
} else {
|
||||
do = do.Where(m.CreatedAt.Gt(msg.CreatedAt))
|
||||
do = do.Order(m.CreatedAt.Asc())
|
||||
}
|
||||
} else {
|
||||
do = do.Order(m.CreatedAt.Desc())
|
||||
}
|
||||
|
||||
do = do.Order(m.CreatedAt.Desc())
|
||||
messageList, err := do.Find()
|
||||
|
||||
var hasMore bool
|
||||
@@ -103,9 +118,9 @@ func (dao *MessageDAO) List(ctx context.Context, conversationID int64, limit int
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
if len(messageList) > limit {
|
||||
if len(messageList) > int(listMeta.Limit) {
|
||||
hasMore = true
|
||||
messageList = messageList[:limit]
|
||||
messageList = messageList[:int(listMeta.Limit)]
|
||||
}
|
||||
|
||||
return dao.batchMessagePO2DO(messageList), hasMore, nil
|
||||
@@ -113,7 +128,8 @@ func (dao *MessageDAO) List(ctx context.Context, conversationID int64, limit int
|
||||
|
||||
func (dao *MessageDAO) GetByRunIDs(ctx context.Context, runIDs []int64, orderBy string) ([]*entity.Message, error) {
|
||||
m := dao.query.Message
|
||||
do := m.WithContext(ctx).Debug().Where(m.RunID.In(runIDs...))
|
||||
do := m.WithContext(ctx).Debug().Where(m.RunID.In(runIDs...)).Where(m.Status.Eq(int32(entity.MessageStatusAvailable)))
|
||||
|
||||
if orderBy == "DESC" {
|
||||
do = do.Order(m.CreatedAt.Desc())
|
||||
} else {
|
||||
@@ -133,19 +149,37 @@ func (dao *MessageDAO) GetByRunIDs(ctx context.Context, runIDs []int64, orderBy
|
||||
|
||||
func (dao *MessageDAO) Edit(ctx context.Context, msgID int64, msg *message.Message) (int64, error) {
|
||||
m := dao.query.Message
|
||||
columns := dao.buildEditColumns(msg)
|
||||
|
||||
originMsg, err := dao.GetByID(ctx, msgID)
|
||||
if originMsg == nil {
|
||||
return 0, errorx.New(errno.ErrRecordNotFound)
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
columns := dao.buildEditColumns(msg, originMsg)
|
||||
do, err := m.WithContext(ctx).Where(m.ID.Eq(msgID)).UpdateColumns(columns)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if do.RowsAffected == 0 {
|
||||
return 0, errorx.New(errno.ErrRecordNotFound)
|
||||
}
|
||||
|
||||
return do.RowsAffected, nil
|
||||
}
|
||||
|
||||
func (dao *MessageDAO) buildEditColumns(msg *message.Message) map[string]interface{} {
|
||||
func (dao *MessageDAO) buildEditColumns(msg *message.Message, originMsg *entity.Message) map[string]interface{} {
|
||||
columns := make(map[string]interface{})
|
||||
table := dao.query.Message
|
||||
if msg.Content != "" {
|
||||
msg.Role = originMsg.Role
|
||||
columns[table.Content.ColumnName().String()] = msg.Content
|
||||
modelContent, err := dao.buildModelContent(msg)
|
||||
if err == nil {
|
||||
columns[table.ModelContent.ColumnName().String()] = modelContent
|
||||
}
|
||||
}
|
||||
if msg.MessageType != "" {
|
||||
columns[table.MessageType.ColumnName().String()] = msg.MessageType
|
||||
@@ -170,6 +204,11 @@ func (dao *MessageDAO) buildEditColumns(msg *message.Message) map[string]interfa
|
||||
|
||||
columns[table.UpdatedAt.ColumnName().String()] = time.Now().UnixMilli()
|
||||
if msg.Ext != nil {
|
||||
if originMsg.Ext != nil {
|
||||
for k, v := range originMsg.Ext {
|
||||
msg.Ext[k] = v
|
||||
}
|
||||
}
|
||||
ext, err := sonic.MarshalString(msg.Ext)
|
||||
if err == nil {
|
||||
columns[table.Ext.ColumnName().String()] = ext
|
||||
@@ -192,8 +231,8 @@ func (dao *MessageDAO) GetByID(ctx context.Context, msgID int64) (*entity.Messag
|
||||
return dao.messagePO2DO(po), nil
|
||||
}
|
||||
|
||||
func (dao *MessageDAO) Delete(ctx context.Context, msgIDs []int64, runIDs []int64) error {
|
||||
if len(msgIDs) == 0 && len(runIDs) == 0 {
|
||||
func (dao *MessageDAO) Delete(ctx context.Context, delMeta *entity.DeleteMeta) error {
|
||||
if len(delMeta.MessageIDs) == 0 && len(delMeta.RunIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -202,11 +241,14 @@ func (dao *MessageDAO) Delete(ctx context.Context, msgIDs []int64, runIDs []int6
|
||||
m := dao.query.Message
|
||||
do := m.WithContext(ctx)
|
||||
|
||||
if len(runIDs) > 0 {
|
||||
do = do.Where(m.RunID.In(runIDs...))
|
||||
if len(delMeta.RunIDs) > 0 {
|
||||
do = do.Where(m.RunID.In(delMeta.RunIDs...))
|
||||
}
|
||||
if len(msgIDs) > 0 {
|
||||
do = do.Where(m.ID.In(msgIDs...))
|
||||
if len(delMeta.MessageIDs) > 0 {
|
||||
do = do.Where(m.ID.In(delMeta.MessageIDs...))
|
||||
}
|
||||
if delMeta.ConversationID != nil && ptr.From(delMeta.ConversationID) > 0 {
|
||||
do = do.Where(m.ConversationID.Eq(*delMeta.ConversationID))
|
||||
}
|
||||
_, err := do.UpdateColumns(&updateColumns)
|
||||
return err
|
||||
@@ -284,6 +326,9 @@ func (dao *MessageDAO) buildModelContent(msgDO *entity.Message) (string, error)
|
||||
var multiContent []schema.ChatMessagePart
|
||||
for _, contentData := range msgDO.MultiContent {
|
||||
if contentData.Type == message.InputTypeText {
|
||||
if len(msgDO.Content) == 0 && len(contentData.Text) > 0 {
|
||||
msgDO.Content = contentData.Text
|
||||
}
|
||||
continue
|
||||
}
|
||||
one := schema.ChatMessagePart{}
|
||||
|
||||
@@ -34,10 +34,9 @@ func NewMessageRepo(db *gorm.DB, idGen idgen.IDGenerator) MessageRepo {
|
||||
type MessageRepo interface {
|
||||
PreCreate(ctx context.Context, msg *entity.Message) (*entity.Message, error)
|
||||
Create(ctx context.Context, msg *entity.Message) (*entity.Message, error)
|
||||
List(ctx context.Context, conversationID int64, limit int, cursor int64,
|
||||
direction entity.ScrollPageDirection, messageType *message.MessageType) ([]*entity.Message, bool, error)
|
||||
List(ctx context.Context, listMeta *entity.ListMeta) ([]*entity.Message, bool, error)
|
||||
GetByRunIDs(ctx context.Context, runIDs []int64, orderBy string) ([]*entity.Message, error)
|
||||
Edit(ctx context.Context, msgID int64, message *message.Message) (int64, error)
|
||||
GetByID(ctx context.Context, msgID int64) (*entity.Message, error)
|
||||
Delete(ctx context.Context, msgIDs []int64, runIDs []int64) error
|
||||
Delete(ctx context.Context, delMeta *entity.DeleteMeta) error
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
|
||||
type Message interface {
|
||||
List(ctx context.Context, req *entity.ListMeta) (*entity.ListResult, error)
|
||||
ListWithoutPair(ctx context.Context, req *entity.ListMeta) (*entity.ListResult, error)
|
||||
PreCreate(ctx context.Context, req *entity.Message) (*entity.Message, error)
|
||||
Create(ctx context.Context, req *entity.Message) (*entity.Message, error)
|
||||
GetByRunIDs(ctx context.Context, conversationID int64, runIDs []int64) ([]*entity.Message, error)
|
||||
|
||||
@@ -18,6 +18,7 @@ package message
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
|
||||
@@ -51,9 +52,9 @@ func (m *messageImpl) Create(ctx context.Context, msg *entity.Message) (*entity.
|
||||
|
||||
func (m *messageImpl) List(ctx context.Context, req *entity.ListMeta) (*entity.ListResult, error) {
|
||||
resp := &entity.ListResult{}
|
||||
|
||||
req.MessageType = []*message.MessageType{ptr.Of(message.MessageTypeQuestion)}
|
||||
// get message with query
|
||||
messageList, hasMore, err := m.MessageRepo.List(ctx, req.ConversationID, req.Limit, req.Cursor, req.Direction, ptr.Of(message.MessageTypeQuestion))
|
||||
messageList, hasMore, err := m.MessageRepo.List(ctx, req)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
@@ -62,8 +63,11 @@ func (m *messageImpl) List(ctx context.Context, req *entity.ListMeta) (*entity.L
|
||||
resp.HasMore = hasMore
|
||||
|
||||
if len(messageList) > 0 {
|
||||
resp.PrevCursor = messageList[len(messageList)-1].CreatedAt
|
||||
resp.NextCursor = messageList[0].CreatedAt
|
||||
sort.Slice(messageList, func(i, j int) bool {
|
||||
return messageList[i].CreatedAt > messageList[j].CreatedAt
|
||||
})
|
||||
resp.PrevCursor = messageList[len(messageList)-1].ID
|
||||
resp.NextCursor = messageList[0].ID
|
||||
|
||||
var runIDs []int64
|
||||
for _, m := range messageList {
|
||||
@@ -82,6 +86,23 @@ func (m *messageImpl) List(ctx context.Context, req *entity.ListMeta) (*entity.L
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (m *messageImpl) ListWithoutPair(ctx context.Context, req *entity.ListMeta) (*entity.ListResult, error) {
|
||||
resp := &entity.ListResult{}
|
||||
messageList, hasMore, err := m.MessageRepo.List(ctx, req)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
resp.Direction = req.Direction
|
||||
resp.HasMore = hasMore
|
||||
resp.Messages = messageList
|
||||
if len(messageList) > 0 {
|
||||
resp.PrevCursor = messageList[0].ID
|
||||
resp.NextCursor = messageList[len(messageList)-1].ID
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (m *messageImpl) GetByRunIDs(ctx context.Context, conversationID int64, runIDs []int64) ([]*entity.Message, error) {
|
||||
return m.MessageRepo.GetByRunIDs(ctx, runIDs, "ASC")
|
||||
}
|
||||
@@ -96,7 +117,7 @@ func (m *messageImpl) Edit(ctx context.Context, req *entity.Message) (*entity.Me
|
||||
}
|
||||
|
||||
func (m *messageImpl) Delete(ctx context.Context, req *entity.DeleteMeta) error {
|
||||
return m.MessageRepo.Delete(ctx, req.MessageIDs, req.RunIDs)
|
||||
return m.MessageRepo.Delete(ctx, req)
|
||||
}
|
||||
|
||||
func (m *messageImpl) GetByID(ctx context.Context, id int64) (*entity.Message, error) {
|
||||
|
||||
@@ -18,6 +18,7 @@ package message
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -145,20 +146,26 @@ func TestCreateMessage(t *testing.T) {
|
||||
func TestEditMessage(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mockDBGen := orm.NewMockDB()
|
||||
|
||||
extData := map[string]string{
|
||||
"test": "test",
|
||||
}
|
||||
ext, _ := json.Marshal(extData)
|
||||
mockDBGen.AddTable(&model.Message{}).
|
||||
AddRows(
|
||||
&model.Message{
|
||||
ID: 1,
|
||||
ConversationID: 1,
|
||||
UserID: "1",
|
||||
Role: string(schema.User),
|
||||
RunID: 123,
|
||||
},
|
||||
&model.Message{
|
||||
ID: 2,
|
||||
ConversationID: 1,
|
||||
UserID: "1",
|
||||
Role: string(schema.User),
|
||||
RunID: 124,
|
||||
Ext: string(ext),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -177,7 +184,7 @@ func TestEditMessage(t *testing.T) {
|
||||
Url: "https://xxxxx.xxxx/file",
|
||||
Name: "test_file",
|
||||
}
|
||||
content := []*message.InputMetaData{
|
||||
_ = []*message.InputMetaData{
|
||||
{
|
||||
Type: message.InputTypeText,
|
||||
Text: "解析图片中的内容",
|
||||
@@ -197,56 +204,293 @@ func TestEditMessage(t *testing.T) {
|
||||
}
|
||||
|
||||
resp, err := NewService(components).Edit(ctx, &entity.Message{
|
||||
ID: 2,
|
||||
Content: "test edit message",
|
||||
MultiContent: content,
|
||||
ID: 2,
|
||||
Content: "test edit message",
|
||||
Ext: map[string]string{"newext": "true"},
|
||||
|
||||
// MultiContent: content,
|
||||
})
|
||||
_ = resp
|
||||
|
||||
msOne, err := NewService(components).GetByRunIDs(ctx, 1, []int64{124})
|
||||
msg, err := NewService(components).GetByID(ctx, 2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, int64(124), msOne[0].RunID)
|
||||
assert.Equal(t, int64(2), msg.ID)
|
||||
assert.Equal(t, "test edit message", msg.Content)
|
||||
var modelContent *schema.Message
|
||||
err = json.Unmarshal([]byte(msg.ModelContent), &modelContent)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "test edit message", modelContent.Content)
|
||||
|
||||
assert.Equal(t, "true", msg.Ext["newext"])
|
||||
}
|
||||
|
||||
func TestGetByRunIDs(t *testing.T) {
|
||||
//func TestGetByRunIDs(t *testing.T) {
|
||||
// ctx := context.Background()
|
||||
//
|
||||
// mockDBGen := orm.NewMockDB()
|
||||
//
|
||||
// mockDBGen.AddTable(&model.Message{}).
|
||||
// AddRows(
|
||||
// &model.Message{
|
||||
// ID: 1,
|
||||
// ConversationID: 1,
|
||||
// UserID: "1",
|
||||
// RunID: 123,
|
||||
// Content: "test content123",
|
||||
// },
|
||||
// &model.Message{
|
||||
// ID: 2,
|
||||
// ConversationID: 1,
|
||||
// UserID: "1",
|
||||
// Content: "test content124",
|
||||
// RunID: 124,
|
||||
// },
|
||||
// &model.Message{
|
||||
// ID: 3,
|
||||
// ConversationID: 1,
|
||||
// UserID: "1",
|
||||
// Content: "test content124",
|
||||
// RunID: 124,
|
||||
// },
|
||||
// )
|
||||
// mockDB, err := mockDBGen.DB()
|
||||
// assert.NoError(t, err)
|
||||
// components := &Components{
|
||||
// MessageRepo: repository.NewMessageRepo(mockDB, nil),
|
||||
// }
|
||||
//
|
||||
// resp, err := NewService(components).GetByRunIDs(ctx, 1, []int64{124})
|
||||
//
|
||||
// assert.NoError(t, err)
|
||||
//
|
||||
// assert.Len(t, resp, 2)
|
||||
//}
|
||||
|
||||
func TestListWithoutPair(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
t.Run("success_with_messages", func(t *testing.T) {
|
||||
mockDBGen := orm.NewMockDB()
|
||||
|
||||
mockDBGen := orm.NewMockDB()
|
||||
mockDBGen.AddTable(&model.Message{}).
|
||||
AddRows(
|
||||
&model.Message{
|
||||
ID: 1,
|
||||
ConversationID: 100,
|
||||
UserID: "user123",
|
||||
RunID: 200,
|
||||
Content: "Hello",
|
||||
MessageType: string(message.MessageTypeAnswer),
|
||||
Status: 1, // MessageStatusAvailable
|
||||
CreatedAt: time.Now().UnixMilli(),
|
||||
},
|
||||
&model.Message{
|
||||
ID: 2,
|
||||
ConversationID: 100,
|
||||
UserID: "user123",
|
||||
RunID: 201,
|
||||
Content: "World",
|
||||
MessageType: string(message.MessageTypeAnswer),
|
||||
Status: 1, // MessageStatusAvailable
|
||||
CreatedAt: time.Now().UnixMilli(),
|
||||
},
|
||||
)
|
||||
|
||||
mockDBGen.AddTable(&model.Message{}).
|
||||
AddRows(
|
||||
&model.Message{
|
||||
ID: 1,
|
||||
ConversationID: 1,
|
||||
UserID: "1",
|
||||
RunID: 123,
|
||||
Content: "test content123",
|
||||
},
|
||||
&model.Message{
|
||||
ID: 2,
|
||||
ConversationID: 1,
|
||||
UserID: "1",
|
||||
Content: "test content124",
|
||||
RunID: 124,
|
||||
},
|
||||
&model.Message{
|
||||
ID: 3,
|
||||
ConversationID: 1,
|
||||
UserID: "1",
|
||||
Content: "test content124",
|
||||
RunID: 124,
|
||||
},
|
||||
)
|
||||
mockDB, err := mockDBGen.DB()
|
||||
assert.NoError(t, err)
|
||||
components := &Components{
|
||||
MessageRepo: repository.NewMessageRepo(mockDB, nil),
|
||||
}
|
||||
mockDB, err := mockDBGen.DB()
|
||||
assert.NoError(t, err)
|
||||
|
||||
resp, err := NewService(components).GetByRunIDs(ctx, 1, []int64{124})
|
||||
components := &Components{
|
||||
MessageRepo: repository.NewMessageRepo(mockDB, nil),
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
req := &entity.ListMeta{
|
||||
ConversationID: 100,
|
||||
UserID: "user123",
|
||||
Limit: 10,
|
||||
Direction: entity.ScrollPageDirectionNext,
|
||||
}
|
||||
|
||||
assert.Len(t, resp, 2)
|
||||
resp, err := NewService(components).ListWithoutPair(ctx, req)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
assert.Equal(t, entity.ScrollPageDirectionNext, resp.Direction)
|
||||
assert.False(t, resp.HasMore)
|
||||
assert.Len(t, resp.Messages, 2)
|
||||
assert.Equal(t, "Hello", resp.Messages[0].Content)
|
||||
assert.Equal(t, "World", resp.Messages[1].Content)
|
||||
})
|
||||
|
||||
t.Run("empty_result", func(t *testing.T) {
|
||||
mockDBGen := orm.NewMockDB()
|
||||
mockDBGen.AddTable(&model.Message{})
|
||||
|
||||
mockDB, err := mockDBGen.DB()
|
||||
assert.NoError(t, err)
|
||||
|
||||
components := &Components{
|
||||
MessageRepo: repository.NewMessageRepo(mockDB, nil),
|
||||
}
|
||||
|
||||
req := &entity.ListMeta{
|
||||
ConversationID: 999,
|
||||
UserID: "user123",
|
||||
Limit: 10,
|
||||
Direction: entity.ScrollPageDirectionNext,
|
||||
}
|
||||
|
||||
resp, err := NewService(components).ListWithoutPair(ctx, req)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
assert.Equal(t, entity.ScrollPageDirectionNext, resp.Direction)
|
||||
assert.False(t, resp.HasMore)
|
||||
assert.Len(t, resp.Messages, 0)
|
||||
})
|
||||
|
||||
t.Run("pagination_has_more", func(t *testing.T) {
|
||||
mockDBGen := orm.NewMockDB()
|
||||
|
||||
mockDBGen.AddTable(&model.Message{}).
|
||||
AddRows(
|
||||
&model.Message{
|
||||
ID: 1,
|
||||
ConversationID: 100,
|
||||
UserID: "user123",
|
||||
RunID: 200,
|
||||
Content: "Message 1",
|
||||
MessageType: string(message.MessageTypeAnswer),
|
||||
Status: 1,
|
||||
CreatedAt: time.Now().UnixMilli() - 3000,
|
||||
},
|
||||
&model.Message{
|
||||
ID: 2,
|
||||
ConversationID: 100,
|
||||
UserID: "user123",
|
||||
RunID: 201,
|
||||
Content: "Message 2",
|
||||
MessageType: string(message.MessageTypeAnswer),
|
||||
Status: 1,
|
||||
CreatedAt: time.Now().UnixMilli() - 2000,
|
||||
},
|
||||
&model.Message{
|
||||
ID: 3,
|
||||
ConversationID: 100,
|
||||
UserID: "user123",
|
||||
RunID: 202,
|
||||
Content: "Message 3",
|
||||
MessageType: string(message.MessageTypeAnswer),
|
||||
Status: 1,
|
||||
CreatedAt: time.Now().UnixMilli() - 1000,
|
||||
},
|
||||
)
|
||||
|
||||
mockDB, err := mockDBGen.DB()
|
||||
assert.NoError(t, err)
|
||||
|
||||
components := &Components{
|
||||
MessageRepo: repository.NewMessageRepo(mockDB, nil),
|
||||
}
|
||||
|
||||
req := &entity.ListMeta{
|
||||
ConversationID: 100,
|
||||
UserID: "user123",
|
||||
Limit: 2,
|
||||
Direction: entity.ScrollPageDirectionNext,
|
||||
}
|
||||
|
||||
resp, err := NewService(components).ListWithoutPair(ctx, req)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
assert.Equal(t, entity.ScrollPageDirectionNext, resp.Direction)
|
||||
assert.True(t, resp.HasMore)
|
||||
assert.Len(t, resp.Messages, 2)
|
||||
})
|
||||
|
||||
t.Run("direction_prev", func(t *testing.T) {
|
||||
mockDBGen := orm.NewMockDB()
|
||||
|
||||
mockDBGen.AddTable(&model.Message{}).
|
||||
AddRows(
|
||||
&model.Message{
|
||||
ID: 1,
|
||||
ConversationID: 100,
|
||||
UserID: "user123",
|
||||
RunID: 200,
|
||||
Content: "Test message",
|
||||
MessageType: string(message.MessageTypeAnswer),
|
||||
Status: 1,
|
||||
CreatedAt: time.Now().UnixMilli(),
|
||||
},
|
||||
)
|
||||
|
||||
mockDB, err := mockDBGen.DB()
|
||||
assert.NoError(t, err)
|
||||
|
||||
components := &Components{
|
||||
MessageRepo: repository.NewMessageRepo(mockDB, nil),
|
||||
}
|
||||
|
||||
req := &entity.ListMeta{
|
||||
ConversationID: 100,
|
||||
UserID: "user123",
|
||||
Limit: 10,
|
||||
Direction: entity.ScrollPageDirectionPrev,
|
||||
}
|
||||
|
||||
resp, err := NewService(components).ListWithoutPair(ctx, req)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
assert.Equal(t, entity.ScrollPageDirectionPrev, resp.Direction)
|
||||
assert.False(t, resp.HasMore)
|
||||
assert.Len(t, resp.Messages, 1)
|
||||
})
|
||||
|
||||
t.Run("with_message_type_filter", func(t *testing.T) {
|
||||
mockDBGen := orm.NewMockDB()
|
||||
|
||||
mockDBGen.AddTable(&model.Message{}).
|
||||
AddRows(
|
||||
&model.Message{
|
||||
ID: 1,
|
||||
ConversationID: 100,
|
||||
UserID: "user123",
|
||||
RunID: 200,
|
||||
Content: "Answer message",
|
||||
MessageType: string(message.MessageTypeAnswer),
|
||||
Status: 1,
|
||||
CreatedAt: time.Now().UnixMilli(),
|
||||
},
|
||||
&model.Message{
|
||||
ID: 2,
|
||||
ConversationID: 100,
|
||||
UserID: "user123",
|
||||
RunID: 201,
|
||||
Content: "Question message",
|
||||
MessageType: string(message.MessageTypeQuestion),
|
||||
Status: 1,
|
||||
CreatedAt: time.Now().UnixMilli(),
|
||||
},
|
||||
)
|
||||
|
||||
mockDB, err := mockDBGen.DB()
|
||||
assert.NoError(t, err)
|
||||
|
||||
components := &Components{
|
||||
MessageRepo: repository.NewMessageRepo(mockDB, nil),
|
||||
}
|
||||
|
||||
req := &entity.ListMeta{
|
||||
ConversationID: 100,
|
||||
UserID: "user123",
|
||||
Limit: 10,
|
||||
Direction: entity.ScrollPageDirectionNext,
|
||||
MessageType: []*message.MessageType{&[]message.MessageType{message.MessageTypeAnswer}[0]},
|
||||
}
|
||||
|
||||
resp, err := NewService(components).ListWithoutPair(ctx, req)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
assert.Len(t, resp.Messages, 1)
|
||||
assert.Equal(t, "Answer message", resp.Messages[0].Content)
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user