/* * Copyright 2025 coze-dev Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package llm import ( "context" "fmt" "github.com/cloudwego/eino/components/prompt" "github.com/cloudwego/eino/schema" "github.com/coze-dev/coze-studio/backend/api/model/workflow" "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes" schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr" "github.com/coze-dev/coze-studio/backend/pkg/ctxcache" "github.com/coze-dev/coze-studio/backend/pkg/logs" "github.com/coze-dev/coze-studio/backend/pkg/sonic" ) type prompts struct { sp *promptTpl up *promptTpl mwi ModelWithInfo } type promptsWithChatHistory struct { prompts *prompts cfg *vo.ChatHistorySetting } func withReservedKeys(keys []string) func(tpl *promptTpl) { return func(tpl *promptTpl) { tpl.reservedKeys = keys } } func withAssociateUserInputFields(fs map[string]struct{}) func(tpl *promptTpl) { return func(tpl *promptTpl) { tpl.associateUserInputFields = fs } } type promptTpl struct { role schema.RoleType tpl string parts []promptPart hasMultiModal bool reservedKeys []string associateUserInputFields map[string]struct{} } type promptPart struct { part nodes.TemplatePart fileType *vo.FileSubType } func newPromptTpl(role schema.RoleType, tpl string, inputTypes map[string]*vo.TypeInfo, opts ...func(*promptTpl), ) *promptTpl { if len(tpl) == 0 { return nil } pTpl := &promptTpl{ role: role, tpl: tpl, } for _, opt := range opts { opt(pTpl) } parts := nodes.ParseTemplate(tpl) promptParts := make([]promptPart, 0, len(parts)) hasMultiModal := false for _, part := range parts { if !part.IsVariable { promptParts = append(promptParts, promptPart{ part: part, }) continue } tInfo := part.TypeInfo(inputTypes) if tInfo == nil || tInfo.Type != vo.DataTypeFile { promptParts = append(promptParts, promptPart{ part: part, }) continue } promptParts = append(promptParts, promptPart{ part: part, fileType: tInfo.FileType, }) hasMultiModal = true } pTpl.parts = promptParts pTpl.hasMultiModal = hasMultiModal return pTpl } const sourceKey = "sources_%s" func newPrompts(sp, up *promptTpl, model ModelWithInfo) *prompts { return &prompts{ sp: sp, up: up, mwi: model, } } func newPromptsWithChatHistory(prompts *prompts, cfg *vo.ChatHistorySetting) *promptsWithChatHistory { return &promptsWithChatHistory{ prompts: prompts, cfg: cfg, } } func (pl *promptTpl) render(ctx context.Context, vs map[string]any, sources map[string]*schema2.SourceInfo, supportedModals map[modelmgr.Modal]bool, ) (*schema.Message, error) { isChatFlow := execute.GetExeCtx(ctx).ExeCfg.WorkflowMode == workflow.WorkflowMode_ChatFlow userMessage := execute.GetExeCtx(ctx).ExeCfg.UserMessage if !isChatFlow { if !pl.hasMultiModal || len(supportedModals) == 0 { var opts []nodes.RenderOption if len(pl.reservedKeys) > 0 { opts = append(opts, nodes.WithReservedKey(pl.reservedKeys...)) } r, err := nodes.Render(ctx, pl.tpl, vs, sources, opts...) if err != nil { return nil, err } return &schema.Message{ Role: pl.role, Content: r, }, nil } } else { if (!pl.hasMultiModal || len(supportedModals) == 0) && (len(pl.associateUserInputFields) == 0 || (len(pl.associateUserInputFields) > 0 && userMessage != nil && userMessage.MultiContent == nil)) { var opts []nodes.RenderOption if len(pl.reservedKeys) > 0 { opts = append(opts, nodes.WithReservedKey(pl.reservedKeys...)) } r, err := nodes.Render(ctx, pl.tpl, vs, sources, opts...) if err != nil { return nil, err } return &schema.Message{ Role: pl.role, Content: r, }, nil } } multiParts := make([]schema.ChatMessagePart, 0, len(pl.parts)) m, err := sonic.Marshal(vs) if err != nil { return nil, err } for _, part := range pl.parts { if !part.part.IsVariable { multiParts = append(multiParts, schema.ChatMessagePart{ Type: schema.ChatMessagePartTypeText, Text: part.part.Value, }) continue } if _, ok := pl.associateUserInputFields[part.part.Value]; ok && userMessage != nil && isChatFlow { for _, p := range userMessage.MultiContent { multiParts = append(multiParts, transformMessagePart(p, supportedModals)) } continue } skipped, invalid := part.part.Skipped(sources) if invalid { var reserved bool for _, k := range pl.reservedKeys { if k == part.part.Root { reserved = true break } } if !reserved { continue } } if skipped { continue } r, err := part.part.Render(m) if err != nil { return nil, err } if part.fileType == nil { multiParts = append(multiParts, schema.ChatMessagePart{ Type: schema.ChatMessagePartTypeText, Text: r, }) continue } var originalPart schema.ChatMessagePart switch *part.fileType { case vo.FileTypeImage, vo.FileTypeSVG: originalPart = schema.ChatMessagePart{ Type: schema.ChatMessagePartTypeImageURL, ImageURL: &schema.ChatMessageImageURL{ URL: r, }, } case vo.FileTypeAudio, vo.FileTypeVoice: originalPart = schema.ChatMessagePart{ Type: schema.ChatMessagePartTypeAudioURL, AudioURL: &schema.ChatMessageAudioURL{ URL: r, }, } case vo.FileTypeVideo: originalPart = schema.ChatMessagePart{ Type: schema.ChatMessagePartTypeVideoURL, VideoURL: &schema.ChatMessageVideoURL{ URL: r, }, } default: originalPart = schema.ChatMessagePart{ Type: schema.ChatMessagePartTypeFileURL, FileURL: &schema.ChatMessageFileURL{ URL: r, }, } } multiParts = append(multiParts, transformMessagePart(originalPart, supportedModals)) } return &schema.Message{ Role: pl.role, MultiContent: multiParts, }, nil } func transformMessagePart(part schema.ChatMessagePart, supportedModals map[modelmgr.Modal]bool) schema.ChatMessagePart { switch part.Type { case schema.ChatMessagePartTypeImageURL: if _, ok := supportedModals[modelmgr.ModalImage]; !ok { return schema.ChatMessagePart{ Type: schema.ChatMessagePartTypeText, Text: part.ImageURL.URL, } } case schema.ChatMessagePartTypeAudioURL: if _, ok := supportedModals[modelmgr.ModalAudio]; !ok { return schema.ChatMessagePart{ Type: schema.ChatMessagePartTypeText, Text: part.AudioURL.URL, } } case schema.ChatMessagePartTypeVideoURL: if _, ok := supportedModals[modelmgr.ModalVideo]; !ok { return schema.ChatMessagePart{ Type: schema.ChatMessagePartTypeText, Text: part.VideoURL.URL, } } case schema.ChatMessagePartTypeFileURL: if _, ok := supportedModals[modelmgr.ModalFile]; !ok { return schema.ChatMessagePart{ Type: schema.ChatMessagePartTypeText, Text: part.FileURL.URL, } } } return part } func (p *prompts) Format(ctx context.Context, vs map[string]any, _ ...prompt.Option) ( _ []*schema.Message, err error, ) { exeCtx := execute.GetExeCtx(ctx) var nodeKey vo.NodeKey if exeCtx != nil && exeCtx.NodeCtx != nil { nodeKey = exeCtx.NodeCtx.NodeKey } sk := fmt.Sprintf(sourceKey, nodeKey) sources, ok := ctxcache.Get[map[string]*schema2.SourceInfo](ctx, sk) if !ok { return nil, fmt.Errorf("resolved sources not found llm node, key: %s", sk) } supportedModal := map[modelmgr.Modal]bool{} mInfo := p.mwi.Info(ctx) if mInfo != nil { for i := range mInfo.Meta.Capability.InputModal { supportedModal[mInfo.Meta.Capability.InputModal[i]] = true } } var systemMsg, userMsg *schema.Message if p.sp != nil { systemMsg, err = p.sp.render(ctx, vs, sources, supportedModal) if err != nil { return nil, err } } if p.up != nil { userMsg, err = p.up.render(ctx, vs, sources, supportedModal) if err != nil { return nil, err } } if userMsg == nil { // give it a default empty message. // Some model may fail on empty message such as this one. userMsg = schema.UserMessage("") } if systemMsg == nil { return []*schema.Message{userMsg}, nil } return []*schema.Message{systemMsg, userMsg}, nil } func (p *promptsWithChatHistory) Format(ctx context.Context, vs map[string]any, _ ...prompt.Option) ( []*schema.Message, error) { baseMessages, err := p.prompts.Format(ctx, vs) if err != nil { return nil, err } if p.cfg == nil || !p.cfg.EnableChatHistory { return baseMessages, nil } exeCtx := execute.GetExeCtx(ctx) if exeCtx == nil { logs.CtxWarnf(ctx, "execute context is nil, skipping chat history") return baseMessages, nil } if exeCtx.ExeCfg.WorkflowMode != workflow.WorkflowMode_ChatFlow { return baseMessages, nil } historyMessages, ok := ctxcache.Get[[]*schema.Message](ctx, chatHistoryKey) if !ok || len(historyMessages) == 0 { logs.CtxWarnf(ctx, "conversation history is empty") return baseMessages, nil } if len(historyMessages) == 0 { return baseMessages, nil } finalMessages := make([]*schema.Message, 0, len(baseMessages)+len(historyMessages)) if len(baseMessages) > 0 && baseMessages[0].Role == schema.System { finalMessages = append(finalMessages, baseMessages[0]) baseMessages = baseMessages[1:] } finalMessages = append(finalMessages, historyMessages...) finalMessages = append(finalMessages, baseMessages...) return finalMessages, nil }