Fix/content reasoning
* fix: concat * fix: concat * fix: concat sep * fix: concat content & url * fix: opimized * fix: test * fix: test * fix: multi content parse with uri * fix: multi content combine * fix: remove unused func * fix: multi content with content & reasoning content * fix: multi content with content & reasoning content See merge request: !886
This commit is contained in:
@@ -473,6 +473,10 @@ func (c *runImpl) push(ctx context.Context, mainChan chan *entity.AgentRespEvent
|
||||
}
|
||||
}()
|
||||
|
||||
reasoningContent := bytes.NewBuffer([]byte{})
|
||||
var createPreMsg = true
|
||||
var preFinalAnswerMsg *msgEntity.Message
|
||||
|
||||
for {
|
||||
chunk, ok := <-mainChan
|
||||
if !ok || chunk == nil {
|
||||
@@ -505,11 +509,7 @@ func (c *runImpl) push(ctx context.Context, mainChan chan *entity.AgentRespEvent
|
||||
}
|
||||
case message.MessageTypeAnswer:
|
||||
fullContent := bytes.NewBuffer([]byte{})
|
||||
reasoningContent := bytes.NewBuffer([]byte{})
|
||||
|
||||
var preMsg *msgEntity.Message
|
||||
var usage *msgEntity.UsageExt
|
||||
var createPreMsg = true
|
||||
var isToolCalls = false
|
||||
|
||||
for {
|
||||
@@ -518,14 +518,15 @@ func (c *runImpl) push(ctx context.Context, mainChan chan *entity.AgentRespEvent
|
||||
if receErr != nil {
|
||||
if errors.Is(receErr, io.EOF) {
|
||||
|
||||
if isToolCalls && reasoningContent.String() == "" {
|
||||
if isToolCalls {
|
||||
break
|
||||
}
|
||||
|
||||
finalAnswer := c.buildSendMsg(ctx, preMsg, false, rtDependence)
|
||||
finalAnswer := c.buildSendMsg(ctx, preFinalAnswerMsg, false, rtDependence)
|
||||
|
||||
finalAnswer.Content = fullContent.String()
|
||||
finalAnswer.ReasoningContent = ptr.Of(reasoningContent.String())
|
||||
hfErr := c.handlerFinalAnswer(ctx, finalAnswer, sw, usage, rtDependence)
|
||||
hfErr := c.handlerFinalAnswer(ctx, finalAnswer, sw, usage, rtDependence, preFinalAnswerMsg)
|
||||
if hfErr != nil {
|
||||
err = hfErr
|
||||
return
|
||||
@@ -553,14 +554,14 @@ func (c *runImpl) push(ctx context.Context, mainChan chan *entity.AgentRespEvent
|
||||
continue
|
||||
}
|
||||
if createPreMsg && (len(streamMsg.ReasoningContent) > 0 || len(streamMsg.Content) > 0) {
|
||||
preMsg, err = c.handlerPreAnswer(ctx, rtDependence)
|
||||
preFinalAnswerMsg, err = c.PreCreateFinalAnswer(ctx, rtDependence)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
createPreMsg = false
|
||||
}
|
||||
|
||||
sendMsg := c.buildSendMsg(ctx, preMsg, false, rtDependence)
|
||||
sendMsg := c.buildSendMsg(ctx, preFinalAnswerMsg, false, rtDependence)
|
||||
reasoningContent.WriteString(streamMsg.ReasoningContent)
|
||||
sendMsg.ReasoningContent = ptr.Of(streamMsg.ReasoningContent)
|
||||
|
||||
@@ -590,7 +591,7 @@ func (c *runImpl) handlerInterrupt(ctx context.Context, chunk *entity.AgentRespE
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
preMsg, err := c.handlerPreAnswer(ctx, rtDependence)
|
||||
preMsg, err := c.PreCreateFinalAnswer(ctx, rtDependence)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -612,7 +613,7 @@ func (c *runImpl) handlerInterrupt(ctx context.Context, chunk *entity.AgentRespE
|
||||
c.runEvent.SendMsgEvent(entity.RunEventMessageDelta, deltaAnswer, sw)
|
||||
finalAnswer := deepcopy.Copy(deltaAnswer).(*entity.ChunkMessageItem)
|
||||
|
||||
err = c.handlerFinalAnswer(ctx, finalAnswer, sw, nil, rtDependence)
|
||||
err = c.handlerFinalAnswer(ctx, finalAnswer, sw, nil, rtDependence, preMsg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -718,7 +719,7 @@ func (c *runImpl) handlerErr(_ context.Context, err error, sw *schema.StreamWrit
|
||||
})
|
||||
}
|
||||
|
||||
func (c *runImpl) handlerPreAnswer(ctx context.Context, rtDependence *runtimeDependence) (*msgEntity.Message, error) {
|
||||
func (c *runImpl) PreCreateFinalAnswer(ctx context.Context, rtDependence *runtimeDependence) (*msgEntity.Message, error) {
|
||||
arm := rtDependence.runMeta
|
||||
msgMeta := &msgEntity.Message{
|
||||
ConversationID: arm.ConversationID,
|
||||
@@ -747,10 +748,10 @@ func (c *runImpl) handlerPreAnswer(ctx context.Context, rtDependence *runtimeDep
|
||||
}
|
||||
|
||||
msgMeta.Ext = arm.Ext
|
||||
return crossmessage.DefaultSVC().Create(ctx, msgMeta)
|
||||
return crossmessage.DefaultSVC().PreCreate(ctx, msgMeta)
|
||||
}
|
||||
|
||||
func (c *runImpl) handlerFinalAnswer(ctx context.Context, msg *entity.ChunkMessageItem, sw *schema.StreamWriter[*entity.AgentRunResponse], usage *msgEntity.UsageExt, rtDependence *runtimeDependence) error {
|
||||
func (c *runImpl) handlerFinalAnswer(ctx context.Context, msg *entity.ChunkMessageItem, sw *schema.StreamWriter[*entity.AgentRunResponse], usage *msgEntity.UsageExt, rtDependence *runtimeDependence, preFinalAnswerMsg *msgEntity.Message) error {
|
||||
|
||||
if len(msg.Content) == 0 && len(ptr.From(msg.ReasoningContent)) == 0 {
|
||||
return nil
|
||||
@@ -786,15 +787,12 @@ func (c *runImpl) handlerFinalAnswer(ctx context.Context, msg *entity.ChunkMessa
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
editMsg := &msgEntity.Message{
|
||||
ID: msg.ID,
|
||||
Content: msg.Content,
|
||||
ContentType: msg.ContentType,
|
||||
ModelContent: string(mc),
|
||||
ReasoningContent: ptr.From(msg.ReasoningContent),
|
||||
Ext: msg.Ext,
|
||||
}
|
||||
_, err = crossmessage.DefaultSVC().Edit(ctx, editMsg)
|
||||
preFinalAnswerMsg.Content = msg.Content
|
||||
preFinalAnswerMsg.ReasoningContent = ptr.From(msg.ReasoningContent)
|
||||
preFinalAnswerMsg.Ext = msg.Ext
|
||||
preFinalAnswerMsg.ContentType = msg.ContentType
|
||||
preFinalAnswerMsg.ModelContent = string(mc)
|
||||
_, err = crossmessage.DefaultSVC().Create(ctx, preFinalAnswerMsg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -48,6 +48,14 @@ func NewMessageDAO(db *gorm.DB, idgen idgen.IDGenerator) *MessageDAO {
|
||||
}
|
||||
}
|
||||
|
||||
func (dao *MessageDAO) PreCreate(ctx context.Context, msg *entity.Message) (*entity.Message, error) {
|
||||
poData, err := dao.messageDO2PO(ctx, msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dao.messagePO2DO(poData), nil
|
||||
}
|
||||
|
||||
func (dao *MessageDAO) Create(ctx context.Context, msg *entity.Message) (*entity.Message, error) {
|
||||
poData, err := dao.messageDO2PO(ctx, msg)
|
||||
if err != nil {
|
||||
@@ -205,33 +213,44 @@ func (dao *MessageDAO) Delete(ctx context.Context, msgIDs []int64, runIDs []int6
|
||||
}
|
||||
|
||||
func (dao *MessageDAO) messageDO2PO(ctx context.Context, msgDo *entity.Message) (*model.Message, error) {
|
||||
id, gErr := dao.idgen.GenID(ctx)
|
||||
if gErr != nil {
|
||||
return nil, gErr
|
||||
var id int64
|
||||
if msgDo.ID > 0 {
|
||||
id = msgDo.ID
|
||||
} else {
|
||||
genID, gErr := dao.idgen.GenID(ctx)
|
||||
if gErr != nil {
|
||||
return nil, gErr
|
||||
}
|
||||
id = genID
|
||||
}
|
||||
msgPO := &model.Message{
|
||||
ID: id,
|
||||
ConversationID: msgDo.ConversationID,
|
||||
RunID: msgDo.RunID,
|
||||
AgentID: msgDo.AgentID,
|
||||
SectionID: msgDo.SectionID,
|
||||
UserID: msgDo.UserID,
|
||||
Role: string(msgDo.Role),
|
||||
ContentType: string(msgDo.ContentType),
|
||||
MessageType: string(msgDo.MessageType),
|
||||
DisplayContent: msgDo.DisplayContent,
|
||||
Content: msgDo.Content,
|
||||
BrokenPosition: msgDo.Position,
|
||||
Status: int32(entity.MessageStatusAvailable),
|
||||
CreatedAt: time.Now().UnixMilli(),
|
||||
UpdatedAt: time.Now().UnixMilli(),
|
||||
ID: id,
|
||||
ConversationID: msgDo.ConversationID,
|
||||
RunID: msgDo.RunID,
|
||||
AgentID: msgDo.AgentID,
|
||||
SectionID: msgDo.SectionID,
|
||||
UserID: msgDo.UserID,
|
||||
Role: string(msgDo.Role),
|
||||
ContentType: string(msgDo.ContentType),
|
||||
MessageType: string(msgDo.MessageType),
|
||||
DisplayContent: msgDo.DisplayContent,
|
||||
Content: msgDo.Content,
|
||||
BrokenPosition: msgDo.Position,
|
||||
Status: int32(entity.MessageStatusAvailable),
|
||||
CreatedAt: time.Now().UnixMilli(),
|
||||
UpdatedAt: time.Now().UnixMilli(),
|
||||
ReasoningContent: msgDo.ReasoningContent,
|
||||
}
|
||||
|
||||
mc, err := dao.buildModelContent(msgDo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if msgDo.ModelContent != "" {
|
||||
msgPO.ModelContent = msgDo.ModelContent
|
||||
} else {
|
||||
mc, err := dao.buildModelContent(msgDo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msgPO.ModelContent = mc
|
||||
}
|
||||
msgPO.ModelContent = mc
|
||||
|
||||
ext, err := json.Marshal(msgDo.Ext)
|
||||
if err != nil {
|
||||
@@ -267,11 +286,13 @@ func (dao *MessageDAO) buildModelContent(msgDO *entity.Message) (string, error)
|
||||
one.Type = schema.ChatMessagePartTypeImageURL
|
||||
one.ImageURL = &schema.ChatMessageImageURL{
|
||||
URL: contentData.FileData[0].Url,
|
||||
URI: contentData.FileData[0].URI,
|
||||
}
|
||||
case message.InputTypeFile:
|
||||
one.Type = schema.ChatMessagePartTypeFileURL
|
||||
one.FileURL = &schema.ChatMessageFileURL{
|
||||
URL: contentData.FileData[0].Url,
|
||||
URI: contentData.FileData[0].URI,
|
||||
}
|
||||
case message.InputTypeVideo:
|
||||
one.Type = schema.ChatMessagePartTypeVideoURL
|
||||
@@ -282,6 +303,7 @@ func (dao *MessageDAO) buildModelContent(msgDO *entity.Message) (string, error)
|
||||
one.Type = schema.ChatMessagePartTypeFileURL
|
||||
one.AudioURL = &schema.ChatMessageAudioURL{
|
||||
URL: contentData.FileData[0].Url,
|
||||
URI: contentData.FileData[0].URI,
|
||||
}
|
||||
}
|
||||
multiContent = append(multiContent, one)
|
||||
|
||||
@@ -32,6 +32,7 @@ func NewMessageRepo(db *gorm.DB, idGen idgen.IDGenerator) MessageRepo {
|
||||
}
|
||||
|
||||
type MessageRepo interface {
|
||||
PreCreate(ctx context.Context, msg *entity.Message) (*entity.Message, error)
|
||||
Create(ctx context.Context, msg *entity.Message) (*entity.Message, error)
|
||||
List(ctx context.Context, conversationID int64, limit int, cursor int64,
|
||||
direction entity.ScrollPageDirection, messageType *message.MessageType) ([]*entity.Message, bool, error)
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
|
||||
type Message interface {
|
||||
List(ctx context.Context, req *entity.ListMeta) (*entity.ListResult, error)
|
||||
PreCreate(ctx context.Context, req *entity.Message) (*entity.Message, error)
|
||||
Create(ctx context.Context, req *entity.Message) (*entity.Message, error)
|
||||
GetByRunIDs(ctx context.Context, conversationID int64, runIDs []int64) ([]*entity.Message, error)
|
||||
GetByID(ctx context.Context, id int64) (*entity.Message, error)
|
||||
|
||||
@@ -39,13 +39,14 @@ func NewService(c *Components) Message {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *messageImpl) PreCreate(ctx context.Context, msg *entity.Message) (*entity.Message, error) {
|
||||
// create message
|
||||
return m.MessageRepo.PreCreate(ctx, msg)
|
||||
}
|
||||
|
||||
func (m *messageImpl) Create(ctx context.Context, msg *entity.Message) (*entity.Message, error) {
|
||||
// create message
|
||||
msg, err := m.MessageRepo.Create(ctx, msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return msg, nil
|
||||
return m.MessageRepo.Create(ctx, msg)
|
||||
}
|
||||
|
||||
func (m *messageImpl) List(ctx context.Context, req *entity.ListMeta) (*entity.ListResult, error) {
|
||||
@@ -82,12 +83,7 @@ func (m *messageImpl) List(ctx context.Context, req *entity.ListMeta) (*entity.L
|
||||
}
|
||||
|
||||
func (m *messageImpl) GetByRunIDs(ctx context.Context, conversationID int64, runIDs []int64) ([]*entity.Message, error) {
|
||||
messageList, err := m.MessageRepo.GetByRunIDs(ctx, runIDs, "ASC")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return messageList, nil
|
||||
return m.MessageRepo.GetByRunIDs(ctx, runIDs, "ASC")
|
||||
}
|
||||
|
||||
func (m *messageImpl) Edit(ctx context.Context, req *entity.Message) (*entity.Message, error) {
|
||||
@@ -100,21 +96,11 @@ func (m *messageImpl) Edit(ctx context.Context, req *entity.Message) (*entity.Me
|
||||
}
|
||||
|
||||
func (m *messageImpl) Delete(ctx context.Context, req *entity.DeleteMeta) error {
|
||||
err := m.MessageRepo.Delete(ctx, req.MessageIDs, req.RunIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return m.MessageRepo.Delete(ctx, req.MessageIDs, req.RunIDs)
|
||||
}
|
||||
|
||||
func (m *messageImpl) GetByID(ctx context.Context, id int64) (*entity.Message, error) {
|
||||
msg, err := m.MessageRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
return m.MessageRepo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
func (m *messageImpl) Broken(ctx context.Context, req *entity.BrokenMeta) error {
|
||||
|
||||
@@ -19,6 +19,7 @@ package message
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -74,8 +75,10 @@ func TestCreateMessage(t *testing.T) {
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
idGen := mock.NewMockIDGenerator(ctrl)
|
||||
idGen.EXPECT().GenID(gomock.Any()).Return(int64(10), nil).Times(1)
|
||||
|
||||
idGen.EXPECT().GenID(gomock.Any()).DoAndReturn(func(_ context.Context) (int64, error) {
|
||||
newID := time.Now().UnixNano()
|
||||
return newID, nil
|
||||
}).AnyTimes()
|
||||
mockDBGen := orm.NewMockDB()
|
||||
mockDBGen.AddTable(&model.Message{})
|
||||
mockDB, err := mockDBGen.DB()
|
||||
@@ -92,10 +95,12 @@ func TestCreateMessage(t *testing.T) {
|
||||
imageInput := &message.FileData{
|
||||
Url: "https://xxxxx.xxxx/image",
|
||||
Name: "test_img",
|
||||
URI: "",
|
||||
}
|
||||
fileInput := &message.FileData{
|
||||
Url: "https://xxxxx.xxxx/file",
|
||||
Name: "test_file",
|
||||
URI: "",
|
||||
}
|
||||
content := []*message.InputMetaData{
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user