coze-studio/backend/domain/agent/singleagent/internal/agentflow/agent_flow_builder.go

295 lines
8.8 KiB
Go

/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package agentflow
import (
"context"
"fmt"
"regexp"
"strings"
"github.com/cloudwego/eino/components/tool"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/flow/agent/react"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow"
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
)
type Config struct {
Agent *entity.SingleAgent
UserID string
Identity *entity.AgentIdentity
ModelMgr modelmgr.Manager
ModelFactory chatmodel.Factory
CPStore compose.CheckPointStore
}
const (
keyOfPersonRender = "persona_render"
keyOfKnowledgeRetriever = "knowledge_retriever"
keyOfKnowledgeRetrieverPack = "knowledge_retriever_pack"
keyOfPromptVariables = "prompt_variables"
keyOfPromptTemplate = "prompt_template"
keyOfReActAgent = "react_agent"
keyOfReActAgentToolsNode = "agent_tool"
keyOfReActAgentChatModel = "re_act_chat_model"
keyOfLLM = "llm"
keyOfToolsPreRetriever = "tools_pre_retriever"
)
func BuildAgent(ctx context.Context, conf *Config) (r *AgentRunner, err error) {
persona := conf.Agent.Prompt.GetPrompt()
avConf := &variableConf{
Agent: conf.Agent,
UserID: conf.UserID,
ConnectorID: conf.Identity.ConnectorID,
}
avs, err := loadAgentVariables(ctx, avConf)
if err != nil {
return nil, err
}
promptVars := &promptVariables{
Agent: conf.Agent,
avs: avs,
}
personaVars := &personaRender{
personaVariableNames: extractJinja2Placeholder(persona),
persona: persona,
variables: avs,
}
kr, err := newKnowledgeRetriever(ctx, &retrieverConfig{
knowledgeConfig: conf.Agent.Knowledge,
})
if err != nil {
return nil, err
}
modelInfo, err := loadModelInfo(ctx, conf.ModelMgr, ptr.From(conf.Agent.ModelInfo.ModelId))
if err != nil {
return nil, err
}
chatModel, err := newChatModel(ctx, &config{
modelFactory: conf.ModelFactory,
modelInfo: modelInfo,
})
if err != nil {
return nil, err
}
requireCheckpoint := false
pluginTools, err := newPluginTools(ctx, &toolConfig{
spaceID: conf.Agent.SpaceID,
userID: conf.UserID,
agentIdentity: conf.Identity,
toolConf: conf.Agent.Plugin,
})
if err != nil {
return nil, err
}
tr := newPreToolRetriever(&toolPreCallConf{})
wfTools, toolsReturnDirectly, err := newWorkflowTools(ctx, &workflowConfig{
wfInfos: conf.Agent.Workflow,
})
if err != nil {
return nil, err
}
var dbTools []tool.InvokableTool
if len(conf.Agent.Database) > 0 {
dbTools, err = newDatabaseTools(ctx, &databaseConfig{
spaceID: conf.Agent.SpaceID,
userID: conf.UserID,
agentIdentity: conf.Identity,
databaseConf: conf.Agent.Database,
})
if err != nil {
return nil, err
}
}
var avTools []tool.InvokableTool
if len(avs) > 0 {
avTools, err = newAgentVariableTools(ctx, avConf)
if err != nil {
return nil, err
}
}
containWfTool := false
if len(wfTools) > 0 {
containWfTool = true
}
agentTools := make([]tool.BaseTool, 0, len(pluginTools)+len(wfTools)+len(dbTools)+len(avTools))
agentTools = append(agentTools, slices.Transform(pluginTools, func(a tool.InvokableTool) tool.BaseTool {
return a
})...)
agentTools = append(agentTools, slices.Transform(wfTools, func(a workflow.ToolFromWorkflow) tool.BaseTool { return a.(tool.BaseTool) })...)
agentTools = append(agentTools, slices.Transform(dbTools, func(a tool.InvokableTool) tool.BaseTool {
return a
})...)
agentTools = append(agentTools, slices.Transform(avTools, func(a tool.InvokableTool) tool.BaseTool {
return a
})...)
var isReActAgent bool
if len(agentTools) > 0 {
isReActAgent = true
requireCheckpoint = true
if modelInfo.Meta.Capability != nil && !modelInfo.Meta.Capability.FunctionCall {
return nil, fmt.Errorf("model %v does not support function call", modelInfo.Meta.Name)
}
}
var agentGraph compose.AnyGraph
var agentNodeOpts []compose.GraphAddNodeOpt
var agentNodeName string
if isReActAgent {
agent, err := react.NewAgent(ctx, &react.AgentConfig{
ToolCallingModel: chatModel,
ToolsConfig: compose.ToolsNodeConfig{
Tools: agentTools,
},
ToolReturnDirectly: toolsReturnDirectly,
ModelNodeName: keyOfReActAgentChatModel,
ToolsNodeName: keyOfReActAgentToolsNode,
})
if err != nil {
return nil, err
}
agentGraph, agentNodeOpts = agent.ExportGraph()
agentNodeName = keyOfReActAgent
} else {
agentNodeName = keyOfLLM
}
suggestGraph, nsg := newSuggestGraph(ctx, conf, chatModel)
g := compose.NewGraph[*AgentRequest, *schema.Message](
compose.WithGenLocalState(func(ctx context.Context) (state *AgentState) {
return &AgentState{}
}))
_ = g.AddLambdaNode(keyOfPersonRender,
compose.InvokableLambda[*AgentRequest, string](personaVars.RenderPersona),
compose.WithStatePreHandler(func(ctx context.Context, ar *AgentRequest, state *AgentState) (*AgentRequest, error) {
state.UserInput = ar.Input
return ar, nil
}),
compose.WithOutputKey(placeholderOfPersona))
_ = g.AddLambdaNode(keyOfPromptVariables,
compose.InvokableLambda[*AgentRequest, map[string]any](promptVars.AssemblePromptVariables))
_ = g.AddLambdaNode(keyOfKnowledgeRetriever,
compose.InvokableLambda[*AgentRequest, []*schema.Document](kr.Retrieve),
compose.WithNodeName(keyOfKnowledgeRetriever))
_ = g.AddLambdaNode(keyOfToolsPreRetriever,
compose.InvokableLambda[*AgentRequest, []*schema.Message](tr.toolPreRetrieve),
compose.WithOutputKey(keyOfToolsPreRetriever),
compose.WithNodeName(keyOfToolsPreRetriever),
)
_ = g.AddLambdaNode(keyOfKnowledgeRetrieverPack,
compose.InvokableLambda[[]*schema.Document, string](kr.PackRetrieveResultInfo),
compose.WithOutputKey(placeholderOfKnowledge),
)
_ = g.AddChatTemplateNode(keyOfPromptTemplate, chatPrompt)
agentNodeOpts = append(agentNodeOpts, compose.WithNodeName(agentNodeName))
if isReActAgent {
_ = g.AddGraphNode(agentNodeName, agentGraph, agentNodeOpts...)
} else {
_ = g.AddChatModelNode(agentNodeName, chatModel, agentNodeOpts...)
}
if nsg {
_ = g.AddLambdaNode(keyOfSuggestPreInputParse, compose.ToList[*schema.Message](),
compose.WithStatePostHandler(func(ctx context.Context, out []*schema.Message, state *AgentState) ([]*schema.Message, error) {
out = append(out, state.UserInput)
return out, nil
}),
)
_ = g.AddGraphNode(keyOfSuggestGraph, suggestGraph)
}
_ = g.AddEdge(compose.START, keyOfPersonRender)
_ = g.AddEdge(compose.START, keyOfPromptVariables)
_ = g.AddEdge(compose.START, keyOfKnowledgeRetriever)
_ = g.AddEdge(compose.START, keyOfToolsPreRetriever)
_ = g.AddEdge(keyOfPersonRender, keyOfPromptTemplate)
_ = g.AddEdge(keyOfPromptVariables, keyOfPromptTemplate)
_ = g.AddEdge(keyOfKnowledgeRetriever, keyOfKnowledgeRetrieverPack)
_ = g.AddEdge(keyOfKnowledgeRetrieverPack, keyOfPromptTemplate)
_ = g.AddEdge(keyOfToolsPreRetriever, keyOfPromptTemplate)
_ = g.AddEdge(keyOfPromptTemplate, agentNodeName)
if nsg {
_ = g.AddEdge(agentNodeName, keyOfSuggestPreInputParse)
_ = g.AddEdge(keyOfSuggestPreInputParse, keyOfSuggestGraph)
_ = g.AddEdge(keyOfSuggestGraph, compose.END)
} else {
_ = g.AddEdge(agentNodeName, compose.END)
}
var opts []compose.GraphCompileOption
if requireCheckpoint {
opts = append(opts, compose.WithCheckPointStore(conf.CPStore))
}
opts = append(opts, compose.WithNodeTriggerMode(compose.AllPredecessor))
runner, err := g.Compile(ctx, opts...)
if err != nil {
return nil, err
}
return &AgentRunner{
runner: runner,
requireCheckpoint: requireCheckpoint,
modelInfo: modelInfo,
containWfTool: containWfTool,
}, nil
}
func extractJinja2Placeholder(persona string) (variableNames []string) {
re := regexp.MustCompile(`{{([^}]*)}}`)
matches := re.FindAllStringSubmatch(persona, -1)
variables := make([]string, 0, len(matches))
for _, match := range matches {
val := strings.TrimSpace(match[1])
if val != "" {
variables = append(variables, match[1])
}
}
return variables
}