fix(singleagent): workflow as tool return directly (#526)
This commit is contained in:
parent
38b63f00a3
commit
0367e66eca
|
|
@ -38,6 +38,7 @@ type EventType string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
EventTypeOfChatModelAnswer EventType = "chatmodel_answer"
|
EventTypeOfChatModelAnswer EventType = "chatmodel_answer"
|
||||||
|
EventTypeOfToolsAsChatModelStream EventType = "tools_as_chatmodel_answer"
|
||||||
EventTypeOfToolsMessage EventType = "tools_message"
|
EventTypeOfToolsMessage EventType = "tools_message"
|
||||||
EventTypeOfFuncCall EventType = "func_call"
|
EventTypeOfFuncCall EventType = "func_call"
|
||||||
EventTypeOfSuggest EventType = "suggest"
|
EventTypeOfSuggest EventType = "suggest"
|
||||||
|
|
|
||||||
|
|
@ -113,7 +113,7 @@ func BuildAgent(ctx context.Context, conf *Config) (r *AgentRunner, err error) {
|
||||||
}
|
}
|
||||||
tr := newPreToolRetriever(&toolPreCallConf{})
|
tr := newPreToolRetriever(&toolPreCallConf{})
|
||||||
|
|
||||||
wfTools, toolsReturnDirectly, err := newWorkflowTools(ctx, &workflowConfig{
|
wfTools, returnDirectlyTools, err := newWorkflowTools(ctx, &workflowConfig{
|
||||||
wfInfos: conf.Agent.Workflow,
|
wfInfos: conf.Agent.Workflow,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -176,7 +176,7 @@ func BuildAgent(ctx context.Context, conf *Config) (r *AgentRunner, err error) {
|
||||||
ToolsConfig: compose.ToolsNodeConfig{
|
ToolsConfig: compose.ToolsNodeConfig{
|
||||||
Tools: agentTools,
|
Tools: agentTools,
|
||||||
},
|
},
|
||||||
ToolReturnDirectly: toolsReturnDirectly,
|
ToolReturnDirectly: returnDirectlyTools,
|
||||||
ModelNodeName: keyOfReActAgentChatModel,
|
ModelNodeName: keyOfReActAgentChatModel,
|
||||||
ToolsNodeName: keyOfReActAgentToolsNode,
|
ToolsNodeName: keyOfReActAgentToolsNode,
|
||||||
})
|
})
|
||||||
|
|
@ -277,6 +277,7 @@ func BuildAgent(ctx context.Context, conf *Config) (r *AgentRunner, err error) {
|
||||||
requireCheckpoint: requireCheckpoint,
|
requireCheckpoint: requireCheckpoint,
|
||||||
modelInfo: modelInfo,
|
modelInfo: modelInfo,
|
||||||
containWfTool: containWfTool,
|
containWfTool: containWfTool,
|
||||||
|
returnDirectlyTools: returnDirectlyTools,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -57,6 +57,7 @@ type AgentRunner struct {
|
||||||
runner compose.Runnable[*AgentRequest, *schema.Message]
|
runner compose.Runnable[*AgentRequest, *schema.Message]
|
||||||
requireCheckpoint bool
|
requireCheckpoint bool
|
||||||
|
|
||||||
|
returnDirectlyTools map[string]struct{}
|
||||||
containWfTool bool
|
containWfTool bool
|
||||||
modelInfo *modelmgr.Model
|
modelInfo *modelmgr.Model
|
||||||
}
|
}
|
||||||
|
|
@ -66,7 +67,7 @@ func (r *AgentRunner) StreamExecute(ctx context.Context, req *AgentRequest) (
|
||||||
) {
|
) {
|
||||||
executeID := uuid.New()
|
executeID := uuid.New()
|
||||||
|
|
||||||
hdl, sr, sw := newReplyCallback(ctx, executeID.String())
|
hdl, sr, sw := newReplyCallback(ctx, executeID.String(), r.returnDirectlyTools)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer func() {
|
defer func() {
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ import (
|
||||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newReplyCallback(_ context.Context, executeID string) (clb callbacks.Handler,
|
func newReplyCallback(_ context.Context, executeID string, returnDirectlyTools map[string]struct{}) (clb callbacks.Handler,
|
||||||
sr *schema.StreamReader[*entity.AgentEvent], sw *schema.StreamWriter[*entity.AgentEvent],
|
sr *schema.StreamReader[*entity.AgentEvent], sw *schema.StreamWriter[*entity.AgentEvent],
|
||||||
) {
|
) {
|
||||||
sr, sw = schema.Pipe[*entity.AgentEvent](10)
|
sr, sw = schema.Pipe[*entity.AgentEvent](10)
|
||||||
|
|
@ -46,6 +46,7 @@ func newReplyCallback(_ context.Context, executeID string) (clb callbacks.Handle
|
||||||
rcc := &replyChunkCallback{
|
rcc := &replyChunkCallback{
|
||||||
sw: sw,
|
sw: sw,
|
||||||
executeID: executeID,
|
executeID: executeID,
|
||||||
|
returnDirectlyTools: returnDirectlyTools,
|
||||||
}
|
}
|
||||||
|
|
||||||
clb = callbacks.NewHandlerBuilder().
|
clb = callbacks.NewHandlerBuilder().
|
||||||
|
|
@ -61,6 +62,7 @@ func newReplyCallback(_ context.Context, executeID string) (clb callbacks.Handle
|
||||||
type replyChunkCallback struct {
|
type replyChunkCallback struct {
|
||||||
sw *schema.StreamWriter[*entity.AgentEvent]
|
sw *schema.StreamWriter[*entity.AgentEvent]
|
||||||
executeID string
|
executeID string
|
||||||
|
returnDirectlyTools map[string]struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *replyChunkCallback) OnError(ctx context.Context, info *callbacks.RunInfo, err error) context.Context {
|
func (r *replyChunkCallback) OnError(ctx context.Context, info *callbacks.RunInfo, err error) context.Context {
|
||||||
|
|
@ -201,7 +203,7 @@ func (r *replyChunkCallback) OnEndWithStreamOutput(ctx context.Context, info *ca
|
||||||
}, nil)
|
}, nil)
|
||||||
return ctx
|
return ctx
|
||||||
case compose.ComponentOfToolsNode:
|
case compose.ComponentOfToolsNode:
|
||||||
toolsMessage, err := concatToolsNodeOutput(ctx, output)
|
toolsMessage, err := r.concatToolsNodeOutput(ctx, output)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.sw.Send(nil, err)
|
r.sw.Send(nil, err)
|
||||||
return ctx
|
return ctx
|
||||||
|
|
@ -270,9 +272,21 @@ func convInterruptEventType(interruptEvent any) singleagent.InterruptEventType {
|
||||||
return interruptEventType
|
return interruptEventType
|
||||||
}
|
}
|
||||||
|
|
||||||
func concatToolsNodeOutput(ctx context.Context, output *schema.StreamReader[callbacks.CallbackOutput]) ([]*schema.Message, error) {
|
func (r *replyChunkCallback) concatToolsNodeOutput(ctx context.Context, output *schema.StreamReader[callbacks.CallbackOutput]) ([]*schema.Message, error) {
|
||||||
defer output.Close()
|
defer output.Close()
|
||||||
toolsMsgChunks := make([][]*schema.Message, 0, 5)
|
var toolsMsgChunks [][]*schema.Message
|
||||||
|
var sr *schema.StreamReader[*schema.Message]
|
||||||
|
var sw *schema.StreamWriter[*schema.Message]
|
||||||
|
defer func() {
|
||||||
|
if sw != nil {
|
||||||
|
sw.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
var streamInitialized bool
|
||||||
|
returnDirectToolsMap := make(map[int]bool)
|
||||||
|
isReturnDirectToolsFirstCheck := true
|
||||||
|
isToolsMsgChunksInit := false
|
||||||
|
|
||||||
for {
|
for {
|
||||||
cbOut, err := output.Recv()
|
cbOut, err := output.Recv()
|
||||||
if errors.Is(err, io.EOF) {
|
if errors.Is(err, io.EOF) {
|
||||||
|
|
@ -280,27 +294,48 @@ func concatToolsNodeOutput(ctx context.Context, output *schema.StreamReader[call
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if sw != nil {
|
||||||
|
sw.Send(nil, err)
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
msgs := convToolsNodeCallbackOutput(cbOut)
|
msgs := convToolsNodeCallbackOutput(cbOut)
|
||||||
|
|
||||||
for _, msg := range msgs {
|
if !isToolsMsgChunksInit {
|
||||||
|
isToolsMsgChunksInit = true
|
||||||
|
toolsMsgChunks = make([][]*schema.Message, len(msgs))
|
||||||
|
}
|
||||||
|
|
||||||
|
for mIndex, msg := range msgs {
|
||||||
|
|
||||||
if msg == nil {
|
if msg == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if len(r.returnDirectlyTools) > 0 {
|
||||||
findSameMsg := false
|
if isReturnDirectToolsFirstCheck {
|
||||||
for i, msgChunks := range toolsMsgChunks {
|
isReturnDirectToolsFirstCheck = false
|
||||||
if msg.ToolCallID == msgChunks[0].ToolCallID {
|
if _, ok := r.returnDirectlyTools[msg.ToolName]; ok {
|
||||||
toolsMsgChunks[i] = append(toolsMsgChunks[i], msg)
|
returnDirectToolsMap[mIndex] = true
|
||||||
findSameMsg = true
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !findSameMsg {
|
if _, ok := returnDirectToolsMap[mIndex]; ok {
|
||||||
toolsMsgChunks = append(toolsMsgChunks, []*schema.Message{msg})
|
if !streamInitialized {
|
||||||
|
sr, sw = schema.Pipe[*schema.Message](5)
|
||||||
|
r.sw.Send(&entity.AgentEvent{
|
||||||
|
EventType: singleagent.EventTypeOfToolsAsChatModelStream,
|
||||||
|
ChatModelAnswer: sr,
|
||||||
|
}, nil)
|
||||||
|
streamInitialized = true
|
||||||
|
}
|
||||||
|
sw.Send(msg, nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if toolsMsgChunks[mIndex] == nil {
|
||||||
|
toolsMsgChunks[mIndex] = []*schema.Message{msg}
|
||||||
|
} else {
|
||||||
|
toolsMsgChunks[mIndex] = append(toolsMsgChunks[mIndex], msg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -200,7 +200,7 @@ func transformEventMap(eventType singleagent.EventType) (message.MessageType, er
|
||||||
return message.MessageTypeKnowledge, nil
|
return message.MessageTypeKnowledge, nil
|
||||||
case singleagent.EventTypeOfToolsMessage:
|
case singleagent.EventTypeOfToolsMessage:
|
||||||
return message.MessageTypeToolResponse, nil
|
return message.MessageTypeToolResponse, nil
|
||||||
case singleagent.EventTypeOfChatModelAnswer:
|
case singleagent.EventTypeOfChatModelAnswer, singleagent.EventTypeOfToolsAsChatModelStream:
|
||||||
return message.MessageTypeAnswer, nil
|
return message.MessageTypeAnswer, nil
|
||||||
case singleagent.EventTypeOfSuggest:
|
case singleagent.EventTypeOfSuggest:
|
||||||
return message.MessageTypeFlowUp, nil
|
return message.MessageTypeFlowUp, nil
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue