fix(singleagent): support workflow output node (#662)

This commit is contained in:
junwen-lee
2025-08-11 10:49:51 +08:00
committed by GitHub
parent a21e41b89d
commit efc6e55fe5
18 changed files with 391 additions and 101 deletions

View File

@@ -19,6 +19,7 @@ package agentflow
import (
"context"
"errors"
"io"
"slices"
"github.com/google/uuid"
@@ -33,6 +34,7 @@ import (
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
)
type AgentState struct {
@@ -69,7 +71,49 @@ func (r *AgentRunner) StreamExecute(ctx context.Context, req *AgentRequest) (
hdl, sr, sw := newReplyCallback(ctx, executeID.String(), r.returnDirectlyTools)
go func() {
var composeOpts []compose.Option
var pipeMsgOpt compose.Option
var workflowMsgSr *schema.StreamReader[*crossworkflow.WorkflowMessage]
if r.containWfTool {
cfReq := crossworkflow.ExecuteConfig{
AgentID: &req.Identity.AgentID,
ConnectorUID: req.UserID,
ConnectorID: req.Identity.ConnectorID,
BizType: crossworkflow.BizTypeAgent,
}
if req.Identity.IsDraft {
cfReq.Mode = crossworkflow.ExecuteModeDebug
} else {
cfReq.Mode = crossworkflow.ExecuteModeRelease
}
wfConfig := crossworkflow.DefaultSVC().WithExecuteConfig(cfReq)
composeOpts = append(composeOpts, wfConfig)
pipeMsgOpt, workflowMsgSr = crossworkflow.DefaultSVC().WithMessagePipe()
composeOpts = append(composeOpts, pipeMsgOpt)
}
composeOpts = append(composeOpts, compose.WithCallbacks(hdl))
_ = compose.RegisterSerializableType[*AgentState]("agent_state")
if r.requireCheckpoint {
defaultCheckPointID := executeID.String()
if req.ResumeInfo != nil {
resumeInfo := req.ResumeInfo
if resumeInfo.InterruptType != singleagent.InterruptEventType_OauthPlugin {
defaultCheckPointID = resumeInfo.InterruptID
opts := crossworkflow.DefaultSVC().WithResumeToolWorkflow(resumeInfo.AllWfInterruptData[resumeInfo.ToolCallID], req.Input.Content, resumeInfo.AllWfInterruptData)
composeOpts = append(composeOpts, opts)
}
}
composeOpts = append(composeOpts, compose.WithCheckPointID(defaultCheckPointID))
}
if r.containWfTool && workflowMsgSr != nil {
safego.Go(ctx, func() {
r.processWfMidAnswerStream(ctx, sw, workflowMsgSr)
})
}
safego.Go(ctx, func() {
defer func() {
if pe := recover(); pe != nil {
logs.CtxErrorf(ctx, "[AgentRunner] StreamExecute recover, err: %v", pe)
@@ -78,45 +122,58 @@ func (r *AgentRunner) StreamExecute(ctx context.Context, req *AgentRequest) (
}
sw.Close()
}()
var composeOpts []compose.Option
composeOpts = append(composeOpts, compose.WithCallbacks(hdl))
_ = compose.RegisterSerializableType[*AgentState]("agent_state")
if r.requireCheckpoint {
defaultCheckPointID := executeID.String()
if req.ResumeInfo != nil {
resumeInfo := req.ResumeInfo
if resumeInfo.InterruptType != singleagent.InterruptEventType_OauthPlugin {
defaultCheckPointID = resumeInfo.InterruptID
opts := crossworkflow.DefaultSVC().WithResumeToolWorkflow(resumeInfo.AllWfInterruptData[resumeInfo.ToolCallID], req.Input.Content, resumeInfo.AllWfInterruptData)
composeOpts = append(composeOpts, opts)
}
}
composeOpts = append(composeOpts, compose.WithCheckPointID(defaultCheckPointID))
}
if r.containWfTool {
cfReq := crossworkflow.ExecuteConfig{
AgentID: &req.Identity.AgentID,
ConnectorUID: req.UserID,
ConnectorID: req.Identity.ConnectorID,
BizType: crossworkflow.BizTypeAgent,
}
if req.Identity.IsDraft {
cfReq.Mode = crossworkflow.ExecuteModeDebug
} else {
cfReq.Mode = crossworkflow.ExecuteModeRelease
}
wfConfig := crossworkflow.DefaultSVC().WithExecuteConfig(cfReq)
composeOpts = append(composeOpts, wfConfig)
}
_, _ = r.runner.Stream(ctx, req, composeOpts...)
}()
})
return sr, nil
}
func (r *AgentRunner) processWfMidAnswerStream(_ context.Context, sw *schema.StreamWriter[*entity.AgentEvent], wfStream *schema.StreamReader[*crossworkflow.WorkflowMessage]) {
streamInitialized := false
var srT *schema.StreamReader[*schema.Message]
var swT *schema.StreamWriter[*schema.Message]
defer func() {
if swT != nil {
swT.Close()
}
}()
for {
msg, err := wfStream.Recv()
if err == io.EOF {
break
}
if msg == nil || msg.DataMessage == nil {
continue
}
if msg.DataMessage.NodeType != crossworkflow.NodeTypeOutputEmitter {
continue
}
if !streamInitialized {
streamInitialized = true
srT, swT = schema.Pipe[*schema.Message](5)
sw.Send(&entity.AgentEvent{
EventType: singleagent.EventTypeOfToolMidAnswer,
ToolMidAnswer: srT,
}, nil)
}
swT.Send(&schema.Message{
Role: msg.DataMessage.Role,
Content: msg.DataMessage.Content,
Extra: func(msg *crossworkflow.WorkflowMessage) map[string]any {
extra := make(map[string]any)
extra["workflow_node_name"] = msg.NodeTitle
if msg.DataMessage.Last {
extra["is_finish"] = true
}
return extra
}(msg),
}, nil)
}
}
func (r *AgentRunner) PreHandlerReq(ctx context.Context, req *AgentRequest) *AgentRequest {
req.Input = r.preHandlerInput(req.Input)
req.History = r.preHandlerHistory(req.History)

View File

@@ -267,7 +267,6 @@ func convInterruptEventType(interruptEvent any) singleagent.InterruptEventType {
}
func (r *replyChunkCallback) concatToolsNodeOutput(ctx context.Context, output *schema.StreamReader[callbacks.CallbackOutput]) ([]*schema.Message, error) {
defer output.Close()
var toolsMsgChunks [][]*schema.Message
var sr *schema.StreamReader[*schema.Message]
var sw *schema.StreamWriter[*schema.Message]
@@ -280,7 +279,6 @@ func (r *replyChunkCallback) concatToolsNodeOutput(ctx context.Context, output *
returnDirectToolsMap := make(map[int]bool)
isReturnDirectToolsFirstCheck := true
isToolsMsgChunksInit := false
for {
cbOut, err := output.Recv()
if errors.Is(err, io.EOF) {
@@ -318,8 +316,8 @@ func (r *replyChunkCallback) concatToolsNodeOutput(ctx context.Context, output *
if !streamInitialized {
sr, sw = schema.Pipe[*schema.Message](5)
r.sw.Send(&entity.AgentEvent{
EventType: singleagent.EventTypeOfToolsAsChatModelStream,
ChatModelAnswer: sr,
EventType: singleagent.EventTypeOfToolsAsChatModelStream,
ToolAsChatModelAnswer: sr,
}, nil)
streamInitialized = true
}