feat: Support for Chat Flow & Agent Support for binding a single chat flow (#765)

Co-authored-by: Yu Yang <72337138+tomasyu985@users.noreply.github.com>
Co-authored-by: zengxiaohui <csu.zengxiaohui@gmail.com>
Co-authored-by: lijunwen.gigoo <lijunwen.gigoo@bytedance.com>
Co-authored-by: lvxinyu.1117 <lvxinyu.1117@bytedance.com>
Co-authored-by: liuyunchao.0510 <liuyunchao.0510@bytedance.com>
Co-authored-by: haozhenfei <37089575+haozhenfei@users.noreply.github.com>
Co-authored-by: July <jiangxujin@bytedance.com>
Co-authored-by: tecvan-fe <fanwenjie.fe@bytedance.com>
This commit is contained in:
Zhj
2025-08-28 21:53:32 +08:00
committed by GitHub
parent bbc615a18e
commit d70101c979
503 changed files with 48036 additions and 3427 deletions

View File

@@ -40,6 +40,7 @@ import (
workflowModel "github.com/coze-dev/coze-studio/backend/api/model/crossdomain/workflow"
workflow3 "github.com/coze-dev/coze-studio/backend/api/model/workflow"
crossknowledge "github.com/coze-dev/coze-studio/backend/crossdomain/contract/knowledge"
crossmessage "github.com/coze-dev/coze-studio/backend/crossdomain/contract/message"
crossmodelmgr "github.com/coze-dev/coze-studio/backend/crossdomain/contract/modelmgr"
crossplugin "github.com/coze-dev/coze-studio/backend/crossdomain/contract/plugin"
"github.com/coze-dev/coze-studio/backend/domain/workflow"
@@ -59,6 +60,10 @@ import (
"github.com/coze-dev/coze-studio/backend/types/errno"
)
type contextKey string
const chatHistoryKey contextKey = "chatHistory"
type Format int
const (
@@ -167,12 +172,14 @@ type KnowledgeRecallConfig struct {
}
type Config struct {
SystemPrompt string
UserPrompt string
OutputFormat Format
LLMParams *crossmodel.LLMParams
FCParam *vo.FCParam
BackupLLMParams *crossmodel.LLMParams
SystemPrompt string
UserPrompt string
OutputFormat Format
LLMParams *crossmodel.LLMParams
FCParam *vo.FCParam
BackupLLMParams *crossmodel.LLMParams
ChatHistorySetting *vo.ChatHistorySetting
AssociateStartNodeUserInputFields map[string]struct{}
}
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema2.NodeSchema, error) {
@@ -202,6 +209,13 @@ func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*
c.SystemPrompt = convertedLLMParam.SystemPrompt
c.UserPrompt = convertedLLMParam.Prompt
if convertedLLMParam.EnableChatHistory {
c.ChatHistorySetting = &vo.ChatHistorySetting{
EnableChatHistory: true,
ChatHistoryRound: convertedLLMParam.ChatHistoryRound,
}
}
var resFormat Format
switch convertedLLMParam.ResponseFormat {
case crossmodel.ResponseFormatText:
@@ -273,6 +287,15 @@ func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*
}
}
c.AssociateStartNodeUserInputFields = make(map[string]struct{})
for _, info := range ns.InputSources {
if len(info.Path) == 1 && info.Source.Ref != nil && info.Source.Ref.FromNodeKey == entity.EntryNodeKey {
if compose.FromFieldPath(info.Source.Ref.FromPath).Equals(compose.FromField("USER_INPUT")) {
c.AssociateStartNodeUserInputFields[info.Path[0]] = struct{}{}
}
}
}
return ns, nil
}
@@ -320,7 +343,14 @@ func llmParamsToLLMParam(params vo.LLMParam) (*crossmodel.LLMParams, error) {
case "systemPrompt":
strVal := param.Input.Value.Content.(string)
p.SystemPrompt = strVal
case "chatHistoryRound", "generationDiversity", "frequencyPenalty", "presencePenalty":
case "chatHistoryRound":
strVal := param.Input.Value.Content.(string)
int64Val, err := strconv.ParseInt(strVal, 10, 64)
if err != nil {
return nil, err
}
p.ChatHistoryRound = int64Val
case "generationDiversity", "frequencyPenalty", "presencePenalty":
// do nothing
case "topP":
strVal := param.Input.Value.Content.(string)
@@ -590,11 +620,12 @@ func (c *Config) Build(ctx context.Context, ns *schema2.NodeSchema, _ ...schema2
inputs[knowledgeUserPromptTemplateKey] = &vo.TypeInfo{
Type: vo.DataTypeString,
}
sp := newPromptTpl(schema.System, c.SystemPrompt, inputs, nil)
up := newPromptTpl(schema.User, userPrompt, inputs, []string{knowledgeUserPromptTemplateKey})
sp := newPromptTpl(schema.System, c.SystemPrompt, inputs)
up := newPromptTpl(schema.User, userPrompt, inputs, withReservedKeys([]string{knowledgeUserPromptTemplateKey}), withAssociateUserInputFields(c.AssociateStartNodeUserInputFields))
template := newPrompts(sp, up, modelWithInfo)
templateWithChatHistory := newPromptsWithChatHistory(template, c.ChatHistorySetting)
_ = g.AddChatTemplateNode(templateNodeKey, template,
_ = g.AddChatTemplateNode(templateNodeKey, templateWithChatHistory,
compose.WithStatePreHandler(func(ctx context.Context, in map[string]any, state llmState) (map[string]any, error) {
for k, v := range state {
in[k] = v
@@ -604,10 +635,12 @@ func (c *Config) Build(ctx context.Context, ns *schema2.NodeSchema, _ ...schema2
_ = g.AddEdge(knowledgeLambdaKey, templateNodeKey)
} else {
sp := newPromptTpl(schema.System, c.SystemPrompt, ns.InputTypes, nil)
up := newPromptTpl(schema.User, userPrompt, ns.InputTypes, nil)
sp := newPromptTpl(schema.System, c.SystemPrompt, ns.InputTypes)
up := newPromptTpl(schema.User, userPrompt, ns.InputTypes, withAssociateUserInputFields(c.AssociateStartNodeUserInputFields))
template := newPrompts(sp, up, modelWithInfo)
_ = g.AddChatTemplateNode(templateNodeKey, template)
templateWithChatHistory := newPromptsWithChatHistory(template, c.ChatHistorySetting)
_ = g.AddChatTemplateNode(templateNodeKey, templateWithChatHistory)
_ = g.AddEdge(compose.START, templateNodeKey)
}
@@ -747,10 +780,11 @@ func (c *Config) Build(ctx context.Context, ns *schema2.NodeSchema, _ ...schema2
}
llm := &LLM{
r: r,
outputFormat: format,
requireCheckpoint: requireCheckpoint,
fullSources: ns.FullSources,
r: r,
outputFormat: format,
requireCheckpoint: requireCheckpoint,
fullSources: ns.FullSources,
chatHistorySetting: c.ChatHistorySetting,
}
return llm, nil
@@ -825,10 +859,11 @@ func toRetrievalSearchType(s int64) (knowledge.SearchType, error) {
}
type LLM struct {
r compose.Runnable[map[string]any, map[string]any]
outputFormat Format
requireCheckpoint bool
fullSources map[string]*schema2.SourceInfo
r compose.Runnable[map[string]any, map[string]any]
outputFormat Format
requireCheckpoint bool
fullSources map[string]*schema2.SourceInfo
chatHistorySetting *vo.ChatHistorySetting
}
const (
@@ -1193,6 +1228,68 @@ type ToolInterruptEventStore interface {
ResumeToolInterruptEvent(llmNodeKey vo.NodeKey, toolCallID string) (string, error)
}
func (l *LLM) ToCallbackInput(ctx context.Context, input map[string]any) (map[string]any, error) {
if l.chatHistorySetting == nil || !l.chatHistorySetting.EnableChatHistory {
return input, nil
}
var messages []*crossmessage.WfMessage
var scMessages []*schema.Message
var sectionID *int64
execCtx := execute.GetExeCtx(ctx)
if execCtx != nil {
messages = execCtx.ExeCfg.ConversationHistory
scMessages = execCtx.ExeCfg.ConversationHistorySchemaMessages
sectionID = execCtx.ExeCfg.SectionID
}
ret := map[string]any{
"chatHistory": []any{},
}
maps.Copy(ret, input)
if len(messages) == 0 {
return ret, nil
}
if sectionID != nil && messages[0].SectionID != *sectionID {
return ret, nil
}
maxRounds := int(l.chatHistorySetting.ChatHistoryRound)
if execCtx != nil && execCtx.ExeCfg.MaxHistoryRounds != nil {
maxRounds = min(int(*execCtx.ExeCfg.MaxHistoryRounds), maxRounds)
}
count := 0
startIdx := 0
for i := len(messages) - 1; i >= 0; i-- {
if messages[i].Role == schema.User {
count++
}
if count >= maxRounds {
startIdx = i
break
}
}
var historyMessages []any
for _, msg := range messages[startIdx:] {
content, err := nodes.ConvertMessageToString(ctx, msg)
if err != nil {
logs.CtxWarnf(ctx, "failed to convert message to string: %v", err)
continue
}
historyMessages = append(historyMessages, map[string]any{
"role": string(msg.Role),
"content": content,
})
}
ctxcache.Store(ctx, chatHistoryKey, scMessages[startIdx:])
ret["chatHistory"] = historyMessages
return ret, nil
}
func (l *LLM) ToCallbackOutput(ctx context.Context, output map[string]any) (*nodes.StructuredCallbackOutput, error) {
c := execute.GetExeCtx(ctx)
if c == nil {

View File

@@ -23,12 +23,14 @@ import (
"github.com/cloudwego/eino/components/prompt"
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/api/model/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
)
@@ -38,12 +40,30 @@ type prompts struct {
mwi ModelWithInfo
}
type promptsWithChatHistory struct {
prompts *prompts
cfg *vo.ChatHistorySetting
}
func withReservedKeys(keys []string) func(tpl *promptTpl) {
return func(tpl *promptTpl) {
tpl.reservedKeys = keys
}
}
func withAssociateUserInputFields(fs map[string]struct{}) func(tpl *promptTpl) {
return func(tpl *promptTpl) {
tpl.associateUserInputFields = fs
}
}
type promptTpl struct {
role schema.RoleType
tpl string
parts []promptPart
hasMultiModal bool
reservedKeys []string
role schema.RoleType
tpl string
parts []promptPart
hasMultiModal bool
reservedKeys []string
associateUserInputFields map[string]struct{}
}
type promptPart struct {
@@ -54,12 +74,20 @@ type promptPart struct {
func newPromptTpl(role schema.RoleType,
tpl string,
inputTypes map[string]*vo.TypeInfo,
reservedKeys []string,
opts ...func(*promptTpl),
) *promptTpl {
if len(tpl) == 0 {
return nil
}
pTpl := &promptTpl{
role: role,
tpl: tpl,
}
for _, opt := range opts {
opt(pTpl)
}
parts := nodes.ParseTemplate(tpl)
promptParts := make([]promptPart, 0, len(parts))
hasMultiModal := false
@@ -87,14 +115,10 @@ func newPromptTpl(role schema.RoleType,
hasMultiModal = true
}
pTpl.parts = promptParts
pTpl.hasMultiModal = hasMultiModal
return &promptTpl{
role: role,
tpl: tpl,
parts: promptParts,
hasMultiModal: hasMultiModal,
reservedKeys: reservedKeys,
}
return pTpl
}
const sourceKey = "sources_%s"
@@ -107,23 +131,53 @@ func newPrompts(sp, up *promptTpl, model ModelWithInfo) *prompts {
}
}
func newPromptsWithChatHistory(prompts *prompts, cfg *vo.ChatHistorySetting) *promptsWithChatHistory {
return &promptsWithChatHistory{
prompts: prompts,
cfg: cfg,
}
}
func (pl *promptTpl) render(ctx context.Context, vs map[string]any,
sources map[string]*schema2.SourceInfo,
supportedModals map[modelmgr.Modal]bool,
) (*schema.Message, error) {
if !pl.hasMultiModal || len(supportedModals) == 0 {
var opts []nodes.RenderOption
if len(pl.reservedKeys) > 0 {
opts = append(opts, nodes.WithReservedKey(pl.reservedKeys...))
isChatFlow := execute.GetExeCtx(ctx).ExeCfg.WorkflowMode == workflow.WorkflowMode_ChatFlow
userMessage := execute.GetExeCtx(ctx).ExeCfg.UserMessage
if !isChatFlow {
if !pl.hasMultiModal || len(supportedModals) == 0 {
var opts []nodes.RenderOption
if len(pl.reservedKeys) > 0 {
opts = append(opts, nodes.WithReservedKey(pl.reservedKeys...))
}
r, err := nodes.Render(ctx, pl.tpl, vs, sources, opts...)
if err != nil {
return nil, err
}
return &schema.Message{
Role: pl.role,
Content: r,
}, nil
}
r, err := nodes.Render(ctx, pl.tpl, vs, sources, opts...)
if err != nil {
return nil, err
} else {
if (!pl.hasMultiModal || len(supportedModals) == 0) &&
(len(pl.associateUserInputFields) == 0 ||
(len(pl.associateUserInputFields) > 0 && userMessage != nil && userMessage.MultiContent == nil)) {
var opts []nodes.RenderOption
if len(pl.reservedKeys) > 0 {
opts = append(opts, nodes.WithReservedKey(pl.reservedKeys...))
}
r, err := nodes.Render(ctx, pl.tpl, vs, sources, opts...)
if err != nil {
return nil, err
}
return &schema.Message{
Role: pl.role,
Content: r,
}, nil
}
return &schema.Message{
Role: pl.role,
Content: r,
}, nil
}
multiParts := make([]schema.ChatMessagePart, 0, len(pl.parts))
@@ -141,6 +195,13 @@ func (pl *promptTpl) render(ctx context.Context, vs map[string]any,
continue
}
if _, ok := pl.associateUserInputFields[part.part.Value]; ok && userMessage != nil && isChatFlow {
for _, p := range userMessage.MultiContent {
multiParts = append(multiParts, transformMessagePart(p, supportedModals))
}
continue
}
skipped, invalid := part.part.Skipped(sources)
if invalid {
var reserved bool
@@ -164,6 +225,7 @@ func (pl *promptTpl) render(ctx context.Context, vs map[string]any,
if err != nil {
return nil, err
}
if part.fileType == nil {
multiParts = append(multiParts, schema.ChatMessagePart{
Type: schema.ChatMessagePartTypeText,
@@ -172,64 +234,38 @@ func (pl *promptTpl) render(ctx context.Context, vs map[string]any,
continue
}
var originalPart schema.ChatMessagePart
switch *part.fileType {
case vo.FileTypeImage, vo.FileTypeSVG:
if _, ok := supportedModals[modelmgr.ModalImage]; !ok {
multiParts = append(multiParts, schema.ChatMessagePart{
Type: schema.ChatMessagePartTypeText,
Text: r,
})
} else {
multiParts = append(multiParts, schema.ChatMessagePart{
Type: schema.ChatMessagePartTypeImageURL,
ImageURL: &schema.ChatMessageImageURL{
URL: r,
},
})
originalPart = schema.ChatMessagePart{
Type: schema.ChatMessagePartTypeImageURL,
ImageURL: &schema.ChatMessageImageURL{
URL: r,
},
}
case vo.FileTypeAudio, vo.FileTypeVoice:
if _, ok := supportedModals[modelmgr.ModalAudio]; !ok {
multiParts = append(multiParts, schema.ChatMessagePart{
Type: schema.ChatMessagePartTypeText,
Text: r,
})
} else {
multiParts = append(multiParts, schema.ChatMessagePart{
Type: schema.ChatMessagePartTypeAudioURL,
AudioURL: &schema.ChatMessageAudioURL{
URL: r,
},
})
originalPart = schema.ChatMessagePart{
Type: schema.ChatMessagePartTypeAudioURL,
AudioURL: &schema.ChatMessageAudioURL{
URL: r,
},
}
case vo.FileTypeVideo:
if _, ok := supportedModals[modelmgr.ModalVideo]; !ok {
multiParts = append(multiParts, schema.ChatMessagePart{
Type: schema.ChatMessagePartTypeText,
Text: r,
})
} else {
multiParts = append(multiParts, schema.ChatMessagePart{
Type: schema.ChatMessagePartTypeVideoURL,
VideoURL: &schema.ChatMessageVideoURL{
URL: r,
},
})
originalPart = schema.ChatMessagePart{
Type: schema.ChatMessagePartTypeVideoURL,
VideoURL: &schema.ChatMessageVideoURL{
URL: r,
},
}
default:
if _, ok := supportedModals[modelmgr.ModalFile]; !ok {
multiParts = append(multiParts, schema.ChatMessagePart{
Type: schema.ChatMessagePartTypeText,
Text: r,
})
} else {
multiParts = append(multiParts, schema.ChatMessagePart{
Type: schema.ChatMessagePartTypeFileURL,
FileURL: &schema.ChatMessageFileURL{
URL: r,
},
})
originalPart = schema.ChatMessagePart{
Type: schema.ChatMessagePartTypeFileURL,
FileURL: &schema.ChatMessageFileURL{
URL: r,
},
}
}
multiParts = append(multiParts, transformMessagePart(originalPart, supportedModals))
}
return &schema.Message{
@@ -238,6 +274,40 @@ func (pl *promptTpl) render(ctx context.Context, vs map[string]any,
}, nil
}
func transformMessagePart(part schema.ChatMessagePart, supportedModals map[modelmgr.Modal]bool) schema.ChatMessagePart {
switch part.Type {
case schema.ChatMessagePartTypeImageURL:
if _, ok := supportedModals[modelmgr.ModalImage]; !ok {
return schema.ChatMessagePart{
Type: schema.ChatMessagePartTypeText,
Text: part.ImageURL.URL,
}
}
case schema.ChatMessagePartTypeAudioURL:
if _, ok := supportedModals[modelmgr.ModalAudio]; !ok {
return schema.ChatMessagePart{
Type: schema.ChatMessagePartTypeText,
Text: part.AudioURL.URL,
}
}
case schema.ChatMessagePartTypeVideoURL:
if _, ok := supportedModals[modelmgr.ModalVideo]; !ok {
return schema.ChatMessagePart{
Type: schema.ChatMessagePartTypeText,
Text: part.VideoURL.URL,
}
}
case schema.ChatMessagePartTypeFileURL:
if _, ok := supportedModals[modelmgr.ModalFile]; !ok {
return schema.ChatMessagePart{
Type: schema.ChatMessagePartTypeText,
Text: part.FileURL.URL,
}
}
}
return part
}
func (p *prompts) Format(ctx context.Context, vs map[string]any, _ ...prompt.Option) (
_ []*schema.Message, err error,
) {
@@ -288,3 +358,45 @@ func (p *prompts) Format(ctx context.Context, vs map[string]any, _ ...prompt.Opt
return []*schema.Message{systemMsg, userMsg}, nil
}
func (p *promptsWithChatHistory) Format(ctx context.Context, vs map[string]any, _ ...prompt.Option) (
[]*schema.Message, error) {
baseMessages, err := p.prompts.Format(ctx, vs)
if err != nil {
return nil, err
}
if p.cfg == nil || !p.cfg.EnableChatHistory {
return baseMessages, nil
}
exeCtx := execute.GetExeCtx(ctx)
if exeCtx == nil {
logs.CtxWarnf(ctx, "execute context is nil, skipping chat history")
return baseMessages, nil
}
if exeCtx.ExeCfg.WorkflowMode != workflow.WorkflowMode_ChatFlow {
return baseMessages, nil
}
historyMessages, ok := ctxcache.Get[[]*schema.Message](ctx, chatHistoryKey)
if !ok || len(historyMessages) == 0 {
logs.CtxWarnf(ctx, "conversation history is empty")
return baseMessages, nil
}
if len(historyMessages) == 0 {
return baseMessages, nil
}
finalMessages := make([]*schema.Message, 0, len(baseMessages)+len(historyMessages))
if len(baseMessages) > 0 && baseMessages[0].Role == schema.System {
finalMessages = append(finalMessages, baseMessages[0])
baseMessages = baseMessages[1:]
}
finalMessages = append(finalMessages, historyMessages...)
finalMessages = append(finalMessages, baseMessages...)
return finalMessages, nil
}