295 lines
		
	
	
		
			8.8 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			295 lines
		
	
	
		
			8.8 KiB
		
	
	
	
		
			Go
		
	
	
	
| /*
 | |
|  * Copyright 2025 coze-dev Authors
 | |
|  *
 | |
|  * Licensed under the Apache License, Version 2.0 (the "License");
 | |
|  * you may not use this file except in compliance with the License.
 | |
|  * You may obtain a copy of the License at
 | |
|  *
 | |
|  *     http://www.apache.org/licenses/LICENSE-2.0
 | |
|  *
 | |
|  * Unless required by applicable law or agreed to in writing, software
 | |
|  * distributed under the License is distributed on an "AS IS" BASIS,
 | |
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
|  * See the License for the specific language governing permissions and
 | |
|  * limitations under the License.
 | |
|  */
 | |
| 
 | |
| package agentflow
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"fmt"
 | |
| 	"regexp"
 | |
| 	"strings"
 | |
| 
 | |
| 	"github.com/cloudwego/eino/components/tool"
 | |
| 	"github.com/cloudwego/eino/compose"
 | |
| 	"github.com/cloudwego/eino/flow/agent/react"
 | |
| 	"github.com/cloudwego/eino/schema"
 | |
| 
 | |
| 	"github.com/coze-dev/coze-studio/backend/domain/agent/singleagent/entity"
 | |
| 	"github.com/coze-dev/coze-studio/backend/domain/workflow"
 | |
| 	"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
 | |
| 	"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
 | |
| 	"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
 | |
| 	"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
 | |
| )
 | |
| 
 | |
| type Config struct {
 | |
| 	Agent        *entity.SingleAgent
 | |
| 	UserID       string
 | |
| 	Identity     *entity.AgentIdentity
 | |
| 	ModelMgr     modelmgr.Manager
 | |
| 	ModelFactory chatmodel.Factory
 | |
| 	CPStore      compose.CheckPointStore
 | |
| }
 | |
| 
 | |
| const (
 | |
| 	keyOfPersonRender           = "persona_render"
 | |
| 	keyOfKnowledgeRetriever     = "knowledge_retriever"
 | |
| 	keyOfKnowledgeRetrieverPack = "knowledge_retriever_pack"
 | |
| 	keyOfPromptVariables        = "prompt_variables"
 | |
| 	keyOfPromptTemplate         = "prompt_template"
 | |
| 	keyOfReActAgent             = "react_agent"
 | |
| 	keyOfReActAgentToolsNode    = "agent_tool"
 | |
| 	keyOfReActAgentChatModel    = "re_act_chat_model"
 | |
| 	keyOfLLM                    = "llm"
 | |
| 	keyOfToolsPreRetriever      = "tools_pre_retriever"
 | |
| )
 | |
| 
 | |
| func BuildAgent(ctx context.Context, conf *Config) (r *AgentRunner, err error) {
 | |
| 	persona := conf.Agent.Prompt.GetPrompt()
 | |
| 
 | |
| 	avConf := &variableConf{
 | |
| 		Agent:       conf.Agent,
 | |
| 		UserID:      conf.UserID,
 | |
| 		ConnectorID: conf.Identity.ConnectorID,
 | |
| 	}
 | |
| 	avs, err := loadAgentVariables(ctx, avConf)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	promptVars := &promptVariables{
 | |
| 		Agent: conf.Agent,
 | |
| 		avs:   avs,
 | |
| 	}
 | |
| 
 | |
| 	personaVars := &personaRender{
 | |
| 		personaVariableNames: extractJinja2Placeholder(persona),
 | |
| 		persona:              persona,
 | |
| 		variables:            avs,
 | |
| 	}
 | |
| 
 | |
| 	kr, err := newKnowledgeRetriever(ctx, &retrieverConfig{
 | |
| 		knowledgeConfig: conf.Agent.Knowledge,
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	modelInfo, err := loadModelInfo(ctx, conf.ModelMgr, ptr.From(conf.Agent.ModelInfo.ModelId))
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	chatModel, err := newChatModel(ctx, &config{
 | |
| 		modelFactory: conf.ModelFactory,
 | |
| 		modelInfo:    modelInfo,
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	requireCheckpoint := false
 | |
| 	pluginTools, err := newPluginTools(ctx, &toolConfig{
 | |
| 		spaceID:       conf.Agent.SpaceID,
 | |
| 		userID:        conf.UserID,
 | |
| 		agentIdentity: conf.Identity,
 | |
| 		toolConf:      conf.Agent.Plugin,
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	tr := newPreToolRetriever(&toolPreCallConf{})
 | |
| 
 | |
| 	wfTools, toolsReturnDirectly, err := newWorkflowTools(ctx, &workflowConfig{
 | |
| 		wfInfos: conf.Agent.Workflow,
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	var dbTools []tool.InvokableTool
 | |
| 	if len(conf.Agent.Database) > 0 {
 | |
| 		dbTools, err = newDatabaseTools(ctx, &databaseConfig{
 | |
| 			spaceID:       conf.Agent.SpaceID,
 | |
| 			userID:        conf.UserID,
 | |
| 			agentIdentity: conf.Identity,
 | |
| 			databaseConf:  conf.Agent.Database,
 | |
| 		})
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	var avTools []tool.InvokableTool
 | |
| 	if len(avs) > 0 {
 | |
| 		avTools, err = newAgentVariableTools(ctx, avConf)
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 	}
 | |
| 	containWfTool := false
 | |
| 
 | |
| 	if len(wfTools) > 0 {
 | |
| 		containWfTool = true
 | |
| 	}
 | |
| 	agentTools := make([]tool.BaseTool, 0, len(pluginTools)+len(wfTools)+len(dbTools)+len(avTools))
 | |
| 	agentTools = append(agentTools, slices.Transform(pluginTools, func(a tool.InvokableTool) tool.BaseTool {
 | |
| 		return a
 | |
| 	})...)
 | |
| 	agentTools = append(agentTools, slices.Transform(wfTools, func(a workflow.ToolFromWorkflow) tool.BaseTool { return a.(tool.BaseTool) })...)
 | |
| 	agentTools = append(agentTools, slices.Transform(dbTools, func(a tool.InvokableTool) tool.BaseTool {
 | |
| 		return a
 | |
| 	})...)
 | |
| 
 | |
| 	agentTools = append(agentTools, slices.Transform(avTools, func(a tool.InvokableTool) tool.BaseTool {
 | |
| 		return a
 | |
| 	})...)
 | |
| 
 | |
| 	var isReActAgent bool
 | |
| 	if len(agentTools) > 0 {
 | |
| 		isReActAgent = true
 | |
| 		requireCheckpoint = true
 | |
| 		if modelInfo.Meta.Capability != nil && !modelInfo.Meta.Capability.FunctionCall {
 | |
| 			return nil, fmt.Errorf("model %v does not support function call", modelInfo.Meta.Name)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	var agentGraph compose.AnyGraph
 | |
| 	var agentNodeOpts []compose.GraphAddNodeOpt
 | |
| 	var agentNodeName string
 | |
| 	if isReActAgent {
 | |
| 		agent, err := react.NewAgent(ctx, &react.AgentConfig{
 | |
| 			ToolCallingModel: chatModel,
 | |
| 			ToolsConfig: compose.ToolsNodeConfig{
 | |
| 				Tools: agentTools,
 | |
| 			},
 | |
| 			ToolReturnDirectly: toolsReturnDirectly,
 | |
| 			ModelNodeName:      keyOfReActAgentChatModel,
 | |
| 			ToolsNodeName:      keyOfReActAgentToolsNode,
 | |
| 		})
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 		agentGraph, agentNodeOpts = agent.ExportGraph()
 | |
| 
 | |
| 		agentNodeName = keyOfReActAgent
 | |
| 	} else {
 | |
| 		agentNodeName = keyOfLLM
 | |
| 	}
 | |
| 
 | |
| 	suggestGraph, nsg := newSuggestGraph(ctx, conf, chatModel)
 | |
| 
 | |
| 	g := compose.NewGraph[*AgentRequest, *schema.Message](
 | |
| 		compose.WithGenLocalState(func(ctx context.Context) (state *AgentState) {
 | |
| 			return &AgentState{}
 | |
| 		}))
 | |
| 
 | |
| 	_ = g.AddLambdaNode(keyOfPersonRender,
 | |
| 		compose.InvokableLambda[*AgentRequest, string](personaVars.RenderPersona),
 | |
| 		compose.WithStatePreHandler(func(ctx context.Context, ar *AgentRequest, state *AgentState) (*AgentRequest, error) {
 | |
| 			state.UserInput = ar.Input
 | |
| 			return ar, nil
 | |
| 		}),
 | |
| 		compose.WithOutputKey(placeholderOfPersona))
 | |
| 
 | |
| 	_ = g.AddLambdaNode(keyOfPromptVariables,
 | |
| 		compose.InvokableLambda[*AgentRequest, map[string]any](promptVars.AssemblePromptVariables))
 | |
| 
 | |
| 	_ = g.AddLambdaNode(keyOfKnowledgeRetriever,
 | |
| 		compose.InvokableLambda[*AgentRequest, []*schema.Document](kr.Retrieve),
 | |
| 		compose.WithNodeName(keyOfKnowledgeRetriever))
 | |
| 
 | |
| 	_ = g.AddLambdaNode(keyOfToolsPreRetriever,
 | |
| 		compose.InvokableLambda[*AgentRequest, []*schema.Message](tr.toolPreRetrieve),
 | |
| 		compose.WithOutputKey(keyOfToolsPreRetriever),
 | |
| 		compose.WithNodeName(keyOfToolsPreRetriever),
 | |
| 	)
 | |
| 	_ = g.AddLambdaNode(keyOfKnowledgeRetrieverPack,
 | |
| 		compose.InvokableLambda[[]*schema.Document, string](kr.PackRetrieveResultInfo),
 | |
| 		compose.WithOutputKey(placeholderOfKnowledge),
 | |
| 	)
 | |
| 	_ = g.AddChatTemplateNode(keyOfPromptTemplate, chatPrompt)
 | |
| 
 | |
| 	agentNodeOpts = append(agentNodeOpts, compose.WithNodeName(agentNodeName))
 | |
| 
 | |
| 	if isReActAgent {
 | |
| 		_ = g.AddGraphNode(agentNodeName, agentGraph, agentNodeOpts...)
 | |
| 	} else {
 | |
| 		_ = g.AddChatModelNode(agentNodeName, chatModel, agentNodeOpts...)
 | |
| 	}
 | |
| 
 | |
| 	if nsg {
 | |
| 		_ = g.AddLambdaNode(keyOfSuggestPreInputParse, compose.ToList[*schema.Message](),
 | |
| 			compose.WithStatePostHandler(func(ctx context.Context, out []*schema.Message, state *AgentState) ([]*schema.Message, error) {
 | |
| 				out = append(out, state.UserInput)
 | |
| 				return out, nil
 | |
| 			}),
 | |
| 		)
 | |
| 		_ = g.AddGraphNode(keyOfSuggestGraph, suggestGraph)
 | |
| 	}
 | |
| 
 | |
| 	_ = g.AddEdge(compose.START, keyOfPersonRender)
 | |
| 	_ = g.AddEdge(compose.START, keyOfPromptVariables)
 | |
| 	_ = g.AddEdge(compose.START, keyOfKnowledgeRetriever)
 | |
| 	_ = g.AddEdge(compose.START, keyOfToolsPreRetriever)
 | |
| 
 | |
| 	_ = g.AddEdge(keyOfPersonRender, keyOfPromptTemplate)
 | |
| 	_ = g.AddEdge(keyOfPromptVariables, keyOfPromptTemplate)
 | |
| 	_ = g.AddEdge(keyOfKnowledgeRetriever, keyOfKnowledgeRetrieverPack)
 | |
| 	_ = g.AddEdge(keyOfKnowledgeRetrieverPack, keyOfPromptTemplate)
 | |
| 	_ = g.AddEdge(keyOfToolsPreRetriever, keyOfPromptTemplate)
 | |
| 
 | |
| 	_ = g.AddEdge(keyOfPromptTemplate, agentNodeName)
 | |
| 
 | |
| 	if nsg {
 | |
| 		_ = g.AddEdge(agentNodeName, keyOfSuggestPreInputParse)
 | |
| 		_ = g.AddEdge(keyOfSuggestPreInputParse, keyOfSuggestGraph)
 | |
| 		_ = g.AddEdge(keyOfSuggestGraph, compose.END)
 | |
| 	} else {
 | |
| 		_ = g.AddEdge(agentNodeName, compose.END)
 | |
| 	}
 | |
| 
 | |
| 	var opts []compose.GraphCompileOption
 | |
| 	if requireCheckpoint {
 | |
| 		opts = append(opts, compose.WithCheckPointStore(conf.CPStore))
 | |
| 	}
 | |
| 	opts = append(opts, compose.WithNodeTriggerMode(compose.AllPredecessor))
 | |
| 	runner, err := g.Compile(ctx, opts...)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return &AgentRunner{
 | |
| 		runner:            runner,
 | |
| 		requireCheckpoint: requireCheckpoint,
 | |
| 		modelInfo:         modelInfo,
 | |
| 		containWfTool:     containWfTool,
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| func extractJinja2Placeholder(persona string) (variableNames []string) {
 | |
| 	re := regexp.MustCompile(`{{([^}]*)}}`)
 | |
| 	matches := re.FindAllStringSubmatch(persona, -1)
 | |
| 	variables := make([]string, 0, len(matches))
 | |
| 	for _, match := range matches {
 | |
| 		val := strings.TrimSpace(match[1])
 | |
| 		if val != "" {
 | |
| 			variables = append(variables, match[1])
 | |
| 		}
 | |
| 	}
 | |
| 	return variables
 | |
| }
 |