From 249c23c64b0c8bc96da2a8a475bf0275f45aad5d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E4=BF=8A=E6=96=87?= Date: Mon, 21 Jul 2025 11:23:48 +0000 Subject: [PATCH] Fix/content reasoning * fix: concat * fix: concat * fix: concat sep * fix: concat content & url * fix: opimized * fix: test * fix: test * fix: multi content parse with uri * fix: multi content combine * fix: remove unused func * fix: multi content with content & reasoning content * fix: multi content with content & reasoning content See merge request: !886 --- .../api/model/crossdomain/message/message.go | 1 + backend/application/application.go | 2 +- backend/application/conversation/agent_run.go | 7 +- backend/application/conversation/message.go | 44 +++++++++- .../contract/crossmessage/cross_message.go | 1 + backend/crossdomain/impl/message/message.go | 4 + .../impl/singleagent/single_agent.go | 59 ++++++++++++-- .../internal/agentflow/agent_flow_runner.go | 80 ++++++++++++++++--- .../agentflow/callback_reply_chunk.go | 2 +- .../agentrun/service/agent_run_impl.go | 44 +++++----- .../message/internal/dal/message.go | 66 ++++++++++----- .../message/repository/repository.go | 1 + .../conversation/message/service/message.go | 1 + .../message/service/message_impl.go | 32 +++----- .../message/service/message_test.go | 9 ++- backend/infra/impl/storage/minio/minio.go | 2 +- 16 files changed, 263 insertions(+), 92 deletions(-) diff --git a/backend/api/model/crossdomain/message/message.go b/backend/api/model/crossdomain/message/message.go index c7db3dd4..9b7aedd1 100644 --- a/backend/api/model/crossdomain/message/message.go +++ b/backend/api/model/crossdomain/message/message.go @@ -56,6 +56,7 @@ const ( type FileData struct { Url string `json:"url"` + URI string `json:"uri"` Name string `json:"name"` } diff --git a/backend/application/application.go b/backend/application/application.go index a0af9d36..5de77542 100644 --- a/backend/application/application.go +++ b/backend/application/application.go @@ -137,7 +137,7 @@ func Init(ctx context.Context) (err error) { crossconversation.SetDefaultSVC(conversationImpl.InitDomainService(complexServices.conversationSVC.ConversationDomainSVC)) crossmessage.SetDefaultSVC(messageImpl.InitDomainService(complexServices.conversationSVC.MessageDomainSVC)) crossagentrun.SetDefaultSVC(agentrunImpl.InitDomainService(complexServices.conversationSVC.AgentRunDomainSVC)) - crossagent.SetDefaultSVC(singleagentImpl.InitDomainService(complexServices.singleAgentSVC.DomainSVC)) + crossagent.SetDefaultSVC(singleagentImpl.InitDomainService(complexServices.singleAgentSVC.DomainSVC, infra.ImageXClient)) crossuser.SetDefaultSVC(crossuserImpl.InitDomainService(basicServices.userSVC.DomainSVC)) crossdatacopy.SetDefaultSVC(dataCopyImpl.InitDomainService(basicServices.infra)) crosssearch.SetDefaultSVC(searchImpl.InitDomainService(complexServices.searchSVC.DomainSVC)) diff --git a/backend/application/conversation/agent_run.go b/backend/application/conversation/agent_run.go index 5887b7ae..df50f9bd 100644 --- a/backend/application/conversation/agent_run.go +++ b/backend/application/conversation/agent_run.go @@ -418,11 +418,12 @@ func (c *ConversationApplicationService) parseMultiContent(ctx context.Context, resourceUrl, err := c.getUrlByUri(ctx, item.Image.Key) if err != nil { + logs.CtxErrorf(ctx, "failed to unescape resource url, err is %v", err) continue } - if err != nil { - logs.CtxErrorf(ctx, "failed to unescape resource url, err is %v", err) + if resourceUrl == "" { + logs.CtxErrorf(ctx, "failed to unescape resource url, uri is %v", item.Image.Key) continue } @@ -434,6 +435,7 @@ func (c *ConversationApplicationService) parseMultiContent(ctx context.Context, FileData: []*crossDomainMessage.FileData{ { Url: resourceUrl, + URI: item.Image.Key, }, }, }) @@ -451,6 +453,7 @@ func (c *ConversationApplicationService) parseMultiContent(ctx context.Context, FileData: []*crossDomainMessage.FileData{ { Url: resourceUrl, + URI: item.File.FileKey, }, }, }) diff --git a/backend/application/conversation/message.go b/backend/application/conversation/message.go index 8c9c94a6..53dcdecd 100644 --- a/backend/application/conversation/message.go +++ b/backend/application/conversation/message.go @@ -18,10 +18,12 @@ package conversation import ( "context" + "encoding/json" "strconv" "github.com/coze-dev/coze-studio/backend/api/model/conversation/common" "github.com/coze-dev/coze-studio/backend/api/model/conversation/message" + "github.com/coze-dev/coze-studio/backend/api/model/conversation/run" model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message" "github.com/coze-dev/coze-studio/backend/application/base/ctxutil" singleAgentEntity "github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity" @@ -217,7 +219,7 @@ func (c *ConversationApplicationService) buildDomainMsg2VOMessage(ctx context.Co } if dm.ContentType == model.ContentTypeMix && dm.DisplayContent != "" { - cm.Content = dm.DisplayContent + cm.Content = c.buildParseMessageURI(ctx, dm.DisplayContent) } if dm.MessageType != model.MessageTypeQuestion { @@ -227,6 +229,46 @@ func (c *ConversationApplicationService) buildDomainMsg2VOMessage(ctx context.Co return cm } +func (c *ConversationApplicationService) buildParseMessageURI(ctx context.Context, msgContent string) string { + + if msgContent == "" { + return msgContent + } + + var mc *run.MixContentModel + err := json.Unmarshal([]byte(msgContent), &mc) + if err != nil { + return msgContent + } + for k, item := range mc.ItemList { + switch item.Type { + case run.ContentTypeImage: + + url, pErr := c.appContext.ImageX.GetResourceURL(ctx, item.Image.Key) + if pErr == nil { + mc.ItemList[k].Image.ImageThumb.URL = url.URL + mc.ItemList[k].Image.ImageOri.URL = url.URL + } + + case run.ContentTypeFile, run.ContentTypeAudio, run.ContentTypeVideo: + url, pErr := c.appContext.ImageX.GetResourceURL(ctx, item.File.FileKey) + if pErr == nil { + mc.ItemList[k].File.FileURL = url.URL + + } + + default: + + } + } + jsonMsg, err := json.Marshal(mc) + if err != nil { + return msgContent + } + + return string(jsonMsg) +} + func buildDExt2ApiExt(extra map[string]string) *message.ExtraInfo { return &message.ExtraInfo{ InputTokens: extra["input_tokens"], diff --git a/backend/crossdomain/contract/crossmessage/cross_message.go b/backend/crossdomain/contract/crossmessage/cross_message.go index 36f6a39d..746cc533 100644 --- a/backend/crossdomain/contract/crossmessage/cross_message.go +++ b/backend/crossdomain/contract/crossmessage/cross_message.go @@ -24,6 +24,7 @@ import ( type Message interface { GetByRunIDs(ctx context.Context, conversationID int64, runIDs []int64) ([]*message.Message, error) + PreCreate(ctx context.Context, msg *message.Message) (*message.Message, error) Create(ctx context.Context, msg *message.Message) (*message.Message, error) Edit(ctx context.Context, msg *message.Message) (*message.Message, error) } diff --git a/backend/crossdomain/impl/message/message.go b/backend/crossdomain/impl/message/message.go index b65b663f..a71519b4 100644 --- a/backend/crossdomain/impl/message/message.go +++ b/backend/crossdomain/impl/message/message.go @@ -49,3 +49,7 @@ func (c *impl) Create(ctx context.Context, msg *model.Message) (*model.Message, func (c *impl) Edit(ctx context.Context, msg *model.Message) (*model.Message, error) { return c.DomainSVC.Edit(ctx, msg) } + +func (c *impl) PreCreate(ctx context.Context, msg *model.Message) (*model.Message, error) { + return c.DomainSVC.PreCreate(ctx, msg) +} diff --git a/backend/crossdomain/impl/singleagent/single_agent.go b/backend/crossdomain/impl/singleagent/single_agent.go index f9ca8e1a..f233f2e6 100644 --- a/backend/crossdomain/impl/singleagent/single_agent.go +++ b/backend/crossdomain/impl/singleagent/single_agent.go @@ -28,6 +28,7 @@ import ( "github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossagent" singleagent "github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/service" "github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity" + "github.com/coze-dev/coze-studio/backend/infra/contract/imagex" "github.com/coze-dev/coze-studio/backend/pkg/lang/conv" "github.com/coze-dev/coze-studio/backend/pkg/lang/slices" "github.com/coze-dev/coze-studio/backend/pkg/logs" @@ -37,11 +38,13 @@ var defaultSVC crossagent.SingleAgent type impl struct { DomainSVC singleagent.SingleAgent + ImagexSVC imagex.ImageX } -func InitDomainService(c singleagent.SingleAgent) crossagent.SingleAgent { +func InitDomainService(c singleagent.SingleAgent, imagexClient imagex.ImageX) crossagent.SingleAgent { defaultSVC = &impl{ DomainSVC: c, + ImagexSVC: imagexClient, } return defaultSVC @@ -64,14 +67,18 @@ func (c *impl) buildSingleAgentStreamExecuteReq(ctx context.Context, historyMsg input *message.Message, agentRuntime *model.AgentRuntime, ) *model.ExecuteRequest { identity := c.buildIdentity(input, agentRuntime) - inputBuild := c.buildSchemaMessage([]*message.Message{input}) - history := c.buildSchemaMessage(historyMsg) + inputBuild := c.buildSchemaMessage(ctx, []*message.Message{input}) + var inputSM *schema.Message + if len(inputBuild) > 0 { + inputSM = inputBuild[0] + } + history := c.buildSchemaMessage(ctx, historyMsg) resumeInfo := c.checkResumeInfo(ctx, historyMsg) return &model.ExecuteRequest{ Identity: identity, - Input: inputBuild[0], + Input: inputSM, History: history, UserID: input.UserID, PreCallTools: slices.Transform(agentRuntime.PreRetrieveTools, func(tool *agentrun.Tool) *agentrun.ToolsRetriever { @@ -143,7 +150,7 @@ func (c *impl) checkResumeInfo(_ context.Context, historyMsg []*message.Message) return resumeInfo } -func (c *impl) buildSchemaMessage(msgs []*message.Message) []*schema.Message { +func (c *impl) buildSchemaMessage(ctx context.Context, msgs []*message.Message) []*schema.Message { schemaMessage := make([]*schema.Message, 0, len(msgs)) for _, msgOne := range msgs { @@ -158,12 +165,52 @@ func (c *impl) buildSchemaMessage(msgs []*message.Message) []*schema.Message { if err != nil { continue } - schemaMessage = append(schemaMessage, sm) + schemaMessage = append(schemaMessage, c.parseMessageURI(ctx, sm)) } return schemaMessage } +func (c *impl) parseMessageURI(ctx context.Context, mcMsg *schema.Message) *schema.Message { + if mcMsg.MultiContent == nil { + return mcMsg + } + for k, one := range mcMsg.MultiContent { + switch one.Type { + case schema.ChatMessagePartTypeImageURL: + + if one.ImageURL.URI != "" { + url, err := c.ImagexSVC.GetResourceURL(ctx, one.ImageURL.URI) + if err == nil { + mcMsg.MultiContent[k].ImageURL.URL = url.URL + } + } + case schema.ChatMessagePartTypeFileURL: + if one.FileURL.URI != "" { + url, err := c.ImagexSVC.GetResourceURL(ctx, one.FileURL.URI) + if err == nil { + mcMsg.MultiContent[k].FileURL.URL = url.URL + } + } + case schema.ChatMessagePartTypeAudioURL: + if one.AudioURL.URI != "" { + url, err := c.ImagexSVC.GetResourceURL(ctx, one.AudioURL.URI) + if err == nil { + mcMsg.MultiContent[k].AudioURL.URL = url.URL + } + } + case schema.ChatMessagePartTypeVideoURL: + if one.VideoURL.URI != "" { + url, err := c.ImagexSVC.GetResourceURL(ctx, one.VideoURL.URI) + if err == nil { + mcMsg.MultiContent[k].VideoURL.URL = url.URL + } + } + } + } + return mcMsg +} + func (c *impl) buildIdentity(input *message.Message, agentRuntime *model.AgentRuntime) *model.AgentIdentity { return &model.AgentIdentity{ AgentID: input.AgentID, diff --git a/backend/domain/agent/singleagent/internal/agentflow/agent_flow_runner.go b/backend/domain/agent/singleagent/internal/agentflow/agent_flow_runner.go index e0a3cd89..6bbfc022 100644 --- a/backend/domain/agent/singleagent/internal/agentflow/agent_flow_runner.go +++ b/backend/domain/agent/singleagent/internal/agentflow/agent_flow_runner.go @@ -32,6 +32,7 @@ import ( "github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossmodelmgr" "github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossworkflow" "github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity" + "github.com/coze-dev/coze-studio/backend/pkg/lang/conv" "github.com/coze-dev/coze-studio/backend/pkg/logs" ) @@ -119,48 +120,91 @@ func (r *AgentRunner) StreamExecute(ctx context.Context, req *AgentRequest) ( func (r *AgentRunner) PreHandlerReq(ctx context.Context, req *AgentRequest) *AgentRequest { req.Input = r.preHandlerInput(req.Input) req.History = r.preHandlerHistory(req.History) + logs.CtxInfof(ctx, "[AgentRunner] PreHandlerReq, req: %v", conv.DebugJsonToStr(req)) return req } func (r *AgentRunner) preHandlerInput(input *schema.Message) *schema.Message { var multiContent []schema.ChatMessagePart + + if len(input.MultiContent) == 0 { + return input + } + + unSupportMultiPart := make([]schema.ChatMessagePart, 0, len(input.MultiContent)) + for _, v := range input.MultiContent { switch v.Type { case schema.ChatMessagePartTypeImageURL: - if !slices.Contains(r.modelInfo.Meta.Capability.InputModal, modelmgr.ModalImage) { - input.Content = input.Content + ": " + v.ImageURL.URL + if !r.isSupportImage() { + unSupportMultiPart = append(unSupportMultiPart, v) } else { multiContent = append(multiContent, v) } case schema.ChatMessagePartTypeFileURL: - if !slices.Contains(r.modelInfo.Meta.Capability.InputModal, modelmgr.ModalFile) { - input.Content = input.Content + ": " + v.FileURL.URL + if !r.isSupportFile() { + unSupportMultiPart = append(unSupportMultiPart, v) } else { multiContent = append(multiContent, v) } case schema.ChatMessagePartTypeAudioURL: - if !slices.Contains(r.modelInfo.Meta.Capability.InputModal, modelmgr.ModalAudio) { - input.Content = input.Content + ": " + v.FileURL.URL + if !r.isSupportAudio() { + unSupportMultiPart = append(unSupportMultiPart, v) } else { multiContent = append(multiContent, v) } case schema.ChatMessagePartTypeVideoURL: - if !slices.Contains(r.modelInfo.Meta.Capability.InputModal, modelmgr.ModalVideo) { - input.Content = input.Content + ": " + v.FileURL.URL + if !r.isSupportVideo() { + unSupportMultiPart = append(unSupportMultiPart, v) } else { multiContent = append(multiContent, v) } case schema.ChatMessagePartTypeText: - break - default: multiContent = append(multiContent, v) } } + + for _, v := range input.MultiContent { + if v.Type != schema.ChatMessagePartTypeText { + continue + } + + if r.isSupportMultiContent() { + if len(multiContent) > 0 { + v.Text = concatContentString(v.Text, unSupportMultiPart) + multiContent = append(multiContent, v) + } else { + input.Content = concatContentString(v.Text, unSupportMultiPart) + } + } else { + input.Content = concatContentString(v.Text, unSupportMultiPart) + } + + } input.MultiContent = multiContent return input } +func concatContentString(textContent string, unSupportTypeURL []schema.ChatMessagePart) string { + if len(unSupportTypeURL) == 0 { + return textContent + } + for _, v := range unSupportTypeURL { + switch v.Type { + case schema.ChatMessagePartTypeImageURL: + textContent += " this is a image:" + v.ImageURL.URL + case schema.ChatMessagePartTypeFileURL: + textContent += " this is a file:" + v.FileURL.URL + case schema.ChatMessagePartTypeAudioURL: + textContent += " this is a audio:" + v.AudioURL.URL + case schema.ChatMessagePartTypeVideoURL: + textContent += " this is a video:" + v.VideoURL.URL + default: + } + } + return textContent +} func (r *AgentRunner) preHandlerHistory(history []*schema.Message) []*schema.Message { var hm []*schema.Message @@ -172,3 +216,19 @@ func (r *AgentRunner) preHandlerHistory(history []*schema.Message) []*schema.Mes } return hm } + +func (r *AgentRunner) isSupportMultiContent() bool { + return len(r.modelInfo.Meta.Capability.InputModal) > 1 +} +func (r *AgentRunner) isSupportImage() bool { + return slices.Contains(r.modelInfo.Meta.Capability.InputModal, modelmgr.ModalImage) +} +func (r *AgentRunner) isSupportFile() bool { + return slices.Contains(r.modelInfo.Meta.Capability.InputModal, modelmgr.ModalFile) +} +func (r *AgentRunner) isSupportAudio() bool { + return slices.Contains(r.modelInfo.Meta.Capability.InputModal, modelmgr.ModalAudio) +} +func (r *AgentRunner) isSupportVideo() bool { + return slices.Contains(r.modelInfo.Meta.Capability.InputModal, modelmgr.ModalVideo) +} diff --git a/backend/domain/agent/singleagent/internal/agentflow/callback_reply_chunk.go b/backend/domain/agent/singleagent/internal/agentflow/callback_reply_chunk.go index 9537e0f9..fd2a62a3 100644 --- a/backend/domain/agent/singleagent/internal/agentflow/callback_reply_chunk.go +++ b/backend/domain/agent/singleagent/internal/agentflow/callback_reply_chunk.go @@ -94,7 +94,7 @@ func (r *replyChunkCallback) OnError(ctx context.Context, info *callbacks.RunInf r.sw.Send(interruptEvent, nil) } else { - logs.CtxErrorf(ctx, "node execute failed, component=%v, name=%v, err=%w", + logs.CtxErrorf(ctx, "node execute failed, component=%v, name=%v, err=%v", info.Component, info.Name, err) var customErr errorx.StatusError errMsg := "Internal server error" diff --git a/backend/domain/conversation/agentrun/service/agent_run_impl.go b/backend/domain/conversation/agentrun/service/agent_run_impl.go index 2ffab99e..9cd62810 100644 --- a/backend/domain/conversation/agentrun/service/agent_run_impl.go +++ b/backend/domain/conversation/agentrun/service/agent_run_impl.go @@ -473,6 +473,10 @@ func (c *runImpl) push(ctx context.Context, mainChan chan *entity.AgentRespEvent } }() + reasoningContent := bytes.NewBuffer([]byte{}) + var createPreMsg = true + var preFinalAnswerMsg *msgEntity.Message + for { chunk, ok := <-mainChan if !ok || chunk == nil { @@ -505,11 +509,7 @@ func (c *runImpl) push(ctx context.Context, mainChan chan *entity.AgentRespEvent } case message.MessageTypeAnswer: fullContent := bytes.NewBuffer([]byte{}) - reasoningContent := bytes.NewBuffer([]byte{}) - - var preMsg *msgEntity.Message var usage *msgEntity.UsageExt - var createPreMsg = true var isToolCalls = false for { @@ -518,14 +518,15 @@ func (c *runImpl) push(ctx context.Context, mainChan chan *entity.AgentRespEvent if receErr != nil { if errors.Is(receErr, io.EOF) { - if isToolCalls && reasoningContent.String() == "" { + if isToolCalls { break } - finalAnswer := c.buildSendMsg(ctx, preMsg, false, rtDependence) + finalAnswer := c.buildSendMsg(ctx, preFinalAnswerMsg, false, rtDependence) + finalAnswer.Content = fullContent.String() finalAnswer.ReasoningContent = ptr.Of(reasoningContent.String()) - hfErr := c.handlerFinalAnswer(ctx, finalAnswer, sw, usage, rtDependence) + hfErr := c.handlerFinalAnswer(ctx, finalAnswer, sw, usage, rtDependence, preFinalAnswerMsg) if hfErr != nil { err = hfErr return @@ -553,14 +554,14 @@ func (c *runImpl) push(ctx context.Context, mainChan chan *entity.AgentRespEvent continue } if createPreMsg && (len(streamMsg.ReasoningContent) > 0 || len(streamMsg.Content) > 0) { - preMsg, err = c.handlerPreAnswer(ctx, rtDependence) + preFinalAnswerMsg, err = c.PreCreateFinalAnswer(ctx, rtDependence) if err != nil { return } createPreMsg = false } - sendMsg := c.buildSendMsg(ctx, preMsg, false, rtDependence) + sendMsg := c.buildSendMsg(ctx, preFinalAnswerMsg, false, rtDependence) reasoningContent.WriteString(streamMsg.ReasoningContent) sendMsg.ReasoningContent = ptr.Of(streamMsg.ReasoningContent) @@ -590,7 +591,7 @@ func (c *runImpl) handlerInterrupt(ctx context.Context, chunk *entity.AgentRespE if err != nil { return err } - preMsg, err := c.handlerPreAnswer(ctx, rtDependence) + preMsg, err := c.PreCreateFinalAnswer(ctx, rtDependence) if err != nil { return err } @@ -612,7 +613,7 @@ func (c *runImpl) handlerInterrupt(ctx context.Context, chunk *entity.AgentRespE c.runEvent.SendMsgEvent(entity.RunEventMessageDelta, deltaAnswer, sw) finalAnswer := deepcopy.Copy(deltaAnswer).(*entity.ChunkMessageItem) - err = c.handlerFinalAnswer(ctx, finalAnswer, sw, nil, rtDependence) + err = c.handlerFinalAnswer(ctx, finalAnswer, sw, nil, rtDependence, preMsg) if err != nil { return err } @@ -718,7 +719,7 @@ func (c *runImpl) handlerErr(_ context.Context, err error, sw *schema.StreamWrit }) } -func (c *runImpl) handlerPreAnswer(ctx context.Context, rtDependence *runtimeDependence) (*msgEntity.Message, error) { +func (c *runImpl) PreCreateFinalAnswer(ctx context.Context, rtDependence *runtimeDependence) (*msgEntity.Message, error) { arm := rtDependence.runMeta msgMeta := &msgEntity.Message{ ConversationID: arm.ConversationID, @@ -747,10 +748,10 @@ func (c *runImpl) handlerPreAnswer(ctx context.Context, rtDependence *runtimeDep } msgMeta.Ext = arm.Ext - return crossmessage.DefaultSVC().Create(ctx, msgMeta) + return crossmessage.DefaultSVC().PreCreate(ctx, msgMeta) } -func (c *runImpl) handlerFinalAnswer(ctx context.Context, msg *entity.ChunkMessageItem, sw *schema.StreamWriter[*entity.AgentRunResponse], usage *msgEntity.UsageExt, rtDependence *runtimeDependence) error { +func (c *runImpl) handlerFinalAnswer(ctx context.Context, msg *entity.ChunkMessageItem, sw *schema.StreamWriter[*entity.AgentRunResponse], usage *msgEntity.UsageExt, rtDependence *runtimeDependence, preFinalAnswerMsg *msgEntity.Message) error { if len(msg.Content) == 0 && len(ptr.From(msg.ReasoningContent)) == 0 { return nil @@ -786,15 +787,12 @@ func (c *runImpl) handlerFinalAnswer(ctx context.Context, msg *entity.ChunkMessa if err != nil { return err } - editMsg := &msgEntity.Message{ - ID: msg.ID, - Content: msg.Content, - ContentType: msg.ContentType, - ModelContent: string(mc), - ReasoningContent: ptr.From(msg.ReasoningContent), - Ext: msg.Ext, - } - _, err = crossmessage.DefaultSVC().Edit(ctx, editMsg) + preFinalAnswerMsg.Content = msg.Content + preFinalAnswerMsg.ReasoningContent = ptr.From(msg.ReasoningContent) + preFinalAnswerMsg.Ext = msg.Ext + preFinalAnswerMsg.ContentType = msg.ContentType + preFinalAnswerMsg.ModelContent = string(mc) + _, err = crossmessage.DefaultSVC().Create(ctx, preFinalAnswerMsg) if err != nil { return err } diff --git a/backend/domain/conversation/message/internal/dal/message.go b/backend/domain/conversation/message/internal/dal/message.go index 0e18f56c..6a28242c 100644 --- a/backend/domain/conversation/message/internal/dal/message.go +++ b/backend/domain/conversation/message/internal/dal/message.go @@ -48,6 +48,14 @@ func NewMessageDAO(db *gorm.DB, idgen idgen.IDGenerator) *MessageDAO { } } +func (dao *MessageDAO) PreCreate(ctx context.Context, msg *entity.Message) (*entity.Message, error) { + poData, err := dao.messageDO2PO(ctx, msg) + if err != nil { + return nil, err + } + return dao.messagePO2DO(poData), nil +} + func (dao *MessageDAO) Create(ctx context.Context, msg *entity.Message) (*entity.Message, error) { poData, err := dao.messageDO2PO(ctx, msg) if err != nil { @@ -205,33 +213,44 @@ func (dao *MessageDAO) Delete(ctx context.Context, msgIDs []int64, runIDs []int6 } func (dao *MessageDAO) messageDO2PO(ctx context.Context, msgDo *entity.Message) (*model.Message, error) { - id, gErr := dao.idgen.GenID(ctx) - if gErr != nil { - return nil, gErr + var id int64 + if msgDo.ID > 0 { + id = msgDo.ID + } else { + genID, gErr := dao.idgen.GenID(ctx) + if gErr != nil { + return nil, gErr + } + id = genID } msgPO := &model.Message{ - ID: id, - ConversationID: msgDo.ConversationID, - RunID: msgDo.RunID, - AgentID: msgDo.AgentID, - SectionID: msgDo.SectionID, - UserID: msgDo.UserID, - Role: string(msgDo.Role), - ContentType: string(msgDo.ContentType), - MessageType: string(msgDo.MessageType), - DisplayContent: msgDo.DisplayContent, - Content: msgDo.Content, - BrokenPosition: msgDo.Position, - Status: int32(entity.MessageStatusAvailable), - CreatedAt: time.Now().UnixMilli(), - UpdatedAt: time.Now().UnixMilli(), + ID: id, + ConversationID: msgDo.ConversationID, + RunID: msgDo.RunID, + AgentID: msgDo.AgentID, + SectionID: msgDo.SectionID, + UserID: msgDo.UserID, + Role: string(msgDo.Role), + ContentType: string(msgDo.ContentType), + MessageType: string(msgDo.MessageType), + DisplayContent: msgDo.DisplayContent, + Content: msgDo.Content, + BrokenPosition: msgDo.Position, + Status: int32(entity.MessageStatusAvailable), + CreatedAt: time.Now().UnixMilli(), + UpdatedAt: time.Now().UnixMilli(), + ReasoningContent: msgDo.ReasoningContent, } - mc, err := dao.buildModelContent(msgDo) - if err != nil { - return nil, err + if msgDo.ModelContent != "" { + msgPO.ModelContent = msgDo.ModelContent + } else { + mc, err := dao.buildModelContent(msgDo) + if err != nil { + return nil, err + } + msgPO.ModelContent = mc } - msgPO.ModelContent = mc ext, err := json.Marshal(msgDo.Ext) if err != nil { @@ -267,11 +286,13 @@ func (dao *MessageDAO) buildModelContent(msgDO *entity.Message) (string, error) one.Type = schema.ChatMessagePartTypeImageURL one.ImageURL = &schema.ChatMessageImageURL{ URL: contentData.FileData[0].Url, + URI: contentData.FileData[0].URI, } case message.InputTypeFile: one.Type = schema.ChatMessagePartTypeFileURL one.FileURL = &schema.ChatMessageFileURL{ URL: contentData.FileData[0].Url, + URI: contentData.FileData[0].URI, } case message.InputTypeVideo: one.Type = schema.ChatMessagePartTypeVideoURL @@ -282,6 +303,7 @@ func (dao *MessageDAO) buildModelContent(msgDO *entity.Message) (string, error) one.Type = schema.ChatMessagePartTypeFileURL one.AudioURL = &schema.ChatMessageAudioURL{ URL: contentData.FileData[0].Url, + URI: contentData.FileData[0].URI, } } multiContent = append(multiContent, one) diff --git a/backend/domain/conversation/message/repository/repository.go b/backend/domain/conversation/message/repository/repository.go index 97268d10..b595e9c0 100644 --- a/backend/domain/conversation/message/repository/repository.go +++ b/backend/domain/conversation/message/repository/repository.go @@ -32,6 +32,7 @@ func NewMessageRepo(db *gorm.DB, idGen idgen.IDGenerator) MessageRepo { } type MessageRepo interface { + PreCreate(ctx context.Context, msg *entity.Message) (*entity.Message, error) Create(ctx context.Context, msg *entity.Message) (*entity.Message, error) List(ctx context.Context, conversationID int64, limit int, cursor int64, direction entity.ScrollPageDirection, messageType *message.MessageType) ([]*entity.Message, bool, error) diff --git a/backend/domain/conversation/message/service/message.go b/backend/domain/conversation/message/service/message.go index c8b509aa..d1398e67 100644 --- a/backend/domain/conversation/message/service/message.go +++ b/backend/domain/conversation/message/service/message.go @@ -24,6 +24,7 @@ import ( type Message interface { List(ctx context.Context, req *entity.ListMeta) (*entity.ListResult, error) + PreCreate(ctx context.Context, req *entity.Message) (*entity.Message, error) Create(ctx context.Context, req *entity.Message) (*entity.Message, error) GetByRunIDs(ctx context.Context, conversationID int64, runIDs []int64) ([]*entity.Message, error) GetByID(ctx context.Context, id int64) (*entity.Message, error) diff --git a/backend/domain/conversation/message/service/message_impl.go b/backend/domain/conversation/message/service/message_impl.go index 8dfcd17f..edeec508 100644 --- a/backend/domain/conversation/message/service/message_impl.go +++ b/backend/domain/conversation/message/service/message_impl.go @@ -39,13 +39,14 @@ func NewService(c *Components) Message { } } +func (m *messageImpl) PreCreate(ctx context.Context, msg *entity.Message) (*entity.Message, error) { + // create message + return m.MessageRepo.PreCreate(ctx, msg) +} + func (m *messageImpl) Create(ctx context.Context, msg *entity.Message) (*entity.Message, error) { // create message - msg, err := m.MessageRepo.Create(ctx, msg) - if err != nil { - return nil, err - } - return msg, nil + return m.MessageRepo.Create(ctx, msg) } func (m *messageImpl) List(ctx context.Context, req *entity.ListMeta) (*entity.ListResult, error) { @@ -82,12 +83,7 @@ func (m *messageImpl) List(ctx context.Context, req *entity.ListMeta) (*entity.L } func (m *messageImpl) GetByRunIDs(ctx context.Context, conversationID int64, runIDs []int64) ([]*entity.Message, error) { - messageList, err := m.MessageRepo.GetByRunIDs(ctx, runIDs, "ASC") - if err != nil { - return nil, err - } - - return messageList, nil + return m.MessageRepo.GetByRunIDs(ctx, runIDs, "ASC") } func (m *messageImpl) Edit(ctx context.Context, req *entity.Message) (*entity.Message, error) { @@ -100,21 +96,11 @@ func (m *messageImpl) Edit(ctx context.Context, req *entity.Message) (*entity.Me } func (m *messageImpl) Delete(ctx context.Context, req *entity.DeleteMeta) error { - err := m.MessageRepo.Delete(ctx, req.MessageIDs, req.RunIDs) - if err != nil { - return err - } - - return nil + return m.MessageRepo.Delete(ctx, req.MessageIDs, req.RunIDs) } func (m *messageImpl) GetByID(ctx context.Context, id int64) (*entity.Message, error) { - msg, err := m.MessageRepo.GetByID(ctx, id) - if err != nil { - return nil, err - } - - return msg, nil + return m.MessageRepo.GetByID(ctx, id) } func (m *messageImpl) Broken(ctx context.Context, req *entity.BrokenMeta) error { diff --git a/backend/domain/conversation/message/service/message_test.go b/backend/domain/conversation/message/service/message_test.go index f172303f..50451327 100644 --- a/backend/domain/conversation/message/service/message_test.go +++ b/backend/domain/conversation/message/service/message_test.go @@ -19,6 +19,7 @@ package message import ( "context" "testing" + "time" "github.com/cloudwego/eino/schema" "github.com/stretchr/testify/assert" @@ -74,8 +75,10 @@ func TestCreateMessage(t *testing.T) { ctrl := gomock.NewController(t) idGen := mock.NewMockIDGenerator(ctrl) - idGen.EXPECT().GenID(gomock.Any()).Return(int64(10), nil).Times(1) - + idGen.EXPECT().GenID(gomock.Any()).DoAndReturn(func(_ context.Context) (int64, error) { + newID := time.Now().UnixNano() + return newID, nil + }).AnyTimes() mockDBGen := orm.NewMockDB() mockDBGen.AddTable(&model.Message{}) mockDB, err := mockDBGen.DB() @@ -92,10 +95,12 @@ func TestCreateMessage(t *testing.T) { imageInput := &message.FileData{ Url: "https://xxxxx.xxxx/image", Name: "test_img", + URI: "", } fileInput := &message.FileData{ Url: "https://xxxxx.xxxx/file", Name: "test_file", + URI: "", } content := []*message.InputMetaData{ { diff --git a/backend/infra/impl/storage/minio/minio.go b/backend/infra/impl/storage/minio/minio.go index 8ce841a8..85faad6f 100644 --- a/backend/infra/impl/storage/minio/minio.go +++ b/backend/infra/impl/storage/minio/minio.go @@ -201,7 +201,7 @@ func (m *minioClient) GetObjectUrl(ctx context.Context, objectKey string, opts . } if option.Expire == 0 { - option.Expire = 3600 * 24 + option.Expire = 3600 * 24 * 7 } reqParams := make(url.Values)