coze-studio/backend/domain/workflow/internal/nodes/llm/llm.go

1232 lines
36 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package llm
import (
"context"
"errors"
"fmt"
"io"
"strconv"
"strings"
"github.com/cloudwego/eino/callbacks"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/components/prompt"
"github.com/cloudwego/eino/components/tool"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/flow/agent/react"
"github.com/cloudwego/eino/schema"
callbacks2 "github.com/cloudwego/eino/utils/callbacks"
"golang.org/x/exp/maps"
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/knowledge"
crossmodel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
workflow3 "github.com/coze-dev/coze-studio/backend/api/model/workflow"
crossknowledge "github.com/coze-dev/coze-studio/backend/crossdomain/contract/knowledge"
crossmodelmgr "github.com/coze-dev/coze-studio/backend/crossdomain/contract/modelmgr"
"github.com/coze-dev/coze-studio/backend/domain/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
"github.com/coze-dev/coze-studio/backend/types/errno"
)
type Format int
const (
FormatText Format = iota
FormatMarkdown
FormatJSON
)
const (
jsonPromptFormat = `
Strictly reply in valid JSON format.
- Ensure the output strictly conforms to the JSON schema below
- Do not include explanations, comments, or any text outside the JSON.
Here is the output JSON schema:
'''
%s
'''
`
markdownPrompt = `
Strictly reply in valid Markdown format.
- For headings, use number signs (#).
- For list items, start with dashes (-).
- To emphasize text, wrap it with asterisks (*).
- For code or commands, surround them with backticks (` + "`" + `).
- For quoted text, use greater than signs (>).
- For links, wrap the text in square brackets [], followed by the URL in parentheses ().
- For images, use square brackets [] for the alt text, followed by the image URL in parentheses ().
`
)
const (
ReasoningOutputKey = "reasoning_content"
)
const knowledgeUserPromptTemplate = `根据引用的内容回答问题:
1.如果引用的内容里面包含 <img src=""> 的标签, 标签里的 src 字段表示图片地址, 需要在回答问题的时候展示出去, 输出格式为"![图片名称](图片地址)" 。
2.如果引用的内容不包含 <img src=""> 的标签, 你回答问题时不需要展示图片 。
例如:
如果内容为<img src="https://example.com/image.jpg">一只小猫,你的输出应为:![一只小猫](https://example.com/image.jpg)。
如果内容为<img src="https://example.com/image1.jpg">一只小猫 和 <img src="https://example.com/image2.jpg">一只小狗 和 <img src="https://example.com/image3.jpg">一只小牛,你的输出应为:![一只小猫](https://example.com/image1.jpg) 和 ![一只小狗](https://example.com/image2.jpg) 和 ![一只小牛](https://example.com/image3.jpg)
you can refer to the following content and do relevant searches to improve:
---
%s
question is:
`
const knowledgeIntentPrompt = `
# 角色:
你是一个知识库意图识别AI Agent。
## 目标:
- 按照「系统提示词」、用户需求、最新的聊天记录选择应该使用的知识库。
## 工作流程:
1. 分析「系统提示词」以确定用户的具体需求。
2. 如果「系统提示词」明确指明了要使用的知识库则直接返回这些知识库只输出它们的knowledge_id不需要再判断用户的输入
3. 检查每个知识库的knowledge_name和knowledge_description以了解它们各自的功能。
4. 根据用户需求,选择最符合的知识库。
5. 如果找到一个或多个合适的知识库输出它们的knowledge_id。如果没有合适的知识库输出0。
## 约束:
- 严格按照「系统提示词」和用户的需求选择知识库。「系统提示词」的优先级大于用户的需求
- 如果有多个合适的知识库将它们的knowledge_id用英文逗号连接后输出。
- 输出必须仅为knowledge_id或0不得包括任何其他内容或解释不要在id后面输出知识库名称。
## 输出示例
123,456
## 输出格式:
输出应该是一个纯数字或者由英文逗号连接的数字序列,具体取决于选择的知识库数量。不应包含任何其他文本或格式。
## 知识库列表如下
%s
## 「系统提示词」如下
%s
`
const (
knowledgeTemplateKey = "knowledge_template"
knowledgeChatModelKey = "knowledge_chat_model"
knowledgeLambdaKey = "knowledge_lambda"
knowledgeUserPromptTemplateKey = "knowledge_user_prompt_prefix"
templateNodeKey = "template"
llmNodeKey = "llm"
outputConvertNodeKey = "output_convert"
)
type NoReCallReplyMode int64
const (
NoReCallReplyModeOfDefault NoReCallReplyMode = 0
NoReCallReplyModeOfCustomize NoReCallReplyMode = 1
)
type RetrievalStrategy struct {
RetrievalStrategy *knowledge.RetrievalStrategy
NoReCallReplyMode NoReCallReplyMode
NoReCallReplyCustomizePrompt string
}
type KnowledgeRecallConfig struct {
ChatModel model.BaseChatModel
RetrievalStrategy *RetrievalStrategy
SelectedKnowledgeDetails []*knowledge.KnowledgeDetail
}
type Config struct {
SystemPrompt string
UserPrompt string
OutputFormat Format
LLMParams *crossmodel.LLMParams
FCParam *vo.FCParam
BackupLLMParams *crossmodel.LLMParams
}
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema2.NodeSchema, error) {
ns := &schema2.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeLLM,
Name: n.Data.Meta.Title,
Configs: c,
}
param := n.Data.Inputs.LLMParam
if param == nil {
return nil, fmt.Errorf("llm node's llmParam is nil")
}
bs, _ := sonic.Marshal(param)
llmParam := make(vo.LLMParam, 0)
if err := sonic.Unmarshal(bs, &llmParam); err != nil {
return nil, err
}
convertedLLMParam, err := llmParamsToLLMParam(llmParam)
if err != nil {
return nil, err
}
c.LLMParams = convertedLLMParam
c.SystemPrompt = convertedLLMParam.SystemPrompt
c.UserPrompt = convertedLLMParam.Prompt
var resFormat Format
switch convertedLLMParam.ResponseFormat {
case crossmodel.ResponseFormatText:
resFormat = FormatText
case crossmodel.ResponseFormatMarkdown:
resFormat = FormatMarkdown
case crossmodel.ResponseFormatJSON:
resFormat = FormatJSON
default:
return nil, fmt.Errorf("unsupported response format: %d", convertedLLMParam.ResponseFormat)
}
c.OutputFormat = resFormat
if err = convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
if resFormat == FormatJSON {
if len(ns.OutputTypes) == 1 {
for _, v := range ns.OutputTypes {
if v.Type == vo.DataTypeString {
resFormat = FormatText
break
}
}
} else if len(ns.OutputTypes) == 2 {
if _, ok := ns.OutputTypes[ReasoningOutputKey]; ok {
for k, v := range ns.OutputTypes {
if k != ReasoningOutputKey && v.Type == vo.DataTypeString {
resFormat = FormatText
break
}
}
}
}
}
if resFormat == FormatJSON {
ns.StreamConfigs = &schema2.StreamConfig{
CanGeneratesStream: false,
}
} else {
ns.StreamConfigs = &schema2.StreamConfig{
CanGeneratesStream: true,
}
}
if n.Data.Inputs.LLM != nil && n.Data.Inputs.FCParam != nil {
c.FCParam = n.Data.Inputs.FCParam
}
if se := n.Data.Inputs.SettingOnError; se != nil {
if se.Ext != nil && len(se.Ext.BackupLLMParam) > 0 {
var backupLLMParam vo.SimpleLLMParam
if err = sonic.UnmarshalString(se.Ext.BackupLLMParam, &backupLLMParam); err != nil {
return nil, err
}
backupModel, err := simpleLLMParamsToLLMParams(backupLLMParam)
if err != nil {
return nil, err
}
c.BackupLLMParams = backupModel
}
}
return ns, nil
}
func llmParamsToLLMParam(params vo.LLMParam) (*crossmodel.LLMParams, error) {
p := &crossmodel.LLMParams{}
for _, param := range params {
switch param.Name {
case "temperature":
strVal := param.Input.Value.Content.(string)
floatVal, err := strconv.ParseFloat(strVal, 64)
if err != nil {
return nil, err
}
p.Temperature = &floatVal
case "maxTokens":
strVal := param.Input.Value.Content.(string)
intVal, err := strconv.Atoi(strVal)
if err != nil {
return nil, err
}
p.MaxTokens = intVal
case "responseFormat":
strVal := param.Input.Value.Content.(string)
int64Val, err := strconv.ParseInt(strVal, 10, 64)
if err != nil {
return nil, err
}
p.ResponseFormat = crossmodel.ResponseFormat(int64Val)
case "modleName":
strVal := param.Input.Value.Content.(string)
p.ModelName = strVal
case "modelType":
strVal := param.Input.Value.Content.(string)
int64Val, err := strconv.ParseInt(strVal, 10, 64)
if err != nil {
return nil, err
}
p.ModelType = int64Val
case "prompt":
strVal := param.Input.Value.Content.(string)
p.Prompt = strVal
case "enableChatHistory":
boolVar := param.Input.Value.Content.(bool)
p.EnableChatHistory = boolVar
case "systemPrompt":
strVal := param.Input.Value.Content.(string)
p.SystemPrompt = strVal
case "chatHistoryRound", "generationDiversity", "frequencyPenalty", "presencePenalty":
// do nothing
case "topP":
strVal := param.Input.Value.Content.(string)
floatVar, err := strconv.ParseFloat(strVal, 64)
if err != nil {
return nil, err
}
p.TopP = &floatVar
default:
return nil, fmt.Errorf("invalid LLMParam name: %s", param.Name)
}
}
return p, nil
}
func simpleLLMParamsToLLMParams(params vo.SimpleLLMParam) (*crossmodel.LLMParams, error) {
p := &crossmodel.LLMParams{}
p.ModelName = params.ModelName
p.ModelType = params.ModelType
p.Temperature = &params.Temperature
p.MaxTokens = params.MaxTokens
p.TopP = &params.TopP
p.ResponseFormat = params.ResponseFormat
p.SystemPrompt = params.SystemPrompt
return p, nil
}
func getReasoningContent(message *schema.Message) string {
return message.ReasoningContent
}
func (c *Config) Build(ctx context.Context, ns *schema2.NodeSchema, _ ...schema2.BuildOption) (any, error) {
var (
err error
chatModel, fallbackM model.BaseChatModel
info, fallbackI *modelmgr.Model
modelWithInfo ModelWithInfo
tools []tool.BaseTool
toolsReturnDirectly map[string]bool
knowledgeRecallConfig *KnowledgeRecallConfig
)
chatModel, info, err = crossmodelmgr.DefaultSVC().GetModel(ctx, c.LLMParams)
if err != nil {
return nil, err
}
exceptionConf := ns.ExceptionConfigs
if exceptionConf != nil && exceptionConf.MaxRetry > 0 {
backupModelParams := c.BackupLLMParams
if backupModelParams != nil {
fallbackM, fallbackI, err = crossmodelmgr.DefaultSVC().GetModel(ctx, backupModelParams)
if err != nil {
return nil, err
}
}
}
if fallbackM == nil {
modelWithInfo = NewModel(chatModel, info)
} else {
modelWithInfo = NewModelWithFallback(chatModel, fallbackM, info, fallbackI)
}
fcParams := c.FCParam
if fcParams != nil {
if fcParams.WorkflowFCParam != nil {
for _, wf := range fcParams.WorkflowFCParam.WorkflowList {
wfIDStr := wf.WorkflowID
wfID, err := strconv.ParseInt(wfIDStr, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid workflow id: %s", wfIDStr)
}
workflowToolConfig := vo.WorkflowToolConfig{}
if wf.FCSetting != nil {
workflowToolConfig.InputParametersConfig = wf.FCSetting.RequestParameters
workflowToolConfig.OutputParametersConfig = wf.FCSetting.ResponseParameters
}
locator := vo.FromDraft
if wf.WorkflowVersion != "" {
locator = vo.FromSpecificVersion
}
wfTool, err := workflow.GetRepository().WorkflowAsTool(ctx, vo.GetPolicy{
ID: wfID,
QType: locator,
Version: wf.WorkflowVersion,
}, workflowToolConfig)
if err != nil {
return nil, err
}
tools = append(tools, wfTool)
if wfTool.TerminatePlan() == vo.UseAnswerContent {
if toolsReturnDirectly == nil {
toolsReturnDirectly = make(map[string]bool)
}
toolInfo, err := wfTool.Info(ctx)
if err != nil {
return nil, err
}
toolsReturnDirectly[toolInfo.Name] = true
}
}
}
if fcParams.PluginFCParam != nil {
pluginToolsInvokableReq := make(map[int64]*plugin.ToolsInvokableRequest)
for _, p := range fcParams.PluginFCParam.PluginList {
pid, err := strconv.ParseInt(p.PluginID, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid plugin id: %s", p.PluginID)
}
toolID, err := strconv.ParseInt(p.ApiId, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid plugin id: %s", p.PluginID)
}
var (
requestParameters []*workflow3.APIParameter
responseParameters []*workflow3.APIParameter
)
if p.FCSetting != nil {
requestParameters = p.FCSetting.RequestParameters
responseParameters = p.FCSetting.ResponseParameters
}
if req, ok := pluginToolsInvokableReq[pid]; ok {
req.ToolsInvokableInfo[toolID] = &plugin.ToolsInvokableInfo{
ToolID: toolID,
RequestAPIParametersConfig: requestParameters,
ResponseAPIParametersConfig: responseParameters,
}
} else {
pluginToolsInfoRequest := &plugin.ToolsInvokableRequest{
PluginEntity: plugin.Entity{
PluginID: pid,
PluginVersion: ptr.Of(p.PluginVersion),
},
ToolsInvokableInfo: map[int64]*plugin.ToolsInvokableInfo{
toolID: {
ToolID: toolID,
RequestAPIParametersConfig: requestParameters,
ResponseAPIParametersConfig: responseParameters,
},
},
IsDraft: p.IsDraft,
}
pluginToolsInvokableReq[pid] = pluginToolsInfoRequest
}
}
inInvokableTools := make([]tool.BaseTool, 0, len(fcParams.PluginFCParam.PluginList))
for _, req := range pluginToolsInvokableReq {
toolMap, err := plugin.GetPluginService().GetPluginInvokableTools(ctx, req)
if err != nil {
return nil, err
}
for _, t := range toolMap {
inInvokableTools = append(inInvokableTools, plugin.NewInvokableTool(t))
}
}
if len(inInvokableTools) > 0 {
tools = append(tools, inInvokableTools...)
}
}
if fcParams.KnowledgeFCParam != nil && len(fcParams.KnowledgeFCParam.KnowledgeList) > 0 {
kwChatModel := workflow.GetRepository().GetKnowledgeRecallChatModel()
if kwChatModel == nil {
return nil, fmt.Errorf("workflow builtin chat model for knowledge recall not configured")
}
setting := fcParams.KnowledgeFCParam.GlobalSetting
knowledgeRecallConfig = &KnowledgeRecallConfig{
ChatModel: kwChatModel,
}
searchType, err := toRetrievalSearchType(setting.SearchMode)
if err != nil {
return nil, err
}
knowledgeRecallConfig.RetrievalStrategy = &RetrievalStrategy{
RetrievalStrategy: &knowledge.RetrievalStrategy{
TopK: ptr.Of(setting.TopK),
MinScore: ptr.Of(setting.MinScore),
SearchType: searchType,
EnableNL2SQL: setting.UseNL2SQL,
EnableQueryRewrite: setting.UseRewrite,
EnableRerank: setting.UseRerank,
},
NoReCallReplyMode: NoReCallReplyMode(setting.NoRecallReplyMode),
NoReCallReplyCustomizePrompt: setting.NoRecallReplyCustomizePrompt,
}
knowledgeIDs := make([]int64, 0, len(fcParams.KnowledgeFCParam.KnowledgeList))
for _, kw := range fcParams.KnowledgeFCParam.KnowledgeList {
kid, err := strconv.ParseInt(kw.ID, 10, 64)
if err != nil {
return nil, err
}
knowledgeIDs = append(knowledgeIDs, kid)
}
detailResp, err := crossknowledge.DefaultSVC().ListKnowledgeDetail(ctx,
&knowledge.ListKnowledgeDetailRequest{
KnowledgeIDs: knowledgeIDs,
})
if err != nil {
return nil, err
}
knowledgeRecallConfig.SelectedKnowledgeDetails = detailResp.KnowledgeDetails
}
}
g := compose.NewGraph[map[string]any, map[string]any](compose.WithGenLocalState(func(ctx context.Context) (state llmState) {
return llmState{}
}))
var hasReasoning bool
format := c.OutputFormat
if format == FormatJSON {
if len(ns.OutputTypes) == 1 {
for _, v := range ns.OutputTypes {
if v.Type == vo.DataTypeString {
format = FormatText
break
}
}
} else if len(ns.OutputTypes) == 2 {
if _, ok := ns.OutputTypes[ReasoningOutputKey]; ok {
for k, v := range ns.OutputTypes {
if k != ReasoningOutputKey && v.Type == vo.DataTypeString {
format = FormatText
break
}
}
}
}
}
userPrompt := c.UserPrompt
switch format {
case FormatJSON:
jsonSchema, err := vo.TypeInfoToJSONSchema(ns.OutputTypes, nil)
if err != nil {
return nil, err
}
jsonPrompt := fmt.Sprintf(jsonPromptFormat, jsonSchema)
userPrompt = userPrompt + jsonPrompt
case FormatMarkdown:
userPrompt = userPrompt + markdownPrompt
case FormatText:
}
if knowledgeRecallConfig != nil {
err := injectKnowledgeTool(ctx, g, c.UserPrompt, knowledgeRecallConfig)
if err != nil {
return nil, err
}
userPrompt = fmt.Sprintf("{{%s}}%s", knowledgeUserPromptTemplateKey, userPrompt)
inputs := maps.Clone(ns.InputTypes)
inputs[knowledgeUserPromptTemplateKey] = &vo.TypeInfo{
Type: vo.DataTypeString,
}
sp := newPromptTpl(schema.System, c.SystemPrompt, inputs, nil)
up := newPromptTpl(schema.User, userPrompt, inputs, []string{knowledgeUserPromptTemplateKey})
template := newPrompts(sp, up, modelWithInfo)
_ = g.AddChatTemplateNode(templateNodeKey, template,
compose.WithStatePreHandler(func(ctx context.Context, in map[string]any, state llmState) (map[string]any, error) {
for k, v := range state {
in[k] = v
}
return in, nil
}))
_ = g.AddEdge(knowledgeLambdaKey, templateNodeKey)
} else {
sp := newPromptTpl(schema.System, c.SystemPrompt, ns.InputTypes, nil)
up := newPromptTpl(schema.User, userPrompt, ns.InputTypes, nil)
template := newPrompts(sp, up, modelWithInfo)
_ = g.AddChatTemplateNode(templateNodeKey, template)
_ = g.AddEdge(compose.START, templateNodeKey)
}
if len(tools) > 0 {
m, ok := modelWithInfo.(model.ToolCallingChatModel)
if !ok {
return nil, errors.New("requires a ToolCallingChatModel to use with tools")
}
reactConfig := react.AgentConfig{
ToolCallingModel: m,
ToolsConfig: compose.ToolsNodeConfig{Tools: tools},
ModelNodeName: agentModelName,
}
if len(toolsReturnDirectly) > 0 {
reactConfig.ToolReturnDirectly = make(map[string]struct{}, len(toolsReturnDirectly))
for k := range toolsReturnDirectly {
reactConfig.ToolReturnDirectly[k] = struct{}{}
}
}
reactAgent, err := react.NewAgent(ctx, &reactConfig)
if err != nil {
return nil, err
}
agentNode, opts := reactAgent.ExportGraph()
opts = append(opts, compose.WithNodeName("workflow_llm_react_agent"))
_ = g.AddGraphNode(llmNodeKey, agentNode, opts...)
} else {
_ = g.AddChatModelNode(llmNodeKey, modelWithInfo)
}
_ = g.AddEdge(templateNodeKey, llmNodeKey)
if format == FormatJSON {
iConvert := func(ctx context.Context, msg *schema.Message) (map[string]any, error) {
return jsonParse(ctx, msg.Content, ns.OutputTypes)
}
convertNode := compose.InvokableLambda(iConvert)
_ = g.AddLambdaNode(outputConvertNodeKey, convertNode)
} else {
var outputKey string
if len(ns.OutputTypes) != 1 && len(ns.OutputTypes) != 2 {
panic("impossible")
}
for k, v := range ns.OutputTypes {
if v.Type != vo.DataTypeString {
panic("impossible")
}
if k == ReasoningOutputKey {
hasReasoning = true
} else {
outputKey = k
}
}
iConvert := func(_ context.Context, msg *schema.Message, _ ...struct{}) (map[string]any, error) {
out := map[string]any{outputKey: msg.Content}
if hasReasoning {
out[ReasoningOutputKey] = getReasoningContent(msg)
}
return out, nil
}
tConvert := func(_ context.Context, s *schema.StreamReader[*schema.Message], _ ...struct{}) (*schema.StreamReader[map[string]any], error) {
sr, sw := schema.Pipe[map[string]any](0)
safego.Go(ctx, func() {
reasoningDone := false
for {
msg, err := s.Recv()
if err != nil {
if err == io.EOF {
sw.Send(map[string]any{
outputKey: nodes.KeyIsFinished,
}, nil)
sw.Close()
return
}
sw.Send(nil, err)
sw.Close()
return
}
if hasReasoning {
reasoning := getReasoningContent(msg)
if len(reasoning) > 0 {
sw.Send(map[string]any{ReasoningOutputKey: reasoning}, nil)
}
}
if len(msg.Content) > 0 {
if !reasoningDone && hasReasoning {
reasoningDone = true
sw.Send(map[string]any{
ReasoningOutputKey: nodes.KeyIsFinished,
}, nil)
}
sw.Send(map[string]any{outputKey: msg.Content}, nil)
}
}
})
return sr, nil
}
convertNode, err := compose.AnyLambda(iConvert, nil, nil, tConvert)
if err != nil {
return nil, err
}
_ = g.AddLambdaNode(outputConvertNodeKey, convertNode)
}
_ = g.AddEdge(llmNodeKey, outputConvertNodeKey)
_ = g.AddEdge(outputConvertNodeKey, compose.END)
requireCheckpoint := false
if len(tools) > 0 {
requireCheckpoint = true
}
var compileOpts []compose.GraphCompileOption
if requireCheckpoint {
compileOpts = append(compileOpts, compose.WithCheckPointStore(workflow.GetRepository()))
}
compileOpts = append(compileOpts, compose.WithGraphName("workflow_llm_node_graph"))
r, err := g.Compile(ctx, compileOpts...)
if err != nil {
return nil, err
}
llm := &LLM{
r: r,
outputFormat: format,
requireCheckpoint: requireCheckpoint,
fullSources: ns.FullSources,
}
return llm, nil
}
func (c *Config) RequireCheckpoint() bool {
if c.FCParam != nil {
if c.FCParam.WorkflowFCParam != nil || c.FCParam.PluginFCParam != nil {
return true
}
}
return false
}
func (c *Config) FieldStreamType(path compose.FieldPath, ns *schema2.NodeSchema,
sc *schema2.WorkflowSchema) (schema2.FieldStreamType, error) {
if !sc.RequireStreaming() {
return schema2.FieldNotStream, nil
}
if len(path) != 1 {
return schema2.FieldNotStream, nil
}
outputs := ns.OutputTypes
if len(outputs) != 1 && len(outputs) != 2 {
return schema2.FieldNotStream, nil
}
var outputKey string
for key, output := range outputs {
if output.Type != vo.DataTypeString {
return schema2.FieldNotStream, nil
}
if key != ReasoningOutputKey {
if len(outputKey) > 0 {
return schema2.FieldNotStream, nil
}
outputKey = key
}
}
field := path[0]
if field == ReasoningOutputKey || field == outputKey {
return schema2.FieldIsStream, nil
}
return schema2.FieldNotStream, nil
}
func toRetrievalSearchType(s int64) (knowledge.SearchType, error) {
switch s {
case 0:
return knowledge.SearchTypeSemantic, nil
case 1:
return knowledge.SearchTypeHybrid, nil
case 20:
return knowledge.SearchTypeFullText, nil
default:
return 0, fmt.Errorf("invalid retrieval search type %v", s)
}
}
type LLM struct {
r compose.Runnable[map[string]any, map[string]any]
outputFormat Format
requireCheckpoint bool
fullSources map[string]*schema2.SourceInfo
}
const (
rawOutputKey = "llm_raw_output_%s"
warningKey = "llm_warning_%s"
)
func jsonParse(ctx context.Context, data string, schema_ map[string]*vo.TypeInfo) (map[string]any, error) {
data = nodes.ExtractJSONString(data)
var result map[string]any
err := sonic.UnmarshalString(data, &result)
if err != nil {
c := execute.GetExeCtx(ctx)
if c != nil {
logs.CtxErrorf(ctx, "failed to parse json: %v, data: %s", err, data)
rawOutputK := fmt.Sprintf(rawOutputKey, c.NodeCtx.NodeKey)
warningK := fmt.Sprintf(warningKey, c.NodeCtx.NodeKey)
ctxcache.Store(ctx, rawOutputK, data)
ctxcache.Store(ctx, warningK, vo.WrapWarn(errno.ErrLLMStructuredOutputParseFail, err))
return map[string]any{}, nil
}
return nil, err
}
r, ws, err := nodes.ConvertInputs(ctx, result, schema_)
if err != nil {
return nil, vo.WrapError(errno.ErrLLMStructuredOutputParseFail, err)
}
if ws != nil {
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
}
return r, nil
}
type llmOptions struct {
toolWorkflowSW *schema.StreamWriter[*entity.Message]
}
func WithToolWorkflowMessageWriter(sw *schema.StreamWriter[*entity.Message]) nodes.NodeOption {
return nodes.WrapImplSpecificOptFn(func(o *llmOptions) {
o.toolWorkflowSW = sw
})
}
type llmState = map[string]any
const agentModelName = "agent_model"
func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...nodes.NodeOption) (composeOpts []compose.Option, resumingEvent *entity.InterruptEvent, err error) {
c := execute.GetExeCtx(ctx)
if c != nil {
resumingEvent = c.NodeCtx.ResumingEvent
}
var previousToolES map[string]*entity.ToolInterruptEvent
if c != nil && c.RootCtx.ResumeEvent != nil {
// check if we are not resuming, but previously interrupted. Interrupt immediately.
if resumingEvent == nil {
err := compose.ProcessState(ctx, func(ctx context.Context, state ToolInterruptEventStore) error {
var e error
previousToolES, e = state.GetToolInterruptEvents(c.NodeKey)
if e != nil {
return e
}
return nil
})
if err != nil {
return nil, nil, err
}
if len(previousToolES) > 0 {
return nil, nil, compose.InterruptAndRerun
}
}
}
if l.requireCheckpoint && c != nil {
checkpointID := fmt.Sprintf("%d_%s", c.RootCtx.RootExecuteID, c.NodeCtx.NodeKey)
composeOpts = append(composeOpts, compose.WithCheckPointID(checkpointID))
}
options := nodes.GetCommonOptions(&nodes.NodeOptions{}, opts...)
composeOpts = append(composeOpts, options.GetOptsForNested()...)
if resumingEvent != nil {
var (
resumeData string
e error
allIEs = make(map[string]*entity.ToolInterruptEvent)
)
err = compose.ProcessState(ctx, func(ctx context.Context, state ToolInterruptEventStore) error {
allIEs, e = state.GetToolInterruptEvents(c.NodeKey)
if e != nil {
return e
}
allIEs = maps.Clone(allIEs)
resumeData, e = state.ResumeToolInterruptEvent(c.NodeKey, resumingEvent.ToolInterruptEvent.ToolCallID)
return e
})
if err != nil {
return nil, nil, err
}
composeOpts = append(composeOpts, compose.WithToolsNodeOption(
compose.WithToolOption(
execute.WithResume(&entity.ResumeRequest{
ExecuteID: resumingEvent.ToolInterruptEvent.ExecuteID,
EventID: resumingEvent.ToolInterruptEvent.ID,
ResumeData: resumeData,
}, allIEs))))
chatModelHandler := callbacks2.NewHandlerHelper().ChatModel(&callbacks2.ModelCallbackHandler{
OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.CallbackInput) context.Context {
if runInfo.Name != agentModelName {
return ctx
}
// react agent loops back to chat model after resuming,
// pop the previous interrupt event immediately
ie, deleted, e := workflow.GetRepository().PopFirstInterruptEvent(ctx, c.RootExecuteID)
if e != nil {
logs.CtxErrorf(ctx, "failed to pop first interrupt event on react agent chatmodel start: %v", err)
return ctx
}
if !deleted {
logs.CtxErrorf(ctx, "failed to pop first interrupt event on react agent chatmodel start: not deleted")
return ctx
}
if ie.ID != resumingEvent.ID {
logs.CtxErrorf(ctx, "failed to pop first interrupt event on react agent chatmodel start, "+
"deleted ID: %d, resumingEvent ID: %d", ie.ID, resumingEvent.ID)
return ctx
}
return ctx
},
}).Handler()
composeOpts = append(composeOpts, compose.WithCallbacks(chatModelHandler))
}
if c != nil {
exeCfg := c.ExeCfg
composeOpts = append(composeOpts, compose.WithToolsNodeOption(compose.WithToolOption(execute.WithExecuteConfig(exeCfg))))
}
llmOpts := nodes.GetImplSpecificOptions(&llmOptions{}, opts...)
if llmOpts.toolWorkflowSW != nil {
toolMsgOpt, toolMsgSR := execute.WithMessagePipe()
composeOpts = append(composeOpts, toolMsgOpt)
safego.Go(ctx, func() {
defer toolMsgSR.Close()
for {
msg, err := toolMsgSR.Recv()
if err != nil {
if err == io.EOF {
return
}
logs.CtxErrorf(ctx, "failed to receive message from tool workflow: %v", err)
return
}
logs.Infof("received message from tool workflow: %+v", msg)
llmOpts.toolWorkflowSW.Send(msg, nil)
}
})
}
resolvedSources, err := nodes.ResolveStreamSources(ctx, l.fullSources)
if err != nil {
return nil, nil, err
}
var nodeKey vo.NodeKey
if c != nil && c.NodeCtx != nil {
nodeKey = c.NodeCtx.NodeKey
}
ctxcache.Store(ctx, fmt.Sprintf(sourceKey, nodeKey), resolvedSources)
return composeOpts, resumingEvent, nil
}
func handleInterrupt(ctx context.Context, err error, resumingEvent *entity.InterruptEvent) error {
info, ok := compose.ExtractInterruptInfo(err)
if !ok {
return err
}
info = info.SubGraphs["llm"] // 'llm' is the node key of the react agent
var extra any
for i := range info.RerunNodesExtra {
extra = info.RerunNodesExtra[i]
break
}
toolsNodeExtra, ok := extra.(*compose.ToolsInterruptAndRerunExtra)
if !ok {
return fmt.Errorf("llm rerun node extra type expected to be ToolsInterruptAndRerunExtra, actual: %T", extra)
}
id, err := workflow.GetRepository().GenID(ctx)
if err != nil {
return err
}
var (
previousInterruptedCallID string
highPriorityEvent *entity.ToolInterruptEvent
)
if resumingEvent != nil {
previousInterruptedCallID = resumingEvent.ToolInterruptEvent.ToolCallID
}
c := execute.GetExeCtx(ctx)
toolIEs := make([]*entity.ToolInterruptEvent, 0, len(toolsNodeExtra.RerunExtraMap))
for callID := range toolsNodeExtra.RerunExtraMap {
subIE, ok := toolsNodeExtra.RerunExtraMap[callID].(*entity.ToolInterruptEvent)
if !ok {
return fmt.Errorf("llm rerun node extra type expected to be ToolInterruptEvent, actual: %T", extra)
}
if subIE.ExecuteID == 0 {
subIE.ExecuteID = c.RootExecuteID
}
toolIEs = append(toolIEs, subIE)
if subIE.ToolCallID == previousInterruptedCallID {
highPriorityEvent = subIE
}
}
ie := &entity.InterruptEvent{
ID: id,
NodeKey: c.NodeKey,
NodeType: entity.NodeTypeLLM,
NodeTitle: c.NodeName,
NodeIcon: entity.NodeMetaByNodeType(entity.NodeTypeLLM).IconURL,
EventType: entity.InterruptEventLLM,
}
if highPriorityEvent != nil {
ie.ToolInterruptEvent = highPriorityEvent
} else {
ie.ToolInterruptEvent = toolIEs[0]
}
err = compose.ProcessState(ctx, func(ctx context.Context, ieStore ToolInterruptEventStore) error {
for i := range toolIEs {
e := ieStore.SetToolInterruptEvent(c.NodeKey, toolIEs[i].ToolCallID, toolIEs[i])
if e != nil {
return e
}
}
return nil
})
if err != nil {
return err
}
return compose.NewInterruptAndRerunErr(ie)
}
func (l *LLM) Invoke(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (out map[string]any, err error) {
composeOpts, resumingEvent, err := l.prepare(ctx, in, opts...)
if err != nil {
return nil, err
}
out, err = l.r.Invoke(ctx, in, composeOpts...)
if err != nil {
err = handleInterrupt(ctx, err, resumingEvent)
return nil, err
}
return out, nil
}
func (l *LLM) Stream(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (out *schema.StreamReader[map[string]any], err error) {
composeOpts, resumingEvent, err := l.prepare(ctx, in, opts...)
if err != nil {
return nil, err
}
out, err = l.r.Stream(ctx, in, composeOpts...)
if err != nil {
err = handleInterrupt(ctx, err, resumingEvent)
return nil, err
}
return out, nil
}
func injectKnowledgeTool(_ context.Context, g *compose.Graph[map[string]any, map[string]any], userPrompt string, cfg *KnowledgeRecallConfig) error {
selectedKwDetails, err := sonic.MarshalString(cfg.SelectedKnowledgeDetails)
if err != nil {
return err
}
_ = g.AddChatTemplateNode(knowledgeTemplateKey,
prompt.FromMessages(schema.Jinja2,
schema.SystemMessage(fmt.Sprintf(knowledgeIntentPrompt, selectedKwDetails, userPrompt)),
), compose.WithStatePreHandler(func(ctx context.Context, in map[string]any, state llmState) (map[string]any, error) {
for k, v := range in {
state[k] = v
}
return in, nil
}))
_ = g.AddChatModelNode(knowledgeChatModelKey, cfg.ChatModel)
_ = g.AddLambdaNode(knowledgeLambdaKey, compose.InvokableLambda(func(ctx context.Context, input *schema.Message) (output map[string]any, err error) {
modelPredictionIDs := strings.Split(input.Content, ",")
selectKwIDs := slices.ToMap(cfg.SelectedKnowledgeDetails, func(e *knowledge.KnowledgeDetail) (string, int64) {
return strconv.Itoa(int(e.ID)), e.ID
})
recallKnowledgeIDs := make([]int64, 0)
for _, id := range modelPredictionIDs {
if kid, ok := selectKwIDs[id]; ok {
recallKnowledgeIDs = append(recallKnowledgeIDs, kid)
}
}
if len(recallKnowledgeIDs) == 0 {
return make(map[string]any), nil
}
docs, err := crossknowledge.DefaultSVC().Retrieve(ctx, &knowledge.RetrieveRequest{
Query: userPrompt,
KnowledgeIDs: recallKnowledgeIDs,
Strategy: cfg.RetrievalStrategy.RetrievalStrategy,
})
if err != nil {
return nil, err
}
if len(docs.RetrieveSlices) == 0 && cfg.RetrievalStrategy.NoReCallReplyMode == NoReCallReplyModeOfDefault {
return make(map[string]any), nil
}
sb := strings.Builder{}
if len(docs.RetrieveSlices) == 0 && cfg.RetrievalStrategy.NoReCallReplyMode == NoReCallReplyModeOfCustomize {
sb.WriteString("recall slice 1: \n")
sb.WriteString(cfg.RetrievalStrategy.NoReCallReplyCustomizePrompt + "\n")
}
for idx, msg := range docs.RetrieveSlices {
sb.WriteString(fmt.Sprintf("recall slice %d:\n", idx+1))
sb.WriteString(fmt.Sprintf("%s\n", msg.Slice.GetSliceContent()))
}
output = map[string]any{
knowledgeUserPromptTemplateKey: fmt.Sprintf(knowledgeUserPromptTemplate, sb.String()),
}
return output, nil
}))
_ = g.AddEdge(compose.START, knowledgeTemplateKey)
_ = g.AddEdge(knowledgeTemplateKey, knowledgeChatModelKey)
_ = g.AddEdge(knowledgeChatModelKey, knowledgeLambdaKey)
return nil
}
type ToolInterruptEventStore interface {
SetToolInterruptEvent(llmNodeKey vo.NodeKey, toolCallID string, ie *entity.ToolInterruptEvent) error
GetToolInterruptEvents(llmNodeKey vo.NodeKey) (map[string]*entity.ToolInterruptEvent, error)
ResumeToolInterruptEvent(llmNodeKey vo.NodeKey, toolCallID string) (string, error)
}
func (l *LLM) ToCallbackOutput(ctx context.Context, output map[string]any) (*nodes.StructuredCallbackOutput, error) {
c := execute.GetExeCtx(ctx)
if c == nil {
return &nodes.StructuredCallbackOutput{
Output: output,
RawOutput: output,
}, nil
}
rawOutputK := fmt.Sprintf(rawOutputKey, c.NodeKey)
warningK := fmt.Sprintf(warningKey, c.NodeKey)
rawOutput, found := ctxcache.Get[string](ctx, rawOutputK)
if !found {
return &nodes.StructuredCallbackOutput{
Output: output,
RawOutput: output,
}, nil
}
warning, found := ctxcache.Get[vo.WorkflowError](ctx, warningK)
if !found {
return &nodes.StructuredCallbackOutput{
Output: output,
RawOutput: map[string]any{"output": rawOutput},
}, nil
}
return &nodes.StructuredCallbackOutput{
Output: output,
RawOutput: map[string]any{"output": rawOutput},
Error: warning,
}, nil
}