fix(singleagent): workflow as tool return directly (#526)

This commit is contained in:
junwen-lee 2025-08-04 21:58:37 +08:00 committed by GitHub
parent 38b63f00a3
commit 0367e66eca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 72 additions and 34 deletions

View File

@ -38,6 +38,7 @@ type EventType string
const (
EventTypeOfChatModelAnswer EventType = "chatmodel_answer"
EventTypeOfToolsAsChatModelStream EventType = "tools_as_chatmodel_answer"
EventTypeOfToolsMessage EventType = "tools_message"
EventTypeOfFuncCall EventType = "func_call"
EventTypeOfSuggest EventType = "suggest"

View File

@ -113,7 +113,7 @@ func BuildAgent(ctx context.Context, conf *Config) (r *AgentRunner, err error) {
}
tr := newPreToolRetriever(&toolPreCallConf{})
wfTools, toolsReturnDirectly, err := newWorkflowTools(ctx, &workflowConfig{
wfTools, returnDirectlyTools, err := newWorkflowTools(ctx, &workflowConfig{
wfInfos: conf.Agent.Workflow,
})
if err != nil {
@ -176,7 +176,7 @@ func BuildAgent(ctx context.Context, conf *Config) (r *AgentRunner, err error) {
ToolsConfig: compose.ToolsNodeConfig{
Tools: agentTools,
},
ToolReturnDirectly: toolsReturnDirectly,
ToolReturnDirectly: returnDirectlyTools,
ModelNodeName: keyOfReActAgentChatModel,
ToolsNodeName: keyOfReActAgentToolsNode,
})
@ -277,6 +277,7 @@ func BuildAgent(ctx context.Context, conf *Config) (r *AgentRunner, err error) {
requireCheckpoint: requireCheckpoint,
modelInfo: modelInfo,
containWfTool: containWfTool,
returnDirectlyTools: returnDirectlyTools,
}, nil
}

View File

@ -57,6 +57,7 @@ type AgentRunner struct {
runner compose.Runnable[*AgentRequest, *schema.Message]
requireCheckpoint bool
returnDirectlyTools map[string]struct{}
containWfTool bool
modelInfo *modelmgr.Model
}
@ -66,7 +67,7 @@ func (r *AgentRunner) StreamExecute(ctx context.Context, req *AgentRequest) (
) {
executeID := uuid.New()
hdl, sr, sw := newReplyCallback(ctx, executeID.String())
hdl, sr, sw := newReplyCallback(ctx, executeID.String(), r.returnDirectlyTools)
go func() {
defer func() {

View File

@ -38,7 +38,7 @@ import (
"github.com/coze-dev/coze-studio/backend/pkg/logs"
)
func newReplyCallback(_ context.Context, executeID string) (clb callbacks.Handler,
func newReplyCallback(_ context.Context, executeID string, returnDirectlyTools map[string]struct{}) (clb callbacks.Handler,
sr *schema.StreamReader[*entity.AgentEvent], sw *schema.StreamWriter[*entity.AgentEvent],
) {
sr, sw = schema.Pipe[*entity.AgentEvent](10)
@ -46,6 +46,7 @@ func newReplyCallback(_ context.Context, executeID string) (clb callbacks.Handle
rcc := &replyChunkCallback{
sw: sw,
executeID: executeID,
returnDirectlyTools: returnDirectlyTools,
}
clb = callbacks.NewHandlerBuilder().
@ -61,6 +62,7 @@ func newReplyCallback(_ context.Context, executeID string) (clb callbacks.Handle
type replyChunkCallback struct {
sw *schema.StreamWriter[*entity.AgentEvent]
executeID string
returnDirectlyTools map[string]struct{}
}
func (r *replyChunkCallback) OnError(ctx context.Context, info *callbacks.RunInfo, err error) context.Context {
@ -201,7 +203,7 @@ func (r *replyChunkCallback) OnEndWithStreamOutput(ctx context.Context, info *ca
}, nil)
return ctx
case compose.ComponentOfToolsNode:
toolsMessage, err := concatToolsNodeOutput(ctx, output)
toolsMessage, err := r.concatToolsNodeOutput(ctx, output)
if err != nil {
r.sw.Send(nil, err)
return ctx
@ -270,9 +272,21 @@ func convInterruptEventType(interruptEvent any) singleagent.InterruptEventType {
return interruptEventType
}
func concatToolsNodeOutput(ctx context.Context, output *schema.StreamReader[callbacks.CallbackOutput]) ([]*schema.Message, error) {
func (r *replyChunkCallback) concatToolsNodeOutput(ctx context.Context, output *schema.StreamReader[callbacks.CallbackOutput]) ([]*schema.Message, error) {
defer output.Close()
toolsMsgChunks := make([][]*schema.Message, 0, 5)
var toolsMsgChunks [][]*schema.Message
var sr *schema.StreamReader[*schema.Message]
var sw *schema.StreamWriter[*schema.Message]
defer func() {
if sw != nil {
sw.Close()
}
}()
var streamInitialized bool
returnDirectToolsMap := make(map[int]bool)
isReturnDirectToolsFirstCheck := true
isToolsMsgChunksInit := false
for {
cbOut, err := output.Recv()
if errors.Is(err, io.EOF) {
@ -280,27 +294,48 @@ func concatToolsNodeOutput(ctx context.Context, output *schema.StreamReader[call
}
if err != nil {
if sw != nil {
sw.Send(nil, err)
}
return nil, err
}
msgs := convToolsNodeCallbackOutput(cbOut)
for _, msg := range msgs {
if !isToolsMsgChunksInit {
isToolsMsgChunksInit = true
toolsMsgChunks = make([][]*schema.Message, len(msgs))
}
for mIndex, msg := range msgs {
if msg == nil {
continue
}
findSameMsg := false
for i, msgChunks := range toolsMsgChunks {
if msg.ToolCallID == msgChunks[0].ToolCallID {
toolsMsgChunks[i] = append(toolsMsgChunks[i], msg)
findSameMsg = true
break
if len(r.returnDirectlyTools) > 0 {
if isReturnDirectToolsFirstCheck {
isReturnDirectToolsFirstCheck = false
if _, ok := r.returnDirectlyTools[msg.ToolName]; ok {
returnDirectToolsMap[mIndex] = true
}
}
if !findSameMsg {
toolsMsgChunks = append(toolsMsgChunks, []*schema.Message{msg})
if _, ok := returnDirectToolsMap[mIndex]; ok {
if !streamInitialized {
sr, sw = schema.Pipe[*schema.Message](5)
r.sw.Send(&entity.AgentEvent{
EventType: singleagent.EventTypeOfToolsAsChatModelStream,
ChatModelAnswer: sr,
}, nil)
streamInitialized = true
}
sw.Send(msg, nil)
}
}
if toolsMsgChunks[mIndex] == nil {
toolsMsgChunks[mIndex] = []*schema.Message{msg}
} else {
toolsMsgChunks[mIndex] = append(toolsMsgChunks[mIndex], msg)
}
}
}

View File

@ -200,7 +200,7 @@ func transformEventMap(eventType singleagent.EventType) (message.MessageType, er
return message.MessageTypeKnowledge, nil
case singleagent.EventTypeOfToolsMessage:
return message.MessageTypeToolResponse, nil
case singleagent.EventTypeOfChatModelAnswer:
case singleagent.EventTypeOfChatModelAnswer, singleagent.EventTypeOfToolsAsChatModelStream:
return message.MessageTypeAnswer, nil
case singleagent.EventTypeOfSuggest:
return message.MessageTypeFlowUp, nil