refactor: how to add a node type in workflow (#558)

This commit is contained in:
shentongmartin
2025-08-05 14:02:33 +08:00
committed by GitHub
parent 5dafd81a3f
commit bb6ff0026b
96 changed files with 8305 additions and 8717 deletions

View File

@@ -29,9 +29,12 @@ import (
"github.com/cloudwego/eino/schema"
"github.com/coze-dev/coze-studio/backend/domain/workflow"
crossmodel "github.com/coze-dev/coze-studio/backend/domain/workflow/crossdomain/model"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity"
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/canvas/convert"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/lang/ternary"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
@@ -39,8 +42,21 @@ import (
)
type QuestionAnswer struct {
config *Config
model model.BaseChatModel
nodeMeta entity.NodeTypeMeta
questionTpl string
answerType AnswerType
choiceType ChoiceType
fixedChoices []string
needExtractFromAnswer bool
additionalSystemPromptTpl string
maxAnswerCount int
nodeKey vo.NodeKey
outputFields map[string]*vo.TypeInfo
}
type Config struct {
@@ -51,15 +67,249 @@ type Config struct {
FixedChoices []string
// used for intent recognize if answer by choices and given a custom answer, as well as for extracting structured output from user response
Model model.BaseChatModel
LLMParams *crossmodel.LLMParams
// the following are required if AnswerType is AnswerDirectly and needs to extract from answer
ExtractFromAnswer bool
AdditionalSystemPromptTpl string
MaxAnswerCount int
OutputFields map[string]*vo.TypeInfo
}
NodeKey vo.NodeKey
func (c *Config) Adapt(_ context.Context, n *vo.Node, _ ...nodes.AdaptOption) (*schema2.NodeSchema, error) {
ns := &schema2.NodeSchema{
Key: vo.NodeKey(n.ID),
Type: entity.NodeTypeQuestionAnswer,
Name: n.Data.Meta.Title,
Configs: c,
}
qaConf := n.Data.Inputs.QA
if qaConf == nil {
return nil, fmt.Errorf("qa config is nil")
}
c.QuestionTpl = qaConf.Question
var llmParams *crossmodel.LLMParams
if n.Data.Inputs.LLMParam != nil {
llmParamBytes, err := sonic.Marshal(n.Data.Inputs.LLMParam)
if err != nil {
return nil, err
}
var qaLLMParams vo.SimpleLLMParam
err = sonic.Unmarshal(llmParamBytes, &qaLLMParams)
if err != nil {
return nil, err
}
llmParams, err = convertLLMParams(qaLLMParams)
if err != nil {
return nil, err
}
c.LLMParams = llmParams
}
answerType, err := convertAnswerType(qaConf.AnswerType)
if err != nil {
return nil, err
}
c.AnswerType = answerType
var choiceType ChoiceType
if len(qaConf.OptionType) > 0 {
choiceType, err = convertChoiceType(qaConf.OptionType)
if err != nil {
return nil, err
}
c.ChoiceType = choiceType
}
if answerType == AnswerByChoices {
switch choiceType {
case FixedChoices:
var options []string
for _, option := range qaConf.Options {
options = append(options, option.Name)
}
c.FixedChoices = options
case DynamicChoices:
inputSources, err := convert.CanvasBlockInputToFieldInfo(qaConf.DynamicOption, compose.FieldPath{DynamicChoicesKey}, n.Parent())
if err != nil {
return nil, err
}
ns.AddInputSource(inputSources...)
inputTypes, err := convert.CanvasBlockInputToTypeInfo(qaConf.DynamicOption)
if err != nil {
return nil, err
}
ns.SetInputType(DynamicChoicesKey, inputTypes)
default:
return nil, fmt.Errorf("qa node is answer by options, but option type not provided")
}
} else if answerType == AnswerDirectly {
c.ExtractFromAnswer = qaConf.ExtractOutput
if qaConf.ExtractOutput {
if llmParams == nil {
return nil, fmt.Errorf("qa node needs to extract from answer, but LLMParams not provided")
}
c.AdditionalSystemPromptTpl = llmParams.SystemPrompt
c.MaxAnswerCount = qaConf.Limit
if err = convert.SetOutputTypesForNodeSchema(n, ns); err != nil {
return nil, err
}
}
}
if err = convert.SetInputsForNodeSchema(n, ns); err != nil {
return nil, err
}
return ns, nil
}
func convertLLMParams(params vo.SimpleLLMParam) (*crossmodel.LLMParams, error) {
p := &crossmodel.LLMParams{}
p.ModelName = params.ModelName
p.ModelType = params.ModelType
p.Temperature = &params.Temperature
p.MaxTokens = params.MaxTokens
p.TopP = &params.TopP
p.ResponseFormat = params.ResponseFormat
p.SystemPrompt = params.SystemPrompt
return p, nil
}
func convertAnswerType(t vo.QAAnswerType) (AnswerType, error) {
switch t {
case vo.QAAnswerTypeOption:
return AnswerByChoices, nil
case vo.QAAnswerTypeText:
return AnswerDirectly, nil
default:
return "", fmt.Errorf("invalid QAAnswerType: %s", t)
}
}
func convertChoiceType(t vo.QAOptionType) (ChoiceType, error) {
switch t {
case vo.QAOptionTypeStatic:
return FixedChoices, nil
case vo.QAOptionTypeDynamic:
return DynamicChoices, nil
default:
return "", fmt.Errorf("invalid QAOptionType: %s", t)
}
}
func (c *Config) Build(ctx context.Context, ns *schema2.NodeSchema, _ ...schema2.BuildOption) (any, error) {
if c.AnswerType == AnswerDirectly {
if c.ExtractFromAnswer {
if c.LLMParams == nil {
return nil, errors.New("model is required when extract from answer")
}
if len(ns.OutputTypes) == 0 {
return nil, errors.New("output fields is required when extract from answer")
}
}
} else if c.AnswerType == AnswerByChoices {
if c.ChoiceType == FixedChoices {
if len(c.FixedChoices) == 0 {
return nil, errors.New("fixed choices is required when extract from answer")
}
}
} else {
return nil, fmt.Errorf("unknown answer type: %s", c.AnswerType)
}
nodeMeta := entity.NodeMetaByNodeType(entity.NodeTypeQuestionAnswer)
if nodeMeta == nil {
return nil, errors.New("node meta not found for question answer")
}
var (
m model.BaseChatModel
err error
)
if c.LLMParams != nil {
m, _, err = crossmodel.GetManager().GetModel(ctx, c.LLMParams)
if err != nil {
return nil, err
}
}
return &QuestionAnswer{
model: m,
nodeMeta: *nodeMeta,
questionTpl: c.QuestionTpl,
answerType: c.AnswerType,
choiceType: c.ChoiceType,
fixedChoices: c.FixedChoices,
needExtractFromAnswer: c.ExtractFromAnswer,
additionalSystemPromptTpl: c.AdditionalSystemPromptTpl,
maxAnswerCount: c.MaxAnswerCount,
nodeKey: ns.Key,
outputFields: ns.OutputTypes,
}, nil
}
func (c *Config) BuildBranch(_ context.Context) (
func(ctx context.Context, nodeOutput map[string]any) (int64, bool, error), bool) {
if c.AnswerType != AnswerByChoices {
return nil, false
}
return func(ctx context.Context, nodeOutput map[string]any) (int64, bool, error) {
optionID, ok := nodeOutput[OptionIDKey]
if !ok {
return -1, false, fmt.Errorf("failed to take option id from input map: %v", nodeOutput)
}
if c.ChoiceType == DynamicChoices {
if optionID.(string) == "other" {
return -1, true, nil
} else {
return 0, false, nil
}
}
if optionID.(string) == "other" {
return -1, true, nil
}
optionIDInt, ok := AlphabetToInt(optionID.(string))
if !ok {
return -1, false, fmt.Errorf("failed to convert option id from input map: %v", optionID)
}
return optionIDInt, false, nil
}, true
}
func (c *Config) ExpectPorts(ctx context.Context, n *vo.Node) (expects []string) {
if n.Data.Inputs.QA.AnswerType != vo.QAAnswerTypeOption {
return expects
}
if n.Data.Inputs.QA.OptionType == vo.QAOptionTypeStatic {
for index := range n.Data.Inputs.QA.Options {
expects = append(expects, fmt.Sprintf(schema2.PortBranchFormat, index))
}
expects = append(expects, schema2.PortDefault)
return expects
}
if n.Data.Inputs.QA.OptionType == vo.QAOptionTypeDynamic {
expects = append(expects, fmt.Sprintf(schema2.PortBranchFormat, 0))
expects = append(expects, schema2.PortDefault)
}
return expects
}
func (c *Config) RequireCheckpoint() bool {
return true
}
type AnswerType string
@@ -126,41 +376,6 @@ Strictly identify the intention and select the most suitable option. You can onl
Note: You can only output the id or -1. Your output can only be a pure number and no other content (including the reason)!`
)
func NewQuestionAnswer(_ context.Context, conf *Config) (*QuestionAnswer, error) {
if conf == nil {
return nil, errors.New("config is nil")
}
if conf.AnswerType == AnswerDirectly {
if conf.ExtractFromAnswer {
if conf.Model == nil {
return nil, errors.New("model is required when extract from answer")
}
if len(conf.OutputFields) == 0 {
return nil, errors.New("output fields is required when extract from answer")
}
}
} else if conf.AnswerType == AnswerByChoices {
if conf.ChoiceType == FixedChoices {
if len(conf.FixedChoices) == 0 {
return nil, errors.New("fixed choices is required when extract from answer")
}
}
} else {
return nil, fmt.Errorf("unknown answer type: %s", conf.AnswerType)
}
nodeMeta := entity.NodeMetaByNodeType(entity.NodeTypeQuestionAnswer)
if nodeMeta == nil {
return nil, errors.New("node meta not found for question answer")
}
return &QuestionAnswer{
config: conf,
nodeMeta: *nodeMeta,
}, nil
}
type Question struct {
Question string
Choices []string
@@ -182,10 +397,10 @@ type message struct {
ID string `json:"id,omitempty"`
}
// Execute formats the question (optionally with choices), interrupts, then extracts the answer.
// Invoke formats the question (optionally with choices), interrupts, then extracts the answer.
// input: the references by input fields, as well as the dynamic choices array if needed.
// output: USER_RESPONSE for direct answer, structured output if needs to extract from answer, and option ID / content for answer by choices.
func (q *QuestionAnswer) Execute(ctx context.Context, in map[string]any) (out map[string]any, err error) {
func (q *QuestionAnswer) Invoke(ctx context.Context, in map[string]any) (out map[string]any, err error) {
var (
questions []*Question
answers []string
@@ -206,11 +421,11 @@ func (q *QuestionAnswer) Execute(ctx context.Context, in map[string]any) (out ma
out[QuestionsKey] = questions
out[AnswersKey] = answers
switch q.config.AnswerType {
switch q.answerType {
case AnswerDirectly:
if isFirst { // first execution, ask the question
// format the question. Which is common to all use cases
firstQuestion, err := nodes.TemplateRender(q.config.QuestionTpl, in)
firstQuestion, err := nodes.TemplateRender(q.questionTpl, in)
if err != nil {
return nil, err
}
@@ -218,7 +433,7 @@ func (q *QuestionAnswer) Execute(ctx context.Context, in map[string]any) (out ma
return nil, q.interrupt(ctx, firstQuestion, nil, nil, nil)
}
if q.config.ExtractFromAnswer {
if q.needExtractFromAnswer {
return q.extractFromAnswer(ctx, in, questions, answers)
}
@@ -253,15 +468,15 @@ func (q *QuestionAnswer) Execute(ctx context.Context, in map[string]any) (out ma
}
// format the question. Which is common to all use cases
firstQuestion, err := nodes.TemplateRender(q.config.QuestionTpl, in)
firstQuestion, err := nodes.TemplateRender(q.questionTpl, in)
if err != nil {
return nil, err
}
var formattedChoices []string
switch q.config.ChoiceType {
switch q.choiceType {
case FixedChoices:
for _, choice := range q.config.FixedChoices {
for _, choice := range q.fixedChoices {
formattedChoice, err := nodes.TemplateRender(choice, in)
if err != nil {
return nil, err
@@ -283,18 +498,18 @@ func (q *QuestionAnswer) Execute(ctx context.Context, in map[string]any) (out ma
formattedChoices = append(formattedChoices, c)
}
default:
return nil, fmt.Errorf("unknown choice type: %s", q.config.ChoiceType)
return nil, fmt.Errorf("unknown choice type: %s", q.choiceType)
}
return nil, q.interrupt(ctx, firstQuestion, formattedChoices, nil, nil)
default:
return nil, fmt.Errorf("unknown answer type: %s", q.config.AnswerType)
return nil, fmt.Errorf("unknown answer type: %s", q.answerType)
}
}
func (q *QuestionAnswer) extractFromAnswer(ctx context.Context, in map[string]any, questions []*Question, answers []string) (map[string]any, error) {
fieldInfo := "FieldInfo"
s, err := vo.TypeInfoToJSONSchema(q.config.OutputFields, &fieldInfo)
s, err := vo.TypeInfoToJSONSchema(q.outputFields, &fieldInfo)
if err != nil {
return nil, err
}
@@ -302,15 +517,15 @@ func (q *QuestionAnswer) extractFromAnswer(ctx context.Context, in map[string]an
sysPrompt := fmt.Sprintf(extractSystemPrompt, s)
var requiredFields []string
for fName, tInfo := range q.config.OutputFields {
for fName, tInfo := range q.outputFields {
if tInfo.Required {
requiredFields = append(requiredFields, fName)
}
}
var formattedAdditionalPrompt string
if len(q.config.AdditionalSystemPromptTpl) > 0 {
additionalPrompt, err := nodes.TemplateRender(q.config.AdditionalSystemPromptTpl, in)
if len(q.additionalSystemPromptTpl) > 0 {
additionalPrompt, err := nodes.TemplateRender(q.additionalSystemPromptTpl, in)
if err != nil {
return nil, err
}
@@ -336,7 +551,7 @@ func (q *QuestionAnswer) extractFromAnswer(ctx context.Context, in map[string]an
messages = append(messages, schema.UserMessage(answer))
}
out, err := q.config.Model.Generate(ctx, messages)
out, err := q.model.Generate(ctx, messages)
if err != nil {
return nil, err
}
@@ -353,8 +568,8 @@ func (q *QuestionAnswer) extractFromAnswer(ctx context.Context, in map[string]an
if ok {
nextQuestionStr, ok := nextQuestion.(string)
if ok && len(nextQuestionStr) > 0 {
if len(answers) >= q.config.MaxAnswerCount {
return nil, fmt.Errorf("max answer count= %d exceeded", q.config.MaxAnswerCount)
if len(answers) >= q.maxAnswerCount {
return nil, fmt.Errorf("max answer count= %d exceeded", q.maxAnswerCount)
}
return nil, q.interrupt(ctx, nextQuestionStr, nil, questions, answers)
@@ -366,7 +581,7 @@ func (q *QuestionAnswer) extractFromAnswer(ctx context.Context, in map[string]an
return nil, fmt.Errorf("field %s not found", fieldInfo)
}
realOutput, ws, err := nodes.ConvertInputs(ctx, fields.(map[string]any), q.config.OutputFields, nodes.SkipRequireCheck())
realOutput, ws, err := nodes.ConvertInputs(ctx, fields.(map[string]any), q.outputFields, nodes.SkipRequireCheck())
if err != nil {
return nil, err
}
@@ -431,7 +646,7 @@ func (q *QuestionAnswer) intentDetect(ctx context.Context, answer string, choice
schema.UserMessage(answer),
}
out, err := q.config.Model.Generate(ctx, messages)
out, err := q.model.Generate(ctx, messages)
if err != nil {
return -1, err
}
@@ -468,7 +683,7 @@ func (q *QuestionAnswer) interrupt(ctx context.Context, newQuestion string, choi
event := &entity.InterruptEvent{
ID: eventID,
NodeKey: q.config.NodeKey,
NodeKey: q.nodeKey,
NodeType: entity.NodeTypeQuestionAnswer,
NodeTitle: q.nodeMeta.Name,
NodeIcon: q.nodeMeta.IconURL,
@@ -477,7 +692,7 @@ func (q *QuestionAnswer) interrupt(ctx context.Context, newQuestion string, choi
}
_ = compose.ProcessState(ctx, func(ctx context.Context, setter QuestionAnswerAware) error {
setter.AddQuestion(q.config.NodeKey, &Question{
setter.AddQuestion(q.nodeKey, &Question{
Question: newQuestion,
Choices: choices,
})
@@ -495,14 +710,14 @@ func intToAlphabet(num int) string {
return ""
}
func AlphabetToInt(str string) (int, bool) {
func AlphabetToInt(str string) (int64, bool) {
if len(str) != 1 {
return 0, false
}
char := rune(str[0])
char = unicode.ToUpper(char)
if char >= 'A' && char <= 'Z' {
return int(char - 'A'), true
return int64(char - 'A'), true
}
return 0, false
}
@@ -521,14 +736,14 @@ func (q *QuestionAnswer) generateHistory(oldQuestions []*Question, oldAnswers []
for i := 0; i < len(oldQuestions); i++ {
oldQuestion := oldQuestions[i]
oldAnswer := oldAnswers[i]
contentType := ternary.IFElse(q.config.AnswerType == AnswerByChoices, "option", "text")
contentType := ternary.IFElse(q.answerType == AnswerByChoices, "option", "text")
questionMsg := &message{
Type: "question",
ContentType: contentType,
ID: fmt.Sprintf("%s_%d", q.config.NodeKey, i*2),
ID: fmt.Sprintf("%s_%d", q.nodeKey, i*2),
}
if q.config.AnswerType == AnswerByChoices {
if q.answerType == AnswerByChoices {
questionMsg.Content = optionContent{
Options: conv(oldQuestion.Choices),
Question: oldQuestion.Question,
@@ -541,14 +756,14 @@ func (q *QuestionAnswer) generateHistory(oldQuestions []*Question, oldAnswers []
Type: "answer",
ContentType: contentType,
Content: oldAnswer,
ID: fmt.Sprintf("%s_%d", q.config.NodeKey, i+1),
ID: fmt.Sprintf("%s_%d", q.nodeKey, i+1),
}
history = append(history, questionMsg, answerMsg)
}
if newQuestion != nil {
if q.config.AnswerType == AnswerByChoices {
if q.answerType == AnswerByChoices {
history = append(history, &message{
Type: "question",
ContentType: "option",
@@ -556,14 +771,14 @@ func (q *QuestionAnswer) generateHistory(oldQuestions []*Question, oldAnswers []
Options: conv(choices),
Question: *newQuestion,
},
ID: fmt.Sprintf("%s_%d", q.config.NodeKey, len(oldQuestions)*2),
ID: fmt.Sprintf("%s_%d", q.nodeKey, len(oldQuestions)*2),
})
} else {
history = append(history, &message{
Type: "question",
ContentType: "text",
Content: *newQuestion,
ID: fmt.Sprintf("%s_%d", q.config.NodeKey, len(oldQuestions)*2),
ID: fmt.Sprintf("%s_%d", q.nodeKey, len(oldQuestions)*2),
})
}
}