feat: manually mirror opencoze's code from bytedance
Change-Id: I09a73aadda978ad9511264a756b2ce51f5761adf
This commit is contained in:
@@ -0,0 +1,293 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package agentflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/flow/agent/react"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Agent *entity.SingleAgent
|
||||
UserID string
|
||||
Identity *entity.AgentIdentity
|
||||
ModelFactory chatmodel.Factory
|
||||
CPStore compose.CheckPointStore
|
||||
}
|
||||
|
||||
const (
|
||||
keyOfPersonRender = "persona_render"
|
||||
keyOfKnowledgeRetriever = "knowledge_retriever"
|
||||
keyOfKnowledgeRetrieverPack = "knowledge_retriever_pack"
|
||||
keyOfPromptVariables = "prompt_variables"
|
||||
keyOfPromptTemplate = "prompt_template"
|
||||
keyOfReActAgent = "react_agent"
|
||||
keyOfReActAgentToolsNode = "agent_tool"
|
||||
keyOfReActAgentChatModel = "re_act_chat_model"
|
||||
keyOfLLM = "llm"
|
||||
keyOfToolsPreRetriever = "tools_pre_retriever"
|
||||
)
|
||||
|
||||
func BuildAgent(ctx context.Context, conf *Config) (r *AgentRunner, err error) {
|
||||
persona := conf.Agent.Prompt.GetPrompt()
|
||||
|
||||
avConf := &variableConf{
|
||||
Agent: conf.Agent,
|
||||
UserID: conf.UserID,
|
||||
ConnectorID: conf.Identity.ConnectorID,
|
||||
}
|
||||
avs, err := loadAgentVariables(ctx, avConf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
promptVars := &promptVariables{
|
||||
Agent: conf.Agent,
|
||||
avs: avs,
|
||||
}
|
||||
|
||||
personaVars := &personaRender{
|
||||
personaVariableNames: extractJinja2Placeholder(persona),
|
||||
persona: persona,
|
||||
variables: avs,
|
||||
}
|
||||
|
||||
kr, err := newKnowledgeRetriever(ctx, &retrieverConfig{
|
||||
knowledgeConfig: conf.Agent.Knowledge,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
modelInfo, err := loadModelInfo(ctx, ptr.From(conf.Agent.ModelInfo.ModelId))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
chatModel, err := newChatModel(ctx, &config{
|
||||
modelFactory: conf.ModelFactory,
|
||||
modelInfo: modelInfo,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
requireCheckpoint := false
|
||||
pluginTools, err := newPluginTools(ctx, &toolConfig{
|
||||
spaceID: conf.Agent.SpaceID,
|
||||
userID: conf.UserID,
|
||||
agentIdentity: conf.Identity,
|
||||
toolConf: conf.Agent.Plugin,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tr := newPreToolRetriever(&toolPreCallConf{})
|
||||
|
||||
wfTools, toolsReturnDirectly, err := newWorkflowTools(ctx, &workflowConfig{
|
||||
wfInfos: conf.Agent.Workflow,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var dbTools []tool.InvokableTool
|
||||
if len(conf.Agent.Database) > 0 {
|
||||
dbTools, err = newDatabaseTools(ctx, &databaseConfig{
|
||||
spaceID: conf.Agent.SpaceID,
|
||||
userID: conf.UserID,
|
||||
agentIdentity: conf.Identity,
|
||||
databaseConf: conf.Agent.Database,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var avTools []tool.InvokableTool
|
||||
if len(avs) > 0 {
|
||||
avTools, err = newAgentVariableTools(ctx, avConf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
containWfTool := false
|
||||
|
||||
if len(wfTools) > 0 {
|
||||
containWfTool = true
|
||||
}
|
||||
agentTools := make([]tool.BaseTool, 0, len(pluginTools)+len(wfTools)+len(dbTools)+len(avTools))
|
||||
agentTools = append(agentTools, slices.Transform(pluginTools, func(a tool.InvokableTool) tool.BaseTool {
|
||||
return a
|
||||
})...)
|
||||
agentTools = append(agentTools, slices.Transform(wfTools, func(a workflow.ToolFromWorkflow) tool.BaseTool { return a.(tool.BaseTool) })...)
|
||||
agentTools = append(agentTools, slices.Transform(dbTools, func(a tool.InvokableTool) tool.BaseTool {
|
||||
return a
|
||||
})...)
|
||||
|
||||
agentTools = append(agentTools, slices.Transform(avTools, func(a tool.InvokableTool) tool.BaseTool {
|
||||
return a
|
||||
})...)
|
||||
|
||||
var isReActAgent bool
|
||||
if len(agentTools) > 0 {
|
||||
isReActAgent = true
|
||||
requireCheckpoint = true
|
||||
if modelInfo.Meta.Capability != nil && !modelInfo.Meta.Capability.FunctionCall {
|
||||
return nil, fmt.Errorf("model %v does not support function call", modelInfo.Meta.Name)
|
||||
}
|
||||
}
|
||||
|
||||
var agentGraph compose.AnyGraph
|
||||
var agentNodeOpts []compose.GraphAddNodeOpt
|
||||
var agentNodeName string
|
||||
if isReActAgent {
|
||||
agent, err := react.NewAgent(ctx, &react.AgentConfig{
|
||||
ToolCallingModel: chatModel,
|
||||
ToolsConfig: compose.ToolsNodeConfig{
|
||||
Tools: agentTools,
|
||||
},
|
||||
ToolReturnDirectly: toolsReturnDirectly,
|
||||
ModelNodeName: keyOfReActAgentChatModel,
|
||||
ToolsNodeName: keyOfReActAgentToolsNode,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
agentGraph, agentNodeOpts = agent.ExportGraph()
|
||||
|
||||
agentNodeName = keyOfReActAgent
|
||||
} else {
|
||||
agentNodeName = keyOfLLM
|
||||
}
|
||||
|
||||
suggestGraph, nsg := newSuggestGraph(ctx, conf, chatModel)
|
||||
|
||||
g := compose.NewGraph[*AgentRequest, *schema.Message](
|
||||
compose.WithGenLocalState(func(ctx context.Context) (state *AgentState) {
|
||||
return &AgentState{}
|
||||
}))
|
||||
|
||||
_ = g.AddLambdaNode(keyOfPersonRender,
|
||||
compose.InvokableLambda[*AgentRequest, string](personaVars.RenderPersona),
|
||||
compose.WithStatePreHandler(func(ctx context.Context, ar *AgentRequest, state *AgentState) (*AgentRequest, error) {
|
||||
state.UserInput = ar.Input
|
||||
return ar, nil
|
||||
}),
|
||||
compose.WithOutputKey(placeholderOfPersona))
|
||||
|
||||
_ = g.AddLambdaNode(keyOfPromptVariables,
|
||||
compose.InvokableLambda[*AgentRequest, map[string]any](promptVars.AssemblePromptVariables))
|
||||
|
||||
_ = g.AddLambdaNode(keyOfKnowledgeRetriever,
|
||||
compose.InvokableLambda[*AgentRequest, []*schema.Document](kr.Retrieve),
|
||||
compose.WithNodeName(keyOfKnowledgeRetriever))
|
||||
|
||||
_ = g.AddLambdaNode(keyOfToolsPreRetriever,
|
||||
compose.InvokableLambda[*AgentRequest, []*schema.Message](tr.toolPreRetrieve),
|
||||
compose.WithOutputKey(keyOfToolsPreRetriever),
|
||||
compose.WithNodeName(keyOfToolsPreRetriever),
|
||||
)
|
||||
_ = g.AddLambdaNode(keyOfKnowledgeRetrieverPack,
|
||||
compose.InvokableLambda[[]*schema.Document, string](kr.PackRetrieveResultInfo),
|
||||
compose.WithOutputKey(placeholderOfKnowledge),
|
||||
)
|
||||
_ = g.AddChatTemplateNode(keyOfPromptTemplate, chatPrompt)
|
||||
|
||||
agentNodeOpts = append(agentNodeOpts, compose.WithNodeName(agentNodeName))
|
||||
|
||||
if isReActAgent {
|
||||
_ = g.AddGraphNode(agentNodeName, agentGraph, agentNodeOpts...)
|
||||
} else {
|
||||
_ = g.AddChatModelNode(agentNodeName, chatModel, agentNodeOpts...)
|
||||
}
|
||||
|
||||
if nsg {
|
||||
_ = g.AddLambdaNode(keyOfSuggestPreInputParse, compose.ToList[*schema.Message](),
|
||||
compose.WithStatePostHandler(func(ctx context.Context, out []*schema.Message, state *AgentState) ([]*schema.Message, error) {
|
||||
out = append(out, state.UserInput)
|
||||
return out, nil
|
||||
}),
|
||||
)
|
||||
_ = g.AddGraphNode(keyOfSuggestGraph, suggestGraph)
|
||||
}
|
||||
|
||||
_ = g.AddEdge(compose.START, keyOfPersonRender)
|
||||
_ = g.AddEdge(compose.START, keyOfPromptVariables)
|
||||
_ = g.AddEdge(compose.START, keyOfKnowledgeRetriever)
|
||||
_ = g.AddEdge(compose.START, keyOfToolsPreRetriever)
|
||||
|
||||
_ = g.AddEdge(keyOfPersonRender, keyOfPromptTemplate)
|
||||
_ = g.AddEdge(keyOfPromptVariables, keyOfPromptTemplate)
|
||||
_ = g.AddEdge(keyOfKnowledgeRetriever, keyOfKnowledgeRetrieverPack)
|
||||
_ = g.AddEdge(keyOfKnowledgeRetrieverPack, keyOfPromptTemplate)
|
||||
_ = g.AddEdge(keyOfToolsPreRetriever, keyOfPromptTemplate)
|
||||
|
||||
_ = g.AddEdge(keyOfPromptTemplate, agentNodeName)
|
||||
|
||||
if nsg {
|
||||
_ = g.AddEdge(agentNodeName, keyOfSuggestPreInputParse)
|
||||
_ = g.AddEdge(keyOfSuggestPreInputParse, keyOfSuggestGraph)
|
||||
_ = g.AddEdge(keyOfSuggestGraph, compose.END)
|
||||
} else {
|
||||
_ = g.AddEdge(agentNodeName, compose.END)
|
||||
}
|
||||
|
||||
var opts []compose.GraphCompileOption
|
||||
if requireCheckpoint {
|
||||
opts = append(opts, compose.WithCheckPointStore(conf.CPStore))
|
||||
}
|
||||
opts = append(opts, compose.WithNodeTriggerMode(compose.AllPredecessor))
|
||||
runner, err := g.Compile(ctx, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &AgentRunner{
|
||||
runner: runner,
|
||||
requireCheckpoint: requireCheckpoint,
|
||||
modelInfo: modelInfo,
|
||||
containWfTool: containWfTool,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func extractJinja2Placeholder(persona string) (variableNames []string) {
|
||||
re := regexp.MustCompile(`{{([^}]*)}}`)
|
||||
matches := re.FindAllStringSubmatch(persona, -1)
|
||||
variables := make([]string, 0, len(matches))
|
||||
for _, match := range matches {
|
||||
val := strings.TrimSpace(match[1])
|
||||
if val != "" {
|
||||
variables = append(variables, match[1])
|
||||
}
|
||||
}
|
||||
return variables
|
||||
}
|
||||
Reference in New Issue
Block a user