fix(singleagent): workflow as tool return directly (#526)

This commit is contained in:
junwen-lee
2025-08-04 21:58:37 +08:00
committed by GitHub
parent 38b63f00a3
commit 0367e66eca
5 changed files with 72 additions and 34 deletions

View File

@@ -113,7 +113,7 @@ func BuildAgent(ctx context.Context, conf *Config) (r *AgentRunner, err error) {
}
tr := newPreToolRetriever(&toolPreCallConf{})
wfTools, toolsReturnDirectly, err := newWorkflowTools(ctx, &workflowConfig{
wfTools, returnDirectlyTools, err := newWorkflowTools(ctx, &workflowConfig{
wfInfos: conf.Agent.Workflow,
})
if err != nil {
@@ -176,7 +176,7 @@ func BuildAgent(ctx context.Context, conf *Config) (r *AgentRunner, err error) {
ToolsConfig: compose.ToolsNodeConfig{
Tools: agentTools,
},
ToolReturnDirectly: toolsReturnDirectly,
ToolReturnDirectly: returnDirectlyTools,
ModelNodeName: keyOfReActAgentChatModel,
ToolsNodeName: keyOfReActAgentToolsNode,
})
@@ -273,10 +273,11 @@ func BuildAgent(ctx context.Context, conf *Config) (r *AgentRunner, err error) {
}
return &AgentRunner{
runner: runner,
requireCheckpoint: requireCheckpoint,
modelInfo: modelInfo,
containWfTool: containWfTool,
runner: runner,
requireCheckpoint: requireCheckpoint,
modelInfo: modelInfo,
containWfTool: containWfTool,
returnDirectlyTools: returnDirectlyTools,
}, nil
}

View File

@@ -57,8 +57,9 @@ type AgentRunner struct {
runner compose.Runnable[*AgentRequest, *schema.Message]
requireCheckpoint bool
containWfTool bool
modelInfo *modelmgr.Model
returnDirectlyTools map[string]struct{}
containWfTool bool
modelInfo *modelmgr.Model
}
func (r *AgentRunner) StreamExecute(ctx context.Context, req *AgentRequest) (
@@ -66,7 +67,7 @@ func (r *AgentRunner) StreamExecute(ctx context.Context, req *AgentRequest) (
) {
executeID := uuid.New()
hdl, sr, sw := newReplyCallback(ctx, executeID.String())
hdl, sr, sw := newReplyCallback(ctx, executeID.String(), r.returnDirectlyTools)
go func() {
defer func() {

View File

@@ -38,14 +38,15 @@ import (
"github.com/coze-dev/coze-studio/backend/pkg/logs"
)
func newReplyCallback(_ context.Context, executeID string) (clb callbacks.Handler,
func newReplyCallback(_ context.Context, executeID string, returnDirectlyTools map[string]struct{}) (clb callbacks.Handler,
sr *schema.StreamReader[*entity.AgentEvent], sw *schema.StreamWriter[*entity.AgentEvent],
) {
sr, sw = schema.Pipe[*entity.AgentEvent](10)
rcc := &replyChunkCallback{
sw: sw,
executeID: executeID,
sw: sw,
executeID: executeID,
returnDirectlyTools: returnDirectlyTools,
}
clb = callbacks.NewHandlerBuilder().
@@ -59,8 +60,9 @@ func newReplyCallback(_ context.Context, executeID string) (clb callbacks.Handle
}
type replyChunkCallback struct {
sw *schema.StreamWriter[*entity.AgentEvent]
executeID string
sw *schema.StreamWriter[*entity.AgentEvent]
executeID string
returnDirectlyTools map[string]struct{}
}
func (r *replyChunkCallback) OnError(ctx context.Context, info *callbacks.RunInfo, err error) context.Context {
@@ -201,7 +203,7 @@ func (r *replyChunkCallback) OnEndWithStreamOutput(ctx context.Context, info *ca
}, nil)
return ctx
case compose.ComponentOfToolsNode:
toolsMessage, err := concatToolsNodeOutput(ctx, output)
toolsMessage, err := r.concatToolsNodeOutput(ctx, output)
if err != nil {
r.sw.Send(nil, err)
return ctx
@@ -270,9 +272,21 @@ func convInterruptEventType(interruptEvent any) singleagent.InterruptEventType {
return interruptEventType
}
func concatToolsNodeOutput(ctx context.Context, output *schema.StreamReader[callbacks.CallbackOutput]) ([]*schema.Message, error) {
func (r *replyChunkCallback) concatToolsNodeOutput(ctx context.Context, output *schema.StreamReader[callbacks.CallbackOutput]) ([]*schema.Message, error) {
defer output.Close()
toolsMsgChunks := make([][]*schema.Message, 0, 5)
var toolsMsgChunks [][]*schema.Message
var sr *schema.StreamReader[*schema.Message]
var sw *schema.StreamWriter[*schema.Message]
defer func() {
if sw != nil {
sw.Close()
}
}()
var streamInitialized bool
returnDirectToolsMap := make(map[int]bool)
isReturnDirectToolsFirstCheck := true
isToolsMsgChunksInit := false
for {
cbOut, err := output.Recv()
if errors.Is(err, io.EOF) {
@@ -280,27 +294,48 @@ func concatToolsNodeOutput(ctx context.Context, output *schema.StreamReader[call
}
if err != nil {
if sw != nil {
sw.Send(nil, err)
}
return nil, err
}
msgs := convToolsNodeCallbackOutput(cbOut)
for _, msg := range msgs {
if !isToolsMsgChunksInit {
isToolsMsgChunksInit = true
toolsMsgChunks = make([][]*schema.Message, len(msgs))
}
for mIndex, msg := range msgs {
if msg == nil {
continue
}
if len(r.returnDirectlyTools) > 0 {
if isReturnDirectToolsFirstCheck {
isReturnDirectToolsFirstCheck = false
if _, ok := r.returnDirectlyTools[msg.ToolName]; ok {
returnDirectToolsMap[mIndex] = true
}
}
findSameMsg := false
for i, msgChunks := range toolsMsgChunks {
if msg.ToolCallID == msgChunks[0].ToolCallID {
toolsMsgChunks[i] = append(toolsMsgChunks[i], msg)
findSameMsg = true
break
if _, ok := returnDirectToolsMap[mIndex]; ok {
if !streamInitialized {
sr, sw = schema.Pipe[*schema.Message](5)
r.sw.Send(&entity.AgentEvent{
EventType: singleagent.EventTypeOfToolsAsChatModelStream,
ChatModelAnswer: sr,
}, nil)
streamInitialized = true
}
sw.Send(msg, nil)
}
}
if !findSameMsg {
toolsMsgChunks = append(toolsMsgChunks, []*schema.Message{msg})
if toolsMsgChunks[mIndex] == nil {
toolsMsgChunks[mIndex] = []*schema.Message{msg}
} else {
toolsMsgChunks[mIndex] = append(toolsMsgChunks[mIndex], msg)
}
}
}