fix(singleagent): workflow as tool return directly (#526)

This commit is contained in:
junwen-lee 2025-08-04 21:58:37 +08:00 committed by GitHub
parent 38b63f00a3
commit 0367e66eca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 72 additions and 34 deletions

View File

@ -37,12 +37,13 @@ type AgentRuntime struct {
type EventType string type EventType string
const ( const (
EventTypeOfChatModelAnswer EventType = "chatmodel_answer" EventTypeOfChatModelAnswer EventType = "chatmodel_answer"
EventTypeOfToolsMessage EventType = "tools_message" EventTypeOfToolsAsChatModelStream EventType = "tools_as_chatmodel_answer"
EventTypeOfFuncCall EventType = "func_call" EventTypeOfToolsMessage EventType = "tools_message"
EventTypeOfSuggest EventType = "suggest" EventTypeOfFuncCall EventType = "func_call"
EventTypeOfKnowledge EventType = "knowledge" EventTypeOfSuggest EventType = "suggest"
EventTypeOfInterrupt EventType = "interrupt" EventTypeOfKnowledge EventType = "knowledge"
EventTypeOfInterrupt EventType = "interrupt"
) )
type AgentEvent struct { type AgentEvent struct {

View File

@ -113,7 +113,7 @@ func BuildAgent(ctx context.Context, conf *Config) (r *AgentRunner, err error) {
} }
tr := newPreToolRetriever(&toolPreCallConf{}) tr := newPreToolRetriever(&toolPreCallConf{})
wfTools, toolsReturnDirectly, err := newWorkflowTools(ctx, &workflowConfig{ wfTools, returnDirectlyTools, err := newWorkflowTools(ctx, &workflowConfig{
wfInfos: conf.Agent.Workflow, wfInfos: conf.Agent.Workflow,
}) })
if err != nil { if err != nil {
@ -176,7 +176,7 @@ func BuildAgent(ctx context.Context, conf *Config) (r *AgentRunner, err error) {
ToolsConfig: compose.ToolsNodeConfig{ ToolsConfig: compose.ToolsNodeConfig{
Tools: agentTools, Tools: agentTools,
}, },
ToolReturnDirectly: toolsReturnDirectly, ToolReturnDirectly: returnDirectlyTools,
ModelNodeName: keyOfReActAgentChatModel, ModelNodeName: keyOfReActAgentChatModel,
ToolsNodeName: keyOfReActAgentToolsNode, ToolsNodeName: keyOfReActAgentToolsNode,
}) })
@ -273,10 +273,11 @@ func BuildAgent(ctx context.Context, conf *Config) (r *AgentRunner, err error) {
} }
return &AgentRunner{ return &AgentRunner{
runner: runner, runner: runner,
requireCheckpoint: requireCheckpoint, requireCheckpoint: requireCheckpoint,
modelInfo: modelInfo, modelInfo: modelInfo,
containWfTool: containWfTool, containWfTool: containWfTool,
returnDirectlyTools: returnDirectlyTools,
}, nil }, nil
} }

View File

@ -57,8 +57,9 @@ type AgentRunner struct {
runner compose.Runnable[*AgentRequest, *schema.Message] runner compose.Runnable[*AgentRequest, *schema.Message]
requireCheckpoint bool requireCheckpoint bool
containWfTool bool returnDirectlyTools map[string]struct{}
modelInfo *modelmgr.Model containWfTool bool
modelInfo *modelmgr.Model
} }
func (r *AgentRunner) StreamExecute(ctx context.Context, req *AgentRequest) ( func (r *AgentRunner) StreamExecute(ctx context.Context, req *AgentRequest) (
@ -66,7 +67,7 @@ func (r *AgentRunner) StreamExecute(ctx context.Context, req *AgentRequest) (
) { ) {
executeID := uuid.New() executeID := uuid.New()
hdl, sr, sw := newReplyCallback(ctx, executeID.String()) hdl, sr, sw := newReplyCallback(ctx, executeID.String(), r.returnDirectlyTools)
go func() { go func() {
defer func() { defer func() {

View File

@ -38,14 +38,15 @@ import (
"github.com/coze-dev/coze-studio/backend/pkg/logs" "github.com/coze-dev/coze-studio/backend/pkg/logs"
) )
func newReplyCallback(_ context.Context, executeID string) (clb callbacks.Handler, func newReplyCallback(_ context.Context, executeID string, returnDirectlyTools map[string]struct{}) (clb callbacks.Handler,
sr *schema.StreamReader[*entity.AgentEvent], sw *schema.StreamWriter[*entity.AgentEvent], sr *schema.StreamReader[*entity.AgentEvent], sw *schema.StreamWriter[*entity.AgentEvent],
) { ) {
sr, sw = schema.Pipe[*entity.AgentEvent](10) sr, sw = schema.Pipe[*entity.AgentEvent](10)
rcc := &replyChunkCallback{ rcc := &replyChunkCallback{
sw: sw, sw: sw,
executeID: executeID, executeID: executeID,
returnDirectlyTools: returnDirectlyTools,
} }
clb = callbacks.NewHandlerBuilder(). clb = callbacks.NewHandlerBuilder().
@ -59,8 +60,9 @@ func newReplyCallback(_ context.Context, executeID string) (clb callbacks.Handle
} }
type replyChunkCallback struct { type replyChunkCallback struct {
sw *schema.StreamWriter[*entity.AgentEvent] sw *schema.StreamWriter[*entity.AgentEvent]
executeID string executeID string
returnDirectlyTools map[string]struct{}
} }
func (r *replyChunkCallback) OnError(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { func (r *replyChunkCallback) OnError(ctx context.Context, info *callbacks.RunInfo, err error) context.Context {
@ -201,7 +203,7 @@ func (r *replyChunkCallback) OnEndWithStreamOutput(ctx context.Context, info *ca
}, nil) }, nil)
return ctx return ctx
case compose.ComponentOfToolsNode: case compose.ComponentOfToolsNode:
toolsMessage, err := concatToolsNodeOutput(ctx, output) toolsMessage, err := r.concatToolsNodeOutput(ctx, output)
if err != nil { if err != nil {
r.sw.Send(nil, err) r.sw.Send(nil, err)
return ctx return ctx
@ -270,9 +272,21 @@ func convInterruptEventType(interruptEvent any) singleagent.InterruptEventType {
return interruptEventType return interruptEventType
} }
func concatToolsNodeOutput(ctx context.Context, output *schema.StreamReader[callbacks.CallbackOutput]) ([]*schema.Message, error) { func (r *replyChunkCallback) concatToolsNodeOutput(ctx context.Context, output *schema.StreamReader[callbacks.CallbackOutput]) ([]*schema.Message, error) {
defer output.Close() defer output.Close()
toolsMsgChunks := make([][]*schema.Message, 0, 5) var toolsMsgChunks [][]*schema.Message
var sr *schema.StreamReader[*schema.Message]
var sw *schema.StreamWriter[*schema.Message]
defer func() {
if sw != nil {
sw.Close()
}
}()
var streamInitialized bool
returnDirectToolsMap := make(map[int]bool)
isReturnDirectToolsFirstCheck := true
isToolsMsgChunksInit := false
for { for {
cbOut, err := output.Recv() cbOut, err := output.Recv()
if errors.Is(err, io.EOF) { if errors.Is(err, io.EOF) {
@ -280,27 +294,48 @@ func concatToolsNodeOutput(ctx context.Context, output *schema.StreamReader[call
} }
if err != nil { if err != nil {
if sw != nil {
sw.Send(nil, err)
}
return nil, err return nil, err
} }
msgs := convToolsNodeCallbackOutput(cbOut) msgs := convToolsNodeCallbackOutput(cbOut)
for _, msg := range msgs { if !isToolsMsgChunksInit {
isToolsMsgChunksInit = true
toolsMsgChunks = make([][]*schema.Message, len(msgs))
}
for mIndex, msg := range msgs {
if msg == nil { if msg == nil {
continue continue
} }
if len(r.returnDirectlyTools) > 0 {
if isReturnDirectToolsFirstCheck {
isReturnDirectToolsFirstCheck = false
if _, ok := r.returnDirectlyTools[msg.ToolName]; ok {
returnDirectToolsMap[mIndex] = true
}
}
findSameMsg := false if _, ok := returnDirectToolsMap[mIndex]; ok {
for i, msgChunks := range toolsMsgChunks { if !streamInitialized {
if msg.ToolCallID == msgChunks[0].ToolCallID { sr, sw = schema.Pipe[*schema.Message](5)
toolsMsgChunks[i] = append(toolsMsgChunks[i], msg) r.sw.Send(&entity.AgentEvent{
findSameMsg = true EventType: singleagent.EventTypeOfToolsAsChatModelStream,
break ChatModelAnswer: sr,
}, nil)
streamInitialized = true
}
sw.Send(msg, nil)
} }
} }
if toolsMsgChunks[mIndex] == nil {
if !findSameMsg { toolsMsgChunks[mIndex] = []*schema.Message{msg}
toolsMsgChunks = append(toolsMsgChunks, []*schema.Message{msg}) } else {
toolsMsgChunks[mIndex] = append(toolsMsgChunks[mIndex], msg)
} }
} }
} }

View File

@ -200,7 +200,7 @@ func transformEventMap(eventType singleagent.EventType) (message.MessageType, er
return message.MessageTypeKnowledge, nil return message.MessageTypeKnowledge, nil
case singleagent.EventTypeOfToolsMessage: case singleagent.EventTypeOfToolsMessage:
return message.MessageTypeToolResponse, nil return message.MessageTypeToolResponse, nil
case singleagent.EventTypeOfChatModelAnswer: case singleagent.EventTypeOfChatModelAnswer, singleagent.EventTypeOfToolsAsChatModelStream:
return message.MessageTypeAnswer, nil return message.MessageTypeAnswer, nil
case singleagent.EventTypeOfSuggest: case singleagent.EventTypeOfSuggest:
return message.MessageTypeFlowUp, nil return message.MessageTypeFlowUp, nil