291 lines
6.8 KiB
Go
291 lines
6.8 KiB
Go
/*
|
|
* Copyright 2025 coze-dev Authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package llm
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
"github.com/cloudwego/eino/components/prompt"
|
|
"github.com/cloudwego/eino/schema"
|
|
|
|
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
|
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
|
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
|
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
|
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
|
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
|
)
|
|
|
|
type prompts struct {
|
|
sp *promptTpl
|
|
up *promptTpl
|
|
mwi ModelWithInfo
|
|
}
|
|
|
|
type promptTpl struct {
|
|
role schema.RoleType
|
|
tpl string
|
|
parts []promptPart
|
|
hasMultiModal bool
|
|
reservedKeys []string
|
|
}
|
|
|
|
type promptPart struct {
|
|
part nodes.TemplatePart
|
|
fileType *vo.FileSubType
|
|
}
|
|
|
|
func newPromptTpl(role schema.RoleType,
|
|
tpl string,
|
|
inputTypes map[string]*vo.TypeInfo,
|
|
reservedKeys []string,
|
|
) *promptTpl {
|
|
if len(tpl) == 0 {
|
|
return nil
|
|
}
|
|
|
|
parts := nodes.ParseTemplate(tpl)
|
|
promptParts := make([]promptPart, 0, len(parts))
|
|
hasMultiModal := false
|
|
for _, part := range parts {
|
|
if !part.IsVariable {
|
|
promptParts = append(promptParts, promptPart{
|
|
part: part,
|
|
})
|
|
|
|
continue
|
|
}
|
|
|
|
tInfo := part.TypeInfo(inputTypes)
|
|
if tInfo == nil || tInfo.Type != vo.DataTypeFile {
|
|
promptParts = append(promptParts, promptPart{
|
|
part: part,
|
|
})
|
|
continue
|
|
}
|
|
|
|
promptParts = append(promptParts, promptPart{
|
|
part: part,
|
|
fileType: tInfo.FileType,
|
|
})
|
|
|
|
hasMultiModal = true
|
|
}
|
|
|
|
return &promptTpl{
|
|
role: role,
|
|
tpl: tpl,
|
|
parts: promptParts,
|
|
hasMultiModal: hasMultiModal,
|
|
reservedKeys: reservedKeys,
|
|
}
|
|
}
|
|
|
|
const sourceKey = "sources_%s"
|
|
|
|
func newPrompts(sp, up *promptTpl, model ModelWithInfo) *prompts {
|
|
return &prompts{
|
|
sp: sp,
|
|
up: up,
|
|
mwi: model,
|
|
}
|
|
}
|
|
|
|
func (pl *promptTpl) render(ctx context.Context, vs map[string]any,
|
|
sources map[string]*schema2.SourceInfo,
|
|
supportedModals map[modelmgr.Modal]bool,
|
|
) (*schema.Message, error) {
|
|
if !pl.hasMultiModal || len(supportedModals) == 0 {
|
|
var opts []nodes.RenderOption
|
|
if len(pl.reservedKeys) > 0 {
|
|
opts = append(opts, nodes.WithReservedKey(pl.reservedKeys...))
|
|
}
|
|
r, err := nodes.Render(ctx, pl.tpl, vs, sources, opts...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &schema.Message{
|
|
Role: pl.role,
|
|
Content: r,
|
|
}, nil
|
|
}
|
|
|
|
multiParts := make([]schema.ChatMessagePart, 0, len(pl.parts))
|
|
m, err := sonic.Marshal(vs)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, part := range pl.parts {
|
|
if !part.part.IsVariable {
|
|
multiParts = append(multiParts, schema.ChatMessagePart{
|
|
Type: schema.ChatMessagePartTypeText,
|
|
Text: part.part.Value,
|
|
})
|
|
continue
|
|
}
|
|
|
|
skipped, invalid := part.part.Skipped(sources)
|
|
if invalid {
|
|
var reserved bool
|
|
for _, k := range pl.reservedKeys {
|
|
if k == part.part.Root {
|
|
reserved = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !reserved {
|
|
continue
|
|
}
|
|
}
|
|
|
|
if skipped {
|
|
continue
|
|
}
|
|
|
|
r, err := part.part.Render(m)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if part.fileType == nil {
|
|
multiParts = append(multiParts, schema.ChatMessagePart{
|
|
Type: schema.ChatMessagePartTypeText,
|
|
Text: r,
|
|
})
|
|
continue
|
|
}
|
|
|
|
switch *part.fileType {
|
|
case vo.FileTypeImage, vo.FileTypeSVG:
|
|
if _, ok := supportedModals[modelmgr.ModalImage]; !ok {
|
|
multiParts = append(multiParts, schema.ChatMessagePart{
|
|
Type: schema.ChatMessagePartTypeText,
|
|
Text: r,
|
|
})
|
|
} else {
|
|
multiParts = append(multiParts, schema.ChatMessagePart{
|
|
Type: schema.ChatMessagePartTypeImageURL,
|
|
ImageURL: &schema.ChatMessageImageURL{
|
|
URL: r,
|
|
},
|
|
})
|
|
}
|
|
case vo.FileTypeAudio, vo.FileTypeVoice:
|
|
if _, ok := supportedModals[modelmgr.ModalAudio]; !ok {
|
|
multiParts = append(multiParts, schema.ChatMessagePart{
|
|
Type: schema.ChatMessagePartTypeText,
|
|
Text: r,
|
|
})
|
|
} else {
|
|
multiParts = append(multiParts, schema.ChatMessagePart{
|
|
Type: schema.ChatMessagePartTypeAudioURL,
|
|
AudioURL: &schema.ChatMessageAudioURL{
|
|
URL: r,
|
|
},
|
|
})
|
|
}
|
|
case vo.FileTypeVideo:
|
|
if _, ok := supportedModals[modelmgr.ModalVideo]; !ok {
|
|
multiParts = append(multiParts, schema.ChatMessagePart{
|
|
Type: schema.ChatMessagePartTypeText,
|
|
Text: r,
|
|
})
|
|
} else {
|
|
multiParts = append(multiParts, schema.ChatMessagePart{
|
|
Type: schema.ChatMessagePartTypeVideoURL,
|
|
VideoURL: &schema.ChatMessageVideoURL{
|
|
URL: r,
|
|
},
|
|
})
|
|
}
|
|
default:
|
|
if _, ok := supportedModals[modelmgr.ModalFile]; !ok {
|
|
multiParts = append(multiParts, schema.ChatMessagePart{
|
|
Type: schema.ChatMessagePartTypeText,
|
|
Text: r,
|
|
})
|
|
} else {
|
|
multiParts = append(multiParts, schema.ChatMessagePart{
|
|
Type: schema.ChatMessagePartTypeFileURL,
|
|
FileURL: &schema.ChatMessageFileURL{
|
|
URL: r,
|
|
},
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
return &schema.Message{
|
|
Role: pl.role,
|
|
MultiContent: multiParts,
|
|
}, nil
|
|
}
|
|
|
|
func (p *prompts) Format(ctx context.Context, vs map[string]any, _ ...prompt.Option) (
|
|
_ []*schema.Message, err error,
|
|
) {
|
|
exeCtx := execute.GetExeCtx(ctx)
|
|
var nodeKey vo.NodeKey
|
|
if exeCtx != nil && exeCtx.NodeCtx != nil {
|
|
nodeKey = exeCtx.NodeCtx.NodeKey
|
|
}
|
|
sk := fmt.Sprintf(sourceKey, nodeKey)
|
|
|
|
sources, ok := ctxcache.Get[map[string]*schema2.SourceInfo](ctx, sk)
|
|
if !ok {
|
|
return nil, fmt.Errorf("resolved sources not found llm node, key: %s", sk)
|
|
}
|
|
|
|
supportedModal := map[modelmgr.Modal]bool{}
|
|
mInfo := p.mwi.Info(ctx)
|
|
if mInfo != nil {
|
|
for i := range mInfo.Meta.Capability.InputModal {
|
|
supportedModal[mInfo.Meta.Capability.InputModal[i]] = true
|
|
}
|
|
}
|
|
|
|
var systemMsg, userMsg *schema.Message
|
|
if p.sp != nil {
|
|
systemMsg, err = p.sp.render(ctx, vs, sources, supportedModal)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
if p.up != nil {
|
|
userMsg, err = p.up.render(ctx, vs, sources, supportedModal)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
if userMsg == nil {
|
|
// give it a default empty message.
|
|
// Some model may fail on empty message such as this one.
|
|
userMsg = schema.UserMessage("")
|
|
}
|
|
|
|
if systemMsg == nil {
|
|
return []*schema.Message{userMsg}, nil
|
|
}
|
|
|
|
return []*schema.Message{systemMsg, userMsg}, nil
|
|
}
|