410 lines
11 KiB
Go
410 lines
11 KiB
Go
/*
|
|
* Copyright 2025 coze-dev Authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package agentflow
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"io"
|
|
|
|
"github.com/cloudwego/eino/callbacks"
|
|
"github.com/cloudwego/eino/components"
|
|
"github.com/cloudwego/eino/components/model"
|
|
"github.com/cloudwego/eino/components/retriever"
|
|
"github.com/cloudwego/eino/compose"
|
|
"github.com/cloudwego/eino/schema"
|
|
|
|
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
|
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
|
|
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossworkflow"
|
|
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
|
)
|
|
|
|
func newReplyCallback(_ context.Context, executeID string, returnDirectlyTools map[string]struct{}) (clb callbacks.Handler,
|
|
sr *schema.StreamReader[*entity.AgentEvent], sw *schema.StreamWriter[*entity.AgentEvent],
|
|
) {
|
|
sr, sw = schema.Pipe[*entity.AgentEvent](10)
|
|
|
|
rcc := &replyChunkCallback{
|
|
sw: sw,
|
|
executeID: executeID,
|
|
returnDirectlyTools: returnDirectlyTools,
|
|
}
|
|
|
|
clb = callbacks.NewHandlerBuilder().
|
|
OnStartFn(rcc.OnStart).
|
|
OnEndFn(rcc.OnEnd).
|
|
OnEndWithStreamOutputFn(rcc.OnEndWithStreamOutput).
|
|
OnErrorFn(rcc.OnError).
|
|
Build()
|
|
|
|
return clb, sr, sw
|
|
}
|
|
|
|
type replyChunkCallback struct {
|
|
sw *schema.StreamWriter[*entity.AgentEvent]
|
|
executeID string
|
|
returnDirectlyTools map[string]struct{}
|
|
}
|
|
|
|
func (r *replyChunkCallback) OnError(ctx context.Context, info *callbacks.RunInfo, err error) context.Context {
|
|
logs.CtxInfof(ctx, "info-OnError, info=%v, err=%v", conv.DebugJsonToStr(info), err)
|
|
|
|
switch info.Component {
|
|
case compose.ComponentOfGraph:
|
|
if interruptInfo, ok := compose.ExtractInterruptInfo(err); ok {
|
|
if info.Name != "" {
|
|
return ctx
|
|
}
|
|
interruptData := convInterruptInfo(ctx, interruptInfo)
|
|
interruptData.InterruptID = r.executeID
|
|
|
|
toolMessageEvent := &entity.AgentEvent{
|
|
EventType: singleagent.EventTypeOfToolsMessage,
|
|
ToolsMessage: []*schema.Message{
|
|
{
|
|
Role: schema.Tool,
|
|
Content: "directly streaming reply",
|
|
ToolCallID: interruptData.ToolCallID,
|
|
},
|
|
},
|
|
}
|
|
r.sw.Send(toolMessageEvent, nil)
|
|
|
|
interruptEvent := &entity.AgentEvent{
|
|
EventType: singleagent.EventTypeOfInterrupt,
|
|
Interrupt: interruptData,
|
|
}
|
|
r.sw.Send(interruptEvent, nil)
|
|
|
|
} else {
|
|
logs.CtxErrorf(ctx, "node execute failed, component=%v, name=%v, err=%v",
|
|
info.Component, info.Name, err)
|
|
var customErr errorx.StatusError
|
|
errMsg := "Internal server error"
|
|
if errors.As(err, &customErr) && customErr.Code() != 0 {
|
|
errMsg = customErr.Msg()
|
|
}
|
|
r.sw.Send(nil, errors.New(errMsg))
|
|
}
|
|
|
|
}
|
|
|
|
return ctx
|
|
}
|
|
|
|
func (r *replyChunkCallback) OnStart(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
|
|
logs.CtxInfof(ctx, "info-OnStart, info=%v, input=%v", conv.DebugJsonToStr(info), conv.DebugJsonToStr(input))
|
|
|
|
switch info.Component {
|
|
case compose.ComponentOfToolsNode:
|
|
if info.Name != keyOfReActAgentToolsNode {
|
|
return ctx
|
|
}
|
|
ae := &entity.AgentEvent{
|
|
EventType: singleagent.EventTypeOfFuncCall,
|
|
FuncCall: convToolsNodeCallbackInput(input),
|
|
}
|
|
r.sw.Send(ae, nil)
|
|
}
|
|
|
|
return ctx
|
|
}
|
|
|
|
func (r *replyChunkCallback) OnEnd(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context {
|
|
logs.CtxInfof(ctx, "info-OnEnd, info=%v, output=%v", conv.DebugJsonToStr(info), conv.DebugJsonToStr(output))
|
|
switch info.Name {
|
|
case keyOfKnowledgeRetriever:
|
|
knowledgeEvent := &entity.AgentEvent{
|
|
EventType: singleagent.EventTypeOfKnowledge,
|
|
Knowledge: retriever.ConvCallbackOutput(output).Docs,
|
|
}
|
|
|
|
if knowledgeEvent.Knowledge != nil {
|
|
r.sw.Send(knowledgeEvent, nil)
|
|
}
|
|
case keyOfToolsPreRetriever:
|
|
result := convToolsPreRetrieverCallbackInput(output)
|
|
|
|
if len(result) > 0 {
|
|
for _, item := range result {
|
|
var event *entity.AgentEvent
|
|
if item.Role == schema.Tool {
|
|
event = &entity.AgentEvent{
|
|
EventType: singleagent.EventTypeOfToolsMessage,
|
|
ToolsMessage: []*schema.Message{item},
|
|
}
|
|
} else {
|
|
event = &entity.AgentEvent{
|
|
EventType: singleagent.EventTypeOfFuncCall,
|
|
FuncCall: item,
|
|
}
|
|
}
|
|
r.sw.Send(event, nil)
|
|
}
|
|
}
|
|
|
|
case keyOfSuggestParser:
|
|
sg := convSuggestionNodeCallbackOutput(output)
|
|
|
|
if len(sg) > 0 {
|
|
for _, item := range sg {
|
|
suggestionEvent := &entity.AgentEvent{
|
|
EventType: singleagent.EventTypeOfSuggest,
|
|
Suggest: item,
|
|
}
|
|
r.sw.Send(suggestionEvent, nil)
|
|
}
|
|
}
|
|
|
|
default:
|
|
return ctx
|
|
}
|
|
|
|
return ctx
|
|
}
|
|
|
|
func (r *replyChunkCallback) OnEndWithStreamOutput(ctx context.Context, info *callbacks.RunInfo,
|
|
output *schema.StreamReader[callbacks.CallbackOutput],
|
|
) context.Context {
|
|
logs.CtxInfof(ctx, "info-OnEndWithStreamOutput, info=%v, output=%v", conv.DebugJsonToStr(info), conv.DebugJsonToStr(output))
|
|
switch info.Component {
|
|
case compose.ComponentOfGraph, components.ComponentOfChatModel:
|
|
if info.Name != keyOfReActAgentChatModel && info.Name != keyOfLLM {
|
|
output.Close()
|
|
return ctx
|
|
}
|
|
sr := schema.StreamReaderWithConvert(output, func(t callbacks.CallbackOutput) (*schema.Message, error) {
|
|
cbOut := model.ConvCallbackOutput(t)
|
|
return cbOut.Message, nil
|
|
})
|
|
|
|
r.sw.Send(&entity.AgentEvent{
|
|
EventType: singleagent.EventTypeOfChatModelAnswer,
|
|
ChatModelAnswer: sr,
|
|
}, nil)
|
|
return ctx
|
|
case compose.ComponentOfToolsNode:
|
|
toolsMessage, err := r.concatToolsNodeOutput(ctx, output)
|
|
if err != nil {
|
|
r.sw.Send(nil, err)
|
|
return ctx
|
|
}
|
|
|
|
r.sw.Send(&entity.AgentEvent{
|
|
EventType: singleagent.EventTypeOfToolsMessage,
|
|
ToolsMessage: toolsMessage,
|
|
}, nil)
|
|
return ctx
|
|
default:
|
|
return ctx
|
|
}
|
|
}
|
|
|
|
func convInterruptInfo(ctx context.Context, interruptInfo *compose.InterruptInfo) *singleagent.InterruptInfo {
|
|
var output *compose.InterruptInfo
|
|
output = interruptInfo.SubGraphs[keyOfReActAgent]
|
|
var extra any
|
|
|
|
for i := range output.RerunNodesExtra {
|
|
extra = output.RerunNodesExtra[i]
|
|
break
|
|
}
|
|
toolsNodeExtra, ok := extra.(*compose.ToolsInterruptAndRerunExtra)
|
|
logs.CtxInfof(ctx, "toolsNodeExtra=%v, err=%v", toolsNodeExtra, ok)
|
|
|
|
var toolCallID string
|
|
|
|
wfResumeData := make(map[string]*crossworkflow.ToolInterruptEvent)
|
|
toolResultData := make(map[string]*plugin.ToolInterruptEvent)
|
|
var interruptEventType singleagent.InterruptEventType
|
|
for k, v := range toolsNodeExtra.RerunExtraMap {
|
|
toolCallID = k
|
|
|
|
interruptEventType = convInterruptEventType(v)
|
|
|
|
if interruptEventType == singleagent.InterruptEventType_OauthPlugin {
|
|
toolResultData[k] = v.(*plugin.ToolInterruptEvent)
|
|
} else {
|
|
wfResumeData[k] = v.(*crossworkflow.ToolInterruptEvent)
|
|
}
|
|
break
|
|
}
|
|
|
|
interrupt := &singleagent.InterruptInfo{
|
|
AllToolInterruptData: toolResultData,
|
|
AllWfInterruptData: wfResumeData,
|
|
ToolCallID: toolCallID,
|
|
InterruptType: interruptEventType,
|
|
}
|
|
return interrupt
|
|
}
|
|
|
|
func convInterruptEventType(interruptEvent any) singleagent.InterruptEventType {
|
|
var interruptEventType singleagent.InterruptEventType
|
|
|
|
switch t := interruptEvent.(type) {
|
|
case *crossworkflow.ToolInterruptEvent:
|
|
interruptEventType = singleagent.InterruptEventType(int64(t.EventType))
|
|
case *plugin.ToolInterruptEvent:
|
|
if t.Event == plugin.InterruptEventTypeOfToolNeedOAuth {
|
|
interruptEventType = singleagent.InterruptEventType_OauthPlugin
|
|
}
|
|
}
|
|
return interruptEventType
|
|
}
|
|
|
|
func (r *replyChunkCallback) concatToolsNodeOutput(ctx context.Context, output *schema.StreamReader[callbacks.CallbackOutput]) ([]*schema.Message, error) {
|
|
defer output.Close()
|
|
var toolsMsgChunks [][]*schema.Message
|
|
var sr *schema.StreamReader[*schema.Message]
|
|
var sw *schema.StreamWriter[*schema.Message]
|
|
defer func() {
|
|
if sw != nil {
|
|
sw.Close()
|
|
}
|
|
}()
|
|
var streamInitialized bool
|
|
returnDirectToolsMap := make(map[int]bool)
|
|
isReturnDirectToolsFirstCheck := true
|
|
isToolsMsgChunksInit := false
|
|
|
|
for {
|
|
cbOut, err := output.Recv()
|
|
if errors.Is(err, io.EOF) {
|
|
break
|
|
}
|
|
|
|
if err != nil {
|
|
if sw != nil {
|
|
sw.Send(nil, err)
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
msgs := convToolsNodeCallbackOutput(cbOut)
|
|
|
|
if !isToolsMsgChunksInit {
|
|
isToolsMsgChunksInit = true
|
|
toolsMsgChunks = make([][]*schema.Message, len(msgs))
|
|
}
|
|
|
|
for mIndex, msg := range msgs {
|
|
|
|
if msg == nil {
|
|
continue
|
|
}
|
|
if len(r.returnDirectlyTools) > 0 {
|
|
if isReturnDirectToolsFirstCheck {
|
|
isReturnDirectToolsFirstCheck = false
|
|
if _, ok := r.returnDirectlyTools[msg.ToolName]; ok {
|
|
returnDirectToolsMap[mIndex] = true
|
|
}
|
|
}
|
|
|
|
if _, ok := returnDirectToolsMap[mIndex]; ok {
|
|
if !streamInitialized {
|
|
sr, sw = schema.Pipe[*schema.Message](5)
|
|
r.sw.Send(&entity.AgentEvent{
|
|
EventType: singleagent.EventTypeOfToolsAsChatModelStream,
|
|
ChatModelAnswer: sr,
|
|
}, nil)
|
|
streamInitialized = true
|
|
}
|
|
sw.Send(msg, nil)
|
|
}
|
|
}
|
|
if toolsMsgChunks[mIndex] == nil {
|
|
toolsMsgChunks[mIndex] = []*schema.Message{msg}
|
|
} else {
|
|
toolsMsgChunks[mIndex] = append(toolsMsgChunks[mIndex], msg)
|
|
}
|
|
}
|
|
}
|
|
|
|
toolMessages := make([]*schema.Message, 0, len(toolsMsgChunks))
|
|
|
|
for _, msgChunks := range toolsMsgChunks {
|
|
msg, err := schema.ConcatMessages(msgChunks)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
toolMessages = append(toolMessages, msg)
|
|
}
|
|
|
|
return toolMessages, nil
|
|
}
|
|
|
|
func convToolsNodeCallbackInput(input callbacks.CallbackInput) *schema.Message {
|
|
switch t := input.(type) {
|
|
case *schema.Message:
|
|
return t
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func convToolsNodeCallbackOutput(output callbacks.CallbackOutput) []*schema.Message {
|
|
switch t := output.(type) {
|
|
case []*schema.Message:
|
|
return t
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func convToolsPreRetrieverCallbackInput(output callbacks.CallbackOutput) []*schema.Message {
|
|
switch t := output.(type) {
|
|
case []*schema.Message:
|
|
return t
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func convSuggestionNodeCallbackOutput(output callbacks.CallbackInput) []*schema.Message {
|
|
var sg []*schema.Message
|
|
|
|
switch so := output.(type) {
|
|
case *schema.Message:
|
|
if so.Content != "" {
|
|
var suggestions []string
|
|
|
|
err := json.Unmarshal([]byte(so.Content), &suggestions)
|
|
|
|
if err == nil && len(suggestions) > 0 {
|
|
for _, suggestion := range suggestions {
|
|
sm := &schema.Message{
|
|
Role: so.Role,
|
|
Content: suggestion,
|
|
ResponseMeta: so.ResponseMeta,
|
|
}
|
|
sg = append(sg, sm)
|
|
}
|
|
}
|
|
}
|
|
default:
|
|
return sg
|
|
}
|
|
|
|
return sg
|
|
}
|