refactor: how to add a node type in workflow (#558)
This commit is contained in:
@@ -34,13 +34,20 @@ import (
|
||||
callbacks2 "github.com/cloudwego/eino/utils/callbacks"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
workflow3 "github.com/coze-dev/coze-studio/backend/api/model/ocean/cloud/workflow"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow"
|
||||
crossknowledge "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/knowledge"
|
||||
crossmodel "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/plugin"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/ptr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/lang/slices"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/logs"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/safego"
|
||||
@@ -143,126 +150,408 @@ const (
|
||||
)
|
||||
|
||||
type RetrievalStrategy struct {
|
||||
RetrievalStrategy *crossknowledge.RetrievalStrategy
|
||||
RetrievalStrategy *knowledge.RetrievalStrategy
|
||||
NoReCallReplyMode NoReCallReplyMode
|
||||
NoReCallReplyCustomizePrompt string
|
||||
}
|
||||
|
||||
type KnowledgeRecallConfig struct {
|
||||
ChatModel model.BaseChatModel
|
||||
Retriever crossknowledge.KnowledgeOperator
|
||||
Retriever knowledge.KnowledgeOperator
|
||||
RetrievalStrategy *RetrievalStrategy
|
||||
SelectedKnowledgeDetails []*crossknowledge.KnowledgeDetail
|
||||
SelectedKnowledgeDetails []*knowledge.KnowledgeDetail
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
ChatModel ModelWithInfo
|
||||
Tools []tool.BaseTool
|
||||
SystemPrompt string
|
||||
UserPrompt string
|
||||
OutputFormat Format
|
||||
InputFields map[string]*vo.TypeInfo
|
||||
OutputFields map[string]*vo.TypeInfo
|
||||
ToolsReturnDirectly map[string]bool
|
||||
KnowledgeRecallConfig *KnowledgeRecallConfig
|
||||
FullSources map[string]*nodes.SourceInfo
|
||||
SystemPrompt string
|
||||
UserPrompt string
|
||||
OutputFormat Format
|
||||
LLMParams *crossmodel.LLMParams
|
||||
FCParam *vo.FCParam
|
||||
BackupLLMParams *crossmodel.LLMParams
|
||||
}
|
||||
|
||||
type LLM struct {
|
||||
r compose.Runnable[map[string]any, map[string]any]
|
||||
outputFormat Format
|
||||
outputFields map[string]*vo.TypeInfo
|
||||
canStream bool
|
||||
requireCheckpoint bool
|
||||
fullSources map[string]*nodes.SourceInfo
|
||||
}
|
||||
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema2.NodeSchema, error) {
|
||||
ns := &schema2.NodeSchema{
|
||||
Key: vo.NodeKey(n.ID),
|
||||
Type: entity.NodeTypeLLM,
|
||||
Name: n.Data.Meta.Title,
|
||||
Configs: c,
|
||||
}
|
||||
|
||||
const (
|
||||
rawOutputKey = "llm_raw_output_%s"
|
||||
warningKey = "llm_warning_%s"
|
||||
)
|
||||
param := n.Data.Inputs.LLMParam
|
||||
if param == nil {
|
||||
return nil, fmt.Errorf("llm node's llmParam is nil")
|
||||
}
|
||||
|
||||
func jsonParse(ctx context.Context, data string, schema_ map[string]*vo.TypeInfo) (map[string]any, error) {
|
||||
data = nodes.ExtractJSONString(data)
|
||||
|
||||
var result map[string]any
|
||||
|
||||
err := sonic.UnmarshalString(data, &result)
|
||||
bs, _ := sonic.Marshal(param)
|
||||
llmParam := make(vo.LLMParam, 0)
|
||||
if err := sonic.Unmarshal(bs, &llmParam); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
convertedLLMParam, err := llmParamsToLLMParam(llmParam)
|
||||
if err != nil {
|
||||
c := execute.GetExeCtx(ctx)
|
||||
if c != nil {
|
||||
logs.CtxErrorf(ctx, "failed to parse json: %v, data: %s", err, data)
|
||||
rawOutputK := fmt.Sprintf(rawOutputKey, c.NodeCtx.NodeKey)
|
||||
warningK := fmt.Sprintf(warningKey, c.NodeCtx.NodeKey)
|
||||
ctxcache.Store(ctx, rawOutputK, data)
|
||||
ctxcache.Store(ctx, warningK, vo.WrapWarn(errno.ErrLLMStructuredOutputParseFail, err))
|
||||
return map[string]any{}, nil
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r, ws, err := nodes.ConvertInputs(ctx, result, schema_)
|
||||
if err != nil {
|
||||
return nil, vo.WrapError(errno.ErrLLMStructuredOutputParseFail, err)
|
||||
c.LLMParams = convertedLLMParam
|
||||
c.SystemPrompt = convertedLLMParam.SystemPrompt
|
||||
c.UserPrompt = convertedLLMParam.Prompt
|
||||
|
||||
var resFormat Format
|
||||
switch convertedLLMParam.ResponseFormat {
|
||||
case crossmodel.ResponseFormatText:
|
||||
resFormat = FormatText
|
||||
case crossmodel.ResponseFormatMarkdown:
|
||||
resFormat = FormatMarkdown
|
||||
case crossmodel.ResponseFormatJSON:
|
||||
resFormat = FormatJSON
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported response format: %d", convertedLLMParam.ResponseFormat)
|
||||
}
|
||||
|
||||
if ws != nil {
|
||||
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
|
||||
c.OutputFormat = resFormat
|
||||
|
||||
if err = convert.SetInputsForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return r, nil
|
||||
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resFormat == FormatJSON {
|
||||
if len(ns.OutputTypes) == 1 {
|
||||
for _, v := range ns.OutputTypes {
|
||||
if v.Type == vo.DataTypeString {
|
||||
resFormat = FormatText
|
||||
break
|
||||
}
|
||||
}
|
||||
} else if len(ns.OutputTypes) == 2 {
|
||||
if _, ok := ns.OutputTypes[ReasoningOutputKey]; ok {
|
||||
for k, v := range ns.OutputTypes {
|
||||
if k != ReasoningOutputKey && v.Type == vo.DataTypeString {
|
||||
resFormat = FormatText
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if resFormat == FormatJSON {
|
||||
ns.StreamConfigs = &schema2.StreamConfig{
|
||||
CanGeneratesStream: false,
|
||||
}
|
||||
} else {
|
||||
ns.StreamConfigs = &schema2.StreamConfig{
|
||||
CanGeneratesStream: true,
|
||||
}
|
||||
}
|
||||
|
||||
if n.Data.Inputs.LLM != nil && n.Data.Inputs.FCParam != nil {
|
||||
c.FCParam = n.Data.Inputs.FCParam
|
||||
}
|
||||
|
||||
if se := n.Data.Inputs.SettingOnError; se != nil {
|
||||
if se.Ext != nil && len(se.Ext.BackupLLMParam) > 0 {
|
||||
var backupLLMParam vo.SimpleLLMParam
|
||||
if err = sonic.UnmarshalString(se.Ext.BackupLLMParam, &backupLLMParam); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
backupModel, err := simpleLLMParamsToLLMParams(backupLLMParam)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.BackupLLMParams = backupModel
|
||||
}
|
||||
}
|
||||
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func llmParamsToLLMParam(params vo.LLMParam) (*crossmodel.LLMParams, error) {
|
||||
p := &crossmodel.LLMParams{}
|
||||
for _, param := range params {
|
||||
switch param.Name {
|
||||
case "temperature":
|
||||
strVal := param.Input.Value.Content.(string)
|
||||
floatVal, err := strconv.ParseFloat(strVal, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.Temperature = &floatVal
|
||||
case "maxTokens":
|
||||
strVal := param.Input.Value.Content.(string)
|
||||
intVal, err := strconv.Atoi(strVal)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.MaxTokens = intVal
|
||||
case "responseFormat":
|
||||
strVal := param.Input.Value.Content.(string)
|
||||
int64Val, err := strconv.ParseInt(strVal, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.ResponseFormat = crossmodel.ResponseFormat(int64Val)
|
||||
case "modleName":
|
||||
strVal := param.Input.Value.Content.(string)
|
||||
p.ModelName = strVal
|
||||
case "modelType":
|
||||
strVal := param.Input.Value.Content.(string)
|
||||
int64Val, err := strconv.ParseInt(strVal, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.ModelType = int64Val
|
||||
case "prompt":
|
||||
strVal := param.Input.Value.Content.(string)
|
||||
p.Prompt = strVal
|
||||
case "enableChatHistory":
|
||||
boolVar := param.Input.Value.Content.(bool)
|
||||
p.EnableChatHistory = boolVar
|
||||
case "systemPrompt":
|
||||
strVal := param.Input.Value.Content.(string)
|
||||
p.SystemPrompt = strVal
|
||||
case "chatHistoryRound", "generationDiversity", "frequencyPenalty", "presencePenalty":
|
||||
// do nothing
|
||||
case "topP":
|
||||
strVal := param.Input.Value.Content.(string)
|
||||
floatVar, err := strconv.ParseFloat(strVal, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.TopP = &floatVar
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid LLMParam name: %s", param.Name)
|
||||
}
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func simpleLLMParamsToLLMParams(params vo.SimpleLLMParam) (*crossmodel.LLMParams, error) {
|
||||
p := &crossmodel.LLMParams{}
|
||||
p.ModelName = params.ModelName
|
||||
p.ModelType = params.ModelType
|
||||
p.Temperature = ¶ms.Temperature
|
||||
p.MaxTokens = params.MaxTokens
|
||||
p.TopP = ¶ms.TopP
|
||||
p.ResponseFormat = params.ResponseFormat
|
||||
p.SystemPrompt = params.SystemPrompt
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func getReasoningContent(message *schema.Message) string {
|
||||
return message.ReasoningContent
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
nested []nodes.NestedWorkflowOption
|
||||
toolWorkflowSW *schema.StreamWriter[*entity.Message]
|
||||
}
|
||||
func (c *Config) Build(ctx context.Context, ns *schema2.NodeSchema, _ ...schema2.BuildOption) (any, error) {
|
||||
var (
|
||||
err error
|
||||
chatModel, fallbackM model.BaseChatModel
|
||||
info, fallbackI *modelmgr.Model
|
||||
modelWithInfo ModelWithInfo
|
||||
tools []tool.BaseTool
|
||||
toolsReturnDirectly map[string]bool
|
||||
knowledgeRecallConfig *KnowledgeRecallConfig
|
||||
)
|
||||
|
||||
type Option func(o *Options)
|
||||
|
||||
func WithNestedWorkflowOptions(nested ...nodes.NestedWorkflowOption) Option {
|
||||
return func(o *Options) {
|
||||
o.nested = append(o.nested, nested...)
|
||||
chatModel, info, err = crossmodel.GetManager().GetModel(ctx, c.LLMParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func WithToolWorkflowMessageWriter(sw *schema.StreamWriter[*entity.Message]) Option {
|
||||
return func(o *Options) {
|
||||
o.toolWorkflowSW = sw
|
||||
exceptionConf := ns.ExceptionConfigs
|
||||
if exceptionConf != nil && exceptionConf.MaxRetry > 0 {
|
||||
backupModelParams := c.BackupLLMParams
|
||||
if backupModelParams != nil {
|
||||
fallbackM, fallbackI, err = crossmodel.GetManager().GetModel(ctx, backupModelParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type llmState = map[string]any
|
||||
if fallbackM == nil {
|
||||
modelWithInfo = NewModel(chatModel, info)
|
||||
} else {
|
||||
modelWithInfo = NewModelWithFallback(chatModel, fallbackM, info, fallbackI)
|
||||
}
|
||||
|
||||
const agentModelName = "agent_model"
|
||||
fcParams := c.FCParam
|
||||
if fcParams != nil {
|
||||
if fcParams.WorkflowFCParam != nil {
|
||||
for _, wf := range fcParams.WorkflowFCParam.WorkflowList {
|
||||
wfIDStr := wf.WorkflowID
|
||||
wfID, err := strconv.ParseInt(wfIDStr, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid workflow id: %s", wfIDStr)
|
||||
}
|
||||
|
||||
workflowToolConfig := vo.WorkflowToolConfig{}
|
||||
if wf.FCSetting != nil {
|
||||
workflowToolConfig.InputParametersConfig = wf.FCSetting.RequestParameters
|
||||
workflowToolConfig.OutputParametersConfig = wf.FCSetting.ResponseParameters
|
||||
}
|
||||
|
||||
locator := vo.FromDraft
|
||||
if wf.WorkflowVersion != "" {
|
||||
locator = vo.FromSpecificVersion
|
||||
}
|
||||
|
||||
wfTool, err := workflow.GetRepository().WorkflowAsTool(ctx, vo.GetPolicy{
|
||||
ID: wfID,
|
||||
QType: locator,
|
||||
Version: wf.WorkflowVersion,
|
||||
}, workflowToolConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tools = append(tools, wfTool)
|
||||
if wfTool.TerminatePlan() == vo.UseAnswerContent {
|
||||
if toolsReturnDirectly == nil {
|
||||
toolsReturnDirectly = make(map[string]bool)
|
||||
}
|
||||
toolInfo, err := wfTool.Info(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
toolsReturnDirectly[toolInfo.Name] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if fcParams.PluginFCParam != nil {
|
||||
pluginToolsInvokableReq := make(map[int64]*plugin.ToolsInvokableRequest)
|
||||
for _, p := range fcParams.PluginFCParam.PluginList {
|
||||
pid, err := strconv.ParseInt(p.PluginID, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid plugin id: %s", p.PluginID)
|
||||
}
|
||||
toolID, err := strconv.ParseInt(p.ApiId, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid plugin id: %s", p.PluginID)
|
||||
}
|
||||
|
||||
var (
|
||||
requestParameters []*workflow3.APIParameter
|
||||
responseParameters []*workflow3.APIParameter
|
||||
)
|
||||
if p.FCSetting != nil {
|
||||
requestParameters = p.FCSetting.RequestParameters
|
||||
responseParameters = p.FCSetting.ResponseParameters
|
||||
}
|
||||
|
||||
if req, ok := pluginToolsInvokableReq[pid]; ok {
|
||||
req.ToolsInvokableInfo[toolID] = &plugin.ToolsInvokableInfo{
|
||||
ToolID: toolID,
|
||||
RequestAPIParametersConfig: requestParameters,
|
||||
ResponseAPIParametersConfig: responseParameters,
|
||||
}
|
||||
} else {
|
||||
pluginToolsInfoRequest := &plugin.ToolsInvokableRequest{
|
||||
PluginEntity: plugin.Entity{
|
||||
PluginID: pid,
|
||||
PluginVersion: ptr.Of(p.PluginVersion),
|
||||
},
|
||||
ToolsInvokableInfo: map[int64]*plugin.ToolsInvokableInfo{
|
||||
toolID: {
|
||||
ToolID: toolID,
|
||||
RequestAPIParametersConfig: requestParameters,
|
||||
ResponseAPIParametersConfig: responseParameters,
|
||||
},
|
||||
},
|
||||
IsDraft: p.IsDraft,
|
||||
}
|
||||
pluginToolsInvokableReq[pid] = pluginToolsInfoRequest
|
||||
}
|
||||
}
|
||||
inInvokableTools := make([]tool.BaseTool, 0, len(fcParams.PluginFCParam.PluginList))
|
||||
for _, req := range pluginToolsInvokableReq {
|
||||
toolMap, err := plugin.GetPluginService().GetPluginInvokableTools(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, t := range toolMap {
|
||||
inInvokableTools = append(inInvokableTools, plugin.NewInvokableTool(t))
|
||||
}
|
||||
}
|
||||
if len(inInvokableTools) > 0 {
|
||||
tools = append(tools, inInvokableTools...)
|
||||
}
|
||||
}
|
||||
|
||||
if fcParams.KnowledgeFCParam != nil && len(fcParams.KnowledgeFCParam.KnowledgeList) > 0 {
|
||||
kwChatModel := workflow.GetRepository().GetKnowledgeRecallChatModel()
|
||||
if kwChatModel == nil {
|
||||
return nil, fmt.Errorf("workflow builtin chat model for knowledge recall not configured")
|
||||
}
|
||||
|
||||
knowledgeOperator := knowledge.GetKnowledgeOperator()
|
||||
setting := fcParams.KnowledgeFCParam.GlobalSetting
|
||||
knowledgeRecallConfig = &KnowledgeRecallConfig{
|
||||
ChatModel: kwChatModel,
|
||||
Retriever: knowledgeOperator,
|
||||
}
|
||||
searchType, err := toRetrievalSearchType(setting.SearchMode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
knowledgeRecallConfig.RetrievalStrategy = &RetrievalStrategy{
|
||||
RetrievalStrategy: &knowledge.RetrievalStrategy{
|
||||
TopK: ptr.Of(setting.TopK),
|
||||
MinScore: ptr.Of(setting.MinScore),
|
||||
SearchType: searchType,
|
||||
EnableNL2SQL: setting.UseNL2SQL,
|
||||
EnableQueryRewrite: setting.UseRewrite,
|
||||
EnableRerank: setting.UseRerank,
|
||||
},
|
||||
NoReCallReplyMode: NoReCallReplyMode(setting.NoRecallReplyMode),
|
||||
NoReCallReplyCustomizePrompt: setting.NoRecallReplyCustomizePrompt,
|
||||
}
|
||||
|
||||
knowledgeIDs := make([]int64, 0, len(fcParams.KnowledgeFCParam.KnowledgeList))
|
||||
for _, kw := range fcParams.KnowledgeFCParam.KnowledgeList {
|
||||
kid, err := strconv.ParseInt(kw.ID, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
knowledgeIDs = append(knowledgeIDs, kid)
|
||||
}
|
||||
|
||||
detailResp, err := knowledgeOperator.ListKnowledgeDetail(ctx,
|
||||
&knowledge.ListKnowledgeDetailRequest{
|
||||
KnowledgeIDs: knowledgeIDs,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
knowledgeRecallConfig.SelectedKnowledgeDetails = detailResp.KnowledgeDetails
|
||||
}
|
||||
}
|
||||
|
||||
func New(ctx context.Context, cfg *Config) (*LLM, error) {
|
||||
g := compose.NewGraph[map[string]any, map[string]any](compose.WithGenLocalState(func(ctx context.Context) (state llmState) {
|
||||
return llmState{}
|
||||
}))
|
||||
|
||||
var (
|
||||
hasReasoning bool
|
||||
canStream = true
|
||||
)
|
||||
var hasReasoning bool
|
||||
|
||||
format := cfg.OutputFormat
|
||||
format := c.OutputFormat
|
||||
if format == FormatJSON {
|
||||
if len(cfg.OutputFields) == 1 {
|
||||
for _, v := range cfg.OutputFields {
|
||||
if len(ns.OutputTypes) == 1 {
|
||||
for _, v := range ns.OutputTypes {
|
||||
if v.Type == vo.DataTypeString {
|
||||
format = FormatText
|
||||
break
|
||||
}
|
||||
}
|
||||
} else if len(cfg.OutputFields) == 2 {
|
||||
if _, ok := cfg.OutputFields[ReasoningOutputKey]; ok {
|
||||
for k, v := range cfg.OutputFields {
|
||||
} else if len(ns.OutputTypes) == 2 {
|
||||
if _, ok := ns.OutputTypes[ReasoningOutputKey]; ok {
|
||||
for k, v := range ns.OutputTypes {
|
||||
if k != ReasoningOutputKey && v.Type == vo.DataTypeString {
|
||||
format = FormatText
|
||||
break
|
||||
@@ -272,10 +561,10 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
|
||||
}
|
||||
}
|
||||
|
||||
userPrompt := cfg.UserPrompt
|
||||
userPrompt := c.UserPrompt
|
||||
switch format {
|
||||
case FormatJSON:
|
||||
jsonSchema, err := vo.TypeInfoToJSONSchema(cfg.OutputFields, nil)
|
||||
jsonSchema, err := vo.TypeInfoToJSONSchema(ns.OutputTypes, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -287,20 +576,20 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
|
||||
case FormatText:
|
||||
}
|
||||
|
||||
if cfg.KnowledgeRecallConfig != nil {
|
||||
err := injectKnowledgeTool(ctx, g, cfg.UserPrompt, cfg.KnowledgeRecallConfig)
|
||||
if knowledgeRecallConfig != nil {
|
||||
err := injectKnowledgeTool(ctx, g, c.UserPrompt, knowledgeRecallConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userPrompt = fmt.Sprintf("{{%s}}%s", knowledgeUserPromptTemplateKey, userPrompt)
|
||||
|
||||
inputs := maps.Clone(cfg.InputFields)
|
||||
inputs := maps.Clone(ns.InputTypes)
|
||||
inputs[knowledgeUserPromptTemplateKey] = &vo.TypeInfo{
|
||||
Type: vo.DataTypeString,
|
||||
}
|
||||
sp := newPromptTpl(schema.System, cfg.SystemPrompt, inputs, nil)
|
||||
sp := newPromptTpl(schema.System, c.SystemPrompt, inputs, nil)
|
||||
up := newPromptTpl(schema.User, userPrompt, inputs, []string{knowledgeUserPromptTemplateKey})
|
||||
template := newPrompts(sp, up, cfg.ChatModel)
|
||||
template := newPrompts(sp, up, modelWithInfo)
|
||||
|
||||
_ = g.AddChatTemplateNode(templateNodeKey, template,
|
||||
compose.WithStatePreHandler(func(ctx context.Context, in map[string]any, state llmState) (map[string]any, error) {
|
||||
@@ -312,28 +601,28 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
|
||||
_ = g.AddEdge(knowledgeLambdaKey, templateNodeKey)
|
||||
|
||||
} else {
|
||||
sp := newPromptTpl(schema.System, cfg.SystemPrompt, cfg.InputFields, nil)
|
||||
up := newPromptTpl(schema.User, userPrompt, cfg.InputFields, nil)
|
||||
template := newPrompts(sp, up, cfg.ChatModel)
|
||||
sp := newPromptTpl(schema.System, c.SystemPrompt, ns.InputTypes, nil)
|
||||
up := newPromptTpl(schema.User, userPrompt, ns.InputTypes, nil)
|
||||
template := newPrompts(sp, up, modelWithInfo)
|
||||
_ = g.AddChatTemplateNode(templateNodeKey, template)
|
||||
|
||||
_ = g.AddEdge(compose.START, templateNodeKey)
|
||||
}
|
||||
|
||||
if len(cfg.Tools) > 0 {
|
||||
m, ok := cfg.ChatModel.(model.ToolCallingChatModel)
|
||||
if len(tools) > 0 {
|
||||
m, ok := modelWithInfo.(model.ToolCallingChatModel)
|
||||
if !ok {
|
||||
return nil, errors.New("requires a ToolCallingChatModel to use with tools")
|
||||
}
|
||||
reactConfig := react.AgentConfig{
|
||||
ToolCallingModel: m,
|
||||
ToolsConfig: compose.ToolsNodeConfig{Tools: cfg.Tools},
|
||||
ToolsConfig: compose.ToolsNodeConfig{Tools: tools},
|
||||
ModelNodeName: agentModelName,
|
||||
}
|
||||
|
||||
if len(cfg.ToolsReturnDirectly) > 0 {
|
||||
reactConfig.ToolReturnDirectly = make(map[string]struct{}, len(cfg.ToolsReturnDirectly))
|
||||
for k := range cfg.ToolsReturnDirectly {
|
||||
if len(toolsReturnDirectly) > 0 {
|
||||
reactConfig.ToolReturnDirectly = make(map[string]struct{}, len(toolsReturnDirectly))
|
||||
for k := range toolsReturnDirectly {
|
||||
reactConfig.ToolReturnDirectly[k] = struct{}{}
|
||||
}
|
||||
}
|
||||
@@ -347,28 +636,26 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
|
||||
opts = append(opts, compose.WithNodeName("workflow_llm_react_agent"))
|
||||
_ = g.AddGraphNode(llmNodeKey, agentNode, opts...)
|
||||
} else {
|
||||
_ = g.AddChatModelNode(llmNodeKey, cfg.ChatModel)
|
||||
_ = g.AddChatModelNode(llmNodeKey, modelWithInfo)
|
||||
}
|
||||
|
||||
_ = g.AddEdge(templateNodeKey, llmNodeKey)
|
||||
|
||||
if format == FormatJSON {
|
||||
iConvert := func(ctx context.Context, msg *schema.Message) (map[string]any, error) {
|
||||
return jsonParse(ctx, msg.Content, cfg.OutputFields)
|
||||
return jsonParse(ctx, msg.Content, ns.OutputTypes)
|
||||
}
|
||||
|
||||
convertNode := compose.InvokableLambda(iConvert)
|
||||
|
||||
_ = g.AddLambdaNode(outputConvertNodeKey, convertNode)
|
||||
|
||||
canStream = false
|
||||
} else {
|
||||
var outputKey string
|
||||
if len(cfg.OutputFields) != 1 && len(cfg.OutputFields) != 2 {
|
||||
if len(ns.OutputTypes) != 1 && len(ns.OutputTypes) != 2 {
|
||||
panic("impossible")
|
||||
}
|
||||
|
||||
for k, v := range cfg.OutputFields {
|
||||
for k, v := range ns.OutputTypes {
|
||||
if v.Type != vo.DataTypeString {
|
||||
panic("impossible")
|
||||
}
|
||||
@@ -443,17 +730,17 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
|
||||
_ = g.AddEdge(outputConvertNodeKey, compose.END)
|
||||
|
||||
requireCheckpoint := false
|
||||
if len(cfg.Tools) > 0 {
|
||||
if len(tools) > 0 {
|
||||
requireCheckpoint = true
|
||||
}
|
||||
|
||||
var opts []compose.GraphCompileOption
|
||||
var compileOpts []compose.GraphCompileOption
|
||||
if requireCheckpoint {
|
||||
opts = append(opts, compose.WithCheckPointStore(workflow.GetRepository()))
|
||||
compileOpts = append(compileOpts, compose.WithCheckPointStore(workflow.GetRepository()))
|
||||
}
|
||||
opts = append(opts, compose.WithGraphName("workflow_llm_node_graph"))
|
||||
compileOpts = append(compileOpts, compose.WithGraphName("workflow_llm_node_graph"))
|
||||
|
||||
r, err := g.Compile(ctx, opts...)
|
||||
r, err := g.Compile(ctx, compileOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -461,15 +748,132 @@ func New(ctx context.Context, cfg *Config) (*LLM, error) {
|
||||
llm := &LLM{
|
||||
r: r,
|
||||
outputFormat: format,
|
||||
canStream: canStream,
|
||||
requireCheckpoint: requireCheckpoint,
|
||||
fullSources: cfg.FullSources,
|
||||
fullSources: ns.FullSources,
|
||||
}
|
||||
|
||||
return llm, nil
|
||||
}
|
||||
|
||||
func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...Option) (composeOpts []compose.Option, resumingEvent *entity.InterruptEvent, err error) {
|
||||
func (c *Config) RequireCheckpoint() bool {
|
||||
if c.FCParam != nil {
|
||||
if c.FCParam.WorkflowFCParam != nil || c.FCParam.PluginFCParam != nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Config) FieldStreamType(path compose.FieldPath, ns *schema2.NodeSchema,
|
||||
sc *schema2.WorkflowSchema) (schema2.FieldStreamType, error) {
|
||||
if !sc.RequireStreaming() {
|
||||
return schema2.FieldNotStream, nil
|
||||
}
|
||||
|
||||
if len(path) != 1 {
|
||||
return schema2.FieldNotStream, nil
|
||||
}
|
||||
|
||||
outputs := ns.OutputTypes
|
||||
if len(outputs) != 1 && len(outputs) != 2 {
|
||||
return schema2.FieldNotStream, nil
|
||||
}
|
||||
|
||||
var outputKey string
|
||||
for key, output := range outputs {
|
||||
if output.Type != vo.DataTypeString {
|
||||
return schema2.FieldNotStream, nil
|
||||
}
|
||||
|
||||
if key != ReasoningOutputKey {
|
||||
if len(outputKey) > 0 {
|
||||
return schema2.FieldNotStream, nil
|
||||
}
|
||||
outputKey = key
|
||||
}
|
||||
}
|
||||
|
||||
field := path[0]
|
||||
if field == ReasoningOutputKey || field == outputKey {
|
||||
return schema2.FieldIsStream, nil
|
||||
}
|
||||
|
||||
return schema2.FieldNotStream, nil
|
||||
}
|
||||
|
||||
func toRetrievalSearchType(s int64) (knowledge.SearchType, error) {
|
||||
switch s {
|
||||
case 0:
|
||||
return knowledge.SearchTypeSemantic, nil
|
||||
case 1:
|
||||
return knowledge.SearchTypeHybrid, nil
|
||||
case 20:
|
||||
return knowledge.SearchTypeFullText, nil
|
||||
default:
|
||||
return "", fmt.Errorf("invalid retrieval search type %v", s)
|
||||
}
|
||||
}
|
||||
|
||||
type LLM struct {
|
||||
r compose.Runnable[map[string]any, map[string]any]
|
||||
outputFormat Format
|
||||
requireCheckpoint bool
|
||||
fullSources map[string]*schema2.SourceInfo
|
||||
}
|
||||
|
||||
const (
|
||||
rawOutputKey = "llm_raw_output_%s"
|
||||
warningKey = "llm_warning_%s"
|
||||
)
|
||||
|
||||
func jsonParse(ctx context.Context, data string, schema_ map[string]*vo.TypeInfo) (map[string]any, error) {
|
||||
data = nodes.ExtractJSONString(data)
|
||||
|
||||
var result map[string]any
|
||||
|
||||
err := sonic.UnmarshalString(data, &result)
|
||||
if err != nil {
|
||||
c := execute.GetExeCtx(ctx)
|
||||
if c != nil {
|
||||
logs.CtxErrorf(ctx, "failed to parse json: %v, data: %s", err, data)
|
||||
rawOutputK := fmt.Sprintf(rawOutputKey, c.NodeCtx.NodeKey)
|
||||
warningK := fmt.Sprintf(warningKey, c.NodeCtx.NodeKey)
|
||||
ctxcache.Store(ctx, rawOutputK, data)
|
||||
ctxcache.Store(ctx, warningK, vo.WrapWarn(errno.ErrLLMStructuredOutputParseFail, err))
|
||||
return map[string]any{}, nil
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r, ws, err := nodes.ConvertInputs(ctx, result, schema_)
|
||||
if err != nil {
|
||||
return nil, vo.WrapError(errno.ErrLLMStructuredOutputParseFail, err)
|
||||
}
|
||||
|
||||
if ws != nil {
|
||||
logs.CtxWarnf(ctx, "convert inputs warnings: %v", *ws)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
type llmOptions struct {
|
||||
toolWorkflowSW *schema.StreamWriter[*entity.Message]
|
||||
}
|
||||
|
||||
func WithToolWorkflowMessageWriter(sw *schema.StreamWriter[*entity.Message]) nodes.NodeOption {
|
||||
return nodes.WrapImplSpecificOptFn(func(o *llmOptions) {
|
||||
o.toolWorkflowSW = sw
|
||||
})
|
||||
}
|
||||
|
||||
type llmState = map[string]any
|
||||
|
||||
const agentModelName = "agent_model"
|
||||
|
||||
func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...nodes.NodeOption) (composeOpts []compose.Option, resumingEvent *entity.InterruptEvent, err error) {
|
||||
c := execute.GetExeCtx(ctx)
|
||||
if c != nil {
|
||||
resumingEvent = c.NodeCtx.ResumingEvent
|
||||
@@ -502,17 +906,9 @@ func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...Option) (co
|
||||
composeOpts = append(composeOpts, compose.WithCheckPointID(checkpointID))
|
||||
}
|
||||
|
||||
llmOpts := &Options{}
|
||||
for _, opt := range opts {
|
||||
opt(llmOpts)
|
||||
}
|
||||
options := nodes.GetCommonOptions(&nodes.NodeOptions{}, opts...)
|
||||
|
||||
nestedOpts := &nodes.NestedWorkflowOptions{}
|
||||
for _, opt := range llmOpts.nested {
|
||||
opt(nestedOpts)
|
||||
}
|
||||
|
||||
composeOpts = append(composeOpts, nestedOpts.GetOptsForNested()...)
|
||||
composeOpts = append(composeOpts, options.GetOptsForNested()...)
|
||||
|
||||
if resumingEvent != nil {
|
||||
var (
|
||||
@@ -580,6 +976,7 @@ func (l *LLM) prepare(ctx context.Context, _ map[string]any, opts ...Option) (co
|
||||
composeOpts = append(composeOpts, compose.WithToolsNodeOption(compose.WithToolOption(execute.WithExecuteConfig(exeCfg))))
|
||||
}
|
||||
|
||||
llmOpts := nodes.GetImplSpecificOptions(&llmOptions{}, opts...)
|
||||
if llmOpts.toolWorkflowSW != nil {
|
||||
toolMsgOpt, toolMsgSR := execute.WithMessagePipe()
|
||||
composeOpts = append(composeOpts, toolMsgOpt)
|
||||
@@ -697,7 +1094,7 @@ func handleInterrupt(ctx context.Context, err error, resumingEvent *entity.Inter
|
||||
return compose.NewInterruptAndRerunErr(ie)
|
||||
}
|
||||
|
||||
func (l *LLM) Chat(ctx context.Context, in map[string]any, opts ...Option) (out map[string]any, err error) {
|
||||
func (l *LLM) Invoke(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (out map[string]any, err error) {
|
||||
composeOpts, resumingEvent, err := l.prepare(ctx, in, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -712,7 +1109,7 @@ func (l *LLM) Chat(ctx context.Context, in map[string]any, opts ...Option) (out
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (l *LLM) ChatStream(ctx context.Context, in map[string]any, opts ...Option) (out *schema.StreamReader[map[string]any], err error) {
|
||||
func (l *LLM) Stream(ctx context.Context, in map[string]any, opts ...nodes.NodeOption) (out *schema.StreamReader[map[string]any], err error) {
|
||||
composeOpts, resumingEvent, err := l.prepare(ctx, in, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -745,7 +1142,7 @@ func injectKnowledgeTool(_ context.Context, g *compose.Graph[map[string]any, map
|
||||
|
||||
_ = g.AddLambdaNode(knowledgeLambdaKey, compose.InvokableLambda(func(ctx context.Context, input *schema.Message) (output map[string]any, err error) {
|
||||
modelPredictionIDs := strings.Split(input.Content, ",")
|
||||
selectKwIDs := slices.ToMap(cfg.SelectedKnowledgeDetails, func(e *crossknowledge.KnowledgeDetail) (string, int64) {
|
||||
selectKwIDs := slices.ToMap(cfg.SelectedKnowledgeDetails, func(e *knowledge.KnowledgeDetail) (string, int64) {
|
||||
return strconv.Itoa(int(e.ID)), e.ID
|
||||
})
|
||||
recallKnowledgeIDs := make([]int64, 0)
|
||||
@@ -759,7 +1156,7 @@ func injectKnowledgeTool(_ context.Context, g *compose.Graph[map[string]any, map
|
||||
return make(map[string]any), nil
|
||||
}
|
||||
|
||||
docs, err := cfg.Retriever.Retrieve(ctx, &crossknowledge.RetrieveRequest{
|
||||
docs, err := cfg.Retriever.Retrieve(ctx, &knowledge.RetrieveRequest{
|
||||
Query: userPrompt,
|
||||
KnowledgeIDs: recallKnowledgeIDs,
|
||||
RetrievalStrategy: cfg.RetrievalStrategy.RetrievalStrategy,
|
||||
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
|
||||
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
|
||||
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
|
||||
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/ctxcache"
|
||||
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
|
||||
@@ -107,7 +108,7 @@ func newPrompts(sp, up *promptTpl, model ModelWithInfo) *prompts {
|
||||
}
|
||||
|
||||
func (pl *promptTpl) render(ctx context.Context, vs map[string]any,
|
||||
sources map[string]*nodes.SourceInfo,
|
||||
sources map[string]*schema2.SourceInfo,
|
||||
supportedModals map[modelmgr.Modal]bool,
|
||||
) (*schema.Message, error) {
|
||||
if !pl.hasMultiModal || len(supportedModals) == 0 {
|
||||
@@ -247,7 +248,7 @@ func (p *prompts) Format(ctx context.Context, vs map[string]any, _ ...prompt.Opt
|
||||
}
|
||||
sk := fmt.Sprintf(sourceKey, nodeKey)
|
||||
|
||||
sources, ok := ctxcache.Get[map[string]*nodes.SourceInfo](ctx, sk)
|
||||
sources, ok := ctxcache.Get[map[string]*schema2.SourceInfo](ctx, sk)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("resolved sources not found llm node, key: %s", sk)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user