From 0367e66ecad7f02ed066e7a430c6978776f43a88 Mon Sep 17 00:00:00 2001 From: junwen-lee Date: Mon, 4 Aug 2025 21:58:37 +0800 Subject: [PATCH] fix(singleagent): workflow as tool return directly (#526) --- .../crossdomain/singleagent/single_agent.go | 13 ++-- .../internal/agentflow/agent_flow_builder.go | 13 ++-- .../internal/agentflow/agent_flow_runner.go | 7 +- .../agentflow/callback_reply_chunk.go | 71 ++++++++++++++----- .../agentrun/service/agent_run_impl.go | 2 +- 5 files changed, 72 insertions(+), 34 deletions(-) diff --git a/backend/api/model/crossdomain/singleagent/single_agent.go b/backend/api/model/crossdomain/singleagent/single_agent.go index 34372eeb..a58cb12a 100644 --- a/backend/api/model/crossdomain/singleagent/single_agent.go +++ b/backend/api/model/crossdomain/singleagent/single_agent.go @@ -37,12 +37,13 @@ type AgentRuntime struct { type EventType string const ( - EventTypeOfChatModelAnswer EventType = "chatmodel_answer" - EventTypeOfToolsMessage EventType = "tools_message" - EventTypeOfFuncCall EventType = "func_call" - EventTypeOfSuggest EventType = "suggest" - EventTypeOfKnowledge EventType = "knowledge" - EventTypeOfInterrupt EventType = "interrupt" + EventTypeOfChatModelAnswer EventType = "chatmodel_answer" + EventTypeOfToolsAsChatModelStream EventType = "tools_as_chatmodel_answer" + EventTypeOfToolsMessage EventType = "tools_message" + EventTypeOfFuncCall EventType = "func_call" + EventTypeOfSuggest EventType = "suggest" + EventTypeOfKnowledge EventType = "knowledge" + EventTypeOfInterrupt EventType = "interrupt" ) type AgentEvent struct { diff --git a/backend/domain/agent/singleagent/internal/agentflow/agent_flow_builder.go b/backend/domain/agent/singleagent/internal/agentflow/agent_flow_builder.go index 7b8dd179..35a13571 100644 --- a/backend/domain/agent/singleagent/internal/agentflow/agent_flow_builder.go +++ b/backend/domain/agent/singleagent/internal/agentflow/agent_flow_builder.go @@ -113,7 +113,7 @@ func BuildAgent(ctx context.Context, conf *Config) (r *AgentRunner, err error) { } tr := newPreToolRetriever(&toolPreCallConf{}) - wfTools, toolsReturnDirectly, err := newWorkflowTools(ctx, &workflowConfig{ + wfTools, returnDirectlyTools, err := newWorkflowTools(ctx, &workflowConfig{ wfInfos: conf.Agent.Workflow, }) if err != nil { @@ -176,7 +176,7 @@ func BuildAgent(ctx context.Context, conf *Config) (r *AgentRunner, err error) { ToolsConfig: compose.ToolsNodeConfig{ Tools: agentTools, }, - ToolReturnDirectly: toolsReturnDirectly, + ToolReturnDirectly: returnDirectlyTools, ModelNodeName: keyOfReActAgentChatModel, ToolsNodeName: keyOfReActAgentToolsNode, }) @@ -273,10 +273,11 @@ func BuildAgent(ctx context.Context, conf *Config) (r *AgentRunner, err error) { } return &AgentRunner{ - runner: runner, - requireCheckpoint: requireCheckpoint, - modelInfo: modelInfo, - containWfTool: containWfTool, + runner: runner, + requireCheckpoint: requireCheckpoint, + modelInfo: modelInfo, + containWfTool: containWfTool, + returnDirectlyTools: returnDirectlyTools, }, nil } diff --git a/backend/domain/agent/singleagent/internal/agentflow/agent_flow_runner.go b/backend/domain/agent/singleagent/internal/agentflow/agent_flow_runner.go index c7574d8f..9cb0bb7b 100644 --- a/backend/domain/agent/singleagent/internal/agentflow/agent_flow_runner.go +++ b/backend/domain/agent/singleagent/internal/agentflow/agent_flow_runner.go @@ -57,8 +57,9 @@ type AgentRunner struct { runner compose.Runnable[*AgentRequest, *schema.Message] requireCheckpoint bool - containWfTool bool - modelInfo *modelmgr.Model + returnDirectlyTools map[string]struct{} + containWfTool bool + modelInfo *modelmgr.Model } func (r *AgentRunner) StreamExecute(ctx context.Context, req *AgentRequest) ( @@ -66,7 +67,7 @@ func (r *AgentRunner) StreamExecute(ctx context.Context, req *AgentRequest) ( ) { executeID := uuid.New() - hdl, sr, sw := newReplyCallback(ctx, executeID.String()) + hdl, sr, sw := newReplyCallback(ctx, executeID.String(), r.returnDirectlyTools) go func() { defer func() { diff --git a/backend/domain/agent/singleagent/internal/agentflow/callback_reply_chunk.go b/backend/domain/agent/singleagent/internal/agentflow/callback_reply_chunk.go index 49c7ca86..ffd48bda 100644 --- a/backend/domain/agent/singleagent/internal/agentflow/callback_reply_chunk.go +++ b/backend/domain/agent/singleagent/internal/agentflow/callback_reply_chunk.go @@ -38,14 +38,15 @@ import ( "github.com/coze-dev/coze-studio/backend/pkg/logs" ) -func newReplyCallback(_ context.Context, executeID string) (clb callbacks.Handler, +func newReplyCallback(_ context.Context, executeID string, returnDirectlyTools map[string]struct{}) (clb callbacks.Handler, sr *schema.StreamReader[*entity.AgentEvent], sw *schema.StreamWriter[*entity.AgentEvent], ) { sr, sw = schema.Pipe[*entity.AgentEvent](10) rcc := &replyChunkCallback{ - sw: sw, - executeID: executeID, + sw: sw, + executeID: executeID, + returnDirectlyTools: returnDirectlyTools, } clb = callbacks.NewHandlerBuilder(). @@ -59,8 +60,9 @@ func newReplyCallback(_ context.Context, executeID string) (clb callbacks.Handle } type replyChunkCallback struct { - sw *schema.StreamWriter[*entity.AgentEvent] - executeID string + sw *schema.StreamWriter[*entity.AgentEvent] + executeID string + returnDirectlyTools map[string]struct{} } func (r *replyChunkCallback) OnError(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { @@ -201,7 +203,7 @@ func (r *replyChunkCallback) OnEndWithStreamOutput(ctx context.Context, info *ca }, nil) return ctx case compose.ComponentOfToolsNode: - toolsMessage, err := concatToolsNodeOutput(ctx, output) + toolsMessage, err := r.concatToolsNodeOutput(ctx, output) if err != nil { r.sw.Send(nil, err) return ctx @@ -270,9 +272,21 @@ func convInterruptEventType(interruptEvent any) singleagent.InterruptEventType { return interruptEventType } -func concatToolsNodeOutput(ctx context.Context, output *schema.StreamReader[callbacks.CallbackOutput]) ([]*schema.Message, error) { +func (r *replyChunkCallback) concatToolsNodeOutput(ctx context.Context, output *schema.StreamReader[callbacks.CallbackOutput]) ([]*schema.Message, error) { defer output.Close() - toolsMsgChunks := make([][]*schema.Message, 0, 5) + var toolsMsgChunks [][]*schema.Message + var sr *schema.StreamReader[*schema.Message] + var sw *schema.StreamWriter[*schema.Message] + defer func() { + if sw != nil { + sw.Close() + } + }() + var streamInitialized bool + returnDirectToolsMap := make(map[int]bool) + isReturnDirectToolsFirstCheck := true + isToolsMsgChunksInit := false + for { cbOut, err := output.Recv() if errors.Is(err, io.EOF) { @@ -280,27 +294,48 @@ func concatToolsNodeOutput(ctx context.Context, output *schema.StreamReader[call } if err != nil { + if sw != nil { + sw.Send(nil, err) + } return nil, err } msgs := convToolsNodeCallbackOutput(cbOut) - for _, msg := range msgs { + if !isToolsMsgChunksInit { + isToolsMsgChunksInit = true + toolsMsgChunks = make([][]*schema.Message, len(msgs)) + } + + for mIndex, msg := range msgs { + if msg == nil { continue } + if len(r.returnDirectlyTools) > 0 { + if isReturnDirectToolsFirstCheck { + isReturnDirectToolsFirstCheck = false + if _, ok := r.returnDirectlyTools[msg.ToolName]; ok { + returnDirectToolsMap[mIndex] = true + } + } - findSameMsg := false - for i, msgChunks := range toolsMsgChunks { - if msg.ToolCallID == msgChunks[0].ToolCallID { - toolsMsgChunks[i] = append(toolsMsgChunks[i], msg) - findSameMsg = true - break + if _, ok := returnDirectToolsMap[mIndex]; ok { + if !streamInitialized { + sr, sw = schema.Pipe[*schema.Message](5) + r.sw.Send(&entity.AgentEvent{ + EventType: singleagent.EventTypeOfToolsAsChatModelStream, + ChatModelAnswer: sr, + }, nil) + streamInitialized = true + } + sw.Send(msg, nil) } } - - if !findSameMsg { - toolsMsgChunks = append(toolsMsgChunks, []*schema.Message{msg}) + if toolsMsgChunks[mIndex] == nil { + toolsMsgChunks[mIndex] = []*schema.Message{msg} + } else { + toolsMsgChunks[mIndex] = append(toolsMsgChunks[mIndex], msg) } } } diff --git a/backend/domain/conversation/agentrun/service/agent_run_impl.go b/backend/domain/conversation/agentrun/service/agent_run_impl.go index b0506edc..fbf19b37 100644 --- a/backend/domain/conversation/agentrun/service/agent_run_impl.go +++ b/backend/domain/conversation/agentrun/service/agent_run_impl.go @@ -200,7 +200,7 @@ func transformEventMap(eventType singleagent.EventType) (message.MessageType, er return message.MessageTypeKnowledge, nil case singleagent.EventTypeOfToolsMessage: return message.MessageTypeToolResponse, nil - case singleagent.EventTypeOfChatModelAnswer: + case singleagent.EventTypeOfChatModelAnswer, singleagent.EventTypeOfToolsAsChatModelStream: return message.MessageTypeAnswer, nil case singleagent.EventTypeOfSuggest: return message.MessageTypeFlowUp, nil