refactor: how to add a node type in workflow (#558)

This commit is contained in:
shentongmartin
2025-08-05 14:02:33 +08:00
committed by GitHub
parent 5dafd81a3f
commit bb6ff0026b
96 changed files with 8305 additions and 8717 deletions

View File

@@ -34,13 +34,20 @@ import (
callbacks2 "github.com/cloudwego/eino/utils/callbacks"
"golang.org/x/exp/maps"
workflow3 "github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/workflow"
"github.com/coze-dev/coze-studio/backend/domain/workflow"
crossknowledge "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
crossmodel "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
@@ -143,126 +150,408 @@ const (
)
type RetrievalStrategy struct {
RetrievalStrategy *crossknowledge.RetrievalStrategy
RetrievalStrategy *knowledge.RetrievalStrategy
NoReCallReplyMode NoReCallReplyMode
NoReCallReplyCustomizePrompt string
}
type KnowledgeRecallConfig struct {
ChatModel model.BaseChatModel
Retriever crossknowledge.KnowledgeOperator
Retriever knowledge.KnowledgeOperator
RetrievalStrategy *RetrievalStrategy
SelectedKnowledgeDetails []*crossknowledge.KnowledgeDetail
SelectedKnowledgeDetails []*knowledge.KnowledgeDetail
}
type Config struct {
ChatModel ModelWithInfo
Tools []tool.BaseTool
SystemPrompt string
UserPrompt string
OutputFormat Format
InputFields map[string]*vo.TypeInfo
OutputFields map[string]*vo.TypeInfo
ToolsReturnDirectly map[string]bool
KnowledgeRecallConfig *KnowledgeRecallConfig
FullSources map[string]*nodes.SourceInfo
SystemPrompt string
UserPrompt string
OutputFormat Format
LLMParams *crossmodel.LLMParams
FCParam *vo.FCParam
BackupLLMParams *crossmodel.LLMParams
}
type LLM struct {
r compose.Runnable[map[string]any, map[string]any]
outputFormat Format
outputFields map[string]*vo.TypeInfo
canStream bool
requireCheckpoint bool
fullSources map[string]*nodes.SourceInfo
}
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema2.NodeSchema, error) {
ns := &schema2.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeLLM,
Name: n.Data.Meta.Title,
Configs: c,
}
const (
rawOutputKey = "llm_raw_output_%s"
warningKey = "llm_warning_%s"
)
param := n.Data.Inputs.LLMParam
if param == nil {
return nil, fmt.Errorf("llm node's llmParam is nil")
}
func jsonParse(ctx context.Context, data string, schema_ map[string]*vo.TypeInfo) (map[string]any, error) {
data = nodes.ExtractJSONString(data)
var result map[string]any
err := sonic.UnmarshalString(data, &result)
bs, _ := sonic.Marshal(param)
llmParam := make(vo.LLMParam, 0)
if err := sonic.Unmarshal(bs, &llmParam); err != nil {
return nil, err
}
convertedLLMParam, err := llmParamsToLLMParam(llmParam)
if err != nil {
c := execute.GetExeCtx(ctx)
if c != nil {
logs.CtxErrorf(ctx, "failed to parse json: %v, data: %s", err, data)
rawOutputK := fmt.Sprintf(rawOutputKey, c.NodeCtx.NodeKey)
warningK := fmt.Sprintf(warningKey, c.NodeCtx.NodeKey)
ctxcache.Store(ctx, rawOutputK, data)
ctxcache.Store(ctx, warningK, vo.WrapWarn(errno.ErrLLMStructuredOutputParseFail, err))
return map[string]any{}, nil
}
return nil, err
}
r, ws, err := nodes.ConvertInputs(ctx, result, schema_)
if err != nil {
return nil, vo.WrapError(errno.ErrLLMStructuredOutputParseFail, err)
c.LLMParams = convertedLLMParam
c.SystemPrompt = convertedLLMParam.SystemPrompt
c.UserPrompt = convertedLLMParam.Prompt
var resFormat Format
switch convertedLLMParam.ResponseFormat {
case crossmodel.ResponseFormatText:
resFormat = FormatText
case crossmodel.ResponseFormatMarkdown:
resFormat = FormatMarkdown
case crossmodel.ResponseFormatJSON:
resFormat = FormatJSON
default:
return nil, fmt.Errorf("unsupported response format: %d", convertedLLMParam.ResponseFormat)
}
if ws != nil {
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
c.OutputFormat = resFormat
if err = convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
return r, nil
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
if resFormat == FormatJSON {
if len(ns.OutputTypes) == 1 {
for _, v := range ns.OutputTypes {
if v.Type == vo.DataTypeString {
resFormat = FormatText
break
}
}
} else if len(ns.OutputTypes) == 2 {
if _, ok := ns.OutputTypes[ReasoningOutputKey]; ok {
for k, v := range ns.OutputTypes {
if k != ReasoningOutputKey && v.Type == vo.DataTypeString {
resFormat = FormatText
break
}
}
}
}
}
if resFormat == FormatJSON {
ns.StreamConfigs = &schema2.StreamConfig{
CanGeneratesStream: false,
}
} else {
ns.StreamConfigs = &schema2.StreamConfig{
CanGeneratesStream: true,
}
}
if n.Data.Inputs.LLM != nil && n.Data.Inputs.FCParam != nil {
c.FCParam = n.Data.Inputs.FCParam
}
if se := n.Data.Inputs.SettingOnError; se != nil {
if se.Ext != nil && len(se.Ext.BackupLLMParam) > 0 {
var backupLLMParam vo.SimpleLLMParam
if err = sonic.UnmarshalString(se.Ext.BackupLLMParam, &backupLLMParam); err != nil {
return nil, err
}
backupModel, err := simpleLLMParamsToLLMParams(backupLLMParam)
if err != nil {
return nil, err
}
c.BackupLLMParams = backupModel
}
}
return ns, nil
}
func llmParamsToLLMParam(params vo.LLMParam) (*crossmodel.LLMParams, error) {
p := &crossmodel.LLMParams{}
for _, param := range params {
switch param.Name {
case "temperature":
strVal := param.Input.Value.Content.(string)
floatVal, err := strconv.ParseFloat(strVal, 64)
if err != nil {
return nil, err
}
p.Temperature = &floatVal
case "maxTokens":
strVal := param.Input.Value.Content.(string)
intVal, err := strconv.Atoi(strVal)
if err != nil {
return nil, err
}
p.MaxTokens = intVal
case "responseFormat":
strVal := param.Input.Value.Content.(string)
int64Val, err := strconv.ParseInt(strVal, 10, 64)
if err != nil {
return nil, err
}
p.ResponseFormat = crossmodel.ResponseFormat(int64Val)
case "modleName":
strVal := param.Input.Value.Content.(string)
p.ModelName = strVal
case "modelType":
strVal := param.Input.Value.Content.(string)
int64Val, err := strconv.ParseInt(strVal, 10, 64)
if err != nil {
return nil, err
}
p.ModelType = int64Val
case "prompt":
strVal := param.Input.Value.Content.(string)
p.Prompt = strVal
case "enableChatHistory":
boolVar := param.Input.Value.Content.(bool)
p.EnableChatHistory = boolVar
case "systemPrompt":
strVal := param.Input.Value.Content.(string)
p.SystemPrompt = strVal
case "chatHistoryRound", "generationDiversity", "frequencyPenalty", "presencePenalty":
// do nothing
case "topP":
strVal := param.Input.Value.Content.(string)
floatVar, err := strconv.ParseFloat(strVal, 64)
if err != nil {
return nil, err
}
p.TopP = &floatVar
default:
return nil, fmt.Errorf("invalid LLMParam name: %s", param.Name)
}
}
return p, nil
}
func simpleLLMParamsToLLMParams(params vo.SimpleLLMParam) (*crossmodel.LLMParams, error) {
p := &crossmodel.LLMParams{}
p.ModelName = params.ModelName
p.ModelType = params.ModelType
p.Temperature = &params.Temperature
p.MaxTokens = params.MaxTokens
p.TopP = &params.TopP
p.ResponseFormat = params.ResponseFormat
p.SystemPrompt = params.SystemPrompt
return p, nil
}
func getReasoningContent(message *schema.Message) string {
return message.ReasoningContent
}
type Options struct {
nested []nodes.NestedWorkflowOption
toolWorkflowSW *schema.StreamWriter[*entity.Message]
}
func (c *Config) Build(ctx context.Context, ns *schema2.NodeSchema, _ ...schema2.BuildOption) (any, error) {
var (
err error
chatModel, fallbackM model.BaseChatModel
info, fallbackI *modelmgr.Model
modelWithInfo ModelWithInfo
tools []tool.BaseTool
toolsReturnDirectly map[string]bool
knowledgeRecallConfig *KnowledgeRecallConfig
)
type Option func(o *Options)
func WithNestedWorkflowOptions(nested ...nodes.NestedWorkflowOption) Option {
return func(o *Options) {
o.nested = append(o.nested, nested...)
chatModel, info, err = crossmodel.GetManager().GetModel(ctx, c.LLMParams)
if err != nil {
return nil, err
}
}
func WithToolWorkflowMessageWriter(sw *schema.StreamWriter[*entity.Message]) Option {
return func(o *Options) {
o.toolWorkflowSW = sw
exceptionConf := ns.ExceptionConfigs
if exceptionConf != nil && exceptionConf.MaxRetry > 0 {
backupModelParams := c.BackupLLMParams
if backupModelParams != nil {
fallbackM, fallbackI, err = crossmodel.GetManager().GetModel(ctx, backupModelParams)
if err != nil {
return nil, err
}
}
}
}
type llmState = map[string]any
if fallbackM == nil {
modelWithInfo = NewModel(chatModel, info)
} else {
modelWithInfo = NewModelWithFallback(chatModel, fallbackM, info, fallbackI)
}
const agentModelName = "agent_model"
fcParams := c.FCParam
if fcParams != nil {
if fcParams.WorkflowFCParam != nil {
for _, wf := range fcParams.WorkflowFCParam.WorkflowList {
wfIDStr := wf.WorkflowID
wfID, err := strconv.ParseInt(wfIDStr, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid workflow id: %s", wfIDStr)
}
workflowToolConfig := vo.WorkflowToolConfig{}
if wf.FCSetting != nil {
workflowToolConfig.InputParametersConfig = wf.FCSetting.RequestParameters
workflowToolConfig.OutputParametersConfig = wf.FCSetting.ResponseParameters
}
locator := vo.FromDraft
if wf.WorkflowVersion != "" {
locator = vo.FromSpecificVersion
}
wfTool, err := workflow.GetRepository().WorkflowAsTool(ctx, vo.GetPolicy{
ID: wfID,
QType: locator,
Version: wf.WorkflowVersion,
}, workflowToolConfig)
if err != nil {
return nil, err
}
tools = append(tools, wfTool)
if wfTool.TerminatePlan() == vo.UseAnswerContent {
if toolsReturnDirectly == nil {
toolsReturnDirectly = make(map[string]bool)
}
toolInfo, err := wfTool.Info(ctx)
if err != nil {
return nil, err
}
toolsReturnDirectly[toolInfo.Name] = true
}
}
}
if fcParams.PluginFCParam != nil {
pluginToolsInvokableReq := make(map[int64]*plugin.ToolsInvokableRequest)
for _, p := range fcParams.PluginFCParam.PluginList {
pid, err := strconv.ParseInt(p.PluginID, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid plugin id: %s", p.PluginID)
}
toolID, err := strconv.ParseInt(p.ApiId, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid plugin id: %s", p.PluginID)
}
var (
requestParameters []*workflow3.APIParameter
responseParameters []*workflow3.APIParameter
)
if p.FCSetting != nil {
requestParameters = p.FCSetting.RequestParameters
responseParameters = p.FCSetting.ResponseParameters
}
if req, ok := pluginToolsInvokableReq[pid]; ok {
req.ToolsInvokableInfo[toolID] = &plugin.ToolsInvokableInfo{
ToolID: toolID,
RequestAPIParametersConfig: requestParameters,
ResponseAPIParametersConfig: responseParameters,
}
} else {
pluginToolsInfoRequest := &plugin.ToolsInvokableRequest{
PluginEntity: plugin.Entity{
PluginID: pid,
PluginVersion: ptr.Of(p.PluginVersion),
},
ToolsInvokableInfo: map[int64]*plugin.ToolsInvokableInfo{
toolID: {
ToolID: toolID,
RequestAPIParametersConfig: requestParameters,
ResponseAPIParametersConfig: responseParameters,
},
},
IsDraft: p.IsDraft,
}
pluginToolsInvokableReq[pid] = pluginToolsInfoRequest
}
}
inInvokableTools := make([]tool.BaseTool, 0, len(fcParams.PluginFCParam.PluginList))
for _, req := range pluginToolsInvokableReq {
toolMap, err := plugin.GetPluginService().GetPluginInvokableTools(ctx, req)
if err != nil {
return nil, err
}
for _, t := range toolMap {
inInvokableTools = append(inInvokableTools, plugin.NewInvokableTool(t))
}
}
if len(inInvokableTools) > 0 {
tools = append(tools, inInvokableTools...)
}
}
if fcParams.KnowledgeFCParam != nil && len(fcParams.KnowledgeFCParam.KnowledgeList) > 0 {
kwChatModel := workflow.GetRepository().GetKnowledgeRecallChatModel()
if kwChatModel == nil {
return nil, fmt.Errorf("workflow builtin chat model for knowledge recall not configured")
}
knowledgeOperator := knowledge.GetKnowledgeOperator()
setting := fcParams.KnowledgeFCParam.GlobalSetting
knowledgeRecallConfig = &KnowledgeRecallConfig{
ChatModel: kwChatModel,
Retriever: knowledgeOperator,
}
searchType, err := toRetrievalSearchType(setting.SearchMode)
if err != nil {
return nil, err
}
knowledgeRecallConfig.RetrievalStrategy = &RetrievalStrategy{
RetrievalStrategy: &knowledge.RetrievalStrategy{
TopK: ptr.Of(setting.TopK),
MinScore: ptr.Of(setting.MinScore),
SearchType: searchType,
EnableNL2SQL: setting.UseNL2SQL,
EnableQueryRewrite: setting.UseRewrite,
EnableRerank: setting.UseRerank,
},
NoReCallReplyMode: NoReCallReplyMode(setting.NoRecallReplyMode),
NoReCallReplyCustomizePrompt: setting.NoRecallReplyCustomizePrompt,
}
knowledgeIDs := make([]int64, 0, len(fcParams.KnowledgeFCParam.KnowledgeList))
for _, kw := range fcParams.KnowledgeFCParam.KnowledgeList {
kid, err := strconv.ParseInt(kw.ID, 10, 64)
if err != nil {
return nil, err
}
knowledgeIDs = append(knowledgeIDs, kid)
}
detailResp, err := knowledgeOperator.ListKnowledgeDetail(ctx,
&knowledge.ListKnowledgeDetailRequest{
KnowledgeIDs: knowledgeIDs,
})
if err != nil {
return nil, err
}
knowledgeRecallConfig.SelectedKnowledgeDetails = detailResp.KnowledgeDetails
}
}
func New(ctx context.Context, cfg *Config) (*LLM, error) {
g := compose.NewGraph[map[string]any, map[string]any](compose.WithGenLocalState(func(ctx context.Context) (state llmState) {
return llmState{}
}))
var (
hasReasoning bool
canStream = true
)
var hasReasoning bool
format := cfg.OutputFormat
format := c.OutputFormat
if format == FormatJSON {
if len(cfg.OutputFields) == 1 {
for _, v := range cfg.OutputFields {
if len(ns.OutputTypes) == 1 {
for _, v := range ns.OutputTypes {
if v.Type == vo.DataTypeString {
format = FormatText
break
}
}
} else if len(cfg.OutputFields) == 2 {
if _, ok := cfg.OutputFields[ReasoningOutputKey]; ok {
for k, v := range cfg.OutputFields {
} else if len(ns.OutputTypes) == 2 {
if _, ok := ns.OutputTypes[ReasoningOutputKey]; ok {
for k, v := range ns.OutputTypes {
if k != ReasoningOutputKey && v.Type == vo.DataTypeString {
format = FormatText
break
@@ -272,10 +561,10 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
}
}
userPrompt := cfg.UserPrompt
userPrompt := c.UserPrompt
switch format {
case FormatJSON:
jsonSchema, err := vo.TypeInfoToJSONSchema(cfg.OutputFields, nil)
jsonSchema, err := vo.TypeInfoToJSONSchema(ns.OutputTypes, nil)
if err != nil {
return nil, err
}
@@ -287,20 +576,20 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
case FormatText:
}
if cfg.KnowledgeRecallConfig != nil {
err := injectKnowledgeTool(ctx, g, cfg.UserPrompt, cfg.KnowledgeRecallConfig)
if knowledgeRecallConfig != nil {
err := injectKnowledgeTool(ctx, g, c.UserPrompt, knowledgeRecallConfig)
if err != nil {
return nil, err
}
userPrompt = fmt.Sprintf("{{%s}}%s", knowledgeUserPromptTemplateKey, userPrompt)
inputs := maps.Clone(cfg.InputFields)
inputs := maps.Clone(ns.InputTypes)
inputs[knowledgeUserPromptTemplateKey] = &vo.TypeInfo{
Type: vo.DataTypeString,
}
sp := newPromptTpl(schema.System, cfg.SystemPrompt, inputs, nil)
sp := newPromptTpl(schema.System, c.SystemPrompt, inputs, nil)
up := newPromptTpl(schema.User, userPrompt, inputs, []string{knowledgeUserPromptTemplateKey})
template := newPrompts(sp, up, cfg.ChatModel)
template := newPrompts(sp, up, modelWithInfo)
_ = g.AddChatTemplateNode(templateNodeKey, template,
compose.WithStatePreHandler(func(ctx context.Context, in map[string]any, state llmState) (map[string]any, error) {
@@ -312,28 +601,28 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
_ = g.AddEdge(knowledgeLambdaKey, templateNodeKey)
} else {
sp := newPromptTpl(schema.System, cfg.SystemPrompt, cfg.InputFields, nil)
up := newPromptTpl(schema.User, userPrompt, cfg.InputFields, nil)
template := newPrompts(sp, up, cfg.ChatModel)
sp := newPromptTpl(schema.System, c.SystemPrompt, ns.InputTypes, nil)
up := newPromptTpl(schema.User, userPrompt, ns.InputTypes, nil)
template := newPrompts(sp, up, modelWithInfo)
_ = g.AddChatTemplateNode(templateNodeKey, template)
_ = g.AddEdge(compose.START, templateNodeKey)
}
if len(cfg.Tools) > 0 {
m, ok := cfg.ChatModel.(model.ToolCallingChatModel)
if len(tools) > 0 {
m, ok := modelWithInfo.(model.ToolCallingChatModel)
if !ok {
return nil, errors.New("requires a ToolCallingChatModel to use with tools")
}
reactConfig := react.AgentConfig{
ToolCallingModel: m,
ToolsConfig: compose.ToolsNodeConfig{Tools: cfg.Tools},
ToolsConfig: compose.ToolsNodeConfig{Tools: tools},
ModelNodeName: agentModelName,
}
if len(cfg.ToolsReturnDirectly) > 0 {
reactConfig.ToolReturnDirectly = make(map[string]struct{}, len(cfg.ToolsReturnDirectly))
for k := range cfg.ToolsReturnDirectly {
if len(toolsReturnDirectly) > 0 {
reactConfig.ToolReturnDirectly = make(map[string]struct{}, len(toolsReturnDirectly))
for k := range toolsReturnDirectly {
reactConfig.ToolReturnDirectly[k] = struct{}{}
}
}
@@ -347,28 +636,26 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
opts = append(opts, compose.WithNodeName("workflow_llm_react_agent"))
_ = g.AddGraphNode(llmNodeKey, agentNode, opts...)
} else {
_ = g.AddChatModelNode(llmNodeKey, cfg.ChatModel)
_ = g.AddChatModelNode(llmNodeKey, modelWithInfo)
}
_ = g.AddEdge(templateNodeKey, llmNodeKey)
if format == FormatJSON {
iConvert := func(ctx context.Context, msg *schema.Message) (map[string]any, error) {
return jsonParse(ctx, msg.Content, cfg.OutputFields)
return jsonParse(ctx, msg.Content, ns.OutputTypes)
}
convertNode := compose.InvokableLambda(iConvert)
_ = g.AddLambdaNode(outputConvertNodeKey, convertNode)
canStream = false
} else {
var outputKey string
if len(cfg.OutputFields) != 1 && len(cfg.OutputFields) != 2 {
if len(ns.OutputTypes) != 1 && len(ns.OutputTypes) != 2 {
panic("impossible")
}
for k, v := range cfg.OutputFields {
for k, v := range ns.OutputTypes {
if v.Type != vo.DataTypeString {
panic("impossible")
}
@@ -443,17 +730,17 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
_ = g.AddEdge(outputConvertNodeKey, compose.END)
requireCheckpoint := false
if len(cfg.Tools) > 0 {
if len(tools) > 0 {
requireCheckpoint = true
}
var opts []compose.GraphCompileOption
var compileOpts []compose.GraphCompileOption
if requireCheckpoint {
opts = append(opts, compose.WithCheckPointStore(workflow.GetRepository()))
compileOpts = append(compileOpts, compose.WithCheckPointStore(workflow.GetRepository()))
}
opts = append(opts, compose.WithGraphName("workflow_llm_node_graph"))
compileOpts = append(compileOpts, compose.WithGraphName("workflow_llm_node_graph"))
r, err := g.Compile(ctx, opts...)
r, err := g.Compile(ctx, compileOpts...)
if err != nil {
return nil, err
}
@@ -461,15 +748,132 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
llm := &LLM{
r: r,
outputFormat: format,
canStream: canStream,
requireCheckpoint: requireCheckpoint,
fullSources: cfg.FullSources,
fullSources: ns.FullSources,
}
return llm, nil
}
func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...Option) (composeOpts []compose.Option, resumingEvent *entity.InterruptEvent, err error) {
func (c *Config) RequireCheckpoint() bool {
if c.FCParam != nil {
if c.FCParam.WorkflowFCParam != nil || c.FCParam.PluginFCParam != nil {
return true
}
}
return false
}
func (c *Config) FieldStreamType(path compose.FieldPath, ns *schema2.NodeSchema,
sc *schema2.WorkflowSchema) (schema2.FieldStreamType, error) {
if !sc.RequireStreaming() {
return schema2.FieldNotStream, nil
}
if len(path) != 1 {
return schema2.FieldNotStream, nil
}
outputs := ns.OutputTypes
if len(outputs) != 1 && len(outputs) != 2 {
return schema2.FieldNotStream, nil
}
var outputKey string
for key, output := range outputs {
if output.Type != vo.DataTypeString {
return schema2.FieldNotStream, nil
}
if key != ReasoningOutputKey {
if len(outputKey) > 0 {
return schema2.FieldNotStream, nil
}
outputKey = key
}
}
field := path[0]
if field == ReasoningOutputKey || field == outputKey {
return schema2.FieldIsStream, nil
}
return schema2.FieldNotStream, nil
}
func toRetrievalSearchType(s int64) (knowledge.SearchType, error) {
switch s {
case 0:
return knowledge.SearchTypeSemantic, nil
case 1:
return knowledge.SearchTypeHybrid, nil
case 20:
return knowledge.SearchTypeFullText, nil
default:
return "", fmt.Errorf("invalid retrieval search type %v", s)
}
}
type LLM struct {
r compose.Runnable[map[string]any, map[string]any]
outputFormat Format
requireCheckpoint bool
fullSources map[string]*schema2.SourceInfo
}
const (
rawOutputKey = "llm_raw_output_%s"
warningKey = "llm_warning_%s"
)
func jsonParse(ctx context.Context, data string, schema_ map[string]*vo.TypeInfo) (map[string]any, error) {
data = nodes.ExtractJSONString(data)
var result map[string]any
err := sonic.UnmarshalString(data, &result)
if err != nil {
c := execute.GetExeCtx(ctx)
if c != nil {
logs.CtxErrorf(ctx, "failed to parse json: %v, data: %s", err, data)
rawOutputK := fmt.Sprintf(rawOutputKey, c.NodeCtx.NodeKey)
warningK := fmt.Sprintf(warningKey, c.NodeCtx.NodeKey)
ctxcache.Store(ctx, rawOutputK, data)
ctxcache.Store(ctx, warningK, vo.WrapWarn(errno.ErrLLMStructuredOutputParseFail, err))
return map[string]any{}, nil
}
return nil, err
}
r, ws, err := nodes.ConvertInputs(ctx, result, schema_)
if err != nil {
return nil, vo.WrapError(errno.ErrLLMStructuredOutputParseFail, err)
}
if ws != nil {
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
}
return r, nil
}
type llmOptions struct {
toolWorkflowSW *schema.StreamWriter[*entity.Message]
}
func WithToolWorkflowMessageWriter(sw *schema.StreamWriter[*entity.Message]) nodes.NodeOption {
return nodes.WrapImplSpecificOptFn(func(o *llmOptions) {
o.toolWorkflowSW = sw
})
}
type llmState = map[string]any
const agentModelName = "agent_model"
func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...nodes.NodeOption) (composeOpts []compose.Option, resumingEvent *entity.InterruptEvent, err error) {
c := execute.GetExeCtx(ctx)
if c != nil {
resumingEvent = c.NodeCtx.ResumingEvent
@@ -502,17 +906,9 @@ func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...Option) (co
composeOpts = append(composeOpts, compose.WithCheckPointID(checkpointID))
}
llmOpts := &Options{}
for _, opt := range opts {
opt(llmOpts)
}
options := nodes.GetCommonOptions(&nodes.NodeOptions{}, opts...)
nestedOpts := &nodes.NestedWorkflowOptions{}
for _, opt := range llmOpts.nested {
opt(nestedOpts)
}
composeOpts = append(composeOpts, nestedOpts.GetOptsForNested()...)
composeOpts = append(composeOpts, options.GetOptsForNested()...)
if resumingEvent != nil {
var (
@@ -580,6 +976,7 @@ func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...Option) (co
composeOpts = append(composeOpts, compose.WithToolsNodeOption(compose.WithToolOption(execute.WithExecuteConfig(exeCfg))))
}
llmOpts := nodes.GetImplSpecificOptions(&llmOptions{}, opts...)
if llmOpts.toolWorkflowSW != nil {
toolMsgOpt, toolMsgSR := execute.WithMessagePipe()
composeOpts = append(composeOpts, toolMsgOpt)
@@ -697,7 +1094,7 @@ func handleInterrupt(ctx context.Context, err error, resumingEvent *entity.Inter
return compose.NewInterruptAndRerunErr(ie)
}
func (l *LLM) Chat(ctx context.Context, in map[string]any, opts ...Option) (out map[string]any, err error) {
func (l *LLM) Invoke(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (out map[string]any, err error) {
composeOpts, resumingEvent, err := l.prepare(ctx, in, opts...)
if err != nil {
return nil, err
@@ -712,7 +1109,7 @@ func (l *LLM) Chat(ctx context.Context, in map[string]any, opts ...Option) (out
return out, nil
}
func (l *LLM) ChatStream(ctx context.Context, in map[string]any, opts ...Option) (out *schema.StreamReader[map[string]any], err error) {
func (l *LLM) Stream(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (out *schema.StreamReader[map[string]any], err error) {
composeOpts, resumingEvent, err := l.prepare(ctx, in, opts...)
if err != nil {
return nil, err
@@ -745,7 +1142,7 @@ func injectKnowledgeTool(_ context.Context, g *compose.Graph[map[string]any, map
_ = g.AddLambdaNode(knowledgeLambdaKey, compose.InvokableLambda(func(ctx context.Context, input *schema.Message) (output map[string]any, err error) {
modelPredictionIDs := strings.Split(input.Content, ",")
selectKwIDs := slices.ToMap(cfg.SelectedKnowledgeDetails, func(e *crossknowledge.KnowledgeDetail) (string, int64) {
selectKwIDs := slices.ToMap(cfg.SelectedKnowledgeDetails, func(e *knowledge.KnowledgeDetail) (string, int64) {
return strconv.Itoa(int(e.ID)), e.ID
})
recallKnowledgeIDs := make([]int64, 0)
@@ -759,7 +1156,7 @@ func injectKnowledgeTool(_ context.Context, g *compose.Graph[map[string]any, map
return make(map[string]any), nil
}
docs, err := cfg.Retriever.Retrieve(ctx, &crossknowledge.RetrieveRequest{
docs, err := cfg.Retriever.Retrieve(ctx, &knowledge.RetrieveRequest{
Query: userPrompt,
KnowledgeIDs: recallKnowledgeIDs,
RetrievalStrategy: cfg.RetrievalStrategy.RetrievalStrategy,

View File

@@ -26,6 +26,7 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
@@ -107,7 +108,7 @@ func newPrompts(sp, up *promptTpl, model ModelWithInfo) *prompts {
}
func (pl *promptTpl) render(ctx context.Context, vs map[string]any,
sources map[string]*nodes.SourceInfo,
sources map[string]*schema2.SourceInfo,
supportedModals map[modelmgr.Modal]bool,
) (*schema.Message, error) {
if !pl.hasMultiModal || len(supportedModals) == 0 {
@@ -247,7 +248,7 @@ func (p *prompts) Format(ctx context.Context, vs map[string]any, _ ...prompt.Opt
}
sk := fmt.Sprintf(sourceKey, nodeKey)
sources, ok := ctxcache.Get[map[string]*nodes.SourceInfo](ctx, sk)
sources, ok := ctxcache.Get[map[string]*schema2.SourceInfo](ctx, sk)
if !ok {
return nil, fmt.Errorf("resolved sources not found llm node, key: %s", sk)
}