Fix/content reasoning

* fix: concat
* fix: concat
* fix: concat sep
* fix: concat content & url
* fix: opimized
* fix: test
* fix: test
* fix: multi content parse with uri
* fix: multi content combine
* fix: remove unused func
* fix: multi content with content & reasoning content
* fix: multi content with content & reasoning content

See merge request: !886
This commit is contained in:
李俊文 2025-07-21 11:23:48 +00:00
parent 82e55ffdb8
commit 249c23c64b
16 changed files with 263 additions and 92 deletions

View File

@ -56,6 +56,7 @@ const (
type FileData struct { type FileData struct {
Url string `json:"url"` Url string `json:"url"`
URI string `json:"uri"`
Name string `json:"name"` Name string `json:"name"`
} }

View File

@ -137,7 +137,7 @@ func Init(ctx context.Context) (err error) {
crossconversation.SetDefaultSVC(conversationImpl.InitDomainService(complexServices.conversationSVC.ConversationDomainSVC)) crossconversation.SetDefaultSVC(conversationImpl.InitDomainService(complexServices.conversationSVC.ConversationDomainSVC))
crossmessage.SetDefaultSVC(messageImpl.InitDomainService(complexServices.conversationSVC.MessageDomainSVC)) crossmessage.SetDefaultSVC(messageImpl.InitDomainService(complexServices.conversationSVC.MessageDomainSVC))
crossagentrun.SetDefaultSVC(agentrunImpl.InitDomainService(complexServices.conversationSVC.AgentRunDomainSVC)) crossagentrun.SetDefaultSVC(agentrunImpl.InitDomainService(complexServices.conversationSVC.AgentRunDomainSVC))
crossagent.SetDefaultSVC(singleagentImpl.InitDomainService(complexServices.singleAgentSVC.DomainSVC)) crossagent.SetDefaultSVC(singleagentImpl.InitDomainService(complexServices.singleAgentSVC.DomainSVC, infra.ImageXClient))
crossuser.SetDefaultSVC(crossuserImpl.InitDomainService(basicServices.userSVC.DomainSVC)) crossuser.SetDefaultSVC(crossuserImpl.InitDomainService(basicServices.userSVC.DomainSVC))
crossdatacopy.SetDefaultSVC(dataCopyImpl.InitDomainService(basicServices.infra)) crossdatacopy.SetDefaultSVC(dataCopyImpl.InitDomainService(basicServices.infra))
crosssearch.SetDefaultSVC(searchImpl.InitDomainService(complexServices.searchSVC.DomainSVC)) crosssearch.SetDefaultSVC(searchImpl.InitDomainService(complexServices.searchSVC.DomainSVC))

View File

@ -418,11 +418,12 @@ func (c *ConversationApplicationService) parseMultiContent(ctx context.Context,
resourceUrl, err := c.getUrlByUri(ctx, item.Image.Key) resourceUrl, err := c.getUrlByUri(ctx, item.Image.Key)
if err != nil { if err != nil {
logs.CtxErrorf(ctx, "failed to unescape resource url, err is %v", err)
continue continue
} }
if err != nil { if resourceUrl == "" {
logs.CtxErrorf(ctx, "failed to unescape resource url, err is %v", err) logs.CtxErrorf(ctx, "failed to unescape resource url, uri is %v", item.Image.Key)
continue continue
} }
@ -434,6 +435,7 @@ func (c *ConversationApplicationService) parseMultiContent(ctx context.Context,
FileData: []*crossDomainMessage.FileData{ FileData: []*crossDomainMessage.FileData{
{ {
Url: resourceUrl, Url: resourceUrl,
URI: item.Image.Key,
}, },
}, },
}) })
@ -451,6 +453,7 @@ func (c *ConversationApplicationService) parseMultiContent(ctx context.Context,
FileData: []*crossDomainMessage.FileData{ FileData: []*crossDomainMessage.FileData{
{ {
Url: resourceUrl, Url: resourceUrl,
URI: item.File.FileKey,
}, },
}, },
}) })

View File

@ -18,10 +18,12 @@ package conversation
import ( import (
"context" "context"
"encoding/json"
"strconv" "strconv"
"github.com/coze-dev/coze-studio/backend/api/model/conversation/common" "github.com/coze-dev/coze-studio/backend/api/model/conversation/common"
"github.com/coze-dev/coze-studio/backend/api/model/conversation/message" "github.com/coze-dev/coze-studio/backend/api/model/conversation/message"
"github.com/coze-dev/coze-studio/backend/api/model/conversation/run"
model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message" model "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/message"
"github.com/coze-dev/coze-studio/backend/application/base/ctxutil" "github.com/coze-dev/coze-studio/backend/application/base/ctxutil"
singleAgentEntity "github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity" singleAgentEntity "github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
@ -217,7 +219,7 @@ func (c *ConversationApplicationService) buildDomainMsg2VOMessage(ctx context.Co
} }
if dm.ContentType == model.ContentTypeMix && dm.DisplayContent != "" { if dm.ContentType == model.ContentTypeMix && dm.DisplayContent != "" {
cm.Content = dm.DisplayContent cm.Content = c.buildParseMessageURI(ctx, dm.DisplayContent)
} }
if dm.MessageType != model.MessageTypeQuestion { if dm.MessageType != model.MessageTypeQuestion {
@ -227,6 +229,46 @@ func (c *ConversationApplicationService) buildDomainMsg2VOMessage(ctx context.Co
return cm return cm
} }
func (c *ConversationApplicationService) buildParseMessageURI(ctx context.Context, msgContent string) string {
if msgContent == "" {
return msgContent
}
var mc *run.MixContentModel
err := json.Unmarshal([]byte(msgContent), &mc)
if err != nil {
return msgContent
}
for k, item := range mc.ItemList {
switch item.Type {
case run.ContentTypeImage:
url, pErr := c.appContext.ImageX.GetResourceURL(ctx, item.Image.Key)
if pErr == nil {
mc.ItemList[k].Image.ImageThumb.URL = url.URL
mc.ItemList[k].Image.ImageOri.URL = url.URL
}
case run.ContentTypeFile, run.ContentTypeAudio, run.ContentTypeVideo:
url, pErr := c.appContext.ImageX.GetResourceURL(ctx, item.File.FileKey)
if pErr == nil {
mc.ItemList[k].File.FileURL = url.URL
}
default:
}
}
jsonMsg, err := json.Marshal(mc)
if err != nil {
return msgContent
}
return string(jsonMsg)
}
func buildDExt2ApiExt(extra map[string]string) *message.ExtraInfo { func buildDExt2ApiExt(extra map[string]string) *message.ExtraInfo {
return &message.ExtraInfo{ return &message.ExtraInfo{
InputTokens: extra["input_tokens"], InputTokens: extra["input_tokens"],

View File

@ -24,6 +24,7 @@ import (
type Message interface { type Message interface {
GetByRunIDs(ctx context.Context, conversationID int64, runIDs []int64) ([]*message.Message, error) GetByRunIDs(ctx context.Context, conversationID int64, runIDs []int64) ([]*message.Message, error)
PreCreate(ctx context.Context, msg *message.Message) (*message.Message, error)
Create(ctx context.Context, msg *message.Message) (*message.Message, error) Create(ctx context.Context, msg *message.Message) (*message.Message, error)
Edit(ctx context.Context, msg *message.Message) (*message.Message, error) Edit(ctx context.Context, msg *message.Message) (*message.Message, error)
} }

View File

@ -49,3 +49,7 @@ func (c *impl) Create(ctx context.Context, msg *model.Message) (*model.Message,
func (c *impl) Edit(ctx context.Context, msg *model.Message) (*model.Message, error) { func (c *impl) Edit(ctx context.Context, msg *model.Message) (*model.Message, error) {
return c.DomainSVC.Edit(ctx, msg) return c.DomainSVC.Edit(ctx, msg)
} }
func (c *impl) PreCreate(ctx context.Context, msg *model.Message) (*model.Message, error) {
return c.DomainSVC.PreCreate(ctx, msg)
}

View File

@ -28,6 +28,7 @@ import (
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossagent" "github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossagent"
singleagent "github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/service" singleagent "github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/service"
"github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity" "github.com/coze-dev/coze-studio/backend/domain/conversation/message/entity"
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv" "github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices" "github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/pkg/logs" "github.com/coze-dev/coze-studio/backend/pkg/logs"
@ -37,11 +38,13 @@ var defaultSVC crossagent.SingleAgent
type impl struct { type impl struct {
DomainSVC singleagent.SingleAgent DomainSVC singleagent.SingleAgent
ImagexSVC imagex.ImageX
} }
func InitDomainService(c singleagent.SingleAgent) crossagent.SingleAgent { func InitDomainService(c singleagent.SingleAgent, imagexClient imagex.ImageX) crossagent.SingleAgent {
defaultSVC = &impl{ defaultSVC = &impl{
DomainSVC: c, DomainSVC: c,
ImagexSVC: imagexClient,
} }
return defaultSVC return defaultSVC
@ -64,14 +67,18 @@ func (c *impl) buildSingleAgentStreamExecuteReq(ctx context.Context, historyMsg
input *message.Message, agentRuntime *model.AgentRuntime, input *message.Message, agentRuntime *model.AgentRuntime,
) *model.ExecuteRequest { ) *model.ExecuteRequest {
identity := c.buildIdentity(input, agentRuntime) identity := c.buildIdentity(input, agentRuntime)
inputBuild := c.buildSchemaMessage([]*message.Message{input}) inputBuild := c.buildSchemaMessage(ctx, []*message.Message{input})
history := c.buildSchemaMessage(historyMsg) var inputSM *schema.Message
if len(inputBuild) > 0 {
inputSM = inputBuild[0]
}
history := c.buildSchemaMessage(ctx, historyMsg)
resumeInfo := c.checkResumeInfo(ctx, historyMsg) resumeInfo := c.checkResumeInfo(ctx, historyMsg)
return &model.ExecuteRequest{ return &model.ExecuteRequest{
Identity: identity, Identity: identity,
Input: inputBuild[0], Input: inputSM,
History: history, History: history,
UserID: input.UserID, UserID: input.UserID,
PreCallTools: slices.Transform(agentRuntime.PreRetrieveTools, func(tool *agentrun.Tool) *agentrun.ToolsRetriever { PreCallTools: slices.Transform(agentRuntime.PreRetrieveTools, func(tool *agentrun.Tool) *agentrun.ToolsRetriever {
@ -143,7 +150,7 @@ func (c *impl) checkResumeInfo(_ context.Context, historyMsg []*message.Message)
return resumeInfo return resumeInfo
} }
func (c *impl) buildSchemaMessage(msgs []*message.Message) []*schema.Message { func (c *impl) buildSchemaMessage(ctx context.Context, msgs []*message.Message) []*schema.Message {
schemaMessage := make([]*schema.Message, 0, len(msgs)) schemaMessage := make([]*schema.Message, 0, len(msgs))
for _, msgOne := range msgs { for _, msgOne := range msgs {
@ -158,12 +165,52 @@ func (c *impl) buildSchemaMessage(msgs []*message.Message) []*schema.Message {
if err != nil { if err != nil {
continue continue
} }
schemaMessage = append(schemaMessage, sm) schemaMessage = append(schemaMessage, c.parseMessageURI(ctx, sm))
} }
return schemaMessage return schemaMessage
} }
func (c *impl) parseMessageURI(ctx context.Context, mcMsg *schema.Message) *schema.Message {
if mcMsg.MultiContent == nil {
return mcMsg
}
for k, one := range mcMsg.MultiContent {
switch one.Type {
case schema.ChatMessagePartTypeImageURL:
if one.ImageURL.URI != "" {
url, err := c.ImagexSVC.GetResourceURL(ctx, one.ImageURL.URI)
if err == nil {
mcMsg.MultiContent[k].ImageURL.URL = url.URL
}
}
case schema.ChatMessagePartTypeFileURL:
if one.FileURL.URI != "" {
url, err := c.ImagexSVC.GetResourceURL(ctx, one.FileURL.URI)
if err == nil {
mcMsg.MultiContent[k].FileURL.URL = url.URL
}
}
case schema.ChatMessagePartTypeAudioURL:
if one.AudioURL.URI != "" {
url, err := c.ImagexSVC.GetResourceURL(ctx, one.AudioURL.URI)
if err == nil {
mcMsg.MultiContent[k].AudioURL.URL = url.URL
}
}
case schema.ChatMessagePartTypeVideoURL:
if one.VideoURL.URI != "" {
url, err := c.ImagexSVC.GetResourceURL(ctx, one.VideoURL.URI)
if err == nil {
mcMsg.MultiContent[k].VideoURL.URL = url.URL
}
}
}
}
return mcMsg
}
func (c *impl) buildIdentity(input *message.Message, agentRuntime *model.AgentRuntime) *model.AgentIdentity { func (c *impl) buildIdentity(input *message.Message, agentRuntime *model.AgentRuntime) *model.AgentIdentity {
return &model.AgentIdentity{ return &model.AgentIdentity{
AgentID: input.AgentID, AgentID: input.AgentID,

View File

@ -32,6 +32,7 @@ import (
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossmodelmgr" "github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossmodelmgr"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossworkflow" "github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossworkflow"
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity" "github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
"github.com/coze-dev/coze-studio/backend/pkg/logs" "github.com/coze-dev/coze-studio/backend/pkg/logs"
) )
@ -119,48 +120,91 @@ func (r *AgentRunner) StreamExecute(ctx context.Context, req *AgentRequest) (
func (r *AgentRunner) PreHandlerReq(ctx context.Context, req *AgentRequest) *AgentRequest { func (r *AgentRunner) PreHandlerReq(ctx context.Context, req *AgentRequest) *AgentRequest {
req.Input = r.preHandlerInput(req.Input) req.Input = r.preHandlerInput(req.Input)
req.History = r.preHandlerHistory(req.History) req.History = r.preHandlerHistory(req.History)
logs.CtxInfof(ctx, "[AgentRunner] PreHandlerReq, req: %v", conv.DebugJsonToStr(req))
return req return req
} }
func (r *AgentRunner) preHandlerInput(input *schema.Message) *schema.Message { func (r *AgentRunner) preHandlerInput(input *schema.Message) *schema.Message {
var multiContent []schema.ChatMessagePart var multiContent []schema.ChatMessagePart
if len(input.MultiContent) == 0 {
return input
}
unSupportMultiPart := make([]schema.ChatMessagePart, 0, len(input.MultiContent))
for _, v := range input.MultiContent { for _, v := range input.MultiContent {
switch v.Type { switch v.Type {
case schema.ChatMessagePartTypeImageURL: case schema.ChatMessagePartTypeImageURL:
if !slices.Contains(r.modelInfo.Meta.Capability.InputModal, modelmgr.ModalImage) { if !r.isSupportImage() {
input.Content = input.Content + ": " + v.ImageURL.URL unSupportMultiPart = append(unSupportMultiPart, v)
} else { } else {
multiContent = append(multiContent, v) multiContent = append(multiContent, v)
} }
case schema.ChatMessagePartTypeFileURL: case schema.ChatMessagePartTypeFileURL:
if !slices.Contains(r.modelInfo.Meta.Capability.InputModal, modelmgr.ModalFile) { if !r.isSupportFile() {
input.Content = input.Content + ": " + v.FileURL.URL unSupportMultiPart = append(unSupportMultiPart, v)
} else { } else {
multiContent = append(multiContent, v) multiContent = append(multiContent, v)
} }
case schema.ChatMessagePartTypeAudioURL: case schema.ChatMessagePartTypeAudioURL:
if !slices.Contains(r.modelInfo.Meta.Capability.InputModal, modelmgr.ModalAudio) { if !r.isSupportAudio() {
input.Content = input.Content + ": " + v.FileURL.URL unSupportMultiPart = append(unSupportMultiPart, v)
} else { } else {
multiContent = append(multiContent, v) multiContent = append(multiContent, v)
} }
case schema.ChatMessagePartTypeVideoURL: case schema.ChatMessagePartTypeVideoURL:
if !slices.Contains(r.modelInfo.Meta.Capability.InputModal, modelmgr.ModalVideo) { if !r.isSupportVideo() {
input.Content = input.Content + ": " + v.FileURL.URL unSupportMultiPart = append(unSupportMultiPart, v)
} else { } else {
multiContent = append(multiContent, v) multiContent = append(multiContent, v)
} }
case schema.ChatMessagePartTypeText: case schema.ChatMessagePartTypeText:
break
default: default:
multiContent = append(multiContent, v) multiContent = append(multiContent, v)
} }
} }
for _, v := range input.MultiContent {
if v.Type != schema.ChatMessagePartTypeText {
continue
}
if r.isSupportMultiContent() {
if len(multiContent) > 0 {
v.Text = concatContentString(v.Text, unSupportMultiPart)
multiContent = append(multiContent, v)
} else {
input.Content = concatContentString(v.Text, unSupportMultiPart)
}
} else {
input.Content = concatContentString(v.Text, unSupportMultiPart)
}
}
input.MultiContent = multiContent input.MultiContent = multiContent
return input return input
} }
func concatContentString(textContent string, unSupportTypeURL []schema.ChatMessagePart) string {
if len(unSupportTypeURL) == 0 {
return textContent
}
for _, v := range unSupportTypeURL {
switch v.Type {
case schema.ChatMessagePartTypeImageURL:
textContent += " this is a image:" + v.ImageURL.URL
case schema.ChatMessagePartTypeFileURL:
textContent += " this is a file:" + v.FileURL.URL
case schema.ChatMessagePartTypeAudioURL:
textContent += " this is a audio:" + v.AudioURL.URL
case schema.ChatMessagePartTypeVideoURL:
textContent += " this is a video:" + v.VideoURL.URL
default:
}
}
return textContent
}
func (r *AgentRunner) preHandlerHistory(history []*schema.Message) []*schema.Message { func (r *AgentRunner) preHandlerHistory(history []*schema.Message) []*schema.Message {
var hm []*schema.Message var hm []*schema.Message
@ -172,3 +216,19 @@ func (r *AgentRunner) preHandlerHistory(history []*schema.Message) []*schema.Mes
} }
return hm return hm
} }
func (r *AgentRunner) isSupportMultiContent() bool {
return len(r.modelInfo.Meta.Capability.InputModal) > 1
}
func (r *AgentRunner) isSupportImage() bool {
return slices.Contains(r.modelInfo.Meta.Capability.InputModal, modelmgr.ModalImage)
}
func (r *AgentRunner) isSupportFile() bool {
return slices.Contains(r.modelInfo.Meta.Capability.InputModal, modelmgr.ModalFile)
}
func (r *AgentRunner) isSupportAudio() bool {
return slices.Contains(r.modelInfo.Meta.Capability.InputModal, modelmgr.ModalAudio)
}
func (r *AgentRunner) isSupportVideo() bool {
return slices.Contains(r.modelInfo.Meta.Capability.InputModal, modelmgr.ModalVideo)
}

View File

@ -94,7 +94,7 @@ func (r *replyChunkCallback) OnError(ctx context.Context, info *callbacks.RunInf
r.sw.Send(interruptEvent, nil) r.sw.Send(interruptEvent, nil)
} else { } else {
logs.CtxErrorf(ctx, "node execute failed, component=%v, name=%v, err=%w", logs.CtxErrorf(ctx, "node execute failed, component=%v, name=%v, err=%v",
info.Component, info.Name, err) info.Component, info.Name, err)
var customErr errorx.StatusError var customErr errorx.StatusError
errMsg := "Internal server error" errMsg := "Internal server error"

View File

@ -473,6 +473,10 @@ func (c *runImpl) push(ctx context.Context, mainChan chan *entity.AgentRespEvent
} }
}() }()
reasoningContent := bytes.NewBuffer([]byte{})
var createPreMsg = true
var preFinalAnswerMsg *msgEntity.Message
for { for {
chunk, ok := <-mainChan chunk, ok := <-mainChan
if !ok || chunk == nil { if !ok || chunk == nil {
@ -505,11 +509,7 @@ func (c *runImpl) push(ctx context.Context, mainChan chan *entity.AgentRespEvent
} }
case message.MessageTypeAnswer: case message.MessageTypeAnswer:
fullContent := bytes.NewBuffer([]byte{}) fullContent := bytes.NewBuffer([]byte{})
reasoningContent := bytes.NewBuffer([]byte{})
var preMsg *msgEntity.Message
var usage *msgEntity.UsageExt var usage *msgEntity.UsageExt
var createPreMsg = true
var isToolCalls = false var isToolCalls = false
for { for {
@ -518,14 +518,15 @@ func (c *runImpl) push(ctx context.Context, mainChan chan *entity.AgentRespEvent
if receErr != nil { if receErr != nil {
if errors.Is(receErr, io.EOF) { if errors.Is(receErr, io.EOF) {
if isToolCalls && reasoningContent.String() == "" { if isToolCalls {
break break
} }
finalAnswer := c.buildSendMsg(ctx, preMsg, false, rtDependence) finalAnswer := c.buildSendMsg(ctx, preFinalAnswerMsg, false, rtDependence)
finalAnswer.Content = fullContent.String() finalAnswer.Content = fullContent.String()
finalAnswer.ReasoningContent = ptr.Of(reasoningContent.String()) finalAnswer.ReasoningContent = ptr.Of(reasoningContent.String())
hfErr := c.handlerFinalAnswer(ctx, finalAnswer, sw, usage, rtDependence) hfErr := c.handlerFinalAnswer(ctx, finalAnswer, sw, usage, rtDependence, preFinalAnswerMsg)
if hfErr != nil { if hfErr != nil {
err = hfErr err = hfErr
return return
@ -553,14 +554,14 @@ func (c *runImpl) push(ctx context.Context, mainChan chan *entity.AgentRespEvent
continue continue
} }
if createPreMsg && (len(streamMsg.ReasoningContent) > 0 || len(streamMsg.Content) > 0) { if createPreMsg && (len(streamMsg.ReasoningContent) > 0 || len(streamMsg.Content) > 0) {
preMsg, err = c.handlerPreAnswer(ctx, rtDependence) preFinalAnswerMsg, err = c.PreCreateFinalAnswer(ctx, rtDependence)
if err != nil { if err != nil {
return return
} }
createPreMsg = false createPreMsg = false
} }
sendMsg := c.buildSendMsg(ctx, preMsg, false, rtDependence) sendMsg := c.buildSendMsg(ctx, preFinalAnswerMsg, false, rtDependence)
reasoningContent.WriteString(streamMsg.ReasoningContent) reasoningContent.WriteString(streamMsg.ReasoningContent)
sendMsg.ReasoningContent = ptr.Of(streamMsg.ReasoningContent) sendMsg.ReasoningContent = ptr.Of(streamMsg.ReasoningContent)
@ -590,7 +591,7 @@ func (c *runImpl) handlerInterrupt(ctx context.Context, chunk *entity.AgentRespE
if err != nil { if err != nil {
return err return err
} }
preMsg, err := c.handlerPreAnswer(ctx, rtDependence) preMsg, err := c.PreCreateFinalAnswer(ctx, rtDependence)
if err != nil { if err != nil {
return err return err
} }
@ -612,7 +613,7 @@ func (c *runImpl) handlerInterrupt(ctx context.Context, chunk *entity.AgentRespE
c.runEvent.SendMsgEvent(entity.RunEventMessageDelta, deltaAnswer, sw) c.runEvent.SendMsgEvent(entity.RunEventMessageDelta, deltaAnswer, sw)
finalAnswer := deepcopy.Copy(deltaAnswer).(*entity.ChunkMessageItem) finalAnswer := deepcopy.Copy(deltaAnswer).(*entity.ChunkMessageItem)
err = c.handlerFinalAnswer(ctx, finalAnswer, sw, nil, rtDependence) err = c.handlerFinalAnswer(ctx, finalAnswer, sw, nil, rtDependence, preMsg)
if err != nil { if err != nil {
return err return err
} }
@ -718,7 +719,7 @@ func (c *runImpl) handlerErr(_ context.Context, err error, sw *schema.StreamWrit
}) })
} }
func (c *runImpl) handlerPreAnswer(ctx context.Context, rtDependence *runtimeDependence) (*msgEntity.Message, error) { func (c *runImpl) PreCreateFinalAnswer(ctx context.Context, rtDependence *runtimeDependence) (*msgEntity.Message, error) {
arm := rtDependence.runMeta arm := rtDependence.runMeta
msgMeta := &msgEntity.Message{ msgMeta := &msgEntity.Message{
ConversationID: arm.ConversationID, ConversationID: arm.ConversationID,
@ -747,10 +748,10 @@ func (c *runImpl) handlerPreAnswer(ctx context.Context, rtDependence *runtimeDep
} }
msgMeta.Ext = arm.Ext msgMeta.Ext = arm.Ext
return crossmessage.DefaultSVC().Create(ctx, msgMeta) return crossmessage.DefaultSVC().PreCreate(ctx, msgMeta)
} }
func (c *runImpl) handlerFinalAnswer(ctx context.Context, msg *entity.ChunkMessageItem, sw *schema.StreamWriter[*entity.AgentRunResponse], usage *msgEntity.UsageExt, rtDependence *runtimeDependence) error { func (c *runImpl) handlerFinalAnswer(ctx context.Context, msg *entity.ChunkMessageItem, sw *schema.StreamWriter[*entity.AgentRunResponse], usage *msgEntity.UsageExt, rtDependence *runtimeDependence, preFinalAnswerMsg *msgEntity.Message) error {
if len(msg.Content) == 0 && len(ptr.From(msg.ReasoningContent)) == 0 { if len(msg.Content) == 0 && len(ptr.From(msg.ReasoningContent)) == 0 {
return nil return nil
@ -786,15 +787,12 @@ func (c *runImpl) handlerFinalAnswer(ctx context.Context, msg *entity.ChunkMessa
if err != nil { if err != nil {
return err return err
} }
editMsg := &msgEntity.Message{ preFinalAnswerMsg.Content = msg.Content
ID: msg.ID, preFinalAnswerMsg.ReasoningContent = ptr.From(msg.ReasoningContent)
Content: msg.Content, preFinalAnswerMsg.Ext = msg.Ext
ContentType: msg.ContentType, preFinalAnswerMsg.ContentType = msg.ContentType
ModelContent: string(mc), preFinalAnswerMsg.ModelContent = string(mc)
ReasoningContent: ptr.From(msg.ReasoningContent), _, err = crossmessage.DefaultSVC().Create(ctx, preFinalAnswerMsg)
Ext: msg.Ext,
}
_, err = crossmessage.DefaultSVC().Edit(ctx, editMsg)
if err != nil { if err != nil {
return err return err
} }

View File

@ -48,6 +48,14 @@ func NewMessageDAO(db *gorm.DB, idgen idgen.IDGenerator) *MessageDAO {
} }
} }
func (dao *MessageDAO) PreCreate(ctx context.Context, msg *entity.Message) (*entity.Message, error) {
poData, err := dao.messageDO2PO(ctx, msg)
if err != nil {
return nil, err
}
return dao.messagePO2DO(poData), nil
}
func (dao *MessageDAO) Create(ctx context.Context, msg *entity.Message) (*entity.Message, error) { func (dao *MessageDAO) Create(ctx context.Context, msg *entity.Message) (*entity.Message, error) {
poData, err := dao.messageDO2PO(ctx, msg) poData, err := dao.messageDO2PO(ctx, msg)
if err != nil { if err != nil {
@ -205,10 +213,16 @@ func (dao *MessageDAO) Delete(ctx context.Context, msgIDs []int64, runIDs []int6
} }
func (dao *MessageDAO) messageDO2PO(ctx context.Context, msgDo *entity.Message) (*model.Message, error) { func (dao *MessageDAO) messageDO2PO(ctx context.Context, msgDo *entity.Message) (*model.Message, error) {
id, gErr := dao.idgen.GenID(ctx) var id int64
if msgDo.ID > 0 {
id = msgDo.ID
} else {
genID, gErr := dao.idgen.GenID(ctx)
if gErr != nil { if gErr != nil {
return nil, gErr return nil, gErr
} }
id = genID
}
msgPO := &model.Message{ msgPO := &model.Message{
ID: id, ID: id,
ConversationID: msgDo.ConversationID, ConversationID: msgDo.ConversationID,
@ -225,13 +239,18 @@ func (dao *MessageDAO) messageDO2PO(ctx context.Context, msgDo *entity.Message)
Status: int32(entity.MessageStatusAvailable), Status: int32(entity.MessageStatusAvailable),
CreatedAt: time.Now().UnixMilli(), CreatedAt: time.Now().UnixMilli(),
UpdatedAt: time.Now().UnixMilli(), UpdatedAt: time.Now().UnixMilli(),
ReasoningContent: msgDo.ReasoningContent,
} }
if msgDo.ModelContent != "" {
msgPO.ModelContent = msgDo.ModelContent
} else {
mc, err := dao.buildModelContent(msgDo) mc, err := dao.buildModelContent(msgDo)
if err != nil { if err != nil {
return nil, err return nil, err
} }
msgPO.ModelContent = mc msgPO.ModelContent = mc
}
ext, err := json.Marshal(msgDo.Ext) ext, err := json.Marshal(msgDo.Ext)
if err != nil { if err != nil {
@ -267,11 +286,13 @@ func (dao *MessageDAO) buildModelContent(msgDO *entity.Message) (string, error)
one.Type = schema.ChatMessagePartTypeImageURL one.Type = schema.ChatMessagePartTypeImageURL
one.ImageURL = &schema.ChatMessageImageURL{ one.ImageURL = &schema.ChatMessageImageURL{
URL: contentData.FileData[0].Url, URL: contentData.FileData[0].Url,
URI: contentData.FileData[0].URI,
} }
case message.InputTypeFile: case message.InputTypeFile:
one.Type = schema.ChatMessagePartTypeFileURL one.Type = schema.ChatMessagePartTypeFileURL
one.FileURL = &schema.ChatMessageFileURL{ one.FileURL = &schema.ChatMessageFileURL{
URL: contentData.FileData[0].Url, URL: contentData.FileData[0].Url,
URI: contentData.FileData[0].URI,
} }
case message.InputTypeVideo: case message.InputTypeVideo:
one.Type = schema.ChatMessagePartTypeVideoURL one.Type = schema.ChatMessagePartTypeVideoURL
@ -282,6 +303,7 @@ func (dao *MessageDAO) buildModelContent(msgDO *entity.Message) (string, error)
one.Type = schema.ChatMessagePartTypeFileURL one.Type = schema.ChatMessagePartTypeFileURL
one.AudioURL = &schema.ChatMessageAudioURL{ one.AudioURL = &schema.ChatMessageAudioURL{
URL: contentData.FileData[0].Url, URL: contentData.FileData[0].Url,
URI: contentData.FileData[0].URI,
} }
} }
multiContent = append(multiContent, one) multiContent = append(multiContent, one)

View File

@ -32,6 +32,7 @@ func NewMessageRepo(db *gorm.DB, idGen idgen.IDGenerator) MessageRepo {
} }
type MessageRepo interface { type MessageRepo interface {
PreCreate(ctx context.Context, msg *entity.Message) (*entity.Message, error)
Create(ctx context.Context, msg *entity.Message) (*entity.Message, error) Create(ctx context.Context, msg *entity.Message) (*entity.Message, error)
List(ctx context.Context, conversationID int64, limit int, cursor int64, List(ctx context.Context, conversationID int64, limit int, cursor int64,
direction entity.ScrollPageDirection, messageType *message.MessageType) ([]*entity.Message, bool, error) direction entity.ScrollPageDirection, messageType *message.MessageType) ([]*entity.Message, bool, error)

View File

@ -24,6 +24,7 @@ import (
type Message interface { type Message interface {
List(ctx context.Context, req *entity.ListMeta) (*entity.ListResult, error) List(ctx context.Context, req *entity.ListMeta) (*entity.ListResult, error)
PreCreate(ctx context.Context, req *entity.Message) (*entity.Message, error)
Create(ctx context.Context, req *entity.Message) (*entity.Message, error) Create(ctx context.Context, req *entity.Message) (*entity.Message, error)
GetByRunIDs(ctx context.Context, conversationID int64, runIDs []int64) ([]*entity.Message, error) GetByRunIDs(ctx context.Context, conversationID int64, runIDs []int64) ([]*entity.Message, error)
GetByID(ctx context.Context, id int64) (*entity.Message, error) GetByID(ctx context.Context, id int64) (*entity.Message, error)

View File

@ -39,13 +39,14 @@ func NewService(c *Components) Message {
} }
} }
func (m *messageImpl) PreCreate(ctx context.Context, msg *entity.Message) (*entity.Message, error) {
// create message
return m.MessageRepo.PreCreate(ctx, msg)
}
func (m *messageImpl) Create(ctx context.Context, msg *entity.Message) (*entity.Message, error) { func (m *messageImpl) Create(ctx context.Context, msg *entity.Message) (*entity.Message, error) {
// create message // create message
msg, err := m.MessageRepo.Create(ctx, msg) return m.MessageRepo.Create(ctx, msg)
if err != nil {
return nil, err
}
return msg, nil
} }
func (m *messageImpl) List(ctx context.Context, req *entity.ListMeta) (*entity.ListResult, error) { func (m *messageImpl) List(ctx context.Context, req *entity.ListMeta) (*entity.ListResult, error) {
@ -82,12 +83,7 @@ func (m *messageImpl) List(ctx context.Context, req *entity.ListMeta) (*entity.L
} }
func (m *messageImpl) GetByRunIDs(ctx context.Context, conversationID int64, runIDs []int64) ([]*entity.Message, error) { func (m *messageImpl) GetByRunIDs(ctx context.Context, conversationID int64, runIDs []int64) ([]*entity.Message, error) {
messageList, err := m.MessageRepo.GetByRunIDs(ctx, runIDs, "ASC") return m.MessageRepo.GetByRunIDs(ctx, runIDs, "ASC")
if err != nil {
return nil, err
}
return messageList, nil
} }
func (m *messageImpl) Edit(ctx context.Context, req *entity.Message) (*entity.Message, error) { func (m *messageImpl) Edit(ctx context.Context, req *entity.Message) (*entity.Message, error) {
@ -100,21 +96,11 @@ func (m *messageImpl) Edit(ctx context.Context, req *entity.Message) (*entity.Me
} }
func (m *messageImpl) Delete(ctx context.Context, req *entity.DeleteMeta) error { func (m *messageImpl) Delete(ctx context.Context, req *entity.DeleteMeta) error {
err := m.MessageRepo.Delete(ctx, req.MessageIDs, req.RunIDs) return m.MessageRepo.Delete(ctx, req.MessageIDs, req.RunIDs)
if err != nil {
return err
}
return nil
} }
func (m *messageImpl) GetByID(ctx context.Context, id int64) (*entity.Message, error) { func (m *messageImpl) GetByID(ctx context.Context, id int64) (*entity.Message, error) {
msg, err := m.MessageRepo.GetByID(ctx, id) return m.MessageRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
return msg, nil
} }
func (m *messageImpl) Broken(ctx context.Context, req *entity.BrokenMeta) error { func (m *messageImpl) Broken(ctx context.Context, req *entity.BrokenMeta) error {

View File

@ -19,6 +19,7 @@ package message
import ( import (
"context" "context"
"testing" "testing"
"time"
"github.com/cloudwego/eino/schema" "github.com/cloudwego/eino/schema"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -74,8 +75,10 @@ func TestCreateMessage(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
idGen := mock.NewMockIDGenerator(ctrl) idGen := mock.NewMockIDGenerator(ctrl)
idGen.EXPECT().GenID(gomock.Any()).Return(int64(10), nil).Times(1) idGen.EXPECT().GenID(gomock.Any()).DoAndReturn(func(_ context.Context) (int64, error) {
newID := time.Now().UnixNano()
return newID, nil
}).AnyTimes()
mockDBGen := orm.NewMockDB() mockDBGen := orm.NewMockDB()
mockDBGen.AddTable(&model.Message{}) mockDBGen.AddTable(&model.Message{})
mockDB, err := mockDBGen.DB() mockDB, err := mockDBGen.DB()
@ -92,10 +95,12 @@ func TestCreateMessage(t *testing.T) {
imageInput := &message.FileData{ imageInput := &message.FileData{
Url: "https://xxxxx.xxxx/image", Url: "https://xxxxx.xxxx/image",
Name: "test_img", Name: "test_img",
URI: "",
} }
fileInput := &message.FileData{ fileInput := &message.FileData{
Url: "https://xxxxx.xxxx/file", Url: "https://xxxxx.xxxx/file",
Name: "test_file", Name: "test_file",
URI: "",
} }
content := []*message.InputMetaData{ content := []*message.InputMetaData{
{ {

View File

@ -201,7 +201,7 @@ func (m *minioClient) GetObjectUrl(ctx context.Context, objectKey string, opts .
} }
if option.Expire == 0 { if option.Expire == 0 {
option.Expire = 3600 * 24 option.Expire = 3600 * 24 * 7
} }
reqParams := make(url.Values) reqParams := make(url.Values)