feat: manually mirror opencoze's code from bytedance
Change-Id: I09a73aadda978ad9511264a756b2ce51f5761adf
This commit is contained in:
@@ -0,0 +1,374 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package agentflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
"github.com/cloudwego/eino/callbacks"
|
||||
"github.com/cloudwego/eino/components"
|
||||
"github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/components/retriever"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossworkflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/errorx"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
)
|
||||
|
||||
func newReplyCallback(_ context.Context, executeID string) (clb callbacks.Handler,
|
||||
sr *schema.StreamReader[*entity.AgentEvent], sw *schema.StreamWriter[*entity.AgentEvent],
|
||||
) {
|
||||
sr, sw = schema.Pipe[*entity.AgentEvent](10)
|
||||
|
||||
rcc := &replyChunkCallback{
|
||||
sw: sw,
|
||||
executeID: executeID,
|
||||
}
|
||||
|
||||
clb = callbacks.NewHandlerBuilder().
|
||||
OnStartFn(rcc.OnStart).
|
||||
OnEndFn(rcc.OnEnd).
|
||||
OnEndWithStreamOutputFn(rcc.OnEndWithStreamOutput).
|
||||
OnErrorFn(rcc.OnError).
|
||||
Build()
|
||||
|
||||
return clb, sr, sw
|
||||
}
|
||||
|
||||
type replyChunkCallback struct {
|
||||
sw *schema.StreamWriter[*entity.AgentEvent]
|
||||
executeID string
|
||||
}
|
||||
|
||||
func (r *replyChunkCallback) OnError(ctx context.Context, info *callbacks.RunInfo, err error) context.Context {
|
||||
logs.CtxInfof(ctx, "info-OnError, info=%v, err=%v", conv.DebugJsonToStr(info), err)
|
||||
|
||||
switch info.Component {
|
||||
case compose.ComponentOfGraph:
|
||||
if interruptInfo, ok := compose.ExtractInterruptInfo(err); ok {
|
||||
if info.Name != "" {
|
||||
return ctx
|
||||
}
|
||||
interruptData := convInterruptInfo(ctx, interruptInfo)
|
||||
interruptData.InterruptID = r.executeID
|
||||
|
||||
toolMessageEvent := &entity.AgentEvent{
|
||||
EventType: singleagent.EventTypeOfToolsMessage,
|
||||
ToolsMessage: []*schema.Message{
|
||||
{
|
||||
Role: schema.Tool,
|
||||
Content: "directly streaming reply",
|
||||
ToolCallID: interruptData.ToolCallID,
|
||||
},
|
||||
},
|
||||
}
|
||||
r.sw.Send(toolMessageEvent, nil)
|
||||
|
||||
interruptEvent := &entity.AgentEvent{
|
||||
EventType: singleagent.EventTypeOfInterrupt,
|
||||
Interrupt: interruptData,
|
||||
}
|
||||
r.sw.Send(interruptEvent, nil)
|
||||
|
||||
} else {
|
||||
logs.CtxErrorf(ctx, "node execute failed, component=%v, name=%v, err=%w",
|
||||
info.Component, info.Name, err)
|
||||
var customErr errorx.StatusError
|
||||
errMsg := "Internal server error"
|
||||
if errors.As(err, &customErr) && customErr.Code() != 0 {
|
||||
errMsg = customErr.Msg()
|
||||
}
|
||||
r.sw.Send(nil, errors.New(errMsg))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (r *replyChunkCallback) OnStart(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
|
||||
logs.CtxInfof(ctx, "info-OnStart, info=%v, input=%v", conv.DebugJsonToStr(info), conv.DebugJsonToStr(input))
|
||||
|
||||
switch info.Component {
|
||||
case compose.ComponentOfToolsNode:
|
||||
if info.Name != keyOfReActAgentToolsNode {
|
||||
return ctx
|
||||
}
|
||||
ae := &entity.AgentEvent{
|
||||
EventType: singleagent.EventTypeOfFuncCall,
|
||||
FuncCall: convToolsNodeCallbackInput(input),
|
||||
}
|
||||
r.sw.Send(ae, nil)
|
||||
}
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (r *replyChunkCallback) OnEnd(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context {
|
||||
logs.CtxInfof(ctx, "info-OnEnd, info=%v, output=%v", conv.DebugJsonToStr(info), conv.DebugJsonToStr(output))
|
||||
switch info.Name {
|
||||
case keyOfKnowledgeRetriever:
|
||||
knowledgeEvent := &entity.AgentEvent{
|
||||
EventType: singleagent.EventTypeOfKnowledge,
|
||||
Knowledge: retriever.ConvCallbackOutput(output).Docs,
|
||||
}
|
||||
|
||||
if knowledgeEvent.Knowledge != nil {
|
||||
r.sw.Send(knowledgeEvent, nil)
|
||||
}
|
||||
case keyOfToolsPreRetriever:
|
||||
result := convToolsPreRetrieverCallbackInput(output)
|
||||
|
||||
if len(result) > 0 {
|
||||
for _, item := range result {
|
||||
var event *entity.AgentEvent
|
||||
if item.Role == schema.Tool {
|
||||
event = &entity.AgentEvent{
|
||||
EventType: singleagent.EventTypeOfToolsMessage,
|
||||
ToolsMessage: []*schema.Message{item},
|
||||
}
|
||||
} else {
|
||||
event = &entity.AgentEvent{
|
||||
EventType: singleagent.EventTypeOfFuncCall,
|
||||
FuncCall: item,
|
||||
}
|
||||
}
|
||||
r.sw.Send(event, nil)
|
||||
}
|
||||
}
|
||||
|
||||
case keyOfSuggestParser:
|
||||
sg := convSuggestionNodeCallbackOutput(output)
|
||||
|
||||
if len(sg) > 0 {
|
||||
for _, item := range sg {
|
||||
suggestionEvent := &entity.AgentEvent{
|
||||
EventType: singleagent.EventTypeOfSuggest,
|
||||
Suggest: item,
|
||||
}
|
||||
r.sw.Send(suggestionEvent, nil)
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
return ctx
|
||||
}
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (r *replyChunkCallback) OnEndWithStreamOutput(ctx context.Context, info *callbacks.RunInfo,
|
||||
output *schema.StreamReader[callbacks.CallbackOutput],
|
||||
) context.Context {
|
||||
logs.CtxInfof(ctx, "info-OnEndWithStreamOutput, info=%v, output=%v", conv.DebugJsonToStr(info), conv.DebugJsonToStr(output))
|
||||
switch info.Component {
|
||||
case compose.ComponentOfGraph, components.ComponentOfChatModel:
|
||||
if info.Name != keyOfReActAgentChatModel && info.Name != keyOfLLM {
|
||||
output.Close()
|
||||
return ctx
|
||||
}
|
||||
sr := schema.StreamReaderWithConvert(output, func(t callbacks.CallbackOutput) (*schema.Message, error) {
|
||||
cbOut := model.ConvCallbackOutput(t)
|
||||
return cbOut.Message, nil
|
||||
})
|
||||
|
||||
r.sw.Send(&entity.AgentEvent{
|
||||
EventType: singleagent.EventTypeOfChatModelAnswer,
|
||||
ChatModelAnswer: sr,
|
||||
}, nil)
|
||||
return ctx
|
||||
case compose.ComponentOfToolsNode:
|
||||
toolsMessage, err := concatToolsNodeOutput(ctx, output)
|
||||
if err != nil {
|
||||
r.sw.Send(nil, err)
|
||||
return ctx
|
||||
}
|
||||
|
||||
r.sw.Send(&entity.AgentEvent{
|
||||
EventType: singleagent.EventTypeOfToolsMessage,
|
||||
ToolsMessage: toolsMessage,
|
||||
}, nil)
|
||||
return ctx
|
||||
default:
|
||||
return ctx
|
||||
}
|
||||
}
|
||||
|
||||
func convInterruptInfo(ctx context.Context, interruptInfo *compose.InterruptInfo) *singleagent.InterruptInfo {
|
||||
var output *compose.InterruptInfo
|
||||
output = interruptInfo.SubGraphs[keyOfReActAgent]
|
||||
var extra any
|
||||
|
||||
for i := range output.RerunNodesExtra {
|
||||
extra = output.RerunNodesExtra[i]
|
||||
break
|
||||
}
|
||||
toolsNodeExtra, ok := extra.(*compose.ToolsInterruptAndRerunExtra)
|
||||
logs.CtxInfof(ctx, "toolsNodeExtra=%v, err=%v", toolsNodeExtra, ok)
|
||||
|
||||
var toolCallID string
|
||||
|
||||
wfResumeData := make(map[string]*crossworkflow.ToolInterruptEvent)
|
||||
toolResultData := make(map[string]*plugin.ToolInterruptEvent)
|
||||
var interruptEventType singleagent.InterruptEventType
|
||||
for k, v := range toolsNodeExtra.RerunExtraMap {
|
||||
toolCallID = k
|
||||
|
||||
interruptEventType = convInterruptEventType(v)
|
||||
|
||||
if interruptEventType == singleagent.InterruptEventType_OauthPlugin {
|
||||
toolResultData[k] = v.(*plugin.ToolInterruptEvent)
|
||||
} else {
|
||||
wfResumeData[k] = v.(*crossworkflow.ToolInterruptEvent)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
interrupt := &singleagent.InterruptInfo{
|
||||
AllToolInterruptData: toolResultData,
|
||||
AllWfInterruptData: wfResumeData,
|
||||
ToolCallID: toolCallID,
|
||||
InterruptType: interruptEventType,
|
||||
}
|
||||
return interrupt
|
||||
}
|
||||
|
||||
func convInterruptEventType(interruptEvent any) singleagent.InterruptEventType {
|
||||
var interruptEventType singleagent.InterruptEventType
|
||||
|
||||
switch t := interruptEvent.(type) {
|
||||
case *crossworkflow.ToolInterruptEvent:
|
||||
interruptEventType = singleagent.InterruptEventType(int64(t.EventType))
|
||||
case *plugin.ToolInterruptEvent:
|
||||
if t.Event == plugin.InterruptEventTypeOfToolNeedOAuth {
|
||||
interruptEventType = singleagent.InterruptEventType_OauthPlugin
|
||||
}
|
||||
}
|
||||
return interruptEventType
|
||||
}
|
||||
|
||||
func concatToolsNodeOutput(ctx context.Context, output *schema.StreamReader[callbacks.CallbackOutput]) ([]*schema.Message, error) {
|
||||
defer output.Close()
|
||||
toolsMsgChunks := make([][]*schema.Message, 0, 5)
|
||||
for {
|
||||
cbOut, err := output.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msgs := convToolsNodeCallbackOutput(cbOut)
|
||||
|
||||
for _, msg := range msgs {
|
||||
if msg == nil || msg.ToolCallID == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
findSameMsg := false
|
||||
for i, msgChunks := range toolsMsgChunks {
|
||||
if msg.ToolCallID == msgChunks[0].ToolCallID {
|
||||
toolsMsgChunks[i] = append(toolsMsgChunks[i], msg)
|
||||
findSameMsg = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !findSameMsg {
|
||||
toolsMsgChunks = append(toolsMsgChunks, []*schema.Message{msg})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
toolMessages := make([]*schema.Message, 0, len(toolsMsgChunks))
|
||||
|
||||
for _, msgChunks := range toolsMsgChunks {
|
||||
msg, err := schema.ConcatMessages(msgChunks)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
toolMessages = append(toolMessages, msg)
|
||||
}
|
||||
|
||||
return toolMessages, nil
|
||||
}
|
||||
|
||||
func convToolsNodeCallbackInput(input callbacks.CallbackInput) *schema.Message {
|
||||
switch t := input.(type) {
|
||||
case *schema.Message:
|
||||
return t
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func convToolsNodeCallbackOutput(output callbacks.CallbackOutput) []*schema.Message {
|
||||
switch t := output.(type) {
|
||||
case []*schema.Message:
|
||||
return t
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func convToolsPreRetrieverCallbackInput(output callbacks.CallbackOutput) []*schema.Message {
|
||||
switch t := output.(type) {
|
||||
case []*schema.Message:
|
||||
return t
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func convSuggestionNodeCallbackOutput(output callbacks.CallbackInput) []*schema.Message {
|
||||
var sg []*schema.Message
|
||||
|
||||
switch so := output.(type) {
|
||||
case *schema.Message:
|
||||
if so.Content != "" {
|
||||
var suggestions []string
|
||||
|
||||
err := json.Unmarshal([]byte(so.Content), &suggestions)
|
||||
|
||||
if err == nil && len(suggestions) > 0 {
|
||||
for _, suggestion := range suggestions {
|
||||
sm := &schema.Message{
|
||||
Role: so.Role,
|
||||
Content: suggestion,
|
||||
ResponseMeta: so.ResponseMeta,
|
||||
}
|
||||
sg = append(sg, sm)
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
return sg
|
||||
}
|
||||
|
||||
return sg
|
||||
}
|
||||
Reference in New Issue
Block a user