feat: manually mirror opencoze's code from bytedance
Change-Id: I09a73aadda978ad9511264a756b2ce51f5761adf
This commit is contained in:
@@ -0,0 +1,174 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package agentflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"slices"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/agentrun"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/singleagent"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossmodelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/crossdomain/contract/crossworkflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
)
|
||||
|
||||
type AgentState struct {
|
||||
Messages []*schema.Message
|
||||
UserInput *schema.Message
|
||||
ReturnDirectlyToolCallID string
|
||||
}
|
||||
|
||||
type AgentRequest struct {
|
||||
UserID string
|
||||
Input *schema.Message
|
||||
History []*schema.Message
|
||||
|
||||
Identity *singleagent.AgentIdentity
|
||||
|
||||
ResumeInfo *singleagent.InterruptInfo
|
||||
PreCallTools []*agentrun.ToolsRetriever
|
||||
Variables map[string]string
|
||||
}
|
||||
|
||||
type AgentRunner struct {
|
||||
runner compose.Runnable[*AgentRequest, *schema.Message]
|
||||
requireCheckpoint bool
|
||||
|
||||
containWfTool bool
|
||||
modelInfo *crossmodelmgr.Model
|
||||
}
|
||||
|
||||
func (r *AgentRunner) StreamExecute(ctx context.Context, req *AgentRequest) (
|
||||
sr *schema.StreamReader[*entity.AgentEvent], err error,
|
||||
) {
|
||||
executeID := uuid.New()
|
||||
|
||||
hdl, sr, sw := newReplyCallback(ctx, executeID.String())
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if pe := recover(); pe != nil {
|
||||
logs.CtxErrorf(ctx, "[AgentRunner] StreamExecute recover, err: %v", pe)
|
||||
|
||||
sw.Send(nil, errors.New("internal server error"))
|
||||
}
|
||||
sw.Close()
|
||||
}()
|
||||
|
||||
var composeOpts []compose.Option
|
||||
composeOpts = append(composeOpts, compose.WithCallbacks(hdl))
|
||||
_ = compose.RegisterSerializableType[*AgentState]("agent_state")
|
||||
if r.requireCheckpoint {
|
||||
|
||||
defaultCheckPointID := executeID.String()
|
||||
if req.ResumeInfo != nil {
|
||||
resumeInfo := req.ResumeInfo
|
||||
if resumeInfo.InterruptType != singleagent.InterruptEventType_OauthPlugin {
|
||||
defaultCheckPointID = resumeInfo.InterruptID
|
||||
opts := crossworkflow.DefaultSVC().WithResumeToolWorkflow(resumeInfo.AllWfInterruptData[resumeInfo.ToolCallID], req.Input.Content, resumeInfo.AllWfInterruptData)
|
||||
composeOpts = append(composeOpts, opts)
|
||||
}
|
||||
}
|
||||
|
||||
composeOpts = append(composeOpts, compose.WithCheckPointID(defaultCheckPointID))
|
||||
}
|
||||
if r.containWfTool {
|
||||
cfReq := crossworkflow.ExecuteConfig{
|
||||
AgentID: &req.Identity.AgentID,
|
||||
ConnectorUID: req.UserID,
|
||||
ConnectorID: req.Identity.ConnectorID,
|
||||
BizType: crossworkflow.BizTypeAgent,
|
||||
}
|
||||
if req.Identity.IsDraft {
|
||||
cfReq.Mode = crossworkflow.ExecuteModeDebug
|
||||
} else {
|
||||
cfReq.Mode = crossworkflow.ExecuteModeRelease
|
||||
}
|
||||
wfConfig := crossworkflow.DefaultSVC().WithExecuteConfig(cfReq)
|
||||
composeOpts = append(composeOpts, wfConfig)
|
||||
}
|
||||
_, _ = r.runner.Stream(ctx, req, composeOpts...)
|
||||
}()
|
||||
|
||||
return sr, nil
|
||||
}
|
||||
|
||||
func (r *AgentRunner) PreHandlerReq(ctx context.Context, req *AgentRequest) *AgentRequest {
|
||||
req.Input = r.preHandlerInput(req.Input)
|
||||
req.History = r.preHandlerHistory(req.History)
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
func (r *AgentRunner) preHandlerInput(input *schema.Message) *schema.Message {
|
||||
var multiContent []schema.ChatMessagePart
|
||||
for _, v := range input.MultiContent {
|
||||
switch v.Type {
|
||||
case schema.ChatMessagePartTypeImageURL:
|
||||
if !slices.Contains(r.modelInfo.Meta.Capability.InputModal, modelmgr.ModalImage) {
|
||||
input.Content = input.Content + ": " + v.ImageURL.URL
|
||||
} else {
|
||||
multiContent = append(multiContent, v)
|
||||
}
|
||||
case schema.ChatMessagePartTypeFileURL:
|
||||
if !slices.Contains(r.modelInfo.Meta.Capability.InputModal, modelmgr.ModalFile) {
|
||||
input.Content = input.Content + ": " + v.FileURL.URL
|
||||
} else {
|
||||
multiContent = append(multiContent, v)
|
||||
}
|
||||
case schema.ChatMessagePartTypeAudioURL:
|
||||
if !slices.Contains(r.modelInfo.Meta.Capability.InputModal, modelmgr.ModalAudio) {
|
||||
input.Content = input.Content + ": " + v.FileURL.URL
|
||||
} else {
|
||||
multiContent = append(multiContent, v)
|
||||
}
|
||||
case schema.ChatMessagePartTypeVideoURL:
|
||||
if !slices.Contains(r.modelInfo.Meta.Capability.InputModal, modelmgr.ModalVideo) {
|
||||
input.Content = input.Content + ": " + v.FileURL.URL
|
||||
} else {
|
||||
multiContent = append(multiContent, v)
|
||||
}
|
||||
case schema.ChatMessagePartTypeText:
|
||||
break
|
||||
|
||||
default:
|
||||
multiContent = append(multiContent, v)
|
||||
}
|
||||
}
|
||||
input.MultiContent = multiContent
|
||||
return input
|
||||
}
|
||||
|
||||
func (r *AgentRunner) preHandlerHistory(history []*schema.Message) []*schema.Message {
|
||||
var hm []*schema.Message
|
||||
for _, msg := range history {
|
||||
if msg.Role == schema.User {
|
||||
msg = r.preHandlerInput(msg)
|
||||
}
|
||||
hm = append(hm, msg)
|
||||
}
|
||||
return hm
|
||||
}
|
||||
Reference in New Issue
Block a user