feat: manually mirror opencoze's code from bytedance
Change-Id: I09a73aadda978ad9511264a756b2ce51f5761adf
This commit is contained in:
847
backend/domain/workflow/internal/nodes/llm/llm.go
Normal file
847
backend/domain/workflow/internal/nodes/llm/llm.go
Normal file
@@ -0,0 +1,847 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/model/ark"
|
||||
"github.com/cloudwego/eino-ext/components/model/deepseek"
|
||||
"github.com/cloudwego/eino/callbacks"
|
||||
"github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/components/prompt"
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/flow/agent/react"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
callbacks2 "github.com/cloudwego/eino/utils/callbacks"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
crossknowledge "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
"github.com/coze-dev/coze-studio/backend/types/errno"
|
||||
)
|
||||
|
||||
type Format int
|
||||
|
||||
const (
|
||||
FormatText Format = iota
|
||||
FormatMarkdown
|
||||
FormatJSON
|
||||
)
|
||||
|
||||
const (
|
||||
jsonPromptFormat = `
|
||||
Strictly reply in valid JSON format.
|
||||
- Ensure the output strictly conforms to the JSON schema below
|
||||
- Do not include explanations, comments, or any text outside the JSON.
|
||||
|
||||
Here is the output JSON schema:
|
||||
'''
|
||||
%s
|
||||
'''
|
||||
`
|
||||
markdownPrompt = `
|
||||
Strictly reply in valid Markdown format.
|
||||
- For headings, use number signs (#).
|
||||
- For list items, start with dashes (-).
|
||||
- To emphasize text, wrap it with asterisks (*).
|
||||
- For code or commands, surround them with backticks (` + "`" + `).
|
||||
- For quoted text, use greater than signs (>).
|
||||
- For links, wrap the text in square brackets [], followed by the URL in parentheses ().
|
||||
- For images, use square brackets [] for the alt text, followed by the image URL in parentheses ().
|
||||
|
||||
`
|
||||
)
|
||||
|
||||
const (
|
||||
ReasoningOutputKey = "reasoning_content"
|
||||
)
|
||||
|
||||
const knowledgeUserPromptTemplate = `根据引用的内容回答问题:
|
||||
1.如果引用的内容里面包含 <img src=""> 的标签, 标签里的 src 字段表示图片地址, 需要在回答问题的时候展示出去, 输出格式为"" 。
|
||||
2.如果引用的内容不包含 <img src=""> 的标签, 你回答问题时不需要展示图片 。
|
||||
例如:
|
||||
如果内容为<img src="https://example.com/image.jpg">一只小猫,你的输出应为:。
|
||||
如果内容为<img src="https://example.com/image1.jpg">一只小猫 和 <img src="https://example.com/image2.jpg">一只小狗 和 <img src="https://example.com/image3.jpg">一只小牛,你的输出应为: 和  和 
|
||||
you can refer to the following content and do relevant searches to improve:
|
||||
---
|
||||
%s
|
||||
|
||||
question is:
|
||||
|
||||
`
|
||||
|
||||
const knowledgeIntentPrompt = `
|
||||
# 角色:
|
||||
你是一个知识库意图识别AI Agent。
|
||||
## 目标:
|
||||
- 按照「系统提示词」、用户需求、最新的聊天记录选择应该使用的知识库。
|
||||
## 工作流程:
|
||||
1. 分析「系统提示词」以确定用户的具体需求。
|
||||
2. 如果「系统提示词」明确指明了要使用的知识库,则直接返回这些知识库,只输出它们的knowledge_id,不需要再判断用户的输入
|
||||
3. 检查每个知识库的knowledge_name和knowledge_description,以了解它们各自的功能。
|
||||
4. 根据用户需求,选择最符合的知识库。
|
||||
5. 如果找到一个或多个合适的知识库,输出它们的knowledge_id。如果没有合适的知识库,输出0。
|
||||
## 约束:
|
||||
- 严格按照「系统提示词」和用户的需求选择知识库。「系统提示词」的优先级大于用户的需求
|
||||
- 如果有多个合适的知识库,将它们的knowledge_id用英文逗号连接后输出。
|
||||
- 输出必须仅为knowledge_id或0,不得包括任何其他内容或解释,不要在id后面输出知识库名称。
|
||||
|
||||
## 输出示例
|
||||
123,456
|
||||
|
||||
## 输出格式:
|
||||
输出应该是一个纯数字或者由英文逗号连接的数字序列,具体取决于选择的知识库数量。不应包含任何其他文本或格式。
|
||||
## 知识库列表如下
|
||||
%s
|
||||
## 「系统提示词」如下
|
||||
%s
|
||||
`
|
||||
|
||||
const (
|
||||
knowledgeTemplateKey = "knowledge_template"
|
||||
knowledgeChatModelKey = "knowledge_chat_model"
|
||||
knowledgeLambdaKey = "knowledge_lambda"
|
||||
knowledgeUserPromptTemplateKey = "knowledge_user_prompt_prefix"
|
||||
templateNodeKey = "template"
|
||||
llmNodeKey = "llm"
|
||||
outputConvertNodeKey = "output_convert"
|
||||
)
|
||||
|
||||
type NoReCallReplyMode int64
|
||||
|
||||
const (
|
||||
NoReCallReplyModeOfDefault NoReCallReplyMode = 0
|
||||
NoReCallReplyModeOfCustomize NoReCallReplyMode = 1
|
||||
)
|
||||
|
||||
type RetrievalStrategy struct {
|
||||
RetrievalStrategy *crossknowledge.RetrievalStrategy
|
||||
NoReCallReplyMode NoReCallReplyMode
|
||||
NoReCallReplyCustomizePrompt string
|
||||
}
|
||||
|
||||
type KnowledgeRecallConfig struct {
|
||||
ChatModel model.BaseChatModel
|
||||
Retriever crossknowledge.KnowledgeOperator
|
||||
RetrievalStrategy *RetrievalStrategy
|
||||
SelectedKnowledgeDetails []*crossknowledge.KnowledgeDetail
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
ChatModel ModelWithInfo
|
||||
Tools []tool.BaseTool
|
||||
SystemPrompt string
|
||||
UserPrompt string
|
||||
OutputFormat Format
|
||||
InputFields map[string]*vo.TypeInfo
|
||||
OutputFields map[string]*vo.TypeInfo
|
||||
ToolsReturnDirectly map[string]bool
|
||||
KnowledgeRecallConfig *KnowledgeRecallConfig
|
||||
FullSources map[string]*nodes.SourceInfo
|
||||
}
|
||||
|
||||
type LLM struct {
|
||||
r compose.Runnable[map[string]any, map[string]any]
|
||||
outputFormat Format
|
||||
outputFields map[string]*vo.TypeInfo
|
||||
canStream bool
|
||||
requireCheckpoint bool
|
||||
fullSources map[string]*nodes.SourceInfo
|
||||
}
|
||||
|
||||
const (
|
||||
rawOutputKey = "llm_raw_output_%s"
|
||||
warningKey = "llm_warning_%s"
|
||||
)
|
||||
|
||||
func jsonParse(ctx context.Context, data string, schema_ map[string]*vo.TypeInfo) (map[string]any, error) {
|
||||
data = nodes.ExtractJSONString(data)
|
||||
|
||||
var result map[string]any
|
||||
|
||||
err := sonic.UnmarshalString(data, &result)
|
||||
if err != nil {
|
||||
c := execute.GetExeCtx(ctx)
|
||||
if c != nil {
|
||||
logs.CtxErrorf(ctx, "failed to parse json: %v, data: %s", err, data)
|
||||
rawOutputK := fmt.Sprintf(rawOutputKey, c.NodeCtx.NodeKey)
|
||||
warningK := fmt.Sprintf(warningKey, c.NodeCtx.NodeKey)
|
||||
ctxcache.Store(ctx, rawOutputK, data)
|
||||
ctxcache.Store(ctx, warningK, vo.WrapWarn(errno.ErrLLMStructuredOutputParseFail, err))
|
||||
return map[string]any{}, nil
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r, ws, err := nodes.ConvertInputs(ctx, result, schema_)
|
||||
if err != nil {
|
||||
return nil, vo.WrapError(errno.ErrLLMStructuredOutputParseFail, err)
|
||||
}
|
||||
|
||||
if ws != nil {
|
||||
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func getReasoningContent(message *schema.Message) string {
|
||||
c, ok := deepseek.GetReasoningContent(message)
|
||||
if ok {
|
||||
return c
|
||||
}
|
||||
|
||||
c, ok = ark.GetReasoningContent(message)
|
||||
if ok {
|
||||
return c
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
nested []nodes.NestedWorkflowOption
|
||||
toolWorkflowSW *schema.StreamWriter[*entity.Message]
|
||||
}
|
||||
|
||||
type Option func(o *Options)
|
||||
|
||||
func WithNestedWorkflowOptions(nested ...nodes.NestedWorkflowOption) Option {
|
||||
return func(o *Options) {
|
||||
o.nested = append(o.nested, nested...)
|
||||
}
|
||||
}
|
||||
|
||||
func WithToolWorkflowMessageWriter(sw *schema.StreamWriter[*entity.Message]) Option {
|
||||
return func(o *Options) {
|
||||
o.toolWorkflowSW = sw
|
||||
}
|
||||
}
|
||||
|
||||
type llmState = map[string]any
|
||||
|
||||
const agentModelName = "agent_model"
|
||||
|
||||
func New(ctx context.Context, cfg *Config) (*LLM, error) {
|
||||
g := compose.NewGraph[map[string]any, map[string]any](compose.WithGenLocalState(func(ctx context.Context) (state llmState) {
|
||||
return llmState{}
|
||||
}))
|
||||
|
||||
var (
|
||||
hasReasoning bool
|
||||
canStream = true
|
||||
)
|
||||
|
||||
format := cfg.OutputFormat
|
||||
if format == FormatJSON {
|
||||
if len(cfg.OutputFields) == 1 {
|
||||
for _, v := range cfg.OutputFields {
|
||||
if v.Type == vo.DataTypeString {
|
||||
format = FormatText
|
||||
break
|
||||
}
|
||||
}
|
||||
} else if len(cfg.OutputFields) == 2 {
|
||||
if _, ok := cfg.OutputFields[ReasoningOutputKey]; ok {
|
||||
for k, v := range cfg.OutputFields {
|
||||
if k != ReasoningOutputKey && v.Type == vo.DataTypeString {
|
||||
format = FormatText
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
userPrompt := cfg.UserPrompt
|
||||
switch format {
|
||||
case FormatJSON:
|
||||
jsonSchema, err := vo.TypeInfoToJSONSchema(cfg.OutputFields, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
jsonPrompt := fmt.Sprintf(jsonPromptFormat, jsonSchema)
|
||||
userPrompt = userPrompt + jsonPrompt
|
||||
case FormatMarkdown:
|
||||
userPrompt = userPrompt + markdownPrompt
|
||||
case FormatText:
|
||||
}
|
||||
|
||||
if cfg.KnowledgeRecallConfig != nil {
|
||||
err := injectKnowledgeTool(ctx, g, cfg.UserPrompt, cfg.KnowledgeRecallConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userPrompt = fmt.Sprintf("{{%s}}%s", knowledgeUserPromptTemplateKey, userPrompt)
|
||||
|
||||
inputs := maps.Clone(cfg.InputFields)
|
||||
inputs[knowledgeUserPromptTemplateKey] = &vo.TypeInfo{
|
||||
Type: vo.DataTypeString,
|
||||
}
|
||||
sp := newPromptTpl(schema.System, cfg.SystemPrompt, inputs, nil)
|
||||
up := newPromptTpl(schema.User, userPrompt, inputs, []string{knowledgeUserPromptTemplateKey})
|
||||
template := newPrompts(sp, up, cfg.ChatModel)
|
||||
|
||||
_ = g.AddChatTemplateNode(templateNodeKey, template,
|
||||
compose.WithStatePreHandler(func(ctx context.Context, in map[string]any, state llmState) (map[string]any, error) {
|
||||
for k, v := range state {
|
||||
in[k] = v
|
||||
}
|
||||
return in, nil
|
||||
}))
|
||||
_ = g.AddEdge(knowledgeLambdaKey, templateNodeKey)
|
||||
|
||||
} else {
|
||||
sp := newPromptTpl(schema.System, cfg.SystemPrompt, cfg.InputFields, nil)
|
||||
up := newPromptTpl(schema.User, userPrompt, cfg.InputFields, nil)
|
||||
template := newPrompts(sp, up, cfg.ChatModel)
|
||||
_ = g.AddChatTemplateNode(templateNodeKey, template)
|
||||
|
||||
_ = g.AddEdge(compose.START, templateNodeKey)
|
||||
}
|
||||
|
||||
if len(cfg.Tools) > 0 {
|
||||
m, ok := cfg.ChatModel.(model.ToolCallingChatModel)
|
||||
if !ok {
|
||||
return nil, errors.New("requires a ToolCallingChatModel to use with tools")
|
||||
}
|
||||
reactConfig := react.AgentConfig{
|
||||
ToolCallingModel: m,
|
||||
ToolsConfig: compose.ToolsNodeConfig{Tools: cfg.Tools},
|
||||
ModelNodeName: agentModelName,
|
||||
}
|
||||
|
||||
if len(cfg.ToolsReturnDirectly) > 0 {
|
||||
reactConfig.ToolReturnDirectly = make(map[string]struct{}, len(cfg.ToolsReturnDirectly))
|
||||
for k := range cfg.ToolsReturnDirectly {
|
||||
reactConfig.ToolReturnDirectly[k] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
reactAgent, err := react.NewAgent(ctx, &reactConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
agentNode, opts := reactAgent.ExportGraph()
|
||||
opts = append(opts, compose.WithNodeName("workflow_llm_react_agent"))
|
||||
_ = g.AddGraphNode(llmNodeKey, agentNode, opts...)
|
||||
} else {
|
||||
_ = g.AddChatModelNode(llmNodeKey, cfg.ChatModel)
|
||||
}
|
||||
|
||||
_ = g.AddEdge(templateNodeKey, llmNodeKey)
|
||||
|
||||
if format == FormatJSON {
|
||||
iConvert := func(ctx context.Context, msg *schema.Message) (map[string]any, error) {
|
||||
return jsonParse(ctx, msg.Content, cfg.OutputFields)
|
||||
}
|
||||
|
||||
convertNode := compose.InvokableLambda(iConvert)
|
||||
|
||||
_ = g.AddLambdaNode(outputConvertNodeKey, convertNode)
|
||||
|
||||
canStream = false
|
||||
} else {
|
||||
var outputKey string
|
||||
if len(cfg.OutputFields) != 1 && len(cfg.OutputFields) != 2 {
|
||||
panic("impossible")
|
||||
}
|
||||
|
||||
for k, v := range cfg.OutputFields {
|
||||
if v.Type != vo.DataTypeString {
|
||||
panic("impossible")
|
||||
}
|
||||
|
||||
if k == ReasoningOutputKey {
|
||||
hasReasoning = true
|
||||
} else {
|
||||
outputKey = k
|
||||
}
|
||||
}
|
||||
|
||||
iConvert := func(_ context.Context, msg *schema.Message, _ ...struct{}) (map[string]any, error) {
|
||||
out := map[string]any{outputKey: msg.Content}
|
||||
if hasReasoning {
|
||||
out[ReasoningOutputKey] = getReasoningContent(msg)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
tConvert := func(_ context.Context, s *schema.StreamReader[*schema.Message], _ ...struct{}) (*schema.StreamReader[map[string]any], error) {
|
||||
sr, sw := schema.Pipe[map[string]any](0)
|
||||
|
||||
safego.Go(ctx, func() {
|
||||
reasoningDone := false
|
||||
for {
|
||||
msg, err := s.Recv()
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
sw.Send(map[string]any{
|
||||
outputKey: nodes.KeyIsFinished,
|
||||
}, nil)
|
||||
sw.Close()
|
||||
return
|
||||
}
|
||||
|
||||
sw.Send(nil, err)
|
||||
sw.Close()
|
||||
return
|
||||
}
|
||||
|
||||
if hasReasoning {
|
||||
reasoning := getReasoningContent(msg)
|
||||
if len(reasoning) > 0 {
|
||||
sw.Send(map[string]any{ReasoningOutputKey: reasoning}, nil)
|
||||
}
|
||||
}
|
||||
|
||||
if len(msg.Content) > 0 {
|
||||
if !reasoningDone && hasReasoning {
|
||||
reasoningDone = true
|
||||
sw.Send(map[string]any{
|
||||
ReasoningOutputKey: nodes.KeyIsFinished,
|
||||
}, nil)
|
||||
}
|
||||
sw.Send(map[string]any{outputKey: msg.Content}, nil)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return sr, nil
|
||||
}
|
||||
|
||||
convertNode, err := compose.AnyLambda(iConvert, nil, nil, tConvert)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_ = g.AddLambdaNode(outputConvertNodeKey, convertNode)
|
||||
}
|
||||
|
||||
_ = g.AddEdge(llmNodeKey, outputConvertNodeKey)
|
||||
_ = g.AddEdge(outputConvertNodeKey, compose.END)
|
||||
|
||||
requireCheckpoint := false
|
||||
if len(cfg.Tools) > 0 {
|
||||
requireCheckpoint = true
|
||||
}
|
||||
|
||||
var opts []compose.GraphCompileOption
|
||||
if requireCheckpoint {
|
||||
opts = append(opts, compose.WithCheckPointStore(workflow.GetRepository()))
|
||||
}
|
||||
opts = append(opts, compose.WithGraphName("workflow_llm_node_graph"))
|
||||
|
||||
r, err := g.Compile(ctx, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
llm := &LLM{
|
||||
r: r,
|
||||
outputFormat: format,
|
||||
canStream: canStream,
|
||||
requireCheckpoint: requireCheckpoint,
|
||||
fullSources: cfg.FullSources,
|
||||
}
|
||||
|
||||
return llm, nil
|
||||
}
|
||||
|
||||
func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...Option) (composeOpts []compose.Option, resumingEvent *entity.InterruptEvent, err error) {
|
||||
c := execute.GetExeCtx(ctx)
|
||||
if c != nil {
|
||||
resumingEvent = c.NodeCtx.ResumingEvent
|
||||
}
|
||||
var previousToolES map[string]*entity.ToolInterruptEvent
|
||||
|
||||
if c != nil && c.RootCtx.ResumeEvent != nil {
|
||||
// check if we are not resuming, but previously interrupted. Interrupt immediately.
|
||||
if resumingEvent == nil {
|
||||
err := compose.ProcessState(ctx, func(ctx context.Context, state ToolInterruptEventStore) error {
|
||||
var e error
|
||||
previousToolES, e = state.GetToolInterruptEvents(c.NodeKey)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if len(previousToolES) > 0 {
|
||||
return nil, nil, compose.InterruptAndRerun
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if l.requireCheckpoint && c != nil {
|
||||
checkpointID := fmt.Sprintf("%d_%s", c.RootCtx.RootExecuteID, c.NodeCtx.NodeKey)
|
||||
composeOpts = append(composeOpts, compose.WithCheckPointID(checkpointID))
|
||||
}
|
||||
|
||||
llmOpts := &Options{}
|
||||
for _, opt := range opts {
|
||||
opt(llmOpts)
|
||||
}
|
||||
|
||||
nestedOpts := &nodes.NestedWorkflowOptions{}
|
||||
for _, opt := range llmOpts.nested {
|
||||
opt(nestedOpts)
|
||||
}
|
||||
|
||||
composeOpts = append(composeOpts, nestedOpts.GetOptsForNested()...)
|
||||
|
||||
if resumingEvent != nil {
|
||||
var (
|
||||
resumeData string
|
||||
e error
|
||||
allIEs = make(map[string]*entity.ToolInterruptEvent)
|
||||
)
|
||||
err = compose.ProcessState(ctx, func(ctx context.Context, state ToolInterruptEventStore) error {
|
||||
allIEs, e = state.GetToolInterruptEvents(c.NodeKey)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
|
||||
allIEs = maps.Clone(allIEs)
|
||||
|
||||
resumeData, e = state.ResumeToolInterruptEvent(c.NodeKey, resumingEvent.ToolInterruptEvent.ToolCallID)
|
||||
|
||||
return e
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
composeOpts = append(composeOpts, compose.WithToolsNodeOption(
|
||||
compose.WithToolOption(
|
||||
execute.WithResume(&entity.ResumeRequest{
|
||||
ExecuteID: resumingEvent.ToolInterruptEvent.ExecuteID,
|
||||
EventID: resumingEvent.ToolInterruptEvent.ID,
|
||||
ResumeData: resumeData,
|
||||
}, allIEs))))
|
||||
|
||||
chatModelHandler := callbacks2.NewHandlerHelper().ChatModel(&callbacks2.ModelCallbackHandler{
|
||||
OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.CallbackInput) context.Context {
|
||||
if runInfo.Name != agentModelName {
|
||||
return ctx
|
||||
}
|
||||
|
||||
// react agent loops back to chat model after resuming,
|
||||
// pop the previous interrupt event immediately
|
||||
ie, deleted, e := workflow.GetRepository().PopFirstInterruptEvent(ctx, c.RootExecuteID)
|
||||
if e != nil {
|
||||
logs.CtxErrorf(ctx, "failed to pop first interrupt event on react agent chatmodel start: %v", err)
|
||||
return ctx
|
||||
}
|
||||
|
||||
if !deleted {
|
||||
logs.CtxErrorf(ctx, "failed to pop first interrupt event on react agent chatmodel start: not deleted")
|
||||
return ctx
|
||||
}
|
||||
|
||||
if ie.ID != resumingEvent.ID {
|
||||
logs.CtxErrorf(ctx, "failed to pop first interrupt event on react agent chatmodel start, "+
|
||||
"deleted ID: %d, resumingEvent ID: %d", ie.ID, resumingEvent.ID)
|
||||
return ctx
|
||||
}
|
||||
|
||||
return ctx
|
||||
},
|
||||
}).Handler()
|
||||
|
||||
composeOpts = append(composeOpts, compose.WithCallbacks(chatModelHandler))
|
||||
}
|
||||
|
||||
if c != nil {
|
||||
exeCfg := c.ExeCfg
|
||||
composeOpts = append(composeOpts, compose.WithToolsNodeOption(compose.WithToolOption(execute.WithExecuteConfig(exeCfg))))
|
||||
}
|
||||
|
||||
if llmOpts.toolWorkflowSW != nil {
|
||||
toolMsgOpt, toolMsgSR := execute.WithMessagePipe()
|
||||
composeOpts = append(composeOpts, toolMsgOpt)
|
||||
|
||||
safego.Go(ctx, func() {
|
||||
defer toolMsgSR.Close()
|
||||
for {
|
||||
msg, err := toolMsgSR.Recv()
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return
|
||||
}
|
||||
logs.CtxErrorf(ctx, "failed to receive message from tool workflow: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
logs.Infof("received message from tool workflow: %+v", msg)
|
||||
|
||||
llmOpts.toolWorkflowSW.Send(msg, nil)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
resolvedSources, err := nodes.ResolveStreamSources(ctx, l.fullSources)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var nodeKey vo.NodeKey
|
||||
if c != nil && c.NodeCtx != nil {
|
||||
nodeKey = c.NodeCtx.NodeKey
|
||||
}
|
||||
ctxcache.Store(ctx, fmt.Sprintf(sourceKey, nodeKey), resolvedSources)
|
||||
|
||||
return composeOpts, resumingEvent, nil
|
||||
}
|
||||
|
||||
func handleInterrupt(ctx context.Context, err error, resumingEvent *entity.InterruptEvent) error {
|
||||
info, ok := compose.ExtractInterruptInfo(err)
|
||||
if !ok {
|
||||
return err
|
||||
}
|
||||
|
||||
info = info.SubGraphs["llm"] // 'llm' is the node key of the react agent
|
||||
var extra any
|
||||
for i := range info.RerunNodesExtra {
|
||||
extra = info.RerunNodesExtra[i]
|
||||
break
|
||||
}
|
||||
|
||||
toolsNodeExtra, ok := extra.(*compose.ToolsInterruptAndRerunExtra)
|
||||
if !ok {
|
||||
return fmt.Errorf("llm rerun node extra type expected to be ToolsInterruptAndRerunExtra, actual: %T", extra)
|
||||
}
|
||||
id, err := workflow.GetRepository().GenID(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var (
|
||||
previousInterruptedCallID string
|
||||
highPriorityEvent *entity.ToolInterruptEvent
|
||||
)
|
||||
if resumingEvent != nil {
|
||||
previousInterruptedCallID = resumingEvent.ToolInterruptEvent.ToolCallID
|
||||
}
|
||||
|
||||
c := execute.GetExeCtx(ctx)
|
||||
|
||||
toolIEs := make([]*entity.ToolInterruptEvent, 0, len(toolsNodeExtra.RerunExtraMap))
|
||||
for callID := range toolsNodeExtra.RerunExtraMap {
|
||||
subIE, ok := toolsNodeExtra.RerunExtraMap[callID].(*entity.ToolInterruptEvent)
|
||||
if !ok {
|
||||
return fmt.Errorf("llm rerun node extra type expected to be ToolInterruptEvent, actual: %T", extra)
|
||||
}
|
||||
|
||||
if subIE.ExecuteID == 0 {
|
||||
subIE.ExecuteID = c.RootExecuteID
|
||||
}
|
||||
|
||||
toolIEs = append(toolIEs, subIE)
|
||||
if subIE.ToolCallID == previousInterruptedCallID {
|
||||
highPriorityEvent = subIE
|
||||
}
|
||||
}
|
||||
|
||||
ie := &entity.InterruptEvent{
|
||||
ID: id,
|
||||
NodeKey: c.NodeKey,
|
||||
NodeType: entity.NodeTypeLLM,
|
||||
NodeTitle: c.NodeName,
|
||||
NodeIcon: entity.NodeMetaByNodeType(entity.NodeTypeLLM).IconURL,
|
||||
EventType: entity.InterruptEventLLM,
|
||||
}
|
||||
|
||||
if highPriorityEvent != nil {
|
||||
ie.ToolInterruptEvent = highPriorityEvent
|
||||
} else {
|
||||
ie.ToolInterruptEvent = toolIEs[0]
|
||||
}
|
||||
|
||||
err = compose.ProcessState(ctx, func(ctx context.Context, ieStore ToolInterruptEventStore) error {
|
||||
for i := range toolIEs {
|
||||
e := ieStore.SetToolInterruptEvent(c.NodeKey, toolIEs[i].ToolCallID, toolIEs[i])
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return compose.NewInterruptAndRerunErr(ie)
|
||||
}
|
||||
|
||||
func (l *LLM) Chat(ctx context.Context, in map[string]any, opts ...Option) (out map[string]any, err error) {
|
||||
composeOpts, resumingEvent, err := l.prepare(ctx, in, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out, err = l.r.Invoke(ctx, in, composeOpts...)
|
||||
if err != nil {
|
||||
err = handleInterrupt(ctx, err, resumingEvent)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (l *LLM) ChatStream(ctx context.Context, in map[string]any, opts ...Option) (out *schema.StreamReader[map[string]any], err error) {
|
||||
composeOpts, resumingEvent, err := l.prepare(ctx, in, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out, err = l.r.Stream(ctx, in, composeOpts...)
|
||||
if err != nil {
|
||||
err = handleInterrupt(ctx, err, resumingEvent)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func injectKnowledgeTool(_ context.Context, g *compose.Graph[map[string]any, map[string]any], userPrompt string, cfg *KnowledgeRecallConfig) error {
|
||||
selectedKwDetails, err := sonic.MarshalString(cfg.SelectedKnowledgeDetails)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_ = g.AddChatTemplateNode(knowledgeTemplateKey,
|
||||
prompt.FromMessages(schema.Jinja2,
|
||||
schema.SystemMessage(fmt.Sprintf(knowledgeIntentPrompt, selectedKwDetails, userPrompt)),
|
||||
), compose.WithStatePreHandler(func(ctx context.Context, in map[string]any, state llmState) (map[string]any, error) {
|
||||
for k, v := range in {
|
||||
state[k] = v
|
||||
}
|
||||
return in, nil
|
||||
}))
|
||||
_ = g.AddChatModelNode(knowledgeChatModelKey, cfg.ChatModel)
|
||||
|
||||
_ = g.AddLambdaNode(knowledgeLambdaKey, compose.InvokableLambda(func(ctx context.Context, input *schema.Message) (output map[string]any, err error) {
|
||||
modelPredictionIDs := strings.Split(input.Content, ",")
|
||||
selectKwIDs := slices.ToMap(cfg.SelectedKnowledgeDetails, func(e *crossknowledge.KnowledgeDetail) (string, int64) {
|
||||
return strconv.Itoa(int(e.ID)), e.ID
|
||||
})
|
||||
recallKnowledgeIDs := make([]int64, 0)
|
||||
for _, id := range modelPredictionIDs {
|
||||
if kid, ok := selectKwIDs[id]; ok {
|
||||
recallKnowledgeIDs = append(recallKnowledgeIDs, kid)
|
||||
}
|
||||
}
|
||||
|
||||
if len(recallKnowledgeIDs) == 0 {
|
||||
return make(map[string]any), nil
|
||||
}
|
||||
|
||||
docs, err := cfg.Retriever.Retrieve(ctx, &crossknowledge.RetrieveRequest{
|
||||
Query: userPrompt,
|
||||
KnowledgeIDs: recallKnowledgeIDs,
|
||||
RetrievalStrategy: cfg.RetrievalStrategy.RetrievalStrategy,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(docs.Slices) == 0 && cfg.RetrievalStrategy.NoReCallReplyMode == NoReCallReplyModeOfDefault {
|
||||
return make(map[string]any), nil
|
||||
}
|
||||
|
||||
sb := strings.Builder{}
|
||||
if len(docs.Slices) == 0 && cfg.RetrievalStrategy.NoReCallReplyMode == NoReCallReplyModeOfCustomize {
|
||||
sb.WriteString("recall slice 1: \n")
|
||||
sb.WriteString(cfg.RetrievalStrategy.NoReCallReplyCustomizePrompt + "\n")
|
||||
}
|
||||
|
||||
for idx, msg := range docs.Slices {
|
||||
sb.WriteString(fmt.Sprintf("recall slice %d:\n", idx+1))
|
||||
sb.WriteString(fmt.Sprintf("%s\n", msg.Output))
|
||||
}
|
||||
|
||||
output = map[string]any{
|
||||
knowledgeUserPromptTemplateKey: fmt.Sprintf(knowledgeUserPromptTemplate, sb.String()),
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}))
|
||||
_ = g.AddEdge(compose.START, knowledgeTemplateKey)
|
||||
_ = g.AddEdge(knowledgeTemplateKey, knowledgeChatModelKey)
|
||||
_ = g.AddEdge(knowledgeChatModelKey, knowledgeLambdaKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
type ToolInterruptEventStore interface {
|
||||
SetToolInterruptEvent(llmNodeKey vo.NodeKey, toolCallID string, ie *entity.ToolInterruptEvent) error
|
||||
GetToolInterruptEvents(llmNodeKey vo.NodeKey) (map[string]*entity.ToolInterruptEvent, error)
|
||||
ResumeToolInterruptEvent(llmNodeKey vo.NodeKey, toolCallID string) (string, error)
|
||||
}
|
||||
|
||||
func (l *LLM) ToCallbackOutput(ctx context.Context, output map[string]any) (*nodes.StructuredCallbackOutput, error) {
|
||||
c := execute.GetExeCtx(ctx)
|
||||
if c == nil {
|
||||
return &nodes.StructuredCallbackOutput{
|
||||
Output: output,
|
||||
RawOutput: output,
|
||||
}, nil
|
||||
}
|
||||
rawOutputK := fmt.Sprintf(rawOutputKey, c.NodeKey)
|
||||
warningK := fmt.Sprintf(warningKey, c.NodeKey)
|
||||
rawOutput, found := ctxcache.Get[string](ctx, rawOutputK)
|
||||
if !found {
|
||||
return &nodes.StructuredCallbackOutput{
|
||||
Output: output,
|
||||
RawOutput: output,
|
||||
}, nil
|
||||
}
|
||||
|
||||
warning, found := ctxcache.Get[vo.WorkflowError](ctx, warningK)
|
||||
if !found {
|
||||
return &nodes.StructuredCallbackOutput{
|
||||
Output: output,
|
||||
RawOutput: map[string]any{"output": rawOutput},
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &nodes.StructuredCallbackOutput{
|
||||
Output: output,
|
||||
RawOutput: map[string]any{"output": rawOutput},
|
||||
Error: warning,
|
||||
}, nil
|
||||
}
|
||||
184
backend/domain/workflow/internal/nodes/llm/model_with_info.go
Normal file
184
backend/domain/workflow/internal/nodes/llm/model_with_info.go
Normal file
@@ -0,0 +1,184 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/cloudwego/eino/callbacks"
|
||||
"github.com/cloudwego/eino/components"
|
||||
"github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
crossmodelmgr "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
)
|
||||
|
||||
type ModelWithInfo interface {
|
||||
model.BaseChatModel
|
||||
Info(ctx context.Context) *crossmodelmgr.Model
|
||||
}
|
||||
|
||||
type ModelForLLM struct {
|
||||
Model model.BaseChatModel
|
||||
MInfo *crossmodelmgr.Model
|
||||
FallbackModel model.BaseChatModel
|
||||
FallbackInfo *crossmodelmgr.Model
|
||||
UseFallback func(ctx context.Context) bool
|
||||
|
||||
modelEnableCallback bool
|
||||
fallbackEnableCallback bool
|
||||
}
|
||||
|
||||
func NewModel(m model.BaseChatModel, info *crossmodelmgr.Model) *ModelForLLM {
|
||||
return &ModelForLLM{
|
||||
Model: m,
|
||||
MInfo: info,
|
||||
UseFallback: func(ctx context.Context) bool {
|
||||
return false
|
||||
},
|
||||
|
||||
modelEnableCallback: components.IsCallbacksEnabled(m),
|
||||
}
|
||||
}
|
||||
|
||||
func NewModelWithFallback(m, f model.BaseChatModel, info, fInfo *crossmodelmgr.Model) *ModelForLLM {
|
||||
return &ModelForLLM{
|
||||
Model: m,
|
||||
MInfo: info,
|
||||
FallbackModel: f,
|
||||
FallbackInfo: fInfo,
|
||||
UseFallback: func(ctx context.Context) bool {
|
||||
exeCtx := execute.GetExeCtx(ctx)
|
||||
if exeCtx == nil || exeCtx.NodeCtx == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return exeCtx.CurrentRetryCount > 0
|
||||
},
|
||||
|
||||
modelEnableCallback: components.IsCallbacksEnabled(m),
|
||||
fallbackEnableCallback: components.IsCallbacksEnabled(f),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *ModelForLLM) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (
|
||||
output *schema.Message, err error) {
|
||||
if m.UseFallback(ctx) {
|
||||
if !m.fallbackEnableCallback {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = callbacks.OnError(ctx, err)
|
||||
} else {
|
||||
_ = callbacks.OnEnd(ctx, output)
|
||||
}
|
||||
}()
|
||||
ctx = callbacks.OnStart(ctx, input)
|
||||
}
|
||||
return m.FallbackModel.Generate(ctx, input, opts...)
|
||||
}
|
||||
|
||||
if !m.modelEnableCallback {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = callbacks.OnError(ctx, err)
|
||||
} else {
|
||||
_ = callbacks.OnEnd(ctx, output)
|
||||
}
|
||||
}()
|
||||
ctx = callbacks.OnStart(ctx, input)
|
||||
}
|
||||
return m.Model.Generate(ctx, input, opts...)
|
||||
}
|
||||
|
||||
func (m *ModelForLLM) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (
|
||||
output *schema.StreamReader[*schema.Message], err error) {
|
||||
if m.UseFallback(ctx) {
|
||||
if !m.fallbackEnableCallback {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = callbacks.OnError(ctx, err)
|
||||
} else {
|
||||
_, output = callbacks.OnEndWithStreamOutput(ctx, output)
|
||||
}
|
||||
}()
|
||||
ctx = callbacks.OnStart(ctx, input)
|
||||
}
|
||||
return m.FallbackModel.Stream(ctx, input, opts...)
|
||||
}
|
||||
|
||||
if !m.modelEnableCallback {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = callbacks.OnError(ctx, err)
|
||||
} else {
|
||||
_, output = callbacks.OnEndWithStreamOutput(ctx, output)
|
||||
}
|
||||
}()
|
||||
ctx = callbacks.OnStart(ctx, input)
|
||||
}
|
||||
return m.Model.Stream(ctx, input, opts...)
|
||||
}
|
||||
|
||||
func (m *ModelForLLM) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatModel, error) {
|
||||
toolModel, ok := m.Model.(model.ToolCallingChatModel)
|
||||
if !ok {
|
||||
return nil, errors.New("requires a ToolCallingChatModel to use with tools")
|
||||
}
|
||||
|
||||
var err error
|
||||
toolModel, err = toolModel.WithTools(tools)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var fallbackToolModel model.ToolCallingChatModel
|
||||
if m.FallbackModel != nil {
|
||||
fallbackToolModel, ok = m.FallbackModel.(model.ToolCallingChatModel)
|
||||
if !ok {
|
||||
return nil, errors.New("requires a ToolCallingChatModel to use with tools")
|
||||
}
|
||||
|
||||
fallbackToolModel, err = fallbackToolModel.WithTools(tools)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &ModelForLLM{
|
||||
Model: toolModel,
|
||||
MInfo: m.MInfo,
|
||||
FallbackModel: fallbackToolModel,
|
||||
FallbackInfo: m.FallbackInfo,
|
||||
UseFallback: m.UseFallback,
|
||||
modelEnableCallback: m.modelEnableCallback,
|
||||
fallbackEnableCallback: m.fallbackEnableCallback,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *ModelForLLM) IsCallbacksEnabled() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *ModelForLLM) Info(ctx context.Context) *crossmodelmgr.Model {
|
||||
if m.UseFallback(ctx) {
|
||||
return m.FallbackInfo
|
||||
}
|
||||
|
||||
return m.MInfo
|
||||
}
|
||||
287
backend/domain/workflow/internal/nodes/llm/prompt.go
Normal file
287
backend/domain/workflow/internal/nodes/llm/prompt.go
Normal file
@@ -0,0 +1,287 @@
|
||||
/*
|
||||
* Copyright 2025 coze-dev Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/cloudwego/eino/components/prompt"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/coze-dev/coze-studio/backend/api/model/crossdomain/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
)
|
||||
|
||||
type prompts struct {
|
||||
sp *promptTpl
|
||||
up *promptTpl
|
||||
mwi ModelWithInfo
|
||||
}
|
||||
|
||||
type promptTpl struct {
|
||||
role schema.RoleType
|
||||
tpl string
|
||||
parts []promptPart
|
||||
hasMultiModal bool
|
||||
reservedKeys []string
|
||||
}
|
||||
|
||||
type promptPart struct {
|
||||
part nodes.TemplatePart
|
||||
fileType *vo.FileSubType
|
||||
}
|
||||
|
||||
func newPromptTpl(role schema.RoleType,
|
||||
tpl string,
|
||||
inputTypes map[string]*vo.TypeInfo,
|
||||
reservedKeys []string,
|
||||
) *promptTpl {
|
||||
if len(tpl) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
parts := nodes.ParseTemplate(tpl)
|
||||
promptParts := make([]promptPart, 0, len(parts))
|
||||
hasMultiModal := false
|
||||
for _, part := range parts {
|
||||
if !part.IsVariable {
|
||||
promptParts = append(promptParts, promptPart{
|
||||
part: part,
|
||||
})
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
tInfo := part.TypeInfo(inputTypes)
|
||||
if tInfo == nil || tInfo.Type != vo.DataTypeFile {
|
||||
promptParts = append(promptParts, promptPart{
|
||||
part: part,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
promptParts = append(promptParts, promptPart{
|
||||
part: part,
|
||||
fileType: tInfo.FileType,
|
||||
})
|
||||
|
||||
hasMultiModal = true
|
||||
}
|
||||
|
||||
return &promptTpl{
|
||||
role: role,
|
||||
tpl: tpl,
|
||||
parts: promptParts,
|
||||
hasMultiModal: hasMultiModal,
|
||||
reservedKeys: reservedKeys,
|
||||
}
|
||||
}
|
||||
|
||||
const sourceKey = "sources_%s"
|
||||
|
||||
func newPrompts(sp, up *promptTpl, model ModelWithInfo) *prompts {
|
||||
return &prompts{
|
||||
sp: sp,
|
||||
up: up,
|
||||
mwi: model,
|
||||
}
|
||||
}
|
||||
|
||||
func (pl *promptTpl) render(ctx context.Context, vs map[string]any,
|
||||
sources map[string]*nodes.SourceInfo,
|
||||
supportedModals map[modelmgr.Modal]bool) (*schema.Message, error) {
|
||||
if !pl.hasMultiModal || len(supportedModals) == 0 {
|
||||
var opts []nodes.RenderOption
|
||||
if len(pl.reservedKeys) > 0 {
|
||||
opts = append(opts, nodes.WithReservedKey(pl.reservedKeys...))
|
||||
}
|
||||
r, err := nodes.Render(ctx, pl.tpl, vs, sources, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &schema.Message{
|
||||
Role: pl.role,
|
||||
Content: r,
|
||||
}, nil
|
||||
}
|
||||
|
||||
multiParts := make([]schema.ChatMessagePart, 0, len(pl.parts))
|
||||
m, err := sonic.Marshal(vs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, part := range pl.parts {
|
||||
if !part.part.IsVariable {
|
||||
multiParts = append(multiParts, schema.ChatMessagePart{
|
||||
Type: schema.ChatMessagePartTypeText,
|
||||
Text: part.part.Value,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
skipped, invalid := part.part.Skipped(sources)
|
||||
if invalid {
|
||||
var reserved bool
|
||||
for _, k := range pl.reservedKeys {
|
||||
if k == part.part.Root {
|
||||
reserved = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !reserved {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if skipped {
|
||||
continue
|
||||
}
|
||||
|
||||
r, err := part.part.Render(m)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if part.fileType == nil {
|
||||
multiParts = append(multiParts, schema.ChatMessagePart{
|
||||
Type: schema.ChatMessagePartTypeText,
|
||||
Text: r,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
switch *part.fileType {
|
||||
case vo.FileTypeImage, vo.FileTypeSVG:
|
||||
if _, ok := supportedModals[modelmgr.ModalImage]; !ok {
|
||||
multiParts = append(multiParts, schema.ChatMessagePart{
|
||||
Type: schema.ChatMessagePartTypeText,
|
||||
Text: r,
|
||||
})
|
||||
} else {
|
||||
multiParts = append(multiParts, schema.ChatMessagePart{
|
||||
Type: schema.ChatMessagePartTypeImageURL,
|
||||
ImageURL: &schema.ChatMessageImageURL{
|
||||
URL: r,
|
||||
},
|
||||
})
|
||||
}
|
||||
case vo.FileTypeAudio, vo.FileTypeVoice:
|
||||
if _, ok := supportedModals[modelmgr.ModalAudio]; !ok {
|
||||
multiParts = append(multiParts, schema.ChatMessagePart{
|
||||
Type: schema.ChatMessagePartTypeText,
|
||||
Text: r,
|
||||
})
|
||||
} else {
|
||||
multiParts = append(multiParts, schema.ChatMessagePart{
|
||||
Type: schema.ChatMessagePartTypeAudioURL,
|
||||
AudioURL: &schema.ChatMessageAudioURL{
|
||||
URL: r,
|
||||
},
|
||||
})
|
||||
}
|
||||
case vo.FileTypeVideo:
|
||||
if _, ok := supportedModals[modelmgr.ModalVideo]; !ok {
|
||||
multiParts = append(multiParts, schema.ChatMessagePart{
|
||||
Type: schema.ChatMessagePartTypeText,
|
||||
Text: r,
|
||||
})
|
||||
} else {
|
||||
multiParts = append(multiParts, schema.ChatMessagePart{
|
||||
Type: schema.ChatMessagePartTypeVideoURL,
|
||||
VideoURL: &schema.ChatMessageVideoURL{
|
||||
URL: r,
|
||||
},
|
||||
})
|
||||
}
|
||||
default:
|
||||
if _, ok := supportedModals[modelmgr.ModalFile]; !ok {
|
||||
multiParts = append(multiParts, schema.ChatMessagePart{
|
||||
Type: schema.ChatMessagePartTypeText,
|
||||
Text: r,
|
||||
})
|
||||
} else {
|
||||
multiParts = append(multiParts, schema.ChatMessagePart{
|
||||
Type: schema.ChatMessagePartTypeFileURL,
|
||||
FileURL: &schema.ChatMessageFileURL{
|
||||
URL: r,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &schema.Message{
|
||||
Role: pl.role,
|
||||
MultiContent: multiParts,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *prompts) Format(ctx context.Context, vs map[string]any, _ ...prompt.Option) (
|
||||
_ []*schema.Message, err error) {
|
||||
exeCtx := execute.GetExeCtx(ctx)
|
||||
var nodeKey vo.NodeKey
|
||||
if exeCtx != nil && exeCtx.NodeCtx != nil {
|
||||
nodeKey = exeCtx.NodeCtx.NodeKey
|
||||
}
|
||||
sk := fmt.Sprintf(sourceKey, nodeKey)
|
||||
|
||||
sources, ok := ctxcache.Get[map[string]*nodes.SourceInfo](ctx, sk)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("resolved sources not found llm node, key: %s", sk)
|
||||
}
|
||||
|
||||
supportedModal := map[modelmgr.Modal]bool{}
|
||||
mInfo := p.mwi.Info(ctx)
|
||||
if mInfo != nil {
|
||||
for i := range mInfo.Meta.Capability.InputModal {
|
||||
supportedModal[mInfo.Meta.Capability.InputModal[i]] = true
|
||||
}
|
||||
}
|
||||
|
||||
var systemMsg, userMsg *schema.Message
|
||||
if p.sp != nil {
|
||||
systemMsg, err = p.sp.render(ctx, vs, sources, supportedModal)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if p.up != nil {
|
||||
userMsg, err = p.up.render(ctx, vs, sources, supportedModal)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if userMsg == nil {
|
||||
// give it a default empty message.
|
||||
// Some model may fail on empty message such as this one.
|
||||
userMsg = schema.UserMessage("")
|
||||
}
|
||||
|
||||
if systemMsg == nil {
|
||||
return []*schema.Message{userMsg}, nil
|
||||
}
|
||||
|
||||
return []*schema.Message{systemMsg, userMsg}, nil
|
||||
}
|
||||
Reference in New Issue
Block a user