fix(singleagent): support workflow output node (#662)

This commit is contained in:
junwen-lee 2025-08-11 10:49:51 +08:00 committed by GitHub
parent a21e41b89d
commit efc6e55fe5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 391 additions and 101 deletions

1
.gitignore vendored
View File

@ -34,6 +34,7 @@ output/*
# Vscode files # Vscode files
.vscode/settings.json .vscode/settings.json
.vscode/launch.json
/patches /patches
/oldimpl /oldimpl

View File

@ -29,6 +29,7 @@ import (
"github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/protocol/consts" "github.com/cloudwego/hertz/pkg/protocol/consts"
"github.com/coze-dev/coze-studio/backend/api/model/app/bot_open_api" "github.com/coze-dev/coze-studio/backend/api/model/app/bot_open_api"
) )

View File

@ -23,6 +23,7 @@ import (
"github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/protocol/consts" "github.com/cloudwego/hertz/pkg/protocol/consts"
"github.com/coze-dev/coze-studio/backend/api/model/data/database/table" "github.com/coze-dev/coze-studio/backend/api/model/data/database/table"
"github.com/coze-dev/coze-studio/backend/api/model/data/knowledge" "github.com/coze-dev/coze-studio/backend/api/model/data/knowledge"
"github.com/coze-dev/coze-studio/backend/application/memory" "github.com/coze-dev/coze-studio/backend/application/memory"

View File

@ -24,6 +24,7 @@ import (
"github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/protocol/consts" "github.com/cloudwego/hertz/pkg/protocol/consts"
"github.com/coze-dev/coze-studio/backend/api/model/app/intelligence" "github.com/coze-dev/coze-studio/backend/api/model/app/intelligence"
"github.com/coze-dev/coze-studio/backend/api/model/app/intelligence/common" "github.com/coze-dev/coze-studio/backend/api/model/app/intelligence/common"
project "github.com/coze-dev/coze-studio/backend/api/model/app/intelligence/project" project "github.com/coze-dev/coze-studio/backend/api/model/app/intelligence/project"

View File

@ -23,6 +23,7 @@ import (
"github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/protocol/consts" "github.com/cloudwego/hertz/pkg/protocol/consts"
"github.com/coze-dev/coze-studio/backend/api/model/playground" "github.com/coze-dev/coze-studio/backend/api/model/playground"
appApplication "github.com/coze-dev/coze-studio/backend/application/app" appApplication "github.com/coze-dev/coze-studio/backend/application/app"
"github.com/coze-dev/coze-studio/backend/application/prompt" "github.com/coze-dev/coze-studio/backend/application/prompt"

View File

@ -101,4 +101,7 @@ const (
MessageTypeFlowUp MessageType = "follow_up" MessageTypeFlowUp MessageType = "follow_up"
MessageTypeInterrupt MessageType = "interrupt" MessageTypeInterrupt MessageType = "interrupt"
MessageTypeVerbose MessageType = "verbose" MessageTypeVerbose MessageType = "verbose"
MessageTypeToolAsAnswer MessageType = "tool_as_answer"
MessageTypeToolMidAnswer MessageType = "tool_mid_answer"
) )

View File

@ -39,6 +39,7 @@ type EventType string
const ( const (
EventTypeOfChatModelAnswer EventType = "chatmodel_answer" EventTypeOfChatModelAnswer EventType = "chatmodel_answer"
EventTypeOfToolsAsChatModelStream EventType = "tools_as_chatmodel_answer" EventTypeOfToolsAsChatModelStream EventType = "tools_as_chatmodel_answer"
EventTypeOfToolMidAnswer EventType = "tool_mid_answer"
EventTypeOfToolsMessage EventType = "tools_message" EventTypeOfToolsMessage EventType = "tools_message"
EventTypeOfFuncCall EventType = "func_call" EventTypeOfFuncCall EventType = "func_call"
EventTypeOfSuggest EventType = "suggest" EventTypeOfSuggest EventType = "suggest"
@ -49,6 +50,9 @@ const (
type AgentEvent struct { type AgentEvent struct {
EventType EventType EventType EventType
ToolMidAnswer *schema.StreamReader[*schema.Message]
ToolAsChatModelAnswer *schema.StreamReader[*schema.Message]
ChatModelAnswer *schema.StreamReader[*schema.Message] ChatModelAnswer *schema.StreamReader[*schema.Message]
ToolsMessage []*schema.Message ToolsMessage []*schema.Message
FuncCall *schema.Message FuncCall *schema.Message

View File

@ -19,9 +19,12 @@ package crossworkflow
import ( import (
"context" "context"
"github.com/cloudwego/eino/compose"
einoCompose "github.com/cloudwego/eino/compose" einoCompose "github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/domain/workflow" "github.com/coze-dev/coze-studio/backend/domain/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
workflowEntity "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" workflowEntity "github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
) )
@ -37,10 +40,18 @@ type Workflow interface {
GetWorkflowIDsByAppID(ctx context.Context, appID int64) ([]int64, error) GetWorkflowIDsByAppID(ctx context.Context, appID int64) ([]int64, error)
SyncExecuteWorkflow(ctx context.Context, config vo.ExecuteConfig, input map[string]any) (*workflowEntity.WorkflowExecution, vo.TerminatePlan, error) SyncExecuteWorkflow(ctx context.Context, config vo.ExecuteConfig, input map[string]any) (*workflowEntity.WorkflowExecution, vo.TerminatePlan, error)
WithExecuteConfig(cfg vo.ExecuteConfig) einoCompose.Option WithExecuteConfig(cfg vo.ExecuteConfig) einoCompose.Option
WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message])
} }
type ExecuteConfig = vo.ExecuteConfig type ExecuteConfig = vo.ExecuteConfig
type ExecuteMode = vo.ExecuteMode type ExecuteMode = vo.ExecuteMode
type NodeType = entity.NodeType
type WorkflowMessage = entity.Message
const (
NodeTypeOutputEmitter NodeType = "OutputEmitter"
)
const ( const (
ExecuteModeDebug ExecuteMode = "debug" ExecuteModeDebug ExecuteMode = "debug"

View File

@ -165,6 +165,10 @@ func (c *impl) buildSchemaMessage(ctx context.Context, msgs []*message.Message)
if err != nil { if err != nil {
continue continue
} }
if len(sm.ReasoningContent) > 0 {
sm.ReasoningContent = ""
}
schemaMessage = append(schemaMessage, c.parseMessageURI(ctx, sm)) schemaMessage = append(schemaMessage, c.parseMessageURI(ctx, sm))
} }

View File

@ -19,10 +19,13 @@ package workflow
import ( import (
"context" "context"
"github.com/cloudwego/eino/compose"
einoCompose "github.com/cloudwego/eino/compose" einoCompose "github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossworkflow" "github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossworkflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow" "github.com/coze-dev/coze-studio/backend/domain/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
workflowEntity "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" workflowEntity "github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
@ -72,6 +75,10 @@ func (i *impl) WithExecuteConfig(cfg vo.ExecuteConfig) einoCompose.Option {
return i.DomainSVC.WithExecuteConfig(cfg) return i.DomainSVC.WithExecuteConfig(cfg)
} }
func (i *impl) WithMessagePipe() (compose.Option, *schema.StreamReader[*entity.Message]) {
return i.DomainSVC.WithMessagePipe()
}
func (i *impl) GetWorkflowIDsByAppID(ctx context.Context, appID int64) ([]int64, error) { func (i *impl) GetWorkflowIDsByAppID(ctx context.Context, appID int64) ([]int64, error) {
metas, _, err := i.DomainSVC.MGet(ctx, &vo.MGetPolicy{ metas, _, err := i.DomainSVC.MGet(ctx, &vo.MGetPolicy{
MetaQuery: vo.MetaQuery{ MetaQuery: vo.MetaQuery{

View File

@ -19,10 +19,11 @@ package plugin
import ( import (
"testing" "testing"
"github.com/stretchr/testify/assert"
common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common" common "github.com/coze-dev/coze-studio/backend/api/model/plugin_develop/common"
workflow3 "github.com/coze-dev/coze-studio/backend/api/model/workflow" workflow3 "github.com/coze-dev/coze-studio/backend/api/model/workflow"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr" "github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/stretchr/testify/assert"
) )
func TestToWorkflowAPIParameter(t *testing.T) { func TestToWorkflowAPIParameter(t *testing.T) {

View File

@ -19,6 +19,7 @@ package agentflow
import ( import (
"context" "context"
"errors" "errors"
"io"
"slices" "slices"
"github.com/google/uuid" "github.com/google/uuid"
@ -33,6 +34,7 @@ import (
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr" "github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv" "github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
"github.com/coze-dev/coze-studio/backend/pkg/logs" "github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
) )
type AgentState struct { type AgentState struct {
@ -69,17 +71,27 @@ func (r *AgentRunner) StreamExecute(ctx context.Context, req *AgentRequest) (
hdl, sr, sw := newReplyCallback(ctx, executeID.String(), r.returnDirectlyTools) hdl, sr, sw := newReplyCallback(ctx, executeID.String(), r.returnDirectlyTools)
go func() {
defer func() {
if pe := recover(); pe != nil {
logs.CtxErrorf(ctx, "[AgentRunner] StreamExecute recover, err: %v", pe)
sw.Send(nil, errors.New("internal server error"))
}
sw.Close()
}()
var composeOpts []compose.Option var composeOpts []compose.Option
var pipeMsgOpt compose.Option
var workflowMsgSr *schema.StreamReader[*crossworkflow.WorkflowMessage]
if r.containWfTool {
cfReq := crossworkflow.ExecuteConfig{
AgentID: &req.Identity.AgentID,
ConnectorUID: req.UserID,
ConnectorID: req.Identity.ConnectorID,
BizType: crossworkflow.BizTypeAgent,
}
if req.Identity.IsDraft {
cfReq.Mode = crossworkflow.ExecuteModeDebug
} else {
cfReq.Mode = crossworkflow.ExecuteModeRelease
}
wfConfig := crossworkflow.DefaultSVC().WithExecuteConfig(cfReq)
composeOpts = append(composeOpts, wfConfig)
pipeMsgOpt, workflowMsgSr = crossworkflow.DefaultSVC().WithMessagePipe()
composeOpts = append(composeOpts, pipeMsgOpt)
}
composeOpts = append(composeOpts, compose.WithCallbacks(hdl)) composeOpts = append(composeOpts, compose.WithCallbacks(hdl))
_ = compose.RegisterSerializableType[*AgentState]("agent_state") _ = compose.RegisterSerializableType[*AgentState]("agent_state")
if r.requireCheckpoint { if r.requireCheckpoint {
@ -96,27 +108,72 @@ func (r *AgentRunner) StreamExecute(ctx context.Context, req *AgentRequest) (
composeOpts = append(composeOpts, compose.WithCheckPointID(defaultCheckPointID)) composeOpts = append(composeOpts, compose.WithCheckPointID(defaultCheckPointID))
} }
if r.containWfTool { if r.containWfTool && workflowMsgSr != nil {
cfReq := crossworkflow.ExecuteConfig{ safego.Go(ctx, func() {
AgentID: &req.Identity.AgentID, r.processWfMidAnswerStream(ctx, sw, workflowMsgSr)
ConnectorUID: req.UserID, })
ConnectorID: req.Identity.ConnectorID,
BizType: crossworkflow.BizTypeAgent,
} }
if req.Identity.IsDraft { safego.Go(ctx, func() {
cfReq.Mode = crossworkflow.ExecuteModeDebug defer func() {
} else { if pe := recover(); pe != nil {
cfReq.Mode = crossworkflow.ExecuteModeRelease logs.CtxErrorf(ctx, "[AgentRunner] StreamExecute recover, err: %v", pe)
sw.Send(nil, errors.New("internal server error"))
} }
wfConfig := crossworkflow.DefaultSVC().WithExecuteConfig(cfReq) sw.Close()
composeOpts = append(composeOpts, wfConfig)
}
_, _ = r.runner.Stream(ctx, req, composeOpts...)
}() }()
_, _ = r.runner.Stream(ctx, req, composeOpts...)
})
return sr, nil return sr, nil
} }
func (r *AgentRunner) processWfMidAnswerStream(_ context.Context, sw *schema.StreamWriter[*entity.AgentEvent], wfStream *schema.StreamReader[*crossworkflow.WorkflowMessage]) {
streamInitialized := false
var srT *schema.StreamReader[*schema.Message]
var swT *schema.StreamWriter[*schema.Message]
defer func() {
if swT != nil {
swT.Close()
}
}()
for {
msg, err := wfStream.Recv()
if err == io.EOF {
break
}
if msg == nil || msg.DataMessage == nil {
continue
}
if msg.DataMessage.NodeType != crossworkflow.NodeTypeOutputEmitter {
continue
}
if !streamInitialized {
streamInitialized = true
srT, swT = schema.Pipe[*schema.Message](5)
sw.Send(&entity.AgentEvent{
EventType: singleagent.EventTypeOfToolMidAnswer,
ToolMidAnswer: srT,
}, nil)
}
swT.Send(&schema.Message{
Role: msg.DataMessage.Role,
Content: msg.DataMessage.Content,
Extra: func(msg *crossworkflow.WorkflowMessage) map[string]any {
extra := make(map[string]any)
extra["workflow_node_name"] = msg.NodeTitle
if msg.DataMessage.Last {
extra["is_finish"] = true
}
return extra
}(msg),
}, nil)
}
}
func (r *AgentRunner) PreHandlerReq(ctx context.Context, req *AgentRequest) *AgentRequest { func (r *AgentRunner) PreHandlerReq(ctx context.Context, req *AgentRequest) *AgentRequest {
req.Input = r.preHandlerInput(req.Input) req.Input = r.preHandlerInput(req.Input)
req.History = r.preHandlerHistory(req.History) req.History = r.preHandlerHistory(req.History)

View File

@ -267,7 +267,6 @@ func convInterruptEventType(interruptEvent any) singleagent.InterruptEventType {
} }
func (r *replyChunkCallback) concatToolsNodeOutput(ctx context.Context, output *schema.StreamReader[callbacks.CallbackOutput]) ([]*schema.Message, error) { func (r *replyChunkCallback) concatToolsNodeOutput(ctx context.Context, output *schema.StreamReader[callbacks.CallbackOutput]) ([]*schema.Message, error) {
defer output.Close()
var toolsMsgChunks [][]*schema.Message var toolsMsgChunks [][]*schema.Message
var sr *schema.StreamReader[*schema.Message] var sr *schema.StreamReader[*schema.Message]
var sw *schema.StreamWriter[*schema.Message] var sw *schema.StreamWriter[*schema.Message]
@ -280,7 +279,6 @@ func (r *replyChunkCallback) concatToolsNodeOutput(ctx context.Context, output *
returnDirectToolsMap := make(map[int]bool) returnDirectToolsMap := make(map[int]bool)
isReturnDirectToolsFirstCheck := true isReturnDirectToolsFirstCheck := true
isToolsMsgChunksInit := false isToolsMsgChunksInit := false
for { for {
cbOut, err := output.Recv() cbOut, err := output.Recv()
if errors.Is(err, io.EOF) { if errors.Is(err, io.EOF) {
@ -319,7 +317,7 @@ func (r *replyChunkCallback) concatToolsNodeOutput(ctx context.Context, output *
sr, sw = schema.Pipe[*schema.Message](5) sr, sw = schema.Pipe[*schema.Message](5)
r.sw.Send(&entity.AgentEvent{ r.sw.Send(&entity.AgentEvent{
EventType: singleagent.EventTypeOfToolsAsChatModelStream, EventType: singleagent.EventTypeOfToolsAsChatModelStream,
ChatModelAnswer: sr, ToolAsChatModelAnswer: sr,
}, nil) }, nil)
streamInitialized = true streamInitialized = true
} }

View File

@ -141,8 +141,10 @@ type AgentRunResponse struct {
} }
type AgentRespEvent struct { type AgentRespEvent struct {
EventType message.MessageType EventType message.MessageType `json:"event_type"`
ToolMidAnswer *schema.StreamReader[*schema.Message]
ToolAsAnswer *schema.StreamReader[*schema.Message]
ModelAnswer *schema.StreamReader[*schema.Message] ModelAnswer *schema.StreamReader[*schema.Message]
ToolsMessage []*schema.Message ToolsMessage []*schema.Message
FuncCall *schema.Message FuncCall *schema.Message

View File

@ -203,8 +203,12 @@ func transformEventMap(eventType singleagent.EventType) (message.MessageType, er
return message.MessageTypeKnowledge, nil return message.MessageTypeKnowledge, nil
case singleagent.EventTypeOfToolsMessage: case singleagent.EventTypeOfToolsMessage:
return message.MessageTypeToolResponse, nil return message.MessageTypeToolResponse, nil
case singleagent.EventTypeOfChatModelAnswer, singleagent.EventTypeOfToolsAsChatModelStream: case singleagent.EventTypeOfChatModelAnswer:
return message.MessageTypeAnswer, nil return message.MessageTypeAnswer, nil
case singleagent.EventTypeOfToolsAsChatModelStream:
return message.MessageTypeToolAsAnswer, nil
case singleagent.EventTypeOfToolMidAnswer:
return message.MessageTypeToolMidAnswer, nil
case singleagent.EventTypeOfSuggest: case singleagent.EventTypeOfSuggest:
return message.MessageTypeFlowUp, nil return message.MessageTypeFlowUp, nil
case singleagent.EventTypeOfInterrupt: case singleagent.EventTypeOfInterrupt:
@ -241,12 +245,12 @@ func (c *runImpl) buildAgentMessage2Create(ctx context.Context, chunk *entity.Ag
buildExt = arm.Ext buildExt = arm.Ext
msg.DisplayContent = arm.DisplayContent msg.DisplayContent = arm.DisplayContent
case message.MessageTypeAnswer: case message.MessageTypeAnswer, message.MessageTypeToolAsAnswer:
msg.Role = schema.Assistant msg.Role = schema.Assistant
msg.ContentType = message.ContentTypeText msg.ContentType = message.ContentTypeText
case message.MessageTypeToolResponse: case message.MessageTypeToolResponse:
msg.Role = schema.Tool msg.Role = schema.Assistant
msg.ContentType = message.ContentTypeText msg.ContentType = message.ContentTypeText
msg.Content = chunk.ToolsMessage[0].Content msg.Content = chunk.ToolsMessage[0].Content
@ -261,7 +265,7 @@ func (c *runImpl) buildAgentMessage2Create(ctx context.Context, chunk *entity.Ag
msg.Role = schema.Assistant msg.Role = schema.Assistant
msg.ContentType = message.ContentTypeText msg.ContentType = message.ContentTypeText
knowledgeContent := c.buildKnowledge(ctx, arm, chunk) knowledgeContent := c.buildKnowledge(ctx, chunk)
if knowledgeContent != nil { if knowledgeContent != nil {
knInfo, err := json.Marshal(knowledgeContent) knInfo, err := json.Marshal(knowledgeContent)
if err == nil { if err == nil {
@ -461,6 +465,9 @@ func (c *runImpl) pull(_ context.Context, mainChan chan *entity.AgentRespEvent,
Knowledge: rm.Knowledge, Knowledge: rm.Knowledge,
Suggest: rm.Suggest, Suggest: rm.Suggest,
Interrupt: rm.Interrupt, Interrupt: rm.Interrupt,
ToolMidAnswer: rm.ToolMidAnswer,
ToolAsAnswer: rm.ToolAsChatModelAnswer,
} }
mainChan <- respChunk mainChan <- respChunk
@ -478,9 +485,12 @@ func (c *runImpl) push(ctx context.Context, mainChan chan *entity.AgentRespEvent
}() }()
reasoningContent := bytes.NewBuffer([]byte{}) reasoningContent := bytes.NewBuffer([]byte{})
var createPreMsg = true
var preFinalAnswerMsg *msgEntity.Message
var firstAnswerMsg *msgEntity.Message
var reasoningMsg *msgEntity.Message
isSendFinishAnswer := false
var preToolResponseMsg *msgEntity.Message
toolResponseMsgContent := bytes.NewBuffer([]byte{})
for { for {
chunk, ok := <-mainChan chunk, ok := <-mainChan
if !ok || chunk == nil { if !ok || chunk == nil {
@ -489,6 +499,19 @@ func (c *runImpl) push(ctx context.Context, mainChan chan *entity.AgentRespEvent
logs.CtxInfof(ctx, "hanlder event:%v,err:%v", conv.DebugJsonToStr(chunk), chunk.Err) logs.CtxInfof(ctx, "hanlder event:%v,err:%v", conv.DebugJsonToStr(chunk), chunk.Err)
if chunk.Err != nil { if chunk.Err != nil {
if errors.Is(chunk.Err, io.EOF) { if errors.Is(chunk.Err, io.EOF) {
if !isSendFinishAnswer {
isSendFinishAnswer = true
if firstAnswerMsg != nil && len(reasoningContent.String()) > 0 {
c.saveReasoningContent(ctx, firstAnswerMsg, reasoningContent.String())
reasoningContent.Reset()
}
finishErr := c.handlerFinalAnswerFinish(ctx, sw, rtDependence)
if finishErr != nil {
err = finishErr
return
}
}
return return
} }
c.handlerErr(ctx, chunk.Err, sw) c.handlerErr(ctx, chunk.Err, sw)
@ -501,45 +524,156 @@ func (c *runImpl) push(ctx context.Context, mainChan chan *entity.AgentRespEvent
if err != nil { if err != nil {
return return
} }
if preToolResponseMsg == nil {
var cErr error
preToolResponseMsg, cErr = c.PreCreateAnswer(ctx, rtDependence)
if cErr != nil {
err = cErr
return
}
}
case message.MessageTypeToolResponse: case message.MessageTypeToolResponse:
err = c.handlerTooResponse(ctx, chunk, sw, rtDependence) err = c.handlerTooResponse(ctx, chunk, sw, rtDependence, preToolResponseMsg, toolResponseMsgContent.String())
if err != nil { if err != nil {
return return
} }
preToolResponseMsg = nil // reset
case message.MessageTypeKnowledge: case message.MessageTypeKnowledge:
err = c.handlerKnowledge(ctx, chunk, sw, rtDependence) err = c.handlerKnowledge(ctx, chunk, sw, rtDependence)
if err != nil { if err != nil {
return return
} }
case message.MessageTypeToolMidAnswer:
fullMidAnswerContent := bytes.NewBuffer([]byte{})
var usage *msgEntity.UsageExt
toolMidAnswerMsg, cErr := c.PreCreateAnswer(ctx, rtDependence)
if cErr != nil {
err = cErr
return
}
var preMsgIsFinish = false
for {
streamMsg, receErr := chunk.ToolMidAnswer.Recv()
if receErr != nil {
if errors.Is(receErr, io.EOF) {
break
}
err = receErr
return
}
if preMsgIsFinish {
toolMidAnswerMsg, cErr = c.PreCreateAnswer(ctx, rtDependence)
if cErr != nil {
err = cErr
return
}
preMsgIsFinish = false
}
if streamMsg == nil {
continue
}
if firstAnswerMsg == nil && len(streamMsg.Content) > 0 {
if reasoningMsg != nil {
toolMidAnswerMsg = deepcopy.Copy(reasoningMsg).(*msgEntity.Message)
}
firstAnswerMsg = deepcopy.Copy(toolMidAnswerMsg).(*msgEntity.Message)
}
if streamMsg.Extra != nil {
if val, ok := streamMsg.Extra["workflow_node_name"]; ok && val != nil {
toolMidAnswerMsg.Ext["message_title"] = val.(string)
}
}
sendMidAnswerMsg := c.buildSendMsg(ctx, toolMidAnswerMsg, false, rtDependence)
sendMidAnswerMsg.Content = streamMsg.Content
toolResponseMsgContent.WriteString(streamMsg.Content)
fullMidAnswerContent.WriteString(streamMsg.Content)
c.runEvent.SendMsgEvent(entity.RunEventMessageDelta, sendMidAnswerMsg, sw)
if streamMsg != nil && streamMsg.ResponseMeta != nil {
usage = c.handlerUsage(streamMsg.ResponseMeta)
}
if streamMsg.Extra["is_finish"] == true {
preMsgIsFinish = true
sendMidAnswerMsg := c.buildSendMsg(ctx, toolMidAnswerMsg, false, rtDependence)
sendMidAnswerMsg.Content = fullMidAnswerContent.String()
fullMidAnswerContent.Reset()
hfErr := c.handlerAnswer(ctx, sendMidAnswerMsg, sw, usage, rtDependence, toolMidAnswerMsg)
if hfErr != nil {
err = hfErr
return
}
}
}
case message.MessageTypeToolAsAnswer:
var usage *msgEntity.UsageExt
fullContent := bytes.NewBuffer([]byte{})
toolAsAnswerMsg, cErr := c.PreCreateAnswer(ctx, rtDependence)
if cErr != nil {
err = cErr
return
}
if firstAnswerMsg == nil {
firstAnswerMsg = toolAsAnswerMsg
}
for {
streamMsg, receErr := chunk.ToolAsAnswer.Recv()
if receErr != nil {
if errors.Is(receErr, io.EOF) {
answer := c.buildSendMsg(ctx, toolAsAnswerMsg, false, rtDependence)
answer.Content = fullContent.String()
hfErr := c.handlerAnswer(ctx, answer, sw, usage, rtDependence, toolAsAnswerMsg)
if hfErr != nil {
err = hfErr
return
}
break
}
err = receErr
return
}
if streamMsg != nil && streamMsg.ResponseMeta != nil {
usage = c.handlerUsage(streamMsg.ResponseMeta)
}
sendMsg := c.buildSendMsg(ctx, toolAsAnswerMsg, false, rtDependence)
fullContent.WriteString(streamMsg.Content)
sendMsg.Content = streamMsg.Content
c.runEvent.SendMsgEvent(entity.RunEventMessageDelta, sendMsg, sw)
}
case message.MessageTypeAnswer: case message.MessageTypeAnswer:
fullContent := bytes.NewBuffer([]byte{}) fullContent := bytes.NewBuffer([]byte{})
var usage *msgEntity.UsageExt var usage *msgEntity.UsageExt
var isToolCalls = false var isToolCalls = false
var modelAnswerMsg *msgEntity.Message
for { for {
streamMsg, receErr := chunk.ModelAnswer.Recv() streamMsg, receErr := chunk.ModelAnswer.Recv()
if receErr != nil { if receErr != nil {
if errors.Is(receErr, io.EOF) { if errors.Is(receErr, io.EOF) {
if isToolCalls { if isToolCalls {
break break
} }
if modelAnswerMsg == nil {
finalAnswer := c.buildSendMsg(ctx, preFinalAnswerMsg, false, rtDependence) break
}
finalAnswer.Content = fullContent.String() answer := c.buildSendMsg(ctx, modelAnswerMsg, false, rtDependence)
finalAnswer.ReasoningContent = ptr.Of(reasoningContent.String()) answer.Content = fullContent.String()
hfErr := c.handlerFinalAnswer(ctx, finalAnswer, sw, usage, rtDependence, preFinalAnswerMsg) hfErr := c.handlerAnswer(ctx, answer, sw, usage, rtDependence, modelAnswerMsg)
if hfErr != nil { if hfErr != nil {
err = hfErr err = hfErr
return return
} }
finishErr := c.handlerFinalAnswerFinish(ctx, sw, rtDependence)
if finishErr != nil {
err = finishErr
return
}
break break
} }
err = receErr err = receErr
@ -557,32 +691,64 @@ func (c *runImpl) push(ctx context.Context, mainChan chan *entity.AgentRespEvent
if streamMsg != nil && len(streamMsg.ReasoningContent) == 0 && len(streamMsg.Content) == 0 { if streamMsg != nil && len(streamMsg.ReasoningContent) == 0 && len(streamMsg.Content) == 0 {
continue continue
} }
if createPreMsg && (len(streamMsg.ReasoningContent) > 0 || len(streamMsg.Content) > 0) {
preFinalAnswerMsg, err = c.PreCreateFinalAnswer(ctx, rtDependence) if len(streamMsg.ReasoningContent) > 0 {
if reasoningMsg == nil {
reasoningMsg, err = c.PreCreateAnswer(ctx, rtDependence)
if err != nil { if err != nil {
return return
} }
createPreMsg = false
} }
sendMsg := c.buildSendMsg(ctx, preFinalAnswerMsg, false, rtDependence) sendReasoningMsg := c.buildSendMsg(ctx, reasoningMsg, false, rtDependence)
reasoningContent.WriteString(streamMsg.ReasoningContent) reasoningContent.WriteString(streamMsg.ReasoningContent)
sendMsg.ReasoningContent = ptr.Of(streamMsg.ReasoningContent) sendReasoningMsg.ReasoningContent = ptr.Of(streamMsg.ReasoningContent)
c.runEvent.SendMsgEvent(entity.RunEventMessageDelta, sendReasoningMsg, sw)
}
if len(streamMsg.Content) > 0 {
if modelAnswerMsg == nil {
modelAnswerMsg, err = c.PreCreateAnswer(ctx, rtDependence)
if err != nil {
return
}
if firstAnswerMsg == nil {
if reasoningMsg != nil {
modelAnswerMsg.ID = reasoningMsg.ID
}
firstAnswerMsg = modelAnswerMsg
}
}
sendAnswerMsg := c.buildSendMsg(ctx, modelAnswerMsg, false, rtDependence)
fullContent.WriteString(streamMsg.Content) fullContent.WriteString(streamMsg.Content)
sendMsg.Content = streamMsg.Content sendAnswerMsg.Content = streamMsg.Content
c.runEvent.SendMsgEvent(entity.RunEventMessageDelta, sendAnswerMsg, sw)
c.runEvent.SendMsgEvent(entity.RunEventMessageDelta, sendMsg, sw) }
} }
case message.MessageTypeFlowUp: case message.MessageTypeFlowUp:
if isSendFinishAnswer {
if firstAnswerMsg != nil && len(reasoningContent.String()) > 0 {
c.saveReasoningContent(ctx, firstAnswerMsg, reasoningContent.String())
}
isSendFinishAnswer = true
finishErr := c.handlerFinalAnswerFinish(ctx, sw, rtDependence)
if finishErr != nil {
err = finishErr
return
}
}
err = c.handlerSuggest(ctx, chunk, sw, rtDependence) err = c.handlerSuggest(ctx, chunk, sw, rtDependence)
if err != nil { if err != nil {
return return
} }
case message.MessageTypeInterrupt: case message.MessageTypeInterrupt:
err = c.handlerInterrupt(ctx, chunk, sw, rtDependence) err = c.handlerInterrupt(ctx, chunk, sw, rtDependence, firstAnswerMsg, reasoningContent.String())
if err != nil { if err != nil {
return return
} }
@ -590,12 +756,22 @@ func (c *runImpl) push(ctx context.Context, mainChan chan *entity.AgentRespEvent
} }
} }
func (c *runImpl) handlerInterrupt(ctx context.Context, chunk *entity.AgentRespEvent, sw *schema.StreamWriter[*entity.AgentRunResponse], rtDependence *runtimeDependence) error { func (c *runImpl) saveReasoningContent(ctx context.Context, firstAnswerMsg *msgEntity.Message, reasoningContent string) {
_, err := crossmessage.DefaultSVC().Edit(ctx, &message.Message{
ID: firstAnswerMsg.ID,
ReasoningContent: reasoningContent,
})
if err != nil {
logs.CtxInfof(ctx, "save reasoning content failed, err: %v", err)
}
}
func (c *runImpl) handlerInterrupt(ctx context.Context, chunk *entity.AgentRespEvent, sw *schema.StreamWriter[*entity.AgentRunResponse], rtDependence *runtimeDependence, firstAnswerMsg *msgEntity.Message, reasoningCOntent string) error {
interruptData, cType, err := c.parseInterruptData(ctx, chunk.Interrupt) interruptData, cType, err := c.parseInterruptData(ctx, chunk.Interrupt)
if err != nil { if err != nil {
return err return err
} }
preMsg, err := c.PreCreateFinalAnswer(ctx, rtDependence) preMsg, err := c.PreCreateAnswer(ctx, rtDependence)
if err != nil { if err != nil {
return err return err
} }
@ -616,8 +792,10 @@ func (c *runImpl) handlerInterrupt(ctx context.Context, chunk *entity.AgentRespE
c.runEvent.SendMsgEvent(entity.RunEventMessageDelta, deltaAnswer, sw) c.runEvent.SendMsgEvent(entity.RunEventMessageDelta, deltaAnswer, sw)
finalAnswer := deepcopy.Copy(deltaAnswer).(*entity.ChunkMessageItem) finalAnswer := deepcopy.Copy(deltaAnswer).(*entity.ChunkMessageItem)
if len(reasoningCOntent) > 0 && firstAnswerMsg == nil {
err = c.handlerFinalAnswer(ctx, finalAnswer, sw, nil, rtDependence, preMsg) finalAnswer.ReasoningContent = ptr.Of(reasoningCOntent)
}
err = c.handlerAnswer(ctx, finalAnswer, sw, nil, rtDependence, preMsg)
if err != nil { if err != nil {
return err return err
} }
@ -626,11 +804,6 @@ func (c *runImpl) handlerInterrupt(ctx context.Context, chunk *entity.AgentRespE
if err != nil { if err != nil {
return err return err
} }
err = c.handlerFinalAnswerFinish(ctx, sw, rtDependence)
if err != nil {
return err
}
return nil return nil
} }
@ -733,7 +906,7 @@ func (c *runImpl) handlerErr(_ context.Context, err error, sw *schema.StreamWrit
}) })
} }
func (c *runImpl) PreCreateFinalAnswer(ctx context.Context, rtDependence *runtimeDependence) (*msgEntity.Message, error) { func (c *runImpl) PreCreateAnswer(ctx context.Context, rtDependence *runtimeDependence) (*msgEntity.Message, error) {
arm := rtDependence.runMeta arm := rtDependence.runMeta
msgMeta := &msgEntity.Message{ msgMeta := &msgEntity.Message{
ConversationID: arm.ConversationID, ConversationID: arm.ConversationID,
@ -765,7 +938,7 @@ func (c *runImpl) PreCreateFinalAnswer(ctx context.Context, rtDependence *runtim
return crossmessage.DefaultSVC().PreCreate(ctx, msgMeta) return crossmessage.DefaultSVC().PreCreate(ctx, msgMeta)
} }
func (c *runImpl) handlerFinalAnswer(ctx context.Context, msg *entity.ChunkMessageItem, sw *schema.StreamWriter[*entity.AgentRunResponse], usage *msgEntity.UsageExt, rtDependence *runtimeDependence, preFinalAnswerMsg *msgEntity.Message) error { func (c *runImpl) handlerAnswer(ctx context.Context, msg *entity.ChunkMessageItem, sw *schema.StreamWriter[*entity.AgentRunResponse], usage *msgEntity.UsageExt, rtDependence *runtimeDependence, preAnswerMsg *msgEntity.Message) error {
if len(msg.Content) == 0 && len(ptr.From(msg.ReasoningContent)) == 0 { if len(msg.Content) == 0 && len(ptr.From(msg.ReasoningContent)) == 0 {
return nil return nil
@ -801,12 +974,15 @@ func (c *runImpl) handlerFinalAnswer(ctx context.Context, msg *entity.ChunkMessa
if err != nil { if err != nil {
return err return err
} }
preFinalAnswerMsg.Content = msg.Content preAnswerMsg.Content = msg.Content
preFinalAnswerMsg.ReasoningContent = ptr.From(msg.ReasoningContent) preAnswerMsg.ReasoningContent = ptr.From(msg.ReasoningContent)
preFinalAnswerMsg.Ext = msg.Ext preAnswerMsg.Ext = msg.Ext
preFinalAnswerMsg.ContentType = msg.ContentType preAnswerMsg.ContentType = msg.ContentType
preFinalAnswerMsg.ModelContent = string(mc) preAnswerMsg.ModelContent = string(mc)
_, err = crossmessage.DefaultSVC().Create(ctx, preFinalAnswerMsg) preAnswerMsg.CreatedAt = 0
preAnswerMsg.UpdatedAt = 0
_, err = crossmessage.DefaultSVC().Create(ctx, preAnswerMsg)
if err != nil { if err != nil {
return err return err
} }
@ -860,9 +1036,23 @@ func (c *runImpl) handlerAckMessage(_ context.Context, input *msgEntity.Message,
return nil return nil
} }
func (c *runImpl) handlerTooResponse(ctx context.Context, chunk *entity.AgentRespEvent, sw *schema.StreamWriter[*entity.AgentRunResponse], rtDependence *runtimeDependence) error { func (c *runImpl) handlerTooResponse(ctx context.Context, chunk *entity.AgentRespEvent, sw *schema.StreamWriter[*entity.AgentRunResponse], rtDependence *runtimeDependence, preToolResponseMsg *msgEntity.Message, toolResponseMsgContent string) error {
cm := c.buildAgentMessage2Create(ctx, chunk, message.MessageTypeToolResponse, rtDependence) cm := c.buildAgentMessage2Create(ctx, chunk, message.MessageTypeToolResponse, rtDependence)
cmData, err := crossmessage.DefaultSVC().Create(ctx, cm)
var cmData *message.Message
var err error
if preToolResponseMsg != nil {
cm.ID = preToolResponseMsg.ID
cm.CreatedAt = preToolResponseMsg.CreatedAt
cm.UpdatedAt = preToolResponseMsg.UpdatedAt
if len(toolResponseMsgContent) > 0 {
cm.Content = toolResponseMsgContent + "\n" + cm.Content
}
}
cmData, err = crossmessage.DefaultSVC().Create(ctx, cm)
if err != nil { if err != nil {
return err return err
} }
@ -902,7 +1092,7 @@ func (c *runImpl) handlerKnowledge(ctx context.Context, chunk *entity.AgentRespE
return nil return nil
} }
func (c *runImpl) buildKnowledge(_ context.Context, arm *entity.AgentRunMeta, chunk *entity.AgentRespEvent) *msgEntity.VerboseInfo { func (c *runImpl) buildKnowledge(_ context.Context, chunk *entity.AgentRespEvent) *msgEntity.VerboseInfo {
var recallDatas []msgEntity.RecallDataInfo var recallDatas []msgEntity.RecallDataInfo
for _, kOne := range chunk.Knowledge { for _, kOne := range chunk.Knowledge {
recallDatas = append(recallDatas, msgEntity.RecallDataInfo{ recallDatas = append(recallDatas, msgEntity.RecallDataInfo{

View File

@ -241,6 +241,12 @@ func (dao *MessageDAO) messageDO2PO(ctx context.Context, msgDo *entity.Message)
UpdatedAt: time.Now().UnixMilli(), UpdatedAt: time.Now().UnixMilli(),
ReasoningContent: msgDo.ReasoningContent, ReasoningContent: msgDo.ReasoningContent,
} }
if msgDo.CreatedAt > 0 {
msgPO.CreatedAt = msgDo.CreatedAt
}
if msgDo.UpdatedAt > 0 {
msgPO.UpdatedAt = msgDo.UpdatedAt
}
if msgDo.ModelContent != "" { if msgDo.ModelContent != "" {
msgPO.ModelContent = msgDo.ModelContent msgPO.ModelContent = msgDo.ModelContent

View File

@ -21,6 +21,7 @@ import (
"github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/compose"
"github.com/coze-dev/coze-studio/backend/api/model/workflow" "github.com/coze-dev/coze-studio/backend/api/model/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"

View File

@ -28,6 +28,7 @@ import (
"time" "time"
opt "github.com/cloudwego/eino/components/embedding" opt "github.com/cloudwego/eino/components/embedding"
"github.com/coze-dev/coze-studio/backend/pkg/logs" "github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices" "github.com/coze-dev/coze-studio/backend/pkg/lang/slices"