fix(singleagent): support workflow output node (#662)
This commit is contained in:
@@ -19,6 +19,7 @@ package agentflow
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"slices"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -33,6 +34,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
)
|
||||
|
||||
type AgentState struct {
|
||||
@@ -69,7 +71,49 @@ func (r *AgentRunner) StreamExecute(ctx context.Context, req *AgentRequest) (
|
||||
|
||||
hdl, sr, sw := newReplyCallback(ctx, executeID.String(), r.returnDirectlyTools)
|
||||
|
||||
go func() {
|
||||
var composeOpts []compose.Option
|
||||
var pipeMsgOpt compose.Option
|
||||
var workflowMsgSr *schema.StreamReader[*crossworkflow.WorkflowMessage]
|
||||
if r.containWfTool {
|
||||
cfReq := crossworkflow.ExecuteConfig{
|
||||
AgentID: &req.Identity.AgentID,
|
||||
ConnectorUID: req.UserID,
|
||||
ConnectorID: req.Identity.ConnectorID,
|
||||
BizType: crossworkflow.BizTypeAgent,
|
||||
}
|
||||
if req.Identity.IsDraft {
|
||||
cfReq.Mode = crossworkflow.ExecuteModeDebug
|
||||
} else {
|
||||
cfReq.Mode = crossworkflow.ExecuteModeRelease
|
||||
}
|
||||
wfConfig := crossworkflow.DefaultSVC().WithExecuteConfig(cfReq)
|
||||
composeOpts = append(composeOpts, wfConfig)
|
||||
pipeMsgOpt, workflowMsgSr = crossworkflow.DefaultSVC().WithMessagePipe()
|
||||
composeOpts = append(composeOpts, pipeMsgOpt)
|
||||
}
|
||||
|
||||
composeOpts = append(composeOpts, compose.WithCallbacks(hdl))
|
||||
_ = compose.RegisterSerializableType[*AgentState]("agent_state")
|
||||
if r.requireCheckpoint {
|
||||
|
||||
defaultCheckPointID := executeID.String()
|
||||
if req.ResumeInfo != nil {
|
||||
resumeInfo := req.ResumeInfo
|
||||
if resumeInfo.InterruptType != singleagent.InterruptEventType_OauthPlugin {
|
||||
defaultCheckPointID = resumeInfo.InterruptID
|
||||
opts := crossworkflow.DefaultSVC().WithResumeToolWorkflow(resumeInfo.AllWfInterruptData[resumeInfo.ToolCallID], req.Input.Content, resumeInfo.AllWfInterruptData)
|
||||
composeOpts = append(composeOpts, opts)
|
||||
}
|
||||
}
|
||||
|
||||
composeOpts = append(composeOpts, compose.WithCheckPointID(defaultCheckPointID))
|
||||
}
|
||||
if r.containWfTool && workflowMsgSr != nil {
|
||||
safego.Go(ctx, func() {
|
||||
r.processWfMidAnswerStream(ctx, sw, workflowMsgSr)
|
||||
})
|
||||
}
|
||||
safego.Go(ctx, func() {
|
||||
defer func() {
|
||||
if pe := recover(); pe != nil {
|
||||
logs.CtxErrorf(ctx, "[AgentRunner] StreamExecute recover, err: %v", pe)
|
||||
@@ -78,45 +122,58 @@ func (r *AgentRunner) StreamExecute(ctx context.Context, req *AgentRequest) (
|
||||
}
|
||||
sw.Close()
|
||||
}()
|
||||
|
||||
var composeOpts []compose.Option
|
||||
composeOpts = append(composeOpts, compose.WithCallbacks(hdl))
|
||||
_ = compose.RegisterSerializableType[*AgentState]("agent_state")
|
||||
if r.requireCheckpoint {
|
||||
|
||||
defaultCheckPointID := executeID.String()
|
||||
if req.ResumeInfo != nil {
|
||||
resumeInfo := req.ResumeInfo
|
||||
if resumeInfo.InterruptType != singleagent.InterruptEventType_OauthPlugin {
|
||||
defaultCheckPointID = resumeInfo.InterruptID
|
||||
opts := crossworkflow.DefaultSVC().WithResumeToolWorkflow(resumeInfo.AllWfInterruptData[resumeInfo.ToolCallID], req.Input.Content, resumeInfo.AllWfInterruptData)
|
||||
composeOpts = append(composeOpts, opts)
|
||||
}
|
||||
}
|
||||
|
||||
composeOpts = append(composeOpts, compose.WithCheckPointID(defaultCheckPointID))
|
||||
}
|
||||
if r.containWfTool {
|
||||
cfReq := crossworkflow.ExecuteConfig{
|
||||
AgentID: &req.Identity.AgentID,
|
||||
ConnectorUID: req.UserID,
|
||||
ConnectorID: req.Identity.ConnectorID,
|
||||
BizType: crossworkflow.BizTypeAgent,
|
||||
}
|
||||
if req.Identity.IsDraft {
|
||||
cfReq.Mode = crossworkflow.ExecuteModeDebug
|
||||
} else {
|
||||
cfReq.Mode = crossworkflow.ExecuteModeRelease
|
||||
}
|
||||
wfConfig := crossworkflow.DefaultSVC().WithExecuteConfig(cfReq)
|
||||
composeOpts = append(composeOpts, wfConfig)
|
||||
}
|
||||
_, _ = r.runner.Stream(ctx, req, composeOpts...)
|
||||
}()
|
||||
})
|
||||
|
||||
return sr, nil
|
||||
}
|
||||
|
||||
func (r *AgentRunner) processWfMidAnswerStream(_ context.Context, sw *schema.StreamWriter[*entity.AgentEvent], wfStream *schema.StreamReader[*crossworkflow.WorkflowMessage]) {
|
||||
streamInitialized := false
|
||||
var srT *schema.StreamReader[*schema.Message]
|
||||
var swT *schema.StreamWriter[*schema.Message]
|
||||
defer func() {
|
||||
if swT != nil {
|
||||
swT.Close()
|
||||
}
|
||||
}()
|
||||
for {
|
||||
msg, err := wfStream.Recv()
|
||||
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if msg == nil || msg.DataMessage == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if msg.DataMessage.NodeType != crossworkflow.NodeTypeOutputEmitter {
|
||||
continue
|
||||
}
|
||||
if !streamInitialized {
|
||||
streamInitialized = true
|
||||
srT, swT = schema.Pipe[*schema.Message](5)
|
||||
sw.Send(&entity.AgentEvent{
|
||||
EventType: singleagent.EventTypeOfToolMidAnswer,
|
||||
ToolMidAnswer: srT,
|
||||
}, nil)
|
||||
}
|
||||
swT.Send(&schema.Message{
|
||||
Role: msg.DataMessage.Role,
|
||||
Content: msg.DataMessage.Content,
|
||||
Extra: func(msg *crossworkflow.WorkflowMessage) map[string]any {
|
||||
|
||||
extra := make(map[string]any)
|
||||
extra["workflow_node_name"] = msg.NodeTitle
|
||||
if msg.DataMessage.Last {
|
||||
extra["is_finish"] = true
|
||||
}
|
||||
return extra
|
||||
}(msg),
|
||||
}, nil)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *AgentRunner) PreHandlerReq(ctx context.Context, req *AgentRequest) *AgentRequest {
|
||||
req.Input = r.preHandlerInput(req.Input)
|
||||
req.History = r.preHandlerHistory(req.History)
|
||||
|
||||
@@ -267,7 +267,6 @@ func convInterruptEventType(interruptEvent any) singleagent.InterruptEventType {
|
||||
}
|
||||
|
||||
func (r *replyChunkCallback) concatToolsNodeOutput(ctx context.Context, output *schema.StreamReader[callbacks.CallbackOutput]) ([]*schema.Message, error) {
|
||||
defer output.Close()
|
||||
var toolsMsgChunks [][]*schema.Message
|
||||
var sr *schema.StreamReader[*schema.Message]
|
||||
var sw *schema.StreamWriter[*schema.Message]
|
||||
@@ -280,7 +279,6 @@ func (r *replyChunkCallback) concatToolsNodeOutput(ctx context.Context, output *
|
||||
returnDirectToolsMap := make(map[int]bool)
|
||||
isReturnDirectToolsFirstCheck := true
|
||||
isToolsMsgChunksInit := false
|
||||
|
||||
for {
|
||||
cbOut, err := output.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
@@ -318,8 +316,8 @@ func (r *replyChunkCallback) concatToolsNodeOutput(ctx context.Context, output *
|
||||
if !streamInitialized {
|
||||
sr, sw = schema.Pipe[*schema.Message](5)
|
||||
r.sw.Send(&entity.AgentEvent{
|
||||
EventType: singleagent.EventTypeOfToolsAsChatModelStream,
|
||||
ChatModelAnswer: sr,
|
||||
EventType: singleagent.EventTypeOfToolsAsChatModelStream,
|
||||
ToolAsChatModelAnswer: sr,
|
||||
}, nil)
|
||||
streamInitialized = true
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user