fix(singleagent): workflow as tool return directly (#526)
This commit is contained in:
@@ -38,14 +38,15 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
)
|
||||
|
||||
func newReplyCallback(_ context.Context, executeID string) (clb callbacks.Handler,
|
||||
func newReplyCallback(_ context.Context, executeID string, returnDirectlyTools map[string]struct{}) (clb callbacks.Handler,
|
||||
sr *schema.StreamReader[*entity.AgentEvent], sw *schema.StreamWriter[*entity.AgentEvent],
|
||||
) {
|
||||
sr, sw = schema.Pipe[*entity.AgentEvent](10)
|
||||
|
||||
rcc := &replyChunkCallback{
|
||||
sw: sw,
|
||||
executeID: executeID,
|
||||
sw: sw,
|
||||
executeID: executeID,
|
||||
returnDirectlyTools: returnDirectlyTools,
|
||||
}
|
||||
|
||||
clb = callbacks.NewHandlerBuilder().
|
||||
@@ -59,8 +60,9 @@ func newReplyCallback(_ context.Context, executeID string) (clb callbacks.Handle
|
||||
}
|
||||
|
||||
type replyChunkCallback struct {
|
||||
sw *schema.StreamWriter[*entity.AgentEvent]
|
||||
executeID string
|
||||
sw *schema.StreamWriter[*entity.AgentEvent]
|
||||
executeID string
|
||||
returnDirectlyTools map[string]struct{}
|
||||
}
|
||||
|
||||
func (r *replyChunkCallback) OnError(ctx context.Context, info *callbacks.RunInfo, err error) context.Context {
|
||||
@@ -201,7 +203,7 @@ func (r *replyChunkCallback) OnEndWithStreamOutput(ctx context.Context, info *ca
|
||||
}, nil)
|
||||
return ctx
|
||||
case compose.ComponentOfToolsNode:
|
||||
toolsMessage, err := concatToolsNodeOutput(ctx, output)
|
||||
toolsMessage, err := r.concatToolsNodeOutput(ctx, output)
|
||||
if err != nil {
|
||||
r.sw.Send(nil, err)
|
||||
return ctx
|
||||
@@ -270,9 +272,21 @@ func convInterruptEventType(interruptEvent any) singleagent.InterruptEventType {
|
||||
return interruptEventType
|
||||
}
|
||||
|
||||
func concatToolsNodeOutput(ctx context.Context, output *schema.StreamReader[callbacks.CallbackOutput]) ([]*schema.Message, error) {
|
||||
func (r *replyChunkCallback) concatToolsNodeOutput(ctx context.Context, output *schema.StreamReader[callbacks.CallbackOutput]) ([]*schema.Message, error) {
|
||||
defer output.Close()
|
||||
toolsMsgChunks := make([][]*schema.Message, 0, 5)
|
||||
var toolsMsgChunks [][]*schema.Message
|
||||
var sr *schema.StreamReader[*schema.Message]
|
||||
var sw *schema.StreamWriter[*schema.Message]
|
||||
defer func() {
|
||||
if sw != nil {
|
||||
sw.Close()
|
||||
}
|
||||
}()
|
||||
var streamInitialized bool
|
||||
returnDirectToolsMap := make(map[int]bool)
|
||||
isReturnDirectToolsFirstCheck := true
|
||||
isToolsMsgChunksInit := false
|
||||
|
||||
for {
|
||||
cbOut, err := output.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
@@ -280,27 +294,48 @@ func concatToolsNodeOutput(ctx context.Context, output *schema.StreamReader[call
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if sw != nil {
|
||||
sw.Send(nil, err)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msgs := convToolsNodeCallbackOutput(cbOut)
|
||||
|
||||
for _, msg := range msgs {
|
||||
if !isToolsMsgChunksInit {
|
||||
isToolsMsgChunksInit = true
|
||||
toolsMsgChunks = make([][]*schema.Message, len(msgs))
|
||||
}
|
||||
|
||||
for mIndex, msg := range msgs {
|
||||
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
if len(r.returnDirectlyTools) > 0 {
|
||||
if isReturnDirectToolsFirstCheck {
|
||||
isReturnDirectToolsFirstCheck = false
|
||||
if _, ok := r.returnDirectlyTools[msg.ToolName]; ok {
|
||||
returnDirectToolsMap[mIndex] = true
|
||||
}
|
||||
}
|
||||
|
||||
findSameMsg := false
|
||||
for i, msgChunks := range toolsMsgChunks {
|
||||
if msg.ToolCallID == msgChunks[0].ToolCallID {
|
||||
toolsMsgChunks[i] = append(toolsMsgChunks[i], msg)
|
||||
findSameMsg = true
|
||||
break
|
||||
if _, ok := returnDirectToolsMap[mIndex]; ok {
|
||||
if !streamInitialized {
|
||||
sr, sw = schema.Pipe[*schema.Message](5)
|
||||
r.sw.Send(&entity.AgentEvent{
|
||||
EventType: singleagent.EventTypeOfToolsAsChatModelStream,
|
||||
ChatModelAnswer: sr,
|
||||
}, nil)
|
||||
streamInitialized = true
|
||||
}
|
||||
sw.Send(msg, nil)
|
||||
}
|
||||
}
|
||||
|
||||
if !findSameMsg {
|
||||
toolsMsgChunks = append(toolsMsgChunks, []*schema.Message{msg})
|
||||
if toolsMsgChunks[mIndex] == nil {
|
||||
toolsMsgChunks[mIndex] = []*schema.Message{msg}
|
||||
} else {
|
||||
toolsMsgChunks[mIndex] = append(toolsMsgChunks[mIndex], msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user