fix(singleagent): workflow as tool return directly (#526)
This commit is contained in:
parent
38b63f00a3
commit
0367e66eca
|
|
@ -38,6 +38,7 @@ type EventType string
|
|||
|
||||
const (
|
||||
EventTypeOfChatModelAnswer EventType = "chatmodel_answer"
|
||||
EventTypeOfToolsAsChatModelStream EventType = "tools_as_chatmodel_answer"
|
||||
EventTypeOfToolsMessage EventType = "tools_message"
|
||||
EventTypeOfFuncCall EventType = "func_call"
|
||||
EventTypeOfSuggest EventType = "suggest"
|
||||
|
|
|
|||
|
|
@ -113,7 +113,7 @@ func BuildAgent(ctx context.Context, conf *Config) (r *AgentRunner, err error) {
|
|||
}
|
||||
tr := newPreToolRetriever(&toolPreCallConf{})
|
||||
|
||||
wfTools, toolsReturnDirectly, err := newWorkflowTools(ctx, &workflowConfig{
|
||||
wfTools, returnDirectlyTools, err := newWorkflowTools(ctx, &workflowConfig{
|
||||
wfInfos: conf.Agent.Workflow,
|
||||
})
|
||||
if err != nil {
|
||||
|
|
@ -176,7 +176,7 @@ func BuildAgent(ctx context.Context, conf *Config) (r *AgentRunner, err error) {
|
|||
ToolsConfig: compose.ToolsNodeConfig{
|
||||
Tools: agentTools,
|
||||
},
|
||||
ToolReturnDirectly: toolsReturnDirectly,
|
||||
ToolReturnDirectly: returnDirectlyTools,
|
||||
ModelNodeName: keyOfReActAgentChatModel,
|
||||
ToolsNodeName: keyOfReActAgentToolsNode,
|
||||
})
|
||||
|
|
@ -277,6 +277,7 @@ func BuildAgent(ctx context.Context, conf *Config) (r *AgentRunner, err error) {
|
|||
requireCheckpoint: requireCheckpoint,
|
||||
modelInfo: modelInfo,
|
||||
containWfTool: containWfTool,
|
||||
returnDirectlyTools: returnDirectlyTools,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -57,6 +57,7 @@ type AgentRunner struct {
|
|||
runner compose.Runnable[*AgentRequest, *schema.Message]
|
||||
requireCheckpoint bool
|
||||
|
||||
returnDirectlyTools map[string]struct{}
|
||||
containWfTool bool
|
||||
modelInfo *modelmgr.Model
|
||||
}
|
||||
|
|
@ -66,7 +67,7 @@ func (r *AgentRunner) StreamExecute(ctx context.Context, req *AgentRequest) (
|
|||
) {
|
||||
executeID := uuid.New()
|
||||
|
||||
hdl, sr, sw := newReplyCallback(ctx, executeID.String())
|
||||
hdl, sr, sw := newReplyCallback(ctx, executeID.String(), r.returnDirectlyTools)
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ import (
|
|||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
)
|
||||
|
||||
func newReplyCallback(_ context.Context, executeID string) (clb callbacks.Handler,
|
||||
func newReplyCallback(_ context.Context, executeID string, returnDirectlyTools map[string]struct{}) (clb callbacks.Handler,
|
||||
sr *schema.StreamReader[*entity.AgentEvent], sw *schema.StreamWriter[*entity.AgentEvent],
|
||||
) {
|
||||
sr, sw = schema.Pipe[*entity.AgentEvent](10)
|
||||
|
|
@ -46,6 +46,7 @@ func newReplyCallback(_ context.Context, executeID string) (clb callbacks.Handle
|
|||
rcc := &replyChunkCallback{
|
||||
sw: sw,
|
||||
executeID: executeID,
|
||||
returnDirectlyTools: returnDirectlyTools,
|
||||
}
|
||||
|
||||
clb = callbacks.NewHandlerBuilder().
|
||||
|
|
@ -61,6 +62,7 @@ func newReplyCallback(_ context.Context, executeID string) (clb callbacks.Handle
|
|||
type replyChunkCallback struct {
|
||||
sw *schema.StreamWriter[*entity.AgentEvent]
|
||||
executeID string
|
||||
returnDirectlyTools map[string]struct{}
|
||||
}
|
||||
|
||||
func (r *replyChunkCallback) OnError(ctx context.Context, info *callbacks.RunInfo, err error) context.Context {
|
||||
|
|
@ -201,7 +203,7 @@ func (r *replyChunkCallback) OnEndWithStreamOutput(ctx context.Context, info *ca
|
|||
}, nil)
|
||||
return ctx
|
||||
case compose.ComponentOfToolsNode:
|
||||
toolsMessage, err := concatToolsNodeOutput(ctx, output)
|
||||
toolsMessage, err := r.concatToolsNodeOutput(ctx, output)
|
||||
if err != nil {
|
||||
r.sw.Send(nil, err)
|
||||
return ctx
|
||||
|
|
@ -270,9 +272,21 @@ func convInterruptEventType(interruptEvent any) singleagent.InterruptEventType {
|
|||
return interruptEventType
|
||||
}
|
||||
|
||||
func concatToolsNodeOutput(ctx context.Context, output *schema.StreamReader[callbacks.CallbackOutput]) ([]*schema.Message, error) {
|
||||
func (r *replyChunkCallback) concatToolsNodeOutput(ctx context.Context, output *schema.StreamReader[callbacks.CallbackOutput]) ([]*schema.Message, error) {
|
||||
defer output.Close()
|
||||
toolsMsgChunks := make([][]*schema.Message, 0, 5)
|
||||
var toolsMsgChunks [][]*schema.Message
|
||||
var sr *schema.StreamReader[*schema.Message]
|
||||
var sw *schema.StreamWriter[*schema.Message]
|
||||
defer func() {
|
||||
if sw != nil {
|
||||
sw.Close()
|
||||
}
|
||||
}()
|
||||
var streamInitialized bool
|
||||
returnDirectToolsMap := make(map[int]bool)
|
||||
isReturnDirectToolsFirstCheck := true
|
||||
isToolsMsgChunksInit := false
|
||||
|
||||
for {
|
||||
cbOut, err := output.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
|
|
@ -280,27 +294,48 @@ func concatToolsNodeOutput(ctx context.Context, output *schema.StreamReader[call
|
|||
}
|
||||
|
||||
if err != nil {
|
||||
if sw != nil {
|
||||
sw.Send(nil, err)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msgs := convToolsNodeCallbackOutput(cbOut)
|
||||
|
||||
for _, msg := range msgs {
|
||||
if !isToolsMsgChunksInit {
|
||||
isToolsMsgChunksInit = true
|
||||
toolsMsgChunks = make([][]*schema.Message, len(msgs))
|
||||
}
|
||||
|
||||
for mIndex, msg := range msgs {
|
||||
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
findSameMsg := false
|
||||
for i, msgChunks := range toolsMsgChunks {
|
||||
if msg.ToolCallID == msgChunks[0].ToolCallID {
|
||||
toolsMsgChunks[i] = append(toolsMsgChunks[i], msg)
|
||||
findSameMsg = true
|
||||
break
|
||||
if len(r.returnDirectlyTools) > 0 {
|
||||
if isReturnDirectToolsFirstCheck {
|
||||
isReturnDirectToolsFirstCheck = false
|
||||
if _, ok := r.returnDirectlyTools[msg.ToolName]; ok {
|
||||
returnDirectToolsMap[mIndex] = true
|
||||
}
|
||||
}
|
||||
|
||||
if !findSameMsg {
|
||||
toolsMsgChunks = append(toolsMsgChunks, []*schema.Message{msg})
|
||||
if _, ok := returnDirectToolsMap[mIndex]; ok {
|
||||
if !streamInitialized {
|
||||
sr, sw = schema.Pipe[*schema.Message](5)
|
||||
r.sw.Send(&entity.AgentEvent{
|
||||
EventType: singleagent.EventTypeOfToolsAsChatModelStream,
|
||||
ChatModelAnswer: sr,
|
||||
}, nil)
|
||||
streamInitialized = true
|
||||
}
|
||||
sw.Send(msg, nil)
|
||||
}
|
||||
}
|
||||
if toolsMsgChunks[mIndex] == nil {
|
||||
toolsMsgChunks[mIndex] = []*schema.Message{msg}
|
||||
} else {
|
||||
toolsMsgChunks[mIndex] = append(toolsMsgChunks[mIndex], msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -200,7 +200,7 @@ func transformEventMap(eventType singleagent.EventType) (message.MessageType, er
|
|||
return message.MessageTypeKnowledge, nil
|
||||
case singleagent.EventTypeOfToolsMessage:
|
||||
return message.MessageTypeToolResponse, nil
|
||||
case singleagent.EventTypeOfChatModelAnswer:
|
||||
case singleagent.EventTypeOfChatModelAnswer, singleagent.EventTypeOfToolsAsChatModelStream:
|
||||
return message.MessageTypeAnswer, nil
|
||||
case singleagent.EventTypeOfSuggest:
|
||||
return message.MessageTypeFlowUp, nil
|
||||
|
|
|
|||
Loading…
Reference in New Issue