fix(singleagent): workflow as tool return directly (#526)

This commit is contained in:
junwen-lee 2025-08-04 21:58:37 +08:00 committed by GitHub
parent 38b63f00a3
commit 0367e66eca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 72 additions and 34 deletions

View File

@ -38,6 +38,7 @@ type EventType string
const ( const (
EventTypeOfChatModelAnswer EventType = "chatmodel_answer" EventTypeOfChatModelAnswer EventType = "chatmodel_answer"
EventTypeOfToolsAsChatModelStream EventType = "tools_as_chatmodel_answer"
EventTypeOfToolsMessage EventType = "tools_message" EventTypeOfToolsMessage EventType = "tools_message"
EventTypeOfFuncCall EventType = "func_call" EventTypeOfFuncCall EventType = "func_call"
EventTypeOfSuggest EventType = "suggest" EventTypeOfSuggest EventType = "suggest"

View File

@ -113,7 +113,7 @@ func BuildAgent(ctx context.Context, conf *Config) (r *AgentRunner, err error) {
} }
tr := newPreToolRetriever(&toolPreCallConf{}) tr := newPreToolRetriever(&toolPreCallConf{})
wfTools, toolsReturnDirectly, err := newWorkflowTools(ctx, &workflowConfig{ wfTools, returnDirectlyTools, err := newWorkflowTools(ctx, &workflowConfig{
wfInfos: conf.Agent.Workflow, wfInfos: conf.Agent.Workflow,
}) })
if err != nil { if err != nil {
@ -176,7 +176,7 @@ func BuildAgent(ctx context.Context, conf *Config) (r *AgentRunner, err error) {
ToolsConfig: compose.ToolsNodeConfig{ ToolsConfig: compose.ToolsNodeConfig{
Tools: agentTools, Tools: agentTools,
}, },
ToolReturnDirectly: toolsReturnDirectly, ToolReturnDirectly: returnDirectlyTools,
ModelNodeName: keyOfReActAgentChatModel, ModelNodeName: keyOfReActAgentChatModel,
ToolsNodeName: keyOfReActAgentToolsNode, ToolsNodeName: keyOfReActAgentToolsNode,
}) })
@ -277,6 +277,7 @@ func BuildAgent(ctx context.Context, conf *Config) (r *AgentRunner, err error) {
requireCheckpoint: requireCheckpoint, requireCheckpoint: requireCheckpoint,
modelInfo: modelInfo, modelInfo: modelInfo,
containWfTool: containWfTool, containWfTool: containWfTool,
returnDirectlyTools: returnDirectlyTools,
}, nil }, nil
} }

View File

@ -57,6 +57,7 @@ type AgentRunner struct {
runner compose.Runnable[*AgentRequest, *schema.Message] runner compose.Runnable[*AgentRequest, *schema.Message]
requireCheckpoint bool requireCheckpoint bool
returnDirectlyTools map[string]struct{}
containWfTool bool containWfTool bool
modelInfo *modelmgr.Model modelInfo *modelmgr.Model
} }
@ -66,7 +67,7 @@ func (r *AgentRunner) StreamExecute(ctx context.Context, req *AgentRequest) (
) { ) {
executeID := uuid.New() executeID := uuid.New()
hdl, sr, sw := newReplyCallback(ctx, executeID.String()) hdl, sr, sw := newReplyCallback(ctx, executeID.String(), r.returnDirectlyTools)
go func() { go func() {
defer func() { defer func() {

View File

@ -38,7 +38,7 @@ import (
"github.com/coze-dev/coze-studio/backend/pkg/logs" "github.com/coze-dev/coze-studio/backend/pkg/logs"
) )
func newReplyCallback(_ context.Context, executeID string) (clb callbacks.Handler, func newReplyCallback(_ context.Context, executeID string, returnDirectlyTools map[string]struct{}) (clb callbacks.Handler,
sr *schema.StreamReader[*entity.AgentEvent], sw *schema.StreamWriter[*entity.AgentEvent], sr *schema.StreamReader[*entity.AgentEvent], sw *schema.StreamWriter[*entity.AgentEvent],
) { ) {
sr, sw = schema.Pipe[*entity.AgentEvent](10) sr, sw = schema.Pipe[*entity.AgentEvent](10)
@ -46,6 +46,7 @@ func newReplyCallback(_ context.Context, executeID string) (clb callbacks.Handle
rcc := &replyChunkCallback{ rcc := &replyChunkCallback{
sw: sw, sw: sw,
executeID: executeID, executeID: executeID,
returnDirectlyTools: returnDirectlyTools,
} }
clb = callbacks.NewHandlerBuilder(). clb = callbacks.NewHandlerBuilder().
@ -61,6 +62,7 @@ func newReplyCallback(_ context.Context, executeID string) (clb callbacks.Handle
type replyChunkCallback struct { type replyChunkCallback struct {
sw *schema.StreamWriter[*entity.AgentEvent] sw *schema.StreamWriter[*entity.AgentEvent]
executeID string executeID string
returnDirectlyTools map[string]struct{}
} }
func (r *replyChunkCallback) OnError(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { func (r *replyChunkCallback) OnError(ctx context.Context, info *callbacks.RunInfo, err error) context.Context {
@ -201,7 +203,7 @@ func (r *replyChunkCallback) OnEndWithStreamOutput(ctx context.Context, info *ca
}, nil) }, nil)
return ctx return ctx
case compose.ComponentOfToolsNode: case compose.ComponentOfToolsNode:
toolsMessage, err := concatToolsNodeOutput(ctx, output) toolsMessage, err := r.concatToolsNodeOutput(ctx, output)
if err != nil { if err != nil {
r.sw.Send(nil, err) r.sw.Send(nil, err)
return ctx return ctx
@ -270,9 +272,21 @@ func convInterruptEventType(interruptEvent any) singleagent.InterruptEventType {
return interruptEventType return interruptEventType
} }
func concatToolsNodeOutput(ctx context.Context, output *schema.StreamReader[callbacks.CallbackOutput]) ([]*schema.Message, error) { func (r *replyChunkCallback) concatToolsNodeOutput(ctx context.Context, output *schema.StreamReader[callbacks.CallbackOutput]) ([]*schema.Message, error) {
defer output.Close() defer output.Close()
toolsMsgChunks := make([][]*schema.Message, 0, 5) var toolsMsgChunks [][]*schema.Message
var sr *schema.StreamReader[*schema.Message]
var sw *schema.StreamWriter[*schema.Message]
defer func() {
if sw != nil {
sw.Close()
}
}()
var streamInitialized bool
returnDirectToolsMap := make(map[int]bool)
isReturnDirectToolsFirstCheck := true
isToolsMsgChunksInit := false
for { for {
cbOut, err := output.Recv() cbOut, err := output.Recv()
if errors.Is(err, io.EOF) { if errors.Is(err, io.EOF) {
@ -280,27 +294,48 @@ func concatToolsNodeOutput(ctx context.Context, output *schema.StreamReader[call
} }
if err != nil { if err != nil {
if sw != nil {
sw.Send(nil, err)
}
return nil, err return nil, err
} }
msgs := convToolsNodeCallbackOutput(cbOut) msgs := convToolsNodeCallbackOutput(cbOut)
for _, msg := range msgs { if !isToolsMsgChunksInit {
isToolsMsgChunksInit = true
toolsMsgChunks = make([][]*schema.Message, len(msgs))
}
for mIndex, msg := range msgs {
if msg == nil { if msg == nil {
continue continue
} }
if len(r.returnDirectlyTools) > 0 {
findSameMsg := false if isReturnDirectToolsFirstCheck {
for i, msgChunks := range toolsMsgChunks { isReturnDirectToolsFirstCheck = false
if msg.ToolCallID == msgChunks[0].ToolCallID { if _, ok := r.returnDirectlyTools[msg.ToolName]; ok {
toolsMsgChunks[i] = append(toolsMsgChunks[i], msg) returnDirectToolsMap[mIndex] = true
findSameMsg = true
break
} }
} }
if !findSameMsg { if _, ok := returnDirectToolsMap[mIndex]; ok {
toolsMsgChunks = append(toolsMsgChunks, []*schema.Message{msg}) if !streamInitialized {
sr, sw = schema.Pipe[*schema.Message](5)
r.sw.Send(&entity.AgentEvent{
EventType: singleagent.EventTypeOfToolsAsChatModelStream,
ChatModelAnswer: sr,
}, nil)
streamInitialized = true
}
sw.Send(msg, nil)
}
}
if toolsMsgChunks[mIndex] == nil {
toolsMsgChunks[mIndex] = []*schema.Message{msg}
} else {
toolsMsgChunks[mIndex] = append(toolsMsgChunks[mIndex], msg)
} }
} }
} }

View File

@ -200,7 +200,7 @@ func transformEventMap(eventType singleagent.EventType) (message.MessageType, er
return message.MessageTypeKnowledge, nil return message.MessageTypeKnowledge, nil
case singleagent.EventTypeOfToolsMessage: case singleagent.EventTypeOfToolsMessage:
return message.MessageTypeToolResponse, nil return message.MessageTypeToolResponse, nil
case singleagent.EventTypeOfChatModelAnswer: case singleagent.EventTypeOfChatModelAnswer, singleagent.EventTypeOfToolsAsChatModelStream:
return message.MessageTypeAnswer, nil return message.MessageTypeAnswer, nil
case singleagent.EventTypeOfSuggest: case singleagent.EventTypeOfSuggest:
return message.MessageTypeFlowUp, nil return message.MessageTypeFlowUp, nil