refactor: how to add a node type in workflow (#558)

This commit is contained in:
shentongmartin
2025-08-05 14:02:33 +08:00
committed by GitHub
parent 5dafd81a3f
commit bb6ff0026b
96 changed files with 8305 additions and 8717 deletions

View File

@@ -32,8 +32,10 @@ import (
"github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/execute"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa"
"github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/receiver"
schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema"
"github.com/coze-dev/coze-studio/backend/pkg/sonic"
)
@@ -46,9 +48,9 @@ type State struct {
InterruptEvents map[vo.NodeKey]*entity.InterruptEvent `json:"interrupt_events,omitempty"`
NestedWorkflowStates map[vo.NodeKey]*nodes.NestedWorkflowState `json:"nested_workflow_states,omitempty"`
ExecutedNodes map[vo.NodeKey]bool `json:"executed_nodes,omitempty"`
SourceInfos map[vo.NodeKey]map[string]*nodes.SourceInfo `json:"source_infos,omitempty"`
GroupChoices map[vo.NodeKey]map[string]int `json:"group_choices,omitempty"`
ExecutedNodes map[vo.NodeKey]bool `json:"executed_nodes,omitempty"`
SourceInfos map[vo.NodeKey]map[string]*schema2.SourceInfo `json:"source_infos,omitempty"`
GroupChoices map[vo.NodeKey]map[string]int `json:"group_choices,omitempty"`
ToolInterruptEvents map[vo.NodeKey]map[string] /*ToolCallID*/ *entity.ToolInterruptEvent `json:"tool_interrupt_events,omitempty"`
LLMToResumeData map[vo.NodeKey]string `json:"llm_to_resume_data,omitempty"`
@@ -71,8 +73,8 @@ func init() {
_ = compose.RegisterSerializableType[*model.TokenUsage]("model_token_usage")
_ = compose.RegisterSerializableType[*nodes.NestedWorkflowState]("composite_state")
_ = compose.RegisterSerializableType[*compose.InterruptInfo]("interrupt_info")
_ = compose.RegisterSerializableType[*nodes.SourceInfo]("source_info")
_ = compose.RegisterSerializableType[nodes.FieldStreamType]("field_stream_type")
_ = compose.RegisterSerializableType[*schema2.SourceInfo]("source_info")
_ = compose.RegisterSerializableType[schema2.FieldStreamType]("field_stream_type")
_ = compose.RegisterSerializableType[compose.FieldPath]("field_path")
_ = compose.RegisterSerializableType[*entity.WorkflowBasic]("workflow_basic")
_ = compose.RegisterSerializableType[vo.TerminatePlan]("terminate_plan")
@@ -162,41 +164,41 @@ func (s *State) GetDynamicChoice(nodeKey vo.NodeKey) map[string]int {
return s.GroupChoices[nodeKey]
}
func (s *State) GetDynamicStreamType(nodeKey vo.NodeKey, group string) (nodes.FieldStreamType, error) {
func (s *State) GetDynamicStreamType(nodeKey vo.NodeKey, group string) (schema2.FieldStreamType, error) {
choices, ok := s.GroupChoices[nodeKey]
if !ok {
return nodes.FieldMaybeStream, fmt.Errorf("choice not found for node %s", nodeKey)
return schema2.FieldMaybeStream, fmt.Errorf("choice not found for node %s", nodeKey)
}
choice, ok := choices[group]
if !ok {
return nodes.FieldMaybeStream, fmt.Errorf("choice not found for node %s and group %s", nodeKey, group)
return schema2.FieldMaybeStream, fmt.Errorf("choice not found for node %s and group %s", nodeKey, group)
}
if choice == -1 { // this group picks none of the elements
return nodes.FieldNotStream, nil
return schema2.FieldNotStream, nil
}
sInfos, ok := s.SourceInfos[nodeKey]
if !ok {
return nodes.FieldMaybeStream, fmt.Errorf("source infos not found for node %s", nodeKey)
return schema2.FieldMaybeStream, fmt.Errorf("source infos not found for node %s", nodeKey)
}
groupInfo, ok := sInfos[group]
if !ok {
return nodes.FieldMaybeStream, fmt.Errorf("source infos not found for node %s and group %s", nodeKey, group)
return schema2.FieldMaybeStream, fmt.Errorf("source infos not found for node %s and group %s", nodeKey, group)
}
if groupInfo.SubSources == nil {
return nodes.FieldNotStream, fmt.Errorf("dynamic group %s of node %s does not contain any sub sources", group, nodeKey)
return schema2.FieldNotStream, fmt.Errorf("dynamic group %s of node %s does not contain any sub sources", group, nodeKey)
}
subInfo, ok := groupInfo.SubSources[strconv.Itoa(choice)]
if !ok {
return nodes.FieldNotStream, fmt.Errorf("dynamic group %s of node %s does not contain sub source for choice %d", group, nodeKey, choice)
return schema2.FieldNotStream, fmt.Errorf("dynamic group %s of node %s does not contain sub source for choice %d", group, nodeKey, choice)
}
if subInfo.FieldType != nodes.FieldMaybeStream {
if subInfo.FieldType != schema2.FieldMaybeStream {
return subInfo.FieldType, nil
}
@@ -211,8 +213,8 @@ func (s *State) GetDynamicStreamType(nodeKey vo.NodeKey, group string) (nodes.Fi
return s.GetDynamicStreamType(subInfo.FromNodeKey, subInfo.FromPath[0])
}
func (s *State) GetAllDynamicStreamTypes(nodeKey vo.NodeKey) (map[string]nodes.FieldStreamType, error) {
result := make(map[string]nodes.FieldStreamType)
func (s *State) GetAllDynamicStreamTypes(nodeKey vo.NodeKey) (map[string]schema2.FieldStreamType, error) {
result := make(map[string]schema2.FieldStreamType)
choices, ok := s.GroupChoices[nodeKey]
if !ok {
return result, nil
@@ -269,7 +271,7 @@ func GenState() compose.GenLocalState[*State] {
InterruptEvents: make(map[vo.NodeKey]*entity.InterruptEvent),
NestedWorkflowStates: make(map[vo.NodeKey]*nodes.NestedWorkflowState),
ExecutedNodes: make(map[vo.NodeKey]bool),
SourceInfos: make(map[vo.NodeKey]map[string]*nodes.SourceInfo),
SourceInfos: make(map[vo.NodeKey]map[string]*schema2.SourceInfo),
GroupChoices: make(map[vo.NodeKey]map[string]int),
ToolInterruptEvents: make(map[vo.NodeKey]map[string]*entity.ToolInterruptEvent),
LLMToResumeData: make(map[vo.NodeKey]string),
@@ -277,7 +279,7 @@ func GenState() compose.GenLocalState[*State] {
}
}
func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt {
func statePreHandler(s *schema2.NodeSchema, stream bool) compose.GraphAddNodeOpt {
var (
handlers []compose.StatePreHandler[map[string]any, *State]
streamHandlers []compose.StreamStatePreHandler[map[string]any, *State]
@@ -314,7 +316,7 @@ func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt {
}
return in, nil
})
} else if s.Type == entity.NodeTypeBatch || s.Type == entity.NodeTypeLoop {
} else if entity.NodeMetaByNodeType(s.Type).IsComposite {
handlers = append(handlers, func(ctx context.Context, in map[string]any, state *State) (map[string]any, error) {
if _, ok := state.Inputs[s.Key]; !ok { // first execution, store input for potential resume later
state.Inputs[s.Key] = in
@@ -329,7 +331,7 @@ func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt {
}
if len(handlers) > 0 || !stream {
handlerForVars := s.statePreHandlerForVars()
handlerForVars := statePreHandlerForVars(s)
if handlerForVars != nil {
handlers = append(handlers, handlerForVars)
}
@@ -349,12 +351,12 @@ func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt {
if s.Type == entity.NodeTypeVariableAggregator {
streamHandlers = append(streamHandlers, func(ctx context.Context, in *schema.StreamReader[map[string]any], state *State) (*schema.StreamReader[map[string]any], error) {
state.SourceInfos[s.Key] = mustGetKey[map[string]*nodes.SourceInfo]("FullSources", s.Configs)
state.SourceInfos[s.Key] = s.FullSources
return in, nil
})
}
handlerForVars := s.streamStatePreHandlerForVars()
handlerForVars := streamStatePreHandlerForVars(s)
if handlerForVars != nil {
streamHandlers = append(streamHandlers, handlerForVars)
}
@@ -381,7 +383,7 @@ func (s *NodeSchema) StatePreHandler(stream bool) compose.GraphAddNodeOpt {
return nil
}
func (s *NodeSchema) statePreHandlerForVars() compose.StatePreHandler[map[string]any, *State] {
func statePreHandlerForVars(s *schema2.NodeSchema) compose.StatePreHandler[map[string]any, *State] {
// checkout the node's inputs, if it has any variable, use the state's variableHandler to get the variables and set them to the input
var vars []*vo.FieldInfo
for _, input := range s.InputSources {
@@ -456,7 +458,7 @@ func (s *NodeSchema) statePreHandlerForVars() compose.StatePreHandler[map[string
}
}
func (s *NodeSchema) streamStatePreHandlerForVars() compose.StreamStatePreHandler[map[string]any, *State] {
func streamStatePreHandlerForVars(s *schema2.NodeSchema) compose.StreamStatePreHandler[map[string]any, *State] {
// checkout the node's inputs, if it has any variables, get the variables and merge them with the input
var vars []*vo.FieldInfo
for _, input := range s.InputSources {
@@ -533,7 +535,7 @@ func (s *NodeSchema) streamStatePreHandlerForVars() compose.StreamStatePreHandle
}
}
func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamStatePreHandler[map[string]any, *State] {
func streamStatePreHandlerForStreamSources(s *schema2.NodeSchema) compose.StreamStatePreHandler[map[string]any, *State] {
// if it does not have source info, do not add this pre handler
if s.Configs == nil {
return nil
@@ -543,7 +545,7 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
case entity.NodeTypeVariableAggregator, entity.NodeTypeOutputEmitter:
return nil
case entity.NodeTypeExit:
terminatePlan := mustGetKey[vo.TerminatePlan]("TerminalPlan", s.Configs)
terminatePlan := s.Configs.(*exit.Config).TerminatePlan
if terminatePlan != vo.ReturnVariables {
return nil
}
@@ -551,7 +553,7 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
// all other node can only accept non-stream inputs, relying on Eino's automatically stream concatenation.
}
sourceInfo := getKeyOrZero[map[string]*nodes.SourceInfo]("FullSources", s.Configs)
sourceInfo := s.FullSources
if len(sourceInfo) == 0 {
return nil
}
@@ -566,10 +568,10 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
var (
anyStream bool
checker func(source *nodes.SourceInfo) bool
checker func(source *schema2.SourceInfo) bool
)
checker = func(source *nodes.SourceInfo) bool {
if source.FieldType != nodes.FieldNotStream {
checker = func(source *schema2.SourceInfo) bool {
if source.FieldType != schema2.FieldNotStream {
return true
}
for _, subSource := range source.SubSources {
@@ -594,8 +596,8 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
return func(ctx context.Context, in *schema.StreamReader[map[string]any], state *State) (*schema.StreamReader[map[string]any], error) {
resolved := map[string]resolvedStreamSource{}
var resolver func(source nodes.SourceInfo) (result *resolvedStreamSource, err error)
resolver = func(source nodes.SourceInfo) (result *resolvedStreamSource, err error) {
var resolver func(source schema2.SourceInfo) (result *resolvedStreamSource, err error)
resolver = func(source schema2.SourceInfo) (result *resolvedStreamSource, err error) {
if source.IsIntermediate {
result = &resolvedStreamSource{
intermediate: true,
@@ -615,14 +617,14 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
}
streamType := source.FieldType
if streamType == nodes.FieldMaybeStream {
if streamType == schema2.FieldMaybeStream {
streamType, err = state.GetDynamicStreamType(source.FromNodeKey, source.FromPath[0])
if err != nil {
return nil, err
}
}
if streamType == nodes.FieldNotStream {
if streamType == schema2.FieldNotStream {
return nil, nil
}
@@ -690,7 +692,7 @@ func (s *NodeSchema) streamStatePreHandlerForStreamSources() compose.StreamState
}
}
func (s *NodeSchema) StatePostHandler(stream bool) compose.GraphAddNodeOpt {
func statePostHandler(s *schema2.NodeSchema, stream bool) compose.GraphAddNodeOpt {
var (
handlers []compose.StatePostHandler[map[string]any, *State]
streamHandlers []compose.StreamStatePostHandler[map[string]any, *State]
@@ -702,7 +704,7 @@ func (s *NodeSchema) StatePostHandler(stream bool) compose.GraphAddNodeOpt {
return out, nil
})
forVars := s.streamStatePostHandlerForVars()
forVars := streamStatePostHandlerForVars(s)
if forVars != nil {
streamHandlers = append(streamHandlers, forVars)
}
@@ -725,7 +727,7 @@ func (s *NodeSchema) StatePostHandler(stream bool) compose.GraphAddNodeOpt {
return out, nil
})
forVars := s.statePostHandlerForVars()
forVars := statePostHandlerForVars(s)
if forVars != nil {
handlers = append(handlers, forVars)
}
@@ -745,7 +747,7 @@ func (s *NodeSchema) StatePostHandler(stream bool) compose.GraphAddNodeOpt {
return compose.WithStatePostHandler(handler)
}
func (s *NodeSchema) statePostHandlerForVars() compose.StatePostHandler[map[string]any, *State] {
func statePostHandlerForVars(s *schema2.NodeSchema) compose.StatePostHandler[map[string]any, *State] {
// checkout the node's output sources, if it has any variable,
// use the state's variableHandler to get the variables and set them to the output
var vars []*vo.FieldInfo
@@ -823,7 +825,7 @@ func (s *NodeSchema) statePostHandlerForVars() compose.StatePostHandler[map[stri
}
}
func (s *NodeSchema) streamStatePostHandlerForVars() compose.StreamStatePostHandler[map[string]any, *State] {
func streamStatePostHandlerForVars(s *schema2.NodeSchema) compose.StreamStatePostHandler[map[string]any, *State] {
// checkout the node's output sources, if it has any variables, get the variables and merge them with the output
var vars []*vo.FieldInfo
for _, output := range s.OutputSources {