coze-studio/backend/domain/workflow/internal/nodes/llm/prompt.go

403 lines
9.9 KiB
Go

/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package llm
import (
"context"
"fmt"
"github.com/cloudwego/eino/components/prompt"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/api/model/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
)
type prompts struct {
sp *promptTpl
up *promptTpl
mwi ModelWithInfo
}
type promptsWithChatHistory struct {
prompts *prompts
cfg *vo.ChatHistorySetting
}
func withReservedKeys(keys []string) func(tpl *promptTpl) {
return func(tpl *promptTpl) {
tpl.reservedKeys = keys
}
}
func withAssociateUserInputFields(fs map[string]struct{}) func(tpl *promptTpl) {
return func(tpl *promptTpl) {
tpl.associateUserInputFields = fs
}
}
type promptTpl struct {
role schema.RoleType
tpl string
parts []promptPart
hasMultiModal bool
reservedKeys []string
associateUserInputFields map[string]struct{}
}
type promptPart struct {
part nodes.TemplatePart
fileType *vo.FileSubType
}
func newPromptTpl(role schema.RoleType,
tpl string,
inputTypes map[string]*vo.TypeInfo,
opts ...func(*promptTpl),
) *promptTpl {
if len(tpl) == 0 {
return nil
}
pTpl := &promptTpl{
role: role,
tpl: tpl,
}
for _, opt := range opts {
opt(pTpl)
}
parts := nodes.ParseTemplate(tpl)
promptParts := make([]promptPart, 0, len(parts))
hasMultiModal := false
for _, part := range parts {
if !part.IsVariable {
promptParts = append(promptParts, promptPart{
part: part,
})
continue
}
tInfo := part.TypeInfo(inputTypes)
if tInfo == nil || tInfo.Type != vo.DataTypeFile {
promptParts = append(promptParts, promptPart{
part: part,
})
continue
}
promptParts = append(promptParts, promptPart{
part: part,
fileType: tInfo.FileType,
})
hasMultiModal = true
}
pTpl.parts = promptParts
pTpl.hasMultiModal = hasMultiModal
return pTpl
}
const sourceKey = "sources_%s"
func newPrompts(sp, up *promptTpl, model ModelWithInfo) *prompts {
return &prompts{
sp: sp,
up: up,
mwi: model,
}
}
func newPromptsWithChatHistory(prompts *prompts, cfg *vo.ChatHistorySetting) *promptsWithChatHistory {
return &promptsWithChatHistory{
prompts: prompts,
cfg: cfg,
}
}
func (pl *promptTpl) render(ctx context.Context, vs map[string]any,
sources map[string]*schema2.SourceInfo,
supportedModals map[modelmgr.Modal]bool,
) (*schema.Message, error) {
isChatFlow := execute.GetExeCtx(ctx).ExeCfg.WorkflowMode == workflow.WorkflowMode_ChatFlow
userMessage := execute.GetExeCtx(ctx).ExeCfg.UserMessage
if !isChatFlow {
if !pl.hasMultiModal || len(supportedModals) == 0 {
var opts []nodes.RenderOption
if len(pl.reservedKeys) > 0 {
opts = append(opts, nodes.WithReservedKey(pl.reservedKeys...))
}
r, err := nodes.Render(ctx, pl.tpl, vs, sources, opts...)
if err != nil {
return nil, err
}
return &schema.Message{
Role: pl.role,
Content: r,
}, nil
}
} else {
if (!pl.hasMultiModal || len(supportedModals) == 0) &&
(len(pl.associateUserInputFields) == 0 ||
(len(pl.associateUserInputFields) > 0 && userMessage != nil && userMessage.MultiContent == nil)) {
var opts []nodes.RenderOption
if len(pl.reservedKeys) > 0 {
opts = append(opts, nodes.WithReservedKey(pl.reservedKeys...))
}
r, err := nodes.Render(ctx, pl.tpl, vs, sources, opts...)
if err != nil {
return nil, err
}
return &schema.Message{
Role: pl.role,
Content: r,
}, nil
}
}
multiParts := make([]schema.ChatMessagePart, 0, len(pl.parts))
m, err := sonic.Marshal(vs)
if err != nil {
return nil, err
}
for _, part := range pl.parts {
if !part.part.IsVariable {
multiParts = append(multiParts, schema.ChatMessagePart{
Type: schema.ChatMessagePartTypeText,
Text: part.part.Value,
})
continue
}
if _, ok := pl.associateUserInputFields[part.part.Value]; ok && userMessage != nil && isChatFlow {
for _, p := range userMessage.MultiContent {
multiParts = append(multiParts, transformMessagePart(p, supportedModals))
}
continue
}
skipped, invalid := part.part.Skipped(sources)
if invalid {
var reserved bool
for _, k := range pl.reservedKeys {
if k == part.part.Root {
reserved = true
break
}
}
if !reserved {
continue
}
}
if skipped {
continue
}
r, err := part.part.Render(m)
if err != nil {
return nil, err
}
if part.fileType == nil {
multiParts = append(multiParts, schema.ChatMessagePart{
Type: schema.ChatMessagePartTypeText,
Text: r,
})
continue
}
var originalPart schema.ChatMessagePart
switch *part.fileType {
case vo.FileTypeImage, vo.FileTypeSVG:
originalPart = schema.ChatMessagePart{
Type: schema.ChatMessagePartTypeImageURL,
ImageURL: &schema.ChatMessageImageURL{
URL: r,
},
}
case vo.FileTypeAudio, vo.FileTypeVoice:
originalPart = schema.ChatMessagePart{
Type: schema.ChatMessagePartTypeAudioURL,
AudioURL: &schema.ChatMessageAudioURL{
URL: r,
},
}
case vo.FileTypeVideo:
originalPart = schema.ChatMessagePart{
Type: schema.ChatMessagePartTypeVideoURL,
VideoURL: &schema.ChatMessageVideoURL{
URL: r,
},
}
default:
originalPart = schema.ChatMessagePart{
Type: schema.ChatMessagePartTypeFileURL,
FileURL: &schema.ChatMessageFileURL{
URL: r,
},
}
}
multiParts = append(multiParts, transformMessagePart(originalPart, supportedModals))
}
return &schema.Message{
Role: pl.role,
MultiContent: multiParts,
}, nil
}
func transformMessagePart(part schema.ChatMessagePart, supportedModals map[modelmgr.Modal]bool) schema.ChatMessagePart {
switch part.Type {
case schema.ChatMessagePartTypeImageURL:
if _, ok := supportedModals[modelmgr.ModalImage]; !ok {
return schema.ChatMessagePart{
Type: schema.ChatMessagePartTypeText,
Text: part.ImageURL.URL,
}
}
case schema.ChatMessagePartTypeAudioURL:
if _, ok := supportedModals[modelmgr.ModalAudio]; !ok {
return schema.ChatMessagePart{
Type: schema.ChatMessagePartTypeText,
Text: part.AudioURL.URL,
}
}
case schema.ChatMessagePartTypeVideoURL:
if _, ok := supportedModals[modelmgr.ModalVideo]; !ok {
return schema.ChatMessagePart{
Type: schema.ChatMessagePartTypeText,
Text: part.VideoURL.URL,
}
}
case schema.ChatMessagePartTypeFileURL:
if _, ok := supportedModals[modelmgr.ModalFile]; !ok {
return schema.ChatMessagePart{
Type: schema.ChatMessagePartTypeText,
Text: part.FileURL.URL,
}
}
}
return part
}
func (p *prompts) Format(ctx context.Context, vs map[string]any, _ ...prompt.Option) (
_ []*schema.Message, err error,
) {
exeCtx := execute.GetExeCtx(ctx)
var nodeKey vo.NodeKey
if exeCtx != nil && exeCtx.NodeCtx != nil {
nodeKey = exeCtx.NodeCtx.NodeKey
}
sk := fmt.Sprintf(sourceKey, nodeKey)
sources, ok := ctxcache.Get[map[string]*schema2.SourceInfo](ctx, sk)
if !ok {
return nil, fmt.Errorf("resolved sources not found llm node, key: %s", sk)
}
supportedModal := map[modelmgr.Modal]bool{}
mInfo := p.mwi.Info(ctx)
if mInfo != nil {
for i := range mInfo.Meta.Capability.InputModal {
supportedModal[mInfo.Meta.Capability.InputModal[i]] = true
}
}
var systemMsg, userMsg *schema.Message
if p.sp != nil {
systemMsg, err = p.sp.render(ctx, vs, sources, supportedModal)
if err != nil {
return nil, err
}
}
if p.up != nil {
userMsg, err = p.up.render(ctx, vs, sources, supportedModal)
if err != nil {
return nil, err
}
}
if userMsg == nil {
// give it a default empty message.
// Some model may fail on empty message such as this one.
userMsg = schema.UserMessage("")
}
if systemMsg == nil {
return []*schema.Message{userMsg}, nil
}
return []*schema.Message{systemMsg, userMsg}, nil
}
func (p *promptsWithChatHistory) Format(ctx context.Context, vs map[string]any, _ ...prompt.Option) (
[]*schema.Message, error) {
baseMessages, err := p.prompts.Format(ctx, vs)
if err != nil {
return nil, err
}
if p.cfg == nil || !p.cfg.EnableChatHistory {
return baseMessages, nil
}
exeCtx := execute.GetExeCtx(ctx)
if exeCtx == nil {
logs.CtxWarnf(ctx, "execute context is nil, skipping chat history")
return baseMessages, nil
}
if exeCtx.ExeCfg.WorkflowMode != workflow.WorkflowMode_ChatFlow {
return baseMessages, nil
}
historyMessages, ok := ctxcache.Get[[]*schema.Message](ctx, chatHistoryKey)
if !ok || len(historyMessages) == 0 {
logs.CtxWarnf(ctx, "conversation history is empty")
return baseMessages, nil
}
if len(historyMessages) == 0 {
return baseMessages, nil
}
finalMessages := make([]*schema.Message, 0, len(baseMessages)+len(historyMessages))
if len(baseMessages) > 0 && baseMessages[0].Role == schema.System {
finalMessages = append(finalMessages, baseMessages[0])
baseMessages = baseMessages[1:]
}
finalMessages = append(finalMessages, historyMessages...)
finalMessages = append(finalMessages, baseMessages...)
return finalMessages, nil
}